Files
2026-04-02 04:55:00 +00:00

32 lines
1.3 KiB
Python

import msgspec
import torch
from typing import Any, Callable, Optional, Union
class SequenceData(msgspec.Struct,
omit_defaults=True): # type: ignore[call-arg]
def append_token_id(self,
token_id: int,
logprob: float,
token_embed: Optional[torch.Tensor] = None) -> None:
self._output_token_ids.append(token_id)
self._new_appended_tokens.append(token_id)
self._cached_all_token_ids.append(token_id)
self._cumulative_logprob += logprob
if token_embed is not None:
self._cached_all_token_embeds = token_embed.unsqueeze(0)
# assert token_embed.ndim == 1
# token_embed = token_embed.detach().cpu().unsqueeze(0)
# if self._output_embeds is None:
# self._output_embeds = token_embed
# else:
# self._output_embeds = torch.cat(
# (self._output_embeds, token_embed), dim=0)
# assert self._cached_all_token_embeds is not None
# self._cached_all_token_embeds = torch.cat(
# (self._cached_all_token_embeds,
# token_embed.to(device=self._cached_all_token_embeds.device)),
# dim=0)