diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index 0fa9c0f9c..cf8f405dd 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -25,7 +25,6 @@ try: from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQMarlinLinearMethod, - GPTQMarlinMoEMethod, ) from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQMarlin24Config, @@ -58,7 +57,11 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import CompressedTensorsConfig, ) from sglang.srt.layers.quantization.fp8 import Fp8Config -from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig +from sglang.srt.layers.quantization.gptq import ( + GPTQConfig, + GPTQMarlinConfig, + GPTQMarlinMoEMethod, +) from sglang.srt.layers.quantization.modelopt_quant import ( ModelOptFp4Config, ModelOptFp8Config, diff --git a/python/sglang/srt/layers/quantization/gptq.py b/python/sglang/srt/layers/quantization/gptq.py index af917a26d..b032ff4c1 100644 --- a/python/sglang/srt/layers/quantization/gptq.py +++ b/python/sglang/srt/layers/quantization/gptq.py @@ -1,21 +1,28 @@ import logging from fractions import Fraction -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch -from sglang.srt.layers.linear import LinearBase -from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.linear import LinearBase, set_weight_attrs +from sglang.srt.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from sglang.srt.layers.quantization.utils import replace_parameter from sglang.srt.utils import is_cuda _is_cuda = is_cuda() try: - from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase + from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod from vllm.model_executor.layers.quantization.gptq_marlin import ( + FusedMoE, + FusedMoEMethodBase, + FusedMoeWeightScaleSupported, GPTQMarlinLinearMethod, - GPTQMarlinMoEMethod, + marlin_moe_permute_scales, ) from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod from vllm.model_executor.layers.quantization.utils.marlin_utils import ( @@ -27,7 +34,9 @@ try: except ImportError: VLLM_AVAILABLE = False - GPTQLinearMethod = MarlinLinearMethod = QuantizeMethodBase = Any + GPTQLinearMethod = MarlinLinearMethod = Any + + FusedMoEMethodBase = QuantizeMethodBase class scalar_types: uint4b8 = "uint4b8" @@ -437,3 +446,286 @@ class MarlinConfig(QuantizationConfig): ): return MarlinLinearMethod(self) return None + + +class GPTQMarlinMoEMethod(FusedMoEMethodBase): + """MoE Marlin method with quantization.""" + + def __init__(self, quant_config: GPTQMarlinConfig) -> None: + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + intermediate_size = extra_weight_attrs.pop("intermediate_size") + + self.is_k_full = (not self.quant_config.desc_act) or ( + intermediate_size_per_partition == intermediate_size + ) + + if self.quant_config.group_size != -1: + scales_size13 = hidden_size // self.quant_config.group_size + w2_scales_size = ( + intermediate_size + if self.quant_config.desc_act + else intermediate_size_per_partition + ) + scales_size2 = w2_scales_size // self.quant_config.group_size + strategy = FusedMoeWeightScaleSupported.GROUP.value + else: + scales_size13 = 1 + scales_size2 = 1 + strategy = FusedMoeWeightScaleSupported.CHANNEL.value + + extra_weight_attrs.update({"quant_method": strategy, "is_transposed": True}) + # Fused gate_up_proj (column parallel) + w13_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size // self.quant_config.pack_factor, + 2 * intermediate_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_qweight", w13_qweight) + set_weight_attrs(w13_qweight, extra_weight_attrs) + # down_proj (row parallel) + w2_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size_per_partition // self.quant_config.pack_factor, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_qweight", w2_qweight) + set_weight_attrs(w2_qweight, extra_weight_attrs) + # up_proj scales + w13_scales = torch.nn.Parameter( + torch.empty( + num_experts, + scales_size13, + 2 * intermediate_size_per_partition, + dtype=torch.half, + ), + requires_grad=False, + ) + layer.register_parameter("w13_scales", w13_scales) + set_weight_attrs(w13_scales, extra_weight_attrs) + # down_proj scales + w2_scales = torch.nn.Parameter( + torch.empty(num_experts, scales_size2, hidden_size, dtype=torch.half), + requires_grad=False, + ) + layer.register_parameter("w2_scales", w2_scales) + set_weight_attrs(w2_scales, extra_weight_attrs) + # dont shard the w2 scales when running act order + set_weight_attrs(w2_scales, {"load_full_w2": self.quant_config.desc_act}) + # up_proj scales + w13_qzeros = torch.nn.Parameter( + torch.empty( + num_experts, + scales_size13, + 2 * intermediate_size_per_partition // self.quant_config.pack_factor, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_qzeros", w13_qzeros) + set_weight_attrs(w13_qzeros, extra_weight_attrs) + # down_proj scales + w2_qzeros = torch.nn.Parameter( + torch.empty( + num_experts, + scales_size2, + hidden_size // self.quant_config.pack_factor, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_qzeros", w2_qzeros) + set_weight_attrs(w2_qzeros, extra_weight_attrs) + # dont shard the w2 scales when running act order + set_weight_attrs(w2_qzeros, {"load_full_w2": self.quant_config.desc_act}) + w13_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx", w13_g_idx) + set_weight_attrs(w13_g_idx, extra_weight_attrs) + w2_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx", w2_g_idx) + set_weight_attrs(w2_g_idx, extra_weight_attrs) + w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx_sort_indices", w13_g_idx_sort_indices) + set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) + w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx_sort_indices", w2_g_idx_sort_indices) + set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + + # Process act_order + if self.quant_config.desc_act: + # Get sorting based on g_idx + num_experts = layer.w13_g_idx.shape[0] + w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx) + w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx) + w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx) + w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx) + for e in range(num_experts): + w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_g_idx[e]).to( + torch.int32 + ) + w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to( + torch.int32 + ) + w13_sorted_g_idx[e] = layer.w13_g_idx[e][w13_g_idx_sort_indices[e]] + w2_sorted_g_idx[e] = layer.w2_g_idx[e][w2_g_idx_sort_indices[e]] + replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx) + replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx) + replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices) + replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices) + else: + # Reset g_idx related tensors + num_experts = layer.w13_g_idx.shape[0] + device = layer.w13_g_idx.device + layer.w13_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + layer.w2_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + layer.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + layer.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + # Repack weights + marlin_w13_qweight = ops.gptq_marlin_moe_repack( + layer.w13_qweight, + layer.w13_g_idx_sort_indices, + layer.w13_qweight.shape[1] * self.quant_config.pack_factor, + layer.w13_qweight.shape[2], + self.quant_config.quant_type.size_bits, + ) + replace_parameter(layer, "w13_qweight", marlin_w13_qweight) + marlin_w2_qweight = ops.gptq_marlin_moe_repack( + layer.w2_qweight, + layer.w2_g_idx_sort_indices, + layer.w2_qweight.shape[1] * self.quant_config.pack_factor, + layer.w2_qweight.shape[2], + self.quant_config.quant_type.size_bits, + ) + replace_parameter(layer, "w2_qweight", marlin_w2_qweight) + # Repack scales + marlin_w13_scales = marlin_moe_permute_scales( + s=layer.w13_scales, + size_k=layer.intermediate_size_per_partition, + size_n=layer.w13_scales.shape[2], + group_size=self.quant_config.group_size, + ) + replace_parameter(layer, "w13_scales", marlin_w13_scales) + marlin_w2_scales = marlin_moe_permute_scales( + s=layer.w2_scales, + size_k=layer.w2_scales.shape[1] + * ( + self.quant_config.group_size + if self.quant_config.group_size != -1 + else self.quant_config.pack_factor + ), + size_n=layer.w2_scales.shape[2], + group_size=self.quant_config.group_size, + ) + replace_parameter(layer, "w2_scales", marlin_w2_scales) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", + ) -> torch.Tensor: + assert activation == "silu", "Only SiLU activation is supported." + + # The input must currently be float16 + orig_dtype = x.dtype + x = x.half() + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) + + return torch.ops.vllm.fused_marlin_moe( + x, + layer.w13_qweight, + layer.w2_qweight, + layer.w13_scales, + layer.w2_scales, + router_logits, + topk_weights, + topk_ids, + g_idx1=layer.w13_g_idx, + g_idx2=layer.w2_g_idx, + sort_indices1=layer.w13_g_idx_sort_indices, + sort_indices2=layer.w2_g_idx_sort_indices, + num_bits=self.quant_config.quant_type.size_bits, + is_k_full=self.is_k_full, + ).to(orig_dtype)