[Refactor] move deep_gemm_wrapper out of quantization (#11784)
This commit is contained in:
@@ -61,7 +61,6 @@ import torch.distributed as dist
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.distributed.parallel_state import destroy_distributed_environment
|
||||
from sglang.srt.entrypoints.engine import _set_envs_and_config
|
||||
from sglang.srt.environ import envs
|
||||
from sglang.srt.layers.moe import initialize_moe_config
|
||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||
from sglang.srt.managers.scheduler import Scheduler
|
||||
|
||||
@@ -17,10 +17,10 @@ if is_cuda():
|
||||
except ImportError as e:
|
||||
deep_gemm = e
|
||||
|
||||
from sglang.srt.layers import deep_gemm_wrapper
|
||||
from sglang.srt.layers.attention.nsa.utils import NSA_DUAL_STREAM
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_group
|
||||
from sglang.srt.layers.linear import ReplicatedLinear
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.rotary_embedding import get_rope_wrapper
|
||||
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||
|
||||
@@ -8,9 +8,7 @@ import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from sglang.srt.environ import envs
|
||||
from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
|
||||
ENABLE_JIT_DEEPGEMM,
|
||||
)
|
||||
from sglang.srt.layers.deep_gemm_wrapper.configurer import ENABLE_JIT_DEEPGEMM
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import ceil_div, get_bool_env_var
|
||||
|
||||
@@ -4,8 +4,8 @@ from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.quantization.deep_gemm_wrapper import compile_utils
|
||||
from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import ( # noqa: F401
|
||||
from sglang.srt.layers.deep_gemm_wrapper import compile_utils
|
||||
from sglang.srt.layers.deep_gemm_wrapper.configurer import ( # noqa: F401
|
||||
DEEPGEMM_BLACKWELL,
|
||||
DEEPGEMM_SCALE_UE8M0,
|
||||
ENABLE_JIT_DEEPGEMM,
|
||||
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
import torch
|
||||
|
||||
from sglang.srt import single_batch_overlap
|
||||
from sglang.srt.layers import deep_gemm_wrapper
|
||||
from sglang.srt.layers.moe import (
|
||||
get_deepep_mode,
|
||||
get_moe_a2a_backend,
|
||||
@@ -19,7 +20,6 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
tma_align_input_scale,
|
||||
)
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
|
||||
@@ -105,10 +105,10 @@ class DeepGemmRunnerCore(MoeRunnerCore):
|
||||
running_state: dict,
|
||||
) -> torch.Tensor:
|
||||
|
||||
from sglang.srt.layers import deep_gemm_wrapper
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
silu_and_mul_masked_post_quant_fwd,
|
||||
)
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
|
||||
hidden_states = runner_input.hidden_states
|
||||
hidden_states_scale = runner_input.hidden_states_scale
|
||||
|
||||
@@ -6,6 +6,7 @@ from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union
|
||||
|
||||
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||
from sglang.srt.layers import deep_gemm_wrapper
|
||||
from sglang.srt.layers.moe.token_dispatcher.base import (
|
||||
BaseDispatcher,
|
||||
BaseDispatcherConfig,
|
||||
@@ -20,7 +21,6 @@ from sglang.srt.layers.moe.utils import (
|
||||
get_moe_runner_backend,
|
||||
is_tbo_enabled,
|
||||
)
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
get_int_env_var,
|
||||
|
||||
@@ -1007,11 +1007,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||
):
|
||||
|
||||
from sglang.srt.layers import deep_gemm_wrapper
|
||||
from sglang.srt.layers.moe.utils import (
|
||||
get_moe_a2a_backend,
|
||||
get_moe_runner_backend,
|
||||
)
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
|
||||
self.moe_runner_config = moe_runner_config
|
||||
moe_runner_backend = get_moe_runner_backend()
|
||||
|
||||
@@ -23,7 +23,7 @@ import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.layers import deep_gemm_wrapper
|
||||
from sglang.srt.utils import (
|
||||
align,
|
||||
direct_register_custom_op,
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.layers import deep_gemm_wrapper
|
||||
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
|
||||
from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
|
||||
from sglang.srt.utils import is_sm100_supported, offloader
|
||||
|
||||
@@ -64,6 +64,7 @@ from sglang.srt.eplb.expert_location import (
|
||||
set_global_expert_location_metadata,
|
||||
)
|
||||
from sglang.srt.eplb.expert_location_updater import ExpertLocationUpdater
|
||||
from sglang.srt.layers import deep_gemm_wrapper
|
||||
from sglang.srt.layers.attention.attention_registry import (
|
||||
ATTENTION_BACKENDS,
|
||||
attn_backend_wrapper,
|
||||
@@ -75,10 +76,7 @@ from sglang.srt.layers.dp_attention import (
|
||||
initialize_dp_attention,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.quantization import (
|
||||
deep_gemm_wrapper,
|
||||
monkey_patch_isinstance_for_vllm_base_layer,
|
||||
)
|
||||
from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer
|
||||
from sglang.srt.layers.sampler import Sampler
|
||||
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
||||
from sglang.srt.lora.lora_manager import LoRAManager
|
||||
|
||||
@@ -28,7 +28,6 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt import single_batch_overlap
|
||||
from sglang.srt.configs.model_config import (
|
||||
get_nsa_index_head_dim,
|
||||
get_nsa_index_n_heads,
|
||||
@@ -48,6 +47,7 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
||||
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
||||
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
||||
from sglang.srt.layers import deep_gemm_wrapper
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.amx_utils import PackWeightMethod
|
||||
from sglang.srt.layers.attention.npu_ops.mla_preprocess import (
|
||||
@@ -82,7 +82,6 @@ from sglang.srt.layers.moe import (
|
||||
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
is_fp8_fnuz,
|
||||
|
||||
@@ -44,6 +44,7 @@ from sglang.srt.distributed import (
|
||||
)
|
||||
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
||||
from sglang.srt.layers import deep_gemm_wrapper
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
@@ -62,7 +63,6 @@ from sglang.srt.layers.moe.ep_moe.kernels import zero_experts_compute_triton
|
||||
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
from sglang.srt.layers.moe.topk import StandardTopKOutput, TopK
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
|
||||
from sglang.srt.layers.quantization.fp8_utils import (
|
||||
|
||||
@@ -39,6 +39,7 @@ from torch import nn
|
||||
|
||||
from sglang.srt.configs import LongcatFlashConfig
|
||||
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||
from sglang.srt.layers import deep_gemm_wrapper
|
||||
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
get_attention_tp_rank,
|
||||
@@ -48,7 +49,6 @@ from sglang.srt.layers.dp_attention import (
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import ReplicatedLinear
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
|
||||
from sglang.srt.layers.quantization.fp8_utils import (
|
||||
|
||||
@@ -3,9 +3,9 @@ from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers import deep_gemm_wrapper
|
||||
from sglang.srt.layers.moe import get_moe_runner_backend
|
||||
from sglang.srt.layers.moe.utils import is_sbo_enabled
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import get_int_env_var
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers import deep_gemm_wrapper
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.communicator import (
|
||||
CommunicateContext,
|
||||
@@ -24,7 +25,6 @@ from sglang.srt.layers.moe.token_dispatcher import (
|
||||
DeepEPDispatcher,
|
||||
MooncakeEPDispatcher,
|
||||
)
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
ForwardBatch,
|
||||
|
||||
Reference in New Issue
Block a user