Fix linear.py and improve weight loading (#2851)
Co-authored-by: SangBin Cho <rkooo567@gmail.com>
This commit is contained in:
@@ -1,7 +1,4 @@
|
||||
"""
|
||||
Adapted from vLLM (0.6.4.post1).
|
||||
https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/parameter.py
|
||||
"""
|
||||
"""Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/parameter.py"""
|
||||
|
||||
import logging
|
||||
from fractions import Fraction
|
||||
@@ -88,12 +85,17 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
||||
def output_dim(self):
|
||||
return self._output_dim
|
||||
|
||||
def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = self.data.shape[self.output_dim]
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
self.output_dim, tp_rank * shard_size, shard_size
|
||||
)
|
||||
def load_column_parallel_weight(
|
||||
self,
|
||||
loaded_weight: torch.Tensor,
|
||||
tp_rank: int,
|
||||
use_presharded_weights: bool = False,
|
||||
):
|
||||
if not use_presharded_weights:
|
||||
shard_size = self.data.shape[self.output_dim]
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
self.output_dim, tp_rank * shard_size, shard_size
|
||||
)
|
||||
assert self.data.shape == loaded_weight.shape
|
||||
self.data.copy_(loaded_weight)
|
||||
|
||||
@@ -121,7 +123,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||
def load_qkv_weight(self, loaded_weight: torch.Tensor, tp_rank: int, **kwargs):
|
||||
|
||||
shard_offset = kwargs.get("shard_offset")
|
||||
shard_size = kwargs.get("shard_size")
|
||||
@@ -137,7 +139,6 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
||||
)
|
||||
|
||||
param_data = self.data
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
|
||||
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
@@ -164,11 +165,14 @@ class RowvLLMParameter(BasevLLMParameter):
|
||||
def input_dim(self):
|
||||
return self._input_dim
|
||||
|
||||
def load_row_parallel_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||
use_presharded_weights = kwargs.get("use_presharded_weights")
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = self.data.shape[self.input_dim]
|
||||
def load_row_parallel_weight(
|
||||
self,
|
||||
loaded_weight: torch.Tensor,
|
||||
tp_rank: int,
|
||||
use_presharded_weights: bool = False,
|
||||
):
|
||||
if not use_presharded_weights:
|
||||
shard_size = self.data.shape[self.input_dim]
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
self.input_dim, tp_rank * shard_size, shard_size
|
||||
)
|
||||
@@ -238,6 +242,8 @@ class PerTensorScaleParameter(BasevLLMParameter):
|
||||
# For row parallel layers, no sharding needed
|
||||
# load weight into parameter as is
|
||||
def load_row_parallel_weight(self, *args, **kwargs):
|
||||
kwargs.pop("tp_rank", None)
|
||||
kwargs.pop("use_presharded_weights", None)
|
||||
super().load_row_parallel_weight(*args, **kwargs)
|
||||
|
||||
def load_merged_column_weight(self, *args, **kwargs):
|
||||
@@ -247,6 +253,8 @@ class PerTensorScaleParameter(BasevLLMParameter):
|
||||
self._load_into_shard_id(*args, **kwargs)
|
||||
|
||||
def load_column_parallel_weight(self, *args, **kwargs):
|
||||
kwargs.pop("tp_rank", None)
|
||||
kwargs.pop("use_presharded_weights", None)
|
||||
super().load_row_parallel_weight(*args, **kwargs)
|
||||
|
||||
def _load_into_shard_id(
|
||||
|
||||
Reference in New Issue
Block a user