Faster weight processing (trtllm-gen moe nvfp4) (#9162)
This commit is contained in:
@@ -737,6 +737,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|||||||
" above."
|
" above."
|
||||||
)
|
)
|
||||||
self.enable_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
|
self.enable_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
|
||||||
|
self._cache_permute_indices = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def enable_flashinfer_cutlass_moe(self) -> bool:
|
def enable_flashinfer_cutlass_moe(self) -> bool:
|
||||||
@@ -900,10 +901,15 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|||||||
e2m1_and_ufp8sf_scale_to_float,
|
e2m1_and_ufp8sf_scale_to_float,
|
||||||
fp4_quantize,
|
fp4_quantize,
|
||||||
next_positive_power_of_2,
|
next_positive_power_of_2,
|
||||||
|
nvfp4_block_scale_interleave,
|
||||||
reorder_rows_for_gated_act_gemm,
|
reorder_rows_for_gated_act_gemm,
|
||||||
shuffle_matrix_a,
|
shuffle_matrix_a,
|
||||||
shuffle_matrix_sf_a,
|
shuffle_matrix_sf_a,
|
||||||
)
|
)
|
||||||
|
from flashinfer.fused_moe.core import (
|
||||||
|
_maybe_get_cached_w2_permute_indices,
|
||||||
|
_maybe_get_cached_w3_w1_permute_indices,
|
||||||
|
)
|
||||||
|
|
||||||
"""Prepare quantized weights for kernel (done offline with weights)."""
|
"""Prepare quantized weights for kernel (done offline with weights)."""
|
||||||
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
|
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
|
||||||
@@ -927,50 +933,66 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|||||||
num_experts, hidden_size, intermediate_size // 16
|
num_experts, hidden_size, intermediate_size // 16
|
||||||
) # fp8 scaling factors
|
) # fp8 scaling factors
|
||||||
|
|
||||||
# Reorder rows of W1 and scales for fused gated activation
|
|
||||||
gemm1_weights_fp4_interleaved = []
|
|
||||||
gemm1_scales_fp4_interleaved = []
|
|
||||||
for i in range(num_experts):
|
|
||||||
gemm1_weights_fp4_interleaved.append(
|
|
||||||
reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone())
|
|
||||||
)
|
|
||||||
gemm1_scales_fp4_interleaved.append(
|
|
||||||
reorder_rows_for_gated_act_gemm(gemm1_scales_linear_fp4[i].clone())
|
|
||||||
)
|
|
||||||
|
|
||||||
# Stack weights and scales for all experts
|
|
||||||
gemm1_weights_fp4_interleaved = torch.stack(
|
|
||||||
gemm1_weights_fp4_interleaved
|
|
||||||
).reshape(num_experts, 2 * intermediate_size, hidden_size // 2)
|
|
||||||
gemm1_scales_fp4_interleaved = torch.stack(
|
|
||||||
gemm1_scales_fp4_interleaved
|
|
||||||
).reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
|
|
||||||
|
|
||||||
# Shuffle weights and scaling factors for transposed mma output
|
|
||||||
gemm1_weights_fp4_shuffled = []
|
gemm1_weights_fp4_shuffled = []
|
||||||
gemm1_scales_fp4_shuffled = []
|
gemm1_scales_fp4_shuffled = []
|
||||||
gemm2_weights_fp4_shuffled = []
|
gemm2_weights_fp4_shuffled = []
|
||||||
gemm2_scales_fp4_shuffled = []
|
gemm2_scales_fp4_shuffled = []
|
||||||
for i in range(num_experts):
|
for i in range(num_experts):
|
||||||
|
# Calculate the permute indices for the following:
|
||||||
|
# 1. Reorder rows of W1 and scales for fused gated activation
|
||||||
|
# 2. Shuffle weights and scaling factors for transposed mma output
|
||||||
|
# for both w3_w1 and w2 weights and scale factors
|
||||||
|
permute_indices = _maybe_get_cached_w3_w1_permute_indices(
|
||||||
|
self._cache_permute_indices,
|
||||||
|
gemm1_weights_fp4[i].view(torch.uint8),
|
||||||
|
epilogue_tile_m,
|
||||||
|
)
|
||||||
gemm1_weights_fp4_shuffled.append(
|
gemm1_weights_fp4_shuffled.append(
|
||||||
shuffle_matrix_a(
|
gemm1_weights_fp4[i]
|
||||||
gemm1_weights_fp4_interleaved[i].view(torch.uint8), epilogue_tile_m
|
.view(torch.uint8)[permute_indices.to(gemm1_weights_fp4.device)]
|
||||||
)
|
.contiguous()
|
||||||
|
)
|
||||||
|
|
||||||
|
permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(
|
||||||
|
self._cache_permute_indices,
|
||||||
|
gemm1_scales_linear_fp4[i].view(torch.uint8),
|
||||||
|
epilogue_tile_m,
|
||||||
|
num_elts_per_sf=16,
|
||||||
)
|
)
|
||||||
gemm1_scales_fp4_shuffled.append(
|
gemm1_scales_fp4_shuffled.append(
|
||||||
shuffle_matrix_sf_a(
|
nvfp4_block_scale_interleave(
|
||||||
gemm1_scales_fp4_interleaved[i].view(torch.uint8), epilogue_tile_m
|
gemm1_scales_linear_fp4[i]
|
||||||
|
.view(torch.uint8)[
|
||||||
|
permute_sf_indices.to(gemm1_scales_linear_fp4.device)
|
||||||
|
]
|
||||||
|
.contiguous()
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
permute_indices = _maybe_get_cached_w2_permute_indices(
|
||||||
|
self._cache_permute_indices,
|
||||||
|
gemm2_weights_fp4[i].view(torch.uint8),
|
||||||
|
epilogue_tile_m,
|
||||||
|
)
|
||||||
gemm2_weights_fp4_shuffled.append(
|
gemm2_weights_fp4_shuffled.append(
|
||||||
shuffle_matrix_a(
|
gemm2_weights_fp4[i]
|
||||||
gemm2_weights_fp4[i].view(torch.uint8), epilogue_tile_m
|
.view(torch.uint8)[permute_indices.to(gemm2_weights_fp4.device)]
|
||||||
)
|
.contiguous()
|
||||||
|
)
|
||||||
|
|
||||||
|
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
|
||||||
|
self._cache_permute_indices,
|
||||||
|
gemm2_scales_linear_fp4[i].view(torch.uint8),
|
||||||
|
epilogue_tile_m,
|
||||||
|
num_elts_per_sf=16,
|
||||||
)
|
)
|
||||||
gemm2_scales_fp4_shuffled.append(
|
gemm2_scales_fp4_shuffled.append(
|
||||||
shuffle_matrix_sf_a(
|
nvfp4_block_scale_interleave(
|
||||||
gemm2_scales_linear_fp4[i].view(torch.uint8), epilogue_tile_m
|
gemm2_scales_linear_fp4[i]
|
||||||
|
.view(torch.uint8)[
|
||||||
|
permute_sf_indices.to(gemm2_scales_linear_fp4.device)
|
||||||
|
]
|
||||||
|
.contiguous()
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user