[Bugfix] Fix padding logic in eagle proposer for kimi25 (#7348)

### What this PR does / why we need it?
This PR aims to fix padding logic in eagle proposer for kimi25. Main
changes involve:
1. modify the way to obtain draft model attention builder and backend
2. add block table padding & related tensor slicing in common metadata
when `draft_step>1` for solving fia verifying error
3. replace block table in `update_graph_params` for solving fia
verifying error

- vLLM version: v0.17.0
- vLLM main:
4034c3d32e

Signed-off-by: Zetong Li <slippersss@126.com>
This commit is contained in:
Zetong Li
2026-03-21 16:57:22 +08:00
committed by GitHub
parent f482c314cf
commit 84a74f0cb1
4 changed files with 51 additions and 29 deletions

View File

@@ -407,6 +407,7 @@ class TestEagleProposerDummyRun(TestBase):
mock_get_context.return_value = mock_return_context mock_get_context.return_value = mock_return_context
mock_get_context_2.return_value = mock_return_context mock_get_context_2.return_value = mock_return_context
self.proposer.use_cuda_graph = True self.proposer.use_cuda_graph = True
self.proposer.draft_attn_groups = [MagicMock()]
# cpu does not support `torch.ops.vllm.maybe_pad_and_reduce` # cpu does not support `torch.ops.vllm.maybe_pad_and_reduce`
with set_current_vllm_config(self.vllm_config): with set_current_vllm_config(self.vllm_config):
self.proposer.enable_shared_expert_dp = False self.proposer.enable_shared_expert_dp = False

View File

@@ -495,10 +495,12 @@ class AscendAttentionBackendImpl(AttentionImpl):
draft_step = attn_count // num_layers draft_step = attn_count // num_layers
seq_lens = attn_metadata[draft_step][key].seq_lens_list seq_lens = attn_metadata[draft_step][key].seq_lens_list
actual_seq_lengths_q = attn_metadata[draft_step][key].actual_seq_lengths_q actual_seq_lengths_q = attn_metadata[draft_step][key].actual_seq_lengths_q
block_tables = attn_metadata[draft_step][key].block_tables
attn_count = attn_count + 1 attn_count = attn_count + 1
else: else:
seq_lens = attn_metadata[key].seq_lens_list seq_lens = attn_metadata[key].seq_lens_list
actual_seq_lengths_q = attn_metadata[key].actual_seq_lengths_q actual_seq_lengths_q = attn_metadata[key].actual_seq_lengths_q
block_tables = attn_metadata[key].block_tables
torch.npu.graph_task_update_begin(update_stream, handle) torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.npu_fused_infer_attention_score.out( torch_npu.npu_fused_infer_attention_score.out(

View File

@@ -17,7 +17,7 @@ from vllm.distributed.parallel_state import (
init_model_parallel_group, init_model_parallel_group,
patch_tensor_parallel_group, patch_tensor_parallel_group,
) )
from vllm.forward_context import get_forward_context from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.logger import logger from vllm.logger import logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
@@ -461,6 +461,7 @@ class SpecDecodeBaseProposer(EagleProposer):
next_token_ids: torch.Tensor, next_token_ids: torch.Tensor,
token_indices_to_sample: torch.Tensor | None, token_indices_to_sample: torch.Tensor | None,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
target_model_batch_desc: BatchDescriptor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None, mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
req_scheduled_tokens=None, req_scheduled_tokens=None,
@@ -500,23 +501,6 @@ class SpecDecodeBaseProposer(EagleProposer):
assert self.runner is not None assert self.runner is not None
if self.use_cuda_graph and num_tokens <= self.runner.cudagraph_batch_sizes[-1]: if self.use_cuda_graph and num_tokens <= self.runner.cudagraph_batch_sizes[-1]:
num_input_tokens = self.runner.cudagraph_dispatcher._bs_to_padded_graph_size[num_tokens] num_input_tokens = self.runner.cudagraph_dispatcher._bs_to_padded_graph_size[num_tokens]
if not (
self.speculative_config.disable_padded_drafter_batch
and self.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
):
# TODO: Due to the inconsistency between the proposer `dispatcher` and model runner, this padding
# should have been done in model runner but not. For example, at prefill stage, target model
# is run in eager mode currently, which means `_pad_query_start_loc_for_fia` is not called,
# while draft model is run in graph model, which means we should pad the `query_start_loc`.
# Need to be fixed in the future.
num_reqs_padded = self.runner._pad_query_start_loc_for_fia(
num_input_tokens, common_attn_metadata.num_reqs, common_attn_metadata.num_reqs
)
common_attn_metadata.num_reqs = num_reqs_padded
common_attn_metadata.query_start_loc = self.runner.query_start_loc.gpu[: num_reqs_padded + 1]
common_attn_metadata.query_start_loc_cpu = self.runner.query_start_loc.cpu[: num_reqs_padded + 1]
common_attn_metadata.seq_lens = self.runner.seq_lens.gpu[:num_reqs_padded]
common_attn_metadata.seq_lens_cpu = self.runner.seq_lens.cpu[:num_reqs_padded]
else: else:
num_input_tokens = num_tokens num_input_tokens = num_tokens
@@ -529,12 +513,30 @@ class SpecDecodeBaseProposer(EagleProposer):
has_lora = len(self.runner.input_batch.lora_id_to_lora_request) > 0 has_lora = len(self.runner.input_batch.lora_id_to_lora_request) > 0
if self.use_cuda_graph: if self.use_cuda_graph:
aclgraph_runtime_mode, batch_descriptor = self.runner.cudagraph_dispatcher.dispatch( aclgraph_runtime_mode, batch_descriptor = self.runner.cudagraph_dispatcher.dispatch(
num_tokens=num_input_tokens, uniform_decode=True, has_lora=has_lora num_tokens=num_input_tokens, uniform_decode=target_model_batch_desc.uniform, has_lora=has_lora
) )
else: else:
aclgraph_runtime_mode = CUDAGraphMode.NONE aclgraph_runtime_mode = CUDAGraphMode.NONE
batch_descriptor = None batch_descriptor = None
if aclgraph_runtime_mode == CUDAGraphMode.FULL:
# TODO: Due to the inconsistency between the proposer `dispatcher` and model runner, this padding
# should have been done in model runner but not. For example, at prefill stage, target model
# is run in eager mode currently, which means `_pad_query_start_loc_for_fia` is not called,
# while draft model is run in graph model, which means we should pad the `query_start_loc`.
# Need to be fixed in the future.
num_reqs_padded = self.runner._pad_query_start_loc_for_fia(
num_input_tokens, common_attn_metadata.num_reqs, common_attn_metadata.num_reqs
)
common_attn_metadata.num_reqs = num_reqs_padded
common_attn_metadata.query_start_loc = self.runner.query_start_loc.gpu[: num_reqs_padded + 1]
common_attn_metadata.query_start_loc_cpu = self.runner.query_start_loc.cpu[: num_reqs_padded + 1]
common_attn_metadata.block_table_tensor = self._pad_tensor(
common_attn_metadata.block_table_tensor, num_reqs_padded
)
common_attn_metadata.seq_lens = self.runner.seq_lens.gpu[:num_reqs_padded]
common_attn_metadata.seq_lens_cpu = self.runner.seq_lens.cpu[:num_reqs_padded]
if self.supports_mm_inputs: if self.supports_mm_inputs:
mm_embeds, is_mm_embed = mm_embed_inputs or (None, None) mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)
inputs_embeds = self.model.embed_input_ids( inputs_embeds = self.model.embed_input_ids(
@@ -1110,15 +1112,17 @@ class SpecDecodeBaseProposer(EagleProposer):
common_attn_metadata = self.shallow_copy_metadata(old_common_metadata) common_attn_metadata = self.shallow_copy_metadata(old_common_metadata)
if draft_step == 1: if draft_step == 1:
if aclgraph_runtime_mode == CUDAGraphMode.FULL and (pad_size := input_batch_size - batch_size) > 0: if aclgraph_runtime_mode == CUDAGraphMode.FULL:
common_attn_metadata.num_reqs = input_batch_size common_attn_metadata.num_reqs = input_batch_size
common_attn_metadata.block_table_tensor = self._pad_tensor( common_attn_metadata.block_table_tensor = self._pad_tensor(
common_attn_metadata.block_table_tensor, pad_size common_attn_metadata.block_table_tensor, input_batch_size
)
common_attn_metadata.seq_lens = self._pad_tensor(common_attn_metadata.seq_lens, input_batch_size)
common_attn_metadata.seq_lens_cpu = self._pad_tensor(
common_attn_metadata.seq_lens_cpu, input_batch_size
) )
common_attn_metadata.seq_lens = self._pad_tensor(common_attn_metadata.seq_lens, pad_size)
common_attn_metadata.seq_lens_cpu = self._pad_tensor(common_attn_metadata.seq_lens_cpu, pad_size)
common_attn_metadata.num_computed_tokens_cpu = self._pad_tensor( common_attn_metadata.num_computed_tokens_cpu = self._pad_tensor(
common_attn_metadata.num_computed_tokens_cpu, pad_size common_attn_metadata.num_computed_tokens_cpu, input_batch_size
) )
common_attn_metadata.query_start_loc = self.arange[: input_batch_size + 1] common_attn_metadata.query_start_loc = self.arange[: input_batch_size + 1]
common_attn_metadata.query_start_loc_cpu = torch.from_numpy( common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
@@ -1545,8 +1549,13 @@ class SpecDecodeBaseProposer(EagleProposer):
# update full-graph params for one spec token # update full-graph params for one spec token
def _update_full_graph_params(self, forward_context, num_tokens, draft_attn_metadatas=None): def _update_full_graph_params(self, forward_context, num_tokens, draft_attn_metadatas=None):
if vllm_version_is("0.17.0"):
assert len(self.draft_attn_groups) > 0
attn_backend = self.draft_attn_groups[0].backend
else:
attn_backend = self.runner.attn_backend
update_full_graph_params( update_full_graph_params(
self.runner.attn_backend, attn_backend,
self.update_stream, self.update_stream,
forward_context, forward_context,
num_tokens, num_tokens,
@@ -1556,10 +1565,14 @@ class SpecDecodeBaseProposer(EagleProposer):
) )
# padding tensor into desired size # padding tensor into desired size
def _pad_tensor(self, tensor, pad_size): def _pad_tensor(self, tensor, desired_size):
pad_size = desired_size - tensor.shape[0]
if pad_size > 0:
pad = [0] * (2 * tensor.dim() - 1) + [pad_size] pad = [0] * (2 * tensor.dim() - 1) + [pad_size]
padded_tensor = F.pad(tensor, pad, mode="constant", value=0) tensor = F.pad(tensor, pad, mode="constant", value=0)
return padded_tensor else:
tensor = tensor[:desired_size]
return tensor
def maybe_pad_and_reduce( def maybe_pad_and_reduce(
self, self,

View File

@@ -215,6 +215,7 @@ class ExecuteModelState(NamedTuple):
positions: torch.Tensor positions: torch.Tensor
ec_connector_output: "ECConnectorOutput | None" ec_connector_output: "ECConnectorOutput | None"
cudagraph_stats: CUDAGraphStat | None cudagraph_stats: CUDAGraphStat | None
batch_desc: BatchDescriptor
class NPUModelRunner(GPUModelRunner): class NPUModelRunner(GPUModelRunner):
@@ -995,6 +996,7 @@ class NPUModelRunner(GPUModelRunner):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
aux_hidden_states: torch.Tensor = None, aux_hidden_states: torch.Tensor = None,
sample_hidden_states: torch.Tensor = None, sample_hidden_states: torch.Tensor = None,
target_model_batch_desc: BatchDescriptor = None,
) -> list[list[int]] | None: ) -> list[list[int]] | None:
if not self.drafter: if not self.drafter:
# Speculative decoding is not enabled. # Speculative decoding is not enabled.
@@ -1115,6 +1117,7 @@ class NPUModelRunner(GPUModelRunner):
next_token_ids=next_token_ids, next_token_ids=next_token_ids,
token_indices_to_sample=token_indices_to_sample, token_indices_to_sample=token_indices_to_sample,
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
target_model_batch_desc=target_model_batch_desc,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
req_scheduled_tokens=req_scheduled_tokens, req_scheduled_tokens=req_scheduled_tokens,
long_seq_metadata=long_seq_metadata, long_seq_metadata=long_seq_metadata,
@@ -1455,6 +1458,7 @@ class NPUModelRunner(GPUModelRunner):
positions, positions,
ec_connector_output, ec_connector_output,
cudagraph_stats, cudagraph_stats,
batch_desc,
) )
self.kv_connector_output = kv_connector_output self.kv_connector_output = kv_connector_output
return None return None
@@ -1497,6 +1501,7 @@ class NPUModelRunner(GPUModelRunner):
positions, positions,
ec_connector_output, ec_connector_output,
cudagraph_stats, cudagraph_stats,
batch_desc,
) = self.execute_model_state ) = self.execute_model_state
# Clear ephemeral state. # Clear ephemeral state.
self.execute_model_state = None self.execute_model_state = None
@@ -1533,6 +1538,7 @@ class NPUModelRunner(GPUModelRunner):
hidden_states, hidden_states,
aux_hidden_states, aux_hidden_states,
sample_hidden_states, sample_hidden_states,
batch_desc,
) )
self._copy_draft_token_ids_to_cpu(scheduler_output) self._copy_draft_token_ids_to_cpu(scheduler_output)