Fix mismatch between padded_scales shape and reshape dimensions in modelopt quantization (#8766)
This commit is contained in:
@@ -677,9 +677,9 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
|
|||||||
padded_scales = padded_scales.permute((0, 1, 4, 3, 2, 5))
|
padded_scales = padded_scales.permute((0, 1, 4, 3, 2, 5))
|
||||||
padded_scales = padded_scales.contiguous().cuda()
|
padded_scales = padded_scales.contiguous().cuda()
|
||||||
padded_scales = (
|
padded_scales = (
|
||||||
padded_scales.reshape(M, K)
|
padded_scales.reshape(M_padded, K_padded)
|
||||||
if scale_ndim == 2
|
if scale_ndim == 2
|
||||||
else padded_scales.reshape(B, M, K)
|
else padded_scales.reshape(B, M_padded, K_padded)
|
||||||
)
|
)
|
||||||
layer.weight_scale_interleaved = Parameter(padded_scales, requires_grad=False)
|
layer.weight_scale_interleaved = Parameter(padded_scales, requires_grad=False)
|
||||||
|
|
||||||
@@ -878,9 +878,9 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|||||||
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
|
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
|
||||||
swizzled_scale = swizzled_scale.contiguous().cuda()
|
swizzled_scale = swizzled_scale.contiguous().cuda()
|
||||||
return (
|
return (
|
||||||
swizzled_scale.reshape(M, K)
|
swizzled_scale.reshape(M_padded, K_padded)
|
||||||
if scale_ndim == 2
|
if scale_ndim == 2
|
||||||
else swizzled_scale.reshape(B, M, K)
|
else swizzled_scale.reshape(B, M_padded, K_padded)
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_static_weights_for_kernel(
|
def prepare_static_weights_for_kernel(
|
||||||
|
|||||||
Reference in New Issue
Block a user