diff --git a/benchmark/kernels/quantization/bench_fp4_quant.py b/benchmark/kernels/quantization/bench_fp4_quant.py index 318e820ad..9a5b69463 100644 --- a/benchmark/kernels/quantization/bench_fp4_quant.py +++ b/benchmark/kernels/quantization/bench_fp4_quant.py @@ -6,8 +6,8 @@ import triton from sgl_kernel import scaled_fp4_grouped_quant, silu_and_mul_scaled_fp4_grouped_quant from sgl_kernel.elementwise import silu_and_mul +from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_masked_post_quant_fwd -from sglang.srt.layers.quantization import deep_gemm_wrapper def _test_accuracy_once(E, M, K, input_dtype, device): diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index c8a0c222a..2dce0623a 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -61,7 +61,6 @@ import torch.distributed as dist from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed.parallel_state import destroy_distributed_environment from sglang.srt.entrypoints.engine import _set_envs_and_config -from sglang.srt.environ import envs from sglang.srt.layers.moe import initialize_moe_config from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.managers.scheduler import Scheduler diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py index b9f399899..93d7b61a6 100644 --- a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py +++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py @@ -17,10 +17,10 @@ if is_cuda(): except ImportError as e: deep_gemm = e +from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers.attention.nsa.utils import NSA_DUAL_STREAM from sglang.srt.layers.dp_attention import get_attention_tp_group from sglang.srt.layers.linear import ReplicatedLinear -from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.rotary_embedding import get_rope_wrapper from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode diff --git a/python/sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py b/python/sglang/srt/layers/deep_gemm_wrapper/__init__.py similarity index 100% rename from python/sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py rename to python/sglang/srt/layers/deep_gemm_wrapper/__init__.py diff --git a/python/sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py b/python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py similarity index 98% rename from python/sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py rename to python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py index 0f4aa9449..202801b4e 100644 --- a/python/sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +++ b/python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py @@ -8,9 +8,7 @@ import torch from tqdm import tqdm from sglang.srt.environ import envs -from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import ( - ENABLE_JIT_DEEPGEMM, -) +from sglang.srt.layers.deep_gemm_wrapper.configurer import ENABLE_JIT_DEEPGEMM from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ceil_div, get_bool_env_var diff --git a/python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py b/python/sglang/srt/layers/deep_gemm_wrapper/configurer.py similarity index 100% rename from python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py rename to python/sglang/srt/layers/deep_gemm_wrapper/configurer.py diff --git a/python/sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py b/python/sglang/srt/layers/deep_gemm_wrapper/entrypoint.py similarity index 94% rename from python/sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py rename to python/sglang/srt/layers/deep_gemm_wrapper/entrypoint.py index 1f2f4542a..bf2ab4800 100644 --- a/python/sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +++ b/python/sglang/srt/layers/deep_gemm_wrapper/entrypoint.py @@ -4,8 +4,8 @@ from typing import Tuple import torch -from sglang.srt.layers.quantization.deep_gemm_wrapper import compile_utils -from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import ( # noqa: F401 +from sglang.srt.layers.deep_gemm_wrapper import compile_utils +from sglang.srt.layers.deep_gemm_wrapper.configurer import ( # noqa: F401 DEEPGEMM_BLACKWELL, DEEPGEMM_SCALE_UE8M0, ENABLE_JIT_DEEPGEMM, diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 4ecc5535b..0aa24f461 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import torch from sglang.srt import single_batch_overlap +from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers.moe import ( get_deepep_mode, get_moe_a2a_backend, @@ -19,7 +20,6 @@ from sglang.srt.layers.moe.ep_moe.kernels import ( tma_align_input_scale, ) from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE -from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.fp8 import Fp8Config from sglang.srt.layers.quantization.fp8_kernel import ( diff --git a/python/sglang/srt/layers/moe/moe_runner/deep_gemm.py b/python/sglang/srt/layers/moe/moe_runner/deep_gemm.py index 9bc3824b9..8977955d4 100644 --- a/python/sglang/srt/layers/moe/moe_runner/deep_gemm.py +++ b/python/sglang/srt/layers/moe/moe_runner/deep_gemm.py @@ -105,10 +105,10 @@ class DeepGemmRunnerCore(MoeRunnerCore): running_state: dict, ) -> torch.Tensor: + from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers.moe.ep_moe.kernels import ( silu_and_mul_masked_post_quant_fwd, ) - from sglang.srt.layers.quantization import deep_gemm_wrapper hidden_states = runner_input.hidden_states hidden_states_scale = runner_input.hidden_states_scale diff --git a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py index c944ef679..dd94d4464 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder +from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers.moe.token_dispatcher.base import ( BaseDispatcher, BaseDispatcherConfig, @@ -20,7 +21,6 @@ from sglang.srt.layers.moe.utils import ( get_moe_runner_backend, is_tbo_enabled, ) -from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.utils import ( get_bool_env_var, get_int_env_var, diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index a1a25102d..bad1e2b87 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -1007,11 +1007,11 @@ class Fp8MoEMethod(FusedMoEMethodBase): self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig ): + from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers.moe.utils import ( get_moe_a2a_backend, get_moe_runner_backend, ) - from sglang.srt.layers.quantization import deep_gemm_wrapper self.moe_runner_config = moe_runner_config moe_runner_backend = get_moe_runner_backend() diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index bd9628916..9ac766a23 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -23,7 +23,7 @@ import torch import triton import triton.language as tl -from sglang.srt.layers.quantization import deep_gemm_wrapper +from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.utils import ( align, direct_register_custom_op, diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index fc50c1f54..dc70c53b3 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -2,7 +2,7 @@ from typing import Callable, List, Optional, Tuple import torch -from sglang.srt.layers.quantization import deep_gemm_wrapper +from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8 from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil from sglang.srt.utils import is_sm100_supported, offloader diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index ef780899d..23faf600b 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -64,6 +64,7 @@ from sglang.srt.eplb.expert_location import ( set_global_expert_location_metadata, ) from sglang.srt.eplb.expert_location_updater import ExpertLocationUpdater +from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers.attention.attention_registry import ( ATTENTION_BACKENDS, attn_backend_wrapper, @@ -75,10 +76,7 @@ from sglang.srt.layers.dp_attention import ( initialize_dp_attention, ) from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.layers.quantization import ( - deep_gemm_wrapper, - monkey_patch_isinstance_for_vllm_base_layer, -) +from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer from sglang.srt.layers.sampler import Sampler from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.lora.lora_manager import LoRAManager diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 5bf8da10e..327e04c65 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -28,7 +28,6 @@ import torch.nn.functional as F from torch import nn from transformers import PretrainedConfig -from sglang.srt import single_batch_overlap from sglang.srt.configs.model_config import ( get_nsa_index_head_dim, get_nsa_index_n_heads, @@ -48,6 +47,7 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import ( from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo +from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.amx_utils import PackWeightMethod from sglang.srt.layers.attention.npu_ops.mla_preprocess import ( @@ -82,7 +82,6 @@ from sglang.srt.layers.moe import ( from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat -from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.fp8_kernel import ( is_fp8_fnuz, diff --git a/python/sglang/srt/models/longcat_flash.py b/python/sglang/srt/models/longcat_flash.py index ffca2bad0..88a8cad3a 100644 --- a/python/sglang/srt/models/longcat_flash.py +++ b/python/sglang/srt/models/longcat_flash.py @@ -44,6 +44,7 @@ from sglang.srt.distributed import ( ) from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation +from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes from sglang.srt.layers.dp_attention import ( @@ -62,7 +63,6 @@ from sglang.srt.layers.moe.ep_moe.kernels import zero_experts_compute_triton from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.topk import StandardTopKOutput, TopK -from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz from sglang.srt.layers.quantization.fp8_utils import ( diff --git a/python/sglang/srt/models/longcat_flash_nextn.py b/python/sglang/srt/models/longcat_flash_nextn.py index a6092785a..cae974815 100644 --- a/python/sglang/srt/models/longcat_flash_nextn.py +++ b/python/sglang/srt/models/longcat_flash_nextn.py @@ -39,6 +39,7 @@ from torch import nn from sglang.srt.configs import LongcatFlashConfig from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder +from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes from sglang.srt.layers.dp_attention import ( get_attention_tp_rank, @@ -48,7 +49,6 @@ from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ReplicatedLinear from sglang.srt.layers.logits_processor import LogitsProcessor -from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz from sglang.srt.layers.quantization.fp8_utils import ( diff --git a/python/sglang/srt/single_batch_overlap.py b/python/sglang/srt/single_batch_overlap.py index dd2be4885..885f750ff 100644 --- a/python/sglang/srt/single_batch_overlap.py +++ b/python/sglang/srt/single_batch_overlap.py @@ -3,9 +3,9 @@ from typing import TYPE_CHECKING, Any, Callable, Optional import torch +from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers.moe import get_moe_runner_backend from sglang.srt.layers.moe.utils import is_sbo_enabled -from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.utils import get_int_env_var diff --git a/python/sglang/srt/two_batch_overlap.py b/python/sglang/srt/two_batch_overlap.py index b09c72dae..69d3f03c1 100644 --- a/python/sglang/srt/two_batch_overlap.py +++ b/python/sglang/srt/two_batch_overlap.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence import torch +from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.communicator import ( CommunicateContext, @@ -24,7 +25,6 @@ from sglang.srt.layers.moe.token_dispatcher import ( DeepEPDispatcher, MooncakeEPDispatcher, ) -from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.model_executor.forward_batch_info import ( ForwardBatch,