Skip the flaky test_stateful_custom_logit_processor (#6251)
This commit is contained in:
@@ -28,11 +28,26 @@ class CustomLogitProcessor(ABC):
|
||||
"""Define the callable behavior."""
|
||||
raise NotImplementedError
|
||||
|
||||
def to_str(self) -> str:
|
||||
@classmethod
|
||||
def to_str(cls) -> str:
|
||||
"""Serialize the callable function to a JSON-compatible string."""
|
||||
return json.dumps({"callable": dill.dumps(self).hex()})
|
||||
return json.dumps({"callable": dill.dumps(cls).hex()})
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, json_str: str):
|
||||
"""Deserialize a callable function from a JSON string."""
|
||||
return _cache_from_str(json_str)
|
||||
return _cache_from_str(json_str)()
|
||||
|
||||
|
||||
class DisallowedTokensLogitsProcessor(CustomLogitProcessor):
|
||||
def __call__(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
custom_param_list: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> torch.Tensor:
|
||||
disallowed_token_ids = custom_param_list[0]["token_ids"]
|
||||
assert all(
|
||||
disallowed_token_ids == c["token_ids"] for c in custom_param_list
|
||||
), f"{custom_param_list=}"
|
||||
logits[..., disallowed_token_ids] = -float("inf")
|
||||
return logits
|
||||
|
||||
Reference in New Issue
Block a user