[SpecDecode] Fix Draft model proposer (#7230)

### What this PR does / why we need it?
This pr fix the Unified draft parallel feature. 
1. In Draft model proposer, there are exceed 1 attention layers in
target model, thus removing the assertion on layer number.
2. we should get block size through `draft_attn_groups` instead of
`attn_metadata_builder` after 0.17.0.
3. `attn_update_stack_num_spec_norm` shouldn't be done when unified
draft parallel is enabled

### How was this patch tested?
Test pass with
`tests/e2e/singlecard/spec_decode/test_v1_spec_decode.py::test_parallel_drafting_acceptance`,
which is already included in CI

- vLLM version: v0.17.0
- vLLM main:
4034c3d32e

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2026-03-14 18:26:37 +08:00
committed by GitHub
parent 0ad52517a1
commit e7aa2c285c
2 changed files with 33 additions and 29 deletions

View File

@@ -195,7 +195,6 @@ class SpecDecodeBaseProposer(EagleProposer):
all_indexer_layer_names = set(get_layers_from_vllm_config(self.vllm_config, DeepseekV32IndexerCache).keys()) all_indexer_layer_names = set(get_layers_from_vllm_config(self.vllm_config, DeepseekV32IndexerCache).keys())
self._draft_attn_layer_names = set(all_attn_layers.keys()) - target_attn_layer_names - all_indexer_layer_names self._draft_attn_layer_names = set(all_attn_layers.keys()) - target_attn_layer_names - all_indexer_layer_names
assert len(self._draft_attn_layer_names) == 1
self.attn_layer_names = list(sorted(self._draft_attn_layer_names)) self.attn_layer_names = list(sorted(self._draft_attn_layer_names))
draft_attn_layers_dict = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase) draft_attn_layers_dict = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase)
self.kernel_block_size = ( self.kernel_block_size = (
@@ -699,10 +698,24 @@ class SpecDecodeBaseProposer(EagleProposer):
multi_steps_attn_metadata.append(per_layer_attn_metadata) multi_steps_attn_metadata.append(per_layer_attn_metadata)
else: else:
# Copy the old attn_metadata and update # Copy the old attn_metadata and update
for draft_step in range(1, self.num_speculative_tokens): if not self.parallel_drafting:
per_layer_attn_metadata = dict() for draft_step in range(1, self.num_speculative_tokens):
if vllm_version_is("0.17.0"): per_layer_attn_metadata = dict()
for attn_group in self.draft_attn_groups: if vllm_version_is("0.17.0"):
for attn_group in self.draft_attn_groups:
common_attn_metadata, attn_metadata = self.attn_update_stack_num_spec_norm(
draft_step,
attn_metadata,
common_attn_metadata,
batch_size,
num_input_tokens,
used_update_positions,
aclgraph_runtime_mode,
attn_group=attn_group,
)
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
else:
common_attn_metadata, attn_metadata = self.attn_update_stack_num_spec_norm( common_attn_metadata, attn_metadata = self.attn_update_stack_num_spec_norm(
draft_step, draft_step,
attn_metadata, attn_metadata,
@@ -711,23 +724,10 @@ class SpecDecodeBaseProposer(EagleProposer):
num_input_tokens, num_input_tokens,
used_update_positions, used_update_positions,
aclgraph_runtime_mode, aclgraph_runtime_mode,
attn_group=attn_group,
) )
for layer_name in self.attn_layer_names: for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata per_layer_attn_metadata[layer_name] = attn_metadata
else: multi_steps_attn_metadata.append(per_layer_attn_metadata)
common_attn_metadata, attn_metadata = self.attn_update_stack_num_spec_norm(
draft_step,
attn_metadata,
common_attn_metadata,
batch_size,
num_input_tokens,
used_update_positions,
aclgraph_runtime_mode,
)
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
multi_steps_attn_metadata.append(per_layer_attn_metadata)
token_indices_to_sample_len = token_indices_to_sample.shape[0] token_indices_to_sample_len = token_indices_to_sample.shape[0]
self.token_indices_to_sample[:token_indices_to_sample_len].copy_(token_indices_to_sample) self.token_indices_to_sample[:token_indices_to_sample_len].copy_(token_indices_to_sample)
@@ -1064,16 +1064,21 @@ class SpecDecodeBaseProposer(EagleProposer):
# 2. # 2.
# Recompute the slot mapping based on the new positions and # Recompute the slot mapping based on the new positions and
# rejection mask. # rejection mask.
builder = ( if vllm_version_is("0.17.0"):
self._get_attention_metadata_builder() # Use the first draft attention group's kv_cache_spec for block_size
if self.attn_metadata_builder is None # (all draft layers share the same kv-cache group)
else self.attn_metadata_builder assert len(self.draft_attn_groups) > 0
) block_size = self.draft_attn_groups[0].kv_cache_spec.block_size
else:
if self.attn_metadata_builder is None:
block_size = self._get_attention_metadata_builder().kv_cache_spec.block_size
else:
block_size = self.attn_metadata_builder.kv_cache_spec.block_size
new_slot_mapping = compute_new_slot_mapping( new_slot_mapping = compute_new_slot_mapping(
cad=cad, cad=cad,
new_positions=self.positions[:total_num_output_tokens], new_positions=self.positions[:total_num_output_tokens],
is_rejected_token_mask=self.is_rejected_token_mask[:total_num_output_tokens], is_rejected_token_mask=self.is_rejected_token_mask[:total_num_output_tokens],
block_size=builder.kv_cache_spec.block_size, block_size=block_size,
num_new_tokens=self.net_num_new_slots_per_request, num_new_tokens=self.net_num_new_slots_per_request,
max_model_len=self.max_model_len, max_model_len=self.max_model_len,
) )
@@ -1152,14 +1157,14 @@ class SpecDecodeBaseProposer(EagleProposer):
# out-of-range access during the model execution. The draft tokens # out-of-range access during the model execution. The draft tokens
# generated with this adjustment should be ignored. # generated with this adjustment should be ignored.
if self.uses_mrope: if self.uses_mrope:
exceeds_max_model_len = used_update_positions[0] >= self.vllm_config.model_config.max_model_len exceeds_max_model_len = used_update_positions[0] >= self.max_model_len
# Mask out the position ids that exceed the max model length. # Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE. # Otherwise, we may get out-of-range error in RoPE.
clamped_positions = torch.where( clamped_positions = torch.where(
exceeds_max_model_len.unsqueeze(0), torch.zeros_like(used_update_positions), used_update_positions exceeds_max_model_len.unsqueeze(0), torch.zeros_like(used_update_positions), used_update_positions
) )
else: else:
exceeds_max_model_len = used_update_positions >= self.vllm_config.model_config.max_model_len exceeds_max_model_len = used_update_positions >= self.max_model_len
clamped_positions = torch.where(exceeds_max_model_len, 0, used_update_positions) clamped_positions = torch.where(exceeds_max_model_len, 0, used_update_positions)
# For data integrity when async scheduling, we shouldn't use in place # For data integrity when async scheduling, we shouldn't use in place

View File

@@ -74,7 +74,6 @@ from vllm.v1.outputs import (
from vllm.v1.sample.logits_processor import build_logitsprocs from vllm.v1.sample.logits_processor import build_logitsprocs
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.spec_decode.draft_model import DraftModelProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.structured_output.utils import apply_grammar_bitmask
from vllm.v1.utils import record_function_or_nullcontext from vllm.v1.utils import record_function_or_nullcontext
@@ -2561,7 +2560,7 @@ class NPUModelRunner(GPUModelRunner):
if self.speculative_config and ( if self.speculative_config and (
self.speculative_config.use_eagle() or self.speculative_config.uses_draft_model() self.speculative_config.use_eagle() or self.speculative_config.uses_draft_model()
): ):
assert isinstance(self.drafter, AscendEagleProposer | DraftModelProposer) assert isinstance(self.drafter, AscendEagleProposer | AscendDraftModelProposer)
self.drafter.initialize_attn_backend(kv_cache_config, self.kernel_block_sizes) self.drafter.initialize_attn_backend(kv_cache_config, self.kernel_block_sizes)
if has_kv_transfer_group(): if has_kv_transfer_group():