[HotFix] fix fp8 scale load failed in tp>1 (#2837)

This commit is contained in:
Xiaoyu Zhang
2025-01-11 14:34:26 +08:00
committed by GitHub
parent f1769586d6
commit f0e15dc6ab

View File

@@ -437,7 +437,7 @@ class ColumnParallelLinear(LinearBase):
if len(loaded_weight.shape) == 0:
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)
load_column_parallel_weight(param, loaded_weight, self.tp_rank)
param.load_column_parallel_weight(loaded_weight=loaded_weight)
def forward(self, input_):
bias = self.bias if not self.skip_bias_add else None
@@ -1247,12 +1247,7 @@ class RowParallelLinear(LinearBase):
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)
load_row_parallel_weight(
param,
loaded_weight,
self.tp_rank,
use_presharded_weights=self.use_presharded_weights,
)
param.load_row_parallel_weight(loaded_weight=loaded_weight)
def forward(self, input_):
if self.input_is_parallel: