Fix linear.py and improve weight loading (#2851)

Co-authored-by: SangBin Cho <rkooo567@gmail.com>
This commit is contained in:
Lianmin Zheng
2025-01-13 01:39:14 -08:00
committed by GitHub
parent 4093aa4660
commit 72c7776355
12 changed files with 113 additions and 125 deletions

View File

@@ -220,6 +220,7 @@ class VocabParallelEmbedding(torch.nn.Module):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_tp: bool = True,
use_presharded_weights: bool = False,
):
super().__init__()
self.quant_config = quant_config
@@ -236,6 +237,12 @@ class VocabParallelEmbedding(torch.nn.Module):
self.padding_size = padding_size
self.org_vocab_size = org_num_embeddings or num_embeddings
num_added_embeddings = num_embeddings - self.org_vocab_size
self.use_presharded_weights = use_presharded_weights
if use_presharded_weights:
assert (
num_added_embeddings == 0
), "Lora is not supported with presharded weights."
self.org_vocab_size_padded = pad_vocab_size(
self.org_vocab_size, self.padding_size
)
@@ -447,10 +454,14 @@ class VocabParallelEmbedding(torch.nn.Module):
start_idx = start_idx // packed_factor
shard_size = shard_size // packed_factor
else:
assert loaded_weight.shape[output_dim] == self.org_vocab_size
assert loaded_weight.shape[output_dim] == (
self.org_vocab_size
// (self.tp_size if self.use_presharded_weights else 1)
)
# Copy the data.
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
if not self.use_presharded_weights:
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
param[: loaded_weight.shape[0]].data.copy_(loaded_weight)
param[loaded_weight.shape[0] :].data.fill_(0)
@@ -514,6 +525,7 @@ class ParallelLMHead(VocabParallelEmbedding):
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_presharded_weights: bool = False,
):
super().__init__(
num_embeddings,
@@ -523,6 +535,7 @@ class ParallelLMHead(VocabParallelEmbedding):
padding_size,
quant_config,
prefix,
use_presharded_weights=use_presharded_weights,
)
self.quant_config = quant_config
if bias: