[fix] Fix mxfp4 triton MoE tp bug (#9473)
Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com>
This commit is contained in:
@@ -309,6 +309,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size, 64
|
||||
)
|
||||
elif has_triton_kernels:
|
||||
# TODO: this is a hack to make
|
||||
# intermediate_size_per_partition_after_pad the same as the
|
||||
# per_rank_intermediate_size during weight loading
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size, mxfp4_block
|
||||
)
|
||||
|
||||
self.intermediate_size = intermediate_size_per_partition_after_pad
|
||||
|
||||
|
||||
Reference in New Issue
Block a user