Skip the flaky test_stateful_custom_logit_processor (#6251)

This commit is contained in:
Lianmin Zheng
2025-05-12 18:29:41 -07:00
committed by GitHub
parent ef8ec07b2c
commit ac2324c177
2 changed files with 31 additions and 7 deletions

View File

@@ -28,11 +28,26 @@ class CustomLogitProcessor(ABC):
"""Define the callable behavior."""
raise NotImplementedError
def to_str(self) -> str:
@classmethod
def to_str(cls) -> str:
"""Serialize the callable function to a JSON-compatible string."""
return json.dumps({"callable": dill.dumps(self).hex()})
return json.dumps({"callable": dill.dumps(cls).hex()})
@classmethod
def from_str(cls, json_str: str):
"""Deserialize a callable function from a JSON string."""
return _cache_from_str(json_str)
return _cache_from_str(json_str)()
class DisallowedTokensLogitsProcessor(CustomLogitProcessor):
def __call__(
self,
logits: torch.Tensor,
custom_param_list: Optional[List[Dict[str, Any]]] = None,
) -> torch.Tensor:
disallowed_token_ids = custom_param_list[0]["token_ids"]
assert all(
disallowed_token_ids == c["token_ids"] for c in custom_param_list
), f"{custom_param_list=}"
logits[..., disallowed_token_ids] = -float("inf")
return logits