# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass import numpy as np import torch @dataclass class SpecDecodeMetadata: # [num_tokens] draft_token_ids: torch.Tensor # [batch_size] num_draft_tokens: list[int] # [batch_size] cu_num_draft_tokens: torch.Tensor # [batch_size] cu_num_sampled_tokens: torch.Tensor # [num_tokens] target_logits_indices: torch.Tensor # [batch_size] bonus_logits_indices: torch.Tensor # [num_tokens + batch_size] logits_indices: torch.Tensor def __post_init__(self): self.max_spec_len = max(self.num_draft_tokens) @classmethod def make_dummy( cls, draft_token_ids: list[list[int]], device: torch.device, ) -> "SpecDecodeMetadata": batch_size = len(draft_token_ids) num_draft_tokens = [len(ids) for ids in draft_token_ids] num_sampled_tokens = [len(ids) + 1 for ids in draft_token_ids] flattened_draft_token_ids = sum(draft_token_ids, []) num_tokens = len(flattened_draft_token_ids) draft_token_ids_tensor = torch.tensor( flattened_draft_token_ids, dtype=torch.int32, device=device ) cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32) cu_num_draft_tokens_tensor = torch.from_numpy(cu_num_draft_tokens).to(device) cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32) cu_num_sampled_tokens_tensor = torch.from_numpy(cu_num_sampled_tokens).to( device ) target_logits_indices = torch.zeros( num_tokens, dtype=torch.int32, device=device ) bonus_logits_indices = torch.zeros(batch_size, dtype=torch.int32, device=device) logits_indices = torch.zeros( num_tokens + batch_size, dtype=torch.int32, device=device ) return cls( draft_token_ids=draft_token_ids_tensor, num_draft_tokens=num_draft_tokens, cu_num_draft_tokens=cu_num_draft_tokens_tensor, cu_num_sampled_tokens=cu_num_sampled_tokens_tensor, target_logits_indices=target_logits_indices, bonus_logits_indices=bonus_logits_indices, logits_indices=logits_indices, )