Reorg moe code (#2563)
This commit is contained in:
@@ -5,7 +5,9 @@ import triton
|
||||
from torch.nn import functional as F
|
||||
from transformers import AutoConfig
|
||||
|
||||
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe as fused_moe_triton
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||
fused_moe as fused_moe_triton,
|
||||
)
|
||||
from sglang.srt.model_executor.cuda_graph_runner import set_torch_compile_config
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,9 @@ import triton
|
||||
from transformers import AutoConfig
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm
|
||||
|
||||
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe as fused_moe_sglang
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||
fused_moe as fused_moe_sglang,
|
||||
)
|
||||
|
||||
|
||||
def get_model_config(model_name: str, tp_size: int):
|
||||
|
||||
@@ -11,7 +11,7 @@ import triton
|
||||
from ray.experimental.tqdm_ray import tqdm
|
||||
from transformers import AutoConfig
|
||||
|
||||
from sglang.srt.layers.fused_moe_triton.fused_moe import (
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||
fused_moe,
|
||||
get_config_dtype_str,
|
||||
get_config_file_name,
|
||||
@@ -97,7 +97,7 @@ def benchmark_config(
|
||||
input_gating.copy_(gating_output[i])
|
||||
|
||||
def run():
|
||||
from sglang.srt.layers.fused_moe_triton import override_config
|
||||
from sglang.srt.layers.moe.fused_moe_triton import override_config
|
||||
|
||||
with override_config(config):
|
||||
fused_moe(
|
||||
|
||||
Reference in New Issue
Block a user