[Feat]support sequence parallelism by pass for VL models (#5632)

This commit is contained in:
realliujiaxu
2026-02-27 08:27:41 +08:00
committed by GitHub
parent ed175d6d92
commit 5def28dcd3
22 changed files with 460 additions and 101 deletions

View File

@@ -70,7 +70,7 @@ from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager
from vllm_ascend.utils import (
enable_dsa_cp,
enable_dsa_cp_with_layer_shard,
enable_sp,
enable_flash_comm_v1,
flashcomm2_enable,
get_flashcomm2_reorgnized_batch_ids,
get_weight_prefetch_method,
@@ -368,7 +368,7 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp):
else:
output = output_parallel
if not forward_context.sp_enabled:
if not forward_context.flash_comm_v1_enabled:
# flashcomm1 not enabled
output = get_tp_group().all_gather(output, 0)
if num_padding_tokens > 0:
@@ -466,7 +466,7 @@ class Flashcomm2OshardQKVParallelOp(CustomColumnParallelOp):
# Matrix multiply.
assert self.quant_method is not None
if enable_sp():
if enable_flash_comm_v1():
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True)
# Trigger async broadcast before matmul to overlap communication.
@@ -515,15 +515,15 @@ class SequenceRowParallelOp(CustomRowParallelOp):
assert self.quant_method is not None
try:
forward_context = get_forward_context()
sp_enabled = forward_context.sp_enabled
flash_comm_v1_enabled = forward_context.flash_comm_v1_enabled
mmrs_fusion = forward_context.mmrs_fusion
except AssertionError:
sp_enabled = False
flash_comm_v1_enabled = False
mmrs_fusion = False
x = input_parallel
if not sp_enabled:
if not flash_comm_v1_enabled:
output_parallel = self.layer.quant_method.apply(self.layer, x, bias=bias_)
return tensor_model_parallel_all_reduce(output_parallel)
@@ -649,7 +649,7 @@ def _get_column_parallel_op(
if flashcomm2_oshard_manager.flashcomm2_oshard_enable():
if any(p in prefix for p in ("qkv_proj", "conv1d", "query_key_value")):
return Flashcomm2OshardQKVParallelOp(layer)
if enable_sp():
if enable_flash_comm_v1():
if "shared_expert" in prefix:
return None
sp_column_prefix = [
@@ -688,7 +688,7 @@ def _get_row_parallel_op(
if flashcomm2_enable():
if "o_proj" in prefix or "out_proj" in prefix:
return Flashcomm2OProjRowParallelOp(layer)
if enable_sp():
if enable_flash_comm_v1():
if "shared_expert" in prefix:
return None
sp_row_prefixes = [