[Feature] implement eagle spec decoding for model runner v2 (#5840)
### What this PR does / why we need it? this pr implement eagle spec decoding for model runner v2, please see RFC https://github.com/vllm-project/vllm-ascend/issues/5208 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? vLLM version: v0.13.0 --------- Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
@@ -21,20 +21,20 @@ import numpy as np
|
||||
import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu.input_batch import (InputBatch,
|
||||
combine_sampled_and_draft_tokens,
|
||||
prepare_pos_seq_lens,
|
||||
prepare_prefill_inputs)
|
||||
from vllm.v1.worker.gpu.model_runner import GPUModelRunner
|
||||
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.gpu.sample.output import SamplerOutput
|
||||
|
||||
from vllm_ascend.worker.v2.aclgraph_utils import AclGraphManager
|
||||
from vllm_ascend.worker.v2.attn_utils import (build_attn_metadata,
|
||||
build_attn_state)
|
||||
from vllm_ascend.worker.v2.input_batch import AscendInputBuffers
|
||||
from vllm_ascend.worker.v2.sample.sampler import AscendSampler
|
||||
from vllm_ascend.worker.v2.spec_decode import init_speculator
|
||||
from vllm_ascend.worker.v2.spec_decode.eagle import AscendEagleSpeculator
|
||||
from vllm_ascend.worker.v2.states import AscendRequestState, uva_wrapper
|
||||
from vllm_ascend.worker.v2.utils import torch_cuda_wrapper
|
||||
|
||||
@@ -54,12 +54,21 @@ class NPUModelRunner(GPUModelRunner):
|
||||
del self.req_states
|
||||
del self.input_buffers
|
||||
del self.sampler
|
||||
del self.speculator
|
||||
|
||||
# NPU specific initializations can be added below.
|
||||
self.cudagraph_manager: AclGraphManager = AclGraphManager(
|
||||
vllm_config,
|
||||
device,
|
||||
)
|
||||
|
||||
# we define AscendEagleSpeculator in vllm_ascend.worker.v2.spec_decode.eagle
|
||||
# init_speculator will return AscendEagleSpeculator when eagle is used.
|
||||
# so here we just call init_speculator to reinitialize speculator.
|
||||
self.speculator: AscendEagleSpeculator | None = None
|
||||
if self.speculative_config is not None:
|
||||
self.speculator = init_speculator(self.vllm_config, self.device)
|
||||
|
||||
# AscendRequestState has extra `num_computed_tokens_cpu` attribute.
|
||||
# so reinitialize req_states here.
|
||||
self.req_states: AscendRequestState = AscendRequestState(
|
||||
@@ -87,29 +96,18 @@ class NPUModelRunner(GPUModelRunner):
|
||||
self.sampler: AscendSampler = AscendSampler(
|
||||
logprobs_mode=self.model_config.logprobs_mode, )
|
||||
|
||||
# actual seq lengths for query (used in attention backends).
|
||||
self.actual_seq_lengths_q: list[int] = []
|
||||
# decode token per request (used in attention backends).
|
||||
self.decode_token_per_req = 1
|
||||
|
||||
# there attributes are for async scheduling with speculative decoding.
|
||||
# because npu attention backend still need to use seq_lens_cpu,
|
||||
# we need to copy num_rejected_tokens back to cpu to help
|
||||
# update actual seq_lens_cpu. gpu attention backend do not need these
|
||||
# attributes, cause their attention backends do not use seq_lens_cpu.
|
||||
# we need to copy num_computed_tokens back to cpu to help
|
||||
# update actual seq_lens_cpu. gpu attention backend doesn't need these
|
||||
# attributes, cause their attention backends doesn't use seq_lens_cpu.
|
||||
# and seq_lens_cpu is deprecated in gpu_model_runner_v2.
|
||||
self.num_rejected_tokens_event = None
|
||||
self.num_rejectd_tokens_cpu = None
|
||||
self.num_rejected_token_stream = None
|
||||
if self.use_async_scheduling and self.do_spec_decode:
|
||||
self.num_rejected_tokens_event = torch.npu.Event()
|
||||
self.num_rejected_token_stream = torch.npu.Stream()
|
||||
self.num_rejectd_tokens_cpu = torch.empty(
|
||||
self.max_num_reqs,
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
self.num_computed_tokens_event = torch.npu.Event()
|
||||
self.num_computed_tokens_stream = torch.npu.Stream()
|
||||
self.num_computed_tokens_cpu = torch.empty(
|
||||
self.max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
|
||||
def prepare_inputs(
|
||||
self,
|
||||
@@ -161,9 +159,6 @@ class NPUModelRunner(GPUModelRunner):
|
||||
idx_mapping = self.input_buffers.idx_mapping
|
||||
idx_mapping.np[:num_reqs] = idx_mapping_list
|
||||
idx_mapping_np = idx_mapping.np[:num_reqs]
|
||||
# add `idx_mapping_cpu` here, because vllm-ascend's self.req_states.
|
||||
# num_computed_tokens_cpu is actually cpu's tensor, while it's a gpu's
|
||||
# tensor in vllm gpu_model_runner_v2.
|
||||
idx_mapping_cpu = idx_mapping.cpu[:num_reqs]
|
||||
idx_mapping_npu = idx_mapping.copy_to_gpu(num_reqs)
|
||||
|
||||
@@ -267,16 +262,12 @@ class NPUModelRunner(GPUModelRunner):
|
||||
query_start_loc_gpu=query_start_loc_gpu,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens=self.input_buffers.seq_lens,
|
||||
seq_lens_cpu=self.input_buffers.seq_lens_cpu,
|
||||
actual_seq_lengths_q=self.actual_seq_lengths_q,
|
||||
seq_lens_np=self.input_buffers.seq_lens_np,
|
||||
num_computed_tokens_cpu=self.req_states.
|
||||
num_computed_tokens_cpu[idx_mapping_cpu],
|
||||
block_tables=block_tables,
|
||||
# torch_npu._reshape_and_cache operator requires slot_mappings to
|
||||
# be torch.int32.
|
||||
slot_mappings=slot_mappings.to(torch.int32),
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
decode_token_per_req=self.decode_token_per_req,
|
||||
attn_state=attn_state,
|
||||
)
|
||||
|
||||
@@ -302,40 +293,35 @@ class NPUModelRunner(GPUModelRunner):
|
||||
cu_num_logits=cu_num_logits,
|
||||
)
|
||||
|
||||
def sample(
|
||||
def postprocess(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_batch: InputBatch,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
grammar_output: GrammarOutput | None,
|
||||
) -> tuple[SamplerOutput, torch.Tensor, torch.Tensor]:
|
||||
"""Override GPUModelRunner.sample for Ascend NPUs.
|
||||
when using async scheduling with speculative decoding,
|
||||
we need to copy mpu's num_rejected tensor to cpu.
|
||||
these operations aren't needed in gpu_model_runner_v2,
|
||||
because gpu attention backends do not use seq_lens_cpu anymore.
|
||||
input_batch,
|
||||
sampled_tokens,
|
||||
num_sampled,
|
||||
num_rejected,
|
||||
):
|
||||
"""Override GPUModelRunner.postprocess for Ascend NPUs.
|
||||
npu attention backends need seq_lens_cpu to work.
|
||||
so we need to copy num_computed_tokens back to cpu here.
|
||||
"""
|
||||
sampler_output, num_sampled, num_rejected = super().sample(
|
||||
hidden_states,
|
||||
super().postprocess(
|
||||
input_batch,
|
||||
sampling_metadata,
|
||||
grammar_output,
|
||||
sampled_tokens,
|
||||
num_sampled,
|
||||
num_rejected,
|
||||
)
|
||||
if self.num_rejected_tokens_event is not None:
|
||||
# npu attention backend still need to use seq_lens_cpu,
|
||||
# when doing speculative decoding with async_scheduling,
|
||||
# we need to copy num_rejected_tokens back to cpu.
|
||||
default_stream = torch.cuda.current_stream()
|
||||
assert self.num_rejected_token_stream is not None
|
||||
assert self.num_rejectd_tokens_cpu is not None
|
||||
with torch.npu.stream(self.num_rejected_token_stream):
|
||||
self.num_rejected_token_stream.wait_stream(default_stream)
|
||||
self.num_rejectd_tokens_cpu.copy_(
|
||||
num_rejected,
|
||||
non_blocking=True,
|
||||
)
|
||||
self.num_rejected_tokens_event.record()
|
||||
return sampler_output, num_sampled, num_rejected
|
||||
# npu attention backend still need to use seq_lens_cpu,
|
||||
# we need to copy num_computed_tokens back to cpu.
|
||||
default_stream = torch.cuda.current_stream()
|
||||
assert self.num_computed_tokens_stream is not None
|
||||
assert self.num_computed_tokens_cpu is not None
|
||||
with torch.npu.stream(self.num_computed_tokens_stream):
|
||||
self.num_computed_tokens_stream.wait_stream(default_stream)
|
||||
self.num_computed_tokens_cpu.copy_(
|
||||
self.req_states.num_computed_tokens,
|
||||
non_blocking=True,
|
||||
)
|
||||
self.num_computed_tokens_event.record()
|
||||
|
||||
def _update_seq_lens_cpu(
|
||||
self,
|
||||
@@ -343,17 +329,14 @@ class NPUModelRunner(GPUModelRunner):
|
||||
req_ids: list[str],
|
||||
):
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||
|
||||
# update num_computed_tokens_cpu
|
||||
# TODO(Ronald1995): update num_computed_tokens_cpu by considering
|
||||
# num_rejectd_tokens.
|
||||
for req_id, num_computed_token in zip(
|
||||
scheduler_output.scheduled_cached_reqs.req_ids,
|
||||
scheduler_output.scheduled_cached_reqs.num_computed_tokens,
|
||||
):
|
||||
# wait for num_computed_tokens copy to cpu stream to finish.
|
||||
self.num_computed_tokens_event.synchronize()
|
||||
for req_id in scheduler_output.scheduled_cached_reqs.req_ids:
|
||||
req_index = self.req_states.req_id_to_index[req_id]
|
||||
# num_computed_tokens_cpu has reverted by num_rejected_tokens already.
|
||||
# in super postprocess method.
|
||||
self.req_states.num_computed_tokens_cpu[
|
||||
req_index] = num_computed_token
|
||||
req_index] = self.num_computed_tokens_cpu[req_index]
|
||||
|
||||
# update seq_lens_cpu
|
||||
for i, req_id in enumerate(req_ids):
|
||||
|
||||
Reference in New Issue
Block a user