Fix mismatch between padded_scales shape and reshape dimensions in modelopt quantization (#8766)
This commit is contained in:
@@ -677,9 +677,9 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
|
||||
padded_scales = padded_scales.permute((0, 1, 4, 3, 2, 5))
|
||||
padded_scales = padded_scales.contiguous().cuda()
|
||||
padded_scales = (
|
||||
padded_scales.reshape(M, K)
|
||||
padded_scales.reshape(M_padded, K_padded)
|
||||
if scale_ndim == 2
|
||||
else padded_scales.reshape(B, M, K)
|
||||
else padded_scales.reshape(B, M_padded, K_padded)
|
||||
)
|
||||
layer.weight_scale_interleaved = Parameter(padded_scales, requires_grad=False)
|
||||
|
||||
@@ -878,9 +878,9 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
|
||||
swizzled_scale = swizzled_scale.contiguous().cuda()
|
||||
return (
|
||||
swizzled_scale.reshape(M, K)
|
||||
swizzled_scale.reshape(M_padded, K_padded)
|
||||
if scale_ndim == 2
|
||||
else swizzled_scale.reshape(B, M, K)
|
||||
else swizzled_scale.reshape(B, M_padded, K_padded)
|
||||
)
|
||||
|
||||
def prepare_static_weights_for_kernel(
|
||||
|
||||
Reference in New Issue
Block a user