[Bugfix][CI]Add qwen3Next MTP+Full Decode (#6047)
### What this PR does / why we need it?
Fix a bug in the repo and add a test case for MTP + Full Decode Only +
Qwen3Next.
The _build_dummy_attn_metadata function in NPUModelRunner seems losed a
query_star_loc.copy_to_gpu operation, which will lead to difference
between query_start_loc and query_start_loc_cpu, and they are required
to be same in MTP + Full Decode Only + Qwen3Next case.
Before this pr:
`self.query_start_loc = [0, 0, 0, 0, ... , 0]
self.query_start_loc_cpu = [0, 2, 4, 6, ... ,128]`
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
d68209402d
---------
Signed-off-by: SunnyLee219 <3294305115@qq.com>
This commit is contained in:
@@ -152,3 +152,37 @@ def test_qwen3_next_mtp_correctness_tp4(model_name: str,
|
||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||
assert matches > int(0.66 * len(ref_outputs))
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", MODELS)
|
||||
@pytest.mark.parametrize("num_speculative_tokens", [1])
|
||||
def test_qwen3_next_mtp_full_decode(model_name: str,
|
||||
num_speculative_tokens: int):
|
||||
example_prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
max_tokens = 20
|
||||
'''
|
||||
Compare the outputs of a original LLM and a speculative LLM
|
||||
should be the same when using mtp speculative decoding.
|
||||
'''
|
||||
with VllmRunner(model_name,
|
||||
tensor_parallel_size=4,
|
||||
max_model_len=4096,
|
||||
gpu_memory_utilization=0.8,
|
||||
distributed_executor_backend="mp",
|
||||
speculative_config={
|
||||
"method": "qwen3_next_mtp",
|
||||
"num_speculative_tokens": num_speculative_tokens,
|
||||
},
|
||||
compilation_config=CompilationConfig(
|
||||
cudagraph_mode="FULL_DECODE_ONLY",
|
||||
cudagraph_capture_sizes=[4])) as llm:
|
||||
outputs = llm.generate_greedy(example_prompts, max_tokens)
|
||||
print(outputs)
|
||||
del llm
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
@@ -217,13 +217,8 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
|
||||
mixed_qkv_non_spec)
|
||||
|
||||
if attn_metadata.num_prefills > 0 or spec_sequence_masks is not None:
|
||||
is_cuda_graph = forward_context.cudagraph_runtime_mode != CUDAGraphMode.NONE
|
||||
if (is_cuda_graph):
|
||||
g, beta = fused_gdn_gating_patch(self.A_log, a, b,
|
||||
self.dt_bias)
|
||||
else:
|
||||
g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias)
|
||||
|
||||
g, beta = fused_gdn_gating_patch(self.A_log, a, b,
|
||||
self.dt_bias)
|
||||
if spec_sequence_masks is not None:
|
||||
if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
|
||||
g_spec = g
|
||||
|
||||
Reference in New Issue
Block a user