Small fix to make transformers 4.42 functional

This commit is contained in:
oobabooga 2024-06-27 17:05:29 -07:00
parent 66090758df
commit 9dbcb1aeea

View file

@ -359,14 +359,14 @@ class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
return scores return scores
def get_logits_warper_patch(self, generation_config): def get_logits_warper_patch(self, generation_config, **kwargs):
# Parameter sanitization # Parameter sanitization
if isinstance(generation_config.temperature, int): if isinstance(generation_config.temperature, int):
generation_config.temperature = float(generation_config.temperature) # Must be float generation_config.temperature = float(generation_config.temperature) # Must be float
# Get the original warpers # Get the original warpers
warpers = self._get_logits_warper_old(generation_config) warpers = self._get_logits_warper_old(generation_config, **kwargs)
# Replace temperature with our modified class. # Replace temperature with our modified class.
# Currently, it behaves identically to the original. # Currently, it behaves identically to the original.