diff --git a/modules/sampler_hijack.py b/modules/sampler_hijack.py index ad74d658..9fb661ae 100644 --- a/modules/sampler_hijack.py +++ b/modules/sampler_hijack.py @@ -359,14 +359,14 @@ class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor): return scores -def get_logits_warper_patch(self, generation_config): +def get_logits_warper_patch(self, generation_config, **kwargs): # Parameter sanitization if isinstance(generation_config.temperature, int): generation_config.temperature = float(generation_config.temperature) # Must be float # Get the original warpers - warpers = self._get_logits_warper_old(generation_config) + warpers = self._get_logits_warper_old(generation_config, **kwargs) # Replace temperature with our modified class. # Currently, it behaves identically to the original.