[Spec-Decode] Fix spec decode proposer in 0.18.0 (#7544)
### What this PR does / why we need it?
As the vllm-ascend main doesn't maintain v0.17.0 now, we'd just apply
the single branch in eagle proposer. Otherwise it will raise error in
v0.18.0
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
CI passed with existing test.
- vLLM version: v0.18.0
- vLLM main:
8b6325758c
Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
@@ -46,7 +46,7 @@ from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.compilation.acl_graph import ACLGraphWrapper, update_full_graph_params
|
||||
from vllm_ascend.ops.triton.spec_decode.utils import prepare_inputs_padded_kernel
|
||||
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
||||
from vllm_ascend.utils import enable_sp, lmhead_tp_enable, shared_expert_dp_enabled, vllm_version_is
|
||||
from vllm_ascend.utils import enable_sp, lmhead_tp_enable, shared_expert_dp_enabled
|
||||
|
||||
# Currently we will fix block size to a small one since `num_reqs` can't be too large
|
||||
_PREPARE_INPUTS_BLOCK_SIZE = 4
|
||||
@@ -390,11 +390,8 @@ class SpecDecodeBaseProposer(EagleProposer):
|
||||
: num_reqs * self.decode_threshold
|
||||
]
|
||||
|
||||
if vllm_version_is("0.17.0"):
|
||||
assert len(self.draft_attn_groups) > 0
|
||||
builder = self.draft_attn_groups[0].get_metadata_builder()
|
||||
else:
|
||||
builder = self.runner.attn_groups[0][0].get_metadata_builder()
|
||||
assert len(self.draft_attn_groups) > 0
|
||||
builder = self.draft_attn_groups[0].get_metadata_builder()
|
||||
# update the tensor's address for each step.
|
||||
for draft_step in range(self.num_speculative_tokens):
|
||||
common_attn_metadata = self.shallow_copy_metadata(common_attn_metadata)
|
||||
@@ -558,11 +555,8 @@ class SpecDecodeBaseProposer(EagleProposer):
|
||||
common_attn_metadata.slot_mapping = self.slot_mapping_group[0]
|
||||
common_attn_metadata.num_input_tokens = num_input_tokens
|
||||
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
|
||||
if vllm_version_is("0.17.0"):
|
||||
assert len(self.draft_attn_groups) > 0
|
||||
builder = self.draft_attn_groups[0].get_metadata_builder()
|
||||
else:
|
||||
builder = self.runner.attn_groups[0][0].get_metadata_builder()
|
||||
assert len(self.draft_attn_groups) > 0
|
||||
builder = self.draft_attn_groups[0].get_metadata_builder()
|
||||
attn_metadata = builder.build(0, common_attn_metadata, self.runner.get_model())
|
||||
|
||||
if self.uses_mrope:
|
||||
@@ -1549,11 +1543,8 @@ class SpecDecodeBaseProposer(EagleProposer):
|
||||
|
||||
# update full-graph params for one spec token
|
||||
def _update_full_graph_params(self, forward_context, num_tokens, draft_attn_metadatas=None):
|
||||
if vllm_version_is("0.17.0"):
|
||||
assert len(self.draft_attn_groups) > 0
|
||||
attn_backend = self.draft_attn_groups[0].backend
|
||||
else:
|
||||
attn_backend = self.runner.attn_backend
|
||||
assert len(self.draft_attn_groups) > 0
|
||||
attn_backend = self.draft_attn_groups[0].backend
|
||||
update_full_graph_params(
|
||||
attn_backend,
|
||||
self.update_stream,
|
||||
|
||||
Reference in New Issue
Block a user