[BugFix]Fix eplb problems when using dynamic eplb. (#3364)
### What this PR does / why we need it? When using dynamic eplb,it will be blocking by nz tensor.We fix these prolems by clone src tensor and recv tensor. ### Does this PR introduce any user-facing change? ### How was this patch tested? Qwen3_moe in A3. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: offline0806 <3337230449@qq.com> Co-authored-by: offline0806 <3337230449@qq.com>
This commit is contained in:
@@ -48,13 +48,7 @@ def test_generate_task_and_state_flow(mock_adaptor):
|
|||||||
|
|
||||||
loader_obj.generate_expert_d2d_transfer_task([], [], {}, 0)
|
loader_obj.generate_expert_d2d_transfer_task([], [], {}, 0)
|
||||||
assert loader_obj.comm_op_list is None
|
assert loader_obj.comm_op_list is None
|
||||||
|
assert loader_obj.state == loader.ExpertWeightUpdateState.WAITING
|
||||||
updated_map = {20: torch.tensor(0)}
|
|
||||||
loader_obj.generate_expert_d2d_transfer_task([(1, 10)], [(2, 20)],
|
|
||||||
updated_map, 0)
|
|
||||||
assert loader_obj.state == loader.ExpertWeightUpdateState.READY
|
|
||||||
assert loader_obj.comm_op_list
|
|
||||||
assert loader_obj.recv_expert_list
|
|
||||||
|
|
||||||
|
|
||||||
def test_asyn_transfer_and_update(mock_adaptor):
|
def test_asyn_transfer_and_update(mock_adaptor):
|
||||||
|
|||||||
@@ -80,15 +80,15 @@ class VllmEplbAdaptor(EplbAdaptor):
|
|||||||
self.all_topk_ids = []
|
self.all_topk_ids = []
|
||||||
|
|
||||||
def init_buffer_tensor(self, num_buffer_tensor):
|
def init_buffer_tensor(self, num_buffer_tensor):
|
||||||
|
for buffer_id in range(num_buffer_tensor):
|
||||||
for name in self.expert_weight_names:
|
for name in self.expert_weight_names:
|
||||||
complete_name = "model.layers." + str(
|
complete_name = "model.layers." + str(
|
||||||
self.num_dense_layers) + ".mlp.experts." + name
|
self.num_dense_layers) + ".mlp.experts." + name
|
||||||
expert_tensor = self.param_dict[complete_name].data[
|
expert_tensor = self.param_dict[complete_name].data[0]
|
||||||
0:num_buffer_tensor]
|
if name in ["w13_weight", "w2_weight"]:
|
||||||
buffer_tensors = torch.empty_like(expert_tensor)
|
expert_tensor = expert_tensor.clone()
|
||||||
for buffer_id in range(num_buffer_tensor):
|
buffer_tensor = torch.empty_like(expert_tensor)
|
||||||
self.buffer_tensor_list[buffer_id].append(
|
self.buffer_tensor_list[buffer_id].append(buffer_tensor)
|
||||||
buffer_tensors[buffer_id])
|
|
||||||
|
|
||||||
def init_expert_param_per_layer(self):
|
def init_expert_param_per_layer(self):
|
||||||
num_local_expert = self.param_dict["model.layers." + str(self.num_dense_layers) + \
|
num_local_expert = self.param_dict["model.layers." + str(self.num_dense_layers) + \
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ class D2DExpertWeightLoader:
|
|||||||
layer_id):
|
layer_id):
|
||||||
# When current send/recv and weight.expert_map update tasks are not finished, cannot accept new d2d task
|
# When current send/recv and weight.expert_map update tasks are not finished, cannot accept new d2d task
|
||||||
if self.state != ExpertWeightUpdateState.WAITING:
|
if self.state != ExpertWeightUpdateState.WAITING:
|
||||||
logger.error(
|
logger.warning_once(
|
||||||
"current d2d weight update tasks are on-going, cannot accept new weight update task"
|
"current d2d weight update tasks are on-going, cannot accept new weight update task"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
@@ -64,6 +64,7 @@ class D2DExpertWeightLoader:
|
|||||||
layer_id][global_expert_id_to_send].item()
|
layer_id][global_expert_id_to_send].item()
|
||||||
for src_tensor in self.eplb_adaptor.expert_param_per_layer[
|
for src_tensor in self.eplb_adaptor.expert_param_per_layer[
|
||||||
layer_id][local_expert_id]:
|
layer_id][local_expert_id]:
|
||||||
|
src_tensor = src_tensor.clone()
|
||||||
self.comm_op_list.append(
|
self.comm_op_list.append(
|
||||||
dist.P2POp(dist.isend, src_tensor, dst_rank))
|
dist.P2POp(dist.isend, src_tensor, dst_rank))
|
||||||
|
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from vllm.config import CompilationLevel, get_current_vllm_config
|
|||||||
from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group,
|
from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group,
|
||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
|
from vllm.logger import logger
|
||||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||||
from vllm.model_executor.layers.fused_moe.layer import (
|
from vllm.model_executor.layers.fused_moe.layer import (
|
||||||
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
|
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
|
||||||
@@ -185,13 +186,23 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
os.R_OK):
|
os.R_OK):
|
||||||
self.expert_load_balancer = ExpertLoadBalancer(
|
self.expert_load_balancer = ExpertLoadBalancer(
|
||||||
self.expert_map_path, self.global_num_experts)
|
self.expert_map_path, self.global_num_experts)
|
||||||
|
self.global_redundant_expert_num = (
|
||||||
|
self.expert_load_balancer.get_global_redundant_expert_num())
|
||||||
|
try:
|
||||||
self.local_num_experts, self.expert_map = (
|
self.local_num_experts, self.expert_map = (
|
||||||
self.expert_load_balancer.get_rank_placement_map(
|
self.expert_load_balancer.get_rank_placement_map(
|
||||||
self.moe_instance_id, self.ep_rank))
|
self.moe_instance_id, self.ep_rank))
|
||||||
self.log2phy = self.expert_load_balancer.get_rank_log2phy_map(
|
self.log2phy = self.expert_load_balancer.get_rank_log2phy_map(
|
||||||
self.moe_instance_id, self.ep_rank).npu()
|
self.moe_instance_id, self.ep_rank).npu()
|
||||||
self.global_redundant_expert_num = (
|
except Exception as e:
|
||||||
self.expert_load_balancer.get_global_redundant_expert_num())
|
logger.warning(
|
||||||
|
f"Init expert map of mtp/eagle when using sample.{e}")
|
||||||
|
self.local_num_experts, self.expert_map = determine_default_expert_map(
|
||||||
|
self.global_num_experts, self.ep_size, self.ep_rank,
|
||||||
|
self.global_redundant_expert_num)
|
||||||
|
self.log2phy = determine_default_log2phy_map(
|
||||||
|
self.global_num_experts, self.ep_size, self.ep_rank,
|
||||||
|
self.global_redundant_expert_num).npu()
|
||||||
else:
|
else:
|
||||||
# init moe.
|
# init moe.
|
||||||
self.local_num_experts, self.expert_map = determine_expert_map(
|
self.local_num_experts, self.expert_map = determine_expert_map(
|
||||||
@@ -227,6 +238,7 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
if (self.quant_method.__class__.__name__
|
if (self.quant_method.__class__.__name__
|
||||||
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
|
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
|
||||||
moe_quant_params["intermediate_size_full"] = intermediate_size
|
moe_quant_params["intermediate_size_full"] = intermediate_size
|
||||||
|
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||||
|
|
||||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||||
|
|
||||||
|
|||||||
@@ -150,7 +150,8 @@ class MoECommMethod(ABC):
|
|||||||
with_quant=use_int8_w8a8
|
with_quant=use_int8_w8a8
|
||||||
or use_int4_w4a8,
|
or use_int4_w4a8,
|
||||||
fusion=use_int8_w8a8,
|
fusion=use_int8_w8a8,
|
||||||
need_trans=need_trans)
|
need_trans=need_trans,
|
||||||
|
dynamic_eplb=dynamic_eplb)
|
||||||
|
|
||||||
final_hidden_states = self.token_dispatcher.token_combine(
|
final_hidden_states = self.token_dispatcher.token_combine(
|
||||||
hidden_states=mlp_output)
|
hidden_states=mlp_output)
|
||||||
|
|||||||
@@ -63,7 +63,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
|||||||
dynamic_scale: torch.Tensor = None,
|
dynamic_scale: torch.Tensor = None,
|
||||||
w1_scale_bias: torch.Tensor = None,
|
w1_scale_bias: torch.Tensor = None,
|
||||||
w2_scale_bias: torch.Tensor = None,
|
w2_scale_bias: torch.Tensor = None,
|
||||||
fusion: bool = False) -> torch.Tensor:
|
fusion: bool = False,
|
||||||
|
dynamic_eplb: bool = False) -> torch.Tensor:
|
||||||
if dynamic_scale is None:
|
if dynamic_scale is None:
|
||||||
unquantized_hidden_states = hidden_states
|
unquantized_hidden_states = hidden_states
|
||||||
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
|
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
|
||||||
@@ -79,7 +80,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
|||||||
|
|
||||||
is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2
|
is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2
|
||||||
if w1_scale_bias is None and is_mc2:
|
if w1_scale_bias is None and is_mc2:
|
||||||
if fusion:
|
if fusion and not dynamic_eplb:
|
||||||
# gmm1: gate_up_proj & act_fn: swiglu
|
# gmm1: gate_up_proj & act_fn: swiglu
|
||||||
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
|
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
|
||||||
x=hidden_states,
|
x=hidden_states,
|
||||||
@@ -134,7 +135,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
|||||||
# TODO w4a8 scene: dynamic acquisition of dtype in the future
|
# TODO w4a8 scene: dynamic acquisition of dtype in the future
|
||||||
_output_dtype = torch.bfloat16
|
_output_dtype = torch.bfloat16
|
||||||
|
|
||||||
if fusion:
|
if fusion and not dynamic_eplb:
|
||||||
# gmm1: gate_up_proj & act_fn: swiglu
|
# gmm1: gate_up_proj & act_fn: swiglu
|
||||||
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
|
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
|
||||||
x=hidden_states,
|
x=hidden_states,
|
||||||
@@ -229,7 +230,8 @@ def unified_apply_mlp(hidden_states: torch.Tensor,
|
|||||||
topk_scales: Optional[torch.Tensor] = None,
|
topk_scales: Optional[torch.Tensor] = None,
|
||||||
with_quant: bool = False,
|
with_quant: bool = False,
|
||||||
fusion: bool = False,
|
fusion: bool = False,
|
||||||
need_trans: bool = True) -> torch.Tensor:
|
need_trans: bool = True,
|
||||||
|
dynamic_eplb: bool = False) -> torch.Tensor:
|
||||||
if with_quant:
|
if with_quant:
|
||||||
return quant_apply_mlp(hidden_states=hidden_states,
|
return quant_apply_mlp(hidden_states=hidden_states,
|
||||||
w1=w1,
|
w1=w1,
|
||||||
@@ -241,7 +243,8 @@ def unified_apply_mlp(hidden_states: torch.Tensor,
|
|||||||
group_list_type=group_list_type,
|
group_list_type=group_list_type,
|
||||||
w1_scale_bias=w1_scale_bias,
|
w1_scale_bias=w1_scale_bias,
|
||||||
w2_scale_bias=w2_scale_bias,
|
w2_scale_bias=w2_scale_bias,
|
||||||
fusion=fusion)
|
fusion=fusion,
|
||||||
|
dynamic_eplb=dynamic_eplb)
|
||||||
else:
|
else:
|
||||||
return unquant_apply_mlp(hidden_states=hidden_states,
|
return unquant_apply_mlp(hidden_states=hidden_states,
|
||||||
w1=w1,
|
w1=w1,
|
||||||
|
|||||||
@@ -236,7 +236,9 @@ class AscendW8A8DynamicFusedMoEMethod:
|
|||||||
w1_scale=layer.w13_weight_scale,
|
w1_scale=layer.w13_weight_scale,
|
||||||
w2_scale=layer.w2_weight_scale,
|
w2_scale=layer.w2_weight_scale,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
dynamic_eplb=self.dynamic_eplb)
|
dynamic_eplb=self.dynamic_eplb,
|
||||||
|
log2phy=log2phy,
|
||||||
|
global_redundant_expert_num=global_redundant_expert_num)
|
||||||
|
|
||||||
topk_weights = topk_weights.to(x.dtype)
|
topk_weights = topk_weights.to(x.dtype)
|
||||||
|
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_rank,
|
|||||||
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
|
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
|
||||||
get_tp_group)
|
get_tp_group)
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
|
from vllm.logger import logger
|
||||||
from vllm.model_executor.layers.fused_moe.config import \
|
from vllm.model_executor.layers.fused_moe.config import \
|
||||||
FusedMoEConfig # isort: skip
|
FusedMoEConfig # isort: skip
|
||||||
from vllm.model_executor.layers.fused_moe.config import \
|
from vllm.model_executor.layers.fused_moe.config import \
|
||||||
@@ -1027,13 +1028,23 @@ class TorchairAscendFusedMoE(FusedMoE):
|
|||||||
os.R_OK):
|
os.R_OK):
|
||||||
self.expert_load_balancer = ExpertLoadBalancer(
|
self.expert_load_balancer = ExpertLoadBalancer(
|
||||||
self.expert_map_path, self.global_num_experts)
|
self.expert_map_path, self.global_num_experts)
|
||||||
|
self.global_redundant_expert_num = (
|
||||||
|
self.expert_load_balancer.get_global_redundant_expert_num())
|
||||||
|
try:
|
||||||
self.local_num_experts, self.expert_map = (
|
self.local_num_experts, self.expert_map = (
|
||||||
self.expert_load_balancer.get_rank_placement_map(
|
self.expert_load_balancer.get_rank_placement_map(
|
||||||
self.moe_instance_id, self.ep_rank))
|
self.moe_instance_id, self.ep_rank))
|
||||||
self.log2phy = self.expert_load_balancer.get_rank_log2phy_map(
|
self.log2phy = self.expert_load_balancer.get_rank_log2phy_map(
|
||||||
self.moe_instance_id, self.ep_rank).npu()
|
self.moe_instance_id, self.ep_rank).npu()
|
||||||
self.global_redundant_expert_num = (
|
except Exception as e:
|
||||||
self.expert_load_balancer.get_global_redundant_expert_num())
|
logger.warning(
|
||||||
|
f"Init expert map of mtp/eagle when using sample.{e}")
|
||||||
|
self.local_num_experts, self.expert_map = determine_default_expert_map(
|
||||||
|
self.global_num_experts, self.ep_size, self.ep_rank,
|
||||||
|
self.global_redundant_expert_num)
|
||||||
|
self.log2phy = determine_default_log2phy_map(
|
||||||
|
self.global_num_experts, self.ep_size, self.ep_rank,
|
||||||
|
self.global_redundant_expert_num).npu()
|
||||||
else:
|
else:
|
||||||
# init moe.
|
# init moe.
|
||||||
self.local_num_experts, self.expert_map = determine_expert_map(
|
self.local_num_experts, self.expert_map = determine_expert_map(
|
||||||
|
|||||||
Reference in New Issue
Block a user