Improve torch compile for fused moe (#2327)
This commit is contained in:
@@ -6,6 +6,7 @@ from torch.nn import functional as F
|
|||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
|
|
||||||
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe as fused_moe_triton
|
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe as fused_moe_triton
|
||||||
|
from sglang.srt.model_executor.cuda_graph_runner import set_torch_compile_config
|
||||||
|
|
||||||
|
|
||||||
def get_model_config(model_name: str, tp_size: int):
|
def get_model_config(model_name: str, tp_size: int):
|
||||||
@@ -64,7 +65,7 @@ def fused_topk_native(
|
|||||||
return topk_weights, topk_ids
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
|
|
||||||
@torch.compile
|
@torch.compile(dynamic=False)
|
||||||
def fused_moe_torch(
|
def fused_moe_torch(
|
||||||
x,
|
x,
|
||||||
w1,
|
w1,
|
||||||
@@ -88,7 +89,8 @@ def fused_moe_torch(
|
|||||||
w13_weights = w1[topk_ids]
|
w13_weights = w1[topk_ids]
|
||||||
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
|
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
|
||||||
w2_weights = w2[topk_ids]
|
w2_weights = w2[topk_ids]
|
||||||
x1 = F.gelu(torch.einsum("ti,taoi -> tao", x, w1_weights))
|
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
|
||||||
|
x1 = F.silu(x1)
|
||||||
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
|
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
|
||||||
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
|
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
|
||||||
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
|
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
|
||||||
@@ -174,6 +176,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
|
|||||||
print(f"benchmark {provider} with batch_size={batch_size}")
|
print(f"benchmark {provider} with batch_size={batch_size}")
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
torch.cuda.manual_seed_all(0)
|
torch.cuda.manual_seed_all(0)
|
||||||
|
set_torch_compile_config()
|
||||||
|
|
||||||
num_tokens = batch_size
|
num_tokens = batch_size
|
||||||
num_experts = model_config["num_experts"]
|
num_experts = model_config["num_experts"]
|
||||||
|
|||||||
@@ -105,20 +105,29 @@ def fused_moe_forward_native(
|
|||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert custom_routing_function is None
|
|
||||||
topk_weights, topk_ids = select_experts_native(
|
if use_grouped_topk:
|
||||||
hidden_states=x,
|
assert num_expert_group is not None and topk_group is not None
|
||||||
router_logits=router_logits,
|
topk_weights, topk_ids = grouped_topk(
|
||||||
use_grouped_topk=use_grouped_topk,
|
x,
|
||||||
top_k=top_k,
|
router_logits,
|
||||||
renormalize=renormalize,
|
top_k,
|
||||||
topk_group=topk_group,
|
renormalize,
|
||||||
num_expert_group=num_expert_group,
|
num_expert_group,
|
||||||
)
|
topk_group,
|
||||||
|
)
|
||||||
|
elif custom_routing_function is None:
|
||||||
|
topk_weights, topk_ids = fused_topk_native(x, router_logits, top_k, renormalize)
|
||||||
|
else:
|
||||||
|
topk_weights, topk_ids = custom_routing_function(
|
||||||
|
x, router_logits, top_k, renormalize
|
||||||
|
)
|
||||||
|
|
||||||
w13_weights = layer.w13_weight[topk_ids]
|
w13_weights = layer.w13_weight[topk_ids]
|
||||||
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
|
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
|
||||||
w2_weights = layer.w2_weight[topk_ids]
|
w2_weights = layer.w2_weight[topk_ids]
|
||||||
x1 = F.silu(torch.einsum("ti,taoi -> tao", x, w1_weights))
|
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
|
||||||
|
x1 = F.silu(x1)
|
||||||
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
|
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
|
||||||
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
|
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
|
||||||
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
|
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ if TYPE_CHECKING:
|
|||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
|
|
||||||
|
|
||||||
def _to_torch(model: torch.nn.Module, reverse: bool = False):
|
def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
|
||||||
for sub in model._modules.values():
|
for sub in model._modules.values():
|
||||||
if isinstance(sub, CustomOp):
|
if isinstance(sub, CustomOp):
|
||||||
if reverse:
|
if reverse:
|
||||||
@@ -45,24 +45,30 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
|
|||||||
else:
|
else:
|
||||||
# NOTE: Temporarily workaround MoE
|
# NOTE: Temporarily workaround MoE
|
||||||
if "FusedMoE" in sub.__class__.__name__:
|
if "FusedMoE" in sub.__class__.__name__:
|
||||||
sub._forward_method = fused_moe_forward_native
|
if batch_size == 1:
|
||||||
|
# The performance of torch.compile on this layer is not always good when bs > 1,
|
||||||
|
# so we decide to skip it for now.
|
||||||
|
sub._forward_method = fused_moe_forward_native
|
||||||
else:
|
else:
|
||||||
sub._forward_method = sub.forward_native
|
sub._forward_method = sub.forward_native
|
||||||
setattr(sub, "is_torch_compile", True)
|
setattr(sub, "is_torch_compile", True)
|
||||||
if isinstance(sub, torch.nn.Module):
|
if isinstance(sub, torch.nn.Module):
|
||||||
_to_torch(sub, reverse)
|
_to_torch(sub, reverse, batch_size)
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def patch_model(
|
def patch_model(
|
||||||
model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator"
|
model: torch.nn.Module,
|
||||||
|
enable_compile: bool,
|
||||||
|
batch_size: int,
|
||||||
|
tp_group: "GroupCoordinator",
|
||||||
):
|
):
|
||||||
"""Patch the model to make it compatible with with torch.compile"""
|
"""Patch the model to make it compatible with with torch.compile"""
|
||||||
backup_ca_comm = None
|
backup_ca_comm = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if enable_compile:
|
if enable_compile:
|
||||||
_to_torch(model)
|
_to_torch(model, reverse=False, batch_size=batch_size)
|
||||||
monkey_patch_vllm_all_gather()
|
monkey_patch_vllm_all_gather()
|
||||||
backup_ca_comm = tp_group.ca_comm
|
backup_ca_comm = tp_group.ca_comm
|
||||||
# Use custom-allreduce here.
|
# Use custom-allreduce here.
|
||||||
@@ -70,13 +76,15 @@ def patch_model(
|
|||||||
# even with ENABLE_INTRA_NODE_COMM=1.
|
# even with ENABLE_INTRA_NODE_COMM=1.
|
||||||
# tp_group.ca_comm = None
|
# tp_group.ca_comm = None
|
||||||
yield torch.compile(
|
yield torch.compile(
|
||||||
torch.no_grad()(model.forward), mode="max-autotune-no-cudagraphs"
|
torch.no_grad()(model.forward),
|
||||||
|
mode="max-autotune-no-cudagraphs",
|
||||||
|
dynamic=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
yield model.forward
|
yield model.forward
|
||||||
finally:
|
finally:
|
||||||
if enable_compile:
|
if enable_compile:
|
||||||
_to_torch(model, reverse=True)
|
_to_torch(model, reverse=True, batch_size=batch_size)
|
||||||
monkey_patch_vllm_all_gather(reverse=True)
|
monkey_patch_vllm_all_gather(reverse=True)
|
||||||
tp_group.ca_comm = backup_ca_comm
|
tp_group.ca_comm = backup_ca_comm
|
||||||
|
|
||||||
@@ -237,6 +245,7 @@ class CudaGraphRunner:
|
|||||||
with patch_model(
|
with patch_model(
|
||||||
self.model_runner.model,
|
self.model_runner.model,
|
||||||
bs in self.compile_bs,
|
bs in self.compile_bs,
|
||||||
|
bs,
|
||||||
self.model_runner.tp_group,
|
self.model_runner.tp_group,
|
||||||
) as forward:
|
) as forward:
|
||||||
(
|
(
|
||||||
|
|||||||
@@ -622,7 +622,7 @@ class ModelRunner:
|
|||||||
tic = time.time()
|
tic = time.time()
|
||||||
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
||||||
self.cuda_graph_runner = CudaGraphRunner(self)
|
self.cuda_graph_runner = CudaGraphRunner(self)
|
||||||
logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f}s")
|
logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s")
|
||||||
|
|
||||||
def apply_torch_tp(self):
|
def apply_torch_tp(self):
|
||||||
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
|
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
|
||||||
|
|||||||
@@ -188,7 +188,7 @@ class TestSRTEngine(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
bench_args = BenchArgs(num_prompts=10)
|
bench_args = BenchArgs(num_prompts=10)
|
||||||
result = throughput_test(server_args=server_args, bench_args=bench_args)
|
result = throughput_test(server_args=server_args, bench_args=bench_args)
|
||||||
self.assertGreater(result["total_throughput"], 3500)
|
self.assertGreater(result["total_throughput"], 3000)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from sglang.test.test_utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestTorchCompile(unittest.TestCase):
|
class TestTorchCompileMoe(unittest.TestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
cls.model = DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST
|
cls.model = DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST
|
||||||
@@ -23,7 +23,7 @@ class TestTorchCompile(unittest.TestCase):
|
|||||||
cls.model,
|
cls.model,
|
||||||
cls.base_url,
|
cls.base_url,
|
||||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
other_args=["--enable-torch-compile", "--torch-compile-max-bs", "1"],
|
other_args=["--enable-torch-compile", "--torch-compile-max-bs", "8"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
Reference in New Issue
Block a user