Improve torch compile for fused moe (#2327)

This commit is contained in:
Lianmin Zheng
2024-12-03 01:58:25 -08:00
committed by GitHub
parent 83b340e371
commit 07ec07ad1f
6 changed files with 45 additions and 24 deletions

View File

@@ -105,20 +105,29 @@ def fused_moe_forward_native(
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor:
assert custom_routing_function is None
topk_weights, topk_ids = select_experts_native(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
)
if use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
topk_weights, topk_ids = grouped_topk(
x,
router_logits,
top_k,
renormalize,
num_expert_group,
topk_group,
)
elif custom_routing_function is None:
topk_weights, topk_ids = fused_topk_native(x, router_logits, top_k, renormalize)
else:
topk_weights, topk_ids = custom_routing_function(
x, router_logits, top_k, renormalize
)
w13_weights = layer.w13_weight[topk_ids]
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
w2_weights = layer.w2_weight[topk_ids]
x1 = F.silu(torch.einsum("ti,taoi -> tao", x, w1_weights))
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
x1 = F.silu(x1)
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))

View File

@@ -36,7 +36,7 @@ if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
def _to_torch(model: torch.nn.Module, reverse: bool = False):
def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
for sub in model._modules.values():
if isinstance(sub, CustomOp):
if reverse:
@@ -45,24 +45,30 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
else:
# NOTE: Temporarily workaround MoE
if "FusedMoE" in sub.__class__.__name__:
sub._forward_method = fused_moe_forward_native
if batch_size == 1:
# The performance of torch.compile on this layer is not always good when bs > 1,
# so we decide to skip it for now.
sub._forward_method = fused_moe_forward_native
else:
sub._forward_method = sub.forward_native
setattr(sub, "is_torch_compile", True)
if isinstance(sub, torch.nn.Module):
_to_torch(sub, reverse)
_to_torch(sub, reverse, batch_size)
@contextmanager
def patch_model(
model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator"
model: torch.nn.Module,
enable_compile: bool,
batch_size: int,
tp_group: "GroupCoordinator",
):
"""Patch the model to make it compatible with with torch.compile"""
backup_ca_comm = None
try:
if enable_compile:
_to_torch(model)
_to_torch(model, reverse=False, batch_size=batch_size)
monkey_patch_vllm_all_gather()
backup_ca_comm = tp_group.ca_comm
# Use custom-allreduce here.
@@ -70,13 +76,15 @@ def patch_model(
# even with ENABLE_INTRA_NODE_COMM=1.
# tp_group.ca_comm = None
yield torch.compile(
torch.no_grad()(model.forward), mode="max-autotune-no-cudagraphs"
torch.no_grad()(model.forward),
mode="max-autotune-no-cudagraphs",
dynamic=False,
)
else:
yield model.forward
finally:
if enable_compile:
_to_torch(model, reverse=True)
_to_torch(model, reverse=True, batch_size=batch_size)
monkey_patch_vllm_all_gather(reverse=True)
tp_group.ca_comm = backup_ca_comm
@@ -237,6 +245,7 @@ class CudaGraphRunner:
with patch_model(
self.model_runner.model,
bs in self.compile_bs,
bs,
self.model_runner.tp_group,
) as forward:
(

View File

@@ -622,7 +622,7 @@ class ModelRunner:
tic = time.time()
logger.info("Capture cuda graph begin. This can take up to several minutes.")
self.cuda_graph_runner = CudaGraphRunner(self)
logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f}s")
logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s")
def apply_torch_tp(self):
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")