Fix linear.py and improve weight loading (#2851)
Co-authored-by: SangBin Cho <rkooo567@gmail.com>
This commit is contained in:
@@ -24,7 +24,9 @@ def fused_topk_native(
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
):
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
assert (
|
||||
hidden_states.shape[0] == gating_output.shape[0]
|
||||
), f"Number of tokens mismatch, {hidden_states.shape=} vs {gating_output.shape=}"
|
||||
M, _ = hidden_states.shape
|
||||
topk_weights = torch.empty(
|
||||
M, topk, dtype=torch.float32, device=hidden_states.device
|
||||
@@ -180,7 +182,7 @@ def select_experts(
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
)
|
||||
elif torch_native:
|
||||
elif torch_native and custom_routing_function is None:
|
||||
topk_weights, topk_ids = fused_topk_native(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
|
||||
Reference in New Issue
Block a user