[Feature] implement eagle spec decoding for model runner v2 (#5840)

### What this PR does / why we need it?
this pr implement eagle spec decoding for model runner v2, please see
RFC https://github.com/vllm-project/vllm-ascend/issues/5208

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
vLLM version: v0.13.0

---------

Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
Ronald
2026-01-14 09:18:05 +08:00
committed by GitHub
parent 0415e694cd
commit e20813f441
9 changed files with 468 additions and 82 deletions

View File

@@ -21,20 +21,20 @@ import numpy as np
import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu.input_batch import (InputBatch,
combine_sampled_and_draft_tokens,
prepare_pos_seq_lens,
prepare_prefill_inputs)
from vllm.v1.worker.gpu.model_runner import GPUModelRunner
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu.sample.output import SamplerOutput
from vllm_ascend.worker.v2.aclgraph_utils import AclGraphManager
from vllm_ascend.worker.v2.attn_utils import (build_attn_metadata,
build_attn_state)
from vllm_ascend.worker.v2.input_batch import AscendInputBuffers
from vllm_ascend.worker.v2.sample.sampler import AscendSampler
from vllm_ascend.worker.v2.spec_decode import init_speculator
from vllm_ascend.worker.v2.spec_decode.eagle import AscendEagleSpeculator
from vllm_ascend.worker.v2.states import AscendRequestState, uva_wrapper
from vllm_ascend.worker.v2.utils import torch_cuda_wrapper
@@ -54,12 +54,21 @@ class NPUModelRunner(GPUModelRunner):
del self.req_states
del self.input_buffers
del self.sampler
del self.speculator
# NPU specific initializations can be added below.
self.cudagraph_manager: AclGraphManager = AclGraphManager(
vllm_config,
device,
)
# we define AscendEagleSpeculator in vllm_ascend.worker.v2.spec_decode.eagle
# init_speculator will return AscendEagleSpeculator when eagle is used.
# so here we just call init_speculator to reinitialize speculator.
self.speculator: AscendEagleSpeculator | None = None
if self.speculative_config is not None:
self.speculator = init_speculator(self.vllm_config, self.device)
# AscendRequestState has extra `num_computed_tokens_cpu` attribute.
# so reinitialize req_states here.
self.req_states: AscendRequestState = AscendRequestState(
@@ -87,29 +96,18 @@ class NPUModelRunner(GPUModelRunner):
self.sampler: AscendSampler = AscendSampler(
logprobs_mode=self.model_config.logprobs_mode, )
# actual seq lengths for query (used in attention backends).
self.actual_seq_lengths_q: list[int] = []
# decode token per request (used in attention backends).
self.decode_token_per_req = 1
# there attributes are for async scheduling with speculative decoding.
# because npu attention backend still need to use seq_lens_cpu,
# we need to copy num_rejected_tokens back to cpu to help
# update actual seq_lens_cpu. gpu attention backend do not need these
# attributes, cause their attention backends do not use seq_lens_cpu.
# we need to copy num_computed_tokens back to cpu to help
# update actual seq_lens_cpu. gpu attention backend doesn't need these
# attributes, cause their attention backends doesn't use seq_lens_cpu.
# and seq_lens_cpu is deprecated in gpu_model_runner_v2.
self.num_rejected_tokens_event = None
self.num_rejectd_tokens_cpu = None
self.num_rejected_token_stream = None
if self.use_async_scheduling and self.do_spec_decode:
self.num_rejected_tokens_event = torch.npu.Event()
self.num_rejected_token_stream = torch.npu.Stream()
self.num_rejectd_tokens_cpu = torch.empty(
self.max_num_reqs,
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory,
)
self.num_computed_tokens_event = torch.npu.Event()
self.num_computed_tokens_stream = torch.npu.Stream()
self.num_computed_tokens_cpu = torch.empty(
self.max_num_reqs,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory,
)
def prepare_inputs(
self,
@@ -161,9 +159,6 @@ class NPUModelRunner(GPUModelRunner):
idx_mapping = self.input_buffers.idx_mapping
idx_mapping.np[:num_reqs] = idx_mapping_list
idx_mapping_np = idx_mapping.np[:num_reqs]
# add `idx_mapping_cpu` here, because vllm-ascend's self.req_states.
# num_computed_tokens_cpu is actually cpu's tensor, while it's a gpu's
# tensor in vllm gpu_model_runner_v2.
idx_mapping_cpu = idx_mapping.cpu[:num_reqs]
idx_mapping_npu = idx_mapping.copy_to_gpu(num_reqs)
@@ -267,16 +262,12 @@ class NPUModelRunner(GPUModelRunner):
query_start_loc_gpu=query_start_loc_gpu,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=self.input_buffers.seq_lens,
seq_lens_cpu=self.input_buffers.seq_lens_cpu,
actual_seq_lengths_q=self.actual_seq_lengths_q,
seq_lens_np=self.input_buffers.seq_lens_np,
num_computed_tokens_cpu=self.req_states.
num_computed_tokens_cpu[idx_mapping_cpu],
block_tables=block_tables,
# torch_npu._reshape_and_cache operator requires slot_mappings to
# be torch.int32.
slot_mappings=slot_mappings.to(torch.int32),
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
decode_token_per_req=self.decode_token_per_req,
attn_state=attn_state,
)
@@ -302,40 +293,35 @@ class NPUModelRunner(GPUModelRunner):
cu_num_logits=cu_num_logits,
)
def sample(
def postprocess(
self,
hidden_states: torch.Tensor,
input_batch: InputBatch,
sampling_metadata: SamplingMetadata,
grammar_output: GrammarOutput | None,
) -> tuple[SamplerOutput, torch.Tensor, torch.Tensor]:
"""Override GPUModelRunner.sample for Ascend NPUs.
when using async scheduling with speculative decoding,
we need to copy mpu's num_rejected tensor to cpu.
these operations aren't needed in gpu_model_runner_v2,
because gpu attention backends do not use seq_lens_cpu anymore.
input_batch,
sampled_tokens,
num_sampled,
num_rejected,
):
"""Override GPUModelRunner.postprocess for Ascend NPUs.
npu attention backends need seq_lens_cpu to work.
so we need to copy num_computed_tokens back to cpu here.
"""
sampler_output, num_sampled, num_rejected = super().sample(
hidden_states,
super().postprocess(
input_batch,
sampling_metadata,
grammar_output,
sampled_tokens,
num_sampled,
num_rejected,
)
if self.num_rejected_tokens_event is not None:
# npu attention backend still need to use seq_lens_cpu,
# when doing speculative decoding with async_scheduling,
# we need to copy num_rejected_tokens back to cpu.
default_stream = torch.cuda.current_stream()
assert self.num_rejected_token_stream is not None
assert self.num_rejectd_tokens_cpu is not None
with torch.npu.stream(self.num_rejected_token_stream):
self.num_rejected_token_stream.wait_stream(default_stream)
self.num_rejectd_tokens_cpu.copy_(
num_rejected,
non_blocking=True,
)
self.num_rejected_tokens_event.record()
return sampler_output, num_sampled, num_rejected
# npu attention backend still need to use seq_lens_cpu,
# we need to copy num_computed_tokens back to cpu.
default_stream = torch.cuda.current_stream()
assert self.num_computed_tokens_stream is not None
assert self.num_computed_tokens_cpu is not None
with torch.npu.stream(self.num_computed_tokens_stream):
self.num_computed_tokens_stream.wait_stream(default_stream)
self.num_computed_tokens_cpu.copy_(
self.req_states.num_computed_tokens,
non_blocking=True,
)
self.num_computed_tokens_event.record()
def _update_seq_lens_cpu(
self,
@@ -343,17 +329,14 @@ class NPUModelRunner(GPUModelRunner):
req_ids: list[str],
):
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
# update num_computed_tokens_cpu
# TODO(Ronald1995): update num_computed_tokens_cpu by considering
# num_rejectd_tokens.
for req_id, num_computed_token in zip(
scheduler_output.scheduled_cached_reqs.req_ids,
scheduler_output.scheduled_cached_reqs.num_computed_tokens,
):
# wait for num_computed_tokens copy to cpu stream to finish.
self.num_computed_tokens_event.synchronize()
for req_id in scheduler_output.scheduled_cached_reqs.req_ids:
req_index = self.req_states.req_id_to_index[req_id]
# num_computed_tokens_cpu has reverted by num_rejected_tokens already.
# in super postprocess method.
self.req_states.num_computed_tokens_cpu[
req_index] = num_computed_token
req_index] = self.num_computed_tokens_cpu[req_index]
# update seq_lens_cpu
for i, req_id in enumerate(req_ids):