Revert "Revert "[FEAT] Support GGUF format"" (#2287)
This commit is contained in:
@@ -230,6 +230,7 @@ class Qwen2Model(nn.Module):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
@@ -276,7 +277,12 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = Qwen2Model(config, quant_config=quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
@@ -292,7 +298,7 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
if not get_embedding:
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
)
|
||||
else:
|
||||
return self.pooler(hidden_states, forward_batch)
|
||||
@@ -306,6 +312,7 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name or "projector" in name:
|
||||
@@ -335,11 +342,6 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
if (
|
||||
self.config.tie_word_embeddings
|
||||
and name == "model.embed_tokens.weight"
|
||||
):
|
||||
weight_loader(params_dict["lm_head.weight"], loaded_weight)
|
||||
|
||||
|
||||
EntryClass = Qwen2ForCausalLM
|
||||
|
||||
Reference in New Issue
Block a user