Skip to content

Commit

Permalink
Make transformers 4.49 functional
Browse files Browse the repository at this point in the history
  • Loading branch information
oobabooga committed Feb 18, 2025
1 parent 16f4f1a commit dba17c4
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions modules/sampler_hijack.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch
import transformers
from transformers import LogitsWarper
from transformers import LogitsProcessor
from transformers.generation.logits_process import (
LogitNormalization,
LogitsProcessor,
Expand All @@ -19,7 +19,7 @@
global_scores = None


class TemperatureLogitsWarperCustom(LogitsWarper):
class TemperatureLogitsWarperCustom(LogitsProcessor):
'''
A copy of the original Transformers temperature logits warper.
'''
Expand All @@ -42,7 +42,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
return scores


class DynamicTemperatureLogitsWarper(LogitsWarper):
class DynamicTemperatureLogitsWarper(LogitsProcessor):
'''
Dynamic temperature.
'''
Expand Down Expand Up @@ -100,7 +100,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
return scores


class QuadraticSamplingLogitsWarper(LogitsWarper):
class QuadraticSamplingLogitsWarper(LogitsProcessor):
'''
Quadratic sampling with smoothing factor and smoothing curve parameters.
'''
Expand All @@ -127,7 +127,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
return transformed_logits


class TailFreeLogitsWarper(LogitsWarper):
class TailFreeLogitsWarper(LogitsProcessor):
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
tfs = float(tfs)
if tfs < 0 or tfs > 1.0:
Expand Down Expand Up @@ -167,7 +167,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
return scores


class TopALogitsWarper(LogitsWarper):
class TopALogitsWarper(LogitsProcessor):
def __init__(self, top_a: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
top_a = float(top_a)
if top_a < 0 or top_a > 1.0:
Expand All @@ -194,7 +194,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to


# Exclude Top Choices (XTC)
class XTCLogitsWarper(LogitsWarper):
class XTCLogitsWarper(LogitsProcessor):
def __init__(self, threshold: float, probability: float, filter_value: float = -float("Inf")):
self.threshold = threshold
self.probability = probability
Expand Down Expand Up @@ -312,7 +312,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
return scores


class MirostatLogitsWarper(LogitsWarper):
class MirostatLogitsWarper(LogitsProcessor):
def __init__(self, mirostat_mode: int, mirostat_tau: float, mirostat_eta: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if mirostat_mode not in [2]:
raise ValueError(f"`mirostat` has to be a an integer 2, but is {mirostat_mode}")
Expand Down Expand Up @@ -361,7 +361,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
return scores


class SpyLogitsWarper(LogitsWarper):
class SpyLogitsWarper(LogitsProcessor):
def __init__(self):
pass

Expand Down

0 comments on commit dba17c4

Please sign in to comment.