From b498cd21d76aa0ec039b10b0e2d97607c0776a56 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Mon, 18 Aug 2025 04:26:02 +0800 Subject: [PATCH] Tiny make fp4 moe method parameters more static (#8520) --- .../srt/layers/quantization/modelopt_quant.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 7647ec30b..ccc6ebffb 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -812,6 +812,11 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ) layer.register_parameter("w13_weight_scale", w13_weight_scale) + # Only use `swizzle_blockscale` for shapes, not for real content + layer.w13_blockscale_swizzled = Parameter( + self.swizzle_blockscale(layer.w13_weight_scale), requires_grad=False + ) + w2_weight_scale = ModelWeightParameter( data=torch.empty( layer.num_local_experts, @@ -826,6 +831,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ) layer.register_parameter("w2_weight_scale", w2_weight_scale) + layer.w2_blockscale_swizzled = Parameter( + self.swizzle_blockscale(layer.w2_weight_scale), requires_grad=False + ) + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported extra_weight_attrs.update( @@ -1129,16 +1138,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): # Process w13 weights w13_blockscale_swizzled = self.swizzle_blockscale(layer.w13_weight_scale) - layer.w13_blockscale_swizzled = Parameter( - w13_blockscale_swizzled, requires_grad=False - ) + layer.w13_blockscale_swizzled.data.copy_(w13_blockscale_swizzled) layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False) # Process w2 weights w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale) - layer.w2_blockscale_swizzled = Parameter( - w2_blockscale_swizzled, requires_grad=False - ) + layer.w2_blockscale_swizzled.data.copy_(w2_blockscale_swizzled) layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False) # Both flashinfer cutlass and regular cutlass use same processing for w2