ctransformers: move thread and seed parameters (#3543)

This commit is contained in:
cal066 2023-08-13 03:04:03 +00:00 committed by GitHub
parent 73421b1fed
commit bf70c19603
WARNING! Although there is a key with this ID in the database it does not verify this commit! This commit is SUSPICIOUS.
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 6 deletions

View file

@ -13,14 +13,12 @@ class CtransformersModel:
def from_pretrained(self, path): def from_pretrained(self, path):
result = self() result = self()
# ctransformers uses -1 for random seed
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
str(path), str(path),
threads=shared.args.threads, threads=shared.args.threads,
gpu_layers=shared.args.n_gpu_layers, gpu_layers=shared.args.n_gpu_layers,
batch_size=shared.args.n_batch, batch_size=shared.args.n_batch,
stream=True, stream=True
seed=(-1 if shared.args.llama_cpp_seed == 0 else shared.args.llama_cpp_seed)
) )
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
@ -49,6 +47,7 @@ class CtransformersModel:
def generate(self, prompt, state, callback=None): def generate(self, prompt, state, callback=None):
prompt = prompt if type(prompt) is str else prompt.decode() prompt = prompt if type(prompt) is str else prompt.decode()
# ctransformers uses -1 for random seed
generator = self.model._stream( generator = self.model._stream(
prompt=prompt, prompt=prompt,
max_new_tokens=state['max_new_tokens'], max_new_tokens=state['max_new_tokens'],
@ -57,7 +56,7 @@ class CtransformersModel:
top_k=state['top_k'], top_k=state['top_k'],
repetition_penalty=state['repetition_penalty'], repetition_penalty=state['repetition_penalty'],
last_n_tokens=state['repetition_penalty_range'], last_n_tokens=state['repetition_penalty_range'],
threads=shared.args.threads seed=state['seed']
) )
output = "" output = ""

View file

@ -95,8 +95,7 @@ loaders_and_params = OrderedDict({
'n_gpu_layers', 'n_gpu_layers',
'n_batch', 'n_batch',
'threads', 'threads',
'model_type', 'model_type'
'llama_cpp_seed',
] ]
}) })