[HotFix] fix fp8 scale load failed in tp>1 (#2837)
This commit is contained in:
@@ -437,7 +437,7 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
if len(loaded_weight.shape) == 0:
|
if len(loaded_weight.shape) == 0:
|
||||||
assert loaded_weight.numel() == 1
|
assert loaded_weight.numel() == 1
|
||||||
loaded_weight = loaded_weight.reshape(1)
|
loaded_weight = loaded_weight.reshape(1)
|
||||||
load_column_parallel_weight(param, loaded_weight, self.tp_rank)
|
param.load_column_parallel_weight(loaded_weight=loaded_weight)
|
||||||
|
|
||||||
def forward(self, input_):
|
def forward(self, input_):
|
||||||
bias = self.bias if not self.skip_bias_add else None
|
bias = self.bias if not self.skip_bias_add else None
|
||||||
@@ -1247,12 +1247,7 @@ class RowParallelLinear(LinearBase):
|
|||||||
assert loaded_weight.numel() == 1
|
assert loaded_weight.numel() == 1
|
||||||
loaded_weight = loaded_weight.reshape(1)
|
loaded_weight = loaded_weight.reshape(1)
|
||||||
|
|
||||||
load_row_parallel_weight(
|
param.load_row_parallel_weight(loaded_weight=loaded_weight)
|
||||||
param,
|
|
||||||
loaded_weight,
|
|
||||||
self.tp_rank,
|
|
||||||
use_presharded_weights=self.use_presharded_weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, input_):
|
def forward(self, input_):
|
||||||
if self.input_is_parallel:
|
if self.input_is_parallel:
|
||||||
|
|||||||
Reference in New Issue
Block a user