[Feat][main] Supported to use full-graph with Qwen3-Next-MTP (#5477)
### What this PR does / why we need it?
Supported to use full-graph with Qwen3-Next-MTP.
In detail, we adatpted `AscendAttentionState.ChunkedPrefill` in main
model, and also adapted `AscendAttentionState.ChunkedPrefill` in mtp
model.
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
We changed the test of Qwen3-Next-MTP in
`tests/e2e/multicard/test_qwen3_next.py` to make it a test of
`FULL_DECODE_ONLY`. Then run `pytest -s
tests/e2e/multicard/test_qwen3_next.py::test_qwen3_next_distributed_mp_eager_mtp_similarity_tp4`.
And this test passed.
```text
.
================================================================================================================================= warnings summary =================================================================================================================================
<frozen importlib._bootstrap>:241
<frozen importlib._bootstrap>:241: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute
<frozen importlib._bootstrap>:241
<frozen importlib._bootstrap>:241: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
==================================================================================================================== 1 passed, 2 warnings in 271.89s (0:04:31) =====================================================================================================================
```
- vLLM version: v0.13.0
- vLLM main:
5326c89803
Signed-off-by: drslark <slarksblood@qq.com>
This commit is contained in:
@@ -894,9 +894,6 @@ class NPUModelRunner(GPUModelRunner):
|
||||
self.logits_indices = logits_indices
|
||||
|
||||
# Used in the below loop.
|
||||
# query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1]
|
||||
num_computed_tokens_cpu = (
|
||||
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
|
||||
self.spec_decode_common_attn_metadata = None
|
||||
if use_spec_decode and self.need_accepted_tokens:
|
||||
self.num_accepted_tokens.np[:num_reqs] = (
|
||||
@@ -991,7 +988,8 @@ class NPUModelRunner(GPUModelRunner):
|
||||
# TODO: change this to the right block table for linear attn
|
||||
block_table_tensor=blk_table_tensor[:num_reqs],
|
||||
slot_mapping=slot_mapping,
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
num_computed_tokens_cpu=self.input_batch.
|
||||
num_computed_tokens_cpu_tensor[:num_reqs],
|
||||
positions=self.positions.gpu,
|
||||
attn_mask=self.attn_mask,
|
||||
spec_attn_mask=self.spec_attn_mask,
|
||||
@@ -1822,7 +1820,11 @@ class NPUModelRunner(GPUModelRunner):
|
||||
attn_state = AscendAttentionState.DecodeOnly
|
||||
if self.speculative_config and \
|
||||
self.speculative_config.method == "mtp":
|
||||
attn_state = AscendAttentionState.SpecDecoding
|
||||
# `AscendAttentionState.SpecDecoding` is only designed for mla
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
attn_state = AscendAttentionState.SpecDecoding
|
||||
else:
|
||||
attn_state = AscendAttentionState.ChunkedPrefill
|
||||
|
||||
common_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=self.query_start_loc.gpu[:num_reqs + 1],
|
||||
|
||||
Reference in New Issue
Block a user