Revert "[FEAT] Support GGUF format" (#2285)

This commit is contained in:
Lianmin Zheng
2024-11-30 19:03:26 -08:00
committed by GitHub
parent d622851dc9
commit 7e4c6dd8da
39 changed files with 89 additions and 180 deletions

View File

@@ -15,4 +15,3 @@ sphinx-copybutton
sphinx-tabs sphinx-tabs
sphinxcontrib-mermaid sphinxcontrib-mermaid
urllib3<2.0.0 urllib3<2.0.0
gguf>=0.10.0

View File

@@ -16,7 +16,6 @@
import contextlib import contextlib
import os import os
import warnings import warnings
from pathlib import Path
from typing import Dict, Optional, Type, Union from typing import Dict, Optional, Type, Union
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
@@ -28,7 +27,6 @@ from transformers import (
PreTrainedTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast, PreTrainedTokenizerFast,
) )
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
try: try:
from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
@@ -62,29 +60,15 @@ def get_config(
trust_remote_code: bool, trust_remote_code: bool,
revision: Optional[str] = None, revision: Optional[str] = None,
model_override_args: Optional[dict] = None, model_override_args: Optional[dict] = None,
**kwargs,
): ):
is_gguf = check_gguf_file(model)
if is_gguf:
kwargs["gguf_file"] = model
model = Path(model).parent
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs model, trust_remote_code=trust_remote_code, revision=revision
) )
if config.model_type in _CONFIG_REGISTRY: if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type] config_class = _CONFIG_REGISTRY[config.model_type]
config = config_class.from_pretrained(model, revision=revision) config = config_class.from_pretrained(model, revision=revision)
if model_override_args: if model_override_args:
config.update(model_override_args) config.update(model_override_args)
# Special architecture mapping check for GGUF models
if is_gguf:
if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
raise RuntimeError(f"Can't get gguf config for {config.model_type}.")
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
config.update({"architectures": [model_type]})
return config return config
@@ -139,11 +123,6 @@ def get_tokenizer(
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
kwargs["use_fast"] = False kwargs["use_fast"] = False
is_gguf = check_gguf_file(tokenizer_name)
if is_gguf:
kwargs["gguf_file"] = tokenizer_name
tokenizer_name = Path(tokenizer_name).parent
try: try:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name, tokenizer_name,
@@ -216,16 +195,3 @@ def attach_additional_stop_token_ids(tokenizer):
) )
else: else:
tokenizer.additional_stop_token_ids = None tokenizer.additional_stop_token_ids = None
def check_gguf_file(model: Union[str, os.PathLike]) -> bool:
"""Check if the file is a GGUF model."""
model = Path(model)
if not model.is_file():
return False
elif model.suffix == ".gguf":
return True
with open(model, "rb") as f:
header = f.read(4)
return header == b"GGUF"

View File

@@ -23,7 +23,6 @@ from vllm.distributed import (
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
) )
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
@@ -164,7 +163,7 @@ class LogitsProcessor(nn.Module):
self, self,
input_ids, input_ids,
hidden_states, hidden_states,
lm_head: VocabParallelEmbedding, weight,
logits_metadata: Union[LogitsMetadata, ForwardBatch], logits_metadata: Union[LogitsMetadata, ForwardBatch],
): ):
if isinstance(logits_metadata, ForwardBatch): if isinstance(logits_metadata, ForwardBatch):
@@ -179,7 +178,7 @@ class LogitsProcessor(nn.Module):
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1 last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
last_hidden = hidden_states[last_index] last_hidden = hidden_states[last_index]
last_logits = self._get_logits(last_hidden, lm_head) last_logits = torch.matmul(last_hidden, weight.T)
if self.do_tensor_parallel_all_gather: if self.do_tensor_parallel_all_gather:
last_logits = tensor_model_parallel_all_gather(last_logits) last_logits = tensor_model_parallel_all_gather(last_logits)
last_logits = last_logits[:, : self.config.vocab_size].float() last_logits = last_logits[:, : self.config.vocab_size].float()
@@ -230,7 +229,7 @@ class LogitsProcessor(nn.Module):
# Compute the logits and logprobs for all required tokens # Compute the logits and logprobs for all required tokens
states = torch.cat(states, dim=0) states = torch.cat(states, dim=0)
all_logits = self._get_logits(states, lm_head) all_logits = torch.matmul(states, weight.T)
if self.do_tensor_parallel_all_gather: if self.do_tensor_parallel_all_gather:
all_logits = tensor_model_parallel_all_gather(all_logits) all_logits = tensor_model_parallel_all_gather(all_logits)
all_logits = all_logits[:, : self.config.vocab_size].float() all_logits = all_logits[:, : self.config.vocab_size].float()
@@ -277,19 +276,6 @@ class LogitsProcessor(nn.Module):
output_top_logprobs=output_top_logprobs, output_top_logprobs=output_top_logprobs,
) )
def _get_logits(
self,
hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding,
embedding_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if hasattr(lm_head, "weight"):
logits = torch.matmul(hidden_states, lm_head.weight.T)
else:
# GGUF models
logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias)
return logits
def test(): def test():
all_logprobs = torch.tensor( all_logprobs = torch.tensor(

View File

@@ -222,7 +222,6 @@ class VocabParallelEmbedding(torch.nn.Module):
enable_tp: bool = True, enable_tp: bool = True,
): ):
super().__init__() super().__init__()
self.quant_config = quant_config
self.enable_tp = enable_tp self.enable_tp = enable_tp
if self.enable_tp: if self.enable_tp:

