diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 9edfa7394..b839deeb3 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -437,7 +437,7 @@ class ColumnParallelLinear(LinearBase): if len(loaded_weight.shape) == 0: assert loaded_weight.numel() == 1 loaded_weight = loaded_weight.reshape(1) - load_column_parallel_weight(param, loaded_weight, self.tp_rank) + param.load_column_parallel_weight(loaded_weight=loaded_weight) def forward(self, input_): bias = self.bias if not self.skip_bias_add else None @@ -1247,12 +1247,7 @@ class RowParallelLinear(LinearBase): assert loaded_weight.numel() == 1 loaded_weight = loaded_weight.reshape(1) - load_row_parallel_weight( - param, - loaded_weight, - self.tp_rank, - use_presharded_weights=self.use_presharded_weights, - ) + param.load_row_parallel_weight(loaded_weight=loaded_weight) def forward(self, input_): if self.input_is_parallel: