diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 698ce2df2..103f675d2 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -737,6 +737,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): " above." ) self.enable_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe() + self._cache_permute_indices = {} @property def enable_flashinfer_cutlass_moe(self) -> bool: @@ -900,10 +901,15 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): e2m1_and_ufp8sf_scale_to_float, fp4_quantize, next_positive_power_of_2, + nvfp4_block_scale_interleave, reorder_rows_for_gated_act_gemm, shuffle_matrix_a, shuffle_matrix_sf_a, ) + from flashinfer.fused_moe.core import ( + _maybe_get_cached_w2_permute_indices, + _maybe_get_cached_w3_w1_permute_indices, + ) """Prepare quantized weights for kernel (done offline with weights).""" epilogue_tile_m = 128 # FIXME: this depends on the kernel internals @@ -927,50 +933,66 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): num_experts, hidden_size, intermediate_size // 16 ) # fp8 scaling factors - # Reorder rows of W1 and scales for fused gated activation - gemm1_weights_fp4_interleaved = [] - gemm1_scales_fp4_interleaved = [] - for i in range(num_experts): - gemm1_weights_fp4_interleaved.append( - reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone()) - ) - gemm1_scales_fp4_interleaved.append( - reorder_rows_for_gated_act_gemm(gemm1_scales_linear_fp4[i].clone()) - ) - - # Stack weights and scales for all experts - gemm1_weights_fp4_interleaved = torch.stack( - gemm1_weights_fp4_interleaved - ).reshape(num_experts, 2 * intermediate_size, hidden_size // 2) - gemm1_scales_fp4_interleaved = torch.stack( - gemm1_scales_fp4_interleaved - ).reshape(num_experts, 2 * intermediate_size, hidden_size // 16) - - # Shuffle weights and scaling factors for transposed mma output gemm1_weights_fp4_shuffled = [] gemm1_scales_fp4_shuffled = [] gemm2_weights_fp4_shuffled = [] gemm2_scales_fp4_shuffled = [] for i in range(num_experts): + # Calculate the permute indices for the following: + # 1. Reorder rows of W1 and scales for fused gated activation + # 2. Shuffle weights and scaling factors for transposed mma output + # for both w3_w1 and w2 weights and scale factors + permute_indices = _maybe_get_cached_w3_w1_permute_indices( + self._cache_permute_indices, + gemm1_weights_fp4[i].view(torch.uint8), + epilogue_tile_m, + ) gemm1_weights_fp4_shuffled.append( - shuffle_matrix_a( - gemm1_weights_fp4_interleaved[i].view(torch.uint8), epilogue_tile_m - ) + gemm1_weights_fp4[i] + .view(torch.uint8)[permute_indices.to(gemm1_weights_fp4.device)] + .contiguous() + ) + + permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices( + self._cache_permute_indices, + gemm1_scales_linear_fp4[i].view(torch.uint8), + epilogue_tile_m, + num_elts_per_sf=16, ) gemm1_scales_fp4_shuffled.append( - shuffle_matrix_sf_a( - gemm1_scales_fp4_interleaved[i].view(torch.uint8), epilogue_tile_m + nvfp4_block_scale_interleave( + gemm1_scales_linear_fp4[i] + .view(torch.uint8)[ + permute_sf_indices.to(gemm1_scales_linear_fp4.device) + ] + .contiguous() ) ) + permute_indices = _maybe_get_cached_w2_permute_indices( + self._cache_permute_indices, + gemm2_weights_fp4[i].view(torch.uint8), + epilogue_tile_m, + ) gemm2_weights_fp4_shuffled.append( - shuffle_matrix_a( - gemm2_weights_fp4[i].view(torch.uint8), epilogue_tile_m - ) + gemm2_weights_fp4[i] + .view(torch.uint8)[permute_indices.to(gemm2_weights_fp4.device)] + .contiguous() + ) + + permute_sf_indices = _maybe_get_cached_w2_permute_indices( + self._cache_permute_indices, + gemm2_scales_linear_fp4[i].view(torch.uint8), + epilogue_tile_m, + num_elts_per_sf=16, ) gemm2_scales_fp4_shuffled.append( - shuffle_matrix_sf_a( - gemm2_scales_linear_fp4[i].view(torch.uint8), epilogue_tile_m + nvfp4_block_scale_interleave( + gemm2_scales_linear_fp4[i] + .view(torch.uint8)[ + permute_sf_indices.to(gemm2_scales_linear_fp4.device) + ] + .contiguous() ) )