[Bugfix][CI]Add qwen3Next MTP+Full Decode (#6047)
### What this PR does / why we need it?
Fix a bug in the repo and add a test case for MTP + Full Decode Only +
Qwen3Next.
The _build_dummy_attn_metadata function in NPUModelRunner seems losed a
query_star_loc.copy_to_gpu operation, which will lead to difference
between query_start_loc and query_start_loc_cpu, and they are required
to be same in MTP + Full Decode Only + Qwen3Next case.
Before this pr:
`self.query_start_loc = [0, 0, 0, 0, ... , 0]
self.query_start_loc_cpu = [0, 2, 4, 6, ... ,128]`
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
d68209402d
---------
Signed-off-by: SunnyLee219 <3294305115@qq.com>
This commit is contained in:
@@ -152,3 +152,37 @@ def test_qwen3_next_mtp_correctness_tp4(model_name: str,
|
|||||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||||
assert matches > int(0.66 * len(ref_outputs))
|
assert matches > int(0.66 * len(ref_outputs))
|
||||||
cleanup_dist_env_and_memory()
|
cleanup_dist_env_and_memory()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_name", MODELS)
|
||||||
|
@pytest.mark.parametrize("num_speculative_tokens", [1])
|
||||||
|
def test_qwen3_next_mtp_full_decode(model_name: str,
|
||||||
|
num_speculative_tokens: int):
|
||||||
|
example_prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
|
||||||
|
max_tokens = 20
|
||||||
|
'''
|
||||||
|
Compare the outputs of a original LLM and a speculative LLM
|
||||||
|
should be the same when using mtp speculative decoding.
|
||||||
|
'''
|
||||||
|
with VllmRunner(model_name,
|
||||||
|
tensor_parallel_size=4,
|
||||||
|
max_model_len=4096,
|
||||||
|
gpu_memory_utilization=0.8,
|
||||||
|
distributed_executor_backend="mp",
|
||||||
|
speculative_config={
|
||||||
|
"method": "qwen3_next_mtp",
|
||||||
|
"num_speculative_tokens": num_speculative_tokens,
|
||||||
|
},
|
||||||
|
compilation_config=CompilationConfig(
|
||||||
|
cudagraph_mode="FULL_DECODE_ONLY",
|
||||||
|
cudagraph_capture_sizes=[4])) as llm:
|
||||||
|
outputs = llm.generate_greedy(example_prompts, max_tokens)
|
||||||
|
print(outputs)
|
||||||
|
del llm
|
||||||
|
cleanup_dist_env_and_memory()
|
||||||
|
|||||||
@@ -217,13 +217,8 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
|
|||||||
mixed_qkv_non_spec)
|
mixed_qkv_non_spec)
|
||||||
|
|
||||||
if attn_metadata.num_prefills > 0 or spec_sequence_masks is not None:
|
if attn_metadata.num_prefills > 0 or spec_sequence_masks is not None:
|
||||||
is_cuda_graph = forward_context.cudagraph_runtime_mode != CUDAGraphMode.NONE
|
g, beta = fused_gdn_gating_patch(self.A_log, a, b,
|
||||||
if (is_cuda_graph):
|
self.dt_bias)
|
||||||
g, beta = fused_gdn_gating_patch(self.A_log, a, b,
|
|
||||||
self.dt_bias)
|
|
||||||
else:
|
|
||||||
g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias)
|
|
||||||
|
|
||||||
if spec_sequence_masks is not None:
|
if spec_sequence_masks is not None:
|
||||||
if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
|
if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
|
||||||
g_spec = g
|
g_spec = g
|
||||||
|
|||||||
Reference in New Issue
Block a user