Rename triton_fused_moe -> fused_moe_triton (#2163)

This commit is contained in:
Lianmin Zheng
2024-11-24 08:12:35 -08:00
committed by GitHub
parent fe5d3e818f
commit be0124bda0
76 changed files with 19 additions and 19 deletions

View File

@@ -1 +0,0 @@
from sglang.srt.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase

View File

@@ -0,0 +1 @@
from sglang.srt.layers.fused_moe_grok.layer import FusedMoE, FusedMoEMethodBase

View File

@@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from sglang.srt.layers.fused_moe.fused_moe import padding_size from sglang.srt.layers.fused_moe_grok.fused_moe import padding_size
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip
logger = init_logger(__name__) logger = init_logger(__name__)
@@ -123,7 +123,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group: Optional[int], num_expert_group: Optional[int],
topk_group: Optional[int], topk_group: Optional[int],
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.fused_moe.fused_moe import fused_moe from sglang.srt.layers.fused_moe_grok.fused_moe import fused_moe
return fused_moe( return fused_moe(
x, x,
@@ -609,7 +609,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.fused_moe.fused_moe import fused_moe from sglang.srt.layers.fused_moe_grok.fused_moe import fused_moe
return fused_moe( return fused_moe(
x, x,

View File

@@ -1,14 +1,14 @@
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import sglang.srt.layers.triton_fused_moe.fused_moe # noqa import sglang.srt.layers.fused_moe_triton.fused_moe # noqa
from sglang.srt.layers.triton_fused_moe.fused_moe import ( from sglang.srt.layers.fused_moe_triton.fused_moe import (
fused_experts, fused_experts,
fused_topk, fused_topk,
get_config_file_name, get_config_file_name,
grouped_topk, grouped_topk,
) )
from sglang.srt.layers.triton_fused_moe.layer import ( from sglang.srt.layers.fused_moe_triton.layer import (
FusedMoE, FusedMoE,
FusedMoEMethodBase, FusedMoEMethodBase,
FusedMoeWeightScaleSupported, FusedMoeWeightScaleSupported,

View File

@@ -376,7 +376,7 @@ def try_get_optimal_moe_config(
M: int, M: int,
is_marlin: bool = False, is_marlin: bool = False,
): ):
from sglang.srt.layers.triton_fused_moe import get_config from sglang.srt.layers.fused_moe_triton import get_config
override_config = get_config() override_config = get_config()
if override_config: if override_config:

View File

@@ -20,7 +20,7 @@ from sglang.srt.layers.quantization.base_config import (
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import set_weight_attrs
if torch.cuda.is_available() or torch.hip.is_available(): if torch.cuda.is_available() or torch.hip.is_available():
from sglang.srt.layers.triton_fused_moe.fused_moe import fused_experts from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
else: else:
fused_experts = None # type: ignore fused_experts = None # type: ignore
@@ -514,7 +514,7 @@ class FusedMoE(torch.nn.Module):
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
): ):
from sglang.srt.layers.triton_fused_moe.fused_moe import ( from sglang.srt.layers.fused_moe_triton.fused_moe import (
fused_topk, fused_topk,
grouped_topk, grouped_topk,
) )

View File

@@ -68,7 +68,7 @@ def fp8_get_quant_method(self, layer, prefix):
is_layer_skipped, is_layer_skipped,
) )
from sglang.srt.layers.triton_fused_moe.layer import FusedMoE from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers): if is_layer_skipped(prefix, self.ignored_layers):

View File

@@ -28,6 +28,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.transformers_utils.configs.dbrx import DbrxConfig from vllm.transformers_utils.configs.dbrx import DbrxConfig
from sglang.srt.layers.fused_moe_triton import fused_moe
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,
@@ -36,7 +37,6 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.triton_fused_moe import fused_moe
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead, ParallelLMHead,

View File

@@ -30,6 +30,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.fused_moe_triton import fused_moe
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,
@@ -40,7 +41,6 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.triton_fused_moe import fused_moe
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,

View File

@@ -31,6 +31,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
@@ -41,7 +42,6 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.triton_fused_moe import FusedMoE
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,

View File

@@ -31,7 +31,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.fused_moe import FusedMoE from sglang.srt.layers.fused_moe_grok import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
QKVParallelLinear, QKVParallelLinear,

View File

@@ -25,6 +25,7 @@ from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
QKVParallelLinear, QKVParallelLinear,
@@ -35,7 +36,6 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.layers.torchao_utils import apply_torchao_config_
from sglang.srt.layers.triton_fused_moe import FusedMoE
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,

View File

@@ -38,11 +38,11 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.triton_fused_moe import FusedMoE
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,

View File

@@ -30,6 +30,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,
@@ -41,7 +42,6 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.layers.torchao_utils import apply_torchao_config_
from sglang.srt.layers.triton_fused_moe import FusedMoE
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,

View File

@@ -34,10 +34,10 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.fused_moe_triton import fused_moe
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.triton_fused_moe import fused_moe
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,