Change tensor alignment method to mn major (#9844)

This commit is contained in:
Mohammad Miadh Angkad
2025-09-02 16:23:13 +08:00
committed by GitHub
parent 1fbfdebe6b
commit b9eb0d9c2b

View File

@@ -229,7 +229,7 @@ class EPMoE(FusedMoE):
(
_cast_to_e8m0_with_rounding_up(gateup_input_scale)
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
else deep_gemm_wrapper.get_col_major_tma_aligned_tensor(
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
gateup_input_scale
)
),
@@ -286,9 +286,7 @@ class EPMoE(FusedMoE):
(
down_input_scale
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
else deep_gemm_wrapper.get_col_major_tma_aligned_tensor(
down_input_scale
)
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
),
)
down_output = torch.empty(