[Feature] Support DSA-CP for Hybrid scenario (#5702)

Signed-off-by: zzhx1 <zzh_201018@outlook.com>

### What this PR does / why we need it?
> Extracted from PR #5513
Based on the Sharded-CP feature PR:#4702;
RFC:https://github.com/vllm-project/vllm/issues/30055

### Support FULL_DECODE_ONLY Mode under PD-Mixed Scenario:
Extends DSA-CP to handle the FULL_DECODE_ONLY execution mode when
running in a prefill-decode mixed (PD-mixed) serving environment,
improving throughput and resource utilization for decode-intensive
workloads.
**In pure prefill nodes:**
- Both q_proj and o_proj are sharded across world ranks, using
**broadcast** for weights distribution.

**In PD-mixed nodes (supporting both prefill and decode):**

- q_proj is fully replicated (not sharded) to avoid communication
overhead during decoding.
- o_proj Using the original TP `RowParallelLinear` method to store
weights

**During prefill execution:**
- o_proj forwards through all_gather to collect weights, reconstructing
the complete o_proj weights on each card.

**During decode (graph replay phase):**
- Additional all_to_all (before o_proj) and reduce_scatter (after
o_proj) are introduced to enable sequence-parallel output aggregation
while maintaining correctness under SFA CP.

### benchmark:
- TTFT increased by **527%**
- TPOT increased by **180%**

<img width="1550" height="938" alt="image"
src="https://github.com/user-attachments/assets/9b7a03d8-a3db-4a99-8923-6e5bfcfecf72"
/>


### Does this PR introduce _any_ user-facing change?
None
### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

---------

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: zzhxx <zhangzihang23@mails.ucas.ac.cn>
Co-authored-by: clrs97 <524936896@qq.com>
This commit is contained in:
zzhxxx
2026-01-22 10:12:09 +08:00
committed by GitHub
parent 69740039b7
commit dd8571860d
4 changed files with 207 additions and 68 deletions

View File

@@ -6,6 +6,7 @@ import torch_npu
import vllm.envs as envs_vllm
from torch import nn
from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
from vllm.forward_context import get_forward_context
from vllm.logger import logger
@@ -34,7 +35,7 @@ from vllm_ascend.ops.triton.rope import rope_forward_triton
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, _round_up, dispose_layer,
enable_dsa_cp, maybe_trans_nz, vllm_version_is)
enable_dsa_cp, enable_dsa_cp_with_layer_shard, maybe_trans_nz, vllm_version_is)
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
# isort: off
@@ -79,7 +80,7 @@ class AscendSFABackend(AttentionBackend):
@dataclass
class SfaCpContext:
class DSACPContext:
num_tokens: int
num_tokens_pad: int
local_start: int
@@ -119,7 +120,7 @@ class AscendSFAMetadata:
attn_mask: torch.Tensor = None
# chunked prefill by default if no attn_states passed
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
sfa_cp_context: Optional[SfaCpContext] = None
dsa_cp_context: Optional[DSACPContext] = None
reshape_cache_event: torch.npu.Event = None
@@ -159,15 +160,16 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
npu_fused_infer_attention_score TND layout's limit of 16, \
got {self.decode_threshold}"
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
self.enable_sfa_cp = enable_dsa_cp()
assert not (
self.enable_sfa_cp
and self.vllm_config.compilation_config.cudagraph_mode
== CUDAGraphMode.FULL_DECODE_ONLY
), "FlashComm1 is not compatible with FULL_DECODE_ONLY. Please set graph_mode to 'piecewise' or disable FlashComm1."
self.attn_mask_builder = AttentionMaskBuilder(self.device)
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
self.enable_dsa_cp = enable_dsa_cp()
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
self.actual_seq_lengths_query = torch.zeros(max_num_reqs + 1,
dtype=torch.int32,
device=device)
self.actual_seq_lengths_key = torch.empty_like(
self.actual_seq_lengths_query)
@staticmethod
def determine_chunked_prefill_workspace_size(
@@ -210,8 +212,8 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
cos, sin = get_cos_and_sin_mla(input_positions, True)
sfa_cp_context = None
if self.enable_sfa_cp:
dsa_cp_context = None
if self.enable_dsa_cp:
global_tp_size = get_tp_group().world_size
num_tokens = num_input_tokens
num_tokens_pad = _round_up(num_tokens, global_tp_size)
@@ -235,13 +237,11 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
value=-1)
else:
slot_mapping = slot_mapping[:num_tokens_pad]
slot_mapping_cp = slot_mapping[local_start:local_end_with_pad]
cos = cos[local_start:local_end_with_pad]
sin = sin[local_start:local_end_with_pad]
slot_mapping_cp = torch.full(size=(num_tokens_per_device, ),
fill_value=-1,
dtype=slot_mapping.dtype,
device=slot_mapping.device)
assert cos.shape[0] == num_tokens_per_device, \
f"cos.shape[0] must be equal to num_tokens_per_device, \
got {cos.shape[0]} and {num_tokens_per_device}"
@@ -252,8 +252,9 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
f"slot_mapping.shape[0] must be equal to num_tokens_pad, \
got {slot_mapping.shape[0]} and {num_tokens_pad}"
actual_seq_lengths_query = torch.empty_like(cum_query_lens)
actual_seq_lengths_key = torch.empty_like(seq_lens)
actual_seq_lengths_query = self.actual_seq_lengths_query
actual_seq_lengths_key = self.actual_seq_lengths_key
num_segs = cum_query_lens.shape[0]
last_token = 0
cum = 0
@@ -262,21 +263,24 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
global_end = cum_query_lens[i].item()
last_token = global_end
local_start = max(global_start, local_start)
local_end = min(global_end, local_end_with_pad)
num_local_tokens = local_end - local_start
req_local_start = max(global_start, local_start)
req_local_end = min(global_end, local_end_with_pad)
num_local_tokens = req_local_end - req_local_start
if num_local_tokens > 0:
cum += num_local_tokens
actual_seq_lengths_query[i] = cum
offset = global_end - local_end
offset = global_end - req_local_end
actual_seq_lengths_key[i] = seq_lens[i].item() - offset
else:
actual_seq_lengths_query[i] = cum
actual_seq_lengths_key[i] = 0
sfa_cp_context = SfaCpContext(
actual_seq_lengths_query = actual_seq_lengths_query[:num_reqs]
actual_seq_lengths_key = actual_seq_lengths_key[:num_reqs]
dsa_cp_context = DSACPContext(
num_tokens=num_tokens,
num_tokens_pad=num_tokens_pad,
local_start=local_start,
@@ -300,7 +304,7 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
block_tables=block_table,
sin=sin[:num_input_tokens],
cos=cos[:num_input_tokens],
sfa_cp_context=sfa_cp_context)
dsa_cp_context=dsa_cp_context)
def build_for_graph_capture(
self,
@@ -329,6 +333,8 @@ class AscendSFAImpl(MLAAttentionImpl):
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
# Supports forward using the all-gather o_proj weight for decode requests when Sharded CP is enabled.
o_proj_full_pool: Optional[torch.Tensor] = None
def __init__(
self,
@@ -382,22 +388,9 @@ class AscendSFAImpl(MLAAttentionImpl):
assert self.indexer is not None, "Indexer is required for DSA."
self.enable_sfa_cp = enable_dsa_cp()
self.local_num_heads = self.num_heads
self.vllm_config = get_current_vllm_config()
self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
if self.enable_sfa_cp:
self.local_num_heads = self.num_heads * self.tp_size
self.layer_sharding_kwargs = []
for layer_name in (get_ascend_config().layer_sharding or []):
if layer_name in kwargs:
self.layer_sharding_kwargs.append(kwargs[layer_name])
else:
logger.warning_once(
f"Layer '{layer_name}' not found in kwargs for layer sharding, skipping sharding configuration"
)
register_all_layers_to_shard_weight_series(
self.layer_sharding_kwargs)
# indexer param
self.n_head: int = self.indexer.n_head # 64
@@ -406,9 +399,24 @@ class AscendSFAImpl(MLAAttentionImpl):
self.wk = self.indexer.wk
self.weights_proj = self.indexer.weights_proj
self.k_norm = self.indexer.k_norm
self.cp_size = 1
self.enable_dsa_cp = enable_dsa_cp()
self.enable_dsa_cp_prefill_only = enable_dsa_cp_with_layer_shard()
if self.enable_dsa_cp:
self.local_num_heads = self.num_heads * self.tp_size
if self.enable_dsa_cp_prefill_only:
self.layer_sharding_kwargs = []
for layer_name in (get_ascend_config().layer_sharding or []):
if layer_name in kwargs:
self.layer_sharding_kwargs.append(kwargs[layer_name])
else:
logger.warning_once(
f"[SFAImpl init] Layer '{layer_name}' not found in kwargs for layer sharding, skipping sharding configuration"
)
register_all_layers_to_shard_weight_series(
self.layer_sharding_kwargs)
def process_weights_after_loading(self, act_dtype: torch.dtype):
# NOTE: We currently do not support quant kv_b_proj.
assert isinstance(self.kv_b_proj.quant_method, UnquantizedLinearMethod)
@@ -442,10 +450,14 @@ class AscendSFAImpl(MLAAttentionImpl):
# Dispose kv_b_proj since it is replaced by W_UV and W_UK_T to save memory
dispose_layer(self.kv_b_proj)
if self.enable_sfa_cp:
for layer in (self.layer_sharding_kwargs or []):
if is_hidden_layer(layer):
post_process_after_loading_for_shard_weight_series(layer)
if self.enable_dsa_cp:
if self.enable_dsa_cp_prefill_only:
for layer in (self.layer_sharding_kwargs or []):
if is_hidden_layer(layer):
post_process_after_loading_for_shard_weight_series(
layer)
else:
self._init_o_proj_tp_full_params()
if self.enable_mlapo:
quant_method = getattr(
@@ -460,7 +472,7 @@ class AscendSFAImpl(MLAAttentionImpl):
"Currently mlapo only supports W8A8 quantization in SFA scenario."
"Some layers in your model are not quantized with W8A8,"
"thus mlapo is disabled for these layers.")
if self.enable_sfa_cp:
if self.enable_dsa_cp:
reasons.append("Currently mlapo does not support SFA with CP,"
"thus mlapo is disabled for these layers.")
if reasons:
@@ -525,7 +537,7 @@ class AscendSFAImpl(MLAAttentionImpl):
B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
cache_mode = "PA"
if self.enable_sfa_cp:
if self.enable_dsa_cp:
_, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache(
kv_no_split,
self.kv_a_layernorm.weight,
@@ -738,7 +750,7 @@ class AscendSFAImpl(MLAAttentionImpl):
forward_context = get_forward_context()
if attn_metadata is None:
# Profiling run.
if self.enable_sfa_cp and not forward_context.in_profile_run:
if self.enable_dsa_cp_prefill_only and not forward_context.in_profile_run:
for layer in (self.layer_sharding_kwargs or []):
if is_hidden_layer(layer):
reach_layer_for_shard_weight_series(layer)
@@ -748,12 +760,20 @@ class AscendSFAImpl(MLAAttentionImpl):
sin = attn_metadata.sin
actual_seq_lengths_query = attn_metadata.cum_query_lens
actual_seq_lengths_key = attn_metadata.seq_lens
if self.enable_sfa_cp:
if self.enable_dsa_cp:
need_gather_q_kv = False
# Inputs and outputs may be padded for CUDA graphs
num_input_tokens = attn_metadata.num_input_tokens
output_padded = output
# all-gather o_proj weight for prefill stage of PD mix node
o_proj_full_handle = None
# if is PD mix stage, using original TP o_proj weight, and also need to full gather for o_proj weight for prefill stage.
should_shard_weight = self.enable_dsa_cp_prefill_only or attn_metadata.attn_state not in {
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
}
if self.enable_mlapo and num_input_tokens <= MLAPO_MAX_SUPPORTED_TOKENS:
hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocess_decode(
hidden_states=hidden_states,
@@ -796,16 +816,16 @@ class AscendSFAImpl(MLAAttentionImpl):
wait_for_kv_layer_from_connector(layer_name)
slot_mapping = attn_metadata.slot_mapping
if self.enable_sfa_cp:
assert attn_metadata.sfa_cp_context is not None
slot_mapping = attn_metadata.sfa_cp_context.slot_mapping_cp
actual_seq_lengths_query = attn_metadata.sfa_cp_context.actual_seq_lengths_query
actual_seq_lengths_key = attn_metadata.sfa_cp_context.actual_seq_lengths_key
if self.enable_dsa_cp:
assert attn_metadata.dsa_cp_context is not None
slot_mapping = attn_metadata.dsa_cp_context.slot_mapping_cp
actual_seq_lengths_query = attn_metadata.dsa_cp_context.actual_seq_lengths_query
actual_seq_lengths_key = attn_metadata.dsa_cp_context.actual_seq_lengths_key
k_pe, k_nope = self.exec_kv(kv_no_split, cos, sin, kv_cache,
slot_mapping)
if self.enable_sfa_cp:
if self.enable_dsa_cp:
assert k_pe is not None
assert k_nope is not None
# support all_gather kv async for communication calculation overlap
@@ -815,17 +835,26 @@ class AscendSFAImpl(MLAAttentionImpl):
k_nope.view(-1, k_nope.shape[-1]),
k.view(-1, k.shape[-1])
],
dim=1), get_tp_group())
dim=1),
get_tp_group(),
async_op=should_shard_weight)
ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c)
q_pe = self.rope_single(q_pe, cos, sin)
if self.enable_sfa_cp:
if self.enable_dsa_cp:
if kv_ag_handle is not None:
kv_ag_handle.wait()
for layer in (self.layer_sharding_kwargs or []):
if is_hidden_layer(layer):
reach_layer_for_shard_weight_series(layer)
if self.enable_dsa_cp_prefill_only:
for layer in (self.layer_sharding_kwargs or []):
if is_hidden_layer(layer):
reach_layer_for_shard_weight_series(layer)
elif should_shard_weight:
_, o_proj_full_handle = all_gather_async(
self.o_proj_tp_weight,
get_tp_group(),
output=AscendSFAImpl.o_proj_full_pool)
if kv_cache is not None:
assert fused_kv_no_split is not None
@@ -841,6 +870,12 @@ class AscendSFAImpl(MLAAttentionImpl):
kv_cache[1].view(-1, k_pe.shape[-1]), slot_mapping,
k_pe)
if kv_cache is not None:
torch_npu.npu_scatter_nd_update_(
kv_cache[2].view(-1, k.shape[-1]),
attn_metadata.slot_mapping.view(-1, 1),
k.view(-1, k.shape[-1])) # b, s, n, d
topk_indices = self.indexer_select_post_process(
x=hidden_states,
qr=q_c,
@@ -876,6 +911,20 @@ class AscendSFAImpl(MLAAttentionImpl):
dependency=attn_output,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch)
if self.enable_dsa_cp and not self.enable_dsa_cp_prefill_only:
# When using SFA-CP with pd mixed, o_proj has two cases:
# 1. prefill: o_proj is a TP weight, we need to all-gather o_proj weight to switch TP=1.
# 2. decode: all-to-all the hidden_state before the o_proj forward.
result, require_o_proj_forward = self._handle_o_proj_weight_switch_and_forward(
attn_output=attn_output,
output=output,
o_proj_full_handle=o_proj_full_handle,
should_shard_weight=should_shard_weight)
if not require_o_proj_forward:
return result
attn_output = result
output[...] = self.o_proj(attn_output)[0]
maybe_save_kv_layer_to_connector(layer_name, list(kv_cache))
@@ -912,7 +961,10 @@ class AscendSFAImpl(MLAAttentionImpl):
k_pe, k_nope = torch.split(
k,
[self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
dim=-1) # [b,s,64+64]
dim=-1)
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
k_pe = k_pe.unsqueeze(2)
k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin)
@@ -940,10 +992,7 @@ class AscendSFAImpl(MLAAttentionImpl):
if q is None:
q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128]
cos_q, sin_q = cos, sin
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
q_pe, q_nope = torch.split(
q,
@@ -984,3 +1033,92 @@ class AscendSFAImpl(MLAAttentionImpl):
sparse_count=2048,
sparse_mode=3)
return topk_indices
def _init_o_proj_tp_full_params(self):
"""
Initialize TP-mode and Full-mode parameters for o_proj weight,
preparing for weight switching in PD mix stage.
For PD mix stage:
- Use original TP o_proj weight for decode phase
- Need full-gather o_proj weight from all TP ranks for prefill phase
"""
if AscendSFAImpl.o_proj_full_pool is None:
sample = self.o_proj.weight
AscendSFAImpl.o_proj_full_pool = torch.empty(
(sample.shape[0] * self.tp_size, sample.shape[1]),
dtype=sample.dtype,
device=sample.device)
# Save TP-mode parameters (original sharded weights)
self.o_proj_tp_weight = self.o_proj.weight.clone().detach()
self.o_proj_tp_aclnn_input_scale = self.o_proj.aclnn_input_scale.clone(
).detach()
self.o_proj_tp_aclnn_input_scale_reciprocal = self.o_proj.aclnn_input_scale_reciprocal.clone(
).detach()
self.o_proj_tp_aclnn_input_offset = self.o_proj.aclnn_input_offset.clone(
).detach()
# Initially switch to TP mode for graph capture
self.o_proj.weight.set_(self.o_proj_tp_weight)
self.o_proj.aclnn_input_scale.set_(self.o_proj_tp_aclnn_input_scale)
self.o_proj.aclnn_input_scale_reciprocal.set_(
self.o_proj_tp_aclnn_input_scale_reciprocal)
self.o_proj.aclnn_input_offset.set_(self.o_proj_tp_aclnn_input_offset)
# Precompute Full-mode quantization parameters by repeating TP parameters across all TP ranks
self.o_proj_full_aclnn_input_scale = self.o_proj.aclnn_input_scale.repeat(
self.tp_size)
self.o_proj_full_aclnn_input_scale_reciprocal = self.o_proj.aclnn_input_scale_reciprocal.repeat(
self.tp_size)
self.o_proj_full_aclnn_input_offset = self.o_proj.aclnn_input_offset.repeat(
self.tp_size)
def _handle_o_proj_weight_switch_and_forward(
self, attn_output: torch.Tensor, output: torch.Tensor,
o_proj_full_handle: Optional[torch.distributed.Work],
should_shard_weight: bool) -> Tuple[torch.Tensor, bool]:
"""
Handle o_proj weight switching between TP-mode and Full-mode, and execute forward computation.
"""
# Gather o_proj weight from all TP ranks for Full-mode computation
if should_shard_weight:
# Wait for the completion of o_proj weight all-gather operation
if o_proj_full_handle is not None:
o_proj_full_handle.wait()
# Switch o_proj to Full-mode (gathered weight from all TP ranks)
self.o_proj.weight.set_(AscendSFAImpl.o_proj_full_pool)
self.o_proj.aclnn_input_scale.set_(
self.o_proj_full_aclnn_input_scale)
self.o_proj.aclnn_input_scale_reciprocal.set_(
self.o_proj_full_aclnn_input_scale_reciprocal)
self.o_proj.aclnn_input_offset.set_(
self.o_proj_full_aclnn_input_offset)
# Apply quantization method and execute forward computation
output[...] = self.o_proj.quant_method.quant_method.apply(
self.o_proj, attn_output)
# Switch o_proj back to TP-mode for subsequent decode operations
self.o_proj.weight.set_(self.o_proj_tp_weight)
self.o_proj.aclnn_input_scale.set_(
self.o_proj_tp_aclnn_input_scale)
self.o_proj.aclnn_input_scale_reciprocal.set_(
self.o_proj_tp_aclnn_input_scale_reciprocal)
self.o_proj.aclnn_input_offset.set_(
self.o_proj_tp_aclnn_input_offset)
return output, False
else:
# For decode scenario: perform all-to-all communication on o_proj input activations
# Reshape for all-to-all: [batch * seq, tp_size, head_dim] -> [tp_size, batch * seq, head_dim]
send = attn_output.view(-1, self.tp_size, self.num_heads *
self.v_head_dim).permute(1, 0, 2).reshape(
-1, self.num_heads * self.v_head_dim)
attn_output = torch.empty_like(send)
torch.distributed.all_to_all_single(
attn_output, send, group=get_tp_group().device_group)
return attn_output, True

