[fix] Fix mxfp4 triton MoE tp bug (#9473)

Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com>
This commit is contained in:
hlu1
2025-08-23 01:48:40 -07:00
committed by GitHub
parent c9dd70fbde
commit ccd3fb946e
3 changed files with 14 additions and 12 deletions

View File

@@ -309,6 +309,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
intermediate_size_per_partition_after_pad = round_up(
intermediate_size, 64
)
elif has_triton_kernels:
# TODO: this is a hack to make
# intermediate_size_per_partition_after_pad the same as the
# per_rank_intermediate_size during weight loading
intermediate_size_per_partition_after_pad = round_up(
intermediate_size, mxfp4_block
)
self.intermediate_size = intermediate_size_per_partition_after_pad