From f0e15dc6ab6766a8fcdeedb5432b92a18e14979f Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Sat, 11 Jan 2025 14:34:26 +0800 Subject: [PATCH] [HotFix] fix fp8 scale load failed in tp>1 (#2837) --- python/sglang/srt/layers/linear.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 9edfa7394..b839deeb3 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -437,7 +437,7 @@ class ColumnParallelLinear(LinearBase): if len(loaded_weight.shape) == 0: assert loaded_weight.numel() == 1 loaded_weight = loaded_weight.reshape(1) - load_column_parallel_weight(param, loaded_weight, self.tp_rank) + param.load_column_parallel_weight(loaded_weight=loaded_weight) def forward(self, input_): bias = self.bias if not self.skip_bias_add else None @@ -1247,12 +1247,7 @@ class RowParallelLinear(LinearBase): assert loaded_weight.numel() == 1 loaded_weight = loaded_weight.reshape(1) - load_row_parallel_weight( - param, - loaded_weight, - self.tp_rank, - use_presharded_weights=self.use_presharded_weights, - ) + param.load_row_parallel_weight(loaded_weight=loaded_weight) def forward(self, input_): if self.input_is_parallel: