Refactor CustomOp to avoid confusing bugs (#5382)
This commit is contained in:
@@ -1,6 +1,3 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from sglang.srt.utils import is_cuda, is_hip
|
from sglang.srt.utils import is_cuda, is_hip
|
||||||
@@ -14,6 +11,26 @@ class CustomOp(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self._forward_method = self.dispatch_forward()
|
self._forward_method = self.dispatch_forward()
|
||||||
|
|
||||||
|
def enter_torch_compile(self, num_tokens: int):
|
||||||
|
# NOTE: Temporarily workaround MoE
|
||||||
|
if "FusedMoE" in self.__class__.__name__:
|
||||||
|
if num_tokens == 1:
|
||||||
|
from sglang.srt.layers.moe.fused_moe_native import (
|
||||||
|
fused_moe_forward_native,
|
||||||
|
)
|
||||||
|
|
||||||
|
# The performance of torch.compile on this layer is not always good when bs > 1,
|
||||||
|
# so we decide to only use torch.compile when bs =1
|
||||||
|
self._forward_method = fused_moe_forward_native
|
||||||
|
else:
|
||||||
|
self._forward_method = self.forward_native
|
||||||
|
self.is_torch_compile = True
|
||||||
|
|
||||||
|
def leave_torch_compile(self):
|
||||||
|
self._forward_method = self.forward_cuda
|
||||||
|
self.is_torch_compile = False
|
||||||
|
|
||||||
|
# Please do not override this method, because `self._forward_method` can change when in torch compile mode
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
return self._forward_method(*args, **kwargs)
|
return self._forward_method(*args, **kwargs)
|
||||||
|
|
||||||
|
|||||||
@@ -28,7 +28,6 @@ from sglang.srt.custom_op import CustomOp
|
|||||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||||
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
|
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
|
|
||||||
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import (
|
from sglang.srt.model_executor.forward_batch_info import (
|
||||||
@@ -60,18 +59,9 @@ def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
|
|||||||
for sub in model._modules.values():
|
for sub in model._modules.values():
|
||||||
if isinstance(sub, CustomOp):
|
if isinstance(sub, CustomOp):
|
||||||
if reverse:
|
if reverse:
|
||||||
sub._forward_method = sub.forward_cuda
|
sub.leave_torch_compile()
|
||||||
setattr(sub, "is_torch_compile", False)
|
|
||||||
else:
|
else:
|
||||||
# NOTE: Temporarily workaround MoE
|
sub.enter_torch_compile(num_tokens=num_tokens)
|
||||||
if "FusedMoE" in sub.__class__.__name__:
|
|
||||||
if num_tokens == 1:
|
|
||||||
# The performance of torch.compile on this layer is not always good when bs > 1,
|
|
||||||
# so we decide to only use torch.compile when bs =1
|
|
||||||
sub._forward_method = fused_moe_forward_native
|
|
||||||
else:
|
|
||||||
sub._forward_method = sub.forward_native
|
|
||||||
setattr(sub, "is_torch_compile", True)
|
|
||||||
if isinstance(sub, torch.nn.Module):
|
if isinstance(sub, torch.nn.Module):
|
||||||
_to_torch(sub, reverse, num_tokens)
|
_to_torch(sub, reverse, num_tokens)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user