Improve torch compile for fused moe (#2327)

This commit is contained in:
Lianmin Zheng
2024-12-03 01:58:25 -08:00
committed by GitHub
parent 83b340e371
commit 07ec07ad1f
6 changed files with 45 additions and 24 deletions

View File

@@ -36,7 +36,7 @@ if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
def _to_torch(model: torch.nn.Module, reverse: bool = False):
def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
for sub in model._modules.values():
if isinstance(sub, CustomOp):
if reverse:
@@ -45,24 +45,30 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
else:
# NOTE: Temporarily workaround MoE
if "FusedMoE" in sub.__class__.__name__:
sub._forward_method = fused_moe_forward_native
if batch_size == 1:
# The performance of torch.compile on this layer is not always good when bs > 1,
# so we decide to skip it for now.
sub._forward_method = fused_moe_forward_native
else:
sub._forward_method = sub.forward_native
setattr(sub, "is_torch_compile", True)
if isinstance(sub, torch.nn.Module):
_to_torch(sub, reverse)
_to_torch(sub, reverse, batch_size)
@contextmanager
def patch_model(
model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator"
model: torch.nn.Module,
enable_compile: bool,
batch_size: int,
tp_group: "GroupCoordinator",
):
"""Patch the model to make it compatible with with torch.compile"""
backup_ca_comm = None
try:
if enable_compile:
_to_torch(model)
_to_torch(model, reverse=False, batch_size=batch_size)
monkey_patch_vllm_all_gather()
backup_ca_comm = tp_group.ca_comm
# Use custom-allreduce here.
@@ -70,13 +76,15 @@ def patch_model(
# even with ENABLE_INTRA_NODE_COMM=1.
# tp_group.ca_comm = None
yield torch.compile(
torch.no_grad()(model.forward), mode="max-autotune-no-cudagraphs"
torch.no_grad()(model.forward),
mode="max-autotune-no-cudagraphs",
dynamic=False,
)
else:
yield model.forward
finally:
if enable_compile:
_to_torch(model, reverse=True)
_to_torch(model, reverse=True, batch_size=batch_size)
monkey_patch_vllm_all_gather(reverse=True)
tp_group.ca_comm = backup_ca_comm
@@ -237,6 +245,7 @@ class CudaGraphRunner:
with patch_model(
self.model_runner.model,
bs in self.compile_bs,
bs,
self.model_runner.tp_group,
) as forward:
(

View File

@@ -622,7 +622,7 @@ class ModelRunner:
tic = time.time()
logger.info("Capture cuda graph begin. This can take up to several minutes.")
self.cuda_graph_runner = CudaGraphRunner(self)
logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f}s")
logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s")
def apply_torch_tp(self):
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")