[Feature] implement eagle spec decoding for model runner v2 (#5840)
### What this PR does / why we need it? this pr implement eagle spec decoding for model runner v2, please see RFC https://github.com/vllm-project/vllm-ascend/issues/5208 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? vLLM version: v0.13.0 --------- Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
@@ -25,6 +25,9 @@ from tests.e2e.conftest import VllmRunner
|
|||||||
|
|
||||||
MODELS = ["Qwen/Qwen3-0.6B"]
|
MODELS = ["Qwen/Qwen3-0.6B"]
|
||||||
|
|
||||||
|
MAIN_MODELS = ["LLM-Research/Meta-Llama-3.1-8B-Instruct"]
|
||||||
|
EGALE_MODELS = ["vllm-ascend/EAGLE-LLaMA3.1-Instruct-8B"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.parametrize("max_tokens", [32])
|
@pytest.mark.parametrize("max_tokens", [32])
|
||||||
@@ -49,3 +52,36 @@ def test_qwen3_dense_eager_mode(
|
|||||||
enforce_eager=enforce_eager,
|
enforce_eager=enforce_eager,
|
||||||
) as runner:
|
) as runner:
|
||||||
runner.model.generate(prompts, sampling_params)
|
runner.model.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", MAIN_MODELS)
|
||||||
|
@pytest.mark.parametrize("eagle_model", EGALE_MODELS)
|
||||||
|
@pytest.mark.parametrize("max_tokens", [32])
|
||||||
|
@pytest.mark.parametrize("enforce_eager", [True])
|
||||||
|
@patch.dict(os.environ, {"VLLM_USE_V2_MODEL_RUNNER": "1"})
|
||||||
|
def test_egale_spec_decoding(
|
||||||
|
model: str,
|
||||||
|
eagle_model: str,
|
||||||
|
max_tokens: int,
|
||||||
|
enforce_eager: bool,
|
||||||
|
) -> None:
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(max_tokens=max_tokens, temperature=0.0)
|
||||||
|
with VllmRunner(
|
||||||
|
model,
|
||||||
|
max_model_len=1024,
|
||||||
|
enforce_eager=enforce_eager,
|
||||||
|
async_scheduling=True,
|
||||||
|
speculative_config={
|
||||||
|
"model": eagle_model,
|
||||||
|
"method": "eagle",
|
||||||
|
"num_speculative_tokens": 3,
|
||||||
|
},
|
||||||
|
) as runner:
|
||||||
|
runner.model.generate(prompts, sampling_params)
|
||||||
|
|||||||
@@ -174,7 +174,8 @@
|
|||||||
#
|
#
|
||||||
# ** 6. File: worker/patch_triton.py**
|
# ** 6. File: worker/patch_triton.py**
|
||||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
# 1. `vllm.model_executor.layers.mamba.ops`, `vllm.model_executor.layers.fla.ops`
|
# 1. `vllm.model_executor.layers.mamba.ops`, `vllm.model_executor.layers.fla.ops`,
|
||||||
|
# `vllm.v1.worker.gpu.sample.gumbel.gumbel_sample`
|
||||||
# Why:
|
# Why:
|
||||||
# triton ops in vLLM perform not good on NPU. And there is no dispatch mechanism for triton ops.
|
# triton ops in vLLM perform not good on NPU. And there is no dispatch mechanism for triton ops.
|
||||||
# How:
|
# How:
|
||||||
@@ -263,3 +264,15 @@
|
|||||||
# Future Plan:
|
# Future Plan:
|
||||||
# Remove this patch when vLLM support these operators.
|
# Remove this patch when vLLM support these operators.
|
||||||
#
|
#
|
||||||
|
# ** 12. File: worker/patch_v2_eagle.py**
|
||||||
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
# 1. `vllm.v1.worker.gpu.spec_decode.eagle.EagleSpeculator.propose`
|
||||||
|
# Why:
|
||||||
|
# `propose` method use torch.gather, but the gather operator will
|
||||||
|
# pollute the arguments passed to it. the bug is reported to huawei
|
||||||
|
# CANN team, but not fixed yet.
|
||||||
|
# How:
|
||||||
|
# clone the out attribute ahead of gather to avoid the bug.
|
||||||
|
# Future Plan:
|
||||||
|
# Remove this patch when cann fix the gather bug.
|
||||||
|
#
|
||||||
|
|||||||
@@ -32,3 +32,4 @@ import vllm_ascend.patch.worker.patch_qwen3_next # noqa
|
|||||||
import vllm_ascend.patch.worker.patch_qwen3_next_mtp # noqa
|
import vllm_ascend.patch.worker.patch_qwen3_next_mtp # noqa
|
||||||
import vllm_ascend.patch.worker.patch_rejection_sampler # noqa
|
import vllm_ascend.patch.worker.patch_rejection_sampler # noqa
|
||||||
import vllm_ascend.patch.worker.patch_qwen3_next # noqa
|
import vllm_ascend.patch.worker.patch_qwen3_next # noqa
|
||||||
|
import vllm_ascend.patch.worker.patch_v2_egale # noqa
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import vllm.model_executor.layers.mamba.ops.causal_conv1d
|
import vllm.model_executor.layers.mamba.ops.causal_conv1d
|
||||||
|
import vllm.v1.worker.gpu.sample.gumbel
|
||||||
|
|
||||||
from vllm_ascend.ops.triton.fla.chunk import chunk_gated_delta_rule
|
from vllm_ascend.ops.triton.fla.chunk import chunk_gated_delta_rule
|
||||||
from vllm_ascend.ops.triton.fla.layernorm_guard import LayerNormFn
|
from vllm_ascend.ops.triton.fla.layernorm_guard import LayerNormFn
|
||||||
@@ -6,9 +7,12 @@ from vllm_ascend.ops.triton.fla.sigmoid_gating import \
|
|||||||
fused_recurrent_gated_delta_rule_fwd_kernel
|
fused_recurrent_gated_delta_rule_fwd_kernel
|
||||||
from vllm_ascend.ops.triton.mamba.causal_conv1d import (
|
from vllm_ascend.ops.triton.mamba.causal_conv1d import (
|
||||||
causal_conv1d_fn, causal_conv1d_update_npu)
|
causal_conv1d_fn, causal_conv1d_update_npu)
|
||||||
|
from vllm_ascend.worker.v2.sample.gumbel import \
|
||||||
|
gumbel_sample as ascend_gumbel_sample
|
||||||
|
|
||||||
vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_update = causal_conv1d_update_npu
|
vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_update = causal_conv1d_update_npu
|
||||||
vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_fn = causal_conv1d_fn
|
vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_fn = causal_conv1d_fn
|
||||||
vllm.model_executor.layers.fla.ops.fused_recurrent.fused_recurrent_gated_delta_rule_fwd_kernel = fused_recurrent_gated_delta_rule_fwd_kernel
|
vllm.model_executor.layers.fla.ops.fused_recurrent.fused_recurrent_gated_delta_rule_fwd_kernel = fused_recurrent_gated_delta_rule_fwd_kernel
|
||||||
vllm.model_executor.layers.fla.ops.layernorm_guard.LayerNormFn = LayerNormFn
|
vllm.model_executor.layers.fla.ops.layernorm_guard.LayerNormFn = LayerNormFn
|
||||||
vllm.model_executor.layers.fla.ops.chunk_gated_delta_rule = chunk_gated_delta_rule
|
vllm.model_executor.layers.fla.ops.chunk_gated_delta_rule = chunk_gated_delta_rule
|
||||||
|
vllm.v1.worker.gpu.sample.gumbel.gumbel_sample = ascend_gumbel_sample
|
||||||
|
|||||||
166
vllm_ascend/patch/worker/patch_v2_egale.py
Normal file
166
vllm_ascend/patch/worker/patch_v2_egale.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/sample/spec_decode/eagle.py
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# This file is a part of the vllm-ascend project.
|
||||||
|
#
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import vllm
|
||||||
|
from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||||
|
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
|
||||||
|
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
|
||||||
|
from vllm.v1.worker.gpu.spec_decode.eagle import (prepare_eagle_decode,
|
||||||
|
prepare_eagle_inputs)
|
||||||
|
|
||||||
|
from vllm_ascend.worker.v2.attn_utils import build_attn_metadata
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def propose(
|
||||||
|
self,
|
||||||
|
input_batch: InputBatch,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
# [num_tokens, hidden_size]
|
||||||
|
last_hidden_states: torch.Tensor,
|
||||||
|
# num_layers x [num_tokens, hidden_size]
|
||||||
|
aux_hidden_states: list[torch.Tensor] | None,
|
||||||
|
# [num_reqs]
|
||||||
|
num_sampled: torch.Tensor,
|
||||||
|
# [num_reqs]
|
||||||
|
num_rejected: torch.Tensor,
|
||||||
|
# [num_reqs]
|
||||||
|
last_sampled: torch.Tensor,
|
||||||
|
# [num_reqs]
|
||||||
|
next_prefill_tokens: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the
|
||||||
|
# number of rejected tokens, we maintain the size of eagle's input_ids and
|
||||||
|
# hidden_states the same as the target model's. This means, we pad each
|
||||||
|
# request's query length to include any rejected positions. By doing so,
|
||||||
|
# we can also reuse the attention metadata (e.g., query_start_loc,
|
||||||
|
# seq_lens) of the target model.
|
||||||
|
if aux_hidden_states:
|
||||||
|
assert self.method == "eagle3"
|
||||||
|
hidden_states = self.model.combine_hidden_states(
|
||||||
|
torch.cat(aux_hidden_states, dim=-1))
|
||||||
|
else:
|
||||||
|
hidden_states = last_hidden_states
|
||||||
|
num_tokens = input_batch.num_tokens_after_padding
|
||||||
|
self.hidden_states[:num_tokens] = hidden_states
|
||||||
|
|
||||||
|
# Get the input ids and last token indices for the speculator.
|
||||||
|
last_token_indices = prepare_eagle_inputs(
|
||||||
|
self.input_buffers,
|
||||||
|
input_batch,
|
||||||
|
num_sampled,
|
||||||
|
num_rejected,
|
||||||
|
last_sampled,
|
||||||
|
next_prefill_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prefill: Run the eagle speculator with eager mode.
|
||||||
|
# TODO(woosuk): Support CUDA graph for prefill.
|
||||||
|
last_hidden_states, hidden_states = self.run_model(
|
||||||
|
num_tokens,
|
||||||
|
input_batch.attn_metadata,
|
||||||
|
num_tokens_across_dp=None, # FIXME
|
||||||
|
)
|
||||||
|
sample_hidden_states = last_hidden_states[last_token_indices]
|
||||||
|
logits = self.model.compute_logits(sample_hidden_states)
|
||||||
|
|
||||||
|
num_reqs = input_batch.num_reqs
|
||||||
|
cu_num_logits = input_batch.cu_num_logits[:num_reqs]
|
||||||
|
# NOTE(woosuk): For draft sampling, we only consider the temperature
|
||||||
|
# and ignore the other sampling parameters such as top_k and top_p,
|
||||||
|
# for simplicity and performance.
|
||||||
|
# While this may slightly degrade the acceptance rate, it does not
|
||||||
|
# affect the output distribution after rejection sampling.
|
||||||
|
# NOTE(Ronald1995): torch.gather will pollute the cache such as self.input_buffers.positions
|
||||||
|
# the bug is reported to huawei CANN team, but not fixed yet.
|
||||||
|
# So we clone the tensors before calling torch.gather to avoid the issue.
|
||||||
|
temperature = self.temperature[:num_reqs].clone()
|
||||||
|
seeds = self.seeds[:num_reqs].clone()
|
||||||
|
pos = self.input_buffers.positions[:num_reqs].clone()
|
||||||
|
# Gather the values and copy them to the pre-allocated buffers.
|
||||||
|
torch.gather(sampling_metadata.temperature,
|
||||||
|
0,
|
||||||
|
cu_num_logits,
|
||||||
|
out=temperature)
|
||||||
|
torch.gather(sampling_metadata.seeds, 0, cu_num_logits, out=seeds)
|
||||||
|
torch.gather(input_batch.positions, 0, last_token_indices, out=pos)
|
||||||
|
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
|
||||||
|
# used for draft and target sampling.
|
||||||
|
draft_tokens = gumbel_sample(logits,
|
||||||
|
temperature,
|
||||||
|
seeds,
|
||||||
|
pos + 1,
|
||||||
|
apply_temperature=True)
|
||||||
|
if self.num_speculative_steps == 1:
|
||||||
|
# Early exit.
|
||||||
|
return draft_tokens.view(-1, 1)
|
||||||
|
|
||||||
|
# Save the draft tokens for the first step.
|
||||||
|
self.draft_tokens[:num_reqs, 0] = draft_tokens
|
||||||
|
# Prepare the inputs for the decode steps.
|
||||||
|
prepare_eagle_decode(
|
||||||
|
draft_tokens,
|
||||||
|
hidden_states,
|
||||||
|
last_token_indices,
|
||||||
|
input_batch.seq_lens,
|
||||||
|
num_rejected,
|
||||||
|
self.input_buffers,
|
||||||
|
self.hidden_states,
|
||||||
|
self.max_model_len,
|
||||||
|
self.max_num_reqs,
|
||||||
|
)
|
||||||
|
query_start_loc = self.input_buffers.query_start_loc
|
||||||
|
query_start_loc_gpu = query_start_loc.gpu[:num_reqs + 1]
|
||||||
|
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||||
|
query_start_loc_gpu, pos)
|
||||||
|
|
||||||
|
cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs)
|
||||||
|
if cudagraph_size is not None:
|
||||||
|
# Run CUDA graph.
|
||||||
|
self.cudagraph_manager.run(cudagraph_size)
|
||||||
|
return self.draft_tokens[:num_reqs]
|
||||||
|
|
||||||
|
# Run eager mode.
|
||||||
|
query_start_loc.np[:num_reqs + 1] = np.arange(num_reqs + 1)
|
||||||
|
query_start_loc_cpu = query_start_loc.cpu[:num_reqs + 1]
|
||||||
|
# HACK(woosuk)
|
||||||
|
seq_lens_np = np.full(num_reqs, self.max_model_len, dtype=np.int32)
|
||||||
|
block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables]
|
||||||
|
|
||||||
|
# FIXME(woosuk): This is UNSAFE!!
|
||||||
|
attn_metadata = build_attn_metadata(
|
||||||
|
attn_metadata_builders=self.attn_metadata_builders,
|
||||||
|
num_reqs=num_reqs,
|
||||||
|
num_tokens=num_reqs,
|
||||||
|
query_start_loc_gpu=query_start_loc_gpu,
|
||||||
|
query_start_loc_cpu=query_start_loc_cpu,
|
||||||
|
seq_lens=self.input_buffers.seq_lens[:num_reqs],
|
||||||
|
seq_lens_np=seq_lens_np,
|
||||||
|
num_computed_tokens_cpu=None, # FIXME
|
||||||
|
block_tables=block_tables,
|
||||||
|
slot_mappings=slot_mappings,
|
||||||
|
kv_cache_config=self.kv_cache_config,
|
||||||
|
)
|
||||||
|
self.generate_draft(num_reqs, attn_metadata,
|
||||||
|
num_tokens_across_dp=None) # FIXME
|
||||||
|
return self.draft_tokens[:num_reqs]
|
||||||
|
|
||||||
|
|
||||||
|
vllm.v1.worker.gpu.spec_decode.eagle.EagleSpeculator.propose = propose
|
||||||
@@ -18,7 +18,7 @@
|
|||||||
#
|
#
|
||||||
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Any
|
from typing import Any, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -50,13 +50,11 @@ def build_attn_metadata(
|
|||||||
query_start_loc_gpu: torch.Tensor,
|
query_start_loc_gpu: torch.Tensor,
|
||||||
query_start_loc_cpu: torch.Tensor,
|
query_start_loc_cpu: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_cpu: torch.Tensor,
|
seq_lens_np: np.ndarray,
|
||||||
num_computed_tokens_cpu: torch.Tensor,
|
num_computed_tokens_cpu: torch.Tensor | None,
|
||||||
block_tables: Sequence[torch.Tensor],
|
block_tables: Sequence[torch.Tensor],
|
||||||
slot_mappings: torch.Tensor,
|
slot_mappings: torch.Tensor,
|
||||||
kv_cache_config: KVCacheConfig,
|
kv_cache_config: KVCacheConfig,
|
||||||
decode_token_per_req: int,
|
|
||||||
actual_seq_lengths_q: list[int],
|
|
||||||
positions: torch.Tensor | None = None,
|
positions: torch.Tensor | None = None,
|
||||||
attn_state: Any | None = None,
|
attn_state: Any | None = None,
|
||||||
graph_pad_size: int = -1,
|
graph_pad_size: int = -1,
|
||||||
@@ -67,7 +65,11 @@ def build_attn_metadata(
|
|||||||
"""Build attention metadata for Ascend NPUs."""
|
"""Build attention metadata for Ascend NPUs."""
|
||||||
# TODO(Ronald1995): optimize AscendCommonAttentionMetadata.
|
# TODO(Ronald1995): optimize AscendCommonAttentionMetadata.
|
||||||
max_query_len = int(query_start_loc_cpu.max())
|
max_query_len = int(query_start_loc_cpu.max())
|
||||||
|
seq_lens_cpu = torch.from_numpy(seq_lens_np)
|
||||||
max_seq_len = int(seq_lens_cpu.max())
|
max_seq_len = int(seq_lens_cpu.max())
|
||||||
|
# torch_npu._reshape_and_cache operator requires slot_mappings to
|
||||||
|
# be torch.int32.
|
||||||
|
slot_mappings = slot_mappings.to(torch.int32)
|
||||||
|
|
||||||
attn_metadata: dict[str, Any] = {}
|
attn_metadata: dict[str, Any] = {}
|
||||||
kv_cache_groups = kv_cache_config.kv_cache_groups
|
kv_cache_groups = kv_cache_config.kv_cache_groups
|
||||||
@@ -80,14 +82,11 @@ def build_attn_metadata(
|
|||||||
query_start_loc_cpu=query_start_loc_cpu,
|
query_start_loc_cpu=query_start_loc_cpu,
|
||||||
seq_lens_cpu=seq_lens_cpu[:num_reqs],
|
seq_lens_cpu=seq_lens_cpu[:num_reqs],
|
||||||
seq_lens=seq_lens[:num_reqs],
|
seq_lens=seq_lens[:num_reqs],
|
||||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
|
||||||
num_reqs=num_reqs,
|
num_reqs=num_reqs,
|
||||||
num_actual_tokens=num_tokens,
|
num_actual_tokens=num_tokens,
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
decode_token_per_req=decode_token_per_req,
|
|
||||||
block_table_tensor=block_table,
|
block_table_tensor=block_table,
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
actual_seq_lengths_q=actual_seq_lengths_q,
|
|
||||||
positions=positions,
|
positions=positions,
|
||||||
attn_state=attn_state,
|
attn_state=attn_state,
|
||||||
graph_pad_size=graph_pad_size,
|
graph_pad_size=graph_pad_size,
|
||||||
|
|||||||
@@ -21,20 +21,20 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.worker.gpu.input_batch import (InputBatch,
|
from vllm.v1.worker.gpu.input_batch import (InputBatch,
|
||||||
combine_sampled_and_draft_tokens,
|
combine_sampled_and_draft_tokens,
|
||||||
prepare_pos_seq_lens,
|
prepare_pos_seq_lens,
|
||||||
prepare_prefill_inputs)
|
prepare_prefill_inputs)
|
||||||
from vllm.v1.worker.gpu.model_runner import GPUModelRunner
|
from vllm.v1.worker.gpu.model_runner import GPUModelRunner
|
||||||
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
|
|
||||||
from vllm.v1.worker.gpu.sample.output import SamplerOutput
|
|
||||||
|
|
||||||
from vllm_ascend.worker.v2.aclgraph_utils import AclGraphManager
|
from vllm_ascend.worker.v2.aclgraph_utils import AclGraphManager
|
||||||
from vllm_ascend.worker.v2.attn_utils import (build_attn_metadata,
|
from vllm_ascend.worker.v2.attn_utils import (build_attn_metadata,
|
||||||
build_attn_state)
|
build_attn_state)
|
||||||
from vllm_ascend.worker.v2.input_batch import AscendInputBuffers
|
from vllm_ascend.worker.v2.input_batch import AscendInputBuffers
|
||||||
from vllm_ascend.worker.v2.sample.sampler import AscendSampler
|
from vllm_ascend.worker.v2.sample.sampler import AscendSampler
|
||||||
|
from vllm_ascend.worker.v2.spec_decode import init_speculator
|
||||||
|
from vllm_ascend.worker.v2.spec_decode.eagle import AscendEagleSpeculator
|
||||||
from vllm_ascend.worker.v2.states import AscendRequestState, uva_wrapper
|
from vllm_ascend.worker.v2.states import AscendRequestState, uva_wrapper
|
||||||
from vllm_ascend.worker.v2.utils import torch_cuda_wrapper
|
from vllm_ascend.worker.v2.utils import torch_cuda_wrapper
|
||||||
|
|
||||||
@@ -54,12 +54,21 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
del self.req_states
|
del self.req_states
|
||||||
del self.input_buffers
|
del self.input_buffers
|
||||||
del self.sampler
|
del self.sampler
|
||||||
|
del self.speculator
|
||||||
|
|
||||||
# NPU specific initializations can be added below.
|
# NPU specific initializations can be added below.
|
||||||
self.cudagraph_manager: AclGraphManager = AclGraphManager(
|
self.cudagraph_manager: AclGraphManager = AclGraphManager(
|
||||||
vllm_config,
|
vllm_config,
|
||||||
device,
|
device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# we define AscendEagleSpeculator in vllm_ascend.worker.v2.spec_decode.eagle
|
||||||
|
# init_speculator will return AscendEagleSpeculator when eagle is used.
|
||||||
|
# so here we just call init_speculator to reinitialize speculator.
|
||||||
|
self.speculator: AscendEagleSpeculator | None = None
|
||||||
|
if self.speculative_config is not None:
|
||||||
|
self.speculator = init_speculator(self.vllm_config, self.device)
|
||||||
|
|
||||||
# AscendRequestState has extra `num_computed_tokens_cpu` attribute.
|
# AscendRequestState has extra `num_computed_tokens_cpu` attribute.
|
||||||
# so reinitialize req_states here.
|
# so reinitialize req_states here.
|
||||||
self.req_states: AscendRequestState = AscendRequestState(
|
self.req_states: AscendRequestState = AscendRequestState(
|
||||||
@@ -87,26 +96,15 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
self.sampler: AscendSampler = AscendSampler(
|
self.sampler: AscendSampler = AscendSampler(
|
||||||
logprobs_mode=self.model_config.logprobs_mode, )
|
logprobs_mode=self.model_config.logprobs_mode, )
|
||||||
|
|
||||||
# actual seq lengths for query (used in attention backends).
|
# we need to copy num_computed_tokens back to cpu to help
|
||||||
self.actual_seq_lengths_q: list[int] = []
|
# update actual seq_lens_cpu. gpu attention backend doesn't need these
|
||||||
# decode token per request (used in attention backends).
|
# attributes, cause their attention backends doesn't use seq_lens_cpu.
|
||||||
self.decode_token_per_req = 1
|
|
||||||
|
|
||||||
# there attributes are for async scheduling with speculative decoding.
|
|
||||||
# because npu attention backend still need to use seq_lens_cpu,
|
|
||||||
# we need to copy num_rejected_tokens back to cpu to help
|
|
||||||
# update actual seq_lens_cpu. gpu attention backend do not need these
|
|
||||||
# attributes, cause their attention backends do not use seq_lens_cpu.
|
|
||||||
# and seq_lens_cpu is deprecated in gpu_model_runner_v2.
|
# and seq_lens_cpu is deprecated in gpu_model_runner_v2.
|
||||||
self.num_rejected_tokens_event = None
|
self.num_computed_tokens_event = torch.npu.Event()
|
||||||
self.num_rejectd_tokens_cpu = None
|
self.num_computed_tokens_stream = torch.npu.Stream()
|
||||||
self.num_rejected_token_stream = None
|
self.num_computed_tokens_cpu = torch.empty(
|
||||||
if self.use_async_scheduling and self.do_spec_decode:
|
|
||||||
self.num_rejected_tokens_event = torch.npu.Event()
|
|
||||||
self.num_rejected_token_stream = torch.npu.Stream()
|
|
||||||
self.num_rejectd_tokens_cpu = torch.empty(
|
|
||||||
self.max_num_reqs,
|
self.max_num_reqs,
|
||||||
dtype=torch.int64,
|
dtype=torch.int32,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
)
|
)
|
||||||
@@ -161,9 +159,6 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
idx_mapping = self.input_buffers.idx_mapping
|
idx_mapping = self.input_buffers.idx_mapping
|
||||||
idx_mapping.np[:num_reqs] = idx_mapping_list
|
idx_mapping.np[:num_reqs] = idx_mapping_list
|
||||||
idx_mapping_np = idx_mapping.np[:num_reqs]
|
idx_mapping_np = idx_mapping.np[:num_reqs]
|
||||||
# add `idx_mapping_cpu` here, because vllm-ascend's self.req_states.
|
|
||||||
# num_computed_tokens_cpu is actually cpu's tensor, while it's a gpu's
|
|
||||||
# tensor in vllm gpu_model_runner_v2.
|
|
||||||
idx_mapping_cpu = idx_mapping.cpu[:num_reqs]
|
idx_mapping_cpu = idx_mapping.cpu[:num_reqs]
|
||||||
idx_mapping_npu = idx_mapping.copy_to_gpu(num_reqs)
|
idx_mapping_npu = idx_mapping.copy_to_gpu(num_reqs)
|
||||||
|
|
||||||
@@ -267,16 +262,12 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
query_start_loc_gpu=query_start_loc_gpu,
|
query_start_loc_gpu=query_start_loc_gpu,
|
||||||
query_start_loc_cpu=query_start_loc_cpu,
|
query_start_loc_cpu=query_start_loc_cpu,
|
||||||
seq_lens=self.input_buffers.seq_lens,
|
seq_lens=self.input_buffers.seq_lens,
|
||||||
seq_lens_cpu=self.input_buffers.seq_lens_cpu,
|
seq_lens_np=self.input_buffers.seq_lens_np,
|
||||||
actual_seq_lengths_q=self.actual_seq_lengths_q,
|
|
||||||
num_computed_tokens_cpu=self.req_states.
|
num_computed_tokens_cpu=self.req_states.
|
||||||
num_computed_tokens_cpu[idx_mapping_cpu],
|
num_computed_tokens_cpu[idx_mapping_cpu],
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
# torch_npu._reshape_and_cache operator requires slot_mappings to
|
slot_mappings=slot_mappings,
|
||||||
# be torch.int32.
|
|
||||||
slot_mappings=slot_mappings.to(torch.int32),
|
|
||||||
kv_cache_config=self.kv_cache_config,
|
kv_cache_config=self.kv_cache_config,
|
||||||
decode_token_per_req=self.decode_token_per_req,
|
|
||||||
attn_state=attn_state,
|
attn_state=attn_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -302,40 +293,35 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
cu_num_logits=cu_num_logits,
|
cu_num_logits=cu_num_logits,
|
||||||
)
|
)
|
||||||
|
|
||||||
def sample(
|
def postprocess(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
input_batch: InputBatch,
|
|
||||||
sampling_metadata: SamplingMetadata,
|
|
||||||
grammar_output: GrammarOutput | None,
|
|
||||||
) -> tuple[SamplerOutput, torch.Tensor, torch.Tensor]:
|
|
||||||
"""Override GPUModelRunner.sample for Ascend NPUs.
|
|
||||||
when using async scheduling with speculative decoding,
|
|
||||||
we need to copy mpu's num_rejected tensor to cpu.
|
|
||||||
these operations aren't needed in gpu_model_runner_v2,
|
|
||||||
because gpu attention backends do not use seq_lens_cpu anymore.
|
|
||||||
"""
|
|
||||||
sampler_output, num_sampled, num_rejected = super().sample(
|
|
||||||
hidden_states,
|
|
||||||
input_batch,
|
input_batch,
|
||||||
sampling_metadata,
|
sampled_tokens,
|
||||||
grammar_output,
|
num_sampled,
|
||||||
)
|
|
||||||
if self.num_rejected_tokens_event is not None:
|
|
||||||
# npu attention backend still need to use seq_lens_cpu,
|
|
||||||
# when doing speculative decoding with async_scheduling,
|
|
||||||
# we need to copy num_rejected_tokens back to cpu.
|
|
||||||
default_stream = torch.cuda.current_stream()
|
|
||||||
assert self.num_rejected_token_stream is not None
|
|
||||||
assert self.num_rejectd_tokens_cpu is not None
|
|
||||||
with torch.npu.stream(self.num_rejected_token_stream):
|
|
||||||
self.num_rejected_token_stream.wait_stream(default_stream)
|
|
||||||
self.num_rejectd_tokens_cpu.copy_(
|
|
||||||
num_rejected,
|
num_rejected,
|
||||||
|
):
|
||||||
|
"""Override GPUModelRunner.postprocess for Ascend NPUs.
|
||||||
|
npu attention backends need seq_lens_cpu to work.
|
||||||
|
so we need to copy num_computed_tokens back to cpu here.
|
||||||
|
"""
|
||||||
|
super().postprocess(
|
||||||
|
input_batch,
|
||||||
|
sampled_tokens,
|
||||||
|
num_sampled,
|
||||||
|
num_rejected,
|
||||||
|
)
|
||||||
|
# npu attention backend still need to use seq_lens_cpu,
|
||||||
|
# we need to copy num_computed_tokens back to cpu.
|
||||||
|
default_stream = torch.cuda.current_stream()
|
||||||
|
assert self.num_computed_tokens_stream is not None
|
||||||
|
assert self.num_computed_tokens_cpu is not None
|
||||||
|
with torch.npu.stream(self.num_computed_tokens_stream):
|
||||||
|
self.num_computed_tokens_stream.wait_stream(default_stream)
|
||||||
|
self.num_computed_tokens_cpu.copy_(
|
||||||
|
self.req_states.num_computed_tokens,
|
||||||
non_blocking=True,
|
non_blocking=True,
|
||||||
)
|
)
|
||||||
self.num_rejected_tokens_event.record()
|
self.num_computed_tokens_event.record()
|
||||||
return sampler_output, num_sampled, num_rejected
|
|
||||||
|
|
||||||
def _update_seq_lens_cpu(
|
def _update_seq_lens_cpu(
|
||||||
self,
|
self,
|
||||||
@@ -343,17 +329,14 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
req_ids: list[str],
|
req_ids: list[str],
|
||||||
):
|
):
|
||||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||||
|
# wait for num_computed_tokens copy to cpu stream to finish.
|
||||||
# update num_computed_tokens_cpu
|
self.num_computed_tokens_event.synchronize()
|
||||||
# TODO(Ronald1995): update num_computed_tokens_cpu by considering
|
for req_id in scheduler_output.scheduled_cached_reqs.req_ids:
|
||||||
# num_rejectd_tokens.
|
|
||||||
for req_id, num_computed_token in zip(
|
|
||||||
scheduler_output.scheduled_cached_reqs.req_ids,
|
|
||||||
scheduler_output.scheduled_cached_reqs.num_computed_tokens,
|
|
||||||
):
|
|
||||||
req_index = self.req_states.req_id_to_index[req_id]
|
req_index = self.req_states.req_id_to_index[req_id]
|
||||||
|
# num_computed_tokens_cpu has reverted by num_rejected_tokens already.
|
||||||
|
# in super postprocess method.
|
||||||
self.req_states.num_computed_tokens_cpu[
|
self.req_states.num_computed_tokens_cpu[
|
||||||
req_index] = num_computed_token
|
req_index] = self.num_computed_tokens_cpu[req_index]
|
||||||
|
|
||||||
# update seq_lens_cpu
|
# update seq_lens_cpu
|
||||||
for i, req_id in enumerate(req_ids):
|
for i, req_id in enumerate(req_ids):
|
||||||
|
|||||||
38
vllm_ascend/worker/v2/spec_decode/__init__.py
Normal file
38
vllm_ascend/worker/v2/spec_decode/__init__.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/sample/spec_decode/__init__.py
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# This file is a part of the vllm-ascend project.
|
||||||
|
#
|
||||||
|
import torch
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
|
||||||
|
|
||||||
|
def init_speculator(
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
|
"""Override GPU init_speculator for Ascend NPUs.
|
||||||
|
Use AscendEagleSpeculator when eagle is used.
|
||||||
|
"""
|
||||||
|
speculative_config = vllm_config.speculative_config
|
||||||
|
assert speculative_config is not None
|
||||||
|
if speculative_config.use_eagle():
|
||||||
|
from vllm_ascend.worker.v2.spec_decode.eagle import \
|
||||||
|
AscendEagleSpeculator
|
||||||
|
|
||||||
|
return AscendEagleSpeculator(vllm_config, device)
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{speculative_config.method} is not supported yet.")
|
||||||
146
vllm_ascend/worker/v2/spec_decode/eagle.py
Normal file
146
vllm_ascend/worker/v2/spec_decode/eagle.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/sample/spec_decode/eagle.py
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# This file is a part of the vllm-ascend project.
|
||||||
|
#
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import vllm
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||||
|
from vllm.v1.worker.gpu.spec_decode.eagle import EagleSpeculator
|
||||||
|
|
||||||
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
|
from vllm_ascend.worker.v2.attn_utils import build_attn_metadata
|
||||||
|
|
||||||
|
|
||||||
|
class AscendEagleSpeculator(EagleSpeculator):
|
||||||
|
|
||||||
|
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
||||||
|
"""Override GPU EagleSpeculator.__init__ for Ascend NPUs.
|
||||||
|
attnention metadata building in Ascend backend needs more information,
|
||||||
|
such as seq_lens_cpu from input_batch, so we need to override __init__.
|
||||||
|
"""
|
||||||
|
super().__init__(vllm_config, device)
|
||||||
|
# when in decode phase of eagle speculator, we need some value in
|
||||||
|
# main model's input_batch. so we keep a reference here.
|
||||||
|
self.input_batch: InputBatch | None = None
|
||||||
|
|
||||||
|
def propose(
|
||||||
|
self,
|
||||||
|
input_batch,
|
||||||
|
sampling_metadata,
|
||||||
|
last_hidden_states,
|
||||||
|
aux_hidden_states,
|
||||||
|
num_sampled,
|
||||||
|
num_rejected,
|
||||||
|
last_sampled,
|
||||||
|
next_prefill_tokens,
|
||||||
|
):
|
||||||
|
"""Override GPU EagleSpeculator.propose for Ascend NPUs,
|
||||||
|
because npu attention metadata needs more information,
|
||||||
|
we need to cache input_batch, so we can use it later in
|
||||||
|
generate_draft.
|
||||||
|
"""
|
||||||
|
self.input_batch = input_batch
|
||||||
|
# wrap build_attn_metadata to use Ascend attention metadata building.
|
||||||
|
# so we can call super().propose() directly.
|
||||||
|
with build_attn_metadata_wrapper():
|
||||||
|
return super().propose(
|
||||||
|
input_batch,
|
||||||
|
sampling_metadata,
|
||||||
|
last_hidden_states,
|
||||||
|
aux_hidden_states,
|
||||||
|
num_sampled,
|
||||||
|
num_rejected,
|
||||||
|
last_sampled,
|
||||||
|
next_prefill_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_draft(
|
||||||
|
self,
|
||||||
|
num_reqs: int,
|
||||||
|
attn_metadata: dict[str, Any],
|
||||||
|
num_tokens_across_dp,
|
||||||
|
):
|
||||||
|
"""Override GPU EagleSpeculator.generate_draft for Ascend NPUs, because
|
||||||
|
attn_metadata is created in super propose method, it does not have some
|
||||||
|
attribute that Ascend attention backend needs, so we update it.
|
||||||
|
"""
|
||||||
|
self._update_decode_attn_metadata(attn_metadata)
|
||||||
|
|
||||||
|
return super().generate_draft(
|
||||||
|
num_reqs,
|
||||||
|
attn_metadata,
|
||||||
|
num_tokens_across_dp,
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def run_model(
|
||||||
|
self,
|
||||||
|
num_tokens: int,
|
||||||
|
attn_metadata: dict[str, Any],
|
||||||
|
num_tokens_across_dp: torch.Tensor | None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Override GPU EagleSpeculator.run_model for Ascend NPUs, because
|
||||||
|
in decode phase, we need to update seq_lens_cpu in attn_metadata after
|
||||||
|
run model.
|
||||||
|
"""
|
||||||
|
last_hidden_states, hidden_states = super().run_model(
|
||||||
|
num_tokens,
|
||||||
|
attn_metadata,
|
||||||
|
num_tokens_across_dp,
|
||||||
|
)
|
||||||
|
|
||||||
|
# attn_metadata is None in profile_run and dummy_run.
|
||||||
|
if attn_metadata is not None:
|
||||||
|
for attn_meta in attn_metadata.values():
|
||||||
|
# seq_lens in AscendMetadata is a cpu tensor.
|
||||||
|
attn_meta.seq_lens = attn_meta.seq_lens + 1
|
||||||
|
attn_meta.seq_len_list = attn_meta.seq_lens.tolist()
|
||||||
|
return last_hidden_states, hidden_states
|
||||||
|
|
||||||
|
def _update_decode_attn_metadata(
|
||||||
|
self,
|
||||||
|
attn_metadata: dict[str, Any],
|
||||||
|
):
|
||||||
|
"""Update attention metadata for decode phase on Ascend NPUs."""
|
||||||
|
attn_state = AscendAttentionState.DecodeOnly
|
||||||
|
seq_lens_cpu = self._get_seq_lens_cpu()
|
||||||
|
# attn_metadata is build in vllm's super class.
|
||||||
|
# We need to update attn_state for each layer's metadata.
|
||||||
|
for metadata in attn_metadata.values():
|
||||||
|
metadata.attn_state = attn_state
|
||||||
|
metadata.seq_lens_cpu = seq_lens_cpu
|
||||||
|
|
||||||
|
def _get_seq_lens_cpu(self) -> torch.Tensor:
|
||||||
|
"""Get seq_lens_cpu from input_batch."""
|
||||||
|
assert self.input_batch is not None
|
||||||
|
seq_lens_cpu = torch.from_numpy(self.input_batch.seq_lens_np)
|
||||||
|
return seq_lens_cpu
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def build_attn_metadata_wrapper():
|
||||||
|
"""Context manager to override attention metadata building for Ascend NPUs."""
|
||||||
|
original_func = vllm.v1.worker.gpu.spec_decode.eagle.build_attn_metadata
|
||||||
|
try:
|
||||||
|
vllm.v1.worker.gpu.spec_decode.eagle.build_attn_metadata = build_attn_metadata
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
vllm.v1.worker.gpu.spec_decode.eagle.build_attn_metadata = original_func
|
||||||
Reference in New Issue
Block a user