View File

@@ -59,7 +59,6 @@ from sglang.srt.utils import (
enable_show_time_cost, enable_show_time_cost,
get_available_gpu_memory, get_available_gpu_memory,
is_hip, is_hip,
monkey_patch_vllm_gguf_config,
monkey_patch_vllm_model_config, monkey_patch_vllm_model_config,
monkey_patch_vllm_p2p_access_check, monkey_patch_vllm_p2p_access_check,
set_cpu_offload_max_bytes, set_cpu_offload_max_bytes,
@@ -298,8 +297,6 @@ class ModelRunner:
download_dir=self.server_args.download_dir, download_dir=self.server_args.download_dir,
) )
monkey_patch_vllm_model_config() monkey_patch_vllm_model_config()
if self.server_args.load_format == "gguf":
monkey_patch_vllm_gguf_config()
self.vllm_model_config = VllmModelConfig(**self.get_model_config_params()) self.vllm_model_config = VllmModelConfig(**self.get_model_config_params())
if self.model_config.model_override_args is not None: if self.model_config.model_override_args is not None:
self.vllm_model_config.hf_config.update( self.vllm_model_config.hf_config.update(

View File

@@ -338,12 +338,11 @@ class BaiChuanBaseForCausalLM(nn.Module):
self.quant_config = quant_config self.quant_config = quant_config
self.model = BaiChuanModel(config, position_embedding, quant_config) self.model = BaiChuanModel(config, position_embedding, quant_config)
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens self.lm_head.weight = self.model.embed_tokens.weight
else:
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
def forward( def forward(
@@ -354,7 +353,7 @@ class BaiChuanBaseForCausalLM(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch) hidden_states = self.model(input_ids, positions, forward_batch)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -378,7 +378,7 @@ class ChatGLMForCausalLM(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, forward_batch) hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -339,7 +339,7 @@ class CohereForCausalLM(nn.Module):
forward_batch, forward_batch,
) )
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens, forward_batch input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -390,7 +390,7 @@ class DbrxForCausalLM(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, forward_batch) hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -394,7 +394,7 @@ class DeepseekForCausalLM(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch) hidden_states = self.model(input_ids, positions, forward_batch)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -763,7 +763,7 @@ class DeepseekV2ForCausalLM(nn.Module):
hidden_states = self.model(input_ids, positions, forward_batch) hidden_states = self.model(input_ids, positions, forward_batch)
if not forward_batch.forward_mode.is_idle(): if not forward_batch.forward_mode.is_idle():
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -314,7 +314,7 @@ class ExaoneForCausalLM(nn.Module):
input_ids, positions, forward_batch, input_embeds input_ids, positions, forward_batch, input_embeds
) )
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -298,7 +298,7 @@ class GemmaForCausalLM(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens, forward_batch input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -363,7 +363,7 @@ class Gemma2ForCausalLM(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens, forward_batch input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
) )
def get_attention_sliding_window_size(self): def get_attention_sliding_window_size(self):

View File

