From e20813f441cd38df86a6302370f3f9d96d2a904c Mon Sep 17 00:00:00 2001 From: Ronald Date: Wed, 14 Jan 2026 09:18:05 +0800 Subject: [PATCH] [Feature] implement eagle spec decoding for model runner v2 (#5840) ### What this PR does / why we need it? this pr implement eagle spec decoding for model runner v2, please see RFC https://github.com/vllm-project/vllm-ascend/issues/5208 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? vLLM version: v0.13.0 --------- Signed-off-by: Ronald1995 --- .../singlecard/model_runner_v2/test_basic.py | 36 ++++ vllm_ascend/patch/__init__.py | 15 +- vllm_ascend/patch/worker/__init__.py | 1 + vllm_ascend/patch/worker/patch_triton.py | 4 + vllm_ascend/patch/worker/patch_v2_egale.py | 166 ++++++++++++++++++ vllm_ascend/worker/v2/attn_utils.py | 15 +- vllm_ascend/worker/v2/model_runner.py | 129 ++++++-------- vllm_ascend/worker/v2/spec_decode/__init__.py | 38 ++++ vllm_ascend/worker/v2/spec_decode/eagle.py | 146 +++++++++++++++ 9 files changed, 468 insertions(+), 82 deletions(-) create mode 100644 vllm_ascend/patch/worker/patch_v2_egale.py create mode 100644 vllm_ascend/worker/v2/spec_decode/__init__.py create mode 100644 vllm_ascend/worker/v2/spec_decode/eagle.py diff --git a/tests/e2e/singlecard/model_runner_v2/test_basic.py b/tests/e2e/singlecard/model_runner_v2/test_basic.py index 83bbe898..672cd274 100644 --- a/tests/e2e/singlecard/model_runner_v2/test_basic.py +++ b/tests/e2e/singlecard/model_runner_v2/test_basic.py @@ -25,6 +25,9 @@ from tests.e2e.conftest import VllmRunner MODELS = ["Qwen/Qwen3-0.6B"] +MAIN_MODELS = ["LLM-Research/Meta-Llama-3.1-8B-Instruct"] +EGALE_MODELS = ["vllm-ascend/EAGLE-LLaMA3.1-Instruct-8B"] + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("max_tokens", [32]) @@ -49,3 +52,36 @@ def test_qwen3_dense_eager_mode( enforce_eager=enforce_eager, ) as runner: runner.model.generate(prompts, sampling_params) + + +@pytest.mark.parametrize("model", MAIN_MODELS) +@pytest.mark.parametrize("eagle_model", EGALE_MODELS) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("enforce_eager", [True]) +@patch.dict(os.environ, {"VLLM_USE_V2_MODEL_RUNNER": "1"}) +def test_egale_spec_decoding( + model: str, + eagle_model: str, + max_tokens: int, + enforce_eager: bool, +) -> None: + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + sampling_params = SamplingParams(max_tokens=max_tokens, temperature=0.0) + with VllmRunner( + model, + max_model_len=1024, + enforce_eager=enforce_eager, + async_scheduling=True, + speculative_config={ + "model": eagle_model, + "method": "eagle", + "num_speculative_tokens": 3, + }, + ) as runner: + runner.model.generate(prompts, sampling_params) diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index a1037855..21601d1f 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -174,7 +174,8 @@ # # ** 6. File: worker/patch_triton.py** # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# 1. `vllm.model_executor.layers.mamba.ops`, `vllm.model_executor.layers.fla.ops` +# 1. `vllm.model_executor.layers.mamba.ops`, `vllm.model_executor.layers.fla.ops`, +# `vllm.v1.worker.gpu.sample.gumbel.gumbel_sample` # Why: # triton ops in vLLM perform not good on NPU. And there is no dispatch mechanism for triton ops. # How: @@ -263,3 +264,15 @@ # Future Plan: # Remove this patch when vLLM support these operators. # +# ** 12. File: worker/patch_v2_eagle.py** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.v1.worker.gpu.spec_decode.eagle.EagleSpeculator.propose` +# Why: +# `propose` method use torch.gather, but the gather operator will +# pollute the arguments passed to it. the bug is reported to huawei +# CANN team, but not fixed yet. +# How: +# clone the out attribute ahead of gather to avoid the bug. +# Future Plan: +# Remove this patch when cann fix the gather bug. +# diff --git a/vllm_ascend/patch/worker/__init__.py b/vllm_ascend/patch/worker/__init__.py index b66e2f4b..2abd9302 100644 --- a/vllm_ascend/patch/worker/__init__.py +++ b/vllm_ascend/patch/worker/__init__.py @@ -32,3 +32,4 @@ import vllm_ascend.patch.worker.patch_qwen3_next # noqa import vllm_ascend.patch.worker.patch_qwen3_next_mtp # noqa import vllm_ascend.patch.worker.patch_rejection_sampler # noqa import vllm_ascend.patch.worker.patch_qwen3_next # noqa +import vllm_ascend.patch.worker.patch_v2_egale # noqa diff --git a/vllm_ascend/patch/worker/patch_triton.py b/vllm_ascend/patch/worker/patch_triton.py index af0909c1..ca731123 100644 --- a/vllm_ascend/patch/worker/patch_triton.py +++ b/vllm_ascend/patch/worker/patch_triton.py @@ -1,4 +1,5 @@ import vllm.model_executor.layers.mamba.ops.causal_conv1d +import vllm.v1.worker.gpu.sample.gumbel from vllm_ascend.ops.triton.fla.chunk import chunk_gated_delta_rule from vllm_ascend.ops.triton.fla.layernorm_guard import LayerNormFn @@ -6,9 +7,12 @@ from vllm_ascend.ops.triton.fla.sigmoid_gating import \ fused_recurrent_gated_delta_rule_fwd_kernel from vllm_ascend.ops.triton.mamba.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update_npu) +from vllm_ascend.worker.v2.sample.gumbel import \ + gumbel_sample as ascend_gumbel_sample vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_update = causal_conv1d_update_npu vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_fn = causal_conv1d_fn vllm.model_executor.layers.fla.ops.fused_recurrent.fused_recurrent_gated_delta_rule_fwd_kernel = fused_recurrent_gated_delta_rule_fwd_kernel vllm.model_executor.layers.fla.ops.layernorm_guard.LayerNormFn = LayerNormFn vllm.model_executor.layers.fla.ops.chunk_gated_delta_rule = chunk_gated_delta_rule +vllm.v1.worker.gpu.sample.gumbel.gumbel_sample = ascend_gumbel_sample diff --git a/vllm_ascend/patch/worker/patch_v2_egale.py b/vllm_ascend/patch/worker/patch_v2_egale.py new file mode 100644 index 00000000..108df8cc --- /dev/null +++ b/vllm_ascend/patch/worker/patch_v2_egale.py @@ -0,0 +1,166 @@ +# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/sample/spec_decode/eagle.py +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +import numpy as np +import torch +import vllm +from vllm.v1.worker.gpu.input_batch import InputBatch +from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample +from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata +from vllm.v1.worker.gpu.spec_decode.eagle import (prepare_eagle_decode, + prepare_eagle_inputs) + +from vllm_ascend.worker.v2.attn_utils import build_attn_metadata + + +@torch.inference_mode() +def propose( + self, + input_batch: InputBatch, + sampling_metadata: SamplingMetadata, + # [num_tokens, hidden_size] + last_hidden_states: torch.Tensor, + # num_layers x [num_tokens, hidden_size] + aux_hidden_states: list[torch.Tensor] | None, + # [num_reqs] + num_sampled: torch.Tensor, + # [num_reqs] + num_rejected: torch.Tensor, + # [num_reqs] + last_sampled: torch.Tensor, + # [num_reqs] + next_prefill_tokens: torch.Tensor, +) -> torch.Tensor: + # NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the + # number of rejected tokens, we maintain the size of eagle's input_ids and + # hidden_states the same as the target model's. This means, we pad each + # request's query length to include any rejected positions. By doing so, + # we can also reuse the attention metadata (e.g., query_start_loc, + # seq_lens) of the target model. + if aux_hidden_states: + assert self.method == "eagle3" + hidden_states = self.model.combine_hidden_states( + torch.cat(aux_hidden_states, dim=-1)) + else: + hidden_states = last_hidden_states + num_tokens = input_batch.num_tokens_after_padding + self.hidden_states[:num_tokens] = hidden_states + + # Get the input ids and last token indices for the speculator. + last_token_indices = prepare_eagle_inputs( + self.input_buffers, + input_batch, + num_sampled, + num_rejected, + last_sampled, + next_prefill_tokens, + ) + + # Prefill: Run the eagle speculator with eager mode. + # TODO(woosuk): Support CUDA graph for prefill. + last_hidden_states, hidden_states = self.run_model( + num_tokens, + input_batch.attn_metadata, + num_tokens_across_dp=None, # FIXME + ) + sample_hidden_states = last_hidden_states[last_token_indices] + logits = self.model.compute_logits(sample_hidden_states) + + num_reqs = input_batch.num_reqs + cu_num_logits = input_batch.cu_num_logits[:num_reqs] + # NOTE(woosuk): For draft sampling, we only consider the temperature + # and ignore the other sampling parameters such as top_k and top_p, + # for simplicity and performance. + # While this may slightly degrade the acceptance rate, it does not + # affect the output distribution after rejection sampling. + # NOTE(Ronald1995): torch.gather will pollute the cache such as self.input_buffers.positions + # the bug is reported to huawei CANN team, but not fixed yet. + # So we clone the tensors before calling torch.gather to avoid the issue. + temperature = self.temperature[:num_reqs].clone() + seeds = self.seeds[:num_reqs].clone() + pos = self.input_buffers.positions[:num_reqs].clone() + # Gather the values and copy them to the pre-allocated buffers. + torch.gather(sampling_metadata.temperature, + 0, + cu_num_logits, + out=temperature) + torch.gather(sampling_metadata.seeds, 0, cu_num_logits, out=seeds) + torch.gather(input_batch.positions, 0, last_token_indices, out=pos) + # NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise + # used for draft and target sampling. + draft_tokens = gumbel_sample(logits, + temperature, + seeds, + pos + 1, + apply_temperature=True) + if self.num_speculative_steps == 1: + # Early exit. + return draft_tokens.view(-1, 1) + + # Save the draft tokens for the first step. + self.draft_tokens[:num_reqs, 0] = draft_tokens + # Prepare the inputs for the decode steps. + prepare_eagle_decode( + draft_tokens, + hidden_states, + last_token_indices, + input_batch.seq_lens, + num_rejected, + self.input_buffers, + self.hidden_states, + self.max_model_len, + self.max_num_reqs, + ) + query_start_loc = self.input_buffers.query_start_loc + query_start_loc_gpu = query_start_loc.gpu[:num_reqs + 1] + slot_mappings = self.block_tables.compute_slot_mappings( + query_start_loc_gpu, pos) + + cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs) + if cudagraph_size is not None: + # Run CUDA graph. + self.cudagraph_manager.run(cudagraph_size) + return self.draft_tokens[:num_reqs] + + # Run eager mode. + query_start_loc.np[:num_reqs + 1] = np.arange(num_reqs + 1) + query_start_loc_cpu = query_start_loc.cpu[:num_reqs + 1] + # HACK(woosuk) + seq_lens_np = np.full(num_reqs, self.max_model_len, dtype=np.int32) + block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables] + + # FIXME(woosuk): This is UNSAFE!! + attn_metadata = build_attn_metadata( + attn_metadata_builders=self.attn_metadata_builders, + num_reqs=num_reqs, + num_tokens=num_reqs, + query_start_loc_gpu=query_start_loc_gpu, + query_start_loc_cpu=query_start_loc_cpu, + seq_lens=self.input_buffers.seq_lens[:num_reqs], + seq_lens_np=seq_lens_np, + num_computed_tokens_cpu=None, # FIXME + block_tables=block_tables, + slot_mappings=slot_mappings, + kv_cache_config=self.kv_cache_config, + ) + self.generate_draft(num_reqs, attn_metadata, + num_tokens_across_dp=None) # FIXME + return self.draft_tokens[:num_reqs] + + +vllm.v1.worker.gpu.spec_decode.eagle.EagleSpeculator.propose = propose diff --git a/vllm_ascend/worker/v2/attn_utils.py b/vllm_ascend/worker/v2/attn_utils.py index 738a84c3..aef319e2 100644 --- a/vllm_ascend/worker/v2/attn_utils.py +++ b/vllm_ascend/worker/v2/attn_utils.py @@ -18,7 +18,7 @@ # from collections.abc import Sequence -from typing import Any +from typing import Any, Tuple import numpy as np import torch @@ -50,13 +50,11 @@ def build_attn_metadata( query_start_loc_gpu: torch.Tensor, query_start_loc_cpu: torch.Tensor, seq_lens: torch.Tensor, - seq_lens_cpu: torch.Tensor, - num_computed_tokens_cpu: torch.Tensor, + seq_lens_np: np.ndarray, + num_computed_tokens_cpu: torch.Tensor | None, block_tables: Sequence[torch.Tensor], slot_mappings: torch.Tensor, kv_cache_config: KVCacheConfig, - decode_token_per_req: int, - actual_seq_lengths_q: list[int], positions: torch.Tensor | None = None, attn_state: Any | None = None, graph_pad_size: int = -1, @@ -67,7 +65,11 @@ def build_attn_metadata( """Build attention metadata for Ascend NPUs.""" # TODO(Ronald1995): optimize AscendCommonAttentionMetadata. max_query_len = int(query_start_loc_cpu.max()) + seq_lens_cpu = torch.from_numpy(seq_lens_np) max_seq_len = int(seq_lens_cpu.max()) + # torch_npu._reshape_and_cache operator requires slot_mappings to + # be torch.int32. + slot_mappings = slot_mappings.to(torch.int32) attn_metadata: dict[str, Any] = {} kv_cache_groups = kv_cache_config.kv_cache_groups @@ -80,14 +82,11 @@ def build_attn_metadata( query_start_loc_cpu=query_start_loc_cpu, seq_lens_cpu=seq_lens_cpu[:num_reqs], seq_lens=seq_lens[:num_reqs], - num_computed_tokens_cpu=num_computed_tokens_cpu, num_reqs=num_reqs, num_actual_tokens=num_tokens, max_query_len=max_query_len, - decode_token_per_req=decode_token_per_req, block_table_tensor=block_table, slot_mapping=slot_mapping, - actual_seq_lengths_q=actual_seq_lengths_q, positions=positions, attn_state=attn_state, graph_pad_size=graph_pad_size, diff --git a/vllm_ascend/worker/v2/model_runner.py b/vllm_ascend/worker/v2/model_runner.py index 99987c5d..2ab579c5 100644 --- a/vllm_ascend/worker/v2/model_runner.py +++ b/vllm_ascend/worker/v2/model_runner.py @@ -21,20 +21,20 @@ import numpy as np import torch from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput +from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu.input_batch import (InputBatch, combine_sampled_and_draft_tokens, prepare_pos_seq_lens, prepare_prefill_inputs) from vllm.v1.worker.gpu.model_runner import GPUModelRunner -from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata -from vllm.v1.worker.gpu.sample.output import SamplerOutput from vllm_ascend.worker.v2.aclgraph_utils import AclGraphManager from vllm_ascend.worker.v2.attn_utils import (build_attn_metadata, build_attn_state) from vllm_ascend.worker.v2.input_batch import AscendInputBuffers from vllm_ascend.worker.v2.sample.sampler import AscendSampler +from vllm_ascend.worker.v2.spec_decode import init_speculator +from vllm_ascend.worker.v2.spec_decode.eagle import AscendEagleSpeculator from vllm_ascend.worker.v2.states import AscendRequestState, uva_wrapper from vllm_ascend.worker.v2.utils import torch_cuda_wrapper @@ -54,12 +54,21 @@ class NPUModelRunner(GPUModelRunner): del self.req_states del self.input_buffers del self.sampler + del self.speculator # NPU specific initializations can be added below. self.cudagraph_manager: AclGraphManager = AclGraphManager( vllm_config, device, ) + + # we define AscendEagleSpeculator in vllm_ascend.worker.v2.spec_decode.eagle + # init_speculator will return AscendEagleSpeculator when eagle is used. + # so here we just call init_speculator to reinitialize speculator. + self.speculator: AscendEagleSpeculator | None = None + if self.speculative_config is not None: + self.speculator = init_speculator(self.vllm_config, self.device) + # AscendRequestState has extra `num_computed_tokens_cpu` attribute. # so reinitialize req_states here. self.req_states: AscendRequestState = AscendRequestState( @@ -87,29 +96,18 @@ class NPUModelRunner(GPUModelRunner): self.sampler: AscendSampler = AscendSampler( logprobs_mode=self.model_config.logprobs_mode, ) - # actual seq lengths for query (used in attention backends). - self.actual_seq_lengths_q: list[int] = [] - # decode token per request (used in attention backends). - self.decode_token_per_req = 1 - - # there attributes are for async scheduling with speculative decoding. - # because npu attention backend still need to use seq_lens_cpu, - # we need to copy num_rejected_tokens back to cpu to help - # update actual seq_lens_cpu. gpu attention backend do not need these - # attributes, cause their attention backends do not use seq_lens_cpu. + # we need to copy num_computed_tokens back to cpu to help + # update actual seq_lens_cpu. gpu attention backend doesn't need these + # attributes, cause their attention backends doesn't use seq_lens_cpu. # and seq_lens_cpu is deprecated in gpu_model_runner_v2. - self.num_rejected_tokens_event = None - self.num_rejectd_tokens_cpu = None - self.num_rejected_token_stream = None - if self.use_async_scheduling and self.do_spec_decode: - self.num_rejected_tokens_event = torch.npu.Event() - self.num_rejected_token_stream = torch.npu.Stream() - self.num_rejectd_tokens_cpu = torch.empty( - self.max_num_reqs, - dtype=torch.int64, - device="cpu", - pin_memory=self.pin_memory, - ) + self.num_computed_tokens_event = torch.npu.Event() + self.num_computed_tokens_stream = torch.npu.Stream() + self.num_computed_tokens_cpu = torch.empty( + self.max_num_reqs, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory, + ) def prepare_inputs( self, @@ -161,9 +159,6 @@ class NPUModelRunner(GPUModelRunner): idx_mapping = self.input_buffers.idx_mapping idx_mapping.np[:num_reqs] = idx_mapping_list idx_mapping_np = idx_mapping.np[:num_reqs] - # add `idx_mapping_cpu` here, because vllm-ascend's self.req_states. - # num_computed_tokens_cpu is actually cpu's tensor, while it's a gpu's - # tensor in vllm gpu_model_runner_v2. idx_mapping_cpu = idx_mapping.cpu[:num_reqs] idx_mapping_npu = idx_mapping.copy_to_gpu(num_reqs) @@ -267,16 +262,12 @@ class NPUModelRunner(GPUModelRunner): query_start_loc_gpu=query_start_loc_gpu, query_start_loc_cpu=query_start_loc_cpu, seq_lens=self.input_buffers.seq_lens, - seq_lens_cpu=self.input_buffers.seq_lens_cpu, - actual_seq_lengths_q=self.actual_seq_lengths_q, + seq_lens_np=self.input_buffers.seq_lens_np, num_computed_tokens_cpu=self.req_states. num_computed_tokens_cpu[idx_mapping_cpu], block_tables=block_tables, - # torch_npu._reshape_and_cache operator requires slot_mappings to - # be torch.int32. - slot_mappings=slot_mappings.to(torch.int32), + slot_mappings=slot_mappings, kv_cache_config=self.kv_cache_config, - decode_token_per_req=self.decode_token_per_req, attn_state=attn_state, ) @@ -302,40 +293,35 @@ class NPUModelRunner(GPUModelRunner): cu_num_logits=cu_num_logits, ) - def sample( + def postprocess( self, - hidden_states: torch.Tensor, - input_batch: InputBatch, - sampling_metadata: SamplingMetadata, - grammar_output: GrammarOutput | None, - ) -> tuple[SamplerOutput, torch.Tensor, torch.Tensor]: - """Override GPUModelRunner.sample for Ascend NPUs. - when using async scheduling with speculative decoding, - we need to copy mpu's num_rejected tensor to cpu. - these operations aren't needed in gpu_model_runner_v2, - because gpu attention backends do not use seq_lens_cpu anymore. + input_batch, + sampled_tokens, + num_sampled, + num_rejected, + ): + """Override GPUModelRunner.postprocess for Ascend NPUs. + npu attention backends need seq_lens_cpu to work. + so we need to copy num_computed_tokens back to cpu here. """ - sampler_output, num_sampled, num_rejected = super().sample( - hidden_states, + super().postprocess( input_batch, - sampling_metadata, - grammar_output, + sampled_tokens, + num_sampled, + num_rejected, ) - if self.num_rejected_tokens_event is not None: - # npu attention backend still need to use seq_lens_cpu, - # when doing speculative decoding with async_scheduling, - # we need to copy num_rejected_tokens back to cpu. - default_stream = torch.cuda.current_stream() - assert self.num_rejected_token_stream is not None - assert self.num_rejectd_tokens_cpu is not None - with torch.npu.stream(self.num_rejected_token_stream): - self.num_rejected_token_stream.wait_stream(default_stream) - self.num_rejectd_tokens_cpu.copy_( - num_rejected, - non_blocking=True, - ) - self.num_rejected_tokens_event.record() - return sampler_output, num_sampled, num_rejected + # npu attention backend still need to use seq_lens_cpu, + # we need to copy num_computed_tokens back to cpu. + default_stream = torch.cuda.current_stream() + assert self.num_computed_tokens_stream is not None + assert self.num_computed_tokens_cpu is not None + with torch.npu.stream(self.num_computed_tokens_stream): + self.num_computed_tokens_stream.wait_stream(default_stream) + self.num_computed_tokens_cpu.copy_( + self.req_states.num_computed_tokens, + non_blocking=True, + ) + self.num_computed_tokens_event.record() def _update_seq_lens_cpu( self, @@ -343,17 +329,14 @@ class NPUModelRunner(GPUModelRunner): req_ids: list[str], ): num_scheduled_tokens = scheduler_output.num_scheduled_tokens - - # update num_computed_tokens_cpu - # TODO(Ronald1995): update num_computed_tokens_cpu by considering - # num_rejectd_tokens. - for req_id, num_computed_token in zip( - scheduler_output.scheduled_cached_reqs.req_ids, - scheduler_output.scheduled_cached_reqs.num_computed_tokens, - ): + # wait for num_computed_tokens copy to cpu stream to finish. + self.num_computed_tokens_event.synchronize() + for req_id in scheduler_output.scheduled_cached_reqs.req_ids: req_index = self.req_states.req_id_to_index[req_id] + # num_computed_tokens_cpu has reverted by num_rejected_tokens already. + # in super postprocess method. self.req_states.num_computed_tokens_cpu[ - req_index] = num_computed_token + req_index] = self.num_computed_tokens_cpu[req_index] # update seq_lens_cpu for i, req_id in enumerate(req_ids): diff --git a/vllm_ascend/worker/v2/spec_decode/__init__.py b/vllm_ascend/worker/v2/spec_decode/__init__.py new file mode 100644 index 00000000..a2841cab --- /dev/null +++ b/vllm_ascend/worker/v2/spec_decode/__init__.py @@ -0,0 +1,38 @@ +# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/sample/spec_decode/__init__.py +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +import torch +from vllm.config import VllmConfig + + +def init_speculator( + vllm_config: VllmConfig, + device: torch.device, +): + """Override GPU init_speculator for Ascend NPUs. + Use AscendEagleSpeculator when eagle is used. + """ + speculative_config = vllm_config.speculative_config + assert speculative_config is not None + if speculative_config.use_eagle(): + from vllm_ascend.worker.v2.spec_decode.eagle import \ + AscendEagleSpeculator + + return AscendEagleSpeculator(vllm_config, device) + raise NotImplementedError( + f"{speculative_config.method} is not supported yet.") diff --git a/vllm_ascend/worker/v2/spec_decode/eagle.py b/vllm_ascend/worker/v2/spec_decode/eagle.py new file mode 100644 index 00000000..fe23b9ce --- /dev/null +++ b/vllm_ascend/worker/v2/spec_decode/eagle.py @@ -0,0 +1,146 @@ +# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/sample/spec_decode/eagle.py +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +from contextlib import contextmanager +from typing import Any + +import torch +import vllm +from vllm.config import VllmConfig +from vllm.v1.worker.gpu.input_batch import InputBatch +from vllm.v1.worker.gpu.spec_decode.eagle import EagleSpeculator + +from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.worker.v2.attn_utils import build_attn_metadata + + +class AscendEagleSpeculator(EagleSpeculator): + + def __init__(self, vllm_config: VllmConfig, device: torch.device): + """Override GPU EagleSpeculator.__init__ for Ascend NPUs. + attnention metadata building in Ascend backend needs more information, + such as seq_lens_cpu from input_batch, so we need to override __init__. + """ + super().__init__(vllm_config, device) + # when in decode phase of eagle speculator, we need some value in + # main model's input_batch. so we keep a reference here. + self.input_batch: InputBatch | None = None + + def propose( + self, + input_batch, + sampling_metadata, + last_hidden_states, + aux_hidden_states, + num_sampled, + num_rejected, + last_sampled, + next_prefill_tokens, + ): + """Override GPU EagleSpeculator.propose for Ascend NPUs, + because npu attention metadata needs more information, + we need to cache input_batch, so we can use it later in + generate_draft. + """ + self.input_batch = input_batch + # wrap build_attn_metadata to use Ascend attention metadata building. + # so we can call super().propose() directly. + with build_attn_metadata_wrapper(): + return super().propose( + input_batch, + sampling_metadata, + last_hidden_states, + aux_hidden_states, + num_sampled, + num_rejected, + last_sampled, + next_prefill_tokens, + ) + + def generate_draft( + self, + num_reqs: int, + attn_metadata: dict[str, Any], + num_tokens_across_dp, + ): + """Override GPU EagleSpeculator.generate_draft for Ascend NPUs, because + attn_metadata is created in super propose method, it does not have some + attribute that Ascend attention backend needs, so we update it. + """ + self._update_decode_attn_metadata(attn_metadata) + + return super().generate_draft( + num_reqs, + attn_metadata, + num_tokens_across_dp, + ) + + @torch.inference_mode() + def run_model( + self, + num_tokens: int, + attn_metadata: dict[str, Any], + num_tokens_across_dp: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Override GPU EagleSpeculator.run_model for Ascend NPUs, because + in decode phase, we need to update seq_lens_cpu in attn_metadata after + run model. + """ + last_hidden_states, hidden_states = super().run_model( + num_tokens, + attn_metadata, + num_tokens_across_dp, + ) + + # attn_metadata is None in profile_run and dummy_run. + if attn_metadata is not None: + for attn_meta in attn_metadata.values(): + # seq_lens in AscendMetadata is a cpu tensor. + attn_meta.seq_lens = attn_meta.seq_lens + 1 + attn_meta.seq_len_list = attn_meta.seq_lens.tolist() + return last_hidden_states, hidden_states + + def _update_decode_attn_metadata( + self, + attn_metadata: dict[str, Any], + ): + """Update attention metadata for decode phase on Ascend NPUs.""" + attn_state = AscendAttentionState.DecodeOnly + seq_lens_cpu = self._get_seq_lens_cpu() + # attn_metadata is build in vllm's super class. + # We need to update attn_state for each layer's metadata. + for metadata in attn_metadata.values(): + metadata.attn_state = attn_state + metadata.seq_lens_cpu = seq_lens_cpu + + def _get_seq_lens_cpu(self) -> torch.Tensor: + """Get seq_lens_cpu from input_batch.""" + assert self.input_batch is not None + seq_lens_cpu = torch.from_numpy(self.input_batch.seq_lens_np) + return seq_lens_cpu + + +@contextmanager +def build_attn_metadata_wrapper(): + """Context manager to override attention metadata building for Ascend NPUs.""" + original_func = vllm.v1.worker.gpu.spec_decode.eagle.build_attn_metadata + try: + vllm.v1.worker.gpu.spec_decode.eagle.build_attn_metadata = build_attn_metadata + yield + finally: + vllm.v1.worker.gpu.spec_decode.eagle.build_attn_metadata = original_func