[Refactor] [SP]The sequence parallelism characteristics in the MoE and Dense models are integrated into a single solution. (#3085)

What this PR does / why we need it?

there are two sets of sp implementations for moe and dense models. One
is called sequence_parallelism, and the other is flashcomm_v1.
We did the following things:

Merge two sets of code with the same implementation into one.
Remove the implementation of sequence_parallelism, as this solution
cannot support aclgraph.
Does this PR introduce any user-facing change?

No

How was this patch tested?

e2e&ut

- vLLM version: v0.10.2
- vLLM main:
f225ea7dd9

---------

Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
weijinqian0
2025-09-24 11:29:59 +08:00
committed by GitHub
parent e7618d9414
commit 6aa4253798
14 changed files with 90 additions and 215 deletions

View File

@@ -20,10 +20,9 @@ def _maybe_chunk_residual_impl(x: torch.Tensor,
return residual
if x.size(0) != residual.size(0):
flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled
assert flashcomm_v1_enabled is True, (
"Currently, this situation only occurs "
"when flashcomm_v1 is enabled")
sp_enabled = forward_context.sp_enabled
assert sp_enabled is True, ("Currently, this situation only occurs "
"when sp is enabled")
pad_size = forward_context.pad_size
if pad_size > 0:
residual = F.pad(residual, (0, 0, 0, pad_size))
@@ -41,8 +40,8 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor,
except AssertionError:
return x
flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled
if flashcomm_v1_enabled and label:
sp_enabled = forward_context.sp_enabled
if sp_enabled and label:
x = tensor_model_parallel_all_gather(x, 0)
pad_size = forward_context.pad_size
if pad_size > 0:
@@ -56,8 +55,8 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
except AssertionError:
return tensor_model_parallel_all_reduce(x)
flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled
if flashcomm_v1_enabled:
sp_enabled = forward_context.sp_enabled
if sp_enabled:
pad_size = forward_context.pad_size
if pad_size > 0:
x = F.pad(x, (0, 0, 0, pad_size))