From 9dcb37e8d4fafb5c1b59f7a56e25fcb9c21e1398 Mon Sep 17 00:00:00 2001 From: Forkoz <59298527+Ph0rk0z@users.noreply.github.com> Date: Sat, 5 Aug 2023 16:45:47 +0000 Subject: [PATCH] Fix: Mirostat fails on models split across multiple GPUs --- modules/sampler_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sampler_hijack.py b/modules/sampler_hijack.py index 0a86b4fd..d5ebbb76 100644 --- a/modules/sampler_hijack.py +++ b/modules/sampler_hijack.py @@ -104,7 +104,7 @@ class MirostatLogitsWarper(LogitsWarper): break # Normalize the probabilities of the remaining words - prob_topk = torch.softmax(sorted_logits, dim=0) + prob_topk = torch.softmax(sorted_logits, dim=0).to('cuda') prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True).to('cuda')