update and simplify CustomOp (#3249)
This commit is contained in:
40
python/sglang/srt/custom_op.py
Normal file
40
python/sglang/srt/custom_op.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
||||||
|
_is_rocm = torch.cuda.is_available() and torch.version.hip
|
||||||
|
|
||||||
|
|
||||||
|
class CustomOp(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self._forward_method = self.dispatch_forward()
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return self._forward_method(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward_native(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def forward_cuda(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def forward_hip(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def forward_xpu(self, *args, **kwargs):
|
||||||
|
return self.forward_native(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward_hpu(self, *args, **kwargs):
|
||||||
|
return self.forward_native(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward_cpu(self, *args, **kwargs):
|
||||||
|
return self.forward_native(*args, **kwargs)
|
||||||
|
|
||||||
|
def dispatch_forward(self):
|
||||||
|
if _is_cuda:
|
||||||
|
return self.forward_cuda
|
||||||
|
elif _is_rocm:
|
||||||
|
return self.forward_hip
|
||||||
|
else:
|
||||||
|
return self.forward_native
|
||||||
@@ -25,21 +25,18 @@ from sglang.srt.utils import is_cuda_available
|
|||||||
if is_cuda_available():
|
if is_cuda_available():
|
||||||
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
||||||
|
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from sglang.srt.custom_op import CustomOp
|
||||||
|
|
||||||
from sglang.srt.distributed import (
|
from sglang.srt.distributed import (
|
||||||
divide,
|
divide,
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.custom_op_util import register_custom_op
|
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.utils import set_weight_attrs
|
from sglang.srt.utils import set_weight_attrs
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@register_custom_op("sglang_silu_and_mul")
|
|
||||||
class SiluAndMul(CustomOp):
|
class SiluAndMul(CustomOp):
|
||||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
d = x.shape[-1] // 2
|
d = x.shape[-1] // 2
|
||||||
@@ -53,7 +50,6 @@ class SiluAndMul(CustomOp):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@register_custom_op("sglang_gelu_and_mul")
|
|
||||||
class GeluAndMul(CustomOp):
|
class GeluAndMul(CustomOp):
|
||||||
def __init__(self, approximate="tanh"):
|
def __init__(self, approximate="tanh"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@@ -1,25 +0,0 @@
|
|||||||
# Copyright 2023-2024 SGLang Team
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
|
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
|
||||||
|
|
||||||
|
|
||||||
def register_custom_op(op_name):
|
|
||||||
def decorator(cls):
|
|
||||||
if hasattr(CustomOp, "register"):
|
|
||||||
return CustomOp.register(op_name)(cls)
|
|
||||||
else:
|
|
||||||
return cls
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
@@ -29,14 +29,11 @@ if is_cuda_available():
|
|||||||
rmsnorm,
|
rmsnorm,
|
||||||
)
|
)
|
||||||
|
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from sglang.srt.custom_op import CustomOp
|
||||||
|
|
||||||
from sglang.srt.layers.custom_op_util import register_custom_op
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@register_custom_op("sglang_rmsnorm")
|
|
||||||
class RMSNorm(CustomOp):
|
class RMSNorm(CustomOp):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -79,7 +76,6 @@ class RMSNorm(CustomOp):
|
|||||||
return x, residual
|
return x, residual
|
||||||
|
|
||||||
|
|
||||||
@register_custom_op("sglang_gemma_rmsnorm")
|
|
||||||
class GemmaRMSNorm(CustomOp):
|
class GemmaRMSNorm(CustomOp):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -4,13 +4,12 @@ from typing import Callable, List, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
|
||||||
|
|
||||||
|
from sglang.srt.custom_op import CustomOp
|
||||||
from sglang.srt.distributed import (
|
from sglang.srt.distributed import (
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.custom_op_util import register_custom_op
|
|
||||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||||
grouped_gemm_triton,
|
grouped_gemm_triton,
|
||||||
post_reorder_triton_kernel,
|
post_reorder_triton_kernel,
|
||||||
@@ -407,7 +406,6 @@ class EPMoE(torch.nn.Module):
|
|||||||
param_data[expert_id] = loaded_weight
|
param_data[expert_id] = loaded_weight
|
||||||
|
|
||||||
|
|
||||||
@register_custom_op("sglang_unquantized_ep_moe")
|
|
||||||
class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
|
class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -5,14 +5,13 @@ from enum import Enum
|
|||||||
from typing import Callable, List, Optional, Tuple
|
from typing import Callable, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
|
||||||
|
|
||||||
|
from sglang.srt.custom_op import CustomOp
|
||||||
from sglang.srt.distributed import (
|
from sglang.srt.distributed import (
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.custom_op_util import register_custom_op
|
|
||||||
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
|
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
|
||||||
from sglang.srt.layers.moe.topk import select_experts
|
from sglang.srt.layers.moe.topk import select_experts
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
@@ -67,7 +66,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@register_custom_op("sglang_unquantized_fused_moe")
|
|
||||||
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||||
"""MoE method without quantization."""
|
"""MoE method without quantization."""
|
||||||
|
|
||||||
|
|||||||
@@ -7,9 +7,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
|
||||||
|
|
||||||
from sglang.srt.layers.custom_op_util import register_custom_op
|
from sglang.srt.custom_op import CustomOp
|
||||||
from sglang.srt.utils import is_cuda_available
|
from sglang.srt.utils import is_cuda_available
|
||||||
|
|
||||||
_is_cuda_available = is_cuda_available()
|
_is_cuda_available = is_cuda_available()
|
||||||
@@ -59,7 +58,6 @@ def _apply_rotary_emb(
|
|||||||
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
||||||
|
|
||||||
|
|
||||||
@register_custom_op("sglang_rotary_embedding")
|
|
||||||
class RotaryEmbedding(CustomOp):
|
class RotaryEmbedding(CustomOp):
|
||||||
"""Original rotary positional embedding."""
|
"""Original rotary positional embedding."""
|
||||||
|
|
||||||
|
|||||||
@@ -21,8 +21,8 @@ from typing import TYPE_CHECKING, Callable
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
|
||||||
|
|
||||||
|
from sglang.srt.custom_op import CustomOp
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||||
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
|
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
|
|||||||
Reference in New Issue
Block a user