diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index acc8b0e68..153e147cd 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -229,6 +229,14 @@ class Envs: SGLANG_IMAGE_MAX_PIXELS = EnvInt(16384 * 28 * 28) SGLANG_RESIZE_RESAMPLE = EnvStr("") + # Ktransformers + SGLANG_KT_MOE_NUM_GPU_EXPERTS = EnvInt(None) + SGLANG_KT_MOE_CPUINFER = EnvInt(None) + SGLANG_KT_THREADPOOL_COUNT = EnvInt(None) + SGLANG_KT_MOE_AMX_WEIGHT_PATH = EnvStr(None) + SGLANG_KT_AMX_METHOD = EnvStr(None) + SGLANG_KT_MOE_CHUNKED_PREFILL_SIZE = EnvInt(None) + # fmt: on diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 84a35b96a..825ba46c6 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -33,6 +33,11 @@ from sglang.srt.layers.quantization.base_config import ( FusedMoEMethodBase, QuantizationConfig, ) +from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe import ( + CompressedTensorsWNA16AMXEPMoEMethod, + CompressedTensorsWNA16AMXMoEMethod, + CompressedTensorsWNA16MoEMethod, +) from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod @@ -150,7 +155,6 @@ class FusedMoE(torch.nn.Module): with_bias=False, ): super().__init__() - if params_dtype is None: params_dtype = torch.get_default_dtype() @@ -227,6 +231,8 @@ class FusedMoE(torch.nn.Module): if not use_weight_loader_fused else self.weight_loader_fused ), + intermediate_size_full=intermediate_size, + top_k=top_k, with_bias=with_bias, ) @@ -542,6 +548,18 @@ class FusedMoE(torch.nn.Module): if expert_id == -1: return + if isinstance( + self.quant_method, + ( + CompressedTensorsWNA16MoEMethod, + CompressedTensorsWNA16AMXMoEMethod, + CompressedTensorsWNA16AMXEPMoEMethod, + ), + ): + if self.quant_method.num_gpu_experts != -1: + if expert_id >= self.quant_method.num_gpu_experts: + return + self._weight_loader_impl( param=param, loaded_weight=loaded_weight, @@ -568,7 +586,12 @@ class FusedMoE(torch.nn.Module): loaded_weight.t().contiguous() if ( self.quant_method.__class__.__name__ - == "CompressedTensorsWNA16MoEMethod" + in [ + "CompressedTensorsWNA16MarlinMoEMethod", + "CompressedTensorsWNA16MoEMethod", + "CompressedTensorsWNA16AMXMoEMethod", + "CompressedTensorsWNA16AMXEPMoEMethod", + ] ) else loaded_weight ) @@ -827,7 +850,6 @@ class FusedMoE(torch.nn.Module): dispatch_output=dispatch_output, **kwargs, ) - final_hidden_states = self.dispatcher.combine(combine_input) # TODO: should we add some conditions here? diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/__init__.py b/python/sglang/srt/layers/quantization/compressed_tensors/__init__.py index e69de29bb..8d8d1e8e1 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/__init__.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/__init__.py @@ -0,0 +1,7 @@ +class scalar_types: + uint4b8 = "uint4b8" + uint8b128 = "uint8b128" + + +WNA16_SUPPORTED_TYPES_MAP = {4: scalar_types.uint4b8, 8: scalar_types.uint8b128} +WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys()) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py index 14822c9e7..fde541e19 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py @@ -19,11 +19,13 @@ from compressed_tensors.quantization import ( ) from pydantic import BaseModel +from sglang.srt.environ import envs from sglang.srt.layers.quantization.base_config import ( LinearMethodBase, QuantizationConfig, QuantizeMethodBase, ) +from sglang.srt.layers.quantization.compressed_tensors import WNA16_SUPPORTED_BITS from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501 CompressedTensorsMoEMethod, ) @@ -38,6 +40,7 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import ( is_activation_quantization_format, should_ignore_layer, ) +from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod try: @@ -76,6 +79,7 @@ class DeviceCapability(NamedTuple): class CompressedTensorsConfig(QuantizationConfig): + DeepSeekFP8Config = None def __init__( self, @@ -129,6 +133,10 @@ class CompressedTensorsConfig(QuantizationConfig): ): return UnquantizedLinearMethod() if isinstance(layer, LinearBase): + if CompressedTensorsConfig.DeepSeekFP8Config is not None: + return Fp8LinearMethod(CompressedTensorsConfig.DeepSeekFP8Config) + if envs.SGLANG_KT_MOE_AMX_WEIGHT_PATH.is_set(): + return UnquantizedLinearMethod() scheme = self.get_scheme(layer=layer, layer_name=prefix) if scheme is None: return UnquantizedLinearMethod() @@ -137,7 +145,8 @@ class CompressedTensorsConfig(QuantizationConfig): from sglang.srt.layers.moe.fused_moe_triton import FusedMoE if isinstance(layer, FusedMoE): - return CompressedTensorsMoEMethod.get_moe_method(self) + # Ktransformers use CompressedTensorsWNA16AMXMOEMethod if AMX weights are provided + return CompressedTensorsMoEMethod.get_moe_method(self, layer, prefix) return None @classmethod diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 3517bc5e2..6a7696c1d 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -4,16 +4,34 @@ from __future__ import annotations import enum import logging +import re from enum import Enum from typing import TYPE_CHECKING, List +try: + from sgl_kernel import fused_marlin_moe + + FUSED_MARLIN_MOE_AVAILABLE = True +except ImportError: + FUSED_MARLIN_MOE_AVAILABLE = False + +try: + from kt_kernel import AMXMoEWrapper + + KTRANSFORMERS_AVAILABLE = True +except ImportError: + KTRANSFORMERS_AVAILABLE = False + import torch from compressed_tensors import CompressionFormat from compressed_tensors.quantization import QuantizationStrategy +from sglang.srt.distributed import get_tensor_model_parallel_rank +from sglang.srt.environ import envs from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase +from sglang.srt.layers.quantization.compressed_tensors import WNA16_SUPPORTED_BITS from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz from sglang.srt.layers.quantization.utils import ( @@ -21,7 +39,12 @@ from sglang.srt.layers.quantization.utils import ( per_tensor_dequantize, replace_parameter, ) -from sglang.srt.utils import get_bool_env_var, is_hip, set_weight_attrs +from sglang.srt.utils import ( + get_bool_env_var, + get_compiler_backend, + is_hip, + set_weight_attrs, +) if TYPE_CHECKING: from sglang.srt.layers.moe.fused_moe_triton import FusedMoE @@ -51,6 +74,18 @@ except ImportError: logger = logging.getLogger(__name__) +def _mask_topk_ids_cpu_experts(topk_ids: torch.Tensor, num_gpu_experts: int): + """Mask topk_ids >= num_gpu_experts by setting them to -1.""" + topk_ids[topk_ids >= num_gpu_experts] = -1 + + +@torch.compile(dynamic=True, backend=get_compiler_backend()) +def mask_cpu_expert_ids(topk_ids: torch.Tensor, num_gpu_experts: int): + """mask CPU expert IDs.""" + _mask_topk_ids_cpu_experts(topk_ids, num_gpu_experts) + return topk_ids + + class GPTQMarlinState(Enum): REPACK = enum.auto() READY = enum.auto() @@ -60,6 +95,7 @@ __all__ = [ "CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod", "CompressedTensorsWNA16MoEMethod", + "CompressedTensorsWNA16AMXEPMoEMethod", # for Ktransformers ] @@ -72,12 +108,24 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): @staticmethod def get_moe_method( quant_config: CompressedTensorsConfig, + layer: torch.nn.Module, + prefix: str, ) -> "CompressedTensorsMoEMethod": # TODO: @dsikka: refactor this to use schemes as other kernels # are supported + check if the layer is being ignored. + + if envs.SGLANG_KT_MOE_AMX_WEIGHT_PATH.is_set(): + match = re.search(r"(\d+)\.mlp", prefix) + if not match: + raise ValueError( + f"Unable to extract layer number from prefix '{prefix}'. " + f"Expected format: '.mlp'" + ) + layer_number = int(match.group(1)) + return CompressedTensorsWNA16AMXEPMoEMethod(quant_config, layer_number) + weight_quant = quant_config.target_scheme_map["Linear"].get("weights") input_quant = quant_config.target_scheme_map["Linear"].get("input_activations") - if quant_config._is_wNa16_group_channel(weight_quant, input_quant): if not VLLM_AVAILABLE: raise ImportError( @@ -201,7 +249,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): layer.w13_input_scale = None layer.w2_input_scale = None - def process_weights_after_loading(self, layer: FusedMoE) -> None: + def process_weights_after_loading(self, layer: torch.nn.Module | FusedMoE) -> None: # Fp8 moe kernels require a single activation scale. # We take the max of all the scales in case they differ. if self.static_input_scales: @@ -349,7 +397,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): - def __init__(self, quant_config: CompressedTensorsConfig): + def __init__(self, quant_config: CompressedTensorsConfig, num_gpu_experts=-1): self.quant_config = quant_config # TODO: @dsikka: refactor this to use schemes as other kernels # are supported + check if the layer is being ignored. @@ -371,6 +419,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): "is supported for the following bits: ", f"{WNA16_SUPPORTED_BITS}", ) + self.num_gpu_experts = num_gpu_experts def create_weights( self, @@ -381,10 +430,11 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): params_dtype: torch.dtype, **extra_weight_attrs, ): - - assert ( - params_dtype == torch.float16 - ), "float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501 + if self.num_gpu_experts != -1: + num_experts = self.num_gpu_experts + # assert ( + # params_dtype == torch.float16 + # ), "float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501 # Will transpose the loaded weight along the # intermediate and hidden dim sizes. Will @@ -683,3 +733,353 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): is_k_full=self.is_k_full, ) return StandardCombineInput(hidden_states=output) + + +class CompressedTensorsWNA16AMXMoEMethod(CompressedTensorsMoEMethod): + """AMX MoE method using AMXMoEWrapper for CPU inference.""" + + def __init__( + self, + quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 + layer_idx, + num_gpu_experts, + cpuinfer, + threadpool_count, + amx_weight_path, + chunked_prefill_size, + ): + + if not KTRANSFORMERS_AVAILABLE: + raise ImportError( + "kt_kernel is not installed, to use CompressedTensorsWNA16AMXEPMoEMethod, please install kt_kernel." + ) + + if not FUSED_MARLIN_MOE_AVAILABLE: + raise ImportError("fused_marlin_moe is not available") + + self.tp_rank = get_tensor_model_parallel_rank() + self.layer_idx = layer_idx + self.num_gpu_experts = num_gpu_experts + self.amx_weight_path = amx_weight_path + self.chunked_prefill_size = chunked_prefill_size + self.cpuinfer = cpuinfer + self.threadpool_count = threadpool_count + self.amx_wrapper = None + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + self.experts_num = num_experts + self.num_experts_per_tok = extra_weight_attrs.pop("top_k") + self.hidden_size = hidden_size + self.moe_intermediate_size = extra_weight_attrs.pop("intermediate_size_full") + + if self.tp_rank != 0: + return + self.amx_wrapper = AMXMoEWrapper( + layer_idx=self.layer_idx, + num_experts=num_experts, + num_experts_per_tok=self.num_experts_per_tok, + hidden_size=hidden_size, + moe_intermediate_size=self.moe_intermediate_size, + num_gpu_experts=self.num_gpu_experts, + cpuinfer_threads=self.cpuinfer, + threadpool_count=self.threadpool_count, + amx_weight_path=self.amx_weight_path, + chunked_prefill_size=self.chunked_prefill_size, + ) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if self.tp_rank != 0: + return + + if self.amx_wrapper is None: + raise RuntimeError( + "AMXMoEWrapper not initialized. Call create_weights first." + ) + + torch.cuda.synchronize() + # Load weights using wrapper + from sglang.srt.eplb.expert_location_dispatch import ( + get_global_expert_location_metadata, + ) + + physical_to_logical_map_cpu = ( + get_global_expert_location_metadata() + .physical_to_logical_map_cpu[self.layer_idx] + .contiguous() + ) + self.amx_wrapper.load_weights(physical_to_logical_map_cpu) + + def submit( + self, + layer: torch.nn.Module, + dispatch_output: StandardDispatchOutput, + ) -> None: + """Submit AMX inference task asynchronously.""" + assert ( + self.moe_runner_config.activation == "silu" + ), "Only SiLU activation is supported." + + x = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + topk_weights, topk_ids, _ = topk_output + + if self.tp_rank != 0 or self.amx_wrapper is None: + return None + + # Submit forward task using wrapper + self.amx_wrapper.submit_forward( + x, topk_ids, topk_weights, torch.cuda.current_stream(x.device).cuda_stream + ) + return None + + def sync(self, x): + """Synchronize and retrieve AMX inference results.""" + if self.tp_rank != 0 or self.amx_wrapper is None: + return torch.zeros_like(x) + + # Sync forward task using wrapper + return self.amx_wrapper.sync_forward( + x, torch.cuda.current_stream(x.device).cuda_stream + ) + + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + + def apply( + self, + layer: torch.nn.Module, + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: + """Execute AMX MoE forward pass synchronously.""" + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput + + assert ( + self.moe_runner_config.activation == "silu" + ), "Only SiLU activation is supported." + + x = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + topk_weights, topk_ids, _ = topk_output + + if self.tp_rank != 0 or self.amx_wrapper is None: + return StandardCombineInput(hidden_states=torch.zeros_like(x)) + + # Execute forward using wrapper (submit + sync) + output = self.amx_wrapper.forward( + x, topk_ids, topk_weights, torch.cuda.current_stream(x.device).cuda_stream + ) + return StandardCombineInput(hidden_states=output) + + +def override_config( + cls, + num_gpu_experts, + cpuinfer, + threadpool_count, + amx_weight_path, + amx_method, + chunked_prefill_size, +): + """Override MOE configuration via environment variables.""" + # Set environment variables using envs utility class + if num_gpu_experts is not None: + envs.SGLANG_KT_MOE_NUM_GPU_EXPERTS.set(num_gpu_experts) + if cpuinfer is not None: + envs.SGLANG_KT_MOE_CPUINFER.set(cpuinfer) + if threadpool_count is not None: + envs.SGLANG_KT_THREADPOOL_COUNT.set(threadpool_count) + if amx_weight_path is not None: + envs.SGLANG_KT_MOE_AMX_WEIGHT_PATH.set(amx_weight_path) + if amx_method is not None: + envs.SGLANG_KT_AMX_METHOD.set(amx_method) + if chunked_prefill_size is not None: + envs.SGLANG_KT_MOE_CHUNKED_PREFILL_SIZE.set(chunked_prefill_size) + + +class CompressedTensorsWNA16AMXEPMoEMethod(CompressedTensorsMoEMethod): + + def __init__( + self, + quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 + layer_idx, + ): + self.tp_rank = get_tensor_model_parallel_rank() + + if ( + not envs.SGLANG_KT_MOE_NUM_GPU_EXPERTS.is_set() + or not envs.SGLANG_KT_MOE_CPUINFER.is_set() + or not envs.SGLANG_KT_MOE_AMX_WEIGHT_PATH.is_set() + ): + raise RuntimeError( + "the following arguments are required: --kt-amx-weight-path, --kt-cpuinfer, --kt-num-gpu-experts" + ) + self.num_gpu_experts = envs.SGLANG_KT_MOE_NUM_GPU_EXPERTS.value + cpuinfer = envs.SGLANG_KT_MOE_CPUINFER.value + threadpool_count = envs.SGLANG_KT_THREADPOOL_COUNT.value + amx_weight_path = envs.SGLANG_KT_MOE_AMX_WEIGHT_PATH.value + chunked_prefill_size = envs.SGLANG_KT_MOE_CHUNKED_PREFILL_SIZE.value + + self.AMX_method = CompressedTensorsWNA16AMXMoEMethod( + quant_config, + layer_idx, + self.num_gpu_experts, + cpuinfer, + threadpool_count, + amx_weight_path, + chunked_prefill_size, + ) + self.marlin_method = CompressedTensorsWNA16MoEMethod( + quant_config, self.num_gpu_experts + ) + self.layer_id = layer_idx + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + self.global_num_experts = num_experts + self.AMX_method.create_weights( + layer, + num_experts, + hidden_size, + intermediate_size_per_partition, + params_dtype, + **extra_weight_attrs, + ) + self.marlin_method.create_weights( + layer, + num_experts, + hidden_size, + intermediate_size_per_partition, + params_dtype, + **extra_weight_attrs, + ) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.AMX_method.process_weights_after_loading(layer) + self.marlin_method.process_weights_after_loading(layer) + + def submit( + self, + layer: torch.nn.Module, + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: + """Submit hybrid GPU+CPU MoE task (AMX submission + GPU execution).""" + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput + + assert ( + self.moe_runner_config.activation == "silu" + ), "Only SiLU activation is supported." + + x = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + + topk_weights, topk_ids, router_logits = topk_output + + # Submit AMX task if on rank 0 + if self.tp_rank == 0: + self.AMX_method.submit(layer, dispatch_output) + + # Mask CPU expert IDs (>= num_gpu_experts) as -1 so they won't be computed on GPU + topk_ids = mask_cpu_expert_ids(topk_ids, self.num_gpu_experts) + + # Execute GPU (Marlin) experts + output = fused_marlin_moe( + x, + layer.w13_weight_packed, + layer.w2_weight_packed, + layer.w13_weight_scale, + layer.w2_weight_scale, + router_logits, + topk_weights, + topk_ids, + g_idx1=layer.w13_weight_g_idx, + g_idx2=layer.w2_weight_g_idx, + sort_indices1=layer.w13_g_idx_sort_indices, + sort_indices2=layer.w2_g_idx_sort_indices, + num_bits=self.marlin_method.num_bits, + is_k_full=self.marlin_method.is_k_full, + global_num_experts=self.global_num_experts, + expert_map=torch.empty(1, device=x.device), + ) + return StandardCombineInput(hidden_states=output) + + def sync(self, x): + """Synchronize and retrieve AMX results.""" + if self.tp_rank != 0: + return torch.zeros_like(x) + return self.AMX_method.sync(x) + + def apply( + self, + layer: torch.nn.Module, + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: + """Execute hybrid GPU+CPU MoE forward pass with parallelism.""" + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput + + assert ( + self.moe_runner_config.activation == "silu" + ), "Only SiLU activation is supported." + + x = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + topk_weights, topk_ids, router_logits = topk_output + + # Step 1: Submit AMX task (non-blocking) if on rank 0 + # This starts CPU computation in parallel + if self.tp_rank == 0: + self.AMX_method.submit(layer, dispatch_output) + + # Step 2: Execute GPU (Marlin) experts in parallel with CPU + + # Mask CPU expert IDs (>= num_gpu_experts) as -1 so they won't be computed on GPU + topk_ids = mask_cpu_expert_ids(topk_ids, self.num_gpu_experts) + + # While GPU computes, CPU is also computing + output = fused_marlin_moe( + x, + layer.w13_weight_packed, + layer.w2_weight_packed, + layer.w13_weight_scale, + layer.w2_weight_scale, + router_logits, + topk_weights, + topk_ids, + g_idx1=layer.w13_weight_g_idx, + g_idx2=layer.w2_weight_g_idx, + sort_indices1=layer.w13_g_idx_sort_indices, + sort_indices2=layer.w2_g_idx_sort_indices, + num_bits=self.marlin_method.num_bits, + is_k_full=self.marlin_method.is_k_full, + global_num_experts=self.global_num_experts, + expert_map=torch.empty(1, device=x.device), + ) + + # Step 3: Sync AMX results and combine with GPU results + if self.tp_rank == 0: + amx_output = self.AMX_method.sync(x) + output += amx_output + + return StandardCombineInput(hidden_states=output) + + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + self.AMX_method.create_moe_runner(layer, moe_runner_config) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 9eabae6d5..0652c6c5b 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -65,6 +65,13 @@ from sglang.srt.utils import ( ) from sglang.srt.utils.patch_torch import monkey_patch_torch_compile +try: + from kt_kernel import AMXMoEWrapper + + KTRANSFORMERS_AVAILABLE = True +except ImportError: + KTRANSFORMERS_AVAILABLE = False + _is_hip = is_hip() logger = logging.getLogger(__name__) @@ -248,6 +255,8 @@ class CudaGraphRunner: # Batch sizes to capture self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner) log_info_on_rank0(logger, f"Capture cuda graph bs {self.capture_bs}") + if KTRANSFORMERS_AVAILABLE: + AMXMoEWrapper.set_capture_batch_sizes(self.capture_bs) self.capture_forward_mode = ForwardMode.DECODE self.capture_hidden_mode = CaptureHiddenMode.NULL self.num_tokens_per_bs = 1 diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 76a946757..5d4dd5325 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -44,6 +44,7 @@ from sglang.srt.distributed import ( from sglang.srt.distributed.device_communicators.pynccl_allocator import ( use_symmetric_memory, ) +from sglang.srt.environ import envs from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo @@ -81,7 +82,12 @@ from sglang.srt.layers.moe import ( from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat +from sglang.srt.layers.quantization import CompressedTensorsConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe import ( + CompressedTensorsWNA16AMXEPMoEMethod, +) +from sglang.srt.layers.quantization.fp8 import Fp8Config from sglang.srt.layers.quantization.fp8_kernel import ( is_fp8_fnuz, per_tensor_quant_mla_fp8, @@ -707,6 +713,10 @@ class DeepseekV2MoE(nn.Module): # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states, gemm_output_zero_allocator) topk_output = self.topk(hidden_states, router_logits) + if isinstance( + self.experts.quant_method, CompressedTensorsWNA16AMXEPMoEMethod + ): + topk_output.topk_weights.mul_(self.routed_scaling_factor) final_hidden_states = self.experts(hidden_states, topk_output) if not _is_cuda: final_hidden_states *= self.routed_scaling_factor @@ -2837,6 +2847,10 @@ class DeepseekV2ForCausalLM(nn.Module): self.config = config self.tp_size = get_tensor_model_parallel_world_size() self.quant_config = quant_config + if envs.SGLANG_KT_MOE_AMX_WEIGHT_PATH.is_set(): + CompressedTensorsConfig.DeepSeekFP8Config = Fp8Config( + True, "dynamic", None, [128, 128] + ) self.determine_num_fused_shared_experts() self.model = DeepseekV2Model( config, quant_config, prefix=add_prefix("model", prefix) @@ -2976,11 +2990,13 @@ class DeepseekV2ForCausalLM(nn.Module): torch.float8_e4m3fn, torch.float8_e4m3fnuz, ): - if ( - hasattr(self.quant_config, "weight_block_size") - and self.quant_config.weight_block_size is not None - ): - weight_block_size = self.quant_config.weight_block_size + selected_quant_config = getattr( + self.quant_config, "DeepSeekFP8Config", self.quant_config + ) + weight_block_size = getattr( + selected_quant_config, "weight_block_size", None + ) + if weight_block_size is not None: assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") if _is_fp8_fnuz: weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( diff --git a/python/sglang/srt/models/qwen3_next.py b/python/sglang/srt/models/qwen3_next.py index 9fe9e7748..62cf15af7 100644 --- a/python/sglang/srt/models/qwen3_next.py +++ b/python/sglang/srt/models/qwen3_next.py @@ -520,6 +520,7 @@ class Qwen3HybridLinearDecoderLayer(nn.Module): config=config, quant_config=quant_config, alt_stream=alt_stream, + prefix=add_prefix("mlp", prefix), ) else: self.mlp = Qwen2MoeMLP( @@ -673,6 +674,7 @@ class Qwen3HybridAttentionDecoderLayer(nn.Module): config=config, quant_config=quant_config, alt_stream=alt_stream, + prefix=add_prefix("mlp", prefix), ) else: self.mlp = Qwen2MoeMLP( diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 2e432de1b..ff5e58dc2 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -91,6 +91,7 @@ QUANTIZATION_CHOICES = [ "qoq", "w4afp8", "mxfp4", + "compressed-tensors", # for Ktransformers ] ATTENTION_BACKEND_CHOICES = [ @@ -389,6 +390,13 @@ class ServerArgs: # LMCache enable_lmcache: bool = False + # Ktransformers + kt_amx_weight_path: Optional[str] = None + kt_amx_method: Optional[str] = None + kt_cpuinfer: Optional[int] = None + kt_threadpool_count: Optional[int] = None + kt_num_gpu_experts: Optional[int] = None + # Double Sparsity enable_double_sparsity: bool = False ds_channel_config_path: Optional[str] = None @@ -544,6 +552,9 @@ class ServerArgs: self._handle_amd_specifics() self._handle_grammar_backend() + # Handle Ktransformers specific configs + self._handle_ktransformers_configs() + # Handle data parallelism. self._handle_data_parallelism() @@ -595,6 +606,22 @@ class ServerArgs: ) self.tool_call_parser = deprecated_tool_call_parsers[self.tool_call_parser] + def _handle_ktransformers_configs(self): + from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe import ( + CompressedTensorsWNA16AMXEPMoEMethod, + override_config, + ) + + override_config( + CompressedTensorsWNA16AMXEPMoEMethod, + self.kt_num_gpu_experts, + self.kt_cpuinfer, + self.kt_threadpool_count, + self.kt_amx_weight_path, + self.kt_amx_method, + self.chunked_prefill_size, + ) + def _handle_missing_default_values(self): if self.tokenizer_path is None: self.tokenizer_path = self.model_path @@ -1518,6 +1545,7 @@ class ServerArgs: @staticmethod def add_cli_args(parser: argparse.ArgumentParser): + # Model and tokenizer parser.add_argument( "--model-path", @@ -2675,6 +2703,35 @@ class ServerArgs: help="Using LMCache as an alternative hierarchical cache solution", ) + # Ktransformer server args + parser.add_argument( + "--kt-amx-weight-path", + type=str, + help="[ktransformers parameter] The path of the quantized expert weights for amx kernel. A local folder.", + ) + parser.add_argument( + "--kt-amx-method", + type=str, + default="AMXINT4", + help="[ktransformers parameter] Quantization formats for CPU execution.", + ) + parser.add_argument( + "--kt-cpuinfer", + type=int, + help="[ktransformers parameter] The number of CPUInfer threads.", + ) + parser.add_argument( + "--kt-threadpool-count", + type=int, + default=2, + help="[ktransformers parameter] One-to-one with the number of NUMA nodes (one thread pool per NUMA).", + ) + parser.add_argument( + "--kt-num-gpu-experts", + type=int, + help="[ktransformers parameter] The number of GPU experts.", + ) + # Double Sparsity parser.add_argument( "--enable-double-sparsity",