Revert "Revert "[FEAT] Support GGUF format"" (#2287)
This commit is contained in:
@@ -396,7 +396,10 @@ class TorchNativeLlamaForCausalLM(nn.Module):
|
||||
self.torchao_config = global_server_args_dict["torchao_config"]
|
||||
self.supports_torch_tp = True
|
||||
self.model = LlamaModel(config, quant_config=quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
# turning off autotune for fp8dq since it doesn't give speedup and
|
||||
@@ -413,7 +416,7 @@ class TorchNativeLlamaForCausalLM(nn.Module):
|
||||
) -> LogitsProcessorOutput:
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
)
|
||||
|
||||
def get_hidden_dim(self, module_name):
|
||||
@@ -501,14 +504,6 @@ class TorchNativeLlamaForCausalLM(nn.Module):
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
if (
|
||||
hasattr(self.config, "tie_word_embeddings")
|
||||
and self.config.tie_word_embeddings
|
||||
):
|
||||
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
|
||||
param = self.lm_head.weight
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, self.model.embed_tokens.weight)
|
||||
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user