Fix linear.py and improve weight loading (#2851)

Co-authored-by: SangBin Cho <rkooo567@gmail.com>
This commit is contained in:
Lianmin Zheng
2025-01-13 01:39:14 -08:00
committed by GitHub
parent 4093aa4660
commit 72c7776355
12 changed files with 113 additions and 125 deletions

View File

@@ -24,7 +24,9 @@ def fused_topk_native(
topk: int,
renormalize: bool,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
assert (
hidden_states.shape[0] == gating_output.shape[0]
), f"Number of tokens mismatch, {hidden_states.shape=} vs {gating_output.shape=}"
M, _ = hidden_states.shape
topk_weights = torch.empty(
M, topk, dtype=torch.float32, device=hidden_states.device
@@ -180,7 +182,7 @@ def select_experts(
num_expert_group=num_expert_group,
topk_group=topk_group,
)
elif torch_native:
elif torch_native and custom_routing_function is None:
topk_weights, topk_ids = fused_topk_native(
hidden_states=hidden_states,
gating_output=router_logits,