Faster weight processing (trtllm-gen moe nvfp4) (#9162)
This commit is contained in:
@@ -737,6 +737,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
" above."
|
||||
)
|
||||
self.enable_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
|
||||
self._cache_permute_indices = {}
|
||||
|
||||
@property
|
||||
def enable_flashinfer_cutlass_moe(self) -> bool:
|
||||
@@ -900,10 +901,15 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
e2m1_and_ufp8sf_scale_to_float,
|
||||
fp4_quantize,
|
||||
next_positive_power_of_2,
|
||||
nvfp4_block_scale_interleave,
|
||||
reorder_rows_for_gated_act_gemm,
|
||||
shuffle_matrix_a,
|
||||
shuffle_matrix_sf_a,
|
||||
)
|
||||
from flashinfer.fused_moe.core import (
|
||||
_maybe_get_cached_w2_permute_indices,
|
||||
_maybe_get_cached_w3_w1_permute_indices,
|
||||
)
|
||||
|
||||
"""Prepare quantized weights for kernel (done offline with weights)."""
|
||||
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
|
||||
@@ -927,50 +933,66 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
num_experts, hidden_size, intermediate_size // 16
|
||||
) # fp8 scaling factors
|
||||
|
||||
# Reorder rows of W1 and scales for fused gated activation
|
||||
gemm1_weights_fp4_interleaved = []
|
||||
gemm1_scales_fp4_interleaved = []
|
||||
for i in range(num_experts):
|
||||
gemm1_weights_fp4_interleaved.append(
|
||||
reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone())
|
||||
)
|
||||
gemm1_scales_fp4_interleaved.append(
|
||||
reorder_rows_for_gated_act_gemm(gemm1_scales_linear_fp4[i].clone())
|
||||
)
|
||||
|
||||
# Stack weights and scales for all experts
|
||||
gemm1_weights_fp4_interleaved = torch.stack(
|
||||
gemm1_weights_fp4_interleaved
|
||||
).reshape(num_experts, 2 * intermediate_size, hidden_size // 2)
|
||||
gemm1_scales_fp4_interleaved = torch.stack(
|
||||
gemm1_scales_fp4_interleaved
|
||||
).reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
|
||||
|
||||
# Shuffle weights and scaling factors for transposed mma output
|
||||
gemm1_weights_fp4_shuffled = []
|
||||
gemm1_scales_fp4_shuffled = []
|
||||
gemm2_weights_fp4_shuffled = []
|
||||
gemm2_scales_fp4_shuffled = []
|
||||
for i in range(num_experts):
|
||||
# Calculate the permute indices for the following:
|
||||
# 1. Reorder rows of W1 and scales for fused gated activation
|
||||
# 2. Shuffle weights and scaling factors for transposed mma output
|
||||
# for both w3_w1 and w2 weights and scale factors
|
||||
permute_indices = _maybe_get_cached_w3_w1_permute_indices(
|
||||
self._cache_permute_indices,
|
||||
gemm1_weights_fp4[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm1_weights_fp4_shuffled.append(
|
||||
shuffle_matrix_a(
|
||||
gemm1_weights_fp4_interleaved[i].view(torch.uint8), epilogue_tile_m
|
||||
)
|
||||
gemm1_weights_fp4[i]
|
||||
.view(torch.uint8)[permute_indices.to(gemm1_weights_fp4.device)]
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(
|
||||
self._cache_permute_indices,
|
||||
gemm1_scales_linear_fp4[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
num_elts_per_sf=16,
|
||||
)
|
||||
gemm1_scales_fp4_shuffled.append(
|
||||
shuffle_matrix_sf_a(
|
||||
gemm1_scales_fp4_interleaved[i].view(torch.uint8), epilogue_tile_m
|
||||
nvfp4_block_scale_interleave(
|
||||
gemm1_scales_linear_fp4[i]
|
||||
.view(torch.uint8)[
|
||||
permute_sf_indices.to(gemm1_scales_linear_fp4.device)
|
||||
]
|
||||
.contiguous()
|
||||
)
|
||||
)
|
||||
|
||||
permute_indices = _maybe_get_cached_w2_permute_indices(
|
||||
self._cache_permute_indices,
|
||||
gemm2_weights_fp4[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm2_weights_fp4_shuffled.append(
|
||||
shuffle_matrix_a(
|
||||
gemm2_weights_fp4[i].view(torch.uint8), epilogue_tile_m
|
||||
)
|
||||
gemm2_weights_fp4[i]
|
||||
.view(torch.uint8)[permute_indices.to(gemm2_weights_fp4.device)]
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
|
||||
self._cache_permute_indices,
|
||||
gemm2_scales_linear_fp4[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
num_elts_per_sf=16,
|
||||
)
|
||||
gemm2_scales_fp4_shuffled.append(
|
||||
shuffle_matrix_sf_a(
|
||||
gemm2_scales_linear_fp4[i].view(torch.uint8), epilogue_tile_m
|
||||
nvfp4_block_scale_interleave(
|
||||
gemm2_scales_linear_fp4[i]
|
||||
.view(torch.uint8)[
|
||||
permute_sf_indices.to(gemm2_scales_linear_fp4.device)
|
||||
]
|
||||
.contiguous()
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user