import msgspec import torch from typing import Any, Callable, Optional, Union class SequenceData(msgspec.Struct, omit_defaults=True): # type: ignore[call-arg] def append_token_id(self, token_id: int, logprob: float, token_embed: Optional[torch.Tensor] = None) -> None: self._output_token_ids.append(token_id) self._new_appended_tokens.append(token_id) self._cached_all_token_ids.append(token_id) self._cumulative_logprob += logprob if token_embed is not None: self._cached_all_token_embeds = token_embed.unsqueeze(0) # assert token_embed.ndim == 1 # token_embed = token_embed.detach().cpu().unsqueeze(0) # if self._output_embeds is None: # self._output_embeds = token_embed # else: # self._output_embeds = torch.cat( # (self._output_embeds, token_embed), dim=0) # assert self._cached_all_token_embeds is not None # self._cached_all_token_embeds = torch.cat( # (self._cached_all_token_embeds, # token_embed.to(device=self._cached_all_token_embeds.device)), # dim=0)