Improve weight loading and code style (#3174)

This commit is contained in:
Lianmin Zheng
2025-01-27 03:00:41 -08:00
committed by GitHub
parent 351a72d40b
commit 53cef81587
11 changed files with 171 additions and 65 deletions

View File

@@ -114,6 +114,7 @@ class EPMoE(torch.nn.Module):
tp_size: Optional[int] = None,
prefix: str = "",
correction_bias: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
activation: str = "silu",
):
super().__init__()
@@ -141,6 +142,7 @@ class EPMoE(torch.nn.Module):
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.correction_bias = correction_bias
self.custom_routing_function = custom_routing_function
self.activation = activation
if quant_config is None:
@@ -184,6 +186,7 @@ class EPMoE(torch.nn.Module):
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
correction_bias=self.correction_bias,
custom_routing_function=self.custom_routing_function,
)
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
@@ -257,16 +260,20 @@ class EPMoE(torch.nn.Module):
dtype=torch.float32,
device=hidden_states.device,
)
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
gateup_output,
down_input,
gateup_output.shape[1],
reorder_topk_ids,
self.w2_input_scale,
self.start_expert_id,
self.end_expert_id,
BLOCK_SIZE=512,
)
if self.activation == "silu":
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
gateup_output,
down_input,
gateup_output.shape[1],
reorder_topk_ids,
self.w2_input_scale,
self.start_expert_id,
self.end_expert_id,
BLOCK_SIZE=512,
)
else:
raise ValueError(f"Unsupported activation: {self.activation=}")
# GroupGemm-1
down_output = torch.empty(
@@ -312,7 +319,6 @@ class EPMoE(torch.nn.Module):
ckpt_up_proj_name: str,
num_experts: int,
) -> List[Tuple[str, str, int, str]]:
return [
# (param_name, weight_name, expert_id, shard_id)
(
@@ -357,7 +363,6 @@ class EPMoE(torch.nn.Module):
)
return
expert_data = param.data[expert_id]
if shard_id == "w2":
param.data[expert_id] = loaded_weight
elif shard_id == "w1":