Revert "[FEAT] Support GGUF format" (#2285)

This commit is contained in:
Lianmin Zheng
2024-11-30 19:03:26 -08:00
committed by GitHub
parent d622851dc9
commit 7e4c6dd8da
39 changed files with 89 additions and 180 deletions

View File

@@ -16,7 +16,6 @@
import contextlib
import os
import warnings
from pathlib import Path
from typing import Dict, Optional, Type, Union
from huggingface_hub import snapshot_download
@@ -28,7 +27,6 @@ from transformers import (
PreTrainedTokenizer,
PreTrainedTokenizerFast,
)
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
try:
from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
@@ -62,29 +60,15 @@ def get_config(
trust_remote_code: bool,
revision: Optional[str] = None,
model_override_args: Optional[dict] = None,
**kwargs,
):
is_gguf = check_gguf_file(model)
if is_gguf:
kwargs["gguf_file"] = model
model = Path(model).parent
config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
model, trust_remote_code=trust_remote_code, revision=revision
)
if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type]
config = config_class.from_pretrained(model, revision=revision)
if model_override_args:
config.update(model_override_args)
# Special architecture mapping check for GGUF models
if is_gguf:
if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
raise RuntimeError(f"Can't get gguf config for {config.model_type}.")
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
config.update({"architectures": [model_type]})
return config
@@ -139,11 +123,6 @@ def get_tokenizer(
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
kwargs["use_fast"] = False
is_gguf = check_gguf_file(tokenizer_name)
if is_gguf:
kwargs["gguf_file"] = tokenizer_name
tokenizer_name = Path(tokenizer_name).parent
try:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
@@ -216,16 +195,3 @@ def attach_additional_stop_token_ids(tokenizer):
)
else:
tokenizer.additional_stop_token_ids = None
def check_gguf_file(model: Union[str, os.PathLike]) -> bool:
"""Check if the file is a GGUF model."""
model = Path(model)
if not model.is_file():
return False
elif model.suffix == ".gguf":
return True
with open(model, "rb") as f:
header = f.read(4)
return header == b"GGUF"

View File

@@ -23,7 +23,6 @@ from vllm.distributed import (
tensor_model_parallel_all_gather,
)
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
@@ -164,7 +163,7 @@ class LogitsProcessor(nn.Module):
self,
input_ids,
hidden_states,
lm_head: VocabParallelEmbedding,
weight,
logits_metadata: Union[LogitsMetadata, ForwardBatch],
):
if isinstance(logits_metadata, ForwardBatch):
@@ -179,7 +178,7 @@ class LogitsProcessor(nn.Module):
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
last_hidden = hidden_states[last_index]
last_logits = self._get_logits(last_hidden, lm_head)
last_logits = torch.matmul(last_hidden, weight.T)
if self.do_tensor_parallel_all_gather:
last_logits = tensor_model_parallel_all_gather(last_logits)
last_logits = last_logits[:, : self.config.vocab_size].float()
@@ -230,7 +229,7 @@ class LogitsProcessor(nn.Module):
# Compute the logits and logprobs for all required tokens
states = torch.cat(states, dim=0)
all_logits = self._get_logits(states, lm_head)
all_logits = torch.matmul(states, weight.T)
if self.do_tensor_parallel_all_gather:
all_logits = tensor_model_parallel_all_gather(all_logits)
all_logits = all_logits[:, : self.config.vocab_size].float()
@@ -277,19 +276,6 @@ class LogitsProcessor(nn.Module):
output_top_logprobs=output_top_logprobs,
)
def _get_logits(
self,
hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding,
embedding_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if hasattr(lm_head, "weight"):
logits = torch.matmul(hidden_states, lm_head.weight.T)
else:
# GGUF models
logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias)
return logits
def test():
all_logprobs = torch.tensor(

View File

@@ -222,7 +222,6 @@ class VocabParallelEmbedding(torch.nn.Module):
enable_tp: bool = True,
):
super().__init__()
self.quant_config = quant_config
self.enable_tp = enable_tp
if self.enable_tp:

View File

@@ -59,7 +59,6 @@ from sglang.srt.utils import (
enable_show_time_cost,
get_available_gpu_memory,
is_hip,
monkey_patch_vllm_gguf_config,
monkey_patch_vllm_model_config,
monkey_patch_vllm_p2p_access_check,
set_cpu_offload_max_bytes,
@@ -298,8 +297,6 @@ class ModelRunner:
download_dir=self.server_args.download_dir,
)
monkey_patch_vllm_model_config()
if self.server_args.load_format == "gguf":
monkey_patch_vllm_gguf_config()
self.vllm_model_config = VllmModelConfig(**self.get_model_config_params())
if self.model_config.model_override_args is not None:
self.vllm_model_config.hf_config.update(

View File

@@ -338,12 +338,11 @@ class BaiChuanBaseForCausalLM(nn.Module):
self.quant_config = quant_config
self.model = BaiChuanModel(config, position_embedding, quant_config)
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config)
def forward(
@@ -354,7 +353,7 @@ class BaiChuanBaseForCausalLM(nn.Module):
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -378,7 +378,7 @@ class ChatGLMForCausalLM(nn.Module):
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -339,7 +339,7 @@ class CohereForCausalLM(nn.Module):
forward_batch,
)
return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens, forward_batch
input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -390,7 +390,7 @@ class DbrxForCausalLM(nn.Module):
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -394,7 +394,7 @@ class DeepseekForCausalLM(nn.Module):
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -763,7 +763,7 @@ class DeepseekV2ForCausalLM(nn.Module):
hidden_states = self.model(input_ids, positions, forward_batch)
if not forward_batch.forward_mode.is_idle():
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -314,7 +314,7 @@ class ExaoneForCausalLM(nn.Module):
input_ids, positions, forward_batch, input_embeds
)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -298,7 +298,7 @@ class GemmaForCausalLM(nn.Module):
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens, forward_batch
input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -363,7 +363,7 @@ class Gemma2ForCausalLM(nn.Module):
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens, forward_batch
input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
)
def get_attention_sliding_window_size(self):

