diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 5bc2641a..71ab0c0b 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -6,6 +6,7 @@ import torch_npu import vllm.envs as envs_vllm from torch import nn from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config +from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group from vllm.forward_context import get_forward_context from vllm.logger import logger @@ -34,7 +35,7 @@ from vllm_ascend.ops.triton.rope import rope_forward_triton from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, _round_up, dispose_layer, - enable_dsa_cp, maybe_trans_nz, vllm_version_is) + enable_dsa_cp, enable_dsa_cp_with_layer_shard, maybe_trans_nz, vllm_version_is) from vllm_ascend.worker.npu_input_batch import NPUInputBatch # isort: off @@ -79,7 +80,7 @@ class AscendSFABackend(AttentionBackend): @dataclass -class SfaCpContext: +class DSACPContext: num_tokens: int num_tokens_pad: int local_start: int @@ -119,7 +120,7 @@ class AscendSFAMetadata: attn_mask: torch.Tensor = None # chunked prefill by default if no attn_states passed attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill - sfa_cp_context: Optional[SfaCpContext] = None + dsa_cp_context: Optional[DSACPContext] = None reshape_cache_event: torch.npu.Event = None @@ -159,15 +160,16 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]): npu_fused_infer_attention_score TND layout's limit of 16, \ got {self.decode_threshold}" - self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim - self.enable_sfa_cp = enable_dsa_cp() - - assert not ( - self.enable_sfa_cp - and self.vllm_config.compilation_config.cudagraph_mode - == CUDAGraphMode.FULL_DECODE_ONLY - ), "FlashComm1 is not compatible with FULL_DECODE_ONLY. Please set graph_mode to 'piecewise' or disable FlashComm1." self.attn_mask_builder = AttentionMaskBuilder(self.device) + self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim + self.enable_dsa_cp = enable_dsa_cp() + + max_num_reqs = vllm_config.scheduler_config.max_num_seqs + self.actual_seq_lengths_query = torch.zeros(max_num_reqs + 1, + dtype=torch.int32, + device=device) + self.actual_seq_lengths_key = torch.empty_like( + self.actual_seq_lengths_query) @staticmethod def determine_chunked_prefill_workspace_size( @@ -210,8 +212,8 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]): cos, sin = get_cos_and_sin_mla(input_positions, True) - sfa_cp_context = None - if self.enable_sfa_cp: + dsa_cp_context = None + if self.enable_dsa_cp: global_tp_size = get_tp_group().world_size num_tokens = num_input_tokens num_tokens_pad = _round_up(num_tokens, global_tp_size) @@ -235,13 +237,11 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]): value=-1) else: slot_mapping = slot_mapping[:num_tokens_pad] + slot_mapping_cp = slot_mapping[local_start:local_end_with_pad] cos = cos[local_start:local_end_with_pad] sin = sin[local_start:local_end_with_pad] - slot_mapping_cp = torch.full(size=(num_tokens_per_device, ), - fill_value=-1, - dtype=slot_mapping.dtype, - device=slot_mapping.device) + assert cos.shape[0] == num_tokens_per_device, \ f"cos.shape[0] must be equal to num_tokens_per_device, \ got {cos.shape[0]} and {num_tokens_per_device}" @@ -252,8 +252,9 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]): f"slot_mapping.shape[0] must be equal to num_tokens_pad, \ got {slot_mapping.shape[0]} and {num_tokens_pad}" - actual_seq_lengths_query = torch.empty_like(cum_query_lens) - actual_seq_lengths_key = torch.empty_like(seq_lens) + actual_seq_lengths_query = self.actual_seq_lengths_query + actual_seq_lengths_key = self.actual_seq_lengths_key + num_segs = cum_query_lens.shape[0] last_token = 0 cum = 0 @@ -262,21 +263,24 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]): global_end = cum_query_lens[i].item() last_token = global_end - local_start = max(global_start, local_start) - local_end = min(global_end, local_end_with_pad) - num_local_tokens = local_end - local_start + req_local_start = max(global_start, local_start) + req_local_end = min(global_end, local_end_with_pad) + num_local_tokens = req_local_end - req_local_start if num_local_tokens > 0: cum += num_local_tokens actual_seq_lengths_query[i] = cum - offset = global_end - local_end + offset = global_end - req_local_end actual_seq_lengths_key[i] = seq_lens[i].item() - offset else: actual_seq_lengths_query[i] = cum actual_seq_lengths_key[i] = 0 - sfa_cp_context = SfaCpContext( + actual_seq_lengths_query = actual_seq_lengths_query[:num_reqs] + actual_seq_lengths_key = actual_seq_lengths_key[:num_reqs] + + dsa_cp_context = DSACPContext( num_tokens=num_tokens, num_tokens_pad=num_tokens_pad, local_start=local_start, @@ -300,7 +304,7 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]): block_tables=block_table, sin=sin[:num_input_tokens], cos=cos[:num_input_tokens], - sfa_cp_context=sfa_cp_context) + dsa_cp_context=dsa_cp_context) def build_for_graph_capture( self, @@ -329,6 +333,8 @@ class AscendSFAImpl(MLAAttentionImpl): NOTE: Please read the comment at the top of the file before trying to understand this class """ + # Supports forward using the all-gather o_proj weight for decode requests when Sharded CP is enabled. + o_proj_full_pool: Optional[torch.Tensor] = None def __init__( self, @@ -382,22 +388,9 @@ class AscendSFAImpl(MLAAttentionImpl): assert self.indexer is not None, "Indexer is required for DSA." - self.enable_sfa_cp = enable_dsa_cp() self.local_num_heads = self.num_heads self.vllm_config = get_current_vllm_config() self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer - if self.enable_sfa_cp: - self.local_num_heads = self.num_heads * self.tp_size - self.layer_sharding_kwargs = [] - for layer_name in (get_ascend_config().layer_sharding or []): - if layer_name in kwargs: - self.layer_sharding_kwargs.append(kwargs[layer_name]) - else: - logger.warning_once( - f"Layer '{layer_name}' not found in kwargs for layer sharding, skipping sharding configuration" - ) - register_all_layers_to_shard_weight_series( - self.layer_sharding_kwargs) # indexer param self.n_head: int = self.indexer.n_head # 64 @@ -406,9 +399,24 @@ class AscendSFAImpl(MLAAttentionImpl): self.wk = self.indexer.wk self.weights_proj = self.indexer.weights_proj self.k_norm = self.indexer.k_norm - self.cp_size = 1 + self.enable_dsa_cp = enable_dsa_cp() + self.enable_dsa_cp_prefill_only = enable_dsa_cp_with_layer_shard() + if self.enable_dsa_cp: + self.local_num_heads = self.num_heads * self.tp_size + if self.enable_dsa_cp_prefill_only: + self.layer_sharding_kwargs = [] + for layer_name in (get_ascend_config().layer_sharding or []): + if layer_name in kwargs: + self.layer_sharding_kwargs.append(kwargs[layer_name]) + else: + logger.warning_once( + f"[SFAImpl init] Layer '{layer_name}' not found in kwargs for layer sharding, skipping sharding configuration" + ) + register_all_layers_to_shard_weight_series( + self.layer_sharding_kwargs) + def process_weights_after_loading(self, act_dtype: torch.dtype): # NOTE: We currently do not support quant kv_b_proj. assert isinstance(self.kv_b_proj.quant_method, UnquantizedLinearMethod) @@ -442,10 +450,14 @@ class AscendSFAImpl(MLAAttentionImpl): # Dispose kv_b_proj since it is replaced by W_UV and W_UK_T to save memory dispose_layer(self.kv_b_proj) - if self.enable_sfa_cp: - for layer in (self.layer_sharding_kwargs or []): - if is_hidden_layer(layer): - post_process_after_loading_for_shard_weight_series(layer) + if self.enable_dsa_cp: + if self.enable_dsa_cp_prefill_only: + for layer in (self.layer_sharding_kwargs or []): + if is_hidden_layer(layer): + post_process_after_loading_for_shard_weight_series( + layer) + else: + self._init_o_proj_tp_full_params() if self.enable_mlapo: quant_method = getattr( @@ -460,7 +472,7 @@ class AscendSFAImpl(MLAAttentionImpl): "Currently mlapo only supports W8A8 quantization in SFA scenario." "Some layers in your model are not quantized with W8A8," "thus mlapo is disabled for these layers.") - if self.enable_sfa_cp: + if self.enable_dsa_cp: reasons.append("Currently mlapo does not support SFA with CP," "thus mlapo is disabled for these layers.") if reasons: @@ -525,7 +537,7 @@ class AscendSFAImpl(MLAAttentionImpl): B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) cache_mode = "PA" - if self.enable_sfa_cp: + if self.enable_dsa_cp: _, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache( kv_no_split, self.kv_a_layernorm.weight, @@ -738,7 +750,7 @@ class AscendSFAImpl(MLAAttentionImpl): forward_context = get_forward_context() if attn_metadata is None: # Profiling run. - if self.enable_sfa_cp and not forward_context.in_profile_run: + if self.enable_dsa_cp_prefill_only and not forward_context.in_profile_run: for layer in (self.layer_sharding_kwargs or []): if is_hidden_layer(layer): reach_layer_for_shard_weight_series(layer) @@ -748,12 +760,20 @@ class AscendSFAImpl(MLAAttentionImpl): sin = attn_metadata.sin actual_seq_lengths_query = attn_metadata.cum_query_lens actual_seq_lengths_key = attn_metadata.seq_lens - if self.enable_sfa_cp: + if self.enable_dsa_cp: need_gather_q_kv = False # Inputs and outputs may be padded for CUDA graphs num_input_tokens = attn_metadata.num_input_tokens output_padded = output + # all-gather o_proj weight for prefill stage of PD mix node + o_proj_full_handle = None + # if is PD mix stage, using original TP o_proj weight, and also need to full gather for o_proj weight for prefill stage. + should_shard_weight = self.enable_dsa_cp_prefill_only or attn_metadata.attn_state not in { + AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding + } + + if self.enable_mlapo and num_input_tokens <= MLAPO_MAX_SUPPORTED_TOKENS: hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocess_decode( hidden_states=hidden_states, @@ -796,16 +816,16 @@ class AscendSFAImpl(MLAAttentionImpl): wait_for_kv_layer_from_connector(layer_name) slot_mapping = attn_metadata.slot_mapping - if self.enable_sfa_cp: - assert attn_metadata.sfa_cp_context is not None - slot_mapping = attn_metadata.sfa_cp_context.slot_mapping_cp - actual_seq_lengths_query = attn_metadata.sfa_cp_context.actual_seq_lengths_query - actual_seq_lengths_key = attn_metadata.sfa_cp_context.actual_seq_lengths_key + if self.enable_dsa_cp: + assert attn_metadata.dsa_cp_context is not None + slot_mapping = attn_metadata.dsa_cp_context.slot_mapping_cp + actual_seq_lengths_query = attn_metadata.dsa_cp_context.actual_seq_lengths_query + actual_seq_lengths_key = attn_metadata.dsa_cp_context.actual_seq_lengths_key k_pe, k_nope = self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping) - if self.enable_sfa_cp: + if self.enable_dsa_cp: assert k_pe is not None assert k_nope is not None # support all_gather kv async for communication calculation overlap @@ -815,17 +835,26 @@ class AscendSFAImpl(MLAAttentionImpl): k_nope.view(-1, k_nope.shape[-1]), k.view(-1, k.shape[-1]) ], - dim=1), get_tp_group()) + dim=1), + get_tp_group(), + async_op=should_shard_weight) ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c) q_pe = self.rope_single(q_pe, cos, sin) - if self.enable_sfa_cp: + if self.enable_dsa_cp: if kv_ag_handle is not None: kv_ag_handle.wait() - for layer in (self.layer_sharding_kwargs or []): - if is_hidden_layer(layer): - reach_layer_for_shard_weight_series(layer) + + if self.enable_dsa_cp_prefill_only: + for layer in (self.layer_sharding_kwargs or []): + if is_hidden_layer(layer): + reach_layer_for_shard_weight_series(layer) + elif should_shard_weight: + _, o_proj_full_handle = all_gather_async( + self.o_proj_tp_weight, + get_tp_group(), + output=AscendSFAImpl.o_proj_full_pool) if kv_cache is not None: assert fused_kv_no_split is not None @@ -841,6 +870,12 @@ class AscendSFAImpl(MLAAttentionImpl): kv_cache[1].view(-1, k_pe.shape[-1]), slot_mapping, k_pe) + if kv_cache is not None: + torch_npu.npu_scatter_nd_update_( + kv_cache[2].view(-1, k.shape[-1]), + attn_metadata.slot_mapping.view(-1, 1), + k.view(-1, k.shape[-1])) # b, s, n, d + topk_indices = self.indexer_select_post_process( x=hidden_states, qr=q_c, @@ -876,6 +911,20 @@ class AscendSFAImpl(MLAAttentionImpl): dependency=attn_output, max_size=MAX_O_PROJ_PREFETCH_SIZE, enabled=self.enable_prefetch) + + if self.enable_dsa_cp and not self.enable_dsa_cp_prefill_only: + # When using SFA-CP with pd mixed, o_proj has two cases: + # 1. prefill: o_proj is a TP weight, we need to all-gather o_proj weight to switch TP=1. + # 2. decode: all-to-all the hidden_state before the o_proj forward. + result, require_o_proj_forward = self._handle_o_proj_weight_switch_and_forward( + attn_output=attn_output, + output=output, + o_proj_full_handle=o_proj_full_handle, + should_shard_weight=should_shard_weight) + if not require_o_proj_forward: + return result + attn_output = result + output[...] = self.o_proj(attn_output)[0] maybe_save_kv_layer_to_connector(layer_name, list(kv_cache)) @@ -912,7 +961,10 @@ class AscendSFAImpl(MLAAttentionImpl): k_pe, k_nope = torch.split( k, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], - dim=-1) # [b,s,64+64] + dim=-1) + + cos = cos.view(-1, 1, 1, self.qk_rope_head_dim) + sin = sin.view(-1, 1, 1, self.qk_rope_head_dim) k_pe = k_pe.unsqueeze(2) k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin) @@ -940,10 +992,7 @@ class AscendSFAImpl(MLAAttentionImpl): if q is None: q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128] q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128] - cos_q, sin_q = cos, sin - cos = cos.view(-1, 1, 1, self.qk_rope_head_dim) - sin = sin.view(-1, 1, 1, self.qk_rope_head_dim) q_pe, q_nope = torch.split( q, @@ -984,3 +1033,92 @@ class AscendSFAImpl(MLAAttentionImpl): sparse_count=2048, sparse_mode=3) return topk_indices + + def _init_o_proj_tp_full_params(self): + """ + Initialize TP-mode and Full-mode parameters for o_proj weight, + preparing for weight switching in PD mix stage. + + For PD mix stage: + - Use original TP o_proj weight for decode phase + - Need full-gather o_proj weight from all TP ranks for prefill phase + """ + if AscendSFAImpl.o_proj_full_pool is None: + sample = self.o_proj.weight + AscendSFAImpl.o_proj_full_pool = torch.empty( + (sample.shape[0] * self.tp_size, sample.shape[1]), + dtype=sample.dtype, + device=sample.device) + + # Save TP-mode parameters (original sharded weights) + self.o_proj_tp_weight = self.o_proj.weight.clone().detach() + self.o_proj_tp_aclnn_input_scale = self.o_proj.aclnn_input_scale.clone( + ).detach() + self.o_proj_tp_aclnn_input_scale_reciprocal = self.o_proj.aclnn_input_scale_reciprocal.clone( + ).detach() + self.o_proj_tp_aclnn_input_offset = self.o_proj.aclnn_input_offset.clone( + ).detach() + + # Initially switch to TP mode for graph capture + self.o_proj.weight.set_(self.o_proj_tp_weight) + self.o_proj.aclnn_input_scale.set_(self.o_proj_tp_aclnn_input_scale) + self.o_proj.aclnn_input_scale_reciprocal.set_( + self.o_proj_tp_aclnn_input_scale_reciprocal) + self.o_proj.aclnn_input_offset.set_(self.o_proj_tp_aclnn_input_offset) + + # Precompute Full-mode quantization parameters by repeating TP parameters across all TP ranks + self.o_proj_full_aclnn_input_scale = self.o_proj.aclnn_input_scale.repeat( + self.tp_size) + self.o_proj_full_aclnn_input_scale_reciprocal = self.o_proj.aclnn_input_scale_reciprocal.repeat( + self.tp_size) + self.o_proj_full_aclnn_input_offset = self.o_proj.aclnn_input_offset.repeat( + self.tp_size) + + def _handle_o_proj_weight_switch_and_forward( + self, attn_output: torch.Tensor, output: torch.Tensor, + o_proj_full_handle: Optional[torch.distributed.Work], + should_shard_weight: bool) -> Tuple[torch.Tensor, bool]: + """ + Handle o_proj weight switching between TP-mode and Full-mode, and execute forward computation. + """ + # Gather o_proj weight from all TP ranks for Full-mode computation + if should_shard_weight: + # Wait for the completion of o_proj weight all-gather operation + if o_proj_full_handle is not None: + o_proj_full_handle.wait() + + # Switch o_proj to Full-mode (gathered weight from all TP ranks) + self.o_proj.weight.set_(AscendSFAImpl.o_proj_full_pool) + self.o_proj.aclnn_input_scale.set_( + self.o_proj_full_aclnn_input_scale) + self.o_proj.aclnn_input_scale_reciprocal.set_( + self.o_proj_full_aclnn_input_scale_reciprocal) + self.o_proj.aclnn_input_offset.set_( + self.o_proj_full_aclnn_input_offset) + + # Apply quantization method and execute forward computation + output[...] = self.o_proj.quant_method.quant_method.apply( + self.o_proj, attn_output) + + # Switch o_proj back to TP-mode for subsequent decode operations + self.o_proj.weight.set_(self.o_proj_tp_weight) + self.o_proj.aclnn_input_scale.set_( + self.o_proj_tp_aclnn_input_scale) + self.o_proj.aclnn_input_scale_reciprocal.set_( + self.o_proj_tp_aclnn_input_scale_reciprocal) + self.o_proj.aclnn_input_offset.set_( + self.o_proj_tp_aclnn_input_offset) + + return output, False + else: + # For decode scenario: perform all-to-all communication on o_proj input activations + # Reshape for all-to-all: [batch * seq, tp_size, head_dim] -> [tp_size, batch * seq, head_dim] + send = attn_output.view(-1, self.tp_size, self.num_heads * + self.v_head_dim).permute(1, 0, 2).reshape( + -1, self.num_heads * self.v_head_dim) + + attn_output = torch.empty_like(send) + torch.distributed.all_to_all_single( + attn_output, send, group=get_tp_group().device_group) + + return attn_output, True diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py index 9932867c..635546de 100644 --- a/vllm_ascend/distributed/parallel_state.py +++ b/vllm_ascend/distributed/parallel_state.py @@ -7,7 +7,7 @@ from vllm.distributed.parallel_state import (GroupCoordinator, get_tp_group, init_model_parallel_group) from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.utils import enable_dsa_cp, flashcomm2_enable +from vllm_ascend.utils import enable_dsa_cp_with_layer_shard, flashcomm2_enable # Currently, mc2 op need their own group coordinator. _MC2: Optional[GroupCoordinator] = None @@ -238,7 +238,7 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): FC2_group_ranks = torch.tensor( flashcomm2_otp_group_ranks).squeeze(0) _SHARD_WEIGHT = create_shard_weight_group(FC2_group_ranks) - elif enable_dsa_cp(): + elif enable_dsa_cp_with_layer_shard(): # For dsa_cp, all shard layers are replicated. _SHARD_WEIGHT = create_shard_weight_group(None) else: diff --git a/vllm_ascend/distributed/utils.py b/vllm_ascend/distributed/utils.py index 4a73e6c5..d2773d51 100644 --- a/vllm_ascend/distributed/utils.py +++ b/vllm_ascend/distributed/utils.py @@ -52,4 +52,4 @@ def all_gather_async(input: torch.Tensor, return output, dist.all_gather_into_tensor(output, input, group=group.device_group, - async_op=async_op) \ No newline at end of file + async_op=async_op) diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index adeaa26a..9755df47 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -62,7 +62,7 @@ from vllm_ascend.distributed.parallel_state import (get_flashcomm2_odp_group, get_mlp_tp_group, get_otp_group) from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager -from vllm_ascend.utils import (enable_dsa_cp, enable_sp, flashcomm2_enable, +from vllm_ascend.utils import (enable_dsa_cp, enable_dsa_cp_with_layer_shard, enable_sp, flashcomm2_enable, get_flashcomm2_reorgnized_batch_ids, matmul_allreduce_enable, mlp_tp_enable, oproj_tp_enable, shared_expert_dp_enabled) @@ -575,7 +575,8 @@ class SequenceRowParallelOp(CustomRowParallelOp): return tensor_model_parallel_all_reduce(output_parallel) pad_size = forward_context.pad_size - if pad_size > 0: + if pad_size > 0 and not (enable_dsa_cp() + and "o_proj" in self.layer.prefix): x = F.pad(x, (0, 0, 0, pad_size)) world_size = self.layer.tp_size @@ -728,7 +729,7 @@ def _get_row_parallel_op( ) -> Optional[Union[MLPRowParallelOp, OProjRowParallelOp, Flashcomm2OProjRowParallelOp, MatmulAllreduceRowParallelOp, SequenceRowParallelOp, ShardedCPRowParallelOp]]: - if enable_dsa_cp() and "o_proj" in prefix: + if enable_dsa_cp_with_layer_shard() and "o_proj" in prefix: return ShardedCPRowParallelOp(layer) if "down_proj" in prefix and mlp_tp_enable() and not is_moe_layer(prefix): return MLPRowParallelOp(layer)