Improve torch compile for fused moe (#2327)
This commit is contained in:
@@ -36,7 +36,7 @@ if TYPE_CHECKING:
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
|
||||
|
||||
def _to_torch(model: torch.nn.Module, reverse: bool = False):
|
||||
def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
|
||||
for sub in model._modules.values():
|
||||
if isinstance(sub, CustomOp):
|
||||
if reverse:
|
||||
@@ -45,24 +45,30 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
|
||||
else:
|
||||
# NOTE: Temporarily workaround MoE
|
||||
if "FusedMoE" in sub.__class__.__name__:
|
||||
sub._forward_method = fused_moe_forward_native
|
||||
if batch_size == 1:
|
||||
# The performance of torch.compile on this layer is not always good when bs > 1,
|
||||
# so we decide to skip it for now.
|
||||
sub._forward_method = fused_moe_forward_native
|
||||
else:
|
||||
sub._forward_method = sub.forward_native
|
||||
setattr(sub, "is_torch_compile", True)
|
||||
if isinstance(sub, torch.nn.Module):
|
||||
_to_torch(sub, reverse)
|
||||
_to_torch(sub, reverse, batch_size)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def patch_model(
|
||||
model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator"
|
||||
model: torch.nn.Module,
|
||||
enable_compile: bool,
|
||||
batch_size: int,
|
||||
tp_group: "GroupCoordinator",
|
||||
):
|
||||
"""Patch the model to make it compatible with with torch.compile"""
|
||||
backup_ca_comm = None
|
||||
|
||||
try:
|
||||
if enable_compile:
|
||||
_to_torch(model)
|
||||
_to_torch(model, reverse=False, batch_size=batch_size)
|
||||
monkey_patch_vllm_all_gather()
|
||||
backup_ca_comm = tp_group.ca_comm
|
||||
# Use custom-allreduce here.
|
||||
@@ -70,13 +76,15 @@ def patch_model(
|
||||
# even with ENABLE_INTRA_NODE_COMM=1.
|
||||
# tp_group.ca_comm = None
|
||||
yield torch.compile(
|
||||
torch.no_grad()(model.forward), mode="max-autotune-no-cudagraphs"
|
||||
torch.no_grad()(model.forward),
|
||||
mode="max-autotune-no-cudagraphs",
|
||||
dynamic=False,
|
||||
)
|
||||
else:
|
||||
yield model.forward
|
||||
finally:
|
||||
if enable_compile:
|
||||
_to_torch(model, reverse=True)
|
||||
_to_torch(model, reverse=True, batch_size=batch_size)
|
||||
monkey_patch_vllm_all_gather(reverse=True)
|
||||
tp_group.ca_comm = backup_ca_comm
|
||||
|
||||
@@ -237,6 +245,7 @@ class CudaGraphRunner:
|
||||
with patch_model(
|
||||
self.model_runner.model,
|
||||
bs in self.compile_bs,
|
||||
bs,
|
||||
self.model_runner.tp_group,
|
||||
) as forward:
|
||||
(
|
||||
|
||||
@@ -622,7 +622,7 @@ class ModelRunner:
|
||||
tic = time.time()
|
||||
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
||||
self.cuda_graph_runner = CudaGraphRunner(self)
|
||||
logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f}s")
|
||||
logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s")
|
||||
|
||||
def apply_torch_tp(self):
|
||||
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
|
||||
|
||||
Reference in New Issue
Block a user