From de93790d08905a497b29741f593d7c5f1f98df73 Mon Sep 17 00:00:00 2001 From: drslark <96540755+drslark@users.noreply.github.com> Date: Thu, 12 Mar 2026 18:38:50 +0800 Subject: [PATCH] [main][bugfix] Fixed the problem of drafter crashed in FULL mode (#7158) ### What this PR does / why we need it? The merged graph of draft in `FULL` mode is broken now. This pr solves it. Also, `actual_seq_lengths_q` in `model_runner` is found redundant, so, it is removed. It depends on https://github.com/vllm-project/vllm-ascend/pull/7144 and https://github.com/vllm-project/vllm-ascend/pull/7148. ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? Test code is shown as below: ```python prompts = [ "1.Who are you?", "2. Who are you?", ] sampling_params = SamplingParams(temperature=0.0, top_p=0.95, top_k=40, max_tokens=200) llm = LLM( model="/home/some-model/Meta-Llama-3.1-8B-Instruct", tensor_parallel_size=1, max_num_seqs=32, # enforce_eager=True, disable_log_stats=False, distributed_executor_backend="mp", gpu_memory_utilization=0.7, async_scheduling=True, speculative_config={ "enforce_eager": True, "model": "/home/some-model/EAGLE3-LLaMA3.1-Instruct-8B", "disable_padded_drafter_batch": False, "method": "eagle3", "num_speculative_tokens": 3, }, compilation_config={ "cudagraph_mode": "FULL", "cudagraph_num_of_warmups": 1, }, max_model_len=4096, enable_prefix_caching=False, ) outputs = llm.generate(prompts, sampling_params) ``` The result before: ```text File "/vllm-workspace/vllm-ascend/vllm_ascend/attention/attention_v1.py", line 575, in full_graph_fia graph_params.events[num_tokens].append(event) ~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^ KeyError: 132 ``` The result after: ```text -------------------------------------------------- total_num_output_tokens: 400 num_drafts: 242 num_draft_tokens: 726 num_accepted_tokens: 156 mean acceptance length: 1.64 -------------------------------------------------- acceptance at token 0: 0.42 acceptance at token 1: 0.16 acceptance at token 2: 0.07 ``` We also test `FULL_DECODE_ONLY` mode. The result is: ```text -------------------------------------------------- total_num_output_tokens: 400 num_drafts: 244 num_draft_tokens: 732 num_accepted_tokens: 155 mean acceptance length: 1.64 -------------------------------------------------- acceptance at token 0: 0.42 acceptance at token 1: 0.16 acceptance at token 2: 0.06 ``` - vLLM version: v0.16.0 - vLLM main: https://github.com/vllm-project/vllm/commit/4034c3d32e30d01639459edd3ab486f56993876d Signed-off-by: drslark --- .../singlecard/spec_decode/test_mtp_eagle_correctness.py | 4 ++-- vllm_ascend/spec_decode/eagle_proposer.py | 7 ++++--- vllm_ascend/worker/model_runner_v1.py | 3 --- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/tests/e2e/singlecard/spec_decode/test_mtp_eagle_correctness.py b/tests/e2e/singlecard/spec_decode/test_mtp_eagle_correctness.py index 77447e8f..ef89d328 100644 --- a/tests/e2e/singlecard/spec_decode/test_mtp_eagle_correctness.py +++ b/tests/e2e/singlecard/spec_decode/test_mtp_eagle_correctness.py @@ -170,8 +170,8 @@ def test_llama_qwen3_eagle_correctness( "max_model_len": 128, }, compilation_config=CompilationConfig( - cudagraph_mode="FULL_DECODE_ONLY", - cudagraph_capture_sizes=[12])) as llm: + cudagraph_mode="FULL", + cudagraph_capture_sizes=[5, 12])) as llm: spec_outputs = llm.generate(example_prompts, sampling_params) cleanup_dist_env_and_memory() del llm diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index b8e0c2f9..b9d4e3bd 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -315,10 +315,11 @@ class AscendEagleProposer(EagleProposer): aclgraph_runtime_mode = CUDAGraphMode.NONE if aclgraph_runtime_mode == CUDAGraphMode.FULL and len(self.runner.attn_groups) > 0: num_computed_tokens_cpu = self.runner.input_batch.num_computed_tokens_cpu_tensor[:num_reqs] - self.query_start_loc.cpu[: num_reqs + 1] = torch.tensor( - [0] + self.runner.actual_seq_lengths_q[:num_reqs], device="cpu", dtype=torch.int32 - ) + + # num_reqs is already the padded version + self.query_start_loc.cpu[: num_reqs + 1].copy_(self.runner.query_start_loc.cpu[: num_reqs + 1]) self.query_start_loc.copy_to_gpu() + common_attn_metadata = AscendCommonAttentionMetadata( query_start_loc=self.query_start_loc.gpu[: num_reqs + 1], query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs + 1], diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 8000dbca..a102d0b1 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -420,9 +420,6 @@ class NPUModelRunner(GPUModelRunner): assert isinstance(self.drafter, AscendEagleProposer) self.use_aux_hidden_state_outputs = self.drafter.eagle3_use_aux_hidden_state self.rejection_sampler = RejectionSampler(self.sampler) - self.actual_seq_lengths_q = list( - range(self.decode_token_per_req, self.max_num_tokens + 1, self.decode_token_per_req) - ) self.discard_request_indices = self._make_buffer(self.max_num_reqs, dtype=torch.int64) self.num_discarded_requests = 0