diff --git a/extensions/long_replies/script.py b/extensions/long_replies/script.py index 035e8c9e..a30b05a7 100644 --- a/extensions/long_replies/script.py +++ b/extensions/long_replies/script.py @@ -28,7 +28,7 @@ class MyLogits(LogitsProcessor): def __call__(self, input_ids, scores): if input_ids.shape[-1] - initial_size < params["min_length"]: scores[...,self.newline_id] = -1000 - # scores[...,shared.tokenizer.eos_token_id] = -1000 + scores[...,shared.tokenizer.eos_token_id] = -1000 # probs = torch.softmax(scores, dim=-1, dtype=torch.float) # probs[0] /= probs[0].sum()