Improve weight loading and code style (#3174)

This commit is contained in:
Lianmin Zheng
2025-01-27 03:00:41 -08:00
committed by GitHub
parent 351a72d40b
commit 53cef81587
11 changed files with 171 additions and 65 deletions

View File

@@ -124,7 +124,13 @@ class _ColumnvLLMParameter(BasevLLMParameter):
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
def load_qkv_weight(self, loaded_weight: torch.Tensor, tp_rank: int, **kwargs):
def load_qkv_weight(
self,
loaded_weight: torch.Tensor,
tp_rank: int,
use_presharded_weights: bool = False,
**kwargs,
):
shard_offset = kwargs.get("shard_offset")
shard_size = kwargs.get("shard_size")
@@ -142,11 +148,14 @@ class _ColumnvLLMParameter(BasevLLMParameter):
param_data = self.data
shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
loaded_weight = loaded_weight.narrow(
self.output_dim, shard_id * shard_size, shard_size
)
if not use_presharded_weights:
loaded_weight = loaded_weight.narrow(
self.output_dim, shard_id * shard_size, shard_size
)
assert param_data.shape == loaded_weight.shape
assert (
param_data.shape == loaded_weight.shape
), f"{param_data.shape=}, {loaded_weight.shape=}"
param_data.copy_(loaded_weight)
@@ -292,7 +301,7 @@ class PackedColumnParameter(_ColumnvLLMParameter):
packed_factor: Union[int, Fraction],
packed_dim: int,
marlin_tile_size: Optional[int] = None,
**kwargs
**kwargs,
):
self._packed_factor = packed_factor
self._packed_dim = packed_dim
@@ -336,7 +345,7 @@ class PackedvLLMParameter(ModelWeightParameter):
packed_factor: Union[int, Fraction],
packed_dim: int,
marlin_tile_size: Optional[int] = None,
**kwargs
**kwargs,
):
self._packed_factor = packed_factor
self._packed_dim = packed_dim