Files
sglang/python/sglang/srt/layers/moe/moe_runner/runner.py
2025-09-06 01:34:17 -07:00

81 lines
2.6 KiB
Python

from __future__ import annotations
import logging
import os
from typing import TYPE_CHECKING
from sglang.srt.layers.moe.moe_runner.base import (
FusedOpPool,
MoeRunnerConfig,
PermuteMethodPool,
)
from sglang.srt.layers.moe.moe_runner.triton import TritonRunnerCore
from sglang.srt.layers.moe.utils import get_moe_a2a_backend
if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner.base import MoeQuantInfo
from sglang.srt.layers.moe.token_dispatcher.base import CombineInput, DispatchOutput
from sglang.srt.layers.moe.utils import MoeRunnerBackend
logger = logging.getLogger(__name__)
class MoeRunner:
def __init__(self, runner_backend: MoeRunnerBackend, config: MoeRunnerConfig):
self.runner_backend = runner_backend
self.config = config
self.fused_func = None
if runner_backend.is_triton():
self.runner_core = TritonRunnerCore(config)
else:
raise NotImplementedError(f"Unsupported runner backend: {runner_backend}")
a2a_backend_name = get_moe_a2a_backend().value
runner_backend_name = runner_backend.value
self.fused_func = FusedOpPool.get_fused_func(
a2a_backend_name, runner_backend_name
)
SGLANG_CI_DISABLE_MOE_FUSED_FUNC = os.environ.get(
"SGLANG_CI_DISABLE_MOE_FUSED_FUNC", "0"
)
if SGLANG_CI_DISABLE_MOE_FUSED_FUNC == "1":
logger.info(
"SGLANG_CI_DISABLE_MOE_FUSED_FUNC is set to 1, disabling fused func"
)
self.fused_func = None
def run(
self, dispatch_output: DispatchOutput, quant_info: MoeQuantInfo
) -> CombineInput:
if self.fused_func is not None:
return self.fused_func(dispatch_output, quant_info, self.config)
dispatch_format = dispatch_output.format.value
runner_format = self.runner_core.runner_backend.value
self.pre_permute_func = PermuteMethodPool.get_pre_permute(
dispatch_format, runner_format
)
running_state = {}
runner_input = self.pre_permute_func(
dispatch_output, quant_info, self.config, running_state
)
runner_output = self.runner_core.run(runner_input, quant_info, running_state)
runner_format = self.runner_core.runner_backend.value
combine_format = dispatch_output.format.value
self.post_permute_func = PermuteMethodPool.get_post_permute(
runner_format, combine_format
)
combine_input = self.post_permute_func(
runner_output, quant_info, self.config, running_state
)
return combine_input