109 lines
3.7 KiB
Python
109 lines
3.7 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from dataclasses import dataclass
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
|
|
@dataclass
|
|
class SpecDecodeMetadata:
|
|
# [num_tokens]
|
|
draft_token_ids: torch.Tensor
|
|
# [batch_size]
|
|
num_draft_tokens: list[int]
|
|
# [batch_size]
|
|
cu_num_draft_tokens: torch.Tensor
|
|
# [batch_size]
|
|
cu_num_sampled_tokens: torch.Tensor
|
|
# [num_tokens]
|
|
target_logits_indices: torch.Tensor
|
|
# [batch_size]
|
|
bonus_logits_indices: torch.Tensor
|
|
# [num_tokens + batch_size]
|
|
logits_indices: torch.Tensor
|
|
|
|
def __post_init__(self):
|
|
self.max_spec_len = max(self.num_draft_tokens)
|
|
|
|
@classmethod
|
|
def make_dummy(
|
|
cls,
|
|
draft_token_ids: list[list[int]],
|
|
device: torch.device,
|
|
) -> "SpecDecodeMetadata":
|
|
batch_size = len(draft_token_ids)
|
|
num_draft_tokens = [len(ids) for ids in draft_token_ids]
|
|
num_sampled_tokens = [len(ids) + 1 for ids in draft_token_ids]
|
|
flattened_draft_token_ids = sum(draft_token_ids, [])
|
|
num_tokens = len(flattened_draft_token_ids)
|
|
|
|
draft_token_ids_tensor = torch.tensor(
|
|
flattened_draft_token_ids, dtype=torch.int32, device=device
|
|
)
|
|
cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32)
|
|
cu_num_draft_tokens_tensor = torch.from_numpy(cu_num_draft_tokens).to(device)
|
|
cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32)
|
|
cu_num_sampled_tokens_tensor = torch.from_numpy(cu_num_sampled_tokens).to(
|
|
device
|
|
)
|
|
|
|
target_logits_indices = torch.zeros(
|
|
num_tokens, dtype=torch.int32, device=device
|
|
)
|
|
bonus_logits_indices = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
|
logits_indices = torch.zeros(
|
|
num_tokens + batch_size, dtype=torch.int32, device=device
|
|
)
|
|
return cls(
|
|
draft_token_ids=draft_token_ids_tensor,
|
|
num_draft_tokens=num_draft_tokens,
|
|
cu_num_draft_tokens=cu_num_draft_tokens_tensor,
|
|
cu_num_sampled_tokens=cu_num_sampled_tokens_tensor,
|
|
target_logits_indices=target_logits_indices,
|
|
bonus_logits_indices=bonus_logits_indices,
|
|
logits_indices=logits_indices,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class MultiLayerEagleMetadata:
|
|
# [batch_size]
|
|
cached_len: torch.Tensor | None = None
|
|
# [batch_size, layer_num]
|
|
cached_token_ids: torch.Tensor | None = None
|
|
# [batch_size, layer_num, hidden_size]
|
|
cached_hidden_states: torch.Tensor | None = None
|
|
# [batch_size, layer_num]
|
|
cached_slot_mappings: torch.Tensor | None = None
|
|
# [batch_size, layer_num]
|
|
cached_positions: torch.Tensor | None = None
|
|
|
|
@classmethod
|
|
def make_dummy(
|
|
cls,
|
|
layer_num: int,
|
|
hidden_size: int,
|
|
device: torch.device,
|
|
) -> "MultiLayerEagleMetadata":
|
|
cached_len = torch.zeros((1), dtype=torch.int64, device=device)
|
|
cached_token_ids = torch.zeros(
|
|
(1, layer_num), dtype=torch.int32, device=device
|
|
)
|
|
cached_hidden_states = torch.zeros(
|
|
(1, layer_num, hidden_size), dtype=torch.float32, device=device
|
|
)
|
|
cached_slot_mappings = torch.zeros(
|
|
(1, layer_num), dtype=torch.int64, device=device
|
|
)
|
|
cached_positions = torch.zeros(
|
|
(1, layer_num), dtype=torch.int64, device=device
|
|
)
|
|
return cls(
|
|
cached_len=cached_len,
|
|
cached_token_ids=cached_token_ids,
|
|
cached_hidden_states=cached_hidden_states,
|
|
cached_slot_mappings=cached_slot_mappings,
|
|
cached_positions=cached_positions,
|
|
)
|