Fix circular import (#10107)
This commit is contained in:
@@ -6,12 +6,6 @@ from typing import TYPE_CHECKING, Callable, Optional, Tuple, TypeGuard
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.layers.moe.token_dispatcher import (
|
|
||||||
CombineInput,
|
|
||||||
CombineInputFormat,
|
|
||||||
DispatchOutput,
|
|
||||||
DispatchOutputFormat,
|
|
||||||
)
|
|
||||||
from sglang.srt.layers.moe.utils import MoeA2ABackend, MoeRunnerBackend
|
from sglang.srt.layers.moe.utils import MoeA2ABackend, MoeRunnerBackend
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -20,6 +14,12 @@ if TYPE_CHECKING:
|
|||||||
TritonRunnerInput,
|
TritonRunnerInput,
|
||||||
TritonRunnerOutput,
|
TritonRunnerOutput,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.moe.token_dispatcher import (
|
||||||
|
CombineInput,
|
||||||
|
CombineInputFormat,
|
||||||
|
DispatchOutput,
|
||||||
|
DispatchOutputFormat,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -143,17 +143,12 @@ class PermuteMethodPool:
|
|||||||
:param runner_backend_name: The MoeRunnerBackend name.
|
:param runner_backend_name: The MoeRunnerBackend name.
|
||||||
:param permute_func: The permute function to register.
|
:param permute_func: The permute function to register.
|
||||||
"""
|
"""
|
||||||
|
# TODO: check if registration is valid
|
||||||
key = (dispatch_output_name, runner_backend_name)
|
key = (dispatch_output_name, runner_backend_name)
|
||||||
if key in cls._pre_permute_methods:
|
if key in cls._pre_permute_methods:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Pre-permute method for {dispatch_output_name} to {runner_backend_name} is already registered."
|
f"Pre-permute method for {dispatch_output_name} to {runner_backend_name} is already registered."
|
||||||
)
|
)
|
||||||
assert DispatchOutputFormat(
|
|
||||||
dispatch_output_name
|
|
||||||
), f"Invalid dispatch output name: {dispatch_output_name}"
|
|
||||||
assert MoeRunnerBackend(
|
|
||||||
runner_backend_name
|
|
||||||
), f"Invalid runner backend name: {runner_backend_name}"
|
|
||||||
cls._pre_permute_methods[key] = permute_func
|
cls._pre_permute_methods[key] = permute_func
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -170,17 +165,12 @@ class PermuteMethodPool:
|
|||||||
:param combine_input_name: The CombineInputFormat name.
|
:param combine_input_name: The CombineInputFormat name.
|
||||||
:param permute_func: The permute function to register.
|
:param permute_func: The permute function to register.
|
||||||
"""
|
"""
|
||||||
|
# TODO: check if registration is valid
|
||||||
key = (runner_backend_name, combine_input_name)
|
key = (runner_backend_name, combine_input_name)
|
||||||
if key in cls._post_permute_methods:
|
if key in cls._post_permute_methods:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Post-permute method for {runner_backend_name} to {combine_input_name} is already registered."
|
f"Post-permute method for {runner_backend_name} to {combine_input_name} is already registered."
|
||||||
)
|
)
|
||||||
assert MoeRunnerBackend(
|
|
||||||
runner_backend_name
|
|
||||||
), f"Invalid runner backend name: {runner_backend_name}"
|
|
||||||
assert CombineInputFormat(
|
|
||||||
combine_input_name
|
|
||||||
), f"Invalid combine input name: {combine_input_name}"
|
|
||||||
cls._post_permute_methods[key] = permute_func
|
cls._post_permute_methods[key] = permute_func
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -10,15 +10,11 @@ from sglang.srt.layers.moe.moe_runner.base import (
|
|||||||
PermuteMethodPool,
|
PermuteMethodPool,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.moe.moe_runner.triton import TritonRunnerCore
|
from sglang.srt.layers.moe.moe_runner.triton import TritonRunnerCore
|
||||||
from sglang.srt.layers.moe.token_dispatcher.base import (
|
|
||||||
CombineInput,
|
|
||||||
CombineInputFormat,
|
|
||||||
DispatchOutput,
|
|
||||||
)
|
|
||||||
from sglang.srt.layers.moe.utils import get_moe_a2a_backend
|
from sglang.srt.layers.moe.utils import get_moe_a2a_backend
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.moe.moe_runner.base import MoeQuantInfo
|
from sglang.srt.layers.moe.moe_runner.base import MoeQuantInfo
|
||||||
|
from sglang.srt.layers.moe.token_dispatcher.base import CombineInput, DispatchOutput
|
||||||
from sglang.srt.layers.moe.utils import MoeRunnerBackend
|
from sglang.srt.layers.moe.utils import MoeRunnerBackend
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -18,13 +18,16 @@ from sglang.srt.layers.moe.moe_runner.base import (
|
|||||||
register_post_permute,
|
register_post_permute,
|
||||||
register_pre_permute,
|
register_pre_permute,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.moe.token_dispatcher import (
|
|
||||||
StandardCombineInput,
|
|
||||||
StandardDispatchOutput,
|
|
||||||
)
|
|
||||||
from sglang.srt.layers.moe.utils import MoeRunnerBackend
|
from sglang.srt.layers.moe.utils import MoeRunnerBackend
|
||||||
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip
|
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.layers.moe.token_dispatcher.standard import (
|
||||||
|
StandardCombineInput,
|
||||||
|
StandardDispatchOutput,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
_is_cpu_amx_available = cpu_has_amx_support()
|
_is_cpu_amx_available = cpu_has_amx_support()
|
||||||
@@ -325,6 +328,7 @@ def fused_experts_none_to_triton(
|
|||||||
runner_config: MoeRunnerConfig,
|
runner_config: MoeRunnerConfig,
|
||||||
) -> StandardCombineInput:
|
) -> StandardCombineInput:
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||||
|
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
|
||||||
|
|
||||||
output = fused_experts(
|
output = fused_experts(
|
||||||
hidden_states=dispatch_output.hidden_states,
|
hidden_states=dispatch_output.hidden_states,
|
||||||
@@ -437,6 +441,8 @@ def post_permute_triton_to_standard(
|
|||||||
# NOTE: this is dead code as a fused func for standard format is registered.
|
# NOTE: this is dead code as a fused func for standard format is registered.
|
||||||
# This is left here for testing and examples.
|
# This is left here for testing and examples.
|
||||||
|
|
||||||
|
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
|
||||||
|
|
||||||
return StandardCombineInput(
|
return StandardCombineInput(
|
||||||
hidden_states=runner_output.hidden_states,
|
hidden_states=runner_output.hidden_states,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -42,11 +42,6 @@ from enum import Enum, IntEnum, auto
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
||||||
deepep_permute_triton_kernel,
|
|
||||||
deepep_post_reorder_triton_kernel,
|
|
||||||
deepep_run_moe_deep_preprocess,
|
|
||||||
)
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
|
|
||||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
|
||||||
@@ -439,6 +434,11 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|||||||
topk_idx: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||||
|
deepep_post_reorder_triton_kernel,
|
||||||
|
)
|
||||||
|
|
||||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter:
|
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter:
|
||||||
output = hidden_states
|
output = hidden_states
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from torch.nn.parameter import Parameter
|
|||||||
|
|
||||||
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
|
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
|
||||||
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||||
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
FusedMoEMethodBase,
|
FusedMoEMethodBase,
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
@@ -297,6 +296,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
dispatch_output: StandardDispatchOutput,
|
dispatch_output: StandardDispatchOutput,
|
||||||
) -> CombineInput:
|
) -> CombineInput:
|
||||||
|
|
||||||
|
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
||||||
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
||||||
|
|
||||||
x = dispatch_output.hidden_states
|
x = dispatch_output.hidden_states
|
||||||
|
|||||||
Reference in New Issue
Block a user