Improve weight loading and code style (#3174)
This commit is contained in:
@@ -114,6 +114,7 @@ class EPMoE(torch.nn.Module):
|
||||
tp_size: Optional[int] = None,
|
||||
prefix: str = "",
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
activation: str = "silu",
|
||||
):
|
||||
super().__init__()
|
||||
@@ -141,6 +142,7 @@ class EPMoE(torch.nn.Module):
|
||||
self.num_expert_group = num_expert_group
|
||||
self.topk_group = topk_group
|
||||
self.correction_bias = correction_bias
|
||||
self.custom_routing_function = custom_routing_function
|
||||
self.activation = activation
|
||||
|
||||
if quant_config is None:
|
||||
@@ -184,6 +186,7 @@ class EPMoE(torch.nn.Module):
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
correction_bias=self.correction_bias,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
)
|
||||
|
||||
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
|
||||
@@ -257,16 +260,20 @@ class EPMoE(torch.nn.Module):
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
||||
gateup_output,
|
||||
down_input,
|
||||
gateup_output.shape[1],
|
||||
reorder_topk_ids,
|
||||
self.w2_input_scale,
|
||||
self.start_expert_id,
|
||||
self.end_expert_id,
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
|
||||
if self.activation == "silu":
|
||||
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
||||
gateup_output,
|
||||
down_input,
|
||||
gateup_output.shape[1],
|
||||
reorder_topk_ids,
|
||||
self.w2_input_scale,
|
||||
self.start_expert_id,
|
||||
self.end_expert_id,
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation: {self.activation=}")
|
||||
|
||||
# GroupGemm-1
|
||||
down_output = torch.empty(
|
||||
@@ -312,7 +319,6 @@ class EPMoE(torch.nn.Module):
|
||||
ckpt_up_proj_name: str,
|
||||
num_experts: int,
|
||||
) -> List[Tuple[str, str, int, str]]:
|
||||
|
||||
return [
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
(
|
||||
@@ -357,7 +363,6 @@ class EPMoE(torch.nn.Module):
|
||||
)
|
||||
return
|
||||
|
||||
expert_data = param.data[expert_id]
|
||||
if shard_id == "w2":
|
||||
param.data[expert_id] = loaded_weight
|
||||
elif shard_id == "w1":
|
||||
|
||||
Reference in New Issue
Block a user