Files
sglang/python/sglang/srt/custom_op.py
2025-09-07 22:34:20 -07:00

103 lines
3.2 KiB
Python

from torch import nn
from sglang.srt.utils import (
cpu_has_amx_support,
is_cpu,
is_cuda,
is_hip,
is_npu,
is_xpu,
)
_is_cuda = is_cuda()
_is_hip = is_hip()
_is_cpu = is_cpu()
_is_cpu_amx_available = cpu_has_amx_support()
_is_npu = is_npu()
_is_xpu = is_xpu()
class CustomOp(nn.Module):
def __init__(self):
super().__init__()
self._forward_method = self.dispatch_forward()
# States for torch.compile
self._original_forward_method = None
self.is_torch_compile = False
def enter_torch_compile(self, num_tokens: int):
# Skip if Op is already entered compile mode.
# NOTE(alcanderian): Some Ops(for example RotaryEmbedding) will be reused
# among layers and `enter_torch_compile` will be called many times.
# We should prevent `self._original_forward_method` from being overridden when
# it is not the first time `enter_torch_compile` called.
if self.is_torch_compile:
return
self._original_forward_method = self._forward_method
# NOTE: Temporarily workaround MoE
# The performance of torch.compile on this layer is not always good when bs > 1,
# so we decide to only use torch.compile when bs=1
if "FusedMoE" in self.__class__.__name__:
if num_tokens == 1:
from sglang.srt.layers.moe.fused_moe_native import (
fused_moe_forward_native,
)
self._forward_method = fused_moe_forward_native
elif "TopK" in self.__class__.__name__:
if num_tokens == 1:
self._forward_method = self.forward_native
else:
self._forward_method = self.forward_native
self.is_torch_compile = True
def leave_torch_compile(self):
# Skip if Op is already exited compile mode.
if not self.is_torch_compile:
return
self._forward_method = self._original_forward_method
self._original_forward_method = None
self.is_torch_compile = False
# Please do not override this method, because `self._forward_method` can change when in torch compile mode
def forward(self, *args, **kwargs):
return self._forward_method(*args, **kwargs)
def forward_native(self, *args, **kwargs):
raise NotImplementedError
def forward_cuda(self, *args, **kwargs):
raise NotImplementedError
def forward_npu(self, *args, **kwargs):
raise NotImplementedError
def forward_hip(self, *args, **kwargs):
return self.forward_cuda(*args, **kwargs)
def forward_xpu(self, *args, **kwargs):
return self.forward_native(*args, **kwargs)
def forward_hpu(self, *args, **kwargs):
return self.forward_native(*args, **kwargs)
def forward_cpu(self, *args, **kwargs):
return self.forward_native(*args, **kwargs)
def dispatch_forward(self):
if _is_cuda:
return self.forward_cuda
elif _is_hip:
return self.forward_hip
elif _is_cpu and _is_cpu_amx_available:
return self.forward_cpu
elif _is_npu:
return self.forward_npu
elif _is_xpu:
return self.forward_xpu
else:
return self.forward_native