MoE torch compile (#1497)
This commit is contained in:
117
python/sglang/srt/layers/fused_moe/patch.py
Normal file
117
python/sglang/srt/layers/fused_moe/patch.py
Normal file
@@ -0,0 +1,117 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def fused_topk_native(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
):
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
M, _ = hidden_states.shape
|
||||
topk_weights = torch.empty(
|
||||
M, topk, dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
||||
topk_weights = F.softmax(gating_output.float(), dim=-1)
|
||||
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
# This is used by the Deepseek-V2 model
|
||||
def grouped_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
):
|
||||
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
|
||||
scores = torch.softmax(gating_output, dim=-1)
|
||||
num_token = scores.shape[0]
|
||||
group_scores = (
|
||||
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
||||
) # [n, n_group]
|
||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
||||
1
|
||||
] # [n, top_k_group]
|
||||
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
||||
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1)
|
||||
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
||||
.reshape(num_token, -1)
|
||||
) # [n, e]
|
||||
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
||||
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def select_experts_native(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
):
|
||||
# DeekSeekv2 uses grouped_top_k
|
||||
if use_grouped_topk:
|
||||
assert topk_group is not None
|
||||
assert num_expert_group is not None
|
||||
topk_weights, topk_ids = grouped_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = fused_topk_native(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def fused_moe_forward_native(
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
topk_weights, topk_ids = select_experts_native(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
)
|
||||
w13_weights = layer.w13_weight[topk_ids]
|
||||
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
|
||||
w2_weights = layer.w2_weight[topk_ids]
|
||||
x1 = F.silu(torch.einsum("ti,taoi -> tao", x, w1_weights))
|
||||
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
|
||||
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
|
||||
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights)
|
||||
@@ -25,6 +25,7 @@ import torch
|
||||
from vllm.distributed.parallel_state import graph_capture
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
from sglang.srt.layers.fused_moe.patch import fused_moe_forward_native
|
||||
from sglang.srt.layers.logits_processor import (
|
||||
LogitsMetadata,
|
||||
LogitsProcessor,
|
||||
@@ -41,14 +42,15 @@ if TYPE_CHECKING:
|
||||
def _to_torch(model: torch.nn.Module, reverse: bool = False):
|
||||
for sub in model._modules.values():
|
||||
if isinstance(sub, CustomOp):
|
||||
# NOTE: FusedMoE torch native implementaiton is not efficient
|
||||
if "FusedMoE" in sub.__class__.__name__:
|
||||
continue
|
||||
if reverse:
|
||||
sub._forward_method = sub.forward_cuda
|
||||
setattr(sub, "is_torch_compile", False)
|
||||
else:
|
||||
sub._forward_method = sub.forward_native
|
||||
# NOTE: Temporarily workaround MoE
|
||||
if "FusedMoE" in sub.__class__.__name__:
|
||||
sub._forward_method = fused_moe_forward_native
|
||||
else:
|
||||
sub._forward_method = sub.forward_native
|
||||
setattr(sub, "is_torch_compile", True)
|
||||
if isinstance(sub, torch.nn.Module):
|
||||
_to_torch(sub, reverse)
|
||||
@@ -67,7 +69,9 @@ def patch_model(
|
||||
monkey_patch_vllm_all_gather()
|
||||
backup_ca_comm = tp_group.ca_comm
|
||||
tp_group.ca_comm = None
|
||||
yield torch.compile(model.forward, mode="max-autotune-no-cudagraphs")
|
||||
yield torch.compile(
|
||||
torch.no_grad()(model.forward), mode="max-autotune-no-cudagraphs"
|
||||
)
|
||||
else:
|
||||
yield model.forward
|
||||
finally:
|
||||
|
||||
Reference in New Issue
Block a user