feat: update other MoE models deps (#2156)
This commit is contained in:
@@ -153,12 +153,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
num_expert_group: Optional[int],
|
num_expert_group: Optional[int],
|
||||||
topk_group: Optional[int],
|
topk_group: Optional[int],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
|
raise NotImplementedError("The TPU backend currently does not support MoE.")
|
||||||
|
|
||||||
assert not use_grouped_topk
|
|
||||||
assert num_expert_group is None
|
|
||||||
assert topk_group is None
|
|
||||||
return fused_moe(x, w1, w2, router_logits, top_k, renormalize)
|
|
||||||
|
|
||||||
|
|
||||||
class FusedMoE(torch.nn.Module):
|
class FusedMoE(torch.nn.Module):
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/fused_moe.py
|
||||||
|
|
||||||
"""Fused MoE kernel."""
|
"""Fused MoE kernel."""
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
|
||||||
|
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Callable, List, Optional, Tuple
|
from typing import Callable, List, Optional, Tuple
|
||||||
@@ -18,7 +20,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
from sglang.srt.utils import set_weight_attrs
|
from sglang.srt.utils import set_weight_attrs
|
||||||
|
|
||||||
if torch.cuda.is_available() or torch.hip.is_available():
|
if torch.cuda.is_available() or torch.hip.is_available():
|
||||||
from .fused_moe import fused_experts
|
from sglang.srt.layers.triton_fused_moe.fused_moe import fused_experts
|
||||||
else:
|
else:
|
||||||
fused_experts = None # type: ignore
|
fused_experts = None # type: ignore
|
||||||
|
|
||||||
@@ -512,7 +514,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
):
|
):
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
from sglang.srt.layers.triton_fused_moe.fused_moe import (
|
||||||
fused_topk,
|
fused_topk,
|
||||||
grouped_topk,
|
grouped_topk,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ from vllm.distributed import (
|
|||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
||||||
@@ -37,6 +36,7 @@ from sglang.srt.layers.linear import (
|
|||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.layers.triton_fused_moe import fused_moe
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
DEFAULT_VOCAB_PADDING_SIZE,
|
DEFAULT_VOCAB_PADDING_SIZE,
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ from vllm.distributed import (
|
|||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
|
||||||
@@ -41,6 +40,7 @@ from sglang.srt.layers.linear import (
|
|||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.layers.triton_fused_moe import fused_moe
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import MixtralConfig
|
from transformers import MixtralConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
|
||||||
@@ -36,6 +35,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
|||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.torchao_utils import apply_torchao_config_
|
from sglang.srt.layers.torchao_utils import apply_torchao_config_
|
||||||
|
from sglang.srt.layers.triton_fused_moe import FusedMoE
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ from vllm.distributed import (
|
|||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
@@ -43,6 +42,7 @@ from sglang.srt.layers.layernorm import RMSNorm
|
|||||||
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.layers.triton_fused_moe import FusedMoE
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ from vllm.distributed import (
|
|||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
|
||||||
@@ -42,6 +41,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
|||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.torchao_utils import apply_torchao_config_
|
from sglang.srt.layers.torchao_utils import apply_torchao_config_
|
||||||
|
from sglang.srt.layers.triton_fused_moe import FusedMoE
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ from vllm.distributed import (
|
|||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
@@ -38,6 +37,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.layers.triton_fused_moe import fused_moe
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
|
|||||||
@@ -957,6 +957,21 @@ def direct_register_custom_op(
|
|||||||
fake_impl: Optional[Callable] = None,
|
fake_impl: Optional[Callable] = None,
|
||||||
target_lib: Optional[Library] = None,
|
target_lib: Optional[Library] = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
`torch.library.custom_op` can have significant overhead because it
|
||||||
|
needs to consider complicated dispatching logic. This function
|
||||||
|
directly registers a custom op and dispatches it to the CUDA backend.
|
||||||
|
See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5
|
||||||
|
for more details.
|
||||||
|
|
||||||
|
By default, the custom op is registered to the vLLM library. If you
|
||||||
|
want to register it to a different library, you can pass the library
|
||||||
|
object to the `target_lib` argument.
|
||||||
|
|
||||||
|
IMPORTANT: the lifetime of the operator is tied to the lifetime of the
|
||||||
|
library object. If you want to bind the operator to a different library,
|
||||||
|
make sure the library object is alive when the operator is used.
|
||||||
|
"""
|
||||||
import torch.library
|
import torch.library
|
||||||
|
|
||||||
if hasattr(torch.library, "infer_schema"):
|
if hasattr(torch.library, "infer_schema"):
|
||||||
|
|||||||
Reference in New Issue
Block a user