init
This commit is contained in:
32
vllm_vacc/vllm/sequence.py
Normal file
32
vllm_vacc/vllm/sequence.py
Normal file
@@ -0,0 +1,32 @@
|
||||
|
||||
import msgspec
|
||||
import torch
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
|
||||
class SequenceData(msgspec.Struct,
|
||||
omit_defaults=True): # type: ignore[call-arg]
|
||||
|
||||
def append_token_id(self,
|
||||
token_id: int,
|
||||
logprob: float,
|
||||
token_embed: Optional[torch.Tensor] = None) -> None:
|
||||
self._output_token_ids.append(token_id)
|
||||
self._new_appended_tokens.append(token_id)
|
||||
self._cached_all_token_ids.append(token_id)
|
||||
self._cumulative_logprob += logprob
|
||||
if token_embed is not None:
|
||||
self._cached_all_token_embeds = token_embed.unsqueeze(0)
|
||||
|
||||
# assert token_embed.ndim == 1
|
||||
# token_embed = token_embed.detach().cpu().unsqueeze(0)
|
||||
# if self._output_embeds is None:
|
||||
# self._output_embeds = token_embed
|
||||
# else:
|
||||
# self._output_embeds = torch.cat(
|
||||
# (self._output_embeds, token_embed), dim=0)
|
||||
# assert self._cached_all_token_embeds is not None
|
||||
# self._cached_all_token_embeds = torch.cat(
|
||||
# (self._cached_all_token_embeds,
|
||||
# token_embed.to(device=self._cached_all_token_embeds.device)),
|
||||
# dim=0)
|
||||
Reference in New Issue
Block a user