From a5a03209e9598c25f28adb29a70c4ab6dc205e61 Mon Sep 17 00:00:00 2001 From: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Date: Sat, 6 Sep 2025 01:34:17 -0700 Subject: [PATCH] Fix circular import (#10107) --- .../sglang/srt/layers/moe/moe_runner/base.py | 26 ++++++------------- .../srt/layers/moe/moe_runner/runner.py | 6 +---- .../srt/layers/moe/moe_runner/triton.py | 14 +++++++--- .../srt/layers/moe/token_dispatcher/deepep.py | 10 +++---- .../sglang/srt/layers/quantization/w4afp8.py | 2 +- 5 files changed, 25 insertions(+), 33 deletions(-) diff --git a/python/sglang/srt/layers/moe/moe_runner/base.py b/python/sglang/srt/layers/moe/moe_runner/base.py index c5c14bfea..4d95540e6 100644 --- a/python/sglang/srt/layers/moe/moe_runner/base.py +++ b/python/sglang/srt/layers/moe/moe_runner/base.py @@ -6,12 +6,6 @@ from typing import TYPE_CHECKING, Callable, Optional, Tuple, TypeGuard import torch -from sglang.srt.layers.moe.token_dispatcher import ( - CombineInput, - CombineInputFormat, - DispatchOutput, - DispatchOutputFormat, -) from sglang.srt.layers.moe.utils import MoeA2ABackend, MoeRunnerBackend if TYPE_CHECKING: @@ -20,6 +14,12 @@ if TYPE_CHECKING: TritonRunnerInput, TritonRunnerOutput, ) + from sglang.srt.layers.moe.token_dispatcher import ( + CombineInput, + CombineInputFormat, + DispatchOutput, + DispatchOutputFormat, + ) @dataclass @@ -143,17 +143,12 @@ class PermuteMethodPool: :param runner_backend_name: The MoeRunnerBackend name. :param permute_func: The permute function to register. """ + # TODO: check if registration is valid key = (dispatch_output_name, runner_backend_name) if key in cls._pre_permute_methods: raise ValueError( f"Pre-permute method for {dispatch_output_name} to {runner_backend_name} is already registered." ) - assert DispatchOutputFormat( - dispatch_output_name - ), f"Invalid dispatch output name: {dispatch_output_name}" - assert MoeRunnerBackend( - runner_backend_name - ), f"Invalid runner backend name: {runner_backend_name}" cls._pre_permute_methods[key] = permute_func @classmethod @@ -170,17 +165,12 @@ class PermuteMethodPool: :param combine_input_name: The CombineInputFormat name. :param permute_func: The permute function to register. """ + # TODO: check if registration is valid key = (runner_backend_name, combine_input_name) if key in cls._post_permute_methods: raise ValueError( f"Post-permute method for {runner_backend_name} to {combine_input_name} is already registered." ) - assert MoeRunnerBackend( - runner_backend_name - ), f"Invalid runner backend name: {runner_backend_name}" - assert CombineInputFormat( - combine_input_name - ), f"Invalid combine input name: {combine_input_name}" cls._post_permute_methods[key] = permute_func @classmethod diff --git a/python/sglang/srt/layers/moe/moe_runner/runner.py b/python/sglang/srt/layers/moe/moe_runner/runner.py index 995813690..3b6fcd980 100644 --- a/python/sglang/srt/layers/moe/moe_runner/runner.py +++ b/python/sglang/srt/layers/moe/moe_runner/runner.py @@ -10,15 +10,11 @@ from sglang.srt.layers.moe.moe_runner.base import ( PermuteMethodPool, ) from sglang.srt.layers.moe.moe_runner.triton import TritonRunnerCore -from sglang.srt.layers.moe.token_dispatcher.base import ( - CombineInput, - CombineInputFormat, - DispatchOutput, -) from sglang.srt.layers.moe.utils import get_moe_a2a_backend if TYPE_CHECKING: from sglang.srt.layers.moe.moe_runner.base import MoeQuantInfo + from sglang.srt.layers.moe.token_dispatcher.base import CombineInput, DispatchOutput from sglang.srt.layers.moe.utils import MoeRunnerBackend logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/layers/moe/moe_runner/triton.py b/python/sglang/srt/layers/moe/moe_runner/triton.py index bc0476812..116fdcaa0 100644 --- a/python/sglang/srt/layers/moe/moe_runner/triton.py +++ b/python/sglang/srt/layers/moe/moe_runner/triton.py @@ -18,13 +18,16 @@ from sglang.srt.layers.moe.moe_runner.base import ( register_post_permute, register_pre_permute, ) -from sglang.srt.layers.moe.token_dispatcher import ( - StandardCombineInput, - StandardDispatchOutput, -) from sglang.srt.layers.moe.utils import MoeRunnerBackend from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip +if TYPE_CHECKING: + from sglang.srt.layers.moe.token_dispatcher.standard import ( + StandardCombineInput, + StandardDispatchOutput, + ) + + _is_hip = is_hip() _is_cuda = is_cuda() _is_cpu_amx_available = cpu_has_amx_support() @@ -325,6 +328,7 @@ def fused_experts_none_to_triton( runner_config: MoeRunnerConfig, ) -> StandardCombineInput: from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts + from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput output = fused_experts( hidden_states=dispatch_output.hidden_states, @@ -437,6 +441,8 @@ def post_permute_triton_to_standard( # NOTE: this is dead code as a fused func for standard format is registered. # This is left here for testing and examples. + from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput + return StandardCombineInput( hidden_states=runner_output.hidden_states, ) diff --git a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py index ccb13e50c..c9c9bb04f 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py @@ -42,11 +42,6 @@ from enum import Enum, IntEnum, auto import torch import torch.distributed as dist -from sglang.srt.layers.moe.ep_moe.kernels import ( - deepep_permute_triton_kernel, - deepep_post_reorder_triton_kernel, - deepep_run_moe_deep_preprocess, -) from sglang.srt.model_executor.forward_batch_info import ForwardBatch _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip() @@ -439,6 +434,11 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): topk_idx: torch.Tensor, topk_weights: torch.Tensor, ): + + from sglang.srt.layers.moe.ep_moe.kernels import ( + deepep_post_reorder_triton_kernel, + ) + if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter: output = hidden_states else: diff --git a/python/sglang/srt/layers/quantization/w4afp8.py b/python/sglang/srt/layers/quantization/w4afp8.py index f39acd8af..f8fad8bcb 100644 --- a/python/sglang/srt/layers/quantization/w4afp8.py +++ b/python/sglang/srt/layers/quantization/w4afp8.py @@ -9,7 +9,6 @@ from torch.nn.parameter import Parameter from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod -from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe from sglang.srt.layers.quantization.base_config import ( FusedMoEMethodBase, QuantizationConfig, @@ -297,6 +296,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): dispatch_output: StandardDispatchOutput, ) -> CombineInput: + from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput x = dispatch_output.hidden_states