[Feature] implement eagle spec decoding for model runner v2 (#5840)

### What this PR does / why we need it?
this pr implement eagle spec decoding for model runner v2, please see
RFC https://github.com/vllm-project/vllm-ascend/issues/5208

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
vLLM version: v0.13.0

---------

Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
Ronald
2026-01-14 09:18:05 +08:00
committed by GitHub
parent 0415e694cd
commit e20813f441
9 changed files with 468 additions and 82 deletions

View File

@@ -18,7 +18,7 @@
#
from collections.abc import Sequence
from typing import Any
from typing import Any, Tuple
import numpy as np
import torch
@@ -50,13 +50,11 @@ def build_attn_metadata(
query_start_loc_gpu: torch.Tensor,
query_start_loc_cpu: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_cpu: torch.Tensor,
num_computed_tokens_cpu: torch.Tensor,
seq_lens_np: np.ndarray,
num_computed_tokens_cpu: torch.Tensor | None,
block_tables: Sequence[torch.Tensor],
slot_mappings: torch.Tensor,
kv_cache_config: KVCacheConfig,
decode_token_per_req: int,
actual_seq_lengths_q: list[int],
positions: torch.Tensor | None = None,
attn_state: Any | None = None,
graph_pad_size: int = -1,
@@ -67,7 +65,11 @@ def build_attn_metadata(
"""Build attention metadata for Ascend NPUs."""
# TODO(Ronald1995): optimize AscendCommonAttentionMetadata.
max_query_len = int(query_start_loc_cpu.max())
seq_lens_cpu = torch.from_numpy(seq_lens_np)
max_seq_len = int(seq_lens_cpu.max())
# torch_npu._reshape_and_cache operator requires slot_mappings to
# be torch.int32.
slot_mappings = slot_mappings.to(torch.int32)
attn_metadata: dict[str, Any] = {}
kv_cache_groups = kv_cache_config.kv_cache_groups
@@ -80,14 +82,11 @@ def build_attn_metadata(
query_start_loc_cpu=query_start_loc_cpu,
seq_lens_cpu=seq_lens_cpu[:num_reqs],
seq_lens=seq_lens[:num_reqs],
num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
decode_token_per_req=decode_token_per_req,
block_table_tensor=block_table,
slot_mapping=slot_mapping,
actual_seq_lengths_q=actual_seq_lengths_q,
positions=positions,
attn_state=attn_state,
graph_pad_size=graph_pad_size,

View File

@@ -21,20 +21,20 @@ import numpy as np
import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu.input_batch import (InputBatch,
combine_sampled_and_draft_tokens,
prepare_pos_seq_lens,
prepare_prefill_inputs)
from vllm.v1.worker.gpu.model_runner import GPUModelRunner
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu.sample.output import SamplerOutput
from vllm_ascend.worker.v2.aclgraph_utils import AclGraphManager
from vllm_ascend.worker.v2.attn_utils import (build_attn_metadata,
build_attn_state)
from vllm_ascend.worker.v2.input_batch import AscendInputBuffers
from vllm_ascend.worker.v2.sample.sampler import AscendSampler
from vllm_ascend.worker.v2.spec_decode import init_speculator
from vllm_ascend.worker.v2.spec_decode.eagle import AscendEagleSpeculator
from vllm_ascend.worker.v2.states import AscendRequestState, uva_wrapper
from vllm_ascend.worker.v2.utils import torch_cuda_wrapper
@@ -54,12 +54,21 @@ class NPUModelRunner(GPUModelRunner):
del self.req_states
del self.input_buffers
del self.sampler
del self.speculator
# NPU specific initializations can be added below.
self.cudagraph_manager: AclGraphManager = AclGraphManager(
vllm_config,
device,
)
# we define AscendEagleSpeculator in vllm_ascend.worker.v2.spec_decode.eagle
# init_speculator will return AscendEagleSpeculator when eagle is used.
# so here we just call init_speculator to reinitialize speculator.
self.speculator: AscendEagleSpeculator | None = None
if self.speculative_config is not None:
self.speculator = init_speculator(self.vllm_config, self.device)
# AscendRequestState has extra `num_computed_tokens_cpu` attribute.
# so reinitialize req_states here.
self.req_states: AscendRequestState = AscendRequestState(
@@ -87,29 +96,18 @@ class NPUModelRunner(GPUModelRunner):
self.sampler: AscendSampler = AscendSampler(
logprobs_mode=self.model_config.logprobs_mode, )
# actual seq lengths for query (used in attention backends).
self.actual_seq_lengths_q: list[int] = []
# decode token per request (used in attention backends).
self.decode_token_per_req = 1
# there attributes are for async scheduling with speculative decoding.
# because npu attention backend still need to use seq_lens_cpu,
# we need to copy num_rejected_tokens back to cpu to help
# update actual seq_lens_cpu. gpu attention backend do not need these
# attributes, cause their attention backends do not use seq_lens_cpu.
# we need to copy num_computed_tokens back to cpu to help
# update actual seq_lens_cpu. gpu attention backend doesn't need these
# attributes, cause their attention backends doesn't use seq_lens_cpu.
# and seq_lens_cpu is deprecated in gpu_model_runner_v2.
self.num_rejected_tokens_event = None
self.num_rejectd_tokens_cpu = None
self.num_rejected_token_stream = None
if self.use_async_scheduling and self.do_spec_decode:
self.num_rejected_tokens_event = torch.npu.Event()
self.num_rejected_token_stream = torch.npu.Stream()
self.num_rejectd_tokens_cpu = torch.empty(
self.max_num_reqs,
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory,
)
self.num_computed_tokens_event = torch.npu.Event()
self.num_computed_tokens_stream = torch.npu.Stream()
self.num_computed_tokens_cpu = torch.empty(
self.max_num_reqs,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory,
)
def prepare_inputs(
self,
@@ -161,9 +159,6 @@ class NPUModelRunner(GPUModelRunner):
idx_mapping = self.input_buffers.idx_mapping
idx_mapping.np[:num_reqs] = idx_mapping_list
idx_mapping_np = idx_mapping.np[:num_reqs]
# add `idx_mapping_cpu` here, because vllm-ascend's self.req_states.
# num_computed_tokens_cpu is actually cpu's tensor, while it's a gpu's
# tensor in vllm gpu_model_runner_v2.
idx_mapping_cpu = idx_mapping.cpu[:num_reqs]
idx_mapping_npu = idx_mapping.copy_to_gpu(num_reqs)
@@ -267,16 +262,12 @@ class NPUModelRunner(GPUModelRunner):
query_start_loc_gpu=query_start_loc_gpu,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=self.input_buffers.seq_lens,
seq_lens_cpu=self.input_buffers.seq_lens_cpu,
actual_seq_lengths_q=self.actual_seq_lengths_q,
seq_lens_np=self.input_buffers.seq_lens_np,
num_computed_tokens_cpu=self.req_states.
num_computed_tokens_cpu[idx_mapping_cpu],
block_tables=block_tables,
# torch_npu._reshape_and_cache operator requires slot_mappings to
# be torch.int32.
slot_mappings=slot_mappings.to(torch.int32),
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
decode_token_per_req=self.decode_token_per_req,
attn_state=attn_state,
)
@@ -302,40 +293,35 @@ class NPUModelRunner(GPUModelRunner):
cu_num_logits=cu_num_logits,
)
def sample(
def postprocess(
self,
hidden_states: torch.Tensor,
input_batch: InputBatch,
sampling_metadata: SamplingMetadata,
grammar_output: GrammarOutput | None,
) -> tuple[SamplerOutput, torch.Tensor, torch.Tensor]:
"""Override GPUModelRunner.sample for Ascend NPUs.
when using async scheduling with speculative decoding,
we need to copy mpu's num_rejected tensor to cpu.
these operations aren't needed in gpu_model_runner_v2,
because gpu attention backends do not use seq_lens_cpu anymore.
input_batch,
sampled_tokens,
num_sampled,
num_rejected,
):
"""Override GPUModelRunner.postprocess for Ascend NPUs.
npu attention backends need seq_lens_cpu to work.
so we need to copy num_computed_tokens back to cpu here.
"""
sampler_output, num_sampled, num_rejected = super().sample(
hidden_states,
super().postprocess(
input_batch,
sampling_metadata,
grammar_output,
sampled_tokens,
num_sampled,
num_rejected,
)
if self.num_rejected_tokens_event is not None:
# npu attention backend still need to use seq_lens_cpu,
# when doing speculative decoding with async_scheduling,
# we need to copy num_rejected_tokens back to cpu.
default_stream = torch.cuda.current_stream()
assert self.num_rejected_token_stream is not None
assert self.num_rejectd_tokens_cpu is not None
with torch.npu.stream(self.num_rejected_token_stream):
self.num_rejected_token_stream.wait_stream(default_stream)
self.num_rejectd_tokens_cpu.copy_(
num_rejected,
non_blocking=True,
)
self.num_rejected_tokens_event.record()
return sampler_output, num_sampled, num_rejected
# npu attention backend still need to use seq_lens_cpu,
# we need to copy num_computed_tokens back to cpu.
default_stream = torch.cuda.current_stream()
assert self.num_computed_tokens_stream is not None
assert self.num_computed_tokens_cpu is not None
with torch.npu.stream(self.num_computed_tokens_stream):
self.num_computed_tokens_stream.wait_stream(default_stream)
self.num_computed_tokens_cpu.copy_(
self.req_states.num_computed_tokens,
non_blocking=True,
)
self.num_computed_tokens_event.record()
def _update_seq_lens_cpu(
self,
@@ -343,17 +329,14 @@ class NPUModelRunner(GPUModelRunner):
req_ids: list[str],
):
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
# update num_computed_tokens_cpu
# TODO(Ronald1995): update num_computed_tokens_cpu by considering
# num_rejectd_tokens.
for req_id, num_computed_token in zip(
scheduler_output.scheduled_cached_reqs.req_ids,
scheduler_output.scheduled_cached_reqs.num_computed_tokens,
):
# wait for num_computed_tokens copy to cpu stream to finish.
self.num_computed_tokens_event.synchronize()
for req_id in scheduler_output.scheduled_cached_reqs.req_ids:
req_index = self.req_states.req_id_to_index[req_id]
# num_computed_tokens_cpu has reverted by num_rejected_tokens already.
# in super postprocess method.
self.req_states.num_computed_tokens_cpu[
req_index] = num_computed_token
req_index] = self.num_computed_tokens_cpu[req_index]
# update seq_lens_cpu
for i, req_id in enumerate(req_ids):

View File

@@ -0,0 +1,38 @@
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/sample/spec_decode/__init__.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import torch
from vllm.config import VllmConfig
def init_speculator(
vllm_config: VllmConfig,
device: torch.device,
):
"""Override GPU init_speculator for Ascend NPUs.
Use AscendEagleSpeculator when eagle is used.
"""
speculative_config = vllm_config.speculative_config
assert speculative_config is not None
if speculative_config.use_eagle():
from vllm_ascend.worker.v2.spec_decode.eagle import \
AscendEagleSpeculator
return AscendEagleSpeculator(vllm_config, device)
raise NotImplementedError(
f"{speculative_config.method} is not supported yet.")

View File

@@ -0,0 +1,146 @@
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/sample/spec_decode/eagle.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from contextlib import contextmanager
from typing import Any
import torch
import vllm
from vllm.config import VllmConfig
from vllm.v1.worker.gpu.input_batch import InputBatch
from vllm.v1.worker.gpu.spec_decode.eagle import EagleSpeculator
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.worker.v2.attn_utils import build_attn_metadata
class AscendEagleSpeculator(EagleSpeculator):
def __init__(self, vllm_config: VllmConfig, device: torch.device):
"""Override GPU EagleSpeculator.__init__ for Ascend NPUs.
attnention metadata building in Ascend backend needs more information,
such as seq_lens_cpu from input_batch, so we need to override __init__.
"""
super().__init__(vllm_config, device)
# when in decode phase of eagle speculator, we need some value in
# main model's input_batch. so we keep a reference here.
self.input_batch: InputBatch | None = None
def propose(
self,
input_batch,
sampling_metadata,
last_hidden_states,
aux_hidden_states,
num_sampled,
num_rejected,
last_sampled,
next_prefill_tokens,
):
"""Override GPU EagleSpeculator.propose for Ascend NPUs,
because npu attention metadata needs more information,
we need to cache input_batch, so we can use it later in
generate_draft.
"""
self.input_batch = input_batch
# wrap build_attn_metadata to use Ascend attention metadata building.
# so we can call super().propose() directly.
with build_attn_metadata_wrapper():
return super().propose(
input_batch,
sampling_metadata,
last_hidden_states,
aux_hidden_states,
num_sampled,
num_rejected,
last_sampled,
next_prefill_tokens,
)
def generate_draft(
self,
num_reqs: int,
attn_metadata: dict[str, Any],
num_tokens_across_dp,
):
"""Override GPU EagleSpeculator.generate_draft for Ascend NPUs, because
attn_metadata is created in super propose method, it does not have some
attribute that Ascend attention backend needs, so we update it.
"""
self._update_decode_attn_metadata(attn_metadata)
return super().generate_draft(
num_reqs,
attn_metadata,
num_tokens_across_dp,
)
@torch.inference_mode()
def run_model(
self,
num_tokens: int,
attn_metadata: dict[str, Any],
num_tokens_across_dp: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Override GPU EagleSpeculator.run_model for Ascend NPUs, because
in decode phase, we need to update seq_lens_cpu in attn_metadata after
run model.
"""
last_hidden_states, hidden_states = super().run_model(
num_tokens,
attn_metadata,
num_tokens_across_dp,
)
# attn_metadata is None in profile_run and dummy_run.
if attn_metadata is not None:
for attn_meta in attn_metadata.values():
# seq_lens in AscendMetadata is a cpu tensor.
attn_meta.seq_lens = attn_meta.seq_lens + 1
attn_meta.seq_len_list = attn_meta.seq_lens.tolist()
return last_hidden_states, hidden_states
def _update_decode_attn_metadata(
self,
attn_metadata: dict[str, Any],
):
"""Update attention metadata for decode phase on Ascend NPUs."""
attn_state = AscendAttentionState.DecodeOnly
seq_lens_cpu = self._get_seq_lens_cpu()
# attn_metadata is build in vllm's super class.
# We need to update attn_state for each layer's metadata.
for metadata in attn_metadata.values():
metadata.attn_state = attn_state
metadata.seq_lens_cpu = seq_lens_cpu
def _get_seq_lens_cpu(self) -> torch.Tensor:
"""Get seq_lens_cpu from input_batch."""
assert self.input_batch is not None
seq_lens_cpu = torch.from_numpy(self.input_batch.seq_lens_np)
return seq_lens_cpu
@contextmanager
def build_attn_metadata_wrapper():
"""Context manager to override attention metadata building for Ascend NPUs."""
original_func = vllm.v1.worker.gpu.spec_decode.eagle.build_attn_metadata
try:
vllm.v1.worker.gpu.spec_decode.eagle.build_attn_metadata = build_attn_metadata
yield
finally:
vllm.v1.worker.gpu.spec_decode.eagle.build_attn_metadata = original_func