Revert "Revert "[FEAT] Support GGUF format"" (#2287)

This commit is contained in:
Lianmin Zheng
2024-11-30 22:14:48 -08:00
committed by GitHub
parent 1bfa511b95
commit 4936be8acc
41 changed files with 229 additions and 132 deletions

View File

@@ -338,11 +338,12 @@ class BaiChuanBaseForCausalLM(nn.Module):
self.quant_config = quant_config
self.model = BaiChuanModel(config, position_embedding, quant_config)
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
self.logits_processor = LogitsProcessor(config)
def forward(
@@ -353,7 +354,7 @@ class BaiChuanBaseForCausalLM(nn.Module):
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -378,7 +378,7 @@ class ChatGLMForCausalLM(nn.Module):
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -339,7 +339,7 @@ class CohereForCausalLM(nn.Module):
forward_batch,
)
return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
input_ids, hidden_states, self.model.embed_tokens, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -390,7 +390,7 @@ class DbrxForCausalLM(nn.Module):
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -394,7 +394,7 @@ class DeepseekForCausalLM(nn.Module):
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -763,7 +763,7 @@ class DeepseekV2ForCausalLM(nn.Module):
hidden_states = self.model(input_ids, positions, forward_batch)
if not forward_batch.forward_mode.is_idle():
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -314,7 +314,7 @@ class ExaoneForCausalLM(nn.Module):
input_ids, positions, forward_batch, input_embeds
)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -298,7 +298,7 @@ class GemmaForCausalLM(nn.Module):
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
input_ids, hidden_states, self.model.embed_tokens, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -363,7 +363,7 @@ class Gemma2ForCausalLM(nn.Module):
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
input_ids, hidden_states, self.model.embed_tokens, forward_batch
)
def get_attention_sliding_window_size(self):

View File

