diff --git a/docs/source/user_guide/configuration/additional_config.md b/docs/source/user_guide/configuration/additional_config.md index df01430..75d0149 100644 --- a/docs/source/user_guide/configuration/additional_config.md +++ b/docs/source/user_guide/configuration/additional_config.md @@ -32,6 +32,7 @@ The following table lists the additional configuration options available in vLLM | `expert_map_path` | str | `None` | When using expert load balancing for the MOE model, an expert map path needs to be passed in. | | `chunked_prefill_for_mla` | bool | `False` | Whether to enable the fused operator-like chunked_prefill. | | `kv_cache_dtype` | str | `None` | When using the kv cache quantization method, kv cache dtype needs to be set, currently only int8 is supported. | +| `enable_shared_expert_dp` | bool | `True` | When the shared expert in DP, it has better performance but consumes more memory. When the memory is sensitive, this switch can be turned off manually. | The details of each config option are as follows: diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 652cff3..497b7b5 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -691,3 +691,40 @@ class TestAscendMLAImpl(TestBase): self.assertEqual(result.shape[2], self.impl.v_head_dim) mock_up_proj.assert_called_once() mock_page_attention_mla.assert_called_once() + + @patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._forward_prefill") + @patch("torch_npu._npu_reshape_and_cache") + def test_forward_without_graph(self, _, mock_forward_prefill): + self.impl.running_in_graph = False + self.impl.torchair_graph_enabled = False + + num_tokens = 100 + num_blocks = 256 + block_size = 4 + rotary_emb_return_value = (torch.randn(num_tokens, 16, + self.impl.kv_lora_rank), + torch.randn(0, 1, self.impl.kv_lora_rank)) + self.impl.rotary_emb.side_effect = lambda *args, **kwargs: rotary_emb_return_value + self.impl.o_proj.side_effect = lambda *args, **kwargs: torch.randn( + 1, num_blocks, 128) + + hidden_states_or_q_c = torch.randn(num_tokens, self.impl.q_lora_rank) + hidden_states_or_kv_c_normed = torch.randn(num_tokens, + self.impl.kv_lora_rank) + k_pe = torch.randn(num_tokens, self.impl.qk_rope_head_dim) + kv_cache = (torch.randn(num_blocks, block_size, self.impl.num_heads, + self.impl.kv_lora_rank), + torch.randn(num_blocks, block_size, self.impl.num_heads, + self.impl.qk_rope_head_dim)) + output = torch.randn(num_tokens, self.impl.num_heads, + self.impl.v_head_dim) + + metadata = MagicMock() + metadata.num_decodes = 0 + metadata.num_prefills = num_tokens + mock_forward_prefill.return_value = torch.randn( + 0, self.impl.num_heads * self.impl.v_head_dim) + result = self.impl.forward(None, hidden_states_or_q_c, + hidden_states_or_kv_c_normed, k_pe, + kv_cache, metadata, output, False) + self.assertEqual(result.shape[0], num_tokens) diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 659f441..777ff9f 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -47,6 +47,9 @@ class AscendConfig: self.expert_map_path = additional_config.get("expert_map_path", None) self.chunked_prefill_for_mla = additional_config.get( "chunked_prefill_for_mla", False) + self.enable_shared_expert_dp = additional_config.get( + "enable_shared_expert_dp", True + ) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel class TorchairGraphConfig: @@ -166,6 +169,10 @@ def check_ascend_config(vllm_config, enforce_eager): raise NotImplementedError( "Torchair graph mode only works with following model types:" f"{TORCHAIR_MODEL_LIST}.") + if ascend_config.enable_shared_expert_dp: + logger.warning( + "enable_shared_expert_dp is not supported for torchair graph mode currently, " + "it has been disabled automatically.") # aclgraph case else: # aclgraph doesn't work with deepseek model and only qwen model is well tested. diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 48713fc..e7dccf3 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -621,6 +621,7 @@ class AscendMLAImpl(MLAAttentionImpl): ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz + self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp # Adapt torch air graph mode with spec decoding. speculative_config = get_current_vllm_config().speculative_config @@ -635,6 +636,8 @@ class AscendMLAImpl(MLAAttentionImpl): x = torch.bmm(x, self.W_UV) # Convert from (N, B, V) to (B, N * V) x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) + if hasattr(self, "running_in_graph") and not self.running_in_graph: + return x MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB npu_prefetch(self.o_proj.weight, x, @@ -905,14 +908,7 @@ class AscendMLAImpl(MLAAttentionImpl): ] and not ascend_config.chunked_prefill_for_mla: attn_output = attn_output_torch - current_ms_metadata = get_multistream_comm_context() - if current_ms_metadata is None: - return self.o_proj(attn_output, is_prefill=True)[0] - else: - current_ms_metadata.before_comm_event.record() - with torch.npu.stream(current_ms_metadata.comm_stream): - current_ms_metadata.before_comm_event.wait() - return self.o_proj(attn_output, is_prefill=True)[0] + return attn_output def exec_kv( self, @@ -1249,6 +1245,12 @@ class AscendMLAImpl(MLAAttentionImpl): key_cache=kv_cache[0], value_cache=kv_cache[1], slot_indices=attn_metadata.slot_mapping) + if not self.running_in_graph: + o_proj_input_shape = (num_actual_toks, + self.num_heads * self.v_head_dim) + o_proj_input = torch.empty(o_proj_input_shape, + dtype=hidden_states_or_q_c.dtype, + device=hidden_states_or_q_c.device) if has_prefill: # FIX: aicore move should be also placed on the comm stream in dbo, # otherwise it may affect the accuracy @@ -1259,11 +1261,12 @@ class AscendMLAImpl(MLAAttentionImpl): attn_metadata) current_ms_metadata = get_multistream_comm_context() if current_ms_metadata is not None: + current_ms_metadata.before_comm_event.record() with torch.npu.stream(current_ms_metadata.comm_stream): - output[num_decode_tokens:] = output_prefill - current_ms_metadata.after_comm_event.record() + current_ms_metadata.before_comm_event.wait() + o_proj_input[num_decode_tokens:] = output_prefill else: - output[num_decode_tokens:] = output_prefill + o_proj_input[num_decode_tokens:] = output_prefill if has_decode: if self.running_in_graph: @@ -1280,9 +1283,32 @@ class AscendMLAImpl(MLAAttentionImpl): current_ms_metadata = get_multistream_comm_context() if current_ms_metadata is not None: with torch.npu.stream(current_ms_metadata.comm_stream): - output[:num_decode_tokens] = output_decode - current_ms_metadata.after_comm_event.record() + o_proj_input[:num_decode_tokens] = output_decode else: - output[:num_decode_tokens] = output_decode + o_proj_input[:num_decode_tokens] = output_decode + current_ms_metadata = get_multistream_comm_context() + MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB + if current_ms_metadata is None: + npu_prefetch(self.o_proj.weight, + o_proj_input, + max_size=MAX_O_PROJ_PREFETCH_SIZE, + enabled=enable_multistream_mla) + + output[...] = self.o_proj( + o_proj_input, + is_prefill=True, + is_force_scatter=self.enable_shared_expert_dp)[0] + else: + with torch.npu.stream(current_ms_metadata.comm_stream): + npu_prefetch(self.o_proj.weight, + o_proj_input, + max_size=MAX_O_PROJ_PREFETCH_SIZE, + enabled=enable_multistream_mla) + output[...] = self.o_proj( + o_proj_input, + is_prefill=True, + is_force_scatter=self.enable_shared_expert_dp)[0] + current_ms_metadata.after_comm_event.record() + del o_proj_input return output_padded diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index ce051c4..0e4cf83 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -141,7 +141,8 @@ class CustomDeepseekV2RowParallelLinearReplaceAllreduce(RowParallelLinear): def forward( self, input_, - is_prefill=True + is_prefill=True, + is_force_scatter=False ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]: if self.input_is_parallel: input_parallel = input_ @@ -160,7 +161,13 @@ class CustomDeepseekV2RowParallelLinearReplaceAllreduce(RowParallelLinear): input_parallel, bias=bias_) if self.reduce_results and self.tp_size > 1: - if not is_prefill and output_parallel.shape[0] % self.tp_size == 0: + num_tokens = output_parallel.shape[0] + if is_force_scatter and num_tokens % self.tp_size: + output_parallel = nn.functional.pad( + output_parallel, (0, 0, 0, -num_tokens % self.tp_size)) + if is_force_scatter or (not is_prefill + and output_parallel.shape[0] % self.tp_size + == 0): output = tensor_model_parallel_reduce_scatter(output_parallel, dim=0) else: @@ -180,7 +187,8 @@ class CustomDeepseekV2RowParallelLinear(RowParallelLinear): def forward( self, input_, - is_prefill=True + is_prefill=True, + is_force_scatter=False ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]: if self.input_is_parallel: input_parallel = input_ @@ -347,13 +355,15 @@ class CustomDeepseekV2MoE(nn.Module): reduce_results = not self.all_reduce_merge intermediate_size = (config.moe_intermediate_size * config.n_shared_experts) + enable_shared_expert_dp = ascend_config.enable_shared_expert_dp self.shared_experts = CustomDeepseekV2MLP( hidden_size=config.hidden_size, intermediate_size=intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, reduce_results=reduce_results, - force_replicate=self.enable_multistream_moe, + force_replicate=self.enable_multistream_moe + or enable_shared_expert_dp, prefix=f"{prefix}.shared_experts", ) else: @@ -447,9 +457,11 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): self.kv_lora_rank = kv_lora_rank self.num_heads = num_heads - tp_size = get_tensor_model_parallel_world_size() - assert num_heads % tp_size == 0 - self.num_local_heads = num_heads // tp_size + self.tp_size = get_tensor_model_parallel_world_size() + assert num_heads % self.tp_size == 0 + self.num_local_heads = num_heads // self.tp_size + self.layers = config.num_hidden_layers + self.first_k_dense_replace = config.first_k_dense_replace self.scaling = self.qk_head_dim**-0.5 self.rope_theta = rope_theta @@ -462,6 +474,7 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled self.enable_multistream_mla = \ ascend_config.torchair_graph_config.enable_multistream_mla + self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp if self.q_lora_rank is not None: self.q_a_proj = ReplicatedLinear(self.hidden_size, @@ -501,8 +514,9 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): prefix=f"{prefix}.kv_b_proj") if (config.n_routed_experts is not None and self.debug_layer_idx >= config.first_k_dense_replace - and self.debug_layer_idx % config.moe_layer_freq == 0 and - ascend_config.torchair_graph_config.enable_multistream_moe): + and self.debug_layer_idx % config.moe_layer_freq == 0 + and (ascend_config.torchair_graph_config.enable_multistream_moe + or self.enable_shared_expert_dp)): self.o_proj = CustomDeepseekV2RowParallelLinearReplaceAllreduce( self.num_heads * self.v_head_dim, self.hidden_size, @@ -596,13 +610,27 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): output = output.view(-1, output_shape[-1]) return output else: - kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( + kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0] + if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers: + hidden_states_or_q_c = get_tp_group().all_gather( + hidden_states_or_q_c, 0) + kv_no_split = get_tp_group().all_gather(kv_no_split, 0) + + kv_c, k_pe = kv_no_split.split( [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) + if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace: + output_shape = hidden_states.shape + else: + num_tokens = hidden_states_or_q_c.shape[0] + rows = num_tokens // self.tp_size + if num_tokens % self.tp_size: + rows += 1 + output_shape = (rows, hidden_states.shape[1]) return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, - output_shape=hidden_states.shape) + output_shape=output_shape) class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): @@ -677,6 +705,8 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): eps=config.rms_norm_eps) self.routed_scaling_factor = config.routed_scaling_factor self.first_k_dense_replace = config.first_k_dense_replace + self.tp_group = get_tp_group().device_group + self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp def forward( self, @@ -731,6 +761,18 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): # first layer. residual *= 1. / self.routed_scaling_factor + tp_size = get_tensor_model_parallel_world_size() + if self.enable_shared_expert_dp and ( + self.layer_idx == self.first_k_dense_replace + or self.layer_idx == self.layers) and tp_size > 1: + num_tokens, _ = residual.shape + if num_tokens % tp_size: + residual = nn.functional.pad(residual, + (0, 0, 0, -num_tokens % tp_size)) + chunk_residual = torch.tensor_split(residual, tp_size, dim=0) + tp_rank = get_tensor_model_parallel_rank() + residual = chunk_residual[tp_rank] + # Fully Connected hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) @@ -756,6 +798,22 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): dim=0) residual = tensor_model_parallel_all_gather(residual, dim=0) + # for last layer of main model and mtp layer. + if self.enable_shared_expert_dp and self.layer_idx >= ( + self.layers - 1) and tp_size > 1: + hidden_states = get_tp_group().all_gather(hidden_states, 0) + residual = get_tp_group().all_gather(residual, 0) + + attn_metadata = get_forward_context().attn_metadata + if attn_metadata is not None: + num_tokens = attn_metadata.num_actual_tokens + else: + num_tokens = hidden_states.shape[0] + + if num_tokens < hidden_states.shape[0]: + hidden_states = hidden_states[:num_tokens] + residual = residual[:num_tokens] + return hidden_states, residual diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 625146d..aec6e72 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -1268,6 +1268,7 @@ class AscendFusedMoE(FusedMoE): self.enable_multistream_moe = \ ascend_config.torchair_graph_config.enable_multistream_moe and \ self.torchair_graph_enabled + self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp if self.scoring_func != "softmax" and not self.use_grouped_topk: raise ValueError("Only softmax scoring function is supported for " @@ -1408,22 +1409,24 @@ class AscendFusedMoE(FusedMoE): else: # TODO: Determine if we can remove the padding padding_size = tp_size - if num_tokens < padding_size: + if num_tokens < padding_size and not self.enable_shared_expert_dp: hidden_states = nn.functional.pad( hidden_states, (0, 0, 0, padding_size - num_tokens)) router_logits = nn.functional.pad( router_logits, (0, 0, 0, padding_size - num_tokens)) if tp_size > 1: - chunk_hidden_states = torch.tensor_split(hidden_states, - tp_size, - dim=0) - chunk_router_logits = torch.tensor_split(router_logits, - tp_size, - dim=0) - chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0) tp_rank = get_tensor_model_parallel_rank() - hidden_states = chunk_hidden_states[tp_rank] - router_logits = chunk_router_logits[tp_rank] + if not self.enable_shared_expert_dp: + chunk_hidden_states = torch.tensor_split(hidden_states, + tp_size, + dim=0) + chunk_router_logits = torch.tensor_split(router_logits, + tp_size, + dim=0) + hidden_states = chunk_hidden_states[tp_rank] + router_logits = chunk_router_logits[tp_rank] + + chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0) mc2_mask = chunk_mc2_mask[tp_rank] if self.dp_size > 1: @@ -1490,7 +1493,7 @@ class AscendFusedMoE(FusedMoE): if (fused_moe_state not in [ FusedMoEState.AllGather, FusedMoEState.AllGatherEP, FusedMoEState.NaiveMulticast - ] and not replace_allreduce): + ] and not replace_allreduce and not self.enable_shared_expert_dp): if tp_size > 1: dist.all_gather(list(chunk_hidden_states), e_hidden_states, self.tp_group) @@ -1500,7 +1503,7 @@ class AscendFusedMoE(FusedMoE): final_hidden_states = e_hidden_states if num_tokens < padding_size: final_hidden_states = final_hidden_states[:num_tokens] - elif self.dp_size > 1: + elif self.dp_size > 1 and not self.enable_shared_expert_dp: if fused_moe_state == FusedMoEState.NaiveMulticast: start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ self.dp_rank - 1]