Change tensor alignment method to mn major (#9844)
This commit is contained in:
committed by
GitHub
parent
1fbfdebe6b
commit
b9eb0d9c2b
@@ -229,7 +229,7 @@ class EPMoE(FusedMoE):
|
|||||||
(
|
(
|
||||||
_cast_to_e8m0_with_rounding_up(gateup_input_scale)
|
_cast_to_e8m0_with_rounding_up(gateup_input_scale)
|
||||||
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
||||||
else deep_gemm_wrapper.get_col_major_tma_aligned_tensor(
|
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
|
||||||
gateup_input_scale
|
gateup_input_scale
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
@@ -286,9 +286,7 @@ class EPMoE(FusedMoE):
|
|||||||
(
|
(
|
||||||
down_input_scale
|
down_input_scale
|
||||||
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
||||||
else deep_gemm_wrapper.get_col_major_tma_aligned_tensor(
|
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
|
||||||
down_input_scale
|
|
||||||
)
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
down_output = torch.empty(
|
down_output = torch.empty(
|
||||||
|
|||||||
Reference in New Issue
Block a user