View File

@@ -247,7 +247,7 @@ class GPT2LMHeadModel(nn.Module):
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -271,7 +271,7 @@ class GPTBigCodeForCausalLM(nn.Module):
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -304,7 +304,7 @@ class Grok1ForCausalLM(nn.Module):
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -270,7 +270,7 @@ class InternLM2ForCausalLM(nn.Module):
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.output, forward_batch
input_ids, hidden_states, self.output.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -258,7 +258,6 @@ class LlamaModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
)
self.layers = make_layers(
config.num_hidden_layers,
@@ -306,12 +305,7 @@ class LlamaForCausalLM(nn.Module):
self.quant_config = quant_config
self.torchao_config = global_server_args_dict["torchao_config"]
self.model = LlamaModel(config, quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.stacked_params_mapping = [
@@ -335,7 +329,7 @@ class LlamaForCausalLM(nn.Module):
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
if not get_embedding:
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
else:
return self.pooler(hidden_states, forward_batch)
@@ -379,6 +373,7 @@ class LlamaForCausalLM(nn.Module):
return len(params_dict)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
embed_tokens_weight = None
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
@@ -390,6 +385,12 @@ class LlamaForCausalLM(nn.Module):
params_dict = dict(self.named_parameters())
load_tie_word_embeddings = (
hasattr(self.config, "tie_word_embeddings")
and self.config.tie_word_embeddings
and "lm_head.weight" in params_dict
)
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name or "projector" in name:
continue
@@ -422,6 +423,16 @@ class LlamaForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
if load_tie_word_embeddings and name == "model.embed_tokens.weight":
embed_tokens_weight = loaded_weight
if load_tie_word_embeddings:
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
param = self.lm_head.weight
weight_loader = getattr(param, "weight_loader", default_weight_loader)
if embed_tokens_weight is not None:
weight_loader(param, embed_tokens_weight)
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
def get_weights_by_name(

View File

@@ -308,10 +308,12 @@ class MiniCPMForCausalLM(nn.Module):
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
hidden_states = hidden_states / self.scale_width
if self.config.tie_word_embeddings:
lm_head = self.model.embed_tokens
lm_head_weight = self.model.embed_tokens.weight
else:
lm_head = self.lm_head
return self.logits_processor(input_ids, hidden_states, lm_head, forward_batch)
lm_head_weight = self.lm_head.weight
return self.logits_processor(
input_ids, hidden_states, lm_head_weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [

View File

@@ -585,10 +585,12 @@ class MiniCPM3ForCausalLM(nn.Module):
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
hidden_states = hidden_states / self.scale_width
if self.config.tie_word_embeddings:
lm_head = self.model.embed_tokens
lm_head_weight = self.model.embed_tokens.weight
else:
lm_head = self.lm_head
return self.logits_processor(input_ids, hidden_states, lm_head, forward_batch)
lm_head_weight = self.lm_head.weight
return self.logits_processor(
input_ids, hidden_states, lm_head_weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [

View File

@@ -310,7 +310,7 @@ class MixtralForCausalLM(nn.Module):
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -343,7 +343,7 @@ class QuantMixtralForCausalLM(nn.Module):
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -966,7 +966,7 @@ class MllamaForConditionalGeneration(nn.Module):
skip_cross_attention=skip_cross_attention,
)
return self.logits_processor(
input_ids, hidden_states, self.language_model.lm_head, forward_batch
input_ids, hidden_states, self.language_model.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -306,7 +306,7 @@ class OlmoForCausalLM(nn.Module):
input_embeds=input_embeds,
)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -326,6 +326,11 @@ class OlmoForCausalLM(nn.Module):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue

View File

@@ -321,7 +321,7 @@ class OlmoeForCausalLM(nn.Module):
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -397,13 +397,10 @@ class Phi3SmallForCausalLM(nn.Module):
def compute_logits(
self,
input_ids: torch.LongTensor,
hidden_states: torch.Tensor,
sampling_metadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(
input_ids, self.lm_head, hidden_states, sampling_metadata
)
logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata)
if self.dummy_token_indices is not None and logits is not None:
logits.index_fill_(-1, self.dummy_token_indices, -torch.inf)
return logits
@@ -425,7 +422,7 @@ class Phi3SmallForCausalLM(nn.Module):
if not get_embedding:
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
else:

View File

@@ -260,7 +260,7 @@ class QWenLMHeadModel(nn.Module):
):
hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -230,7 +230,6 @@ class Qwen2Model(nn.Module):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
)
self.layers = make_layers(
config.num_hidden_layers,
@@ -277,12 +276,7 @@ class Qwen2ForCausalLM(nn.Module):
self.config = config
self.quant_config = quant_config
self.model = Qwen2Model(config, quant_config=quant_config)
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
@@ -298,7 +292,7 @@ class Qwen2ForCausalLM(nn.Module):
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
if not get_embedding:
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
else:
return self.pooler(hidden_states, forward_batch)
@@ -312,7 +306,6 @@ class Qwen2ForCausalLM(nn.Module):
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name or "projector" in name:
@@ -342,6 +335,11 @@ class Qwen2ForCausalLM(nn.Module):
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
if (
self.config.tie_word_embeddings
and name == "model.embed_tokens.weight"
):
weight_loader(params_dict["lm_head.weight"], loaded_weight)
EntryClass = Qwen2ForCausalLM

View File

@@ -376,7 +376,7 @@ class Qwen2MoeForCausalLM(nn.Module):
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -668,7 +668,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
if not get_embedding:
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
else:
return self.pooler(hidden_states, forward_batch)
@@ -686,6 +686,8 @@ class Qwen2VLForConditionalGeneration(nn.Module):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue

View File

@@ -261,7 +261,7 @@ class StableLmForCausalLM(nn.Module):
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -396,10 +396,7 @@ class TorchNativeLlamaForCausalLM(nn.Module):
self.torchao_config = global_server_args_dict["torchao_config"]
self.supports_torch_tp = True
self.model = LlamaModel(config, quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
# turning off autotune for fp8dq since it doesn't give speedup and
@@ -416,7 +413,7 @@ class TorchNativeLlamaForCausalLM(nn.Module):
) -> LogitsProcessorOutput:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def get_hidden_dim(self, module_name):
@@ -504,6 +501,14 @@ class TorchNativeLlamaForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
if (
hasattr(self.config, "tie_word_embeddings")
and self.config.tie_word_embeddings
):
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
param = self.lm_head.weight
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, self.model.embed_tokens.weight)
apply_torchao_config_(self, params_dict, set(["proj.weight"]))

View File

@@ -315,7 +315,7 @@ class XverseForCausalLM(nn.Module):
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(

View File

@@ -390,7 +390,7 @@ class XverseMoeForCausalLM(nn.Module):
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -20,7 +20,6 @@ import random
import tempfile
from typing import List, Optional
from sglang.srt.hf_transformers_utils import check_gguf_file
from sglang.srt.utils import (
get_amdgpu_memory_capacity,
get_nvgpu_memory_capacity,
@@ -205,12 +204,6 @@ class ServerArgs:
"Overlap schedule is disabled."
)
# GGUF
if (
self.load_format == "auto" or self.load_format == "gguf"
) and check_gguf_file(self.model_path):
self.quantization = self.load_format = "gguf"
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
# Model and port args
@@ -250,7 +243,7 @@ class ServerArgs:
"--load-format",
type=str,
default=ServerArgs.load_format,
choices=["auto", "pt", "safetensors", "npcache", "dummy", "gguf"],
choices=["auto", "pt", "safetensors", "npcache", "dummy"],
help="The format of the model weights to load. "
'"auto" will try to load the weights in the safetensors format '
"and fall back to the pytorch bin format if safetensors format "
@@ -260,8 +253,7 @@ class ServerArgs:
'"npcache" will load the weights in pytorch format and store '
"a numpy cache to speed up the loading. "
'"dummy" will initialize the weights with random values, '
"which is mainly for profiling."
'"gguf" will load the weights in the gguf format. ',
"which is mainly for profiling.",
)
parser.add_argument(
"--trust-remote-code",
@@ -301,7 +293,6 @@ class ServerArgs:
"gptq_marlin",
"awq_marlin",
"bitsandbytes",
"gguf",
],
help="The quantization method.",
)

View File

@@ -557,29 +557,6 @@ def monkey_patch_vllm_all_gather(reverse: bool = False):
setattr(GroupCoordinator, "all_gather", all_gather)
def monkey_patch_vllm_gguf_config():
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.gguf import (
GGUFConfig,
GGUFEmbeddingMethod,
GGUFLinearMethod,
)
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
def get_quant_method_with_embedding_replaced(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
return GGUFLinearMethod(self)
elif isinstance(layer, VocabParallelEmbedding):
# patch to own VocabParallelEmbedding
return GGUFEmbeddingMethod(self)
return None
setattr(GGUFConfig, "get_quant_method", get_quant_method_with_embedding_replaced)
def maybe_set_triton_cache_manager() -> None:
"""Set environment variable to tell Triton to use a
custom cache manager"""