32 lines
1.3 KiB
Python
32 lines
1.3 KiB
Python
|
|
import msgspec
|
|
import torch
|
|
from typing import Any, Callable, Optional, Union
|
|
|
|
|
|
class SequenceData(msgspec.Struct,
|
|
omit_defaults=True): # type: ignore[call-arg]
|
|
|
|
def append_token_id(self,
|
|
token_id: int,
|
|
logprob: float,
|
|
token_embed: Optional[torch.Tensor] = None) -> None:
|
|
self._output_token_ids.append(token_id)
|
|
self._new_appended_tokens.append(token_id)
|
|
self._cached_all_token_ids.append(token_id)
|
|
self._cumulative_logprob += logprob
|
|
if token_embed is not None:
|
|
self._cached_all_token_embeds = token_embed.unsqueeze(0)
|
|
|
|
# assert token_embed.ndim == 1
|
|
# token_embed = token_embed.detach().cpu().unsqueeze(0)
|
|
# if self._output_embeds is None:
|
|
# self._output_embeds = token_embed
|
|
# else:
|
|
# self._output_embeds = torch.cat(
|
|
# (self._output_embeds, token_embed), dim=0)
|
|
# assert self._cached_all_token_embeds is not None
|
|
# self._cached_all_token_embeds = torch.cat(
|
|
# (self._cached_all_token_embeds,
|
|
# token_embed.to(device=self._cached_all_token_embeds.device)),
|
|
# dim=0) |