[Refactor] [SP]The sequence parallelism characteristics in the MoE and Dense models are integrated into a single solution. (#3085)
What this PR does / why we need it?
there are two sets of sp implementations for moe and dense models. One
is called sequence_parallelism, and the other is flashcomm_v1.
We did the following things:
Merge two sets of code with the same implementation into one.
Remove the implementation of sequence_parallelism, as this solution
cannot support aclgraph.
Does this PR introduce any user-facing change?
No
How was this patch tested?
e2e&ut
- vLLM version: v0.10.2
- vLLM main:
f225ea7dd9
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
@@ -20,10 +20,9 @@ def _maybe_chunk_residual_impl(x: torch.Tensor,
|
||||
return residual
|
||||
|
||||
if x.size(0) != residual.size(0):
|
||||
flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled
|
||||
assert flashcomm_v1_enabled is True, (
|
||||
"Currently, this situation only occurs "
|
||||
"when flashcomm_v1 is enabled")
|
||||
sp_enabled = forward_context.sp_enabled
|
||||
assert sp_enabled is True, ("Currently, this situation only occurs "
|
||||
"when sp is enabled")
|
||||
pad_size = forward_context.pad_size
|
||||
if pad_size > 0:
|
||||
residual = F.pad(residual, (0, 0, 0, pad_size))
|
||||
@@ -41,8 +40,8 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor,
|
||||
except AssertionError:
|
||||
return x
|
||||
|
||||
flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled
|
||||
if flashcomm_v1_enabled and label:
|
||||
sp_enabled = forward_context.sp_enabled
|
||||
if sp_enabled and label:
|
||||
x = tensor_model_parallel_all_gather(x, 0)
|
||||
pad_size = forward_context.pad_size
|
||||
if pad_size > 0:
|
||||
@@ -56,8 +55,8 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
|
||||
except AssertionError:
|
||||
return tensor_model_parallel_all_reduce(x)
|
||||
|
||||
flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled
|
||||
if flashcomm_v1_enabled:
|
||||
sp_enabled = forward_context.sp_enabled
|
||||
if sp_enabled:
|
||||
pad_size = forward_context.pad_size
|
||||
if pad_size > 0:
|
||||
x = F.pad(x, (0, 0, 0, pad_size))
|
||||
|
||||
Reference in New Issue
Block a user