diff --git a/docs/requirements.txt b/docs/requirements.txt index 171d60e0a..948d5427c 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -15,3 +15,4 @@ sphinx-copybutton sphinx-tabs sphinxcontrib-mermaid urllib3<2.0.0 +gguf>=0.10.0 diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index a5b566fb3..ac475cf34 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -16,6 +16,7 @@ import contextlib import os import warnings +from pathlib import Path from typing import Dict, Optional, Type, Union from huggingface_hub import snapshot_download @@ -27,6 +28,7 @@ from transformers import ( PreTrainedTokenizer, PreTrainedTokenizerFast, ) +from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES try: from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig @@ -60,15 +62,29 @@ def get_config( trust_remote_code: bool, revision: Optional[str] = None, model_override_args: Optional[dict] = None, + **kwargs, ): + is_gguf = check_gguf_file(model) + if is_gguf: + kwargs["gguf_file"] = model + model = Path(model).parent + config = AutoConfig.from_pretrained( - model, trust_remote_code=trust_remote_code, revision=revision + model, trust_remote_code=trust_remote_code, revision=revision, **kwargs ) if config.model_type in _CONFIG_REGISTRY: config_class = _CONFIG_REGISTRY[config.model_type] config = config_class.from_pretrained(model, revision=revision) if model_override_args: config.update(model_override_args) + + # Special architecture mapping check for GGUF models + if is_gguf: + if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: + raise RuntimeError(f"Can't get gguf config for {config.model_type}.") + model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type] + config.update({"architectures": [model_type]}) + return config @@ -123,6 +139,11 @@ def get_tokenizer( raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") kwargs["use_fast"] = False + is_gguf = check_gguf_file(tokenizer_name) + if is_gguf: + kwargs["gguf_file"] = tokenizer_name + tokenizer_name = Path(tokenizer_name).parent + try: tokenizer = AutoTokenizer.from_pretrained( tokenizer_name, @@ -195,3 +216,16 @@ def attach_additional_stop_token_ids(tokenizer): ) else: tokenizer.additional_stop_token_ids = None + + +def check_gguf_file(model: Union[str, os.PathLike]) -> bool: + """Check if the file is a GGUF model.""" + model = Path(model) + if not model.is_file(): + return False + elif model.suffix == ".gguf": + return True + + with open(model, "rb") as f: + header = f.read(4) + return header == b"GGUF" diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index eedd7fe01..274c4c311 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -23,6 +23,7 @@ from vllm.distributed import ( tensor_model_parallel_all_gather, ) +from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode @@ -163,7 +164,7 @@ class LogitsProcessor(nn.Module): self, input_ids, hidden_states, - weight, + lm_head: VocabParallelEmbedding, logits_metadata: Union[LogitsMetadata, ForwardBatch], ): if isinstance(logits_metadata, ForwardBatch): @@ -178,7 +179,7 @@ class LogitsProcessor(nn.Module): last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1 last_hidden = hidden_states[last_index] - last_logits = torch.matmul(last_hidden, weight.T) + last_logits = self._get_logits(last_hidden, lm_head) if self.do_tensor_parallel_all_gather: last_logits = tensor_model_parallel_all_gather(last_logits) last_logits = last_logits[:, : self.config.vocab_size].float() @@ -229,7 +230,7 @@ class LogitsProcessor(nn.Module): # Compute the logits and logprobs for all required tokens states = torch.cat(states, dim=0) - all_logits = torch.matmul(states, weight.T) + all_logits = self._get_logits(states, lm_head) if self.do_tensor_parallel_all_gather: all_logits = tensor_model_parallel_all_gather(all_logits) all_logits = all_logits[:, : self.config.vocab_size].float() @@ -276,6 +277,19 @@ class LogitsProcessor(nn.Module): output_top_logprobs=output_top_logprobs, ) + def _get_logits( + self, + hidden_states: torch.Tensor, + lm_head: VocabParallelEmbedding, + embedding_bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if hasattr(lm_head, "weight"): + logits = torch.matmul(hidden_states, lm_head.weight.T) + else: + # GGUF models + logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias) + return logits + def test(): all_logprobs = torch.tensor( diff --git a/python/sglang/srt/layers/vocab_parallel_embedding.py b/python/sglang/srt/layers/vocab_parallel_embedding.py index a2d15fc78..effea1c6c 100644 --- a/python/sglang/srt/layers/vocab_parallel_embedding.py +++ b/python/sglang/srt/layers/vocab_parallel_embedding.py @@ -222,6 +222,7 @@ class VocabParallelEmbedding(torch.nn.Module): enable_tp: bool = True, ): super().__init__() + self.quant_config = quant_config self.enable_tp = enable_tp if self.enable_tp: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 0542b7b0b..3cc90c19d 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -59,6 +59,7 @@ from sglang.srt.utils import ( enable_show_time_cost, get_available_gpu_memory, is_hip, + monkey_patch_vllm_gguf_config, monkey_patch_vllm_model_config, monkey_patch_vllm_p2p_access_check, set_cpu_offload_max_bytes, @@ -297,6 +298,8 @@ class ModelRunner: download_dir=self.server_args.download_dir, ) monkey_patch_vllm_model_config() + if self.server_args.load_format == "gguf": + monkey_patch_vllm_gguf_config() self.vllm_model_config = VllmModelConfig(**self.get_model_config_params()) if self.model_config.model_override_args is not None: self.vllm_model_config.hf_config.update( diff --git a/python/sglang/srt/models/baichuan.py b/python/sglang/srt/models/baichuan.py index 0e5e3b9ad..d3b0fd9ae 100644 --- a/python/sglang/srt/models/baichuan.py +++ b/python/sglang/srt/models/baichuan.py @@ -338,11 +338,12 @@ class BaiChuanBaseForCausalLM(nn.Module): self.quant_config = quant_config self.model = BaiChuanModel(config, position_embedding, quant_config) - self.lm_head = ParallelLMHead( - config.vocab_size, config.hidden_size, quant_config=quant_config - ) if self.config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.vocab_size, config.hidden_size, quant_config=quant_config + ) self.logits_processor = LogitsProcessor(config) def forward( @@ -353,7 +354,7 @@ class BaiChuanBaseForCausalLM(nn.Module): ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/chatglm.py b/python/sglang/srt/models/chatglm.py index 05ce17a6b..ced6859c7 100644 --- a/python/sglang/srt/models/chatglm.py +++ b/python/sglang/srt/models/chatglm.py @@ -378,7 +378,7 @@ class ChatGLMForCausalLM(nn.Module): ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, forward_batch) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/commandr.py b/python/sglang/srt/models/commandr.py index d4018be88..8769d49db 100644 --- a/python/sglang/srt/models/commandr.py +++ b/python/sglang/srt/models/commandr.py @@ -339,7 +339,7 @@ class CohereForCausalLM(nn.Module): forward_batch, ) return self.logits_processor( - input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch + input_ids, hidden_states, self.model.embed_tokens, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/dbrx.py b/python/sglang/srt/models/dbrx.py index b8dad0248..e9b4ff141 100644 --- a/python/sglang/srt/models/dbrx.py +++ b/python/sglang/srt/models/dbrx.py @@ -390,7 +390,7 @@ class DbrxForCausalLM(nn.Module): ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, forward_batch) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/deepseek.py b/python/sglang/srt/models/deepseek.py index cdebafa2f..43dfc50a4 100644 --- a/python/sglang/srt/models/deepseek.py +++ b/python/sglang/srt/models/deepseek.py @@ -394,7 +394,7 @@ class DeepseekForCausalLM(nn.Module): ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 85467c12c..55a458c20 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -763,7 +763,7 @@ class DeepseekV2ForCausalLM(nn.Module): hidden_states = self.model(input_ids, positions, forward_batch) if not forward_batch.forward_mode.is_idle(): return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/exaone.py b/python/sglang/srt/models/exaone.py index c097e00ad..8c244419f 100644 --- a/python/sglang/srt/models/exaone.py +++ b/python/sglang/srt/models/exaone.py @@ -314,7 +314,7 @@ class ExaoneForCausalLM(nn.Module): input_ids, positions, forward_batch, input_embeds ) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/gemma.py b/python/sglang/srt/models/gemma.py index a53fad958..f6d301546 100644 --- a/python/sglang/srt/models/gemma.py +++ b/python/sglang/srt/models/gemma.py @@ -298,7 +298,7 @@ class GemmaForCausalLM(nn.Module): ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch + input_ids, hidden_states, self.model.embed_tokens, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index 0fa6a5393..104205648 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -363,7 +363,7 @@ class Gemma2ForCausalLM(nn.Module): ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch + input_ids, hidden_states, self.model.embed_tokens, forward_batch ) def get_attention_sliding_window_size(self): diff --git a/python/sglang/srt/models/gpt2.py b/python/sglang/srt/models/gpt2.py index 8d988fe8e..6fbfe9edd 100644 --- a/python/sglang/srt/models/gpt2.py +++ b/python/sglang/srt/models/gpt2.py @@ -247,7 +247,7 @@ class GPT2LMHeadModel(nn.Module): ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, forward_batch) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/gpt_bigcode.py b/python/sglang/srt/models/gpt_bigcode.py index 03597fa73..5af127320 100644 --- a/python/sglang/srt/models/gpt_bigcode.py +++ b/python/sglang/srt/models/gpt_bigcode.py @@ -271,7 +271,7 @@ class GPTBigCodeForCausalLM(nn.Module): ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, forward_batch) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index 1e49eb59a..d5c303d13 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -304,7 +304,7 @@ class Grok1ForCausalLM(nn.Module): ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/internlm2.py b/python/sglang/srt/models/internlm2.py index 59ff6d1e2..d217fd71f 100644 --- a/python/sglang/srt/models/internlm2.py +++ b/python/sglang/srt/models/internlm2.py @@ -270,7 +270,7 @@ class InternLM2ForCausalLM(nn.Module): ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.output.weight, forward_batch + input_ids, hidden_states, self.output, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 5f472ef3b..68809c9c2 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -45,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.utils import make_layers +from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) @@ -258,6 +259,7 @@ class LlamaModel(nn.Module): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, + quant_config=quant_config, ) self.layers = make_layers( config.num_hidden_layers, @@ -305,7 +307,12 @@ class LlamaForCausalLM(nn.Module): self.quant_config = quant_config self.torchao_config = global_server_args_dict["torchao_config"] self.model = LlamaModel(config, quant_config=quant_config) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + if self.config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.vocab_size, config.hidden_size, quant_config=quant_config + ) self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.stacked_params_mapping = [ @@ -329,7 +336,7 @@ class LlamaForCausalLM(nn.Module): hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) if not get_embedding: return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) else: return self.pooler(hidden_states, forward_batch) @@ -373,7 +380,6 @@ class LlamaForCausalLM(nn.Module): return len(params_dict) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - embed_tokens_weight = None stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -385,12 +391,6 @@ class LlamaForCausalLM(nn.Module): params_dict = dict(self.named_parameters()) - load_tie_word_embeddings = ( - hasattr(self.config, "tie_word_embeddings") - and self.config.tie_word_embeddings - and "lm_head.weight" in params_dict - ) - for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name or "projector" in name: continue @@ -423,16 +423,6 @@ class LlamaForCausalLM(nn.Module): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - if load_tie_word_embeddings and name == "model.embed_tokens.weight": - embed_tokens_weight = loaded_weight - - if load_tie_word_embeddings: - # Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing - param = self.lm_head.weight - weight_loader = getattr(param, "weight_loader", default_weight_loader) - if embed_tokens_weight is not None: - weight_loader(param, embed_tokens_weight) - apply_torchao_config_(self, params_dict, set(["proj.weight"])) def get_weights_by_name( @@ -444,6 +434,17 @@ class LlamaForCausalLM(nn.Module): For optimized performance, please use torch.save and torch.load. """ try: + if name == "lm_head.weight" and self.config.tie_word_embeddings: + logger.info( + "word embedding is tied for this model, return embed_tokens.weight as lm_head.weight." + ) + return ( + self.model.embed_tokens.weight.cpu() + .to(torch.float32) + .numpy() + .tolist()[:truncate_size] + ) + mapped_name = name mapped_shard_id = None for param_name, weight_name, shard_id in self.stacked_params_mapping: @@ -452,54 +453,48 @@ class LlamaForCausalLM(nn.Module): mapped_shard_id = shard_id break params_dict = dict(self.named_parameters()) - if mapped_name in params_dict: - param = params_dict[mapped_name] - if mapped_shard_id is not None: - if mapped_shard_id in ["q", "k", "v"]: - num_heads = self.config.num_attention_heads // tp_size - num_kv_heads = self.config.num_key_value_heads // tp_size - head_dim = ( - self.config.hidden_size // self.config.num_attention_heads - ) - if mapped_shard_id == "q": - offset = 0 - size = num_heads * head_dim - elif mapped_shard_id == "k": - offset = num_heads * head_dim - size = num_kv_heads * head_dim - elif mapped_shard_id == "v": - offset = (num_heads + num_kv_heads) * head_dim - size = num_kv_heads * head_dim - weight = param.data.narrow(0, offset, size) - elif mapped_shard_id in [0, 1]: - intermediate_size = self.config.intermediate_size - hidden_size = self.config.hidden_size - slice_size = intermediate_size // tp_size - if mapped_shard_id == 0: # gate_proj - offset = 0 - size = slice_size - elif mapped_shard_id == 1: # up_proj - offset = slice_size - size = slice_size + param = params_dict[mapped_name] + if mapped_shard_id is not None: + if mapped_shard_id in ["q", "k", "v"]: + num_heads = self.config.num_attention_heads // tp_size + num_kv_heads = self.config.num_key_value_heads // tp_size + head_dim = ( + self.config.hidden_size // self.config.num_attention_heads + ) + if mapped_shard_id == "q": + offset = 0 + size = num_heads * head_dim + elif mapped_shard_id == "k": + offset = num_heads * head_dim + size = num_kv_heads * head_dim + elif mapped_shard_id == "v": + offset = (num_heads + num_kv_heads) * head_dim + size = num_kv_heads * head_dim + weight = param.data.narrow(0, offset, size) + elif mapped_shard_id in [0, 1]: + intermediate_size = self.config.intermediate_size + slice_size = intermediate_size // tp_size + if mapped_shard_id == 0: # gate_proj + offset = 0 + size = slice_size + elif mapped_shard_id == 1: # up_proj + offset = slice_size + size = slice_size - weight = param.data.narrow(0, offset, size) - else: - weight = param.data + weight = param.data.narrow(0, offset, size) else: weight = param.data - if tp_size > 1 and ("o_proj" in name or "down_proj" in name): - gathered_weights = [ - torch.zeros_like(weight) for _ in range(tp_size) - ] - torch.distributed.all_gather(gathered_weights, weight) - weight = torch.cat(gathered_weights, dim=1) - return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size] else: - return None + weight = param.data + if tp_size > 1 and ("o_proj" in name or "down_proj" in name): + gathered_weights = [torch.zeros_like(weight) for _ in range(tp_size)] + torch.distributed.all_gather(gathered_weights, weight) + weight = torch.cat(gathered_weights, dim=1) + return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size] - except Exception as e: + except Exception: logger.error( - f"Error getting weights by name {name} in LlamaForCausalLM: {e}" + f"Error getting weights by name {name} in LlamaForCausalLM: {get_exception_traceback()}" ) return None diff --git a/python/sglang/srt/models/minicpm.py b/python/sglang/srt/models/minicpm.py index 239cfb6fc..0d668fe5d 100644 --- a/python/sglang/srt/models/minicpm.py +++ b/python/sglang/srt/models/minicpm.py @@ -308,12 +308,10 @@ class MiniCPMForCausalLM(nn.Module): hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = hidden_states / self.scale_width if self.config.tie_word_embeddings: - lm_head_weight = self.model.embed_tokens.weight + lm_head = self.model.embed_tokens else: - lm_head_weight = self.lm_head.weight - return self.logits_processor( - input_ids, hidden_states, lm_head_weight, forward_batch - ) + lm_head = self.lm_head + return self.logits_processor(input_ids, hidden_states, lm_head, forward_batch) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/minicpm3.py b/python/sglang/srt/models/minicpm3.py index 6f53f2974..e6bf118ed 100644 --- a/python/sglang/srt/models/minicpm3.py +++ b/python/sglang/srt/models/minicpm3.py @@ -585,12 +585,10 @@ class MiniCPM3ForCausalLM(nn.Module): hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = hidden_states / self.scale_width if self.config.tie_word_embeddings: - lm_head_weight = self.model.embed_tokens.weight + lm_head = self.model.embed_tokens else: - lm_head_weight = self.lm_head.weight - return self.logits_processor( - input_ids, hidden_states, lm_head_weight, forward_batch - ) + lm_head = self.lm_head + return self.logits_processor(input_ids, hidden_states, lm_head, forward_batch) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index 98d5ab332..b2e895f56 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -310,7 +310,7 @@ class MixtralForCausalLM(nn.Module): ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/mixtral_quant.py b/python/sglang/srt/models/mixtral_quant.py index d15a389a8..8dba2b722 100644 --- a/python/sglang/srt/models/mixtral_quant.py +++ b/python/sglang/srt/models/mixtral_quant.py @@ -343,7 +343,7 @@ class QuantMixtralForCausalLM(nn.Module): ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/mllama.py b/python/sglang/srt/models/mllama.py index 63bbfdb7e..2a0cf4ea3 100644 --- a/python/sglang/srt/models/mllama.py +++ b/python/sglang/srt/models/mllama.py @@ -966,7 +966,7 @@ class MllamaForConditionalGeneration(nn.Module): skip_cross_attention=skip_cross_attention, ) return self.logits_processor( - input_ids, hidden_states, self.language_model.lm_head.weight, forward_batch + input_ids, hidden_states, self.language_model.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/olmo.py b/python/sglang/srt/models/olmo.py index 80fd64a53..2ef6532ce 100644 --- a/python/sglang/srt/models/olmo.py +++ b/python/sglang/srt/models/olmo.py @@ -306,7 +306,7 @@ class OlmoForCausalLM(nn.Module): input_embeds=input_embeds, ) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): @@ -326,11 +326,6 @@ class OlmoForCausalLM(nn.Module): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue - # With tie_word_embeddings, we can skip lm_head.weight - # The weight might appear unnecessarily in the files if the model is - # processed with quantization, LoRA, fine-tuning, etc. - if self.config.tie_word_embeddings and "lm_head.weight" in name: - continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/python/sglang/srt/models/olmoe.py b/python/sglang/srt/models/olmoe.py index 407eb98cb..549e2d032 100644 --- a/python/sglang/srt/models/olmoe.py +++ b/python/sglang/srt/models/olmoe.py @@ -321,7 +321,7 @@ class OlmoeForCausalLM(nn.Module): ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/phi3_small.py b/python/sglang/srt/models/phi3_small.py index 656115321..e310dfcea 100644 --- a/python/sglang/srt/models/phi3_small.py +++ b/python/sglang/srt/models/phi3_small.py @@ -397,10 +397,13 @@ class Phi3SmallForCausalLM(nn.Module): def compute_logits( self, + input_ids: torch.LongTensor, hidden_states: torch.Tensor, sampling_metadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) + logits = self.logits_processor( + input_ids, self.lm_head, hidden_states, sampling_metadata + ) if self.dummy_token_indices is not None and logits is not None: logits.index_fill_(-1, self.dummy_token_indices, -torch.inf) return logits @@ -422,7 +425,7 @@ class Phi3SmallForCausalLM(nn.Module): if not get_embedding: return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) else: diff --git a/python/sglang/srt/models/qwen.py b/python/sglang/srt/models/qwen.py index 4c1829026..fb4b67ff5 100644 --- a/python/sglang/srt/models/qwen.py +++ b/python/sglang/srt/models/qwen.py @@ -260,7 +260,7 @@ class QWenLMHeadModel(nn.Module): ): hidden_states = self.transformer(input_ids, positions, forward_batch) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 634ce1cf1..4c8ddd4b9 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -230,6 +230,7 @@ class Qwen2Model(nn.Module): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, + quant_config=quant_config, ) self.layers = make_layers( config.num_hidden_layers, @@ -276,7 +277,12 @@ class Qwen2ForCausalLM(nn.Module): self.config = config self.quant_config = quant_config self.model = Qwen2Model(config, quant_config=quant_config) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.vocab_size, config.hidden_size, quant_config=quant_config + ) self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) @@ -292,7 +298,7 @@ class Qwen2ForCausalLM(nn.Module): hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) if not get_embedding: return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) else: return self.pooler(hidden_states, forward_batch) @@ -306,6 +312,7 @@ class Qwen2ForCausalLM(nn.Module): ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] + params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name or "projector" in name: @@ -335,11 +342,6 @@ class Qwen2ForCausalLM(nn.Module): param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - if ( - self.config.tie_word_embeddings - and name == "model.embed_tokens.weight" - ): - weight_loader(params_dict["lm_head.weight"], loaded_weight) EntryClass = Qwen2ForCausalLM diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index febd6d748..256993269 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -376,7 +376,7 @@ class Qwen2MoeForCausalLM(nn.Module): ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index dc58383ee..155bde015 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -668,7 +668,7 @@ class Qwen2VLForConditionalGeneration(nn.Module): if not get_embedding: return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) else: return self.pooler(hidden_states, forward_batch) @@ -686,8 +686,6 @@ class Qwen2VLForConditionalGeneration(nn.Module): for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - if self.config.tie_word_embeddings and "lm_head.weight" in name: - continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/python/sglang/srt/models/stablelm.py b/python/sglang/srt/models/stablelm.py index 9fa2ab343..38f2be13a 100644 --- a/python/sglang/srt/models/stablelm.py +++ b/python/sglang/srt/models/stablelm.py @@ -261,7 +261,7 @@ class StableLmForCausalLM(nn.Module): ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/torch_native_llama.py b/python/sglang/srt/models/torch_native_llama.py index b9451d591..68982eebf 100644 --- a/python/sglang/srt/models/torch_native_llama.py +++ b/python/sglang/srt/models/torch_native_llama.py @@ -396,7 +396,10 @@ class TorchNativeLlamaForCausalLM(nn.Module): self.torchao_config = global_server_args_dict["torchao_config"] self.supports_torch_tp = True self.model = LlamaModel(config, quant_config=quant_config) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + if self.config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) # turning off autotune for fp8dq since it doesn't give speedup and @@ -413,7 +416,7 @@ class TorchNativeLlamaForCausalLM(nn.Module): ) -> LogitsProcessorOutput: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def get_hidden_dim(self, module_name): @@ -501,14 +504,6 @@ class TorchNativeLlamaForCausalLM(nn.Module): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - if ( - hasattr(self.config, "tie_word_embeddings") - and self.config.tie_word_embeddings - ): - # Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing - param = self.lm_head.weight - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, self.model.embed_tokens.weight) apply_torchao_config_(self, params_dict, set(["proj.weight"])) diff --git a/python/sglang/srt/models/xverse.py b/python/sglang/srt/models/xverse.py index fb7e14a0e..42f51a7fa 100644 --- a/python/sglang/srt/models/xverse.py +++ b/python/sglang/srt/models/xverse.py @@ -315,7 +315,7 @@ class XverseForCausalLM(nn.Module): ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights( diff --git a/python/sglang/srt/models/xverse_moe.py b/python/sglang/srt/models/xverse_moe.py index c6458f7f5..3a8b9a9e4 100644 --- a/python/sglang/srt/models/xverse_moe.py +++ b/python/sglang/srt/models/xverse_moe.py @@ -390,7 +390,7 @@ class XverseMoeForCausalLM(nn.Module): ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index be470dac3..e52350490 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -20,6 +20,7 @@ import random import tempfile from typing import List, Optional +from sglang.srt.hf_transformers_utils import check_gguf_file from sglang.srt.utils import ( get_amdgpu_memory_capacity, get_nvgpu_memory_capacity, @@ -204,6 +205,12 @@ class ServerArgs: "Overlap schedule is disabled." ) + # GGUF + if ( + self.load_format == "auto" or self.load_format == "gguf" + ) and check_gguf_file(self.model_path): + self.quantization = self.load_format = "gguf" + @staticmethod def add_cli_args(parser: argparse.ArgumentParser): # Model and port args @@ -243,7 +250,7 @@ class ServerArgs: "--load-format", type=str, default=ServerArgs.load_format, - choices=["auto", "pt", "safetensors", "npcache", "dummy"], + choices=["auto", "pt", "safetensors", "npcache", "dummy", "gguf"], help="The format of the model weights to load. " '"auto" will try to load the weights in the safetensors format ' "and fall back to the pytorch bin format if safetensors format " @@ -253,7 +260,8 @@ class ServerArgs: '"npcache" will load the weights in pytorch format and store ' "a numpy cache to speed up the loading. " '"dummy" will initialize the weights with random values, ' - "which is mainly for profiling.", + "which is mainly for profiling." + '"gguf" will load the weights in the gguf format. ', ) parser.add_argument( "--trust-remote-code", @@ -293,6 +301,7 @@ class ServerArgs: "gptq_marlin", "awq_marlin", "bitsandbytes", + "gguf", ], help="The quantization method.", ) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 46b4db8e8..89044c8b2 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -557,6 +557,29 @@ def monkey_patch_vllm_all_gather(reverse: bool = False): setattr(GroupCoordinator, "all_gather", all_gather) +def monkey_patch_vllm_gguf_config(): + from vllm.model_executor.layers.linear import LinearBase + from vllm.model_executor.layers.quantization.gguf import ( + GGUFConfig, + GGUFEmbeddingMethod, + GGUFLinearMethod, + ) + + from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding + + def get_quant_method_with_embedding_replaced( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: + if isinstance(layer, LinearBase): + return GGUFLinearMethod(self) + elif isinstance(layer, VocabParallelEmbedding): + # patch to own VocabParallelEmbedding + return GGUFEmbeddingMethod(self) + return None + + setattr(GGUFConfig, "get_quant_method", get_quant_method_with_embedding_replaced) + + def maybe_set_triton_cache_manager() -> None: """Set environment variable to tell Triton to use a custom cache manager""" diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index d441cf9b2..b055df1a1 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -15,6 +15,7 @@ suites = { "test_double_sparsity.py", "test_embedding_openai_server.py", "test_eval_accuracy_mini.py", + "test_gguf.py", "test_input_embeddings.py", "test_json_constrained.py", "test_large_max_new_tokens.py", diff --git a/test/srt/test_get_parameter_by_name.py b/test/srt/test_get_weights_by_name.py similarity index 98% rename from test/srt/test_get_parameter_by_name.py rename to test/srt/test_get_weights_by_name.py index 8dce1ac2c..6579646f4 100644 --- a/test/srt/test_get_parameter_by_name.py +++ b/test/srt/test_get_weights_by_name.py @@ -16,7 +16,7 @@ from sglang.test.test_utils import ( from sglang.utils import terminate_process -class TestGetParameterByName(unittest.TestCase): +class TestGetWeightsByName(unittest.TestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST diff --git a/test/srt/test_gguf.py b/test/srt/test_gguf.py new file mode 100644 index 000000000..89572c45f --- /dev/null +++ b/test/srt/test_gguf.py @@ -0,0 +1,26 @@ +import unittest + +from huggingface_hub import hf_hub_download + +import sglang as sgl + + +class TestGGUF(unittest.TestCase): + def test_models(self): + prompt = "Today is a sunny day and I like" + sampling_params = {"temperature": 0, "max_new_tokens": 8} + + model_path = hf_hub_download( + "Qwen/Qwen2-1.5B-Instruct-GGUF", + filename="qwen2-1_5b-instruct-q4_k_m.gguf", + ) + + engine = sgl.Engine(model_path=model_path, random_seed=42) + outputs = engine.generate(prompt, sampling_params)["text"] + engine.shutdown() + + self.assertEqual(outputs, " it. I have a lot of work") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_update_weights.py b/test/srt/test_update_weights_from_disk.py similarity index 100% rename from test/srt/test_update_weights.py rename to test/srt/test_update_weights_from_disk.py