Improve torch compile for fused moe (#2327)
This commit is contained in:
@@ -6,6 +6,7 @@ from torch.nn import functional as F
|
||||
from transformers import AutoConfig
|
||||
|
||||
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe as fused_moe_triton
|
||||
from sglang.srt.model_executor.cuda_graph_runner import set_torch_compile_config
|
||||
|
||||
|
||||
def get_model_config(model_name: str, tp_size: int):
|
||||
@@ -64,7 +65,7 @@ def fused_topk_native(
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
@torch.compile
|
||||
@torch.compile(dynamic=False)
|
||||
def fused_moe_torch(
|
||||
x,
|
||||
w1,
|
||||
@@ -88,7 +89,8 @@ def fused_moe_torch(
|
||||
w13_weights = w1[topk_ids]
|
||||
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
|
||||
w2_weights = w2[topk_ids]
|
||||
x1 = F.gelu(torch.einsum("ti,taoi -> tao", x, w1_weights))
|
||||
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
|
||||
x1 = F.silu(x1)
|
||||
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
|
||||
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
|
||||
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
|
||||
@@ -174,6 +176,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
|
||||
print(f"benchmark {provider} with batch_size={batch_size}")
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.manual_seed_all(0)
|
||||
set_torch_compile_config()
|
||||
|
||||
num_tokens = batch_size
|
||||
num_experts = model_config["num_experts"]
|
||||
|
||||
Reference in New Issue
Block a user