### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| `vllm_ascend/ops/fused_moe/comm_utils.py` |
| `vllm_ascend/ops/fused_moe/experts_selector.py` |
| `vllm_ascend/ops/fused_moe/fused_moe.py` |
| `vllm_ascend/ops/fused_moe/moe_comm_method.py` |
| `vllm_ascend/ops/fused_moe/moe_mlp.py` |
| `vllm_ascend/ops/fused_moe/prepare_finalize.py` |
| `vllm_ascend/ops/fused_moe/token_dispatcher.py` |
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.14.0
- vLLM main:
d68209402d
Signed-off-by: MrZ20 <2609716663@qq.com>
Signed-off-by: SILONG ZENG <2609716663@qq.com>
This commit is contained in:
@@ -23,11 +23,7 @@ import torch_npu
|
||||
COMM_STREAM = None
|
||||
|
||||
|
||||
def async_all_to_all(input_,
|
||||
output_split_sizes,
|
||||
input_split_sizes,
|
||||
group,
|
||||
event=None):
|
||||
def async_all_to_all(input_, output_split_sizes, input_split_sizes, group, event=None):
|
||||
if output_split_sizes is None:
|
||||
# Equal split (all2all)
|
||||
a2a_out = torch.empty_like(input_)
|
||||
@@ -43,8 +39,7 @@ def async_all_to_all(input_,
|
||||
# multi stream wait event
|
||||
global COMM_STREAM
|
||||
if COMM_STREAM is None:
|
||||
COMM_STREAM = torch_npu.npu.Stream(
|
||||
device=torch.npu.current_device())
|
||||
COMM_STREAM = torch_npu.npu.Stream(device=torch.npu.current_device())
|
||||
with torch_npu.npu.stream(COMM_STREAM):
|
||||
event.wait()
|
||||
handle = dist.all_to_all_single(
|
||||
@@ -53,14 +48,17 @@ def async_all_to_all(input_,
|
||||
output_split_sizes=output_split_sizes,
|
||||
input_split_sizes=input_split_sizes,
|
||||
group=group,
|
||||
async_op=True)
|
||||
async_op=True,
|
||||
)
|
||||
else:
|
||||
handle = dist.all_to_all_single(a2a_out,
|
||||
input_.contiguous(),
|
||||
output_split_sizes=output_split_sizes,
|
||||
input_split_sizes=input_split_sizes,
|
||||
group=group,
|
||||
async_op=True)
|
||||
handle = dist.all_to_all_single(
|
||||
a2a_out,
|
||||
input_.contiguous(),
|
||||
output_split_sizes=output_split_sizes,
|
||||
input_split_sizes=input_split_sizes,
|
||||
group=group,
|
||||
async_op=True,
|
||||
)
|
||||
return input_, a2a_out, handle
|
||||
|
||||
|
||||
@@ -86,19 +84,12 @@ def _gather_along_first_dim(input_, group, output_split_sizes=None):
|
||||
if output_split_sizes is None:
|
||||
dim_size[0] = dim_size[0] * world_size
|
||||
|
||||
output = torch.empty(dim_size,
|
||||
dtype=input_.dtype,
|
||||
device=torch.npu.current_device())
|
||||
torch.distributed.all_gather_into_tensor(output,
|
||||
input_.contiguous(),
|
||||
group=group)
|
||||
output = torch.empty(dim_size, dtype=input_.dtype, device=torch.npu.current_device())
|
||||
torch.distributed.all_gather_into_tensor(output, input_.contiguous(), group=group)
|
||||
else:
|
||||
dim_size[0] = sum(output_split_sizes)
|
||||
output = torch.empty(dim_size,
|
||||
dtype=input_.dtype,
|
||||
device=torch.npu.current_device())
|
||||
output_tensor_list = list(
|
||||
torch.split(output, output_split_sizes, dim=0))
|
||||
output = torch.empty(dim_size, dtype=input_.dtype, device=torch.npu.current_device())
|
||||
output_tensor_list = list(torch.split(output, output_split_sizes, dim=0))
|
||||
torch.distributed.all_gather(output_tensor_list, input_, group=group)
|
||||
|
||||
return output
|
||||
@@ -110,4 +101,4 @@ def gather_from_sequence_parallel_region(
|
||||
output_split_sizes=None,
|
||||
):
|
||||
"""Wrapper for autograd function: forward: AG, backward: RS <first dim>"""
|
||||
return _gather_along_first_dim(input_, group, output_split_sizes)
|
||||
return _gather_along_first_dim(input_, group, output_split_sizes)
|
||||
|
||||
Reference in New Issue
Block a user