From 4eb4b401cc552cab162165e22e1428086eb0f874 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sat, 1 Feb 2025 18:56:44 +0800 Subject: [PATCH] update and simplify CustomOp (#3249) --- python/sglang/srt/custom_op.py | 40 +++++++++++++++++++ python/sglang/srt/layers/activation.py | 6 +-- python/sglang/srt/layers/custom_op_util.py | 25 ------------ python/sglang/srt/layers/layernorm.py | 6 +-- python/sglang/srt/layers/moe/ep_moe/layer.py | 4 +- .../srt/layers/moe/fused_moe_triton/layer.py | 4 +- python/sglang/srt/layers/rotary_embedding.py | 4 +- .../srt/model_executor/cuda_graph_runner.py | 2 +- 8 files changed, 46 insertions(+), 45 deletions(-) create mode 100644 python/sglang/srt/custom_op.py delete mode 100644 python/sglang/srt/layers/custom_op_util.py diff --git a/python/sglang/srt/custom_op.py b/python/sglang/srt/custom_op.py new file mode 100644 index 000000000..a702e8f82 --- /dev/null +++ b/python/sglang/srt/custom_op.py @@ -0,0 +1,40 @@ +import torch +from torch import nn + +_is_cuda = torch.cuda.is_available() and torch.version.cuda +_is_rocm = torch.cuda.is_available() and torch.version.hip + + +class CustomOp(nn.Module): + def __init__(self): + super().__init__() + self._forward_method = self.dispatch_forward() + + def forward(self, *args, **kwargs): + return self._forward_method(*args, **kwargs) + + def forward_native(self, *args, **kwargs): + raise NotImplementedError + + def forward_cuda(self, *args, **kwargs): + raise NotImplementedError + + def forward_hip(self, *args, **kwargs): + raise NotImplementedError + + def forward_xpu(self, *args, **kwargs): + return self.forward_native(*args, **kwargs) + + def forward_hpu(self, *args, **kwargs): + return self.forward_native(*args, **kwargs) + + def forward_cpu(self, *args, **kwargs): + return self.forward_native(*args, **kwargs) + + def dispatch_forward(self): + if _is_cuda: + return self.forward_cuda + elif _is_rocm: + return self.forward_hip + else: + return self.forward_native diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index d69d854ab..08ea91b9c 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -25,21 +25,18 @@ from sglang.srt.utils import is_cuda_available if is_cuda_available(): from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul -from vllm.model_executor.custom_op import CustomOp - +from sglang.srt.custom_op import CustomOp from sglang.srt.distributed import ( divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) -from sglang.srt.layers.custom_op_util import register_custom_op from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.utils import set_weight_attrs logger = logging.getLogger(__name__) -@register_custom_op("sglang_silu_and_mul") class SiluAndMul(CustomOp): def forward_native(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 @@ -53,7 +50,6 @@ class SiluAndMul(CustomOp): return out -@register_custom_op("sglang_gelu_and_mul") class GeluAndMul(CustomOp): def __init__(self, approximate="tanh"): super().__init__() diff --git a/python/sglang/srt/layers/custom_op_util.py b/python/sglang/srt/layers/custom_op_util.py deleted file mode 100644 index 92e186cd2..000000000 --- a/python/sglang/srt/layers/custom_op_util.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from vllm.model_executor.custom_op import CustomOp - - -def register_custom_op(op_name): - def decorator(cls): - if hasattr(CustomOp, "register"): - return CustomOp.register(op_name)(cls) - else: - return cls - - return decorator diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 207ba8d1b..e3b23a2a9 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -29,14 +29,11 @@ if is_cuda_available(): rmsnorm, ) -from vllm.model_executor.custom_op import CustomOp - -from sglang.srt.layers.custom_op_util import register_custom_op +from sglang.srt.custom_op import CustomOp logger = logging.getLogger(__name__) -@register_custom_op("sglang_rmsnorm") class RMSNorm(CustomOp): def __init__( self, @@ -79,7 +76,6 @@ class RMSNorm(CustomOp): return x, residual -@register_custom_op("sglang_gemma_rmsnorm") class GemmaRMSNorm(CustomOp): def __init__( self, diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index bc927621a..4d6040646 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -4,13 +4,12 @@ from typing import Callable, List, Optional, Tuple import torch from torch.nn import Module from vllm import _custom_ops as ops -from vllm.model_executor.custom_op import CustomOp +from sglang.srt.custom_op import CustomOp from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) -from sglang.srt.layers.custom_op_util import register_custom_op from sglang.srt.layers.moe.ep_moe.kernels import ( grouped_gemm_triton, post_reorder_triton_kernel, @@ -407,7 +406,6 @@ class EPMoE(torch.nn.Module): param_data[expert_id] = loaded_weight -@register_custom_op("sglang_unquantized_ep_moe") class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp): def create_weights( self, diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index b71a878a0..dc7152da9 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -5,14 +5,13 @@ from enum import Enum from typing import Callable, List, Optional, Tuple import torch -from vllm.model_executor.custom_op import CustomOp +from sglang.srt.custom_op import CustomOp from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) -from sglang.srt.layers.custom_op_util import register_custom_op from sglang.srt.layers.moe.fused_moe_native import moe_forward_native from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.base_config import ( @@ -67,7 +66,6 @@ class FusedMoEMethodBase(QuantizeMethodBase): raise NotImplementedError -@register_custom_op("sglang_unquantized_fused_moe") class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 7093bb90d..ef8a96c98 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -7,9 +7,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn from vllm import _custom_ops as ops -from vllm.model_executor.custom_op import CustomOp -from sglang.srt.layers.custom_op_util import register_custom_op +from sglang.srt.custom_op import CustomOp from sglang.srt.utils import is_cuda_available _is_cuda_available = is_cuda_available() @@ -59,7 +58,6 @@ def _apply_rotary_emb( return torch.stack((o1, o2), dim=-1).flatten(-2) -@register_custom_op("sglang_rotary_embedding") class RotaryEmbedding(CustomOp): """Original rotary positional embedding.""" diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 93b4d0ea5..69615b8ff 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -21,8 +21,8 @@ from typing import TYPE_CHECKING, Callable import torch import tqdm -from vllm.model_executor.custom_op import CustomOp +from sglang.srt.custom_op import CustomOp from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture from sglang.srt.layers.logits_processor import LogitsProcessorOutput