From 153dfeb4dde562ef0bad6743e832e76f28dc9200 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 6 Mar 2023 20:12:54 -0300 Subject: [PATCH] Add --rwkv-cuda-on parameter, bump rwkv version --- modules/RWKV.py | 2 +- modules/shared.py | 3 ++- requirements.txt | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/modules/RWKV.py b/modules/RWKV.py index 9a806a00..acc97044 100644 --- a/modules/RWKV.py +++ b/modules/RWKV.py @@ -9,7 +9,7 @@ import modules.shared as shared np.set_printoptions(precision=4, suppress=True, linewidth=200) os.environ['RWKV_JIT_ON'] = '1' -os.environ["RWKV_CUDA_ON"] = '0' # '1' : use CUDA kernel for seq mode (much faster) +os.environ["RWKV_CUDA_ON"] = '1' if shared.args.rwkv_cuda_on else '0' # use CUDA kernel for seq mode (much faster) from rwkv.model import RWKV from rwkv.utils import PIPELINE, PIPELINE_ARGS diff --git a/modules/shared.py b/modules/shared.py index e1d3765b..b609045c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -81,7 +81,8 @@ parser.add_argument("--pin-weight", type=str2bool, nargs="?", const=True, defaul parser.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.') parser.add_argument('--nvme-offload-dir', type=str, help='DeepSpeed: Directory to use for ZeRO-3 NVME offloading.') parser.add_argument('--local_rank', type=int, default=0, help='DeepSpeed: Optional argument for distributed setups.') -parser.add_argument('--rwkv-strategy', type=str, default=None, help='The strategy to use while loading RWKV models. Examples: "cpu fp32", "cuda fp16", "cuda fp16 *30 -> cpu fp32".') +parser.add_argument('--rwkv-strategy', type=str, default=None, help='RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8".') +parser.add_argument('--rwkv-cuda-on', action='store_true', help='RWKV: Compile the CUDA kernel for better performance.') parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time. This improves the text generation performance.') parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example. If you create a file called settings.json, this file will be loaded by default without the need to use the --settings flag.') parser.add_argument('--extensions', type=str, nargs="+", help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.') diff --git a/requirements.txt b/requirements.txt index 2051dc0b..3a2ac25d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ bitsandbytes==0.37.0 flexgen==0.1.7 gradio==3.18.0 numpy -rwkv==0.0.7 +rwkv==0.0.8 safetensors==0.2.8 sentencepiece git+https://github.com/oobabooga/transformers@llama_push