[Refactor] Refactor Spec Decode (#2668)

### What this PR does / why we need it?
Refactor spec decode

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
CI passed with new added/existing test.


- vLLM version: v0.10.1.1
- vLLM main:
6997a25ac6

---------

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: Icey <1790571317@qq.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
Icey
2025-09-04 11:34:47 +08:00
committed by GitHub
parent 7e16b4a7cd
commit d4370ebc42
7 changed files with 1021 additions and 932 deletions

View File

@@ -86,14 +86,16 @@ from vllm_ascend.compilation.acl_graph import ACLGraphWrapper
from vllm_ascend.multistream.ms_split import compute_split_seq_index
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
from vllm_ascend.spec_decode import get_spec_decode_method
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
from vllm_ascend.spec_decode.interface import SpecDcodeType
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.torchair.torchair_attention import AscendTorchairMetadata
from vllm_ascend.torchair.torchair_mla import AscendMLATorchairMetadata
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
AscendSocVersion, ProfileExecuteDuration,
get_ascend_soc_version, is_310p,
lmhead_tp_enable, vllm_version_is)
from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer
from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
if not (vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1")):
@@ -235,16 +237,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.model_config.max_model_len, self.dtype)
# Set up speculative decoding.
self.use_aux_hidden_state_outputs = False
self.use_spec_decode = False
self.spec_attn_mask = None
self.use_eagle = False
self.drafter: Optional[Union[NgramProposer, EagleProposer,
MtpProposer]] = None
self.actual_seq_lengths_q = []
self.decode_token_per_req = 1
if self.speculative_config:
self.use_spec_decode = True
spec_token_num = self.speculative_config.num_speculative_tokens
assert spec_token_num > 0
self.decode_token_per_req = 1 + spec_token_num
@@ -258,19 +256,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
dtype=torch.bool),
diagonal=1).to(self.device)
if get_pp_group().is_last_rank:
if self.speculative_config.method == "ngram":
self.drafter = NgramProposer(self.vllm_config)
elif self.speculative_config.method in ["eagle", "eagle3"]:
self.use_eagle = True
self.drafter = EagleProposer(self.vllm_config, self.device,
self) # type: ignore
if self.speculative_config.method == "eagle3":
self.use_aux_hidden_state_outputs = True
elif self.speculative_config.method == 'deepseek_mtp':
self.drafter = MtpProposer(self.vllm_config, self)
else:
raise ValueError("Unknown speculative decoding method: "
f"{self.speculative_config.method}")
self.drafter = get_spec_decode_method(
self.speculative_config.method, self.vllm_config,
self.device, self)
self.rejection_sampler = AscendRejectionSampler()
# Persistent batch.
@@ -649,152 +637,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
return False
return True
def get_eagle_atten_dict(
self,
scheduler_output: "SchedulerOutput",
) -> dict[str, Union[AscendMetadata, AscendMLAMetadata,
AscendTorchairMetadata, AscendMLATorchairMetadata]]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs
assert num_reqs > 0
# OPTIMIZATION: Start copying the block table first.
# This way, we can overlap the copy with the following CPU operations.
self.input_batch.block_table.commit_block_table(num_reqs)
# Get the number of scheduled tokens for each request.
req_ids = self.input_batch.req_ids
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
max_num_scheduled_tokens = max(tokens)
self.query_lens = torch.from_numpy(num_scheduled_tokens)
# Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
req_indices = np.repeat(self.arange_np[:num_reqs],
num_scheduled_tokens)
# cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
# arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
cu_num_tokens, arange = self._get_cumsum_and_arange(
num_scheduled_tokens)
# Get positions.
positions_np = self.positions_np[:total_num_scheduled_tokens]
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
arange,
out=positions_np)
# Calculate M-RoPE positions.
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
self._calc_mrope_positions(scheduler_output)
# Get token indices.
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
# where M is the max_model_len.
token_indices = (positions_np +
req_indices * self.input_batch.token_ids_cpu.shape[1])
# NOTE(woosuk): We use torch.index_select instead of np.take here
# because torch.index_select is much faster than np.take for large
# tensors.
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
0,
torch.from_numpy(token_indices),
out=self.input_ids_cpu[:total_num_scheduled_tokens])
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
# NOTE(Chen): there is exactly one KV cache group that contains all
# attetnion layers in the model for now, so the current logic for
# getting attn_metadata is not related to kv_cache_group information.
# Will extend this part to support multiple KV cache groups later.
for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
block_size = kv_cache_group_spec.kv_cache_spec.block_size
block_table = self.input_batch.block_table[kv_cache_group_id]
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
# where K is the max_num_blocks_per_req and the block size is 2.
# NOTE(woosuk): We can't simply use `token_indices // block_size`
# here because M (max_model_len) is not necessarily divisible by
# block_size.
block_table_indices = (
req_indices * block_table.max_num_blocks_per_req +
positions_np // block_size)
block_table_cpu = block_table.get_cpu_tensor()
block_numbers = block_table_cpu.flatten(
)[block_table_indices].numpy()
block_offsets = positions_np % block_size
np.add(
block_numbers * block_size,
block_offsets,
out=block_table.slot_mapping_np[:total_num_scheduled_tokens])
# Prepare the attention metadata.
self.query_start_loc_np[0] = 0
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
self.seq_lens_np[:num_reqs] = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
num_scheduled_tokens)
# Copy the tensors to the NPU.
self.input_ids[:total_num_scheduled_tokens].copy_(
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
if self.uses_mrope:
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
non_blocking=True)
else:
# Common case (1D positions)
self.positions[:total_num_scheduled_tokens].copy_(
self.positions_cpu[:total_num_scheduled_tokens],
non_blocking=True)
self.query_start_loc[:num_reqs + 1].copy_(
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
non_blocking=True)
# Fill unused with -1. Needed for reshape_and_cache
self.seq_lens[num_reqs:].fill_(0)
self.query_start_loc[num_reqs + 1:].fill_(-1)
attn_metadata: dict[str, Union[AscendMetadata, AscendMLAMetadata,
AscendTorchairMetadata,
AscendMLATorchairMetadata]] = {}
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=self.query_start_loc[:num_reqs + 1],
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
seq_lens_cpu=self.seq_lens_cpu,
num_reqs=num_reqs,
max_query_len=max_num_scheduled_tokens,
num_actual_tokens=total_num_scheduled_tokens,
actual_seq_lengths_q=self.actual_seq_lengths_q,
block_table_tensor=self.input_batch.block_table[0].
get_device_tensor(),
slot_mapping_cpu=self.slot_mapping_cpu,
positions=self.positions,
attn_mask=self.attn_mask,
spec_attn_mask=self.spec_attn_mask,
attn_state=self.attn_state,
decode_token_per_req=self.decode_token_per_req,
)
attn_metadata_i = self.attn_metadata_builder.build(
common_attn_metadata, self.get_model())
for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i
return attn_metadata
def get_model(self) -> nn.Module:
# get raw model out of the aclgraph wrapper.
if isinstance(self.model, ACLGraphWrapper):
@@ -1317,7 +1159,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
attn_state = AscendAttentionState.SpecDecoding
# Speculative decoding.
elif np.all(num_valid_tokens == 1):
if self.use_eagle:
if self.drafter and (self.drafter.name == SpecDcodeType.EAGLE
or self.drafter.name == SpecDcodeType.EAGLE3):
attn_state = AscendAttentionState.ChunkedPrefill
else:
attn_state = AscendAttentionState.SpecDecoding
@@ -1338,26 +1181,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
positions = self.mrope_positions[:, :num_input_tokens]
return input_ids, positions
def _get_cumsum_and_arange(
self,
num_tokens: np.ndarray,
cumsum_dtype: Optional[np.dtype] = None,
) -> tuple[np.ndarray, np.ndarray]:
"""Get the cumulative sum and batched arange of the given array.
# E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2])
# Equivalent to but faster than:
# np.concatenate([np.arange(n) for n in num_tokens])
"""
# Step 1. [2, 5, 3] -> [2, 7, 10]
cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype)
total_num_tokens = cu_num_tokens[-1]
# Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7]
cumsums_offsets = np.repeat(cu_num_tokens - num_tokens, num_tokens)
# Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
arange = self.arange_np[:total_num_tokens] - cumsums_offsets
return cu_num_tokens, arange
def _calc_spec_decode_metadata(
self,
num_draft_tokens: np.ndarray,
@@ -1511,24 +1334,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
AscendMLATorchairMetadata],
aux_hidden_states: torch.Tensor = None,
) -> Optional[list[list[int]]]:
if not self.use_spec_decode:
if not self.drafter:
# Speculative decoding is not enabled.
draft_token_ids = None
elif self.speculative_config.method == "ngram":
draft_token_ids = self._generate_ngram_token_ids(
valid_sampled_token_ids)
elif self.speculative_config.method == "eagle":
raise NotImplementedError("Eagle Is Not Supported Yet.")
elif self.speculative_config.method == "eagle3":
draft_token_ids = self._generate_eagle3_token_ids(
else:
draft_token_ids = self.drafter.generate_token_ids(
valid_sampled_token_ids, sampling_metadata, scheduler_output,
spec_decode_metadata, positions, num_scheduled_tokens,
hidden_states, aux_hidden_states)
elif self.speculative_config.method == 'deepseek_mtp':
draft_token_ids = self._generate_mtp_token_ids(
valid_sampled_token_ids, sampling_metadata, scheduler_output,
spec_decode_metadata, positions, num_scheduled_tokens,
hidden_states, attn_metadata)
hidden_states, attn_metadata, aux_hidden_states)
return draft_token_ids
def _pool_v010(
@@ -1692,7 +1505,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
scheduler_output)
aux_hidden_states = None
if self.use_aux_hidden_state_outputs:
if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3:
hidden_states, aux_hidden_states = hidden_states
kv_connector_output = None
@@ -1981,12 +1794,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds)
if self.use_aux_hidden_state_outputs:
if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3:
hidden_states, _ = hidden_states
else:
hidden_states = hidden_states
if self.use_spec_decode and isinstance(self.drafter, EagleProposer):
self.drafter.dummy_run(num_tokens)
return hidden_states
@torch.inference_mode()
@@ -2137,8 +1948,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if need_dummy_logits:
dummy_compute_logits(hidden_states)
if self.speculative_config and self.speculative_config.method == "deepseek_mtp":
assert isinstance(self.drafter, MtpProposer)
if self.drafter:
self.drafter.dummy_run(
num_tokens=num_tokens,
with_prefill=with_prefill,
@@ -2296,13 +2106,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
module.weight.data)
if self.drafter:
logger.info("Loading drafter model...")
if isinstance(self.drafter, EagleProposer):
if self.use_aux_hidden_state_outputs:
self.drafter.load_model(self.model)
self.model.set_aux_hidden_state_layers(
self.model.get_eagle3_aux_hidden_state_layers())
else:
self.drafter.load_model()
self.drafter.load_model(self.model)
if self.drafter.name == SpecDcodeType.EAGLE3:
self.model.set_aux_hidden_state_layers(
self.model.get_eagle3_aux_hidden_state_layers())
if self.lora_config:
self.model = self.load_lora_model(self.model,
self.model_config,
@@ -2590,193 +2398,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
elapsed_time, npu_graph_size / (1 << 30))
def _generate_ngram_token_ids(
self,
sampled_token_ids: list[list[int]],
) -> list[list[int]]:
# TODO(woosuk): Optimize.
draft_token_ids: list[list[int]] = []
for i, sampled_ids in enumerate(sampled_token_ids):
num_sampled_ids = len(sampled_ids)
if not num_sampled_ids:
# Skip speculative decoding.
draft_token_ids.append([])
continue
# Skip requests that require top-p, top-k, etc.
req_id = self.input_batch.req_ids[i]
if req_id in self.input_batch.spec_decode_unsupported_reqs:
draft_token_ids.append([])
continue
# Add sampled_token_ids to token_ids_cpu.
start_idx = self.input_batch.num_tokens_no_spec[i]
end_idx = start_idx + num_sampled_ids
self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
assert isinstance(self.drafter, NgramProposer)
drafter_output = self.drafter.propose(
self.input_batch.token_ids_cpu[i, :end_idx])
if drafter_output is None or len(drafter_output) == 0:
draft_token_ids.append([])
else:
draft_token_ids.append(drafter_output.tolist())
return draft_token_ids
def _generate_eagle3_token_ids(self,
valid_sampled_token_ids: list[list[int]],
sampling_metadata: SamplingMetadata,
scheduler_output: "SchedulerOutput",
spec_decode_metadata: SpecDecodeMetadata,
positions: torch.Tensor,
num_scheduled_tokens: int,
hidden_states: torch.Tensor,
aux_hidden_states: torch.Tensor = None):
assert isinstance(self.drafter, EagleProposer)
attn_metadata = self.get_eagle_atten_dict(scheduler_output)
next_token_ids: list[int] = []
for i, token_ids in enumerate(valid_sampled_token_ids):
if token_ids:
# Common case.
next_token_id = token_ids[-1]
else:
# Partial prefill (rare case).
# Get the next token id from the request state.
req_id = self.input_batch.req_ids[i]
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
next_token_id = req_state.get_token_id(seq_len)
next_token_ids.append(next_token_id)
next_token_ids = torch.tensor(next_token_ids,
dtype=torch.int32,
device=self.device)
eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name]
if spec_decode_metadata is None:
# input_ids can be None for multimodal models.
target_token_ids = self.input_ids[:num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat(
[h[:num_scheduled_tokens] for h in aux_hidden_states],
dim=-1)
else:
target_hidden_states = hidden_states[:num_scheduled_tokens]
target_slot_mapping = eagle_attn_metadata.slot_mapping
cu_num_tokens = eagle_attn_metadata.query_start_loc
else:
num_draft_tokens = spec_decode_metadata.num_draft_tokens
num_rejected_tokens = [
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
for i, n in enumerate(num_draft_tokens)
]
num_rejected_tokens = torch.tensor(
num_rejected_tokens,
dtype=torch.int32,
device=self.device,
)
num_tokens = num_scheduled_tokens - sum(num_rejected_tokens)
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
eagle_attn_metadata.query_start_loc, num_rejected_tokens,
num_tokens)
target_token_ids = self.input_ids[token_indices]
target_positions = positions[token_indices]
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat(
[h[token_indices] for h in aux_hidden_states], dim=-1)
else:
target_hidden_states = hidden_states[token_indices]
target_slot_mapping = eagle_attn_metadata.slot_mapping[
token_indices]
draft_token_ids = self.drafter.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
target_slot_mapping=target_slot_mapping,
next_token_ids=next_token_ids,
cu_num_tokens=cu_num_tokens,
block_table=eagle_attn_metadata.block_tables,
sampling_metadata=sampling_metadata,
)
spec_token_ids = draft_token_ids.tolist()
return spec_token_ids
def _generate_mtp_token_ids(
self,
valid_sampled_token_ids: list[list[int]],
sampling_metadata: SamplingMetadata,
scheduler_output: "SchedulerOutput",
spec_decode_metadata: SpecDecodeMetadata,
positions: torch.Tensor,
num_scheduled_tokens: int,
hidden_states: torch.Tensor,
attn_metadata: Union[AscendMetadata, AscendMLAMetadata,
AscendTorchairMetadata,
AscendMLATorchairMetadata],
):
assert isinstance(self.drafter, MtpProposer)
next_token_ids: list[int] = []
for i, token_ids in enumerate(valid_sampled_token_ids):
if token_ids:
# Common case.
next_token_id = token_ids[-1]
else:
# Partial prefill (rare case).
# Get the next token id from the request state.
req_id = self.input_batch.req_ids[i]
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
next_token_id = req_state.get_token_id(seq_len)
next_token_ids.append(next_token_id)
next_token_ids = torch.tensor(next_token_ids,
dtype=torch.int32,
device=self.device)
accepted_token_indices = None
if spec_decode_metadata is None:
# input_ids can be None for multimodal models.
target_token_ids = self.input_ids[:num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
target_hidden_states = hidden_states[:num_scheduled_tokens]
target_slot_mapping = attn_metadata.slot_mapping
cu_num_tokens = attn_metadata.query_start_loc
else:
# TODO(woosuk): Refactor this.
num_draft_tokens = spec_decode_metadata.num_draft_tokens
num_rejected_tokens = [
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
for i, n in enumerate(num_draft_tokens)
]
num_rejected_tokens = torch.tensor(
num_rejected_tokens,
dtype=torch.int32,
device=self.device,
)
cu_num_tokens, accepted_token_indices, target_token_ids, \
target_positions, target_hidden_states, target_slot_mapping = self.drafter.prepare_inputs(
attn_metadata.query_start_loc,
num_rejected_tokens,
self.input_ids[:num_scheduled_tokens],
positions[:num_scheduled_tokens],
hidden_states[:num_scheduled_tokens],
attn_metadata.slot_mapping[:num_scheduled_tokens],
is_torchair_graph=self._build_drafter_prepare_inputs_torchair_param(),
)
draft_token_ids = self.drafter.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
target_slot_mapping=target_slot_mapping,
next_token_ids=next_token_ids,
cu_num_tokens=cu_num_tokens,
block_table=attn_metadata.block_tables,
sampling_metadata=sampling_metadata,
token_indices=accepted_token_indices)
spec_token_ids = draft_token_ids.tolist()
return spec_token_ids
def _get_prompt_logprobs_dict(
self,
hidden_states: torch.Tensor,