134 lines
4.4 KiB
Python
134 lines
4.4 KiB
Python
"""
|
|
Torch-native implementation for FusedMoE. This is used for torch.compile.
|
|
It is based on https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/mixtral-moe/model.py#L204
|
|
"""
|
|
|
|
from typing import Callable, Optional
|
|
|
|
import torch
|
|
from torch.nn import functional as F
|
|
|
|
|
|
def fused_topk_native(
|
|
hidden_states: torch.Tensor,
|
|
gating_output: torch.Tensor,
|
|
topk: int,
|
|
renormalize: bool,
|
|
):
|
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
|
M, _ = hidden_states.shape
|
|
topk_weights = torch.empty(
|
|
M, topk, dtype=torch.float32, device=hidden_states.device
|
|
)
|
|
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
|
topk_weights = F.softmax(gating_output.float(), dim=-1)
|
|
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
|
|
if renormalize:
|
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
|
return topk_weights, topk_ids
|
|
|
|
|
|
# This is used by the Deepseek-V2 model
|
|
def grouped_topk(
|
|
hidden_states: torch.Tensor,
|
|
gating_output: torch.Tensor,
|
|
topk: int,
|
|
renormalize: bool,
|
|
num_expert_group: int = 0,
|
|
topk_group: int = 0,
|
|
):
|
|
|
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
|
|
|
scores = torch.softmax(gating_output, dim=-1)
|
|
num_token = scores.shape[0]
|
|
group_scores = (
|
|
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
|
) # [n, n_group]
|
|
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
|
1
|
|
] # [n, top_k_group]
|
|
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
|
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
|
score_mask = (
|
|
group_mask.unsqueeze(-1)
|
|
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
|
.reshape(num_token, -1)
|
|
) # [n, e]
|
|
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
|
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
|
|
|
if renormalize:
|
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
|
return topk_weights, topk_ids
|
|
|
|
|
|
def select_experts_native(
|
|
hidden_states: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
top_k: int,
|
|
use_grouped_topk: bool,
|
|
renormalize: bool,
|
|
topk_group: Optional[int] = None,
|
|
num_expert_group: Optional[int] = None,
|
|
):
|
|
# DeekSeekv2 uses grouped_top_k
|
|
if use_grouped_topk:
|
|
assert topk_group is not None
|
|
assert num_expert_group is not None
|
|
topk_weights, topk_ids = grouped_topk(
|
|
hidden_states=hidden_states,
|
|
gating_output=router_logits,
|
|
topk=top_k,
|
|
renormalize=renormalize,
|
|
num_expert_group=num_expert_group,
|
|
topk_group=topk_group,
|
|
)
|
|
else:
|
|
topk_weights, topk_ids = fused_topk_native(
|
|
hidden_states=hidden_states,
|
|
gating_output=router_logits,
|
|
topk=top_k,
|
|
renormalize=renormalize,
|
|
)
|
|
return topk_weights, topk_ids
|
|
|
|
|
|
def fused_moe_forward_native(
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
use_grouped_topk: bool,
|
|
top_k: int,
|
|
router_logits: torch.Tensor,
|
|
renormalize: bool,
|
|
topk_group: Optional[int] = None,
|
|
num_expert_group: Optional[int] = None,
|
|
custom_routing_function: Optional[Callable] = None,
|
|
) -> torch.Tensor:
|
|
|
|
if use_grouped_topk:
|
|
assert num_expert_group is not None and topk_group is not None
|
|
topk_weights, topk_ids = grouped_topk(
|
|
x,
|
|
router_logits,
|
|
top_k,
|
|
renormalize,
|
|
num_expert_group,
|
|
topk_group,
|
|
)
|
|
elif custom_routing_function is None:
|
|
topk_weights, topk_ids = fused_topk_native(x, router_logits, top_k, renormalize)
|
|
else:
|
|
topk_weights, topk_ids = custom_routing_function(
|
|
x, router_logits, top_k, renormalize
|
|
)
|
|
|
|
w13_weights = layer.w13_weight[topk_ids]
|
|
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
|
|
w2_weights = layer.w2_weight[topk_ids]
|
|
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
|
|
x1 = F.silu(x1)
|
|
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
|
|
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
|
|
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
|