67 lines
2.3 KiB
Python
67 lines
2.3 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from dataclasses import dataclass
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
|
|
@dataclass
|
|
class SpecDecodeMetadata:
|
|
# [num_tokens]
|
|
draft_token_ids: torch.Tensor
|
|
# [batch_size]
|
|
num_draft_tokens: list[int]
|
|
# [batch_size]
|
|
cu_num_draft_tokens: torch.Tensor
|
|
# [batch_size]
|
|
cu_num_sampled_tokens: torch.Tensor
|
|
# [num_tokens]
|
|
target_logits_indices: torch.Tensor
|
|
# [batch_size]
|
|
bonus_logits_indices: torch.Tensor
|
|
# [num_tokens + batch_size]
|
|
logits_indices: torch.Tensor
|
|
|
|
def __post_init__(self):
|
|
self.max_spec_len = max(self.num_draft_tokens)
|
|
|
|
@classmethod
|
|
def make_dummy(
|
|
cls,
|
|
draft_token_ids: list[list[int]],
|
|
device: torch.device,
|
|
) -> "SpecDecodeMetadata":
|
|
batch_size = len(draft_token_ids)
|
|
num_draft_tokens = [len(ids) for ids in draft_token_ids]
|
|
num_sampled_tokens = [len(ids) + 1 for ids in draft_token_ids]
|
|
flattened_draft_token_ids = sum(draft_token_ids, [])
|
|
num_tokens = len(flattened_draft_token_ids)
|
|
|
|
draft_token_ids_tensor = torch.tensor(
|
|
flattened_draft_token_ids, dtype=torch.int32, device=device
|
|
)
|
|
cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32)
|
|
cu_num_draft_tokens_tensor = torch.from_numpy(cu_num_draft_tokens).to(device)
|
|
cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32)
|
|
cu_num_sampled_tokens_tensor = torch.from_numpy(cu_num_sampled_tokens).to(
|
|
device
|
|
)
|
|
|
|
target_logits_indices = torch.zeros(
|
|
num_tokens, dtype=torch.int32, device=device
|
|
)
|
|
bonus_logits_indices = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
|
logits_indices = torch.zeros(
|
|
num_tokens + batch_size, dtype=torch.int32, device=device
|
|
)
|
|
return cls(
|
|
draft_token_ids=draft_token_ids_tensor,
|
|
num_draft_tokens=num_draft_tokens,
|
|
cu_num_draft_tokens=cu_num_draft_tokens_tensor,
|
|
cu_num_sampled_tokens=cu_num_sampled_tokens_tensor,
|
|
target_logits_indices=target_logits_indices,
|
|
bonus_logits_indices=bonus_logits_indices,
|
|
logits_indices=logits_indices,
|
|
)
|