Fix linear.py and improve weight loading (#2851)

Co-authored-by: SangBin Cho <rkooo567@gmail.com>
This commit is contained in:
Lianmin Zheng
2025-01-13 01:39:14 -08:00
committed by GitHub
parent 4093aa4660
commit 72c7776355
12 changed files with 113 additions and 125 deletions

View File

@@ -1,7 +1,4 @@
"""
Adapted from vLLM (0.6.4.post1).
https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/parameter.py
"""
"""Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/parameter.py"""
import logging
from fractions import Fraction
@@ -88,12 +85,17 @@ class _ColumnvLLMParameter(BasevLLMParameter):
def output_dim(self):
return self._output_dim
def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.data.shape[self.output_dim]
loaded_weight = loaded_weight.narrow(
self.output_dim, tp_rank * shard_size, shard_size
)
def load_column_parallel_weight(
self,
loaded_weight: torch.Tensor,
tp_rank: int,
use_presharded_weights: bool = False,
):
if not use_presharded_weights:
shard_size = self.data.shape[self.output_dim]
loaded_weight = loaded_weight.narrow(
self.output_dim, tp_rank * shard_size, shard_size
)
assert self.data.shape == loaded_weight.shape
self.data.copy_(loaded_weight)
@@ -121,7 +123,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
def load_qkv_weight(self, loaded_weight: torch.Tensor, tp_rank: int, **kwargs):
shard_offset = kwargs.get("shard_offset")
shard_size = kwargs.get("shard_size")
@@ -137,7 +139,6 @@ class _ColumnvLLMParameter(BasevLLMParameter):
)
param_data = self.data
tp_rank = get_tensor_model_parallel_rank()
shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
loaded_weight = loaded_weight.narrow(
@@ -164,11 +165,14 @@ class RowvLLMParameter(BasevLLMParameter):
def input_dim(self):
return self._input_dim
def load_row_parallel_weight(self, loaded_weight: torch.Tensor, **kwargs):
use_presharded_weights = kwargs.get("use_presharded_weights")
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.data.shape[self.input_dim]
def load_row_parallel_weight(
self,
loaded_weight: torch.Tensor,
tp_rank: int,
use_presharded_weights: bool = False,
):
if not use_presharded_weights:
shard_size = self.data.shape[self.input_dim]
loaded_weight = loaded_weight.narrow(
self.input_dim, tp_rank * shard_size, shard_size
)
@@ -238,6 +242,8 @@ class PerTensorScaleParameter(BasevLLMParameter):
# For row parallel layers, no sharding needed
# load weight into parameter as is
def load_row_parallel_weight(self, *args, **kwargs):
kwargs.pop("tp_rank", None)
kwargs.pop("use_presharded_weights", None)
super().load_row_parallel_weight(*args, **kwargs)
def load_merged_column_weight(self, *args, **kwargs):
@@ -247,6 +253,8 @@ class PerTensorScaleParameter(BasevLLMParameter):
self._load_into_shard_id(*args, **kwargs)
def load_column_parallel_weight(self, *args, **kwargs):
kwargs.pop("tp_rank", None)
kwargs.pop("use_presharded_weights", None)
super().load_row_parallel_weight(*args, **kwargs)
def _load_into_shard_id(