@@ -247,7 +247,7 @@ class GPT2LMHeadModel(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, forward_batch) hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -271,7 +271,7 @@ class GPTBigCodeForCausalLM(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, forward_batch) hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -304,7 +304,7 @@ class Grok1ForCausalLM(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -270,7 +270,7 @@ class InternLM2ForCausalLM(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.output, forward_batch input_ids, hidden_states, self.output.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -258,7 +258,6 @@ class LlamaModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config,
) )
self.layers = make_layers( self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
@@ -306,12 +305,7 @@ class LlamaForCausalLM(nn.Module):
self.quant_config = quant_config self.quant_config = quant_config
self.torchao_config = global_server_args_dict["torchao_config"] self.torchao_config = global_server_args_dict["torchao_config"]
self.model = LlamaModel(config, quant_config=quant_config) self.model = LlamaModel(config, quant_config=quant_config)
if self.config.tie_word_embeddings: self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.stacked_params_mapping = [ self.stacked_params_mapping = [
@@ -335,7 +329,7 @@ class LlamaForCausalLM(nn.Module):
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
if not get_embedding: if not get_embedding:
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
else: else:
return self.pooler(hidden_states, forward_batch) return self.pooler(hidden_states, forward_batch)
@@ -379,6 +373,7 @@ class LlamaForCausalLM(nn.Module):
return len(params_dict) return len(params_dict)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
embed_tokens_weight = None
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".q_proj", "q"),
@@ -390,6 +385,12 @@ class LlamaForCausalLM(nn.Module):
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
load_tie_word_embeddings = (
hasattr(self.config, "tie_word_embeddings")
and self.config.tie_word_embeddings
and "lm_head.weight" in params_dict
)
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name or "projector" in name: if "rotary_emb.inv_freq" in name or "projector" in name:
continue continue
@@ -422,6 +423,16 @@ class LlamaForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if load_tie_word_embeddings and name == "model.embed_tokens.weight":
embed_tokens_weight = loaded_weight
if load_tie_word_embeddings:
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
param = self.lm_head.weight
weight_loader = getattr(param, "weight_loader", default_weight_loader)
if embed_tokens_weight is not None:
weight_loader(param, embed_tokens_weight)
apply_torchao_config_(self, params_dict, set(["proj.weight"])) apply_torchao_config_(self, params_dict, set(["proj.weight"]))
def get_weights_by_name( def get_weights_by_name(

View File

@@ -308,10 +308,12 @@ class MiniCPMForCausalLM(nn.Module):
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
hidden_states = hidden_states / self.scale_width hidden_states = hidden_states / self.scale_width
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
lm_head = self.model.embed_tokens lm_head_weight = self.model.embed_tokens.weight
else: else:
lm_head = self.lm_head lm_head_weight = self.lm_head.weight
return self.logits_processor(input_ids, hidden_states, lm_head, forward_batch) return self.logits_processor(
input_ids, hidden_states, lm_head_weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [

View File

@@ -585,10 +585,12 @@ class MiniCPM3ForCausalLM(nn.Module):
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
hidden_states = hidden_states / self.scale_width hidden_states = hidden_states / self.scale_width
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
lm_head = self.model.embed_tokens lm_head_weight = self.model.embed_tokens.weight
else: else:
lm_head = self.lm_head lm_head_weight = self.lm_head.weight
return self.logits_processor(input_ids, hidden_states, lm_head, forward_batch) return self.logits_processor(
input_ids, hidden_states, lm_head_weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [

View File

@@ -310,7 +310,7 @@ class MixtralForCausalLM(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -343,7 +343,7 @@ class QuantMixtralForCausalLM(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -966,7 +966,7 @@ class MllamaForConditionalGeneration(nn.Module):
skip_cross_attention=skip_cross_attention, skip_cross_attention=skip_cross_attention,
) )
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.language_model.lm_head, forward_batch input_ids, hidden_states, self.language_model.lm_head.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -306,7 +306,7 @@ class OlmoForCausalLM(nn.Module):
input_embeds=input_embeds, input_embeds=input_embeds,
) )
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -326,6 +326,11 @@ class OlmoForCausalLM(nn.Module):
# Models trained using ColossalAI may include these tensors in # Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them. # the checkpoint. Skip them.
continue continue
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue

View File

@@ -321,7 +321,7 @@ class OlmoeForCausalLM(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -397,13 +397,10 @@ class Phi3SmallForCausalLM(nn.Module):
def compute_logits( def compute_logits(
self, self,
input_ids: torch.LongTensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata, sampling_metadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor( logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata)
input_ids, self.lm_head, hidden_states, sampling_metadata
)
if self.dummy_token_indices is not None and logits is not None: if self.dummy_token_indices is not None and logits is not None:
logits.index_fill_(-1, self.dummy_token_indices, -torch.inf) logits.index_fill_(-1, self.dummy_token_indices, -torch.inf)
return logits return logits
@@ -425,7 +422,7 @@ class Phi3SmallForCausalLM(nn.Module):
if not get_embedding: if not get_embedding:
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
else: else:

View File

@@ -260,7 +260,7 @@ class QWenLMHeadModel(nn.Module):
): ):
hidden_states = self.transformer(input_ids, positions, forward_batch) hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -230,7 +230,6 @@ class Qwen2Model(nn.Module):
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config,
) )
self.layers = make_layers( self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
@@ -277,12 +276,7 @@ class Qwen2ForCausalLM(nn.Module):
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = Qwen2Model(config, quant_config=quant_config) self.model = Qwen2Model(config, quant_config=quant_config)
if config.tie_word_embeddings: self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
@@ -298,7 +292,7 @@ class Qwen2ForCausalLM(nn.Module):
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
if not get_embedding: if not get_embedding:
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
else: else:
return self.pooler(hidden_states, forward_batch) return self.pooler(hidden_states, forward_batch)
@@ -312,7 +306,6 @@ class Qwen2ForCausalLM(nn.Module):
("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name or "projector" in name: if "rotary_emb.inv_freq" in name or "projector" in name:
@@ -342,6 +335,11 @@ class Qwen2ForCausalLM(nn.Module):
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if (
self.config.tie_word_embeddings
and name == "model.embed_tokens.weight"
):
weight_loader(params_dict["lm_head.weight"], loaded_weight)
EntryClass = Qwen2ForCausalLM EntryClass = Qwen2ForCausalLM

View File

@@ -376,7 +376,7 @@ class Qwen2MoeForCausalLM(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -668,7 +668,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
if not get_embedding: if not get_embedding:
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
else: else:
return self.pooler(hidden_states, forward_batch) return self.pooler(hidden_states, forward_batch)
@@ -686,6 +686,8 @@ class Qwen2VLForConditionalGeneration(nn.Module):
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue

View File

@@ -261,7 +261,7 @@ class StableLmForCausalLM(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -396,10 +396,7 @@ class TorchNativeLlamaForCausalLM(nn.Module):
self.torchao_config = global_server_args_dict["torchao_config"] self.torchao_config = global_server_args_dict["torchao_config"]
self.supports_torch_tp = True self.supports_torch_tp = True
self.model = LlamaModel(config, quant_config=quant_config) self.model = LlamaModel(config, quant_config=quant_config)
if self.config.tie_word_embeddings: self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
# turning off autotune for fp8dq since it doesn't give speedup and # turning off autotune for fp8dq since it doesn't give speedup and
@@ -416,7 +413,7 @@ class TorchNativeLlamaForCausalLM(nn.Module):
) -> LogitsProcessorOutput: ) -> LogitsProcessorOutput:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
def get_hidden_dim(self, module_name): def get_hidden_dim(self, module_name):
@@ -504,6 +501,14 @@ class TorchNativeLlamaForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if (
hasattr(self.config, "tie_word_embeddings")
and self.config.tie_word_embeddings
):
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
param = self.lm_head.weight
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, self.model.embed_tokens.weight)
apply_torchao_config_(self, params_dict, set(["proj.weight"])) apply_torchao_config_(self, params_dict, set(["proj.weight"]))

View File

@@ -315,7 +315,7 @@ class XverseForCausalLM(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
def load_weights( def load_weights(

View File

@@ -390,7 +390,7 @@ class XverseMoeForCausalLM(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch) hidden_states = self.model(input_ids, positions, forward_batch)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -20,7 +20,6 @@ import random
import tempfile import tempfile
from typing import List, Optional from typing import List, Optional
from sglang.srt.hf_transformers_utils import check_gguf_file
from sglang.srt.utils import ( from sglang.srt.utils import (
get_amdgpu_memory_capacity, get_amdgpu_memory_capacity,
get_nvgpu_memory_capacity, get_nvgpu_memory_capacity,
@@ -205,12 +204,6 @@ class ServerArgs:
"Overlap schedule is disabled." "Overlap schedule is disabled."
) )
# GGUF
if (
self.load_format == "auto" or self.load_format == "gguf"
) and check_gguf_file(self.model_path):
self.quantization = self.load_format = "gguf"
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
# Model and port args # Model and port args
@@ -250,7 +243,7 @@ class ServerArgs:
"--load-format", "--load-format",
type=str, type=str,
default=ServerArgs.load_format, default=ServerArgs.load_format,
choices=["auto", "pt", "safetensors", "npcache", "dummy", "gguf"], choices=["auto", "pt", "safetensors", "npcache", "dummy"],
help="The format of the model weights to load. " help="The format of the model weights to load. "
'"auto" will try to load the weights in the safetensors format ' '"auto" will try to load the weights in the safetensors format '
"and fall back to the pytorch bin format if safetensors format " "and fall back to the pytorch bin format if safetensors format "
@@ -260,8 +253,7 @@ class ServerArgs:
'"npcache" will load the weights in pytorch format and store ' '"npcache" will load the weights in pytorch format and store '
"a numpy cache to speed up the loading. " "a numpy cache to speed up the loading. "
'"dummy" will initialize the weights with random values, ' '"dummy" will initialize the weights with random values, '
"which is mainly for profiling." "which is mainly for profiling.",
'"gguf" will load the weights in the gguf format. ',
) )
parser.add_argument( parser.add_argument(
"--trust-remote-code", "--trust-remote-code",
@@ -301,7 +293,6 @@ class ServerArgs:
"gptq_marlin", "gptq_marlin",
"awq_marlin", "awq_marlin",
"bitsandbytes", "bitsandbytes",
"gguf",
], ],
help="The quantization method.", help="The quantization method.",
) )

View File

@@ -557,29 +557,6 @@ def monkey_patch_vllm_all_gather(reverse: bool = False):
setattr(GroupCoordinator, "all_gather", all_gather) setattr(GroupCoordinator, "all_gather", all_gather)
def monkey_patch_vllm_gguf_config():
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.gguf import (
GGUFConfig,
GGUFEmbeddingMethod,
GGUFLinearMethod,
)
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
def get_quant_method_with_embedding_replaced(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
return GGUFLinearMethod(self)
elif isinstance(layer, VocabParallelEmbedding):
# patch to own VocabParallelEmbedding
return GGUFEmbeddingMethod(self)
return None
setattr(GGUFConfig, "get_quant_method", get_quant_method_with_embedding_replaced)
def maybe_set_triton_cache_manager() -> None: def maybe_set_triton_cache_manager() -> None:
"""Set environment variable to tell Triton to use a """Set environment variable to tell Triton to use a
custom cache manager""" custom cache manager"""

View File

@@ -15,7 +15,6 @@ suites = {
"test_double_sparsity.py", "test_double_sparsity.py",
"test_embedding_openai_server.py", "test_embedding_openai_server.py",
"test_eval_accuracy_mini.py", "test_eval_accuracy_mini.py",
"test_gguf.py",
"test_input_embeddings.py", "test_input_embeddings.py",
"test_json_constrained.py", "test_json_constrained.py",
"test_large_max_new_tokens.py", "test_large_max_new_tokens.py",

View File

@@ -1,26 +0,0 @@
import unittest
from huggingface_hub import hf_hub_download
import sglang as sgl
class TestGGUF(unittest.TestCase):
def test_models(self):
prompt = "Today is a sunny day and I like"
sampling_params = {"temperature": 0, "max_new_tokens": 8}
model_path = hf_hub_download(
"Qwen/Qwen2-1.5B-Instruct-GGUF",
filename="qwen2-1_5b-instruct-q4_k_m.gguf",
)
engine = sgl.Engine(model_path=model_path, random_seed=42)
outputs = engine.generate(prompt, sampling_params)["text"]
engine.shutdown()
self.assertEqual(outputs, " it. I have a lot of work")
if __name__ == "__main__":
unittest.main()