Refactor CustomOp to avoid confusing bugs (#5382)

This commit is contained in:
fzyzcjy
2025-06-03 01:27:36 +08:00
committed by GitHub
parent a2cb5913a0
commit e05e29d178
2 changed files with 22 additions and 15 deletions

View File

@@ -28,7 +28,6 @@ from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
from sglang.srt.layers.torchao_utils import save_gemlite_cache
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import (
@@ -60,18 +59,9 @@ def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
for sub in model._modules.values():
if isinstance(sub, CustomOp):
if reverse:
sub._forward_method = sub.forward_cuda
setattr(sub, "is_torch_compile", False)
sub.leave_torch_compile()
else:
# NOTE: Temporarily workaround MoE
if "FusedMoE" in sub.__class__.__name__:
if num_tokens == 1:
# The performance of torch.compile on this layer is not always good when bs > 1,
# so we decide to only use torch.compile when bs =1
sub._forward_method = fused_moe_forward_native
else:
sub._forward_method = sub.forward_native
setattr(sub, "is_torch_compile", True)
sub.enter_torch_compile(num_tokens=num_tokens)
if isinstance(sub, torch.nn.Module):
_to_torch(sub, reverse, num_tokens)