[QUANT] Add GPTQModel Dynamic Quantization + lm_head Quantization (#3790)
Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai> Co-authored-by: ZX-ModelCloud <zx@modelcloud.ai>
This commit is contained in:
committed by
GitHub
parent
583d6af71b
commit
56a724eba3
@@ -261,26 +261,27 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
)
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
linear_method = None
|
||||
quant_method = None
|
||||
if quant_config is not None:
|
||||
linear_method = quant_config.get_quant_method(self, prefix=prefix)
|
||||
if linear_method is None:
|
||||
linear_method = UnquantizedEmbeddingMethod()
|
||||
quant_method = quant_config.get_quant_method(self, prefix=prefix)
|
||||
print("quant_method", quant_method)
|
||||
if quant_method is None:
|
||||
quant_method = UnquantizedEmbeddingMethod()
|
||||
|
||||
# If we are making an embedding layer, then our quantization linear
|
||||
# method must implement the embedding operation. If we are another
|
||||
# layer type like ParallelLMHead, this is not important.
|
||||
is_embedding_layer = type(self.__class__) is VocabParallelEmbedding
|
||||
linear_method_implements_embedding = method_has_implemented_embedding(
|
||||
type(linear_method)
|
||||
quant_method_implements_embedding = method_has_implemented_embedding(
|
||||
type(quant_method)
|
||||
)
|
||||
if is_embedding_layer and not linear_method_implements_embedding:
|
||||
if is_embedding_layer and not quant_method_implements_embedding:
|
||||
raise NotImplementedError(
|
||||
f"The class {type(linear_method).__name__} must implement "
|
||||
f"The class {type(quant_method).__name__} must implement "
|
||||
"the 'embedding' method, see UnquantizedEmbeddingMethod."
|
||||
)
|
||||
|
||||
self.linear_method: QuantizeMethodBase = linear_method
|
||||
self.quant_method: QuantizeMethodBase = quant_method
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
@@ -301,7 +302,7 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
- self.shard_indices.added_vocab_start_index
|
||||
)
|
||||
|
||||
self.linear_method.create_weights(
|
||||
self.quant_method.create_weights(
|
||||
self,
|
||||
self.embedding_dim,
|
||||
[self.num_embeddings_per_partition],
|
||||
@@ -446,7 +447,7 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
packed_factor = (
|
||||
param.packed_factor
|
||||
if isinstance(param, BasevLLMParameter)
|
||||
else param.pack_factor
|
||||
else param.packed_factor
|
||||
)
|
||||
assert loaded_weight.shape[output_dim] == (
|
||||
self.org_vocab_size // param.packed_factor
|
||||
@@ -479,7 +480,7 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
else:
|
||||
masked_input = input_
|
||||
# Get the embeddings.
|
||||
output_parallel = self.linear_method.embedding(self, masked_input.long())
|
||||
output_parallel = self.quant_method.embedding(self, masked_input.long())
|
||||
# Mask the output embedding.
|
||||
if self.tp_size > 1:
|
||||
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
||||
|
||||
Reference in New Issue
Block a user