diff --git a/python/sglang/srt/layers/fused_moe/layer.py b/python/sglang/srt/layers/fused_moe/layer.py index 19012185d..df91ba117 100644 --- a/python/sglang/srt/layers/fused_moe/layer.py +++ b/python/sglang/srt/layers/fused_moe/layer.py @@ -153,12 +153,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): num_expert_group: Optional[int], topk_group: Optional[int], ) -> torch.Tensor: - from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe - - assert not use_grouped_topk - assert num_expert_group is None - assert topk_group is None - return fused_moe(x, w1, w2, router_logits, top_k, renormalize) + raise NotImplementedError("The TPU backend currently does not support MoE.") class FusedMoE(torch.nn.Module): diff --git a/python/sglang/srt/layers/triton_fused_moe/fused_moe.py b/python/sglang/srt/layers/triton_fused_moe/fused_moe.py index 8a2c7257b..86c189257 100644 --- a/python/sglang/srt/layers/triton_fused_moe/fused_moe.py +++ b/python/sglang/srt/layers/triton_fused_moe/fused_moe.py @@ -1,3 +1,5 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/fused_moe.py + """Fused MoE kernel.""" import functools diff --git a/python/sglang/srt/layers/triton_fused_moe/layer.py b/python/sglang/srt/layers/triton_fused_moe/layer.py index 3ec2f7a34..93a6e5506 100644 --- a/python/sglang/srt/layers/triton_fused_moe/layer.py +++ b/python/sglang/srt/layers/triton_fused_moe/layer.py @@ -1,3 +1,5 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py + from abc import abstractmethod from enum import Enum from typing import Callable, List, Optional, Tuple @@ -18,7 +20,7 @@ from sglang.srt.layers.quantization.base_config import ( from sglang.srt.utils import set_weight_attrs if torch.cuda.is_available() or torch.hip.is_available(): - from .fused_moe import fused_experts + from sglang.srt.layers.triton_fused_moe.fused_moe import fused_experts else: fused_experts = None # type: ignore @@ -512,7 +514,7 @@ class FusedMoE(torch.nn.Module): num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, ): - from vllm.model_executor.layers.fused_moe.fused_moe import ( + from sglang.srt.layers.triton_fused_moe.fused_moe import ( fused_topk, grouped_topk, ) diff --git a/python/sglang/srt/models/dbrx.py b/python/sglang/srt/models/dbrx.py index ad09a0dbf..cfbf21c70 100644 --- a/python/sglang/srt/models/dbrx.py +++ b/python/sglang/srt/models/dbrx.py @@ -24,7 +24,6 @@ from vllm.distributed import ( get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) -from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.transformers_utils.configs.dbrx import DbrxConfig @@ -37,6 +36,7 @@ from sglang.srt.layers.linear import ( from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.triton_fused_moe import fused_moe from sglang.srt.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, diff --git a/python/sglang/srt/models/deepseek.py b/python/sglang/srt/models/deepseek.py index 86c9cde79..e8e163dfc 100644 --- a/python/sglang/srt/models/deepseek.py +++ b/python/sglang/srt/models/deepseek.py @@ -26,7 +26,6 @@ from vllm.distributed import ( get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) -from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -41,6 +40,7 @@ from sglang.srt.layers.linear import ( from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.triton_fused_moe import fused_moe from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index 454e6e348..46a6b6ac7 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -22,7 +22,6 @@ import torch from torch import nn from transformers import MixtralConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -36,6 +35,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.torchao_utils import apply_torchao_config_ +from sglang.srt.layers.triton_fused_moe import FusedMoE from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/olmoe.py b/python/sglang/srt/models/olmoe.py index d1e8e6027..984638d5b 100644 --- a/python/sglang/srt/models/olmoe.py +++ b/python/sglang/srt/models/olmoe.py @@ -27,7 +27,6 @@ from vllm.distributed import ( get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) -from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, @@ -43,6 +42,7 @@ from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.triton_fused_moe import FusedMoE from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 07a29e687..d363ec6a0 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -26,7 +26,6 @@ from vllm.distributed import ( get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) -from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -42,6 +41,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.torchao_utils import apply_torchao_config_ +from sglang.srt.layers.triton_fused_moe import FusedMoE from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/xverse_moe.py b/python/sglang/srt/models/xverse_moe.py index ae02179a7..8cdd4c570 100644 --- a/python/sglang/srt/models/xverse_moe.py +++ b/python/sglang/srt/models/xverse_moe.py @@ -24,7 +24,6 @@ from vllm.distributed import ( tensor_model_parallel_all_reduce, ) from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, @@ -38,6 +37,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.triton_fused_moe import fused_moe from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index f0d129a47..8bb6a5830 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -957,6 +957,21 @@ def direct_register_custom_op( fake_impl: Optional[Callable] = None, target_lib: Optional[Library] = None, ): + """ + `torch.library.custom_op` can have significant overhead because it + needs to consider complicated dispatching logic. This function + directly registers a custom op and dispatches it to the CUDA backend. + See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5 + for more details. + + By default, the custom op is registered to the vLLM library. If you + want to register it to a different library, you can pass the library + object to the `target_lib` argument. + + IMPORTANT: the lifetime of the operator is tied to the lifetime of the + library object. If you want to bind the operator to a different library, + make sure the library object is alive when the operator is used. + """ import torch.library if hasattr(torch.library, "infer_schema"):