[SpecDecode] Fix Draft model proposer (#7230)

### What this PR does / why we need it?
This pr fix the Unified draft parallel feature. 
1. In Draft model proposer, there are exceed 1 attention layers in
target model, thus removing the assertion on layer number.
2. we should get block size through `draft_attn_groups` instead of
`attn_metadata_builder` after 0.17.0.
3. `attn_update_stack_num_spec_norm` shouldn't be done when unified
draft parallel is enabled

### How was this patch tested?
Test pass with
`tests/e2e/singlecard/spec_decode/test_v1_spec_decode.py::test_parallel_drafting_acceptance`,
which is already included in CI

- vLLM version: v0.17.0
- vLLM main:
4034c3d32e

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2026-03-14 18:26:37 +08:00
committed by GitHub
parent 0ad52517a1
commit e7aa2c285c
2 changed files with 33 additions and 29 deletions

View File

@@ -195,7 +195,6 @@ class SpecDecodeBaseProposer(EagleProposer):
all_indexer_layer_names = set(get_layers_from_vllm_config(self.vllm_config, DeepseekV32IndexerCache).keys())
self._draft_attn_layer_names = set(all_attn_layers.keys()) - target_attn_layer_names - all_indexer_layer_names
assert len(self._draft_attn_layer_names) == 1
self.attn_layer_names = list(sorted(self._draft_attn_layer_names))
draft_attn_layers_dict = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase)
self.kernel_block_size = (
@@ -699,6 +698,7 @@ class SpecDecodeBaseProposer(EagleProposer):
multi_steps_attn_metadata.append(per_layer_attn_metadata)
else:
# Copy the old attn_metadata and update
if not self.parallel_drafting:
for draft_step in range(1, self.num_speculative_tokens):
per_layer_attn_metadata = dict()
if vllm_version_is("0.17.0"):
@@ -1064,16 +1064,21 @@ class SpecDecodeBaseProposer(EagleProposer):
# 2.
# Recompute the slot mapping based on the new positions and
# rejection mask.
builder = (
self._get_attention_metadata_builder()
if self.attn_metadata_builder is None
else self.attn_metadata_builder
)
if vllm_version_is("0.17.0"):
# Use the first draft attention group's kv_cache_spec for block_size
# (all draft layers share the same kv-cache group)
assert len(self.draft_attn_groups) > 0
block_size = self.draft_attn_groups[0].kv_cache_spec.block_size
else:
if self.attn_metadata_builder is None:
block_size = self._get_attention_metadata_builder().kv_cache_spec.block_size
else:
block_size = self.attn_metadata_builder.kv_cache_spec.block_size
new_slot_mapping = compute_new_slot_mapping(
cad=cad,
new_positions=self.positions[:total_num_output_tokens],
is_rejected_token_mask=self.is_rejected_token_mask[:total_num_output_tokens],
block_size=builder.kv_cache_spec.block_size,
block_size=block_size,
num_new_tokens=self.net_num_new_slots_per_request,
max_model_len=self.max_model_len,
)
@@ -1152,14 +1157,14 @@ class SpecDecodeBaseProposer(EagleProposer):
# out-of-range access during the model execution. The draft tokens
# generated with this adjustment should be ignored.
if self.uses_mrope:
exceeds_max_model_len = used_update_positions[0] >= self.vllm_config.model_config.max_model_len
exceeds_max_model_len = used_update_positions[0] >= self.max_model_len
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
clamped_positions = torch.where(
exceeds_max_model_len.unsqueeze(0), torch.zeros_like(used_update_positions), used_update_positions
)
else:
exceeds_max_model_len = used_update_positions >= self.vllm_config.model_config.max_model_len
exceeds_max_model_len = used_update_positions >= self.max_model_len
clamped_positions = torch.where(exceeds_max_model_len, 0, used_update_positions)
# For data integrity when async scheduling, we shouldn't use in place

View File

@@ -74,7 +74,6 @@ from vllm.v1.outputs import (
from vllm.v1.sample.logits_processor import build_logitsprocs
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.spec_decode.draft_model import DraftModelProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.structured_output.utils import apply_grammar_bitmask
from vllm.v1.utils import record_function_or_nullcontext
@@ -2561,7 +2560,7 @@ class NPUModelRunner(GPUModelRunner):
if self.speculative_config and (
self.speculative_config.use_eagle() or self.speculative_config.uses_draft_model()
):
assert isinstance(self.drafter, AscendEagleProposer | DraftModelProposer)
assert isinstance(self.drafter, AscendEagleProposer | AscendDraftModelProposer)
self.drafter.initialize_attn_backend(kv_cache_config, self.kernel_block_sizes)
if has_kv_transfer_group():