diff --git a/tests/ut/spec_decode/test_eagle_proposer.py b/tests/ut/spec_decode/test_eagle_proposer.py index 7cc6368c..f66aa804 100644 --- a/tests/ut/spec_decode/test_eagle_proposer.py +++ b/tests/ut/spec_decode/test_eagle_proposer.py @@ -407,6 +407,7 @@ class TestEagleProposerDummyRun(TestBase): mock_get_context.return_value = mock_return_context mock_get_context_2.return_value = mock_return_context self.proposer.use_cuda_graph = True + self.proposer.draft_attn_groups = [MagicMock()] # cpu does not support `torch.ops.vllm.maybe_pad_and_reduce` with set_current_vllm_config(self.vllm_config): self.proposer.enable_shared_expert_dp = False diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index b887b114..015dd90b 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -495,10 +495,12 @@ class AscendAttentionBackendImpl(AttentionImpl): draft_step = attn_count // num_layers seq_lens = attn_metadata[draft_step][key].seq_lens_list actual_seq_lengths_q = attn_metadata[draft_step][key].actual_seq_lengths_q + block_tables = attn_metadata[draft_step][key].block_tables attn_count = attn_count + 1 else: seq_lens = attn_metadata[key].seq_lens_list actual_seq_lengths_q = attn_metadata[key].actual_seq_lengths_q + block_tables = attn_metadata[key].block_tables torch.npu.graph_task_update_begin(update_stream, handle) torch_npu.npu_fused_infer_attention_score.out( diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 869cd287..12bcd8a5 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -17,7 +17,7 @@ from vllm.distributed.parallel_state import ( init_model_parallel_group, patch_tensor_parallel_group, ) -from vllm.forward_context import get_forward_context +from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.logger import logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.model_loader import get_model @@ -461,6 +461,7 @@ class SpecDecodeBaseProposer(EagleProposer): next_token_ids: torch.Tensor, token_indices_to_sample: torch.Tensor | None, common_attn_metadata: CommonAttentionMetadata, + target_model_batch_desc: BatchDescriptor, sampling_metadata: SamplingMetadata, mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None, req_scheduled_tokens=None, @@ -500,23 +501,6 @@ class SpecDecodeBaseProposer(EagleProposer): assert self.runner is not None if self.use_cuda_graph and num_tokens <= self.runner.cudagraph_batch_sizes[-1]: num_input_tokens = self.runner.cudagraph_dispatcher._bs_to_padded_graph_size[num_tokens] - if not ( - self.speculative_config.disable_padded_drafter_batch - and self.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE - ): - # TODO: Due to the inconsistency between the proposer `dispatcher` and model runner, this padding - # should have been done in model runner but not. For example, at prefill stage, target model - # is run in eager mode currently, which means `_pad_query_start_loc_for_fia` is not called, - # while draft model is run in graph model, which means we should pad the `query_start_loc`. - # Need to be fixed in the future. - num_reqs_padded = self.runner._pad_query_start_loc_for_fia( - num_input_tokens, common_attn_metadata.num_reqs, common_attn_metadata.num_reqs - ) - common_attn_metadata.num_reqs = num_reqs_padded - common_attn_metadata.query_start_loc = self.runner.query_start_loc.gpu[: num_reqs_padded + 1] - common_attn_metadata.query_start_loc_cpu = self.runner.query_start_loc.cpu[: num_reqs_padded + 1] - common_attn_metadata.seq_lens = self.runner.seq_lens.gpu[:num_reqs_padded] - common_attn_metadata.seq_lens_cpu = self.runner.seq_lens.cpu[:num_reqs_padded] else: num_input_tokens = num_tokens @@ -529,12 +513,30 @@ class SpecDecodeBaseProposer(EagleProposer): has_lora = len(self.runner.input_batch.lora_id_to_lora_request) > 0 if self.use_cuda_graph: aclgraph_runtime_mode, batch_descriptor = self.runner.cudagraph_dispatcher.dispatch( - num_tokens=num_input_tokens, uniform_decode=True, has_lora=has_lora + num_tokens=num_input_tokens, uniform_decode=target_model_batch_desc.uniform, has_lora=has_lora ) else: aclgraph_runtime_mode = CUDAGraphMode.NONE batch_descriptor = None + if aclgraph_runtime_mode == CUDAGraphMode.FULL: + # TODO: Due to the inconsistency between the proposer `dispatcher` and model runner, this padding + # should have been done in model runner but not. For example, at prefill stage, target model + # is run in eager mode currently, which means `_pad_query_start_loc_for_fia` is not called, + # while draft model is run in graph model, which means we should pad the `query_start_loc`. + # Need to be fixed in the future. + num_reqs_padded = self.runner._pad_query_start_loc_for_fia( + num_input_tokens, common_attn_metadata.num_reqs, common_attn_metadata.num_reqs + ) + common_attn_metadata.num_reqs = num_reqs_padded + common_attn_metadata.query_start_loc = self.runner.query_start_loc.gpu[: num_reqs_padded + 1] + common_attn_metadata.query_start_loc_cpu = self.runner.query_start_loc.cpu[: num_reqs_padded + 1] + common_attn_metadata.block_table_tensor = self._pad_tensor( + common_attn_metadata.block_table_tensor, num_reqs_padded + ) + common_attn_metadata.seq_lens = self.runner.seq_lens.gpu[:num_reqs_padded] + common_attn_metadata.seq_lens_cpu = self.runner.seq_lens.cpu[:num_reqs_padded] + if self.supports_mm_inputs: mm_embeds, is_mm_embed = mm_embed_inputs or (None, None) inputs_embeds = self.model.embed_input_ids( @@ -1110,15 +1112,17 @@ class SpecDecodeBaseProposer(EagleProposer): common_attn_metadata = self.shallow_copy_metadata(old_common_metadata) if draft_step == 1: - if aclgraph_runtime_mode == CUDAGraphMode.FULL and (pad_size := input_batch_size - batch_size) > 0: + if aclgraph_runtime_mode == CUDAGraphMode.FULL: common_attn_metadata.num_reqs = input_batch_size common_attn_metadata.block_table_tensor = self._pad_tensor( - common_attn_metadata.block_table_tensor, pad_size + common_attn_metadata.block_table_tensor, input_batch_size + ) + common_attn_metadata.seq_lens = self._pad_tensor(common_attn_metadata.seq_lens, input_batch_size) + common_attn_metadata.seq_lens_cpu = self._pad_tensor( + common_attn_metadata.seq_lens_cpu, input_batch_size ) - common_attn_metadata.seq_lens = self._pad_tensor(common_attn_metadata.seq_lens, pad_size) - common_attn_metadata.seq_lens_cpu = self._pad_tensor(common_attn_metadata.seq_lens_cpu, pad_size) common_attn_metadata.num_computed_tokens_cpu = self._pad_tensor( - common_attn_metadata.num_computed_tokens_cpu, pad_size + common_attn_metadata.num_computed_tokens_cpu, input_batch_size ) common_attn_metadata.query_start_loc = self.arange[: input_batch_size + 1] common_attn_metadata.query_start_loc_cpu = torch.from_numpy( @@ -1545,8 +1549,13 @@ class SpecDecodeBaseProposer(EagleProposer): # update full-graph params for one spec token def _update_full_graph_params(self, forward_context, num_tokens, draft_attn_metadatas=None): + if vllm_version_is("0.17.0"): + assert len(self.draft_attn_groups) > 0 + attn_backend = self.draft_attn_groups[0].backend + else: + attn_backend = self.runner.attn_backend update_full_graph_params( - self.runner.attn_backend, + attn_backend, self.update_stream, forward_context, num_tokens, @@ -1556,10 +1565,14 @@ class SpecDecodeBaseProposer(EagleProposer): ) # padding tensor into desired size - def _pad_tensor(self, tensor, pad_size): - pad = [0] * (2 * tensor.dim() - 1) + [pad_size] - padded_tensor = F.pad(tensor, pad, mode="constant", value=0) - return padded_tensor + def _pad_tensor(self, tensor, desired_size): + pad_size = desired_size - tensor.shape[0] + if pad_size > 0: + pad = [0] * (2 * tensor.dim() - 1) + [pad_size] + tensor = F.pad(tensor, pad, mode="constant", value=0) + else: + tensor = tensor[:desired_size] + return tensor def maybe_pad_and_reduce( self, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index f8227f3f..d6389402 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -215,6 +215,7 @@ class ExecuteModelState(NamedTuple): positions: torch.Tensor ec_connector_output: "ECConnectorOutput | None" cudagraph_stats: CUDAGraphStat | None + batch_desc: BatchDescriptor class NPUModelRunner(GPUModelRunner): @@ -995,6 +996,7 @@ class NPUModelRunner(GPUModelRunner): hidden_states: torch.Tensor, aux_hidden_states: torch.Tensor = None, sample_hidden_states: torch.Tensor = None, + target_model_batch_desc: BatchDescriptor = None, ) -> list[list[int]] | None: if not self.drafter: # Speculative decoding is not enabled. @@ -1115,6 +1117,7 @@ class NPUModelRunner(GPUModelRunner): next_token_ids=next_token_ids, token_indices_to_sample=token_indices_to_sample, common_attn_metadata=common_attn_metadata, + target_model_batch_desc=target_model_batch_desc, sampling_metadata=sampling_metadata, req_scheduled_tokens=req_scheduled_tokens, long_seq_metadata=long_seq_metadata, @@ -1455,6 +1458,7 @@ class NPUModelRunner(GPUModelRunner): positions, ec_connector_output, cudagraph_stats, + batch_desc, ) self.kv_connector_output = kv_connector_output return None @@ -1497,6 +1501,7 @@ class NPUModelRunner(GPUModelRunner): positions, ec_connector_output, cudagraph_stats, + batch_desc, ) = self.execute_model_state # Clear ephemeral state. self.execute_model_state = None @@ -1533,6 +1538,7 @@ class NPUModelRunner(GPUModelRunner): hidden_states, aux_hidden_states, sample_hidden_states, + batch_desc, ) self._copy_draft_token_ids_to_cpu(scheduler_output)