Fix circular import (#10107)
This commit is contained in:
@@ -6,12 +6,6 @@ from typing import TYPE_CHECKING, Callable, Optional, Tuple, TypeGuard
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.moe.token_dispatcher import (
|
||||
CombineInput,
|
||||
CombineInputFormat,
|
||||
DispatchOutput,
|
||||
DispatchOutputFormat,
|
||||
)
|
||||
from sglang.srt.layers.moe.utils import MoeA2ABackend, MoeRunnerBackend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -20,6 +14,12 @@ if TYPE_CHECKING:
|
||||
TritonRunnerInput,
|
||||
TritonRunnerOutput,
|
||||
)
|
||||
from sglang.srt.layers.moe.token_dispatcher import (
|
||||
CombineInput,
|
||||
CombineInputFormat,
|
||||
DispatchOutput,
|
||||
DispatchOutputFormat,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -143,17 +143,12 @@ class PermuteMethodPool:
|
||||
:param runner_backend_name: The MoeRunnerBackend name.
|
||||
:param permute_func: The permute function to register.
|
||||
"""
|
||||
# TODO: check if registration is valid
|
||||
key = (dispatch_output_name, runner_backend_name)
|
||||
if key in cls._pre_permute_methods:
|
||||
raise ValueError(
|
||||
f"Pre-permute method for {dispatch_output_name} to {runner_backend_name} is already registered."
|
||||
)
|
||||
assert DispatchOutputFormat(
|
||||
dispatch_output_name
|
||||
), f"Invalid dispatch output name: {dispatch_output_name}"
|
||||
assert MoeRunnerBackend(
|
||||
runner_backend_name
|
||||
), f"Invalid runner backend name: {runner_backend_name}"
|
||||
cls._pre_permute_methods[key] = permute_func
|
||||
|
||||
@classmethod
|
||||
@@ -170,17 +165,12 @@ class PermuteMethodPool:
|
||||
:param combine_input_name: The CombineInputFormat name.
|
||||
:param permute_func: The permute function to register.
|
||||
"""
|
||||
# TODO: check if registration is valid
|
||||
key = (runner_backend_name, combine_input_name)
|
||||
if key in cls._post_permute_methods:
|
||||
raise ValueError(
|
||||
f"Post-permute method for {runner_backend_name} to {combine_input_name} is already registered."
|
||||
)
|
||||
assert MoeRunnerBackend(
|
||||
runner_backend_name
|
||||
), f"Invalid runner backend name: {runner_backend_name}"
|
||||
assert CombineInputFormat(
|
||||
combine_input_name
|
||||
), f"Invalid combine input name: {combine_input_name}"
|
||||
cls._post_permute_methods[key] = permute_func
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -10,15 +10,11 @@ from sglang.srt.layers.moe.moe_runner.base import (
|
||||
PermuteMethodPool,
|
||||
)
|
||||
from sglang.srt.layers.moe.moe_runner.triton import TritonRunnerCore
|
||||
from sglang.srt.layers.moe.token_dispatcher.base import (
|
||||
CombineInput,
|
||||
CombineInputFormat,
|
||||
DispatchOutput,
|
||||
)
|
||||
from sglang.srt.layers.moe.utils import get_moe_a2a_backend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.moe_runner.base import MoeQuantInfo
|
||||
from sglang.srt.layers.moe.token_dispatcher.base import CombineInput, DispatchOutput
|
||||
from sglang.srt.layers.moe.utils import MoeRunnerBackend
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -18,13 +18,16 @@ from sglang.srt.layers.moe.moe_runner.base import (
|
||||
register_post_permute,
|
||||
register_pre_permute,
|
||||
)
|
||||
from sglang.srt.layers.moe.token_dispatcher import (
|
||||
StandardCombineInput,
|
||||
StandardDispatchOutput,
|
||||
)
|
||||
from sglang.srt.layers.moe.utils import MoeRunnerBackend
|
||||
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.token_dispatcher.standard import (
|
||||
StandardCombineInput,
|
||||
StandardDispatchOutput,
|
||||
)
|
||||
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_cuda = is_cuda()
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
@@ -325,6 +328,7 @@ def fused_experts_none_to_triton(
|
||||
runner_config: MoeRunnerConfig,
|
||||
) -> StandardCombineInput:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
|
||||
|
||||
output = fused_experts(
|
||||
hidden_states=dispatch_output.hidden_states,
|
||||
@@ -437,6 +441,8 @@ def post_permute_triton_to_standard(
|
||||
# NOTE: this is dead code as a fused func for standard format is registered.
|
||||
# This is left here for testing and examples.
|
||||
|
||||
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
|
||||
|
||||
return StandardCombineInput(
|
||||
hidden_states=runner_output.hidden_states,
|
||||
)
|
||||
|
||||
@@ -42,11 +42,6 @@ from enum import Enum, IntEnum, auto
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
deepep_permute_triton_kernel,
|
||||
deepep_post_reorder_triton_kernel,
|
||||
deepep_run_moe_deep_preprocess,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
|
||||
@@ -439,6 +434,11 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
):
|
||||
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
deepep_post_reorder_triton_kernel,
|
||||
)
|
||||
|
||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter:
|
||||
output = hidden_states
|
||||
else:
|
||||
|
||||
@@ -9,7 +9,6 @@ from torch.nn.parameter import Parameter
|
||||
|
||||
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
|
||||
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
FusedMoEMethodBase,
|
||||
QuantizationConfig,
|
||||
@@ -297,6 +296,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> CombineInput:
|
||||
|
||||
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
||||
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
||||
|
||||
x = dispatch_output.hidden_states
|
||||
|
||||
Reference in New Issue
Block a user