Fix torch.compile for MoE (#2033)

This commit is contained in:
Lianmin Zheng
2024-11-14 01:30:24 -08:00
committed by GitHub
parent b275ce0043
commit c3eac1b010
10 changed files with 89 additions and 12 deletions

View File

@@ -1,4 +1,4 @@
from typing import Optional
from typing import Callable, Optional
import torch
from torch.nn import functional as F
@@ -98,7 +98,9 @@ def fused_moe_forward_native(
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor:
assert custom_routing_function is None
topk_weights, topk_ids = select_experts_native(
hidden_states=x,
router_logits=router_logits,
@@ -114,4 +116,4 @@ def fused_moe_forward_native(
x1 = F.silu(torch.einsum("ti,taoi -> tao", x, w1_weights))
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights)
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))