[HotFix] fix fp8 scale load failed in tp>1 (#2837)
This commit is contained in:
@@ -437,7 +437,7 @@ class ColumnParallelLinear(LinearBase):
|
||||
if len(loaded_weight.shape) == 0:
|
||||
assert loaded_weight.numel() == 1
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
load_column_parallel_weight(param, loaded_weight, self.tp_rank)
|
||||
param.load_column_parallel_weight(loaded_weight=loaded_weight)
|
||||
|
||||
def forward(self, input_):
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
@@ -1247,12 +1247,7 @@ class RowParallelLinear(LinearBase):
|
||||
assert loaded_weight.numel() == 1
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
|
||||
load_row_parallel_weight(
|
||||
param,
|
||||
loaded_weight,
|
||||
self.tp_rank,
|
||||
use_presharded_weights=self.use_presharded_weights,
|
||||
)
|
||||
param.load_row_parallel_weight(loaded_weight=loaded_weight)
|
||||
|
||||
def forward(self, input_):
|
||||
if self.input_is_parallel:
|
||||
|
||||
Reference in New Issue
Block a user