Fix linear.py and improve weight loading (#2851)
Co-authored-by: SangBin Cho <rkooo567@gmail.com>
This commit is contained in:
@@ -220,6 +220,7 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
enable_tp: bool = True,
|
||||
use_presharded_weights: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.quant_config = quant_config
|
||||
@@ -236,6 +237,12 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
self.padding_size = padding_size
|
||||
self.org_vocab_size = org_num_embeddings or num_embeddings
|
||||
num_added_embeddings = num_embeddings - self.org_vocab_size
|
||||
self.use_presharded_weights = use_presharded_weights
|
||||
if use_presharded_weights:
|
||||
assert (
|
||||
num_added_embeddings == 0
|
||||
), "Lora is not supported with presharded weights."
|
||||
|
||||
self.org_vocab_size_padded = pad_vocab_size(
|
||||
self.org_vocab_size, self.padding_size
|
||||
)
|
||||
@@ -447,10 +454,14 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
start_idx = start_idx // packed_factor
|
||||
shard_size = shard_size // packed_factor
|
||||
else:
|
||||
assert loaded_weight.shape[output_dim] == self.org_vocab_size
|
||||
assert loaded_weight.shape[output_dim] == (
|
||||
self.org_vocab_size
|
||||
// (self.tp_size if self.use_presharded_weights else 1)
|
||||
)
|
||||
|
||||
# Copy the data.
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
||||
if not self.use_presharded_weights:
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
||||
param[: loaded_weight.shape[0]].data.copy_(loaded_weight)
|
||||
param[loaded_weight.shape[0] :].data.fill_(0)
|
||||
|
||||
@@ -514,6 +525,7 @@ class ParallelLMHead(VocabParallelEmbedding):
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_presharded_weights: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
num_embeddings,
|
||||
@@ -523,6 +535,7 @@ class ParallelLMHead(VocabParallelEmbedding):
|
||||
padding_size,
|
||||
quant_config,
|
||||
prefix,
|
||||
use_presharded_weights=use_presharded_weights,
|
||||
)
|
||||
self.quant_config = quant_config
|
||||
if bias:
|
||||
|
||||
Reference in New Issue
Block a user