[SpecDecode] Fix Draft model proposer (#7230)
### What this PR does / why we need it?
This pr fix the Unified draft parallel feature.
1. In Draft model proposer, there are exceed 1 attention layers in
target model, thus removing the assertion on layer number.
2. we should get block size through `draft_attn_groups` instead of
`attn_metadata_builder` after 0.17.0.
3. `attn_update_stack_num_spec_norm` shouldn't be done when unified
draft parallel is enabled
### How was this patch tested?
Test pass with
`tests/e2e/singlecard/spec_decode/test_v1_spec_decode.py::test_parallel_drafting_acceptance`,
which is already included in CI
- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
@@ -195,7 +195,6 @@ class SpecDecodeBaseProposer(EagleProposer):
|
|||||||
all_indexer_layer_names = set(get_layers_from_vllm_config(self.vllm_config, DeepseekV32IndexerCache).keys())
|
all_indexer_layer_names = set(get_layers_from_vllm_config(self.vllm_config, DeepseekV32IndexerCache).keys())
|
||||||
self._draft_attn_layer_names = set(all_attn_layers.keys()) - target_attn_layer_names - all_indexer_layer_names
|
self._draft_attn_layer_names = set(all_attn_layers.keys()) - target_attn_layer_names - all_indexer_layer_names
|
||||||
|
|
||||||
assert len(self._draft_attn_layer_names) == 1
|
|
||||||
self.attn_layer_names = list(sorted(self._draft_attn_layer_names))
|
self.attn_layer_names = list(sorted(self._draft_attn_layer_names))
|
||||||
draft_attn_layers_dict = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase)
|
draft_attn_layers_dict = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase)
|
||||||
self.kernel_block_size = (
|
self.kernel_block_size = (
|
||||||
@@ -699,10 +698,24 @@ class SpecDecodeBaseProposer(EagleProposer):
|
|||||||
multi_steps_attn_metadata.append(per_layer_attn_metadata)
|
multi_steps_attn_metadata.append(per_layer_attn_metadata)
|
||||||
else:
|
else:
|
||||||
# Copy the old attn_metadata and update
|
# Copy the old attn_metadata and update
|
||||||
for draft_step in range(1, self.num_speculative_tokens):
|
if not self.parallel_drafting:
|
||||||
per_layer_attn_metadata = dict()
|
for draft_step in range(1, self.num_speculative_tokens):
|
||||||
if vllm_version_is("0.17.0"):
|
per_layer_attn_metadata = dict()
|
||||||
for attn_group in self.draft_attn_groups:
|
if vllm_version_is("0.17.0"):
|
||||||
|
for attn_group in self.draft_attn_groups:
|
||||||
|
common_attn_metadata, attn_metadata = self.attn_update_stack_num_spec_norm(
|
||||||
|
draft_step,
|
||||||
|
attn_metadata,
|
||||||
|
common_attn_metadata,
|
||||||
|
batch_size,
|
||||||
|
num_input_tokens,
|
||||||
|
used_update_positions,
|
||||||
|
aclgraph_runtime_mode,
|
||||||
|
attn_group=attn_group,
|
||||||
|
)
|
||||||
|
for layer_name in self.attn_layer_names:
|
||||||
|
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||||
|
else:
|
||||||
common_attn_metadata, attn_metadata = self.attn_update_stack_num_spec_norm(
|
common_attn_metadata, attn_metadata = self.attn_update_stack_num_spec_norm(
|
||||||
draft_step,
|
draft_step,
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
@@ -711,23 +724,10 @@ class SpecDecodeBaseProposer(EagleProposer):
|
|||||||
num_input_tokens,
|
num_input_tokens,
|
||||||
used_update_positions,
|
used_update_positions,
|
||||||
aclgraph_runtime_mode,
|
aclgraph_runtime_mode,
|
||||||
attn_group=attn_group,
|
|
||||||
)
|
)
|
||||||
for layer_name in self.attn_layer_names:
|
for layer_name in self.attn_layer_names:
|
||||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||||
else:
|
multi_steps_attn_metadata.append(per_layer_attn_metadata)
|
||||||
common_attn_metadata, attn_metadata = self.attn_update_stack_num_spec_norm(
|
|
||||||
draft_step,
|
|
||||||
attn_metadata,
|
|
||||||
common_attn_metadata,
|
|
||||||
batch_size,
|
|
||||||
num_input_tokens,
|
|
||||||
used_update_positions,
|
|
||||||
aclgraph_runtime_mode,
|
|
||||||
)
|
|
||||||
for layer_name in self.attn_layer_names:
|
|
||||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
|
||||||
multi_steps_attn_metadata.append(per_layer_attn_metadata)
|
|
||||||
|
|
||||||
token_indices_to_sample_len = token_indices_to_sample.shape[0]
|
token_indices_to_sample_len = token_indices_to_sample.shape[0]
|
||||||
self.token_indices_to_sample[:token_indices_to_sample_len].copy_(token_indices_to_sample)
|
self.token_indices_to_sample[:token_indices_to_sample_len].copy_(token_indices_to_sample)
|
||||||
@@ -1064,16 +1064,21 @@ class SpecDecodeBaseProposer(EagleProposer):
|
|||||||
# 2.
|
# 2.
|
||||||
# Recompute the slot mapping based on the new positions and
|
# Recompute the slot mapping based on the new positions and
|
||||||
# rejection mask.
|
# rejection mask.
|
||||||
builder = (
|
if vllm_version_is("0.17.0"):
|
||||||
self._get_attention_metadata_builder()
|
# Use the first draft attention group's kv_cache_spec for block_size
|
||||||
if self.attn_metadata_builder is None
|
# (all draft layers share the same kv-cache group)
|
||||||
else self.attn_metadata_builder
|
assert len(self.draft_attn_groups) > 0
|
||||||
)
|
block_size = self.draft_attn_groups[0].kv_cache_spec.block_size
|
||||||
|
else:
|
||||||
|
if self.attn_metadata_builder is None:
|
||||||
|
block_size = self._get_attention_metadata_builder().kv_cache_spec.block_size
|
||||||
|
else:
|
||||||
|
block_size = self.attn_metadata_builder.kv_cache_spec.block_size
|
||||||
new_slot_mapping = compute_new_slot_mapping(
|
new_slot_mapping = compute_new_slot_mapping(
|
||||||
cad=cad,
|
cad=cad,
|
||||||
new_positions=self.positions[:total_num_output_tokens],
|
new_positions=self.positions[:total_num_output_tokens],
|
||||||
is_rejected_token_mask=self.is_rejected_token_mask[:total_num_output_tokens],
|
is_rejected_token_mask=self.is_rejected_token_mask[:total_num_output_tokens],
|
||||||
block_size=builder.kv_cache_spec.block_size,
|
block_size=block_size,
|
||||||
num_new_tokens=self.net_num_new_slots_per_request,
|
num_new_tokens=self.net_num_new_slots_per_request,
|
||||||
max_model_len=self.max_model_len,
|
max_model_len=self.max_model_len,
|
||||||
)
|
)
|
||||||
@@ -1152,14 +1157,14 @@ class SpecDecodeBaseProposer(EagleProposer):
|
|||||||
# out-of-range access during the model execution. The draft tokens
|
# out-of-range access during the model execution. The draft tokens
|
||||||
# generated with this adjustment should be ignored.
|
# generated with this adjustment should be ignored.
|
||||||
if self.uses_mrope:
|
if self.uses_mrope:
|
||||||
exceeds_max_model_len = used_update_positions[0] >= self.vllm_config.model_config.max_model_len
|
exceeds_max_model_len = used_update_positions[0] >= self.max_model_len
|
||||||
# Mask out the position ids that exceed the max model length.
|
# Mask out the position ids that exceed the max model length.
|
||||||
# Otherwise, we may get out-of-range error in RoPE.
|
# Otherwise, we may get out-of-range error in RoPE.
|
||||||
clamped_positions = torch.where(
|
clamped_positions = torch.where(
|
||||||
exceeds_max_model_len.unsqueeze(0), torch.zeros_like(used_update_positions), used_update_positions
|
exceeds_max_model_len.unsqueeze(0), torch.zeros_like(used_update_positions), used_update_positions
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
exceeds_max_model_len = used_update_positions >= self.vllm_config.model_config.max_model_len
|
exceeds_max_model_len = used_update_positions >= self.max_model_len
|
||||||
clamped_positions = torch.where(exceeds_max_model_len, 0, used_update_positions)
|
clamped_positions = torch.where(exceeds_max_model_len, 0, used_update_positions)
|
||||||
|
|
||||||
# For data integrity when async scheduling, we shouldn't use in place
|
# For data integrity when async scheduling, we shouldn't use in place
|
||||||
|
|||||||
@@ -74,7 +74,6 @@ from vllm.v1.outputs import (
|
|||||||
from vllm.v1.sample.logits_processor import build_logitsprocs
|
from vllm.v1.sample.logits_processor import build_logitsprocs
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.sample.rejection_sampler import RejectionSampler
|
from vllm.v1.sample.rejection_sampler import RejectionSampler
|
||||||
from vllm.v1.spec_decode.draft_model import DraftModelProposer
|
|
||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
from vllm.v1.structured_output.utils import apply_grammar_bitmask
|
from vllm.v1.structured_output.utils import apply_grammar_bitmask
|
||||||
from vllm.v1.utils import record_function_or_nullcontext
|
from vllm.v1.utils import record_function_or_nullcontext
|
||||||
@@ -2561,7 +2560,7 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
if self.speculative_config and (
|
if self.speculative_config and (
|
||||||
self.speculative_config.use_eagle() or self.speculative_config.uses_draft_model()
|
self.speculative_config.use_eagle() or self.speculative_config.uses_draft_model()
|
||||||
):
|
):
|
||||||
assert isinstance(self.drafter, AscendEagleProposer | DraftModelProposer)
|
assert isinstance(self.drafter, AscendEagleProposer | AscendDraftModelProposer)
|
||||||
self.drafter.initialize_attn_backend(kv_cache_config, self.kernel_block_sizes)
|
self.drafter.initialize_attn_backend(kv_cache_config, self.kernel_block_sizes)
|
||||||
|
|
||||||
if has_kv_transfer_group():
|
if has_kv_transfer_group():
|
||||||
|
|||||||
Reference in New Issue
Block a user