View File

@@ -7,7 +7,7 @@ from vllm.distributed.parallel_state import (GroupCoordinator, get_tp_group,
init_model_parallel_group)
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import enable_dsa_cp, flashcomm2_enable
from vllm_ascend.utils import enable_dsa_cp_with_layer_shard, flashcomm2_enable
# Currently, mc2 op need their own group coordinator.
_MC2: Optional[GroupCoordinator] = None
@@ -238,7 +238,7 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
FC2_group_ranks = torch.tensor(
flashcomm2_otp_group_ranks).squeeze(0)
_SHARD_WEIGHT = create_shard_weight_group(FC2_group_ranks)
elif enable_dsa_cp():
elif enable_dsa_cp_with_layer_shard():
# For dsa_cp, all shard layers are replicated.
_SHARD_WEIGHT = create_shard_weight_group(None)
else:

View File

@@ -52,4 +52,4 @@ def all_gather_async(input: torch.Tensor,
return output, dist.all_gather_into_tensor(output,
input,
group=group.device_group,
async_op=async_op)
async_op=async_op)

View File

@@ -62,7 +62,7 @@ from vllm_ascend.distributed.parallel_state import (get_flashcomm2_odp_group,
get_mlp_tp_group,
get_otp_group)
from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager
from vllm_ascend.utils import (enable_dsa_cp, enable_sp, flashcomm2_enable,
from vllm_ascend.utils import (enable_dsa_cp, enable_dsa_cp_with_layer_shard, enable_sp, flashcomm2_enable,
get_flashcomm2_reorgnized_batch_ids,
matmul_allreduce_enable, mlp_tp_enable,
oproj_tp_enable, shared_expert_dp_enabled)
@@ -575,7 +575,8 @@ class SequenceRowParallelOp(CustomRowParallelOp):
return tensor_model_parallel_all_reduce(output_parallel)
pad_size = forward_context.pad_size
if pad_size > 0:
if pad_size > 0 and not (enable_dsa_cp()
and "o_proj" in self.layer.prefix):
x = F.pad(x, (0, 0, 0, pad_size))
world_size = self.layer.tp_size
@@ -728,7 +729,7 @@ def _get_row_parallel_op(
) -> Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
Flashcomm2OProjRowParallelOp, MatmulAllreduceRowParallelOp,
SequenceRowParallelOp, ShardedCPRowParallelOp]]:
if enable_dsa_cp() and "o_proj" in prefix:
if enable_dsa_cp_with_layer_shard() and "o_proj" in prefix:
return ShardedCPRowParallelOp(layer)
if "down_proj" in prefix and mlp_tp_enable() and not is_moe_layer(prefix):
return MLPRowParallelOp(layer)