Change tensor alignment method to mn major (#9844)
This commit is contained in:
committed by
GitHub
parent
1fbfdebe6b
commit
b9eb0d9c2b
@@ -229,7 +229,7 @@ class EPMoE(FusedMoE):
|
||||
(
|
||||
_cast_to_e8m0_with_rounding_up(gateup_input_scale)
|
||||
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
||||
else deep_gemm_wrapper.get_col_major_tma_aligned_tensor(
|
||||
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
|
||||
gateup_input_scale
|
||||
)
|
||||
),
|
||||
@@ -286,9 +286,7 @@ class EPMoE(FusedMoE):
|
||||
(
|
||||
down_input_scale
|
||||
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
||||
else deep_gemm_wrapper.get_col_major_tma_aligned_tensor(
|
||||
down_input_scale
|
||||
)
|
||||
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
|
||||
),
|
||||
)
|
||||
down_output = torch.empty(
|
||||
|
||||
Reference in New Issue
Block a user