diff --git a/docs/source/user_guide/configuration/additional_config.md b/docs/source/user_guide/configuration/additional_config.md index 31bae11..cf4d1dc 100644 --- a/docs/source/user_guide/configuration/additional_config.md +++ b/docs/source/user_guide/configuration/additional_config.md @@ -31,6 +31,7 @@ The following table lists the additional configuration options available in vLLM | `refresh` | bool | `false` | Whether to refresh global ascend config content. This value is usually used by rlhf or ut/e2e test case. | | `expert_map_path` | str | `None` | When using expert load balancing for the MOE model, an expert map path needs to be passed in. | | `chunked_prefill_for_mla` | bool | `False` | Whether to enable the fused operator-like chunked_prefill. | +| `enable_prefetch` | bool | `False` | Whether to enable weight prefetch. | | `kv_cache_dtype` | str | `None` | When using the kv cache quantization method, kv cache dtype needs to be set, currently only int8 is supported. | | `enable_shared_expert_dp` | bool | `False` | When the shared expert in DP, it has better performance but consumes more memory. Currently only DeepSeek series models are supported to use. | diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 3ca7210..4a29dd9 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -210,6 +210,7 @@ class TestAscendMLAMetadataBuilder(TestBase): with patch("vllm_ascend.attention.mla_v1.get_ascend_config", return_value=ascend_config): builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device) + builder.decode_threshold = 1 input_batch = MagicMock() input_batch.req_ids = [0, 1, 2, 3] @@ -303,18 +304,16 @@ class TestAscendMLAImpl(TestBase): self.assertEqual(self.impl.num_queries_per_kv, 32) self.assertEqual(self.impl.tp_size, 2) - def test_v_up_proj_and_o_proj(self): + def test_v_up_proj(self): batch_size = 4 x = torch.randn(batch_size, self.impl.num_heads, self.impl.kv_lora_rank) - self.impl.o_proj.return_value = (torch.randn( - batch_size, self.impl.num_heads * self.impl.v_head_dim), ) if not hasattr(self.impl, 'W_UV') or self.impl.W_UV is None: self.impl.W_UV = torch.randn(self.impl.num_heads, self.impl.kv_lora_rank, self.impl.v_head_dim) - result = self.impl._v_up_proj_and_o_proj(x) + result = self.impl._v_up_proj(x) self.assertEqual(result.shape[0], batch_size) self.assertEqual(result.shape[1], @@ -371,8 +370,11 @@ class TestAscendMLAImpl(TestBase): metadata.prefill = None prefix_out = torch.randn(2, 16, 128) prefix_lse = torch.randn(2, 16, 8) - out, lse = self.impl._compute_prefill_context(query, kv_cache, 32, - metadata, prefix_out, + q_pe = query[..., self.impl.qk_nope_head_dim:] + q_nope = query[..., :self.impl.qk_nope_head_dim] + + out, lse = self.impl._compute_prefill_context(q_nope, q_pe, kv_cache, + 32, metadata, prefix_out, prefix_lse) self.assertTrue(torch.equal(prefix_out, out)) @@ -386,6 +388,8 @@ class TestAscendMLAImpl(TestBase): latent_kv_dim = self.impl.kv_lora_rank num_blocks, block_size = 100, 20 query = torch.randn(S, N, D) + q_nope = query[..., :self.impl.qk_nope_head_dim] + q_pe = query[..., self.impl.qk_nope_head_dim:] kv_cache_0 = torch.randn(num_blocks, block_size, N, latent_kv_dim) kv_cache_1 = torch.randn(num_blocks, block_size, N, D) kv_cache = [kv_cache_0, kv_cache_1] @@ -406,9 +410,11 @@ class TestAscendMLAImpl(TestBase): meta = MagicMock() meta.prefill = prefill_meta + self.impl.prefill_mask = torch.triu( + torch.ones(512, 512, device=q_nope.device, dtype=q_nope.dtype), 1) - out, lse = self.impl._compute_prefill_context(query, kv_cache, 32, - meta, prefix_out, + out, lse = self.impl._compute_prefill_context(q_nope, q_pe, kv_cache, + 32, meta, prefix_out, prefix_lse) mock_load.assert_called_once() @@ -417,67 +423,36 @@ class TestAscendMLAImpl(TestBase): self.assertEqual(out.shape, prefix_out.shape) self.assertEqual(lse.shape, prefix_lse.shape) - @patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj_and_o_proj") - @patch("torch_npu._npu_paged_attention_mla") - def test_forward_decode_without_graph(self, mock_page_attention_mla, + @patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj") + @patch("torch_npu.npu_fused_infer_attention_score") + def test_forward_decode_without_graph(self, + mock_npu_fused_infer_attention_score, mock_up_proj): num_tokens = 100 - num_blocks = 256 block_size = 4 q_nope = torch.randn(num_tokens, self.impl.num_heads, self.impl.qk_nope_head_dim) q_pe = torch.randn(num_tokens, self.impl.num_heads, self.impl.qk_rope_head_dim) - kv_c_and_k_pe_cache = torch.randn(num_blocks, block_size, - self.impl.num_heads, - self.impl.kv_lora_rank) + k_nope = torch.randn(num_tokens, self.impl.num_heads, + self.impl.qk_nope_head_dim) + k_pe = torch.randn(num_tokens, self.impl.num_heads, + self.impl.qk_rope_head_dim) metadata = MagicMock() metadata.decode = MagicMock() metadata.decode.block_table = MagicMock() metadata.decode.seq_lens = 10 - mock_page_attention_mla.return_value = torch.randn( - num_tokens, self.impl.num_heads, self.impl.kv_lora_rank) + mock_npu_fused_infer_attention_score.return_value = [ + torch.randn(num_tokens, self.impl.num_heads, + self.impl.kv_lora_rank), None + ] mock_up_proj.return_value = torch.randn(num_tokens, self.impl.num_heads, self.impl.v_head_dim) - result = self.impl._forward_decode(q_nope, q_pe, None, None, - kv_c_and_k_pe_cache, metadata) + result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe, + block_size, metadata) self.assertEqual(result.shape[0], num_tokens) self.assertEqual(result.shape[1], self.impl.num_heads) self.assertEqual(result.shape[2], self.impl.v_head_dim) mock_up_proj.assert_called_once() - mock_page_attention_mla.assert_called_once() - - @patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._forward_prefill") - @patch("torch_npu._npu_reshape_and_cache") - def test_forward_without_graph(self, _, mock_forward_prefill): - num_tokens = 100 - num_blocks = 256 - block_size = 4 - rotary_emb_return_value = (torch.randn(num_tokens, 16, - self.impl.kv_lora_rank), - torch.randn(0, 1, self.impl.kv_lora_rank)) - self.impl.rotary_emb.side_effect = lambda *args, **kwargs: rotary_emb_return_value - self.impl.o_proj.side_effect = lambda *args, **kwargs: torch.randn( - 1, num_blocks, 128) - - hidden_states_or_q_c = torch.randn(num_tokens, self.impl.q_lora_rank) - hidden_states_or_kv_c_normed = torch.randn(num_tokens, - self.impl.kv_lora_rank) - k_pe = torch.randn(num_tokens, self.impl.qk_rope_head_dim) - kv_cache = (torch.randn(num_blocks, block_size, self.impl.num_heads, - self.impl.kv_lora_rank), - torch.randn(num_blocks, block_size, self.impl.num_heads, - self.impl.qk_rope_head_dim)) - output = torch.randn(num_tokens, self.impl.num_heads, - self.impl.v_head_dim) - - metadata = MagicMock() - metadata.num_decodes = 0 - metadata.num_prefills = num_tokens - mock_forward_prefill.return_value = torch.randn( - 0, self.impl.num_heads * self.impl.v_head_dim) - result = self.impl.forward(None, hidden_states_or_q_c, - hidden_states_or_kv_c_normed, k_pe, - kv_cache, metadata, output, False) - self.assertEqual(result.shape[0], num_tokens) + mock_npu_fused_infer_attention_score.assert_called_once() diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 3769bcb..81cf177 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -50,6 +50,7 @@ class AscendConfig: self.enable_shared_expert_dp = additional_config.get( "enable_shared_expert_dp", False ) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel + self.enable_prefetch = additional_config.get("enable_prefetch", False) class TorchairGraphConfig: diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 605d8c1..dcdd627 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -1,19 +1,18 @@ from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional, Tuple, Type, TypeVar +from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Type, TypeVar import torch -import torch.nn as nn import torch_npu -from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, +from torch import nn +from vllm.attention.backends.abstract import (AttentionBackend, AttentionMetadata, MLAAttentionImpl) from vllm.config import VllmConfig, get_current_vllm_config -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) from vllm.utils import cdiv, round_down -import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, @@ -22,6 +21,7 @@ from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig from vllm_ascend.multistream.context import get_multistream_comm_context from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla +from vllm_ascend.utils import npu_prefetch from vllm_ascend.worker.npu_input_batch import InputBatch if TYPE_CHECKING: @@ -184,6 +184,9 @@ class AscendMLAMetadataBuilder: self.max_blocks = (vllm_config.model_config.max_model_len + self.block_size - 1) // self.block_size self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled + + self.decode_threshold = 1 + if self.chunked_prefill_enabled: self.chunked_prefill_workspace_size = min( # Max sure there is enough for 8 full length request or at least @@ -224,7 +227,7 @@ class AscendMLAMetadataBuilder: for i, req_id in enumerate(input_batch.req_ids): num_tokens = scheduler_output.num_scheduled_tokens[req_id] - if num_tokens == 1: + if num_tokens <= self.decode_threshold: decodes.append(i) else: prefills.append(i) @@ -270,9 +273,8 @@ class AscendMLAMetadataBuilder: query_start_loc = common_attn_metadata.query_start_loc query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu # TODO(xyx): remove the if condition after mla supports torch mode speculative decoding - decode_threshold = 1 num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ - split_decodes_and_prefills(common_attn_metadata, decode_threshold=decode_threshold) + split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold) assert num_decodes + num_prefills == num_reqs assert num_decode_tokens + num_prefill_tokens == num_actual_tokens @@ -312,8 +314,8 @@ class AscendMLAMetadataBuilder: if num_prefills > 0: reqs_start = num_decodes # prefill_start tokens_start = num_decode_tokens - max_query_len = query_lens[tokens_start:].max().item() - max_seq_lens = seq_lens[tokens_start:].max().item() + max_query_len = query_lens[reqs_start:].max().item() + max_seq_lens = seq_lens[reqs_start:].max().item() prefill_query_start_loc = query_start_loc[ reqs_start:] - query_start_loc[reqs_start] @@ -359,9 +361,9 @@ class AscendMLAMetadataBuilder: 1).unsqueeze(2) prefill_metadata = AscendMLAPrefillMetadata( attn_mask=common_attn_metadata.attn_mask, - query_lens=query_lens[tokens_start:], + query_lens=query_lens[reqs_start:], seq_lens=seq_lens, - context_lens=seq_lens[tokens_start:], + context_lens=seq_lens[reqs_start:], input_positions=prefill_input_positions, block_table=block_table[reqs_start:, ...], max_query_len=max_query_len, @@ -416,6 +418,21 @@ class AscendMLAMetadataBuilder: ) +class DecodeMLAPreprocessResult(NamedTuple): + ql_nope: Optional[torch.Tensor] = None + q_pe: Optional[torch.Tensor] = None + k_nope: Optional[torch.Tensor] = None + k_pe: Optional[torch.Tensor] = None + + +class PrefillMLAPreprocessResult(NamedTuple): + q_nope: Optional[torch.Tensor] = None + q_pe: Optional[torch.Tensor] = None + k_nope: Optional[torch.Tensor] = None + k_pe: Optional[torch.Tensor] = None + value: Optional[torch.Tensor] = None + + class AscendMLAImpl(MLAAttentionImpl): """ NOTE: Please read the comment at the top of the file before trying to @@ -455,11 +472,18 @@ class AscendMLAImpl(MLAAttentionImpl): self.o_proj = kwargs['o_proj'] self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None) self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None) + self.q_a_proj = kwargs.get('q_a_proj', None) + self.q_a_layernorm = kwargs.get('q_a_layernorm', None) self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.tp_size = get_tensor_model_parallel_world_size() ascend_config = get_ascend_config() self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp + self.enable_prefetch = ascend_config.enable_prefetch + self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz + self.chunked_prefill_for_mla = ascend_config.chunked_prefill_for_mla + + self.prefill_mask = None # Adapt torch air graph mode with spec decoding. speculative_config = get_current_vllm_config().speculative_config @@ -467,7 +491,7 @@ class AscendMLAImpl(MLAAttentionImpl): self.spec_token_num = speculative_config.num_speculative_tokens assert self.spec_token_num > 0 - def _v_up_proj_and_o_proj(self, x): + def _v_up_proj(self, x): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) # Multiply (N, B, L) x (N, L, V) -> (N, B, V) @@ -546,7 +570,8 @@ class AscendMLAImpl(MLAAttentionImpl): def _compute_prefill_context( self, - query: torch.Tensor, + q_nope: torch.Tensor, + q_pe: torch.Tensor, kv_c_and_k_pe_cache: Tuple[torch.Tensor], rope_dim: int, attn_metadata: AscendMLAMetadata, @@ -559,8 +584,6 @@ class AscendMLAImpl(MLAAttentionImpl): return prefix_output, prefix_lse iters = len(prefill_metadata.chunked_context.seq_tot) - q_pe = query[..., self.qk_nope_head_dim:] - q_nope = query[..., :self.qk_nope_head_dim] seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32) cache_kv_c = kv_c_and_k_pe_cache[0] @@ -575,19 +598,19 @@ class AscendMLAImpl(MLAAttentionImpl): kv_c_normed = torch.empty(toks, num_heads, latent_kv_dim, - dtype=query.dtype, - device=query.device) + dtype=q_nope.dtype, + device=q_nope.device) k_pe = torch.empty(toks, num_heads, rope_dim, - dtype=query.dtype, - device=query.device) + dtype=q_nope.dtype, + device=q_nope.device) torch_npu.atb.npu_paged_cache_load( cache_kv_c, cache_k_pe, prefill_metadata.block_table, - seq_len2.to(query.device), + seq_len2.to(q_nope.device), seq_starts=prefill_metadata.chunked_context.starts[i], key=kv_c_normed, value=k_pe, @@ -599,16 +622,13 @@ class AscendMLAImpl(MLAAttentionImpl): k_nope, v = kv_nope\ .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) k_pe = k_pe.expand((*k_nope.shape[:-1], -1)) - mask = torch.triu( - torch.ones(512, 512, device=query.device, dtype=query.dtype), - 1) torch_npu.atb.npu_ring_mla( q_nope=q_nope, q_rope=q_pe, k_nope=k_nope, k_rope=k_pe, value=v, - mask=mask, + mask=self.prefill_mask, seqlen=seq_len, head_num=self.num_heads, kv_head_num=self.num_heads, @@ -625,33 +645,74 @@ class AscendMLAImpl(MLAAttentionImpl): def _forward_prefill( self, - query: torch.Tensor, - kv_c_normed: torch.Tensor, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + k_nope: torch.Tensor, k_pe: torch.Tensor, + value: torch.Tensor, kv_c_and_k_pe_cache: Tuple[torch.Tensor], attn_metadata: AscendMLAMetadata, ) -> torch.Tensor: assert attn_metadata.prefill is not None assert len(kv_c_and_k_pe_cache) > 1 - - num_tokens = query.size(0) + num_tokens = q_nope.size(0) attn_output = torch.empty(num_tokens, self.num_heads, self.v_head_dim, - dtype=query.dtype, - device=query.device) - k_nope, value = self.kv_b_proj(kv_c_normed)[0].view( - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k_pe = k_pe.expand((*k_nope.shape[:-1], -1)) - # Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache - ascend_config = get_ascend_config() - - if attn_metadata.attn_state in [ - AscendAttentionState.ChunkedPrefill, - AscendAttentionState.SpecDecoding, - AscendAttentionState.PrefillCacheHit - ] and not ascend_config.chunked_prefill_for_mla: + dtype=q_nope.dtype, + device=q_nope.device) + if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: + query = torch.cat((q_nope, q_pe), dim=-1) + key = torch.cat((k_nope, k_pe), dim=-1) + torch_npu._npu_flash_attention( + query=query, + key=key, + value=value, + mask=attn_metadata.attn_mask, + seq_len=attn_metadata.prefill.context_lens, + scale_value=self.scale, + num_heads=self.num_heads, + num_kv_heads=self.num_heads, + out=attn_output) + elif self.chunked_prefill_for_mla: + attn_lse = torch.empty(self.num_heads, + num_tokens, + dtype=torch.float32, + device=q_nope.device) + if self.prefill_mask is None: + self.prefill_mask = torch.triu( + torch.ones(512, + 512, + device=q_nope.device, + dtype=q_nope.dtype), + 1) # 512: mask only support 512 + if attn_metadata.num_prefills > 1: + self.prefill_mask = self.prefill_mask.unsqueeze(0).repeat( + attn_metadata.num_prefills, 1, 1) + torch_npu.atb.npu_ring_mla( + q_nope=q_nope, + q_rope=q_pe, + k_nope=k_nope, + k_rope=k_pe, + value=value, + mask=self.prefill_mask, + seqlen=torch.tensor(attn_metadata.prefill.query_lens, + dtype=torch.int32), + head_num=self.num_heads, + kv_head_num=self.num_heads, + pre_out=None, + prev_lse=None, + qk_scale=self.scale, + kernel_type="kernel_type_high_precision", + mask_type="mask_type_triu", + input_layout="type_bsnd", + calc_type="calc_type_first_ring", + output=attn_output, + softmax_lse=attn_lse) + attn_output, attn_lse = self._compute_prefill_context( \ + q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse) + else: + query = torch.cat((q_nope, q_pe), dim=-1) attn_output_torch = torch.empty(num_tokens, self.num_heads * self.v_head_dim, dtype=query.dtype, @@ -673,240 +734,318 @@ class AscendMLAImpl(MLAAttentionImpl): scale=self.scale, alibi_slopes=None, causal=True) - elif attn_metadata.attn_state in [ - AscendAttentionState.ChunkedPrefill, - AscendAttentionState.SpecDecoding, - AscendAttentionState.PrefillCacheHit - ]: - attn_lse = torch.empty(self.num_heads, - num_tokens, - dtype=torch.float32, - device=query.device) - q_pe = query[..., self.qk_nope_head_dim:] - q_nope = query[..., :self.qk_nope_head_dim] - mask = torch.triu( - torch.ones(512, 512, device=query.device, dtype=query.dtype), - 1) # 512: mask only support 512 - if attn_metadata.num_prefills > 1: - mask = mask.unsqueeze(0).repeat(attn_metadata.num_prefills, 1, - 1) - torch_npu.atb.npu_ring_mla( - q_nope=q_nope, - q_rope=q_pe, - k_nope=k_nope, - k_rope=k_pe, - value=value, - mask=mask, - seqlen=torch.tensor(attn_metadata.prefill.query_lens, - dtype=torch.int32), - head_num=self.num_heads, - kv_head_num=self.num_heads, - pre_out=None, - prev_lse=None, - qk_scale=self.scale, - kernel_type="kernel_type_high_precision", - mask_type="mask_type_triu", - input_layout="type_bsnd", - calc_type="calc_type_first_ring", - output=attn_output, - softmax_lse=attn_lse) - attn_output, attn_lse = self._compute_prefill_context( \ - query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse) - elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: - key = torch.cat((k_nope, k_pe), dim=-1) - torch_npu._npu_flash_attention( - query=query, - key=key, - value=value, - mask=attn_metadata.attn_mask, - seq_len=attn_metadata.prefill.context_lens, - scale_value=self.scale, - num_heads=self.num_heads, - num_kv_heads=self.num_heads, - out=attn_output) - attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim) - else: - raise RuntimeError( - "Unexpected path reached, AscendMLAImpl should only have PrefillNoCache, PrefillCacheHit, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !" - ) attn_output = attn_output.reshape( [num_tokens, self.num_heads * self.v_head_dim]) if attn_metadata.attn_state in [ AscendAttentionState.ChunkedPrefill, AscendAttentionState.SpecDecoding, AscendAttentionState.PrefillCacheHit - ] and not ascend_config.chunked_prefill_for_mla: + ] and not self.chunked_prefill_for_mla: attn_output = attn_output_torch - return attn_output + def exec_kv_decode( + self, + kv_no_split: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + kv_cache: Tuple, + slots: torch.Tensor, + ): + B = kv_no_split.shape[0] + N = self.num_kv_heads + S = 1 + # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] + kv_no_split = kv_no_split.view( + B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) + cache_mode = "PA_NZ" if self.enable_kv_nz else "PA" + k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( + kv_no_split, + self.kv_a_layernorm.weight, + cos, + sin, + slots.to(torch.int64), + kv_cache[1], + kv_cache[0], + epsilon=self.kv_a_layernorm.variance_epsilon, + cache_mode=cache_mode, + ) + return k_pe, k_nope + + def exec_kv_prefill( + self, + kv_no_split: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + kv_cache: Tuple, + slots: torch.Tensor, + ): + B = kv_no_split.shape[0] + N = self.num_kv_heads + S = 1 + # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] + kv_no_split = kv_no_split.view( + B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) + cache_mode = "PA_BLK_NZ" if self.enable_kv_nz else "PA" + _, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache( + kv_no_split, + self.kv_a_layernorm.weight, + cos, + sin, + slots.to(torch.int64), + kv_cache[1], + kv_cache[0], + epsilon=self.kv_a_layernorm.variance_epsilon, + cache_mode=cache_mode, + is_output_kv=True, + ) + return k_pe, k_nope + + def rope_single( + self, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ) -> torch.Tensor: + B, N, D = x.shape + S = 1 + x = x.view(B, N, S, D) + x = torch_npu.npu_interleave_rope(x, cos, sin) + return x.view(B, N, D) + def _forward_decode( self, q_nope: torch.Tensor, q_pe: torch.Tensor, k_nope: torch.Tensor, k_pe: torch.Tensor, - kv_c_and_k_pe_cache: Tuple[torch.Tensor], + block_size: int, attn_metadata: AscendMLAMetadata, ) -> torch.Tensor: decode_meta = attn_metadata.decode assert decode_meta is not None num_tokens = q_nope.size(0) - # The MLA_PA path will be used as default path in the future, `_npu_paged_attention_mla` will - # be removed after the torch_npu contains `torch_npu.atb.npu_multi_head_latent_attention` become - # public available - assert len(kv_c_and_k_pe_cache) > 1 - if envs_ascend.VLLM_ASCEND_MLA_PA: - attn_output = torch_npu.atb.npu_multi_head_latent_attention( - q_nope, q_pe, kv_c_and_k_pe_cache[0], kv_c_and_k_pe_cache[1], - attn_metadata.decode.block_table, - attn_metadata.decode.seq_lens, self.num_heads, self.scale, - self.num_kv_heads) + # shape of knope/k_pe for npu graph mode should be: + # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim] + actual_seq_lengths = None + if self.enable_kv_nz: + k_nope = k_nope.view(-1, self.num_kv_heads, + self.kv_lora_rank // 16, block_size, 16) + k_pe = k_pe.view(-1, self.num_kv_heads, + self.qk_rope_head_dim // 16, block_size, 16) + input_layout = "BSND" else: - q = torch.cat([q_nope, q_pe], dim=-1) - attn_output = torch.empty( - [num_tokens, self.num_heads, self.kv_lora_rank], - dtype=q.dtype, - device=q.device) - k_cache = torch.cat( - [kv_c_and_k_pe_cache[0], kv_c_and_k_pe_cache[1]], dim=-1) - torch_npu._npu_paged_attention_mla( - query=q, - key_cache=k_cache, - num_kv_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale_value=self.scale, - block_table=attn_metadata.decode.block_table, # type:ignore - context_lens=attn_metadata.decode.seq_lens, # type:ignore - mla_vheadsize=self.kv_lora_rank, - out=attn_output) + k_nope = k_nope.view(-1, self.num_kv_heads, block_size, + self.kv_lora_rank) + k_pe = k_pe.view(-1, self.num_kv_heads, block_size, + self.qk_rope_head_dim) + input_layout = "BNSD" + + if attn_metadata.attn_state == AscendAttentionState.SpecDecoding: + assert num_tokens % self.spec_token_num == 0 + input_layout = "TND" + # [bs * q_seq_len, num_heads_per_rank, dim] + q_nope = q_nope.view(num_tokens, self.num_heads, -1) + q_pe = q_pe.view(num_tokens, self.num_heads, -1) + sparse_mode = 3 + spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore + actual_seq_lengths = decode_meta.actual_seq_lengths_q + else: + if self.enable_kv_nz: + q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1) + q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1) + else: + q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1) + q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1) + sparse_mode = 0 + spec_attn_mask = None + + attn_output, _ = torch_npu.npu_fused_infer_attention_score( + q_nope, + k_nope, + k_nope, + query_rope=q_pe, + key_rope=k_pe, + num_heads=self.num_heads, + num_key_value_heads=self.num_kv_heads, + input_layout=input_layout, + atten_mask=spec_attn_mask, + sparse_mode=sparse_mode, + scale=self.scale, + antiquant_mode=0, + antiquant_scale=None, + block_table=decode_meta.block_table, + block_size=block_size, + actual_seq_lengths_kv=decode_meta.seq_lens_list, + actual_seq_lengths=actual_seq_lengths) + current_ms_metadata = get_multistream_comm_context() if current_ms_metadata is None: - return self._v_up_proj_and_o_proj(attn_output) + return self._v_up_proj(attn_output) else: current_ms_metadata.before_comm_event.record() with torch.npu.stream(current_ms_metadata.comm_stream): current_ms_metadata.before_comm_event.wait() - return self._v_up_proj_and_o_proj(attn_output) + return self._v_up_proj(attn_output) + + def _mla_preprocess(self, hidden_states, kv_cache, attn_metadata, + need_gather_q_kv): + # MLA Preprocess: + # 1. Perform q_a_proj and q_a_layernorm to obtain q_c + # 2. Perform kv_a_proj_with_mqa to obtain kv_no_split + # 3. If need_gather_q_kv, perform all_gather. + # 4. Preprocess decode tokens, write kv cache and get: + # decode_ql_nope, decode_q_pe, decode_k_pe, decode_k_nope + # 5. Preprocess prefill tokens, write kv cache and get: + # prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value + has_decode = attn_metadata.num_decodes > 0 + has_prefill = attn_metadata.num_prefills > 0 + num_decode_tokens = attn_metadata.num_decode_tokens + num_actual_tokens = attn_metadata.num_actual_tokens + if self.q_a_proj is not None: + npu_prefetch(self.q_a_proj.weight, + hidden_states, + enabled=self.enable_prefetch) + ckq = self.q_a_proj(hidden_states)[0] + q_c = self.q_a_layernorm(ckq) + else: + q_c = hidden_states + + kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0] + # Process for shared_expert_dp + if need_gather_q_kv: + q_c = get_tp_group().all_gather(q_c, 0) + kv_no_split = get_tp_group().all_gather(kv_no_split, 0) + decode_preprocess_res = None + prefill_preprocess_res = None + # Preprocess for decode tokens + if has_decode: + decode_q_c = q_c[:num_decode_tokens] + cos = attn_metadata.decode.cos + sin = attn_metadata.decode.sin + decode_ql_nope, decode_q_pe = \ + self._q_proj_and_k_up_proj(decode_q_c) + decode_q_pe = self.rope_single(decode_q_pe, cos, sin) + decode_slots = attn_metadata.slot_mapping[:num_decode_tokens] + decode_kv_no_split = kv_no_split[:num_decode_tokens] + decode_k_pe, decode_k_nope = self.exec_kv_decode( + decode_kv_no_split, cos, sin, kv_cache, decode_slots) + decode_preprocess_res = DecodeMLAPreprocessResult( + decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe) + # Preprocess for prefill tokens + if has_prefill: + prefill_kv_no_split = kv_no_split[ + num_decode_tokens:num_actual_tokens] + prefill_q_c = q_c[num_decode_tokens:num_actual_tokens] + prefill_q = self.q_proj(prefill_q_c)[0]\ + .view(-1, self.num_heads, self.qk_head_dim) + prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] + prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim] + cos = attn_metadata.prefill.cos + sin = attn_metadata.prefill.sin + prefill_slots = attn_metadata.slot_mapping[ + num_decode_tokens:num_actual_tokens] + prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin) + prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill( + prefill_kv_no_split, cos, sin, kv_cache, prefill_slots) + prefill_k_pe = prefill_k_pe.view(prefill_q_c.shape[0], + self.num_kv_heads, -1) + prefill_k_nope, prefill_value = self.kv_b_proj( + prefill_k_c_normed)[0].view( + -1, self.num_heads, + self.qk_nope_head_dim + self.v_head_dim).split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + prefill_k_pe = prefill_k_pe.expand( + (*prefill_k_nope.shape[:-1], -1)) + prefill_preprocess_res = PrefillMLAPreprocessResult( + prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, + prefill_value) + return decode_preprocess_res, prefill_preprocess_res def forward( self, - layer: AttentionLayer, - hidden_states_or_q_c: torch.Tensor, # query in unified attn - hidden_states_or_kv_c_normed: torch.Tensor, # key in unified attn - k_pe: torch.Tensor, # value in unified attn + hidden_states: torch.Tensor, # query in unified attn kv_cache: Tuple[torch.Tensor], attn_metadata: M, + need_gather_q_kv: bool = False, output: Optional[torch.Tensor] = None, - ckq: Optional[torch.Tensor] = None, ) -> torch.Tensor: assert output is not None, "Output tensor must be provided." if attn_metadata is None: # Profiling run. return output - num_actual_toks = attn_metadata.num_actual_tokens - if k_pe is None: - kv_c, k_pe = self.kv_a_proj_with_mqa( - hidden_states_or_kv_c_normed)[0].split( - [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) - else: - kv_c_normed = hidden_states_or_kv_c_normed + num_actual_tokens = attn_metadata.num_actual_tokens assert attn_metadata.num_decodes is not None and \ attn_metadata.num_prefills is not None and \ attn_metadata.num_decode_tokens is not None - has_decode = attn_metadata.num_decodes > 0 - has_prefill = attn_metadata.num_prefills > 0 num_decode_tokens = attn_metadata.num_decode_tokens # Inputs and outputs may be padded for CUDA graphs output_padded = output - output = output[:num_actual_toks, ...] - kv_c_normed = kv_c_normed[:num_actual_toks, ...] - prefill_k_c_normed = kv_c_normed[num_decode_tokens:] - hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...] - prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:] - decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens] - k_pe = k_pe[:num_actual_toks, ...] - k_pe = k_pe.unsqueeze(1) - decode_k_pe = k_pe[:num_decode_tokens] - prefill_k_pe = k_pe[num_decode_tokens:] - if has_decode: - decode_k_nope = None - assert attn_metadata.decode is not None - decode_ql_nope, decode_q_pe = \ - self._q_proj_and_k_up_proj(decode_hs_or_q_c) - decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( - attn_metadata.decode.input_positions, - decode_q_pe.contiguous(), - decode_k_pe, - max_seq_len=attn_metadata.decode.max_seq_lens) - if has_prefill: - assert attn_metadata.prefill is not None - prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\ - .view(-1, self.num_heads, self.qk_head_dim) - prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] - prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( - attn_metadata.prefill.input_positions, - prefill_q_pe.contiguous(), - prefill_k_pe, - max_seq_len=attn_metadata.prefill.max_seq_lens) - - assert len( - kv_cache - ) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)" - kv_c_normed = kv_c_normed.view( - [num_actual_toks, self.num_kv_heads, -1]) - torch_npu._npu_reshape_and_cache( - key=kv_c_normed, - value=k_pe, - key_cache=kv_cache[0], - value_cache=kv_cache[1], - slot_indices=attn_metadata.slot_mapping) - o_proj_input_shape = (num_actual_toks, + output = output[:num_actual_tokens, ...] + o_proj_input_shape = (num_actual_tokens, self.num_heads * self.v_head_dim) o_proj_input = torch.empty(o_proj_input_shape, - dtype=hidden_states_or_q_c.dtype, - device=hidden_states_or_q_c.device) - if has_prefill: - # FIX: aicore move should be also placed on the comm stream in dbo, - # otherwise it may affect the accuracy - # TODO: use an elegant way to overlap - output_prefill = self._forward_prefill(prefill_q, - prefill_k_c_normed, - prefill_k_pe, kv_cache, - attn_metadata) - current_ms_metadata = get_multistream_comm_context() - if current_ms_metadata is not None: - current_ms_metadata.before_comm_event.record() - with torch.npu.stream(current_ms_metadata.comm_stream): - current_ms_metadata.before_comm_event.wait() - o_proj_input[num_decode_tokens:] = output_prefill - else: - o_proj_input[num_decode_tokens:] = output_prefill + dtype=hidden_states.dtype, + device=hidden_states.device) - if has_decode: - output_decode = self._forward_decode(decode_ql_nope, decode_q_pe, - decode_k_nope, decode_k_pe, - kv_cache, attn_metadata) + # MLA Preprocess + decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess( + hidden_states, kv_cache, attn_metadata, need_gather_q_kv) + + if decode_preprocess_res is not None: + # MLA Preprocess for decoding + output_decode = self._forward_decode(decode_preprocess_res.ql_nope, + decode_preprocess_res.q_pe, + decode_preprocess_res.k_nope, + decode_preprocess_res.k_pe, + kv_cache[0].shape[1], + attn_metadata) current_ms_metadata = get_multistream_comm_context() if current_ms_metadata is not None: with torch.npu.stream(current_ms_metadata.comm_stream): o_proj_input[:num_decode_tokens] = output_decode + current_ms_metadata.after_comm_event.record() else: o_proj_input[:num_decode_tokens] = output_decode + if prefill_preprocess_res is not None: + # FIX: aicore move should be also placed on the comm stream in dbo, + # otherwise it may affect the accuracy + # TODO: use an elegant way to overlap + output_prefill = self._forward_prefill( + prefill_preprocess_res.q_nope, prefill_preprocess_res.q_pe, + prefill_preprocess_res.k_nope, prefill_preprocess_res.k_pe, + prefill_preprocess_res.value, kv_cache, attn_metadata) + current_ms_metadata = get_multistream_comm_context() + if current_ms_metadata is not None: + with torch.npu.stream(current_ms_metadata.comm_stream): + o_proj_input[num_decode_tokens:] = output_prefill + current_ms_metadata.after_comm_event.record() + else: + o_proj_input[num_decode_tokens:] = output_prefill + # O proj current_ms_metadata = get_multistream_comm_context() + MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 if current_ms_metadata is None: + npu_prefetch(self.o_proj.weight, + o_proj_input, + max_size=MAX_O_PROJ_PREFETCH_SIZE, + enabled=self.enable_prefetch) + output[...] = self.o_proj( o_proj_input, - is_prefill=True, + is_prefill=prefill_preprocess_res is not None, is_force_scatter=self.enable_shared_expert_dp)[0] else: with torch.npu.stream(current_ms_metadata.comm_stream): + npu_prefetch(self.o_proj.weight, + o_proj_input, + max_size=MAX_O_PROJ_PREFETCH_SIZE, + enabled=self.enable_prefetch) output[...] = self.o_proj( o_proj_input, - is_prefill=True, + is_prefill=prefill_preprocess_res is not None, is_force_scatter=self.enable_shared_expert_dp)[0] current_ms_metadata.after_comm_event.record() del o_proj_input diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 0e4cf83..6d0913c 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -37,7 +37,6 @@ from vllm.config import (CacheConfig, ModelConfig, VllmConfig, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tp_group, split_tensor_along_last_dim, - tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, tensor_model_parallel_reduce_scatter) from vllm.distributed.parallel_state import get_dp_group, get_ep_group @@ -73,7 +72,7 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ops.fused_moe import AscendFusedMoE from vllm_ascend.quantization.quant_config import AscendLinearMethod from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod -from vllm_ascend.utils import dispose_tensor, npu_prefetch +from vllm_ascend.utils import dispose_tensor class CustomDeepseekV2SiluAndMul(SiluAndMul): @@ -471,9 +470,6 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): self.debug_layer_idx = int(self.prefix.split(".")[-2]) ascend_config = get_ascend_config() - self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - self.enable_multistream_mla = \ - ascend_config.torchair_graph_config.enable_multistream_mla self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp if self.q_lora_rank is not None: @@ -515,8 +511,7 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): if (config.n_routed_experts is not None and self.debug_layer_idx >= config.first_k_dense_replace and self.debug_layer_idx % config.moe_layer_freq == 0 - and (ascend_config.torchair_graph_config.enable_multistream_moe - or self.enable_shared_expert_dp)): + and self.enable_shared_expert_dp): self.o_proj = CustomDeepseekV2RowParallelLinearReplaceAllreduce( self.num_heads * self.v_head_dim, self.hidden_size, @@ -568,6 +563,9 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): qk_head_dim=self.qk_head_dim, v_head_dim=self.v_head_dim, rotary_emb=self.rotary_emb, + q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None, + q_a_layernorm=self.q_a_layernorm + if self.q_lora_rank is not None else None, q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj, kv_a_proj_with_mqa=self.kv_a_proj_with_mqa, kv_a_layernorm=self.kv_a_layernorm, @@ -582,55 +580,29 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): kv_cache: Optional[torch.Tensor] = None, attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: forward_context = get_forward_context() - enable_multistream_mla = (self.enable_multistream_mla - and attn_metadata is not None - and not forward_context.with_prefill - and attn_metadata.num_decodes > 0) - forward_kwargs = {"enable_multistream_mla": enable_multistream_mla} - if self.q_lora_rank is not None: - npu_prefetch(self.q_a_proj.weight, - hidden_states, - enabled=enable_multistream_mla) - ckq = self.q_a_proj(hidden_states)[0] - hidden_states_or_q_c = self.q_a_layernorm(ckq) - forward_kwargs['ckq'] = ckq - else: - hidden_states_or_q_c = hidden_states - if self.torchair_graph_enabled: + if kv_cache is None: + kv_cache = self.mla_attn.kv_cache[forward_context.virtual_engine] + num_tokens = hidden_states.shape[0] + need_gather_q_kv = False + if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers: + # Simulate all gather to calculate output shape + num_tokens = num_tokens * self.tp_size + need_gather_q_kv = True + if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace: output_shape = hidden_states.shape - output = torch.empty(output_shape, - dtype=hidden_states_or_q_c.dtype, - device=hidden_states_or_q_c.device) - forward_kwargs['output'] = output - output = self.mla_attn.impl.forward(self.mla_attn, - hidden_states_or_q_c, - hidden_states, None, kv_cache, - attn_metadata, - **forward_kwargs) - output = output.view(-1, output_shape[-1]) - return output else: - kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0] - if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers: - hidden_states_or_q_c = get_tp_group().all_gather( - hidden_states_or_q_c, 0) - kv_no_split = get_tp_group().all_gather(kv_no_split, 0) - - kv_c, k_pe = kv_no_split.split( - [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) - if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace: - output_shape = hidden_states.shape - else: - num_tokens = hidden_states_or_q_c.shape[0] - rows = num_tokens // self.tp_size - if num_tokens % self.tp_size: - rows += 1 - output_shape = (rows, hidden_states.shape[1]) - return self.mla_attn(hidden_states_or_q_c, - kv_c_normed, - k_pe, - output_shape=output_shape) + rows = num_tokens // self.tp_size + if num_tokens % self.tp_size: + rows += 1 + output_shape = (rows, hidden_states.shape[1]) + output = torch.empty(output_shape, + dtype=hidden_states.dtype, + device=hidden_states.device) + output = self.mla_attn.impl.forward(hidden_states, kv_cache, + forward_context.attn_metadata, + need_gather_q_kv, output) + output = output.view(-1, output_shape[-1]) + return output class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): @@ -688,8 +660,6 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): quant_config=quant_config, prefix=f"{prefix}.mlp", ) - self.mla_moe_communication = ascend_config.torchair_graph_config.enable_multistream_moe \ - and model_config.use_mla and self.tp_size > 1 else: self.mlp = CustomDeepseekV2MLP( hidden_size=config.hidden_size, @@ -698,7 +668,6 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): quant_config=quant_config, prefix=f"{prefix}.mlp", ) - self.mla_moe_communication = False self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, @@ -718,10 +687,6 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): replace_allreduce: bool = False, ) -> torch.Tensor: # Self Attention - if attn_metadata is not None and attn_metadata.num_decodes > 0: - mla_moe_communication = self.mla_moe_communication and replace_allreduce - else: - mla_moe_communication = False if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -733,9 +698,6 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): # to save npu memory because they're no longer used. dispose_tensor(previous_hidden_states) dispose_tensor(previous_residual) - if mla_moe_communication and self.layer_idx > self.first_k_dense_replace: - hidden_states = tensor_model_parallel_all_gather(hidden_states, - dim=0) hidden_states = self.self_attn( positions=positions, @@ -744,13 +706,6 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): attn_metadata=attn_metadata, ) - if mla_moe_communication and residual.shape[0] != hidden_states.shape[ - 0]: - chunk_hidden_states = torch.tensor_split(residual, - self.tp_size, - dim=0) - residual = chunk_hidden_states[self.tp_rank] - if hidden_states.dtype == torch.float16: # Fix FP16 overflow # We scale both hidden_states and residual before @@ -778,9 +733,7 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): hidden_states, residual) if isinstance(self.mlp, CustomDeepseekV2MoE): - hidden_states = self.mlp(hidden_states, - attn_metadata, - replace_allreduce=mla_moe_communication) + hidden_states = self.mlp(hidden_states, attn_metadata) else: hidden_states = self.mlp(hidden_states) @@ -793,10 +746,6 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): # The scaling of DeepseekV2MOE output would be done in the forward # of DeepseekV2MOE hidden_states *= 1. / self.routed_scaling_factor - if mla_moe_communication and self.layer_idx == self.layers - 1: - hidden_states = tensor_model_parallel_all_gather(hidden_states, - dim=0) - residual = tensor_model_parallel_all_gather(residual, dim=0) # for last layer of main model and mtp layer. if self.enable_shared_expert_dp and self.layer_idx >= (