diff --git a/docs/requirements.txt b/docs/requirements.txt index 948d5427c..171d60e0a 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -15,4 +15,3 @@ sphinx-copybutton sphinx-tabs sphinxcontrib-mermaid urllib3<2.0.0 -gguf>=0.10.0 diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index ac475cf34..a5b566fb3 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -16,7 +16,6 @@ import contextlib import os import warnings -from pathlib import Path from typing import Dict, Optional, Type, Union from huggingface_hub import snapshot_download @@ -28,7 +27,6 @@ from transformers import ( PreTrainedTokenizer, PreTrainedTokenizerFast, ) -from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES try: from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig @@ -62,29 +60,15 @@ def get_config( trust_remote_code: bool, revision: Optional[str] = None, model_override_args: Optional[dict] = None, - **kwargs, ): - is_gguf = check_gguf_file(model) - if is_gguf: - kwargs["gguf_file"] = model - model = Path(model).parent - config = AutoConfig.from_pretrained( - model, trust_remote_code=trust_remote_code, revision=revision, **kwargs + model, trust_remote_code=trust_remote_code, revision=revision ) if config.model_type in _CONFIG_REGISTRY: config_class = _CONFIG_REGISTRY[config.model_type] config = config_class.from_pretrained(model, revision=revision) if model_override_args: config.update(model_override_args) - - # Special architecture mapping check for GGUF models - if is_gguf: - if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: - raise RuntimeError(f"Can't get gguf config for {config.model_type}.") - model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type] - config.update({"architectures": [model_type]}) - return config @@ -139,11 +123,6 @@ def get_tokenizer( raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") kwargs["use_fast"] = False - is_gguf = check_gguf_file(tokenizer_name) - if is_gguf: - kwargs["gguf_file"] = tokenizer_name - tokenizer_name = Path(tokenizer_name).parent - try: tokenizer = AutoTokenizer.from_pretrained( tokenizer_name, @@ -216,16 +195,3 @@ def attach_additional_stop_token_ids(tokenizer): ) else: tokenizer.additional_stop_token_ids = None - - -def check_gguf_file(model: Union[str, os.PathLike]) -> bool: - """Check if the file is a GGUF model.""" - model = Path(model) - if not model.is_file(): - return False - elif model.suffix == ".gguf": - return True - - with open(model, "rb") as f: - header = f.read(4) - return header == b"GGUF" diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 274c4c311..eedd7fe01 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -23,7 +23,6 @@ from vllm.distributed import ( tensor_model_parallel_all_gather, ) -from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode @@ -164,7 +163,7 @@ class LogitsProcessor(nn.Module): self, input_ids, hidden_states, - lm_head: VocabParallelEmbedding, + weight, logits_metadata: Union[LogitsMetadata, ForwardBatch], ): if isinstance(logits_metadata, ForwardBatch): @@ -179,7 +178,7 @@ class LogitsProcessor(nn.Module): last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1 last_hidden = hidden_states[last_index] - last_logits = self._get_logits(last_hidden, lm_head) + last_logits = torch.matmul(last_hidden, weight.T) if self.do_tensor_parallel_all_gather: last_logits = tensor_model_parallel_all_gather(last_logits) last_logits = last_logits[:, : self.config.vocab_size].float() @@ -230,7 +229,7 @@ class LogitsProcessor(nn.Module): # Compute the logits and logprobs for all required tokens states = torch.cat(states, dim=0) - all_logits = self._get_logits(states, lm_head) + all_logits = torch.matmul(states, weight.T) if self.do_tensor_parallel_all_gather: all_logits = tensor_model_parallel_all_gather(all_logits) all_logits = all_logits[:, : self.config.vocab_size].float() @@ -277,19 +276,6 @@ class LogitsProcessor(nn.Module): output_top_logprobs=output_top_logprobs, ) - def _get_logits( - self, - hidden_states: torch.Tensor, - lm_head: VocabParallelEmbedding, - embedding_bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if hasattr(lm_head, "weight"): - logits = torch.matmul(hidden_states, lm_head.weight.T) - else: - # GGUF models - logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias) - return logits - def test(): all_logprobs = torch.tensor( diff --git a/python/sglang/srt/layers/vocab_parallel_embedding.py b/python/sglang/srt/layers/vocab_parallel_embedding.py index effea1c6c..a2d15fc78 100644 --- a/python/sglang/srt/layers/vocab_parallel_embedding.py +++ b/python/sglang/srt/layers/vocab_parallel_embedding.py @@ -222,7 +222,6 @@ class VocabParallelEmbedding(torch.nn.Module): enable_tp: bool = True, ): super().__init__() - self.quant_config = quant_config self.enable_tp = enable_tp if self.enable_tp: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 3cc90c19d..0542b7b0b 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -59,7 +59,6 @@ from sglang.srt.utils import ( enable_show_time_cost, get_available_gpu_memory, is_hip, - monkey_patch_vllm_gguf_config, monkey_patch_vllm_model_config, monkey_patch_vllm_p2p_access_check, set_cpu_offload_max_bytes, @@ -298,8 +297,6 @@ class ModelRunner: download_dir=self.server_args.download_dir, ) monkey_patch_vllm_model_config() - if self.server_args.load_format == "gguf": - monkey_patch_vllm_gguf_config() self.vllm_model_config = VllmModelConfig(**self.get_model_config_params()) if self.model_config.model_override_args is not None: self.vllm_model_config.hf_config.update( diff --git a/python/sglang/srt/models/baichuan.py b/python/sglang/srt/models/baichuan.py index d3b0fd9ae..0e5e3b9ad 100644 --- a/python/sglang/srt/models/baichuan.py +++ b/python/sglang/srt/models/baichuan.py @@ -338,12 +338,11 @@ class BaiChuanBaseForCausalLM(nn.Module): self.quant_config = quant_config self.model = BaiChuanModel(config, position_embedding, quant_config) + self.lm_head = ParallelLMHead( + config.vocab_size, config.hidden_size, quant_config=quant_config + ) if self.config.tie_word_embeddings: - self.lm_head = self.model.embed_tokens - else: - self.lm_head = ParallelLMHead( - config.vocab_size, config.hidden_size, quant_config=quant_config - ) + self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config) def forward( @@ -354,7 +353,7 @@ class BaiChuanBaseForCausalLM(nn.Module): ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch) return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch + input_ids, hidden_states, self.lm_head.weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/chatglm.py b/python/sglang/srt/models/chatglm.py index ced6859c7..05ce17a6b 100644 --- a/python/sglang/srt/models/chatglm.py +++ b/python/sglang/srt/models/chatglm.py @@ -378,7 +378,7 @@ class ChatGLMForCausalLM(nn.Module): ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, forward_batch) return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch + input_ids, hidden_states, self.lm_head.weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/commandr.py b/python/sglang/srt/models/commandr.py index 8769d49db..d4018be88 100644 --- a/python/sglang/srt/models/commandr.py +++ b/python/sglang/srt/models/commandr.py @@ -339,7 +339,7 @@ class CohereForCausalLM(nn.Module): forward_batch, ) return self.logits_processor( - input_ids, hidden_states, self.model.embed_tokens, forward_batch + input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/dbrx.py b/python/sglang/srt/models/dbrx.py index e9b4ff141..b8dad0248 100644 --- a/python/sglang/srt/models/dbrx.py +++ b/python/sglang/srt/models/dbrx.py @@ -390,7 +390,7 @@ class DbrxForCausalLM(nn.Module): ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, forward_batch) return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch + input_ids, hidden_states, self.lm_head.weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/deepseek.py b/python/sglang/srt/models/deepseek.py index 43dfc50a4..cdebafa2f 100644 --- a/python/sglang/srt/models/deepseek.py +++ b/python/sglang/srt/models/deepseek.py @@ -394,7 +394,7 @@ class DeepseekForCausalLM(nn.Module): ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch) return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch + input_ids, hidden_states, self.lm_head.weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 55a458c20..85467c12c 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -763,7 +763,7 @@ class DeepseekV2ForCausalLM(nn.Module): hidden_states = self.model(input_ids, positions, forward_batch) if not forward_batch.forward_mode.is_idle(): return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch + input_ids, hidden_states, self.lm_head.weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/exaone.py b/python/sglang/srt/models/exaone.py index 8c244419f..c097e00ad 100644 --- a/python/sglang/srt/models/exaone.py +++ b/python/sglang/srt/models/exaone.py @@ -314,7 +314,7 @@ class ExaoneForCausalLM(nn.Module): input_ids, positions, forward_batch, input_embeds ) return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch + input_ids, hidden_states, self.lm_head.weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/gemma.py b/python/sglang/srt/models/gemma.py index f6d301546..a53fad958 100644 --- a/python/sglang/srt/models/gemma.py +++ b/python/sglang/srt/models/gemma.py @@ -298,7 +298,7 @@ class GemmaForCausalLM(nn.Module): ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.model.embed_tokens, forward_batch + input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index 104205648..0fa6a5393 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -363,7 +363,7 @@ class Gemma2ForCausalLM(nn.Module): ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.model.embed_tokens, forward_batch + input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch ) def get_attention_sliding_window_size(self): diff --git a/python/sglang/srt/models/gpt2.py b/python/sglang/srt/models/gpt2.py index 6fbfe9edd..8d988fe8e 100644 --- a/python/sglang/srt/models/gpt2.py +++ b/python/sglang/srt/models/gpt2.py @@ -247,7 +247,7 @@ class GPT2LMHeadModel(nn.Module): ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, forward_batch) return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch + input_ids, hidden_states, self.lm_head.weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/gpt_bigcode.py b/python/sglang/srt/models/gpt_bigcode.py index 5af127320..03597fa73 100644 --- a/python/sglang/srt/models/gpt_bigcode.py +++ b/python/sglang/srt/models/gpt_bigcode.py @@ -271,7 +271,7 @@ class GPTBigCodeForCausalLM(nn.Module): ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, forward_batch) return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch + input_ids, hidden_states, self.lm_head.weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index d5c303d13..1e49eb59a 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -304,7 +304,7 @@ class Grok1ForCausalLM(nn.Module): ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch + input_ids, hidden_states, self.lm_head.weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/internlm2.py b/python/sglang/srt/models/internlm2.py index d217fd71f..59ff6d1e2 100644 --- a/python/sglang/srt/models/internlm2.py +++ b/python/sglang/srt/models/internlm2.py @@ -270,7 +270,7 @@ class InternLM2ForCausalLM(nn.Module): ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.output, forward_batch + input_ids, hidden_states, self.output.weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index ba5cfc90e..5f472ef3b 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -258,7 +258,6 @@ class LlamaModel(nn.Module): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, - quant_config=quant_config, ) self.layers = make_layers( config.num_hidden_layers, @@ -306,12 +305,7 @@ class LlamaForCausalLM(nn.Module): self.quant_config = quant_config self.torchao_config = global_server_args_dict["torchao_config"] self.model = LlamaModel(config, quant_config=quant_config) - if self.config.tie_word_embeddings: - self.lm_head = self.model.embed_tokens - else: - self.lm_head = ParallelLMHead( - config.vocab_size, config.hidden_size, quant_config=quant_config - ) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.stacked_params_mapping = [ @@ -335,7 +329,7 @@ class LlamaForCausalLM(nn.Module): hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) if not get_embedding: return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch + input_ids, hidden_states, self.lm_head.weight, forward_batch ) else: return self.pooler(hidden_states, forward_batch) @@ -379,6 +373,7 @@ class LlamaForCausalLM(nn.Module): return len(params_dict) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + embed_tokens_weight = None stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -390,6 +385,12 @@ class LlamaForCausalLM(nn.Module): params_dict = dict(self.named_parameters()) + load_tie_word_embeddings = ( + hasattr(self.config, "tie_word_embeddings") + and self.config.tie_word_embeddings + and "lm_head.weight" in params_dict + ) + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name or "projector" in name: continue @@ -422,6 +423,16 @@ class LlamaForCausalLM(nn.Module): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + if load_tie_word_embeddings and name == "model.embed_tokens.weight": + embed_tokens_weight = loaded_weight + + if load_tie_word_embeddings: + # Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing + param = self.lm_head.weight + weight_loader = getattr(param, "weight_loader", default_weight_loader) + if embed_tokens_weight is not None: + weight_loader(param, embed_tokens_weight) + apply_torchao_config_(self, params_dict, set(["proj.weight"])) def get_weights_by_name( diff --git a/python/sglang/srt/models/minicpm.py b/python/sglang/srt/models/minicpm.py index 0d668fe5d..239cfb6fc 100644 --- a/python/sglang/srt/models/minicpm.py +++ b/python/sglang/srt/models/minicpm.py @@ -308,10 +308,12 @@ class MiniCPMForCausalLM(nn.Module): hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = hidden_states / self.scale_width if self.config.tie_word_embeddings: - lm_head = self.model.embed_tokens + lm_head_weight = self.model.embed_tokens.weight else: - lm_head = self.lm_head - return self.logits_processor(input_ids, hidden_states, lm_head, forward_batch) + lm_head_weight = self.lm_head.weight + return self.logits_processor( + input_ids, hidden_states, lm_head_weight, forward_batch + ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/minicpm3.py b/python/sglang/srt/models/minicpm3.py index e6bf118ed..6f53f2974 100644 --- a/python/sglang/srt/models/minicpm3.py +++ b/python/sglang/srt/models/minicpm3.py @@ -585,10 +585,12 @@ class MiniCPM3ForCausalLM(nn.Module): hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = hidden_states / self.scale_width if self.config.tie_word_embeddings: - lm_head = self.model.embed_tokens + lm_head_weight = self.model.embed_tokens.weight else: - lm_head = self.lm_head - return self.logits_processor(input_ids, hidden_states, lm_head, forward_batch) + lm_head_weight = self.lm_head.weight + return self.logits_processor( + input_ids, hidden_states, lm_head_weight, forward_batch + ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index b2e895f56..98d5ab332 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -310,7 +310,7 @@ class MixtralForCausalLM(nn.Module): ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch + input_ids, hidden_states, self.lm_head.weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/mixtral_quant.py b/python/sglang/srt/models/mixtral_quant.py index 8dba2b722..d15a389a8 100644 --- a/python/sglang/srt/models/mixtral_quant.py +++ b/python/sglang/srt/models/mixtral_quant.py @@ -343,7 +343,7 @@ class QuantMixtralForCausalLM(nn.Module): ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch + input_ids, hidden_states, self.lm_head.weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/mllama.py b/python/sglang/srt/models/mllama.py index 2a0cf4ea3..63bbfdb7e 100644 --- a/python/sglang/srt/models/mllama.py +++ b/python/sglang/srt/models/mllama.py @@ -966,7 +966,7 @@ class MllamaForConditionalGeneration(nn.Module): skip_cross_attention=skip_cross_attention, ) return self.logits_processor( - input_ids, hidden_states, self.language_model.lm_head, forward_batch + input_ids, hidden_states, self.language_model.lm_head.weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/olmo.py b/python/sglang/srt/models/olmo.py index 2ef6532ce..80fd64a53 100644 --- a/python/sglang/srt/models/olmo.py +++ b/python/sglang/srt/models/olmo.py @@ -306,7 +306,7 @@ class OlmoForCausalLM(nn.Module): input_embeds=input_embeds, ) return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch + input_ids, hidden_states, self.lm_head.weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): @@ -326,6 +326,11 @@ class OlmoForCausalLM(nn.Module): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/python/sglang/srt/models/olmoe.py b/python/sglang/srt/models/olmoe.py index 549e2d032..407eb98cb 100644 --- a/python/sglang/srt/models/olmoe.py +++ b/python/sglang/srt/models/olmoe.py @@ -321,7 +321,7 @@ class OlmoeForCausalLM(nn.Module): ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch + input_ids, hidden_states, self.lm_head.weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/phi3_small.py b/python/sglang/srt/models/phi3_small.py index e310dfcea..656115321 100644 --- a/python/sglang/srt/models/phi3_small.py +++ b/python/sglang/srt/models/phi3_small.py @@ -397,13 +397,10 @@ class Phi3SmallForCausalLM(nn.Module): def compute_logits( self, - input_ids: torch.LongTensor, hidden_states: torch.Tensor, sampling_metadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor( - input_ids, self.lm_head, hidden_states, sampling_metadata - ) + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) if self.dummy_token_indices is not None and logits is not None: logits.index_fill_(-1, self.dummy_token_indices, -torch.inf) return logits @@ -425,7 +422,7 @@ class Phi3SmallForCausalLM(nn.Module): if not get_embedding: return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch + input_ids, hidden_states, self.lm_head.weight, forward_batch ) else: diff --git a/python/sglang/srt/models/qwen.py b/python/sglang/srt/models/qwen.py index fb4b67ff5..4c1829026 100644 --- a/python/sglang/srt/models/qwen.py +++ b/python/sglang/srt/models/qwen.py @@ -260,7 +260,7 @@ class QWenLMHeadModel(nn.Module): ): hidden_states = self.transformer(input_ids, positions, forward_batch) return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch + input_ids, hidden_states, self.lm_head.weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 4c8ddd4b9..634ce1cf1 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -230,7 +230,6 @@ class Qwen2Model(nn.Module): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, - quant_config=quant_config, ) self.layers = make_layers( config.num_hidden_layers, @@ -277,12 +276,7 @@ class Qwen2ForCausalLM(nn.Module): self.config = config self.quant_config = quant_config self.model = Qwen2Model(config, quant_config=quant_config) - if config.tie_word_embeddings: - self.lm_head = self.model.embed_tokens - else: - self.lm_head = ParallelLMHead( - config.vocab_size, config.hidden_size, quant_config=quant_config - ) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) @@ -298,7 +292,7 @@ class Qwen2ForCausalLM(nn.Module): hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) if not get_embedding: return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch + input_ids, hidden_states, self.lm_head.weight, forward_batch ) else: return self.pooler(hidden_states, forward_batch) @@ -312,7 +306,6 @@ class Qwen2ForCausalLM(nn.Module): ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] - params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name or "projector" in name: @@ -342,6 +335,11 @@ class Qwen2ForCausalLM(nn.Module): param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + if ( + self.config.tie_word_embeddings + and name == "model.embed_tokens.weight" + ): + weight_loader(params_dict["lm_head.weight"], loaded_weight) EntryClass = Qwen2ForCausalLM diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 256993269..febd6d748 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -376,7 +376,7 @@ class Qwen2MoeForCausalLM(nn.Module): ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch + input_ids, hidden_states, self.lm_head.weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index 155bde015..dc58383ee 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -668,7 +668,7 @@ class Qwen2VLForConditionalGeneration(nn.Module): if not get_embedding: return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch + input_ids, hidden_states, self.lm_head.weight, forward_batch ) else: return self.pooler(hidden_states, forward_batch) @@ -686,6 +686,8 @@ class Qwen2VLForConditionalGeneration(nn.Module): for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/python/sglang/srt/models/stablelm.py b/python/sglang/srt/models/stablelm.py index 38f2be13a..9fa2ab343 100644 --- a/python/sglang/srt/models/stablelm.py +++ b/python/sglang/srt/models/stablelm.py @@ -261,7 +261,7 @@ class StableLmForCausalLM(nn.Module): ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch + input_ids, hidden_states, self.lm_head.weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/torch_native_llama.py b/python/sglang/srt/models/torch_native_llama.py index 68982eebf..b9451d591 100644 --- a/python/sglang/srt/models/torch_native_llama.py +++ b/python/sglang/srt/models/torch_native_llama.py @@ -396,10 +396,7 @@ class TorchNativeLlamaForCausalLM(nn.Module): self.torchao_config = global_server_args_dict["torchao_config"] self.supports_torch_tp = True self.model = LlamaModel(config, quant_config=quant_config) - if self.config.tie_word_embeddings: - self.lm_head = self.model.embed_tokens - else: - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) # turning off autotune for fp8dq since it doesn't give speedup and @@ -416,7 +413,7 @@ class TorchNativeLlamaForCausalLM(nn.Module): ) -> LogitsProcessorOutput: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch + input_ids, hidden_states, self.lm_head.weight, forward_batch ) def get_hidden_dim(self, module_name): @@ -504,6 +501,14 @@ class TorchNativeLlamaForCausalLM(nn.Module): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + if ( + hasattr(self.config, "tie_word_embeddings") + and self.config.tie_word_embeddings + ): + # Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing + param = self.lm_head.weight + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, self.model.embed_tokens.weight) apply_torchao_config_(self, params_dict, set(["proj.weight"])) diff --git a/python/sglang/srt/models/xverse.py b/python/sglang/srt/models/xverse.py index 42f51a7fa..fb7e14a0e 100644 --- a/python/sglang/srt/models/xverse.py +++ b/python/sglang/srt/models/xverse.py @@ -315,7 +315,7 @@ class XverseForCausalLM(nn.Module): ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch + input_ids, hidden_states, self.lm_head.weight, forward_batch ) def load_weights( diff --git a/python/sglang/srt/models/xverse_moe.py b/python/sglang/srt/models/xverse_moe.py index 3a8b9a9e4..c6458f7f5 100644 --- a/python/sglang/srt/models/xverse_moe.py +++ b/python/sglang/srt/models/xverse_moe.py @@ -390,7 +390,7 @@ class XverseMoeForCausalLM(nn.Module): ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch) return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch + input_ids, hidden_states, self.lm_head.weight, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index e52350490..be470dac3 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -20,7 +20,6 @@ import random import tempfile from typing import List, Optional -from sglang.srt.hf_transformers_utils import check_gguf_file from sglang.srt.utils import ( get_amdgpu_memory_capacity, get_nvgpu_memory_capacity, @@ -205,12 +204,6 @@ class ServerArgs: "Overlap schedule is disabled." ) - # GGUF - if ( - self.load_format == "auto" or self.load_format == "gguf" - ) and check_gguf_file(self.model_path): - self.quantization = self.load_format = "gguf" - @staticmethod def add_cli_args(parser: argparse.ArgumentParser): # Model and port args @@ -250,7 +243,7 @@ class ServerArgs: "--load-format", type=str, default=ServerArgs.load_format, - choices=["auto", "pt", "safetensors", "npcache", "dummy", "gguf"], + choices=["auto", "pt", "safetensors", "npcache", "dummy"], help="The format of the model weights to load. " '"auto" will try to load the weights in the safetensors format ' "and fall back to the pytorch bin format if safetensors format " @@ -260,8 +253,7 @@ class ServerArgs: '"npcache" will load the weights in pytorch format and store ' "a numpy cache to speed up the loading. " '"dummy" will initialize the weights with random values, ' - "which is mainly for profiling." - '"gguf" will load the weights in the gguf format. ', + "which is mainly for profiling.", ) parser.add_argument( "--trust-remote-code", @@ -301,7 +293,6 @@ class ServerArgs: "gptq_marlin", "awq_marlin", "bitsandbytes", - "gguf", ], help="The quantization method.", ) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 89044c8b2..46b4db8e8 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -557,29 +557,6 @@ def monkey_patch_vllm_all_gather(reverse: bool = False): setattr(GroupCoordinator, "all_gather", all_gather) -def monkey_patch_vllm_gguf_config(): - from vllm.model_executor.layers.linear import LinearBase - from vllm.model_executor.layers.quantization.gguf import ( - GGUFConfig, - GGUFEmbeddingMethod, - GGUFLinearMethod, - ) - - from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding - - def get_quant_method_with_embedding_replaced( - self, layer: torch.nn.Module, prefix: str - ) -> Optional["QuantizeMethodBase"]: - if isinstance(layer, LinearBase): - return GGUFLinearMethod(self) - elif isinstance(layer, VocabParallelEmbedding): - # patch to own VocabParallelEmbedding - return GGUFEmbeddingMethod(self) - return None - - setattr(GGUFConfig, "get_quant_method", get_quant_method_with_embedding_replaced) - - def maybe_set_triton_cache_manager() -> None: """Set environment variable to tell Triton to use a custom cache manager""" diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index b055df1a1..d441cf9b2 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -15,7 +15,6 @@ suites = { "test_double_sparsity.py", "test_embedding_openai_server.py", "test_eval_accuracy_mini.py", - "test_gguf.py", "test_input_embeddings.py", "test_json_constrained.py", "test_large_max_new_tokens.py", diff --git a/test/srt/test_gguf.py b/test/srt/test_gguf.py deleted file mode 100644 index 89572c45f..000000000 --- a/test/srt/test_gguf.py +++ /dev/null @@ -1,26 +0,0 @@ -import unittest - -from huggingface_hub import hf_hub_download - -import sglang as sgl - - -class TestGGUF(unittest.TestCase): - def test_models(self): - prompt = "Today is a sunny day and I like" - sampling_params = {"temperature": 0, "max_new_tokens": 8} - - model_path = hf_hub_download( - "Qwen/Qwen2-1.5B-Instruct-GGUF", - filename="qwen2-1_5b-instruct-q4_k_m.gguf", - ) - - engine = sgl.Engine(model_path=model_path, random_seed=42) - outputs = engine.generate(prompt, sampling_params)["text"] - engine.shutdown() - - self.assertEqual(outputs, " it. I have a lot of work") - - -if __name__ == "__main__": - unittest.main()