### What this PR does / why we need it?
1. ✅ Upgrade vllm commit to: 0115
(8471b27df97c3eb79f891802fc0e858f8f7ac6a0)
Modify import paths due to the refactors:
https://github.com/vllm-project/vllm/pull/32245
https://github.com/vllm-project/vllm/pull/32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21034239336/job/60490156965?pr=5913
2. ✅Upgrade vllm commit to: 0119
(9a1f16da1e423ede2c2f52a9850cbfbb39cefe96)
Fix `WorkerProc.__init__() missing 1 required positional argument:
'is_driver_worker'` due to
https://github.com/vllm-project/vllm/pull/28506
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21156263050/job/60841668755?5569
3. ✅Upgrade vllm commit to:
0120(148117ea2e689cd43df4be6892671a17cdae5833)
1. Add `skip_compiled` param in `set_forward_context` due to
https://github.com/vllm-project/vllm/pull/30385
2. Modify `tests/ut/spec_decode/test_eagle_proposer.py` due to
https://github.com/vllm-project/vllm/pull/24322
change `self.max_num_tokens =
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size`
3. Modify UT import paths due to the
refactors:https://github.com/vllm-project/vllm/pull/32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21204851770/job/60999046946
4. ✅Upgrade vllm commit to:
0121(f23fb5a7c1b61350c5c40ca1115d3bf8cf2b8cc9)
1. vLLM switched `uses_mrope` from target to draft model config, making
`positions`/`mrope_positions` mutually exclusive, breaking vllm-ascend's
direct self.positions access and tests missing
`draft_model_config.uses_mrope`.
https://github.com/vllm-project/vllm/pull/32048
2. Moved bs_to_padded_graph_size from CompilationConfig to
CudagraphDispatcher due to the refactor
https://github.com/vllm-project/vllm/pull/30143
3. Remove unused `maybe_setup_kv_connector` due to
https://github.com/vllm-project/vllm/pull/32077
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21217728738/job/61043738834
6. ✅Upgrade vllm commit to:
0122(8ebf271bb6d1e7e9b1a55be73d755ef1a57dbbe5)
Updating FusedMoEParallelConfig (added enable_eplb) and FusedMoEConfig
due to https://github.com/vllm-project/vllm/pull/32414
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21249922546/job/61148613054
8. ✅Upgrade vllm commit to:
0123(dc917cceb877dfd13f98c538c4c96158047d98bd)
Setting temperature=0.0 due to the removal of the default temperature
value in https://github.com/vllm-project/vllm/pull/32723
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21280796875
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.14.0
- vLLM main:
d68209402d
---------
Signed-off-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: Meihan-chen <jcccx.cmh@gmail.com>
Co-authored-by: wjunLu <wjunlu217@gmail.com>
167 lines
6.6 KiB
Python
167 lines
6.6 KiB
Python
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/sample/spec_decode/eagle.py
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
import numpy as np
|
|
import torch
|
|
import vllm
|
|
from vllm.v1.worker.gpu.input_batch import InputBatch
|
|
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
|
|
from vllm.v1.sample.metadata import SamplingMetadata
|
|
from vllm.v1.worker.gpu.spec_decode.eagle import (prepare_eagle_decode,
|
|
prepare_eagle_inputs)
|
|
|
|
from vllm_ascend.worker.v2.attn_utils import build_attn_metadata
|
|
|
|
|
|
@torch.inference_mode()
|
|
def propose(
|
|
self,
|
|
input_batch: InputBatch,
|
|
sampling_metadata: SamplingMetadata,
|
|
# [num_tokens, hidden_size]
|
|
last_hidden_states: torch.Tensor,
|
|
# num_layers x [num_tokens, hidden_size]
|
|
aux_hidden_states: list[torch.Tensor] | None,
|
|
# [num_reqs]
|
|
num_sampled: torch.Tensor,
|
|
# [num_reqs]
|
|
num_rejected: torch.Tensor,
|
|
# [num_reqs]
|
|
last_sampled: torch.Tensor,
|
|
# [num_reqs]
|
|
next_prefill_tokens: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
# NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the
|
|
# number of rejected tokens, we maintain the size of eagle's input_ids and
|
|
# hidden_states the same as the target model's. This means, we pad each
|
|
# request's query length to include any rejected positions. By doing so,
|
|
# we can also reuse the attention metadata (e.g., query_start_loc,
|
|
# seq_lens) of the target model.
|
|
if aux_hidden_states:
|
|
assert self.method == "eagle3"
|
|
hidden_states = self.model.combine_hidden_states(
|
|
torch.cat(aux_hidden_states, dim=-1))
|
|
else:
|
|
hidden_states = last_hidden_states
|
|
num_tokens = input_batch.num_tokens_after_padding
|
|
self.hidden_states[:num_tokens] = hidden_states
|
|
|
|
# Get the input ids and last token indices for the speculator.
|
|
last_token_indices = prepare_eagle_inputs(
|
|
self.input_buffers,
|
|
input_batch,
|
|
num_sampled,
|
|
num_rejected,
|
|
last_sampled,
|
|
next_prefill_tokens,
|
|
)
|
|
|
|
# Prefill: Run the eagle speculator with eager mode.
|
|
# TODO(woosuk): Support CUDA graph for prefill.
|
|
last_hidden_states, hidden_states = self.run_model(
|
|
num_tokens,
|
|
input_batch.attn_metadata,
|
|
num_tokens_across_dp=None, # FIXME
|
|
)
|
|
sample_hidden_states = last_hidden_states[last_token_indices]
|
|
logits = self.model.compute_logits(sample_hidden_states)
|
|
|
|
num_reqs = input_batch.num_reqs
|
|
cu_num_logits = input_batch.cu_num_logits[:num_reqs]
|
|
# NOTE(woosuk): For draft sampling, we only consider the temperature
|
|
# and ignore the other sampling parameters such as top_k and top_p,
|
|
# for simplicity and performance.
|
|
# While this may slightly degrade the acceptance rate, it does not
|
|
# affect the output distribution after rejection sampling.
|
|
# NOTE(Ronald1995): torch.gather will pollute the cache such as self.input_buffers.positions
|
|
# the bug is reported to huawei CANN team, but not fixed yet.
|
|
# So we clone the tensors before calling torch.gather to avoid the issue.
|
|
temperature = self.temperature[:num_reqs].clone()
|
|
seeds = self.seeds[:num_reqs].clone()
|
|
pos = self.input_buffers.positions[:num_reqs].clone()
|
|
# Gather the values and copy them to the pre-allocated buffers.
|
|
torch.gather(sampling_metadata.temperature,
|
|
0,
|
|
cu_num_logits,
|
|
out=temperature)
|
|
torch.gather(sampling_metadata.seeds, 0, cu_num_logits, out=seeds)
|
|
torch.gather(input_batch.positions, 0, last_token_indices, out=pos)
|
|
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
|
|
# used for draft and target sampling.
|
|
draft_tokens = gumbel_sample(logits,
|
|
temperature,
|
|
seeds,
|
|
pos + 1,
|
|
apply_temperature=True)
|
|
if self.num_speculative_steps == 1:
|
|
# Early exit.
|
|
return draft_tokens.view(-1, 1)
|
|
|
|
# Save the draft tokens for the first step.
|
|
self.draft_tokens[:num_reqs, 0] = draft_tokens
|
|
# Prepare the inputs for the decode steps.
|
|
prepare_eagle_decode(
|
|
draft_tokens,
|
|
hidden_states,
|
|
last_token_indices,
|
|
input_batch.seq_lens,
|
|
num_rejected,
|
|
self.input_buffers,
|
|
self.hidden_states,
|
|
self.max_model_len,
|
|
self.max_num_reqs,
|
|
)
|
|
query_start_loc = self.input_buffers.query_start_loc
|
|
query_start_loc_gpu = query_start_loc.gpu[:num_reqs + 1]
|
|
slot_mappings = self.block_tables.compute_slot_mappings(
|
|
query_start_loc_gpu, pos)
|
|
|
|
cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs)
|
|
if cudagraph_size is not None:
|
|
# Run CUDA graph.
|
|
self.cudagraph_manager.run(cudagraph_size)
|
|
return self.draft_tokens[:num_reqs]
|
|
|
|
# Run eager mode.
|
|
query_start_loc.np[:num_reqs + 1] = np.arange(num_reqs + 1)
|
|
query_start_loc_cpu = query_start_loc.cpu[:num_reqs + 1]
|
|
# HACK(woosuk)
|
|
seq_lens_np = np.full(num_reqs, self.max_model_len, dtype=np.int32)
|
|
block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables]
|
|
|
|
# FIXME(woosuk): This is UNSAFE!!
|
|
attn_metadata = build_attn_metadata(
|
|
attn_metadata_builders=self.attn_metadata_builders,
|
|
num_reqs=num_reqs,
|
|
num_tokens=num_reqs,
|
|
query_start_loc_gpu=query_start_loc_gpu,
|
|
query_start_loc_cpu=query_start_loc_cpu,
|
|
seq_lens=self.input_buffers.seq_lens[:num_reqs],
|
|
seq_lens_np=seq_lens_np,
|
|
num_computed_tokens_cpu=None, # FIXME
|
|
block_tables=block_tables,
|
|
slot_mappings=slot_mappings,
|
|
kv_cache_config=self.kv_cache_config,
|
|
)
|
|
self.generate_draft(num_reqs, attn_metadata,
|
|
num_tokens_across_dp=None) # FIXME
|
|
return self.draft_tokens[:num_reqs]
|
|
|
|
|
|
vllm.v1.worker.gpu.spec_decode.eagle.EagleSpeculator.propose = propose
|