Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -64,3 +64,45 @@ class SpecDecodeMetadata:
|
||||
bonus_logits_indices=bonus_logits_indices,
|
||||
logits_indices=logits_indices,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiLayerEagleMetadata:
|
||||
# [batch_size]
|
||||
cached_len: torch.Tensor | None = None
|
||||
# [batch_size, layer_num]
|
||||
cached_token_ids: torch.Tensor | None = None
|
||||
# [batch_size, layer_num, hidden_size]
|
||||
cached_hidden_states: torch.Tensor | None = None
|
||||
# [batch_size, layer_num]
|
||||
cached_slot_mappings: torch.Tensor | None = None
|
||||
# [batch_size, layer_num]
|
||||
cached_positions: torch.Tensor | None = None
|
||||
|
||||
@classmethod
|
||||
def make_dummy(
|
||||
cls,
|
||||
layer_num: int,
|
||||
hidden_size: int,
|
||||
device: torch.device,
|
||||
) -> "MultiLayerEagleMetadata":
|
||||
cached_len = torch.zeros((1), dtype=torch.int64, device=device)
|
||||
cached_token_ids = torch.zeros(
|
||||
(1, layer_num), dtype=torch.int32, device=device
|
||||
)
|
||||
cached_hidden_states = torch.zeros(
|
||||
(1, layer_num, hidden_size), dtype=torch.float32, device=device
|
||||
)
|
||||
cached_slot_mappings = torch.zeros(
|
||||
(1, layer_num), dtype=torch.int64, device=device
|
||||
)
|
||||
cached_positions = torch.zeros(
|
||||
(1, layer_num), dtype=torch.int64, device=device
|
||||
)
|
||||
return cls(
|
||||
cached_len=cached_len,
|
||||
cached_token_ids=cached_token_ids,
|
||||
cached_hidden_states=cached_hidden_states,
|
||||
cached_slot_mappings=cached_slot_mappings,
|
||||
cached_positions=cached_positions,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user