@@ -247,7 +247,7 @@ class GPT2LMHeadModel(nn.Module):
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -271,7 +271,7 @@ class GPTBigCodeForCausalLM(nn.Module):
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -304,7 +304,7 @@ class Grok1ForCausalLM(nn.Module):
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -270,7 +270,7 @@ class InternLM2ForCausalLM(nn.Module):
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.output.weight, forward_batch
input_ids, hidden_states, self.output, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -45,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import make_layers
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
@@ -258,6 +259,7 @@ class LlamaModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
)
self.layers = make_layers(
config.num_hidden_layers,
@@ -305,7 +307,12 @@ class LlamaForCausalLM(nn.Module):
self.quant_config = quant_config
self.torchao_config = global_server_args_dict["torchao_config"]
self.model = LlamaModel(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.stacked_params_mapping = [
@@ -329,7 +336,7 @@ class LlamaForCausalLM(nn.Module):
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
if not get_embedding:
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)
else:
return self.pooler(hidden_states, forward_batch)
@@ -373,7 +380,6 @@ class LlamaForCausalLM(nn.Module):
return len(params_dict)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
embed_tokens_weight = None
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
@@ -385,12 +391,6 @@ class LlamaForCausalLM(nn.Module):
params_dict = dict(self.named_parameters())
load_tie_word_embeddings = (
hasattr(self.config, "tie_word_embeddings")
and self.config.tie_word_embeddings
and "lm_head.weight" in params_dict
)
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name or "projector" in name:
continue
@@ -423,16 +423,6 @@ class LlamaForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
if load_tie_word_embeddings and name == "model.embed_tokens.weight":
embed_tokens_weight = loaded_weight
if load_tie_word_embeddings:
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
param = self.lm_head.weight
weight_loader = getattr(param, "weight_loader", default_weight_loader)
if embed_tokens_weight is not None:
weight_loader(param, embed_tokens_weight)
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
def get_weights_by_name(
@@ -444,6 +434,17 @@ class LlamaForCausalLM(nn.Module):
For optimized performance, please use torch.save and torch.load.
"""
try:
if name == "lm_head.weight" and self.config.tie_word_embeddings:
logger.info(
"word embedding is tied for this model, return embed_tokens.weight as lm_head.weight."
)
return (
self.model.embed_tokens.weight.cpu()
.to(torch.float32)
.numpy()
.tolist()[:truncate_size]
)
mapped_name = name
mapped_shard_id = None
for param_name, weight_name, shard_id in self.stacked_params_mapping:
@@ -452,54 +453,48 @@ class LlamaForCausalLM(nn.Module):
mapped_shard_id = shard_id
break
params_dict = dict(self.named_parameters())
if mapped_name in params_dict:
param = params_dict[mapped_name]
if mapped_shard_id is not None:
if mapped_shard_id in ["q", "k", "v"]:
num_heads = self.config.num_attention_heads // tp_size
num_kv_heads = self.config.num_key_value_heads // tp_size
head_dim = (
self.config.hidden_size // self.config.num_attention_heads
)
if mapped_shard_id == "q":
offset = 0
size = num_heads * head_dim
elif mapped_shard_id == "k":
offset = num_heads * head_dim
size = num_kv_heads * head_dim
elif mapped_shard_id == "v":
offset = (num_heads + num_kv_heads) * head_dim
size = num_kv_heads * head_dim
weight = param.data.narrow(0, offset, size)
elif mapped_shard_id in [0, 1]:
intermediate_size = self.config.intermediate_size
hidden_size = self.config.hidden_size
slice_size = intermediate_size // tp_size
if mapped_shard_id == 0: # gate_proj
offset = 0
size = slice_size
elif mapped_shard_id == 1: # up_proj
offset = slice_size
size = slice_size
param = params_dict[mapped_name]
if mapped_shard_id is not None:
if mapped_shard_id in ["q", "k", "v"]:
num_heads = self.config.num_attention_heads // tp_size
num_kv_heads = self.config.num_key_value_heads // tp_size
head_dim = (
self.config.hidden_size // self.config.num_attention_heads
)
if mapped_shard_id == "q":
offset = 0
size = num_heads * head_dim
elif mapped_shard_id == "k":
offset = num_heads * head_dim
size = num_kv_heads * head_dim
elif mapped_shard_id == "v":
offset = (num_heads + num_kv_heads) * head_dim
size = num_kv_heads * head_dim
weight = param.data.narrow(0, offset, size)
elif mapped_shard_id in [0, 1]:
intermediate_size = self.config.intermediate_size
slice_size = intermediate_size // tp_size
if mapped_shard_id == 0: # gate_proj
offset = 0
size = slice_size
elif mapped_shard_id == 1: # up_proj
offset = slice_size
size = slice_size
weight = param.data.narrow(0, offset, size)
else:
weight = param.data
weight = param.data.narrow(0, offset, size)
else:
weight = param.data
if tp_size > 1 and ("o_proj" in name or "down_proj" in name):
gathered_weights = [
torch.zeros_like(weight) for _ in range(tp_size)
]
torch.distributed.all_gather(gathered_weights, weight)
weight = torch.cat(gathered_weights, dim=1)
return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size]
else:
return None
weight = param.data
if tp_size > 1 and ("o_proj" in name or "down_proj" in name):
gathered_weights = [torch.zeros_like(weight) for _ in range(tp_size)]
torch.distributed.all_gather(gathered_weights, weight)
weight = torch.cat(gathered_weights, dim=1)
return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size]
except Exception as e:
except Exception:
logger.error(
f"Error getting weights by name {name} in LlamaForCausalLM: {e}"
f"Error getting weights by name {name} in LlamaForCausalLM: {get_exception_traceback()}"
)
return None

View File

@@ -308,12 +308,10 @@ class MiniCPMForCausalLM(nn.Module):
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
hidden_states = hidden_states / self.scale_width
if self.config.tie_word_embeddings:
lm_head_weight = self.model.embed_tokens.weight
lm_head = self.model.embed_tokens
else:
lm_head_weight = self.lm_head.weight
return self.logits_processor(
input_ids, hidden_states, lm_head_weight, forward_batch
)
lm_head = self.lm_head
return self.logits_processor(input_ids, hidden_states, lm_head, forward_batch)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [

View File

@@ -585,12 +585,10 @@ class MiniCPM3ForCausalLM(nn.Module):
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
hidden_states = hidden_states / self.scale_width
if self.config.tie_word_embeddings:
lm_head_weight = self.model.embed_tokens.weight
lm_head = self.model.embed_tokens
else:
lm_head_weight = self.lm_head.weight
return self.logits_processor(
input_ids, hidden_states, lm_head_weight, forward_batch
)
lm_head = self.lm_head
return self.logits_processor(input_ids, hidden_states, lm_head, forward_batch)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [

View File

@@ -310,7 +310,7 @@ class MixtralForCausalLM(nn.Module):
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -343,7 +343,7 @@ class QuantMixtralForCausalLM(nn.Module):
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -966,7 +966,7 @@ class MllamaForConditionalGeneration(nn.Module):
skip_cross_attention=skip_cross_attention,
)
return self.logits_processor(
input_ids, hidden_states, self.language_model.lm_head.weight, forward_batch
input_ids, hidden_states, self.language_model.lm_head, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -306,7 +306,7 @@ class OlmoForCausalLM(nn.Module):
input_embeds=input_embeds,
)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -326,11 +326,6 @@ class OlmoForCausalLM(nn.Module):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue

View File

@@ -321,7 +321,7 @@ class OlmoeForCausalLM(nn.Module):
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -397,10 +397,13 @@ class Phi3SmallForCausalLM(nn.Module):
def compute_logits(
self,
input_ids: torch.LongTensor,
hidden_states: torch.Tensor,
sampling_metadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata)
logits = self.logits_processor(
input_ids, self.lm_head, hidden_states, sampling_metadata
)
if self.dummy_token_indices is not None and logits is not None:
logits.index_fill_(-1, self.dummy_token_indices, -torch.inf)
return logits
@@ -422,7 +425,7 @@ class Phi3SmallForCausalLM(nn.Module):
if not get_embedding:
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)
else:

View File

@@ -260,7 +260,7 @@ class QWenLMHeadModel(nn.Module):
):
hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -230,6 +230,7 @@ class Qwen2Model(nn.Module):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
)
self.layers = make_layers(
config.num_hidden_layers,
@@ -276,7 +277,12 @@ class Qwen2ForCausalLM(nn.Module):
self.config = config
self.quant_config = quant_config
self.model = Qwen2Model(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
@@ -292,7 +298,7 @@ class Qwen2ForCausalLM(nn.Module):
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
if not get_embedding:
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)
else:
return self.pooler(hidden_states, forward_batch)
@@ -306,6 +312,7 @@ class Qwen2ForCausalLM(nn.Module):
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name or "projector" in name:
@@ -335,11 +342,6 @@ class Qwen2ForCausalLM(nn.Module):
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
if (
self.config.tie_word_embeddings
and name == "model.embed_tokens.weight"
):
weight_loader(params_dict["lm_head.weight"], loaded_weight)
EntryClass = Qwen2ForCausalLM

View File

@@ -376,7 +376,7 @@ class Qwen2MoeForCausalLM(nn.Module):
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -668,7 +668,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
if not get_embedding:
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)
else:
return self.pooler(hidden_states, forward_batch)
@@ -686,8 +686,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue

View File

@@ -261,7 +261,7 @@ class StableLmForCausalLM(nn.Module):
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -396,7 +396,10 @@ class TorchNativeLlamaForCausalLM(nn.Module):
self.torchao_config = global_server_args_dict["torchao_config"]
self.supports_torch_tp = True
self.model = LlamaModel(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
# turning off autotune for fp8dq since it doesn't give speedup and
@@ -413,7 +416,7 @@ class TorchNativeLlamaForCausalLM(nn.Module):
) -> LogitsProcessorOutput:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)
def get_hidden_dim(self, module_name):
@@ -501,14 +504,6 @@ class TorchNativeLlamaForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
if (
hasattr(self.config, "tie_word_embeddings")
and self.config.tie_word_embeddings
):
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
param = self.lm_head.weight
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, self.model.embed_tokens.weight)
apply_torchao_config_(self, params_dict, set(["proj.weight"]))

View File

@@ -315,7 +315,7 @@ class XverseForCausalLM(nn.Module):
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)
def load_weights(

View File

@@ -390,7 +390,7 @@ class XverseMoeForCausalLM(nn.Module):
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):