diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index c9e7547bf..8418334b9 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -342,6 +342,7 @@ def fused_moe_kernel( use_fp8_w8a8: tl.constexpr, use_int8_w8a8: tl.constexpr, use_int8_w8a16: tl.constexpr, + per_channel_quant: tl.constexpr, even_Ks: tl.constexpr, ): """ @@ -416,20 +417,7 @@ def fused_moe_kernel( ) b_scale = tl.load(b_scale_ptrs) - if use_fp8_w8a8: - # block-wise - if group_k > 0 and group_n > 0: - a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm - offs_bsn = offs_bn // group_n - b_scale_ptrs = ( - b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn - ) - # tensor-wise - else: - a_scale = tl.load(a_scale_ptr) - b_scale = tl.load(b_scale_ptr + off_experts) - - if use_int8_w8a8: + if use_fp8_w8a8 or use_int8_w8a8: # block-wise if group_k > 0 and group_n > 0: a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm @@ -438,8 +426,7 @@ def fused_moe_kernel( b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn ) # channel-wise - else: - # Load per-column scale for weights + elif per_channel_quant: b_scale_ptrs = ( b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn ) @@ -447,6 +434,10 @@ def fused_moe_kernel( # Load per-token scale for activations a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None] + # tensor-wise + else: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + off_experts) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. @@ -753,6 +744,7 @@ def invoke_fused_moe_kernel( use_int8_w8a8: bool, use_int8_w8a16: bool, use_int4_w4a16: bool, + per_channel_quant: bool, block_shape: Optional[List[int]] = None, no_combine: bool = False, ) -> None: @@ -777,10 +769,15 @@ def invoke_fused_moe_kernel( if block_shape is None: # activation tensor-wise fp8 quantization, dynamic or static padded_size = padding_size + # activations apply per-token quantization when weights apply per-channel quantization by default if _is_cuda: - A, A_scale = sgl_scaled_fp8_quant(A, A_scale) + A, A_scale = sgl_scaled_fp8_quant( + A, A_scale, use_per_token_if_dynamic=per_channel_quant + ) else: - A, A_scale = vllm_ops.scaled_fp8_quant(A, A_scale) + A, A_scale = vllm_ops.scaled_fp8_quant( + A, A_scale, use_per_token_if_dynamic=per_channel_quant + ) else: # activation block-wise fp8 quantization assert len(block_shape) == 2 @@ -796,6 +793,9 @@ def invoke_fused_moe_kernel( assert B_scale is not None if block_shape is None: # activation channel-wise int8 quantization + assert ( + per_channel_quant + ), "int8 quantization only supports channel-wise quantization except for block-wise quantization" A, A_scale = per_token_quant_int8(A) else: # activation block-wise int8 quantization @@ -904,6 +904,7 @@ def invoke_fused_moe_kernel( use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, + per_channel_quant=per_channel_quant, even_Ks=even_Ks, **config, ) @@ -1086,6 +1087,7 @@ def inplace_fused_experts( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + per_channel_quant: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None, @@ -1107,6 +1109,7 @@ def inplace_fused_experts( use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, + per_channel_quant, w1_scale, w2_scale, w1_zp, @@ -1129,6 +1132,7 @@ def inplace_fused_experts_fake( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + per_channel_quant: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None, @@ -1160,6 +1164,7 @@ def outplace_fused_experts( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + per_channel_quant: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None, @@ -1182,6 +1187,7 @@ def outplace_fused_experts( use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, + per_channel_quant, w1_scale, w2_scale, w1_zp, @@ -1205,6 +1211,7 @@ def outplace_fused_experts_fake( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + per_channel_quant: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None, @@ -1238,6 +1245,7 @@ def fused_experts( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + per_channel_quant: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None, @@ -1261,6 +1269,7 @@ def fused_experts( use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, + per_channel_quant, w1_scale, w2_scale, w1_zp, @@ -1283,6 +1292,7 @@ def fused_experts( use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, + per_channel_quant, w1_scale, w2_scale, w1_zp, @@ -1307,6 +1317,7 @@ def fused_experts_impl( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + per_channel_quant: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None, @@ -1443,6 +1454,7 @@ def fused_experts_impl( use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, + per_channel_quant=per_channel_quant, block_shape=block_shape, ) if activation == "silu": @@ -1486,6 +1498,7 @@ def fused_experts_impl( use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, + per_channel_quant=per_channel_quant, block_shape=block_shape, ) @@ -1532,6 +1545,7 @@ def fused_moe( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + per_channel_quant: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None, @@ -1608,6 +1622,7 @@ def fused_moe( use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, + per_channel_quant=per_channel_quant, w1_scale=w1_scale, w2_scale=w2_scale, w1_zp=w1_zp, diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py index c60e09be4..ce2155600 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py @@ -77,6 +77,7 @@ class CompressedTensorsConfig(QuantizationConfig): sparsity_ignore_list: List[str], kv_cache_scheme: Optional[Dict[str, Any]] = None, config: Optional[Dict[str, Any]] = None, + packed_modules_mapping: Dict[str, List[str]] = {}, ): super().__init__() self.ignore = ignore @@ -87,6 +88,7 @@ class CompressedTensorsConfig(QuantizationConfig): self.sparsity_scheme_map = sparsity_scheme_map self.sparsity_ignore_list = sparsity_ignore_list self.config = config + self.packed_modules_mapping = packed_modules_mapping def get_linear_method(self) -> "CompressedTensorsLinearMethod": return CompressedTensorsLinearMethod(self) @@ -136,6 +138,7 @@ class CompressedTensorsConfig(QuantizationConfig): sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config( config=config ) + packed_modules_mapping = config.get("packed_modules_mapping", {}) return cls( target_scheme_map=target_scheme_map, @@ -144,6 +147,7 @@ class CompressedTensorsConfig(QuantizationConfig): sparsity_scheme_map=sparsity_scheme_map, sparsity_ignore_list=sparsity_ignore_list, config=config, + packed_modules_mapping=packed_modules_mapping, ) @classmethod diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 393d6369c..569f2a2d6 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -103,16 +103,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): "input_activations" ) - if not ( - self.weight_quant.strategy == QuantizationStrategy.TENSOR - and self.input_quant.strategy == QuantizationStrategy.TENSOR - ): - raise ValueError( - "For FP8 Fused MoE layers, only per-tensor scales " - "for weights and activations are supported. Found " - f"{self.weight_quant}, {self.input_quant}" - ) - self.static_input_scales = not self.input_quant.dynamic def create_weights( @@ -154,27 +144,50 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): set_weight_attrs(w2_weight, extra_weight_attrs) # WEIGHT_SCALES - # Allocate 2 scales for w1 and w3 respectively. - # They will be combined to a single scale after weight loading. - w13_weight_scale = torch.nn.Parameter( - torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False - ) - layer.register_parameter("w13_weight_scale", w13_weight_scale) + # per-tensor quantization + if self.weight_quant.strategy == QuantizationStrategy.TENSOR: + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + weight_quant_method = FusedMoeWeightScaleSupported.TENSOR.value + elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL: + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=torch.float32, + ), + requires_grad=False, + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, hidden_size, 1, dtype=torch.float32), + requires_grad=False, + ) + weight_quant_method = FusedMoeWeightScaleSupported.CHANNEL.value + else: + raise ValueError( + f"Unsupported weight quantization strategy: {self.weight_quant.strategy}" + ) - w2_weight_scale = torch.nn.Parameter( - torch.ones(num_experts, dtype=torch.float32), requires_grad=False - ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale) # Add the quantization method used (per tensor/grouped/channel) # to ensure the weight scales are loaded in properly - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} - ) + extra_weight_attrs.update({"quant_method": weight_quant_method}) set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs) # INPUT_SCALES if self.static_input_scales: + assert ( + self.input_quant.strategy == QuantizationStrategy.TENSOR + ), "Only per-tensor quantization is supported for static input scales" w13_input_scale = torch.nn.Parameter( torch.ones(num_experts, dtype=torch.float32), requires_grad=False ) @@ -241,31 +254,37 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): layer.w2_input_scale = torch.nn.Parameter( w2_input_scale, requires_grad=False ) - - # Fp8 moe kernel needs single weight scale for w13 per expert. - # We take the max then dequant and requant each expert. - assert layer.w13_weight_scale is not None - shard_size = layer.intermediate_size_per_partition - max_w13_scales = layer.w13_weight_scale.max(dim=1).values - for expert_id in range(layer.local_num_experts): - start = 0 - for shard_id in range(2): - dq_weight = per_tensor_dequantize( - layer.w13_weight[expert_id][start : start + shard_size, :], - layer.w13_weight_scale[expert_id][shard_id], - ) - - if _is_cuda: - layer.w13_weight[expert_id][start : start + shard_size, :], _ = ( - sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) + if self.weight_quant.strategy == QuantizationStrategy.TENSOR: + # Fp8 moe kernel needs single weight scale for w13 per expert. + # We take the max then dequant and requant each expert. + assert layer.w13_weight_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_weight_scale.max(dim=1).values + for expert_id in range(layer.local_num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start : start + shard_size, :], + layer.w13_weight_scale[expert_id][shard_id], ) - else: - layer.w13_weight[expert_id][start : start + shard_size, :], _ = ( - vllm_ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) - ) - start += shard_size - layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False) + if _is_cuda: + ( + layer.w13_weight[expert_id][start : start + shard_size, :], + _, + ) = sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) + else: + ( + layer.w13_weight[expert_id][start : start + shard_size, :], + _, + ) = vllm_ops.scaled_fp8_quant( + dq_weight, max_w13_scales[expert_id] + ) + start += shard_size + + layer.w13_weight_scale = torch.nn.Parameter( + max_w13_scales, requires_grad=False + ) def apply( self, @@ -311,6 +330,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): inplace=inplace, activation=activation, use_fp8_w8a8=True, + per_channel_quant=self.weight_quant.strategy + == QuantizationStrategy.CHANNEL, w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, a1_scale=layer.w13_input_scale, diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 63c318ba3..7d80f6e0d 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -217,6 +217,15 @@ def block_quant_to_tensor_quant( return x_q_tensor, scale +def channel_quant_to_tensor_quant( + x_q_channel: torch.Tensor, + x_s: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + x_dq_channel = x_q_channel.to(torch.float32) * x_s + x_q_tensor, scale = input_to_float8(x_dq_channel, dtype=x_q_channel.dtype) + return x_q_tensor, scale + + def apply_fp8_linear( input: torch.Tensor, weight: torch.Tensor, diff --git a/python/sglang/srt/layers/quantization/w8a8_fp8.py b/python/sglang/srt/layers/quantization/w8a8_fp8.py index 77819d4fb..12d1eab19 100644 --- a/python/sglang/srt/layers/quantization/w8a8_fp8.py +++ b/python/sglang/srt/layers/quantization/w8a8_fp8.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional import torch from torch.nn.parameter import Parameter @@ -16,7 +16,7 @@ from sglang.srt.layers.quantization.fp8_utils import ( input_to_float8, normalize_e4m3fn_to_e4m3fnuz, ) -from sglang.srt.utils import is_hip +from sglang.srt.utils import is_hip, set_weight_attrs _is_hip = is_hip() @@ -62,7 +62,9 @@ class W8A8Fp8Config(QuantizationConfig): @classmethod def from_config(cls, config: Dict[str, Any]) -> "W8A8Fp8Config": quant_method = cls.get_from_keys(config, ["quant_method"]) - is_checkpoint_fp8_serialized = "compressed-tensors" in quant_method + is_checkpoint_fp8_serialized = ( + "compressed-tensors" in quant_method or "w8a8_fp8" in quant_method + ) return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized) def get_quant_method( @@ -71,9 +73,12 @@ class W8A8Fp8Config(QuantizationConfig): prefix: str, ) -> Optional["QuantizeMethodBase"]: from sglang.srt.layers.linear import LinearBase + from sglang.srt.layers.moe.fused_moe_triton import FusedMoE if isinstance(layer, LinearBase): return W8A8Fp8LinearMethod(self) + elif isinstance(layer, FusedMoE): + return W8A8FP8MoEMethod(self) return None def get_scaled_act_names(self) -> List[str]: @@ -131,7 +136,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase): input_size: int, output_size: int, params_dtype: torch.dtype, - **extra_weight_attrs + **extra_weight_attrs, ): weight_dtype = ( torch.float8_e4m3fn @@ -177,3 +182,148 @@ class W8A8Fp8LinearMethod(LinearMethodBase): bias=bias, cutlass_fp8_supported=self.cutlass_fp8_supported, ) + + +class W8A8FP8MoEMethod: + """MoE method for FP8. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. + Also supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. + Args: + quant_config: The quantization config. + """ + + def __new__(cls, *args, **kwargs): + from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase + + if not hasattr(cls, "_initialized"): + original_init = cls.__init__ + new_cls = type( + cls.__name__, + (FusedMoEMethodBase,), + { + "__init__": original_init, + **{k: v for k, v in cls.__dict__.items() if k != "__dict__"}, + }, + ) + obj = super(new_cls, new_cls).__new__(new_cls) + obj.__init__(*args, **kwargs) + return obj + return super().__new__(cls) + + def __init__(self, quant_config): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + + fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, 2 * intermediate_size, hidden_size, dtype=fp8_dtype + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty(num_experts, hidden_size, intermediate_size, dtype=fp8_dtype), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32), + requires_grad=False, + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, hidden_size, 1, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} + ) + + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + w13_input_scale = None + layer.register_parameter("w13_input_scale", w13_input_scale) + + w2_input_scale = None + layer.register_parameter("w2_input_scale", w2_input_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False) + layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False) + layer.w13_weight_scale = Parameter( + layer.w13_weight_scale.data, requires_grad=False + ) + layer.w2_weight_scale = Parameter( + layer.w2_weight_scale.data, requires_grad=False + ) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", + inplace: bool = True, + no_combine: bool = False, + ) -> torch.Tensor: + from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts + from sglang.srt.layers.moe.topk import select_experts + + # Expert selection + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + correction_bias=correction_bias, + ) + + return fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=inplace, + activation=activation, + use_fp8_w8a8=True, + per_channel_quant=True, + w1_scale=(layer.w13_weight_scale), + w2_scale=(layer.w2_weight_scale), + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + no_combine=no_combine, + ) diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py index 280a9a249..6df5693f8 100644 --- a/python/sglang/srt/layers/quantization/w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/w8a8_int8.py @@ -260,6 +260,7 @@ class W8A8Int8MoEMethod: activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, use_int8_w8a8=True, + per_channel_quant=True, w1_scale=(layer.w13_weight_scale), w2_scale=(layer.w2_weight_scale), a1_scale=layer.w13_input_scale, diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index 7580967a6..4e42ee897 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -108,11 +108,15 @@ logger = logging.getLogger(__name__) def _get_quantization_config( - model_config: ModelConfig, load_config: LoadConfig + model_config: ModelConfig, + load_config: LoadConfig, + packed_modules_mapping: Dict[str, List[str]], ) -> Optional[QuantizationConfig]: """Get the quantization config.""" if model_config.quantization is not None: - quant_config = get_quant_config(model_config, load_config) + quant_config = get_quant_config( + model_config, load_config, packed_modules_mapping + ) major, minor = get_device_capability() if major is not None and minor is not None: @@ -142,7 +146,10 @@ def _initialize_model( ) -> nn.Module: """Initialize a model with the given configurations.""" model_class, _ = get_model_architecture(model_config) - quant_config = _get_quantization_config(model_config, load_config) + packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {}) + quant_config = _get_quantization_config( + model_config, load_config, packed_modules_mapping + ) return model_class( config=model_config.hf_config, quant_config=quant_config, diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index 2f21d55c0..d1c44e4f7 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -129,7 +129,9 @@ def convert_bin_to_safetensor_file( # TODO(woosuk): Move this to other place. def get_quant_config( - model_config: ModelConfig, load_config: LoadConfig + model_config: ModelConfig, + load_config: LoadConfig, + packed_modules_mapping: Dict[str, List[str]], ) -> QuantizationConfig: quant_cls = get_quantization_config(model_config.quantization) @@ -147,6 +149,7 @@ def get_quant_config( # compressed-tensors uses a compressions_config hf_quant_config = getattr(model_config.hf_config, "compression_config", None) if hf_quant_config is not None: + hf_quant_config["packed_modules_mapping"] = packed_modules_mapping return quant_cls.from_config(hf_quant_config) # In case of bitsandbytes/QLoRA, get quant config from the adapter model. if model_config.quantization == "bitsandbytes": diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index d5f332126..c58431336 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -55,6 +55,7 @@ from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.fp8_utils import ( block_quant_to_tensor_quant, + channel_quant_to_tensor_quant, input_to_float8, normalize_e4m3fn_to_e4m3fnuz, ) @@ -1411,27 +1412,34 @@ class DeepseekV2ForCausalLM(nn.Module): w = self_attn.kv_b_proj.weight # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`. # This may affect the accuracy of fp8 model. - if hasattr(self.quant_config, "weight_block_size") and w.dtype in ( + if w.dtype in ( torch.float8_e4m3fn, torch.float8_e4m3fnuz, ): - weight_block_size = self.quant_config.weight_block_size - if weight_block_size is not None: - assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") - if _is_hip: - weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( - weight=w, - weight_scale=self_attn.kv_b_proj.weight_scale_inv, - input_scale=None, - ) - else: - weight = w - weight_scale = self_attn.kv_b_proj.weight_scale_inv + if hasattr(self.quant_config, "weight_block_size"): + weight_block_size = self.quant_config.weight_block_size + if weight_block_size is not None: + assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") + if _is_hip: + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=w, + weight_scale=self_attn.kv_b_proj.weight_scale_inv, + input_scale=None, + ) + else: + weight = w + weight_scale = self_attn.kv_b_proj.weight_scale_inv - w, scale = block_quant_to_tensor_quant( - weight, weight_scale, weight_block_size - ) + w, scale = block_quant_to_tensor_quant( + weight, weight_scale, weight_block_size + ) + self_attn.w_scale = scale + else: + weight = w + weight_scale = self_attn.kv_b_proj.weight_scale + w, scale = channel_quant_to_tensor_quant(weight, weight_scale) self_attn.w_scale = scale + if w.dtype == torch.int8: if hasattr(self.quant_config, "weight_block_size"): # block-wise int8 need it diff --git a/python/sglang/srt/models/llama4.py b/python/sglang/srt/models/llama4.py index 8015c18a0..0a46305b5 100644 --- a/python/sglang/srt/models/llama4.py +++ b/python/sglang/srt/models/llama4.py @@ -414,7 +414,7 @@ class Llama4Model(nn.Module): lambda idx, prefix: Llama4DecoderLayer( config=config, layer_id=idx, quant_config=quant_config, prefix=prefix ), - prefix="model.layers", + prefix=add_prefix("layers", prefix), ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/python/sglang/srt/models/mllama4.py b/python/sglang/srt/models/mllama4.py index f254903a2..98fc80686 100644 --- a/python/sglang/srt/models/mllama4.py +++ b/python/sglang/srt/models/mllama4.py @@ -7,6 +7,7 @@ from torch import nn from transformers import Llama4Config from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.quantization import QuantizationConfig from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -16,6 +17,7 @@ from sglang.srt.utils import add_prefix class Llama4ForConditionalGeneration(nn.Module): packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], } def __init__( @@ -96,6 +98,15 @@ class Llama4ForConditionalGeneration(nn.Module): num_experts = self.config.text_config.num_local_experts + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=num_experts, + ) + for name, loaded_weight in weights: if name.startswith("vision_model") or name.startswith( @@ -115,31 +126,54 @@ class Llama4ForConditionalGeneration(nn.Module): break else: if ".experts" in name: - if ".gate_up_proj" in name: - name_list = [ - name.replace(".experts.gate_up_proj", ".experts.w13_weight") - ] * 2 - loaded_weight_list = loaded_weight.chunk(2, dim=-1) - shard_id_list = ["w1", "w3"] - else: - name_list = [ - name.replace(".experts.down_proj", ".experts.w2_weight") - ] - shard_id_list = ["w2"] - loaded_weight_list = [loaded_weight] - for name, loaded_weight, shard_id in zip( - name_list, loaded_weight_list, shard_id_list + # NOTE: llama4 fp8 has different weight format for experts + if ( + "experts.gate_up_proj" not in name + and "experts.down_proj" not in name ): - param = params_dict[name] - weight_loader = param.weight_loader - for expert_id in range(num_experts): + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader weight_loader( param, - loaded_weight[expert_id].T, + loaded_weight, name, shard_id=shard_id, expert_id=expert_id, ) + break + else: + if ".gate_up_proj" in name: + name_list = [ + name.replace( + ".experts.gate_up_proj", ".experts.w13_weight" + ) + ] * 2 + loaded_weight_list = loaded_weight.chunk(2, dim=-1) + shard_id_list = ["w1", "w3"] + else: + name_list = [ + name.replace(".experts.down_proj", ".experts.w2_weight") + ] + shard_id_list = ["w2"] + loaded_weight_list = [loaded_weight] + for name, loaded_weight, shard_id in zip( + name_list, loaded_weight_list, shard_id_list + ): + param = params_dict[name] + weight_loader = param.weight_loader + for expert_id in range(num_experts): + weight_loader( + param, + loaded_weight[expert_id].T, + name, + shard_id=shard_id, + expert_id=expert_id, + ) else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index c2de835b6..b037e7a92 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -76,6 +76,7 @@ suites = { TestFile("test_create_kvindices.py", 2), TestFile("test_hicache.py", 60), TestFile("test_hicache_mla.py", 90), + TestFile("test_triton_moe_channel_fp8_kernel.py", 25), ], "per-commit-2-gpu": [ TestFile("models/lora/test_lora_tp.py", 300), diff --git a/test/srt/test_int8_kernel.py b/test/srt/test_int8_kernel.py index 959aab900..3e9f7a7dd 100644 --- a/test/srt/test_int8_kernel.py +++ b/test/srt/test_int8_kernel.py @@ -124,6 +124,7 @@ class TestW8A8Int8FusedMoE(CustomTestCase): use_fp8_w8a8=False, # Not using fp8 use_int8_w8a16=False, # Not using int8-w8a16 use_int8_w8a8=True, # Using int8-w8a8 + per_channel_quant=True, w1_scale=w1_s, w2_scale=w2_s, block_shape=None, # Not using block quantization diff --git a/test/srt/test_triton_moe_channel_fp8_kernel.py b/test/srt/test_triton_moe_channel_fp8_kernel.py new file mode 100644 index 000000000..2de9a6790 --- /dev/null +++ b/test/srt/test_triton_moe_channel_fp8_kernel.py @@ -0,0 +1,177 @@ +import itertools +import unittest + +import torch + +from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe +from sglang.test.test_utils import CustomTestCase + + +def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16): + """Matrix multiplication function that supports per-token input quantization and per-column weight quantization""" + A = A.to(torch.float32) + B = B.to(torch.float32) + + assert A.shape[-1] == B.shape[-1], "Dimension mismatch" + assert B.ndim == 2 and B.is_contiguous(), "B must be a 2D contiguous tensor" + + # Reshape input + M = A.numel() // A.shape[-1] + B = B.t() # Transpose weight matrix + N, K = B.shape + origin_C_shape = A.shape[:-1] + (K,) + A = A.reshape(M, N) + + # As is per-token [M, 1], Bs is per-column [1, K] + C = torch.matmul(A, B) # [M, K] + C = As * C * Bs.view(1, -1) # Broadcast per-column scale + + return C.reshape(origin_C_shape).to(output_dtype) + + +def fp8_mask(a, mask): + dtype = a.dtype + return a.view(torch.int8)[mask].view(dtype) + + +def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk): + """This function performs fused moe with per-column int8 quantization using native torch.""" + + B, D = a.shape + # Perform per-token quantization + a_q, a_s = sgl_scaled_fp8_quant(a, use_per_token_if_dynamic=True) + # Repeat tokens to match topk + a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + # Also repeat the scale + a_s = a_s.view(B, -1, 1).repeat(1, topk, 1).reshape(-1, 1) # [B*topk, 1] + + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + + # Calculate routing + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + # Process each expert + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + # First MLP layer: note that a_s is now per-token + inter_out = native_w8a8_per_token_matmul( + fp8_mask(a_q, mask), + w1[i], + fp8_mask(a_s, mask), + w1_s[i], + output_dtype=a.dtype, + ) + # Activation function + act_out = SiluAndMul().forward_native(inter_out) + # Quantize activation output with per-token + act_out_q, act_out_s = sgl_scaled_fp8_quant( + act_out, use_per_token_if_dynamic=True + ) + + # Second MLP layer + out[mask] = native_w8a8_per_token_matmul( + act_out_q, w2[i], act_out_s, w2_s[i], output_dtype=a.dtype + ) + # Apply routing weights and sum + return ( + out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) + ).sum(dim=1) + + +class TestW8A8FP8FusedMoE(CustomTestCase): + DTYPES = [torch.half, torch.bfloat16] + M = [1, 33] + N = [128, 1024] + K = [256, 4096] + E = [8] + TOP_KS = [2, 6] + BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]] + BLOCK_SIZE = [[128, 128]] + SEEDS = [0] + + @classmethod + def setUpClass(cls): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA is not available") + torch.set_default_device("cuda") + + def _w8a8_fp8_fused_moe(self, M, N, K, E, topk, block_size, dtype, seed): + torch.manual_seed(seed) + # Initialize int8 quantization parameters + factor_for_scale = 1e-2 + finfo = torch.finfo(torch.float8_e4m3fn) + fp8_max = finfo.max + fp8_min = finfo.min + + # Input tensor + # M * K + a = torch.randn((M, K), dtype=dtype) / 10 + + # Generate int8 weights + w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2 + w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2 + w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + # Generate scale for each column (per-column quantization) + w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * factor_for_scale + w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale + score = torch.randn((M, E), dtype=dtype) + + with torch.inference_mode(): + ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk) + out = fused_moe( + a, + w1, + w2, + score, + topk, + renormalize=False, + use_fp8_w8a8=True, # using fp8 + use_int8_w8a16=False, + use_int8_w8a8=False, + per_channel_quant=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=None, # Not using block quantization + ) + + # Check results + self.assertTrue( + torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) + / torch.mean(torch.abs(ref_out.to(torch.float32))) + < 0.05 + ) + + def test_w8a8_fp8_fused_moe(self): + for params in itertools.product( + self.M, + self.N, + self.K, + self.E, + self.TOP_KS, + self.BLOCK_SIZE, + self.DTYPES, + self.SEEDS, + ): + with self.subTest( + M=params[0], + N=params[1], + K=params[2], + E=params[3], + topk=params[4], + block_size=params[5], + dtype=params[6], + seed=params[7], + ): + self._w8a8_fp8_fused_moe(*params) + + +if __name__ == "__main__": + unittest.main(verbosity=2)