diff --git a/python/sglang/srt/custom_op.py b/python/sglang/srt/custom_op.py index 5e0f4bd1e..fe6176f4e 100644 --- a/python/sglang/srt/custom_op.py +++ b/python/sglang/srt/custom_op.py @@ -1,6 +1,3 @@ -from typing import Optional - -import torch from torch import nn from sglang.srt.utils import is_cuda, is_hip @@ -14,6 +11,26 @@ class CustomOp(nn.Module): super().__init__() self._forward_method = self.dispatch_forward() + def enter_torch_compile(self, num_tokens: int): + # NOTE: Temporarily workaround MoE + if "FusedMoE" in self.__class__.__name__: + if num_tokens == 1: + from sglang.srt.layers.moe.fused_moe_native import ( + fused_moe_forward_native, + ) + + # The performance of torch.compile on this layer is not always good when bs > 1, + # so we decide to only use torch.compile when bs =1 + self._forward_method = fused_moe_forward_native + else: + self._forward_method = self.forward_native + self.is_torch_compile = True + + def leave_torch_compile(self): + self._forward_method = self.forward_cuda + self.is_torch_compile = False + + # Please do not override this method, because `self._forward_method` can change when in torch compile mode def forward(self, *args, **kwargs): return self._forward_method(*args, **kwargs) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index cc67ec1eb..990452539 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -28,7 +28,6 @@ from sglang.srt.custom_op import CustomOp from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native from sglang.srt.layers.torchao_utils import save_gemlite_cache from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ( @@ -60,18 +59,9 @@ def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int): for sub in model._modules.values(): if isinstance(sub, CustomOp): if reverse: - sub._forward_method = sub.forward_cuda - setattr(sub, "is_torch_compile", False) + sub.leave_torch_compile() else: - # NOTE: Temporarily workaround MoE - if "FusedMoE" in sub.__class__.__name__: - if num_tokens == 1: - # The performance of torch.compile on this layer is not always good when bs > 1, - # so we decide to only use torch.compile when bs =1 - sub._forward_method = fused_moe_forward_native - else: - sub._forward_method = sub.forward_native - setattr(sub, "is_torch_compile", True) + sub.enter_torch_compile(num_tokens=num_tokens) if isinstance(sub, torch.nn.Module): _to_torch(sub, reverse, num_tokens)