Improve torch compile for fused moe (#2327)

This commit is contained in:
Lianmin Zheng
2024-12-03 01:58:25 -08:00
committed by GitHub
parent 83b340e371
commit 07ec07ad1f
6 changed files with 45 additions and 24 deletions

View File

@@ -6,6 +6,7 @@ from torch.nn import functional as F
from transformers import AutoConfig
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe as fused_moe_triton
from sglang.srt.model_executor.cuda_graph_runner import set_torch_compile_config
def get_model_config(model_name: str, tp_size: int):
@@ -64,7 +65,7 @@ def fused_topk_native(
return topk_weights, topk_ids
@torch.compile
@torch.compile(dynamic=False)
def fused_moe_torch(
x,
w1,
@@ -88,7 +89,8 @@ def fused_moe_torch(
w13_weights = w1[topk_ids]
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
w2_weights = w2[topk_ids]
x1 = F.gelu(torch.einsum("ti,taoi -> tao", x, w1_weights))
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
x1 = F.silu(x1)
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
@@ -174,6 +176,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
print(f"benchmark {provider} with batch_size={batch_size}")
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
set_torch_compile_config()
num_tokens = batch_size
num_experts = model_config["num_experts"]