diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index f1d323982..1804963eb 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -556,7 +556,8 @@ class FusedMoE(torch.nn.Module): loaded_weight = loaded_weight.to(param.data.device) if ( - param.data[expert_id] != 1 + "compressed" in self.quant_method.__class__.__name__.lower() + and param.data[expert_id] != 1 and (param.data[expert_id] - loaded_weight).abs() > 1e-5 ): raise ValueError( @@ -580,6 +581,23 @@ class FusedMoE(torch.nn.Module): tp_rank=tp_rank, ) return + if "ModelOpt" in self.quant_method.__class__.__name__: + if "weight_scale_2" in weight_name or "input_scale" in weight_name: + self._load_per_tensor_weight_scale( + shard_id=shard_id, + param=param, + loaded_weight=loaded_weight, + expert_id=expert_id, + ) + elif "weight" in weight_name: + self._load_model_weight_or_group_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank, + ) + return # Case weight scales and zero_points if "scale" in weight_name or "zero" in weight_name: diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index aebc45244..fed4d52dc 100644 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -1,12 +1,17 @@ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py import logging -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional import torch from torch.nn.parameter import Parameter -from sglang.srt.layers.linear import LinearBase, LinearMethodBase +from sglang.srt.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) +from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, @@ -15,10 +20,12 @@ from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.fp8_utils import ( apply_fp8_linear, cutlass_fp8_supported, + is_sm100_supported, ) from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod from sglang.srt.layers.quantization.utils import ( convert_to_channelwise, + is_layer_skipped, requantize_with_max_scale, ) from sglang.srt.layers.radix_attention import RadixAttention @@ -270,9 +277,16 @@ class ModelOptFp4Config(QuantizationConfig): ) is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method kv_cache_quant_algo = quant_config["kv_cache_quant_algo"] + if not kv_cache_quant_algo: + kv_cache_quant_algo = "auto" group_size = quant_config["group_size"] exclude_modules = quant_config["exclude_modules"] if not (group_size and kv_cache_quant_algo and exclude_modules): + logger.warning( + f"group_size: {group_size}," + f"kv_cache_quant_algo: {kv_cache_quant_algo}," + f"exclude_modules: {exclude_modules}" + ) raise ValueError( "NVFP4 quantization requires group size and " "kv_cache_quant_algo specified in " @@ -285,19 +299,30 @@ class ModelOptFp4Config(QuantizationConfig): exclude_modules, ) + def is_layer_excluded(self, prefix: str, exclude_modules: list): + import regex as re + + for pattern in exclude_modules: + regex_str = pattern.replace(".", r"\.").replace("*", r".*") + if re.fullmatch(regex_str, prefix): + return True + return False + def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> Optional["QuantizeMethodBase"]: - if self.exclude_modules and any( - module in prefix for module in self.exclude_modules - ): - return None + from sglang.srt.layers.moe.fused_moe_triton import FusedMoE if isinstance(layer, LinearBase): + if is_layer_skipped(prefix, self.exclude_modules) or self.is_layer_excluded( + prefix, self.exclude_modules + ): + return UnquantizedLinearMethod() return ModelOptFp4LinearMethod(self) if self.kv_cache_quant_algo and isinstance(layer, RadixAttention): return ModelOptFp8KVCacheMethod(self) - + elif isinstance(layer, FusedMoE): + return ModelOptNvFp4FusedMoEMethod(self) return None def get_scaled_act_names(self) -> List[str]: @@ -461,3 +486,305 @@ class ModelOptFp4LinearMethod(LinearMethodBase): if bias is not None: out = out + bias return out.view(*output_shape) + + +class ModelOptNvFp4FusedMoEMethod: + """ + MoE Method for FP4 Quantization with Blockscales and PerTensorScales + Args: + quant_config: NVFP4 Quant Config + """ + + def __new__(cls, *args, **kwargs): + from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase + + if not hasattr(cls, "_initialized"): + original_init = cls.__init__ + new_cls = type( + cls.__name__, + (FusedMoEMethodBase,), + { + "__init__": original_init, + **{k: v for k, v in cls.__dict__.items() if k != "__dict__"}, + }, + ) + obj = super(new_cls, new_cls).__new__(new_cls) + obj.__init__(*args, **kwargs) + return obj + return super().__new__(cls) + + def __init__(self, quant_config: ModelOptFp4Config): + self.quant_config = quant_config + if not is_sm100_supported(): + raise ValueError( + "Current platform does not support NVFP4" + " quantization. Please use Blackwell and" + " above." + ) + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + if not self.quant_config.is_checkpoint_nvfp4_serialized: + raise ValueError( + "NVFP4 quantization was selected, " + " dynamic quantization is not supported." + ) + + layer.num_experts = num_experts + layer.params_dtype = params_dtype + layer.quant_config = self.quant_config + weight_dtype = torch.uint8 + weight_scale_dtype = torch.float8_e4m3fn + weight_loader = extra_weight_attrs.get("weight_loader") + # GEMM 1 + w13_weight = ModelWeightParameter( + data=torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + # 2 fp4 items are packed in the input dimension + hidden_size // 2, + dtype=weight_dtype, + ), + input_dim=1, + output_dim=2, + weight_loader=weight_loader, + ) + layer.register_parameter("w13_weight", w13_weight) + + # GEMM 2 + w2_weight = ModelWeightParameter( + data=torch.empty( + num_experts, + hidden_size, + # 2 fp4 items are packed in the input dimension + intermediate_size_per_partition // 2, + dtype=weight_dtype, + ), + input_dim=1, + output_dim=2, + weight_loader=weight_loader, + ) + layer.register_parameter("w2_weight", w2_weight) + + w13_weight_scale = ModelWeightParameter( + data=torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + # 2 fp4 items are packed in the input dimension + hidden_size // self.quant_config.group_size, + dtype=weight_scale_dtype, + ), + input_dim=1, + output_dim=2, + weight_loader=weight_loader, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + + w2_weight_scale = ModelWeightParameter( + data=torch.empty( + num_experts, + hidden_size, + # 2 fp4 items are packed in the input dimension + intermediate_size_per_partition // self.quant_config.group_size, + dtype=weight_scale_dtype, + ), + input_dim=1, + output_dim=2, + weight_loader=weight_loader, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value} + ) + + w13_weight_scale_2 = PerTensorScaleParameter( + data=torch.empty(num_experts, 2, dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2) + + w2_weight_scale_2 = PerTensorScaleParameter( + data=torch.empty(num_experts, dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2) + + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) + + w13_input_scale = PerTensorScaleParameter( + data=torch.empty(num_experts, 2, dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("w13_input_scale", w13_input_scale) + + w2_input_scale = PerTensorScaleParameter( + data=torch.empty(num_experts, dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("w2_input_scale", w2_input_scale) + + def swizzle_blockscale(self, scale: torch.tensor): + assert scale.dtype == torch.float8_e4m3fn + # Pad and blockwise interleave weight_scale + scale_ndim = scale.ndim + if scale.ndim == 2: + scale = scale.unsqueeze(0) + assert scale.ndim == 3 + B, M, K = scale.shape + round_up_multiple = lambda x, m: (x + m - 1) // m * m + M_padded = round_up_multiple(M, 128) + K_padded = round_up_multiple(K, 4) + padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype) + padded_scale[:B, :M, :K] = scale + batches, rows, cols = padded_scale.shape + assert rows % 128 == 0 + assert cols % 4 == 0 + padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32, cols // 4, 4) + swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5)) + swizzled_scale = swizzled_scale.contiguous().cuda() + return ( + swizzled_scale.reshape(M, K) + if scale_ndim == 2 + else swizzled_scale.reshape(B, M, K) + ) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + + # GEMM 1 + if not torch.allclose( + layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1] + ): + logger.warning_once( + "w1_weight_scale_2 must match w3_weight_scale_2. " + "Accuracy may be affected." + ) + + w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0] + layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False) + + w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32) + layer.g1_alphas = Parameter( + (w13_input_scale * w13_weight_scale_2).to(torch.float32), + requires_grad=False, + ) + + assert ( + layer.w13_weight_scale.shape[2] % 16 == 0 + ), "Expected weight_scale.dim(1) to be divisible by 16" + assert ( + layer.w13_weight_scale.dtype == torch.float8_e4m3fn + ), "Weight Blockscale must be represented as FP8-E4M3" + w13_blockscale_swizzled = self.swizzle_blockscale(layer.w13_weight_scale) + + layer.w13_blockscale_swizzled = Parameter( + w13_blockscale_swizzled, requires_grad=False + ) + + # This is for quantization, so we need to invert it. + layer.w13_input_scale_quant = Parameter( + (1 / w13_input_scale).to(torch.float32), requires_grad=False + ) + + layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False) + + # GEMM 2 + layer.g2_alphas = Parameter( + (layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32), + requires_grad=False, + ) + + # This is for quantization, so we need to invert it. + layer.w2_input_scale_quant = Parameter( + (1 / layer.w2_input_scale).to(torch.float32), requires_grad=False + ) + + assert ( + layer.w2_weight_scale.shape[2] % 16 == 0 + ), "Expected weight_scale.dim(1) to be divisible by 16" + assert ( + layer.w2_weight_scale.dtype == torch.float8_e4m3fn + ), "Weight Blockscale must be represented as FP8-E4M3" + w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale) + + layer.w2_blockscale_swizzled = Parameter( + w2_blockscale_swizzled, requires_grad=False + ) + layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False) + + device = layer.w13_weight.device + layer.cutlass_moe_params = CutlassMoEParams( + CutlassMoEType.BlockscaledFP4, + device, + num_experts=layer.num_experts, + intermediate_size_per_partition=layer.w2_weight.shape[2] * 2, # n + hidden_size=layer.w13_weight.shape[2] * 2, + ) # k + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + num_fused_shared_experts: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + inplace: bool = True, + no_combine: bool = False, + routed_scaling_factor: Optional[float] = None, + ) -> torch.Tensor: + + assert activation == "silu", "Only SiLU activation is supported." + + from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts + from sglang.srt.layers.moe.topk import select_experts + + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + num_fused_shared_experts=num_fused_shared_experts, + custom_routing_function=custom_routing_function, + correction_bias=correction_bias, + routed_scaling_factor=routed_scaling_factor, + ) + + from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4 + + return cutlass_moe_fp4( + a=x, + a1_gscale=layer.w13_input_scale_quant, + w1_fp4=layer.w13_weight, + w1_blockscale=layer.w13_blockscale_swizzled, + w1_alphas=layer.g1_alphas, + a2_gscale=layer.w2_input_scale_quant, + w2_fp4=layer.w2_weight, + w2_blockscale=layer.w2_blockscale_swizzled, + w2_alphas=layer.g2_alphas, + topk_weights=topk_weights, + topk_ids=topk_ids, + params=layer.cutlass_moe_params, + apply_router_weight_on_input=apply_router_weight_on_input, + ).to(x.dtype) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 7970d3503..9a53a1c77 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1746,7 +1746,7 @@ class DeepseekV2ForCausalLM(nn.Module): global_server_args_dict["disable_shared_experts_fusion"] = False log_info_on_rank0( logger, - "Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.", + "Deepseek V3/R1 with fp8/fp4 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.", ) def get_input_embeddings(self) -> nn.Embedding: @@ -1926,6 +1926,7 @@ class DeepseekV2ForCausalLM(nn.Module): self_attn.use_deep_gemm_bmm = True def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False): + if is_nextn: if hasattr(self.config, "num_nextn_predict_layers"): num_nextn_layers = self.config.num_nextn_predict_layers @@ -1982,6 +1983,21 @@ class DeepseekV2ForCausalLM(nn.Module): "up_proj.qzeros", "up_proj.scales", ] + elif self.quant_config.get_name() == "modelopt_fp4": + suffix_list = [ + "down_proj.weight", + "down_proj.weight_scale", + "down_proj.weight_scale_2", + "down_proj.input_scale", + "gate_proj.weight", + "gate_proj.weight_scale", + "gate_proj.weight_scale_2", + "gate_proj.input_scale", + "up_proj.weight", + "up_proj.weight_scale", + "up_proj.weight_scale_2", + "up_proj.input_scale", + ] else: raise ValueError( f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}." @@ -2125,7 +2141,6 @@ class DeepseekV2ForCausalLM(nn.Module): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - if fuse_qkv_a_proj and ( "q_a_proj" in name or "kv_a_proj_with_mqa" in name ): @@ -2151,9 +2166,12 @@ class DeepseekV2ForCausalLM(nn.Module): fused_weight = torch.cat( [q_a_proj_weight, kv_a_proj_weight], dim=0 ) - - param_name = name.replace( - "q_a_proj", "fused_qkv_a_proj_with_mqa" + param_name = ( + name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa") + if "q_a_proj" in name + else name.replace( + "kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa" + ) ) param = params_dict[param_name] @@ -2164,6 +2182,16 @@ class DeepseekV2ForCausalLM(nn.Module): cached_a_proj.pop(q_a_proj_name) cached_a_proj.pop(kv_a_proj_name) else: + if ( + "k_scale" in name or "v_scale" in name + ) and name not in params_dict: + # modelopt attn kv scale is named differently + if any(scale in name for scale in ["k_scale", "v_scale"]): + name = name.replace("_proj", "attn_mqa") + else: + logger.warning( + f"Unknown scale found in checkpoint: {name}" + ) param = params_dict[name] weight_loader = getattr( param, "weight_loader", default_weight_loader