Revert "[FEAT] Support GGUF format" (#2285)
This commit is contained in:
@@ -338,12 +338,11 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.model = BaiChuanModel(config, position_embedding, quant_config)
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||
)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||
)
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
def forward(
|
||||
@@ -354,7 +353,7 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, forward_batch)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -378,7 +378,7 @@ class ChatGLMForCausalLM(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, forward_batch)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -339,7 +339,7 @@ class CohereForCausalLM(nn.Module):
|
||||
forward_batch,
|
||||
)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.model.embed_tokens, forward_batch
|
||||
input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -390,7 +390,7 @@ class DbrxForCausalLM(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, forward_batch)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -394,7 +394,7 @@ class DeepseekForCausalLM(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, forward_batch)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -763,7 +763,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
hidden_states = self.model(input_ids, positions, forward_batch)
|
||||
if not forward_batch.forward_mode.is_idle():
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -314,7 +314,7 @@ class ExaoneForCausalLM(nn.Module):
|
||||
input_ids, positions, forward_batch, input_embeds
|
||||
)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -298,7 +298,7 @@ class GemmaForCausalLM(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.model.embed_tokens, forward_batch
|
||||
input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -363,7 +363,7 @@ class Gemma2ForCausalLM(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.model.embed_tokens, forward_batch
|
||||
input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
|
||||
)
|
||||
|
||||
def get_attention_sliding_window_size(self):
|
||||
|
||||
@@ -247,7 +247,7 @@ class GPT2LMHeadModel(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, forward_batch)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -271,7 +271,7 @@ class GPTBigCodeForCausalLM(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, forward_batch)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -304,7 +304,7 @@ class Grok1ForCausalLM(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -270,7 +270,7 @@ class InternLM2ForCausalLM(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.output, forward_batch
|
||||
input_ids, hidden_states, self.output.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -258,7 +258,6 @@ class LlamaModel(nn.Module):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
@@ -306,12 +305,7 @@ class LlamaForCausalLM(nn.Module):
|
||||
self.quant_config = quant_config
|
||||
self.torchao_config = global_server_args_dict["torchao_config"]
|
||||
self.model = LlamaModel(config, quant_config=quant_config)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||
)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
self.stacked_params_mapping = [
|
||||
@@ -335,7 +329,7 @@ class LlamaForCausalLM(nn.Module):
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
if not get_embedding:
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
else:
|
||||
return self.pooler(hidden_states, forward_batch)
|
||||
@@ -379,6 +373,7 @@ class LlamaForCausalLM(nn.Module):
|
||||
return len(params_dict)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
embed_tokens_weight = None
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
@@ -390,6 +385,12 @@ class LlamaForCausalLM(nn.Module):
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
load_tie_word_embeddings = (
|
||||
hasattr(self.config, "tie_word_embeddings")
|
||||
and self.config.tie_word_embeddings
|
||||
and "lm_head.weight" in params_dict
|
||||
)
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name or "projector" in name:
|
||||
continue
|
||||
@@ -422,6 +423,16 @@ class LlamaForCausalLM(nn.Module):
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
if load_tie_word_embeddings and name == "model.embed_tokens.weight":
|
||||
embed_tokens_weight = loaded_weight
|
||||
|
||||
if load_tie_word_embeddings:
|
||||
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
|
||||
param = self.lm_head.weight
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
if embed_tokens_weight is not None:
|
||||
weight_loader(param, embed_tokens_weight)
|
||||
|
||||
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
|
||||
|
||||
def get_weights_by_name(
|
||||
|
||||
@@ -308,10 +308,12 @@ class MiniCPMForCausalLM(nn.Module):
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
hidden_states = hidden_states / self.scale_width
|
||||
if self.config.tie_word_embeddings:
|
||||
lm_head = self.model.embed_tokens
|
||||
lm_head_weight = self.model.embed_tokens.weight
|
||||
else:
|
||||
lm_head = self.lm_head
|
||||
return self.logits_processor(input_ids, hidden_states, lm_head, forward_batch)
|
||||
lm_head_weight = self.lm_head.weight
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, lm_head_weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
|
||||
@@ -585,10 +585,12 @@ class MiniCPM3ForCausalLM(nn.Module):
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
hidden_states = hidden_states / self.scale_width
|
||||
if self.config.tie_word_embeddings:
|
||||
lm_head = self.model.embed_tokens
|
||||
lm_head_weight = self.model.embed_tokens.weight
|
||||
else:
|
||||
lm_head = self.lm_head
|
||||
return self.logits_processor(input_ids, hidden_states, lm_head, forward_batch)
|
||||
lm_head_weight = self.lm_head.weight
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, lm_head_weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
|
||||
@@ -310,7 +310,7 @@ class MixtralForCausalLM(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -343,7 +343,7 @@ class QuantMixtralForCausalLM(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -966,7 +966,7 @@ class MllamaForConditionalGeneration(nn.Module):
|
||||
skip_cross_attention=skip_cross_attention,
|
||||
)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.language_model.lm_head, forward_batch
|
||||
input_ids, hidden_states, self.language_model.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -306,7 +306,7 @@ class OlmoForCausalLM(nn.Module):
|
||||
input_embeds=input_embeds,
|
||||
)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
@@ -326,6 +326,11 @@ class OlmoForCausalLM(nn.Module):
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
# With tie_word_embeddings, we can skip lm_head.weight
|
||||
# The weight might appear unnecessarily in the files if the model is
|
||||
# processed with quantization, LoRA, fine-tuning, etc.
|
||||
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
@@ -321,7 +321,7 @@ class OlmoeForCausalLM(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -397,13 +397,10 @@ class Phi3SmallForCausalLM(nn.Module):
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(
|
||||
input_ids, self.lm_head, hidden_states, sampling_metadata
|
||||
)
|
||||
logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata)
|
||||
if self.dummy_token_indices is not None and logits is not None:
|
||||
logits.index_fill_(-1, self.dummy_token_indices, -torch.inf)
|
||||
return logits
|
||||
@@ -425,7 +422,7 @@ class Phi3SmallForCausalLM(nn.Module):
|
||||
|
||||
if not get_embedding:
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
@@ -260,7 +260,7 @@ class QWenLMHeadModel(nn.Module):
|
||||
):
|
||||
hidden_states = self.transformer(input_ids, positions, forward_batch)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -230,7 +230,6 @@ class Qwen2Model(nn.Module):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
@@ -277,12 +276,7 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = Qwen2Model(config, quant_config=quant_config)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||
)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
@@ -298,7 +292,7 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
if not get_embedding:
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
else:
|
||||
return self.pooler(hidden_states, forward_batch)
|
||||
@@ -312,7 +306,6 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name or "projector" in name:
|
||||
@@ -342,6 +335,11 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
if (
|
||||
self.config.tie_word_embeddings
|
||||
and name == "model.embed_tokens.weight"
|
||||
):
|
||||
weight_loader(params_dict["lm_head.weight"], loaded_weight)
|
||||
|
||||
|
||||
EntryClass = Qwen2ForCausalLM
|
||||
|
||||
@@ -376,7 +376,7 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -668,7 +668,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
|
||||
if not get_embedding:
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
else:
|
||||
return self.pooler(hidden_states, forward_batch)
|
||||
@@ -686,6 +686,8 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
@@ -261,7 +261,7 @@ class StableLmForCausalLM(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -396,10 +396,7 @@ class TorchNativeLlamaForCausalLM(nn.Module):
|
||||
self.torchao_config = global_server_args_dict["torchao_config"]
|
||||
self.supports_torch_tp = True
|
||||
self.model = LlamaModel(config, quant_config=quant_config)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
# turning off autotune for fp8dq since it doesn't give speedup and
|
||||
@@ -416,7 +413,7 @@ class TorchNativeLlamaForCausalLM(nn.Module):
|
||||
) -> LogitsProcessorOutput:
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
def get_hidden_dim(self, module_name):
|
||||
@@ -504,6 +501,14 @@ class TorchNativeLlamaForCausalLM(nn.Module):
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
if (
|
||||
hasattr(self.config, "tie_word_embeddings")
|
||||
and self.config.tie_word_embeddings
|
||||
):
|
||||
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
|
||||
param = self.lm_head.weight
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, self.model.embed_tokens.weight)
|
||||
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
|
||||
|
||||
|
||||
|
||||
@@ -315,7 +315,7 @@ class XverseForCausalLM(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(
|
||||
|
||||
@@ -390,7 +390,7 @@ class XverseMoeForCausalLM(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, forward_batch)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
Reference in New Issue
Block a user