Reorg moe code (#2563)

This commit is contained in:
Ke Bao
2024-12-24 01:10:22 +08:00
committed by GitHub
parent 23e5e50fd5
commit e835a50021
88 changed files with 338 additions and 344 deletions

View File

@@ -5,7 +5,9 @@ import triton
from torch.nn import functional as F
from transformers import AutoConfig
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe as fused_moe_triton
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_moe as fused_moe_triton,
)
from sglang.srt.model_executor.cuda_graph_runner import set_torch_compile_config

View File

@@ -5,7 +5,9 @@ import triton
from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe as fused_moe_sglang
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_moe as fused_moe_sglang,
)
def get_model_config(model_name: str, tp_size: int):

View File

@@ -11,7 +11,7 @@ import triton
from ray.experimental.tqdm_ray import tqdm
from transformers import AutoConfig
from sglang.srt.layers.fused_moe_triton.fused_moe import (
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_moe,
get_config_dtype_str,
get_config_file_name,
@@ -97,7 +97,7 @@ def benchmark_config(
input_gating.copy_(gating_output[i])
def run():
from sglang.srt.layers.fused_moe_triton import override_config
from sglang.srt.layers.moe.fused_moe_triton import override_config
with override_config(config):
fused_moe(