[Bugfix] fix bug for mtp (#6514)
### What this PR does / why we need it? fix(mtp): resolve MTP core bugs and enhance eager mode test cases 1. Resolved critical issues in eager mode MTP core execution logic; 2. Fixed functional bugs in the _update_states_after_model_execute function; 3. Updated and released test_mtp_qwen3_next.py to validate eager mode acceptance rate. ### Does this PR introduce _any_ user-facing change? None ### How was this patch tested? - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0 Signed-off-by: Bowen-Leee <caoshankuangren@gmail.com>
This commit is contained in:
1
.github/workflows/scripts/config.yaml
vendored
1
.github/workflows/scripts/config.yaml
vendored
@@ -152,4 +152,3 @@ e2e-multicard-4-cards:
|
|||||||
is_skipped: true
|
is_skipped: true
|
||||||
- name: tests/e2e/multicard/4-cards/spec_decode/test_mtp_qwen3_next.py
|
- name: tests/e2e/multicard/4-cards/spec_decode/test_mtp_qwen3_next.py
|
||||||
estimated_time: 60
|
estimated_time: 60
|
||||||
is_skipped: true
|
|
||||||
|
|||||||
@@ -54,7 +54,6 @@ def test_qwen3_next_mtp_acceptance_tp4(model_name):
|
|||||||
distributed_executor_backend="mp",
|
distributed_executor_backend="mp",
|
||||||
disable_log_stats=False,
|
disable_log_stats=False,
|
||||||
speculative_config={
|
speculative_config={
|
||||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
|
||||||
"method": "qwen3_next_mtp",
|
"method": "qwen3_next_mtp",
|
||||||
"num_speculative_tokens": 3,
|
"num_speculative_tokens": 3,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
|||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||||
from vllm_ascend.compilation.acl_graph import ACLGraphWrapper
|
from vllm_ascend.compilation.acl_graph import ACLGraphWrapper
|
||||||
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
|
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla, update_cos_sin
|
||||||
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
||||||
from vllm_ascend.utils import lmhead_tp_enable, vllm_version_is
|
from vllm_ascend.utils import lmhead_tp_enable, vllm_version_is
|
||||||
|
|
||||||
@@ -329,6 +329,7 @@ class MtpProposer(EagleProposer):
|
|||||||
for layer_name in self.attn_layer_names:
|
for layer_name in self.attn_layer_names:
|
||||||
attn_metadata[layer_name] = attn_metadata_mtp
|
attn_metadata[layer_name] = attn_metadata_mtp
|
||||||
|
|
||||||
|
update_cos_sin(self._get_positions(num_input_tokens))
|
||||||
for step in range(self.num_speculative_tokens):
|
for step in range(self.num_speculative_tokens):
|
||||||
with set_ascend_forward_context(
|
with set_ascend_forward_context(
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
|
|||||||
@@ -1374,6 +1374,9 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
with record_function_or_nullcontext("sample_token"):
|
with record_function_or_nullcontext("sample_token"):
|
||||||
sampler_output = self._sample(logits, spec_decode_metadata)
|
sampler_output = self._sample(logits, spec_decode_metadata)
|
||||||
|
|
||||||
|
if self.need_accepted_tokens:
|
||||||
|
self._update_states_after_model_execute(sampler_output.sampled_token_ids, scheduler_output)
|
||||||
|
|
||||||
def propose_draft_token_ids(sampled_token_ids):
|
def propose_draft_token_ids(sampled_token_ids):
|
||||||
assert spec_decode_common_attn_metadata is not None
|
assert spec_decode_common_attn_metadata is not None
|
||||||
self._draft_token_ids = self.propose_draft_token_ids(
|
self._draft_token_ids = self.propose_draft_token_ids(
|
||||||
@@ -1474,8 +1477,6 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
logits,
|
logits,
|
||||||
sampling_metadata,
|
sampling_metadata,
|
||||||
)
|
)
|
||||||
if self.need_accepted_tokens: # TODO remove this if
|
|
||||||
self._update_states_after_model_execute(sampler_output.sampled_token_ids)
|
|
||||||
return sampler_output
|
return sampler_output
|
||||||
|
|
||||||
# TODO: remove this func after eagle_proposer is refactored and
|
# TODO: remove this func after eagle_proposer is refactored and
|
||||||
|
|||||||
Reference in New Issue
Block a user