diff --git a/tests/e2e/multicard/4-cards/spec_decode/test_mtp_qwen3_next.py b/tests/e2e/multicard/4-cards/spec_decode/test_mtp_qwen3_next.py index 709bb3e6..695011db 100644 --- a/tests/e2e/multicard/4-cards/spec_decode/test_mtp_qwen3_next.py +++ b/tests/e2e/multicard/4-cards/spec_decode/test_mtp_qwen3_next.py @@ -152,3 +152,37 @@ def test_qwen3_next_mtp_correctness_tp4(model_name: str, # Upon failure, inspect the outputs to check for inaccuracy. assert matches > int(0.66 * len(ref_outputs)) cleanup_dist_env_and_memory() + + +@pytest.mark.parametrize("model_name", MODELS) +@pytest.mark.parametrize("num_speculative_tokens", [1]) +def test_qwen3_next_mtp_full_decode(model_name: str, + num_speculative_tokens: int): + example_prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + max_tokens = 20 + ''' + Compare the outputs of a original LLM and a speculative LLM + should be the same when using mtp speculative decoding. + ''' + with VllmRunner(model_name, + tensor_parallel_size=4, + max_model_len=4096, + gpu_memory_utilization=0.8, + distributed_executor_backend="mp", + speculative_config={ + "method": "qwen3_next_mtp", + "num_speculative_tokens": num_speculative_tokens, + }, + compilation_config=CompilationConfig( + cudagraph_mode="FULL_DECODE_ONLY", + cudagraph_capture_sizes=[4])) as llm: + outputs = llm.generate_greedy(example_prompts, max_tokens) + print(outputs) + del llm + cleanup_dist_env_and_memory() diff --git a/vllm_ascend/patch/worker/patch_qwen3_next.py b/vllm_ascend/patch/worker/patch_qwen3_next.py index 21c549eb..1c327bad 100644 --- a/vllm_ascend/patch/worker/patch_qwen3_next.py +++ b/vllm_ascend/patch/worker/patch_qwen3_next.py @@ -217,13 +217,8 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase): mixed_qkv_non_spec) if attn_metadata.num_prefills > 0 or spec_sequence_masks is not None: - is_cuda_graph = forward_context.cudagraph_runtime_mode != CUDAGraphMode.NONE - if (is_cuda_graph): - g, beta = fused_gdn_gating_patch(self.A_log, a, b, - self.dt_bias) - else: - g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias) - + g, beta = fused_gdn_gating_patch(self.A_log, a, b, + self.dt_bias) if spec_sequence_masks is not None: if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: g_spec = g