[Feat] Flash comm allgher ep (#3334)
Support flash comm v1(Sequence Parallelism) for Allgather EP. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: realliujiaxu <realliujiaxu@163.com> Co-authored-by: zhaozx-cn <zhaozx2116@163.com>
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
from vllm.distributed import (tensor_model_parallel_all_gather,
|
||||
from vllm.distributed import (get_dp_group, get_ep_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
tensor_model_parallel_reduce_scatter)
|
||||
from vllm.forward_context import get_forward_context
|
||||
@@ -13,8 +15,10 @@ from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.utils import npu_stream_switch, prefetch_stream
|
||||
|
||||
|
||||
def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor,
|
||||
label: bool) -> torch.Tensor:
|
||||
def _maybe_all_gather_and_maybe_unpad_impl(
|
||||
x: torch.Tensor,
|
||||
label: bool,
|
||||
is_ep_comm: bool = False) -> torch.Tensor:
|
||||
try:
|
||||
forward_context = get_forward_context()
|
||||
except AssertionError:
|
||||
@@ -22,27 +26,66 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor,
|
||||
|
||||
sp_enabled = forward_context.sp_enabled
|
||||
if sp_enabled and label:
|
||||
x = tensor_model_parallel_all_gather(x, 0)
|
||||
pad_size = forward_context.pad_size
|
||||
if pad_size > 0:
|
||||
x = x[:-pad_size, :]
|
||||
dp_metadata = forward_context.dp_metadata
|
||||
if dp_metadata is None or not is_ep_comm:
|
||||
x = tensor_model_parallel_all_gather(x, 0)
|
||||
pad_size = forward_context.pad_size
|
||||
if pad_size > 0:
|
||||
x = x[:-pad_size, :]
|
||||
else:
|
||||
x = get_ep_group().all_gather(x, 0)
|
||||
# unpad
|
||||
num_tokens_across_dp_cpu = dp_metadata.num_tokens_across_dp_cpu
|
||||
result = torch.empty(
|
||||
(num_tokens_across_dp_cpu.sum(), *x.shape[1:]),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
dp_size = get_dp_group().world_size
|
||||
x = x.view(dp_size, forward_context.padded_length, *x.shape[1:])
|
||||
offset = 0
|
||||
for idx in range(dp_size):
|
||||
num_tokens_dp = num_tokens_across_dp_cpu[idx]
|
||||
result[offset:offset +
|
||||
num_tokens_dp, :] = x[idx, :num_tokens_dp, :]
|
||||
offset += num_tokens_dp
|
||||
x = result
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
|
||||
def _maybe_pad_and_reduce_impl(x: torch.Tensor,
|
||||
is_ep_comm: bool = False) -> torch.Tensor:
|
||||
try:
|
||||
forward_context = get_forward_context()
|
||||
except AssertionError:
|
||||
return tensor_model_parallel_all_reduce(x)
|
||||
|
||||
sp_enabled = forward_context.sp_enabled
|
||||
if sp_enabled:
|
||||
if not forward_context.sp_enabled:
|
||||
return tensor_model_parallel_all_reduce(x)
|
||||
|
||||
dp_metadata = forward_context.dp_metadata
|
||||
if dp_metadata is None or not is_ep_comm:
|
||||
pad_size = forward_context.pad_size
|
||||
if pad_size > 0:
|
||||
x = F.pad(x, (0, 0, 0, pad_size))
|
||||
return tensor_model_parallel_reduce_scatter(x, 0)
|
||||
else:
|
||||
return tensor_model_parallel_all_reduce(x)
|
||||
# padding
|
||||
dp_size = get_dp_group().world_size
|
||||
num_tokens_across_dp_cpu = \
|
||||
get_forward_context().dp_metadata.num_tokens_across_dp_cpu
|
||||
padded_x = torch.empty(
|
||||
(dp_size, forward_context.padded_length, *x.shape[1:]),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
offset = 0
|
||||
for idx in range(dp_size):
|
||||
num_tokens_dp = num_tokens_across_dp_cpu[idx]
|
||||
padded_x[idx, :num_tokens_dp] = x[offset:offset + num_tokens_dp]
|
||||
offset += num_tokens_dp
|
||||
|
||||
return get_ep_group().reduce_scatter(padded_x.view(-1, *x.shape[1:]),
|
||||
0)
|
||||
|
||||
|
||||
def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor,
|
||||
@@ -71,6 +114,33 @@ def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor,
|
||||
return
|
||||
|
||||
|
||||
def _maybe_all_gather_and_maybe_unpad_fake(
|
||||
x: torch.Tensor,
|
||||
label: bool,
|
||||
is_ep_comm: bool = False) -> torch.Tensor:
|
||||
|
||||
if get_forward_context().sp_enabled and label:
|
||||
return torch.empty(
|
||||
(x.shape[0] * get_tensor_model_parallel_world_size(),
|
||||
*x.shape[1:]),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def _maybe_pad_and_reduce_fake(x: torch.Tensor,
|
||||
is_ep_comm: bool = False) -> torch.Tensor:
|
||||
if get_forward_context().sp_enabled:
|
||||
return torch.empty(
|
||||
(x.shape[0] // get_tensor_model_parallel_world_size(),
|
||||
*x.shape[1:]),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def _maybe_prefetch_mlp_gate_up_proj_impl_fake(x_dependency: torch.Tensor,
|
||||
prefix: str) -> None:
|
||||
return
|
||||
@@ -158,7 +228,8 @@ def _maybe_all_reduce_tensor_model_parallel_impl(
|
||||
final_hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
forward_context = get_forward_context()
|
||||
moe_comm_type = forward_context.moe_comm_type
|
||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}:
|
||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2
|
||||
} or forward_context.sp_enabled:
|
||||
return final_hidden_states
|
||||
else:
|
||||
return tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
@@ -166,13 +237,13 @@ def _maybe_all_reduce_tensor_model_parallel_impl(
|
||||
|
||||
direct_register_custom_op(op_name="maybe_all_gather_and_maybe_unpad",
|
||||
op_func=_maybe_all_gather_and_maybe_unpad_impl,
|
||||
fake_impl=lambda x, label: x,
|
||||
fake_impl=_maybe_all_gather_and_maybe_unpad_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
direct_register_custom_op(op_name="maybe_pad_and_reduce",
|
||||
op_func=_maybe_pad_and_reduce_impl,
|
||||
fake_impl=lambda x: x,
|
||||
fake_impl=_maybe_pad_and_reduce_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user