Improve weight loading and code style (#3174)
This commit is contained in:
@@ -124,7 +124,13 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
def load_qkv_weight(self, loaded_weight: torch.Tensor, tp_rank: int, **kwargs):
|
||||
def load_qkv_weight(
|
||||
self,
|
||||
loaded_weight: torch.Tensor,
|
||||
tp_rank: int,
|
||||
use_presharded_weights: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
shard_offset = kwargs.get("shard_offset")
|
||||
shard_size = kwargs.get("shard_size")
|
||||
@@ -142,11 +148,14 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
||||
param_data = self.data
|
||||
shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
|
||||
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
self.output_dim, shard_id * shard_size, shard_size
|
||||
)
|
||||
if not use_presharded_weights:
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
self.output_dim, shard_id * shard_size, shard_size
|
||||
)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
assert (
|
||||
param_data.shape == loaded_weight.shape
|
||||
), f"{param_data.shape=}, {loaded_weight.shape=}"
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
@@ -292,7 +301,7 @@ class PackedColumnParameter(_ColumnvLLMParameter):
|
||||
packed_factor: Union[int, Fraction],
|
||||
packed_dim: int,
|
||||
marlin_tile_size: Optional[int] = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
self._packed_factor = packed_factor
|
||||
self._packed_dim = packed_dim
|
||||
@@ -336,7 +345,7 @@ class PackedvLLMParameter(ModelWeightParameter):
|
||||
packed_factor: Union[int, Fraction],
|
||||
packed_dim: int,
|
||||
marlin_tile_size: Optional[int] = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
self._packed_factor = packed_factor
|
||||
self._packed_dim = packed_dim
|
||||
|
||||
Reference in New Issue
Block a user