diff --git a/tests/e2e/multicard/4-cards/test_qwen3_next.py b/tests/e2e/multicard/4-cards/test_qwen3_next.py index 445cd36e..f3799cb0 100644 --- a/tests/e2e/multicard/4-cards/test_qwen3_next.py +++ b/tests/e2e/multicard/4-cards/test_qwen3_next.py @@ -73,3 +73,38 @@ def test_qwen3_next_w8a8dynamic_distributed_tp4_ep(): quantization="ascend", ) as vllm_model: vllm_model.generate_greedy(example_prompts, max_tokens) + + +@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"}) +@patch.dict(os.environ, {"HCCL_BUFFSIZE": "1024"}) +def test_qwen3_next_distributed_mp_flash_comm_tp4(): + example_prompts = [ + "Hello, my name is", + ] * 4 + max_tokens = 5 + with VllmRunner("Qwen/Qwen3-Next-80B-A3B-Instruct", + tensor_parallel_size=4, + max_model_len=4096, + gpu_memory_utilization=0.7, + distributed_executor_backend="mp", + enable_expert_parallel=True, + enforce_eager=True) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) + del vllm_model + + +@patch.dict(os.environ, {"HCCL_BUFFSIZE": "1024"}) +def test_qwen3_next_distributed_mp_graph_mode_tp4(): + example_prompts = [ + "Hello, my name is", + ] * 4 + max_tokens = 5 + with VllmRunner("Qwen/Qwen3-Next-80B-A3B-Instruct", + tensor_parallel_size=4, + max_model_len=4096, + gpu_memory_utilization=0.7, + distributed_executor_backend="mp", + enable_expert_parallel=True, + enforce_eager=False) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) + del vllm_model \ No newline at end of file diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index 6706313b..c7510941 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -705,7 +705,11 @@ def _get_row_parallel_op( def get_parallel_op(disable_tp, prefix, layer, direct): - if disable_tp or ("shared_experts" in prefix and shared_expert_dp_enabled()): + if ( + disable_tp + or ("shared_experts" in prefix and shared_expert_dp_enabled()) + or ("shared_expert" in prefix and shared_expert_dp_enabled()) + ): return None, 0, 1 custom_op: ( MLPColumnParallelOp diff --git a/vllm_ascend/patch/worker/patch_qwen3_next.py b/vllm_ascend/patch/worker/patch_qwen3_next.py index 8ce2f0bd..8b642587 100644 --- a/vllm_ascend/patch/worker/patch_qwen3_next.py +++ b/vllm_ascend/patch/worker/patch_qwen3_next.py @@ -30,6 +30,7 @@ from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata from vllm_ascend.ops.triton.fla.fused_qkvzba_split_reshape import fused_qkvzba_split_reshape_cat from vllm_ascend.ops.triton.fla.sigmoid_gating import fused_sigmoid_gating_delta_rule_update from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch +from vllm_ascend.utils import enable_sp class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase): @@ -44,13 +45,13 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase): 2. Core attention (custom op) 3. Output projection """ - num_tokens = hidden_states.size(0) # ============================================================ # Part 1: Input Projection # ============================================================ projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states) projected_states_ba, _ = self.in_proj_ba(hidden_states) + num_tokens = projected_states_qkvz.size(0) mixed_qkv, z, b, a = fused_qkvzba_split_reshape_cat( projected_states_qkvz, @@ -126,9 +127,10 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase): num_actual_tokens = attn_metadata.num_actual_tokens num_accepted_tokens = attn_metadata.num_accepted_tokens - mixed_qkv = mixed_qkv[:num_actual_tokens] - b = b[:num_actual_tokens] - a = a[:num_actual_tokens] + if not enable_sp(): + mixed_qkv = mixed_qkv[:num_actual_tokens] + b = b[:num_actual_tokens] + a = a[:num_actual_tokens] # 1. Convolution sequence transformation conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) @@ -292,11 +294,20 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase): ) merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec) merged_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec) - core_attn_out[:num_actual_tokens] = merged_out.squeeze(0) + if not enable_sp(): + core_attn_out[:num_actual_tokens] = merged_out.squeeze(0) + else: + core_attn_out[:num_actual_tokens] = merged_out.squeeze(0)[:num_actual_tokens] elif spec_sequence_masks is not None: - core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0) + if not enable_sp(): + core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0) + else: + core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0)[:num_actual_tokens] else: - core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0) + if not enable_sp(): + core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0) + else: + core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)[:num_actual_tokens] Qwen3NextGatedDeltaNet.forward = AscendQwen3Next_GatedDeltaNet.forward diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index f574b9e4..7fa58a44 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1214,7 +1214,12 @@ class NPUModelRunner(GPUModelRunner): # Currently, Graph Mode and SP will both pad num_tokens, # Another possible condition is num_tokens_padded != num_tokens_unpadded # but this scope is way too big and the consequences are unpredictable + old_num_reqs_padded = num_reqs_padded num_reqs_padded = self._pad_query_start_loc_for_fia(num_tokens_padded, num_reqs_padded, num_reqs) + if enable_sp() and num_tokens_padded == num_tokens_unpadded: + if num_reqs_padded > old_num_reqs_padded: + num_reqs_padded = old_num_reqs_padded + self.query_start_loc.np[num_reqs_padded + 1] = 0 (attn_metadata, spec_decode_common_attn_metadata) = self._build_attention_metadata( num_tokens=num_tokens_unpadded