[QUANT] Add GPTQModel Dynamic Quantization + lm_head Quantization (#3790)

Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai>
Co-authored-by: ZX-ModelCloud <zx@modelcloud.ai>
This commit is contained in:
Qubitium-ModelCloud
2025-03-05 17:11:00 +08:00
committed by GitHub
parent 583d6af71b
commit 56a724eba3
56 changed files with 1988 additions and 282 deletions

View File

@@ -261,26 +261,27 @@ class VocabParallelEmbedding(torch.nn.Module):
)
self.embedding_dim = embedding_dim
linear_method = None
quant_method = None
if quant_config is not None:
linear_method = quant_config.get_quant_method(self, prefix=prefix)
if linear_method is None:
linear_method = UnquantizedEmbeddingMethod()
quant_method = quant_config.get_quant_method(self, prefix=prefix)
print("quant_method", quant_method)
if quant_method is None:
quant_method = UnquantizedEmbeddingMethod()
# If we are making an embedding layer, then our quantization linear
# method must implement the embedding operation. If we are another
# layer type like ParallelLMHead, this is not important.
is_embedding_layer = type(self.__class__) is VocabParallelEmbedding
linear_method_implements_embedding = method_has_implemented_embedding(
type(linear_method)
quant_method_implements_embedding = method_has_implemented_embedding(
type(quant_method)
)
if is_embedding_layer and not linear_method_implements_embedding:
if is_embedding_layer and not quant_method_implements_embedding:
raise NotImplementedError(
f"The class {type(linear_method).__name__} must implement "
f"The class {type(quant_method).__name__} must implement "
"the 'embedding' method, see UnquantizedEmbeddingMethod."
)
self.linear_method: QuantizeMethodBase = linear_method
self.quant_method: QuantizeMethodBase = quant_method
if params_dtype is None:
params_dtype = torch.get_default_dtype()
@@ -301,7 +302,7 @@ class VocabParallelEmbedding(torch.nn.Module):
- self.shard_indices.added_vocab_start_index
)
self.linear_method.create_weights(
self.quant_method.create_weights(
self,
self.embedding_dim,
[self.num_embeddings_per_partition],
@@ -446,7 +447,7 @@ class VocabParallelEmbedding(torch.nn.Module):
packed_factor = (
param.packed_factor
if isinstance(param, BasevLLMParameter)
else param.pack_factor
else param.packed_factor
)
assert loaded_weight.shape[output_dim] == (
self.org_vocab_size // param.packed_factor
@@ -479,7 +480,7 @@ class VocabParallelEmbedding(torch.nn.Module):
else:
masked_input = input_
# Get the embeddings.
output_parallel = self.linear_method.embedding(self, masked_input.long())
output_parallel = self.quant_method.embedding(self, masked_input.long())
# Mask the output embedding.
if self.tp_size > 1:
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)