update and simplify CustomOp (#3249)
This commit is contained in:
40
python/sglang/srt/custom_op.py
Normal file
40
python/sglang/srt/custom_op.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
||||
_is_rocm = torch.cuda.is_available() and torch.version.hip
|
||||
|
||||
|
||||
class CustomOp(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._forward_method = self.dispatch_forward()
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self._forward_method(*args, **kwargs)
|
||||
|
||||
def forward_native(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward_cuda(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward_hip(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward_xpu(self, *args, **kwargs):
|
||||
return self.forward_native(*args, **kwargs)
|
||||
|
||||
def forward_hpu(self, *args, **kwargs):
|
||||
return self.forward_native(*args, **kwargs)
|
||||
|
||||
def forward_cpu(self, *args, **kwargs):
|
||||
return self.forward_native(*args, **kwargs)
|
||||
|
||||
def dispatch_forward(self):
|
||||
if _is_cuda:
|
||||
return self.forward_cuda
|
||||
elif _is_rocm:
|
||||
return self.forward_hip
|
||||
else:
|
||||
return self.forward_native
|
||||
@@ -25,21 +25,18 @@ from sglang.srt.utils import is_cuda_available
|
||||
if is_cuda_available():
|
||||
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
||||
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.distributed import (
|
||||
divide,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from sglang.srt.layers.custom_op_util import register_custom_op
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.utils import set_weight_attrs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_custom_op("sglang_silu_and_mul")
|
||||
class SiluAndMul(CustomOp):
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
@@ -53,7 +50,6 @@ class SiluAndMul(CustomOp):
|
||||
return out
|
||||
|
||||
|
||||
@register_custom_op("sglang_gelu_and_mul")
|
||||
class GeluAndMul(CustomOp):
|
||||
def __init__(self, approximate="tanh"):
|
||||
super().__init__()
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
|
||||
def register_custom_op(op_name):
|
||||
def decorator(cls):
|
||||
if hasattr(CustomOp, "register"):
|
||||
return CustomOp.register(op_name)(cls)
|
||||
else:
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
@@ -29,14 +29,11 @@ if is_cuda_available():
|
||||
rmsnorm,
|
||||
)
|
||||
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
from sglang.srt.layers.custom_op_util import register_custom_op
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_custom_op("sglang_rmsnorm")
|
||||
class RMSNorm(CustomOp):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -79,7 +76,6 @@ class RMSNorm(CustomOp):
|
||||
return x, residual
|
||||
|
||||
|
||||
@register_custom_op("sglang_gemma_rmsnorm")
|
||||
class GemmaRMSNorm(CustomOp):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -4,13 +4,12 @@ from typing import Callable, List, Optional, Tuple
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from sglang.srt.layers.custom_op_util import register_custom_op
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
grouped_gemm_triton,
|
||||
post_reorder_triton_kernel,
|
||||
@@ -407,7 +406,6 @@ class EPMoE(torch.nn.Module):
|
||||
param_data[expert_id] = loaded_weight
|
||||
|
||||
|
||||
@register_custom_op("sglang_unquantized_ep_moe")
|
||||
class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
def create_weights(
|
||||
self,
|
||||
|
||||
@@ -5,14 +5,13 @@ from enum import Enum
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.layers.custom_op_util import register_custom_op
|
||||
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
@@ -67,7 +66,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@register_custom_op("sglang_unquantized_fused_moe")
|
||||
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
"""MoE method without quantization."""
|
||||
|
||||
|
||||
@@ -7,9 +7,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
from sglang.srt.layers.custom_op_util import register_custom_op
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.utils import is_cuda_available
|
||||
|
||||
_is_cuda_available = is_cuda_available()
|
||||
@@ -59,7 +58,6 @@ def _apply_rotary_emb(
|
||||
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
||||
|
||||
|
||||
@register_custom_op("sglang_rotary_embedding")
|
||||
class RotaryEmbedding(CustomOp):
|
||||
"""Original rotary positional embedding."""
|
||||
|
||||
|
||||
@@ -21,8 +21,8 @@ from typing import TYPE_CHECKING, Callable
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
|
||||
Reference in New Issue
Block a user