init v0.11.0rc0
This commit is contained in:
@@ -23,7 +23,6 @@ import torch_npu
|
||||
from vllm.distributed import GroupCoordinator, get_ep_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import FusedMoEState
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
@@ -417,6 +416,7 @@ def torchair_fused_experts_with_all2all(
|
||||
num_experts = w1.shape[0]
|
||||
|
||||
if expert_map is not None:
|
||||
assert ep_group is not None, "ep_group must be provided when expert_map is given"
|
||||
global_num_experts = len(expert_map) + global_redundant_expert_num
|
||||
if hasattr(torch_npu, "npu_moe_init_routing_quant"):
|
||||
quantized_tokens, expanded_row_idx, global_expert_tokens, _, token_scales = torch_npu.npu_moe_init_routing_quant(
|
||||
@@ -436,8 +436,9 @@ def torchair_fused_experts_with_all2all(
|
||||
|
||||
gather_sizes = global_expert_tokens.new_empty(
|
||||
global_expert_tokens.shape[0])
|
||||
dist.all_to_all_single(gather_sizes, global_expert_tokens)
|
||||
|
||||
dist.all_to_all_single(gather_sizes,
|
||||
global_expert_tokens,
|
||||
group=ep_group.device_group)
|
||||
token_counts_combined = torch.stack(
|
||||
[gather_sizes, global_expert_tokens], dim=0)
|
||||
token_counts_combined = token_counts_combined.view(
|
||||
@@ -452,10 +453,16 @@ def torchair_fused_experts_with_all2all(
|
||||
gather_size_list = token_counts_combined_cpu[1]
|
||||
scatter_size_list = token_counts_combined_cpu[0]
|
||||
|
||||
dist.all_to_all_single(gathered_tokens, quantized_tokens,
|
||||
scatter_size_list, gather_size_list)
|
||||
dist.all_to_all_single(dynamic_scale, token_scales, scatter_size_list,
|
||||
gather_size_list)
|
||||
dist.all_to_all_single(gathered_tokens,
|
||||
quantized_tokens,
|
||||
scatter_size_list,
|
||||
gather_size_list,
|
||||
group=ep_group.device_group)
|
||||
dist.all_to_all_single(dynamic_scale,
|
||||
token_scales,
|
||||
scatter_size_list,
|
||||
gather_size_list,
|
||||
group=ep_group.device_group)
|
||||
|
||||
hidden_states, dynamic_scale, inverse_indices, expert_tokens = torch_npu.npu_moe_re_routing(
|
||||
gathered_tokens,
|
||||
@@ -503,9 +510,11 @@ def torchair_fused_experts_with_all2all(
|
||||
index=inverse_indices.to(torch.float32).argsort().to(torch.int32))
|
||||
|
||||
hidden_states = reordered_outputs.new_empty(*quantized_tokens.shape)
|
||||
dist.all_to_all_single(hidden_states, reordered_outputs,
|
||||
gather_size_list, scatter_size_list)
|
||||
|
||||
dist.all_to_all_single(hidden_states,
|
||||
reordered_outputs,
|
||||
gather_size_list,
|
||||
scatter_size_list,
|
||||
group=ep_group.device_group)
|
||||
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
||||
hidden_states,
|
||||
skip1=None,
|
||||
@@ -824,6 +833,7 @@ class TorchairAscendW8A8DynamicFusedMoEMethod:
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
|
||||
try:
|
||||
device_group = get_mc2_group().device_group
|
||||
@@ -937,6 +947,8 @@ class TorchairAscendW8A8DynamicFusedMoEMethod:
|
||||
)
|
||||
|
||||
fused_moe_state = get_forward_context().fused_moe_state
|
||||
if self.enable_shared_expert_dp and fused_moe_state == FusedMoEState.MC2:
|
||||
fused_moe_state = FusedMoEState.All2All
|
||||
shared_gate_up, shared_dequant_scale = None, None
|
||||
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
|
||||
with npu_stream_switch("moe_secondary", 0):
|
||||
@@ -1021,8 +1033,7 @@ class TorchairAscendW8A8DynamicFusedMoEMethod:
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
if envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP:
|
||||
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
|
||||
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
|
||||
layer.w13_weight_scale.data.shape[0], -1)
|
||||
layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to(
|
||||
|
||||
Reference in New Issue
Block a user