Revert "Revert "[FEAT] Support GGUF format"" (#2287)
This commit is contained in:
@@ -15,3 +15,4 @@ sphinx-copybutton
|
|||||||
sphinx-tabs
|
sphinx-tabs
|
||||||
sphinxcontrib-mermaid
|
sphinxcontrib-mermaid
|
||||||
urllib3<2.0.0
|
urllib3<2.0.0
|
||||||
|
gguf>=0.10.0
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
|
from pathlib import Path
|
||||||
from typing import Dict, Optional, Type, Union
|
from typing import Dict, Optional, Type, Union
|
||||||
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
@@ -27,6 +28,7 @@ from transformers import (
|
|||||||
PreTrainedTokenizer,
|
PreTrainedTokenizer,
|
||||||
PreTrainedTokenizerFast,
|
PreTrainedTokenizerFast,
|
||||||
)
|
)
|
||||||
|
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
|
from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
|
||||||
@@ -60,15 +62,29 @@ def get_config(
|
|||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
model_override_args: Optional[dict] = None,
|
model_override_args: Optional[dict] = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
is_gguf = check_gguf_file(model)
|
||||||
|
if is_gguf:
|
||||||
|
kwargs["gguf_file"] = model
|
||||||
|
model = Path(model).parent
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model, trust_remote_code=trust_remote_code, revision=revision
|
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
|
||||||
)
|
)
|
||||||
if config.model_type in _CONFIG_REGISTRY:
|
if config.model_type in _CONFIG_REGISTRY:
|
||||||
config_class = _CONFIG_REGISTRY[config.model_type]
|
config_class = _CONFIG_REGISTRY[config.model_type]
|
||||||
config = config_class.from_pretrained(model, revision=revision)
|
config = config_class.from_pretrained(model, revision=revision)
|
||||||
if model_override_args:
|
if model_override_args:
|
||||||
config.update(model_override_args)
|
config.update(model_override_args)
|
||||||
|
|
||||||
|
# Special architecture mapping check for GGUF models
|
||||||
|
if is_gguf:
|
||||||
|
if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||||
|
raise RuntimeError(f"Can't get gguf config for {config.model_type}.")
|
||||||
|
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
|
||||||
|
config.update({"architectures": [model_type]})
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
@@ -123,6 +139,11 @@ def get_tokenizer(
|
|||||||
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
|
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
|
||||||
kwargs["use_fast"] = False
|
kwargs["use_fast"] = False
|
||||||
|
|
||||||
|
is_gguf = check_gguf_file(tokenizer_name)
|
||||||
|
if is_gguf:
|
||||||
|
kwargs["gguf_file"] = tokenizer_name
|
||||||
|
tokenizer_name = Path(tokenizer_name).parent
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
tokenizer_name,
|
tokenizer_name,
|
||||||
@@ -195,3 +216,16 @@ def attach_additional_stop_token_ids(tokenizer):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
tokenizer.additional_stop_token_ids = None
|
tokenizer.additional_stop_token_ids = None
|
||||||
|
|
||||||
|
|
||||||
|
def check_gguf_file(model: Union[str, os.PathLike]) -> bool:
|
||||||
|
"""Check if the file is a GGUF model."""
|
||||||
|
model = Path(model)
|
||||||
|
if not model.is_file():
|
||||||
|
return False
|
||||||
|
elif model.suffix == ".gguf":
|
||||||
|
return True
|
||||||
|
|
||||||
|
with open(model, "rb") as f:
|
||||||
|
header = f.read(4)
|
||||||
|
return header == b"GGUF"
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from vllm.distributed import (
|
|||||||
tensor_model_parallel_all_gather,
|
tensor_model_parallel_all_gather,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
|
|
||||||
|
|
||||||
@@ -163,7 +164,7 @@ class LogitsProcessor(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
weight,
|
lm_head: VocabParallelEmbedding,
|
||||||
logits_metadata: Union[LogitsMetadata, ForwardBatch],
|
logits_metadata: Union[LogitsMetadata, ForwardBatch],
|
||||||
):
|
):
|
||||||
if isinstance(logits_metadata, ForwardBatch):
|
if isinstance(logits_metadata, ForwardBatch):
|
||||||
@@ -178,7 +179,7 @@ class LogitsProcessor(nn.Module):
|
|||||||
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
|
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
|
||||||
last_hidden = hidden_states[last_index]
|
last_hidden = hidden_states[last_index]
|
||||||
|
|
||||||
last_logits = torch.matmul(last_hidden, weight.T)
|
last_logits = self._get_logits(last_hidden, lm_head)
|
||||||
if self.do_tensor_parallel_all_gather:
|
if self.do_tensor_parallel_all_gather:
|
||||||
last_logits = tensor_model_parallel_all_gather(last_logits)
|
last_logits = tensor_model_parallel_all_gather(last_logits)
|
||||||
last_logits = last_logits[:, : self.config.vocab_size].float()
|
last_logits = last_logits[:, : self.config.vocab_size].float()
|
||||||
@@ -229,7 +230,7 @@ class LogitsProcessor(nn.Module):
|
|||||||
|
|
||||||
# Compute the logits and logprobs for all required tokens
|
# Compute the logits and logprobs for all required tokens
|
||||||
states = torch.cat(states, dim=0)
|
states = torch.cat(states, dim=0)
|
||||||
all_logits = torch.matmul(states, weight.T)
|
all_logits = self._get_logits(states, lm_head)
|
||||||
if self.do_tensor_parallel_all_gather:
|
if self.do_tensor_parallel_all_gather:
|
||||||
all_logits = tensor_model_parallel_all_gather(all_logits)
|
all_logits = tensor_model_parallel_all_gather(all_logits)
|
||||||
all_logits = all_logits[:, : self.config.vocab_size].float()
|
all_logits = all_logits[:, : self.config.vocab_size].float()
|
||||||
@@ -276,6 +277,19 @@ class LogitsProcessor(nn.Module):
|
|||||||
output_top_logprobs=output_top_logprobs,
|
output_top_logprobs=output_top_logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
lm_head: VocabParallelEmbedding,
|
||||||
|
embedding_bias: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if hasattr(lm_head, "weight"):
|
||||||
|
logits = torch.matmul(hidden_states, lm_head.weight.T)
|
||||||
|
else:
|
||||||
|
# GGUF models
|
||||||
|
logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
def test():
|
def test():
|
||||||
all_logprobs = torch.tensor(
|
all_logprobs = torch.tensor(
|
||||||
|
|||||||
@@ -222,6 +222,7 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|||||||
enable_tp: bool = True,
|
enable_tp: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
self.enable_tp = enable_tp
|
self.enable_tp = enable_tp
|
||||||
if self.enable_tp:
|
if self.enable_tp:
|
||||||
|
|||||||
@@ -59,6 +59,7 @@ from sglang.srt.utils import (
|
|||||||
enable_show_time_cost,
|
enable_show_time_cost,
|
||||||
get_available_gpu_memory,
|
get_available_gpu_memory,
|
||||||
is_hip,
|
is_hip,
|
||||||
|
monkey_patch_vllm_gguf_config,
|
||||||
monkey_patch_vllm_model_config,
|
monkey_patch_vllm_model_config,
|
||||||
monkey_patch_vllm_p2p_access_check,
|
monkey_patch_vllm_p2p_access_check,
|
||||||
set_cpu_offload_max_bytes,
|
set_cpu_offload_max_bytes,
|
||||||
@@ -297,6 +298,8 @@ class ModelRunner:
|
|||||||
download_dir=self.server_args.download_dir,
|
download_dir=self.server_args.download_dir,
|
||||||
)
|
)
|
||||||
monkey_patch_vllm_model_config()
|
monkey_patch_vllm_model_config()
|
||||||
|
if self.server_args.load_format == "gguf":
|
||||||
|
monkey_patch_vllm_gguf_config()
|
||||||
self.vllm_model_config = VllmModelConfig(**self.get_model_config_params())
|
self.vllm_model_config = VllmModelConfig(**self.get_model_config_params())
|
||||||
if self.model_config.model_override_args is not None:
|
if self.model_config.model_override_args is not None:
|
||||||
self.vllm_model_config.hf_config.update(
|
self.vllm_model_config.hf_config.update(
|
||||||
|
|||||||
@@ -338,11 +338,12 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
|||||||
|
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.model = BaiChuanModel(config, position_embedding, quant_config)
|
self.model = BaiChuanModel(config, position_embedding, quant_config)
|
||||||
self.lm_head = ParallelLMHead(
|
|
||||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
|
||||||
)
|
|
||||||
if self.config.tie_word_embeddings:
|
if self.config.tie_word_embeddings:
|
||||||
self.lm_head.weight = self.model.embed_tokens.weight
|
self.lm_head = self.model.embed_tokens
|
||||||
|
else:
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||||
|
)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -353,7 +354,7 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, forward_batch)
|
hidden_states = self.model(input_ids, positions, forward_batch)
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
input_ids, hidden_states, self.lm_head, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
|||||||
@@ -378,7 +378,7 @@ class ChatGLMForCausalLM(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.transformer(input_ids, positions, forward_batch)
|
hidden_states = self.transformer(input_ids, positions, forward_batch)
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
input_ids, hidden_states, self.lm_head, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
|||||||
@@ -339,7 +339,7 @@ class CohereForCausalLM(nn.Module):
|
|||||||
forward_batch,
|
forward_batch,
|
||||||
)
|
)
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
|
input_ids, hidden_states, self.model.embed_tokens, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
|||||||
@@ -390,7 +390,7 @@ class DbrxForCausalLM(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.transformer(input_ids, positions, forward_batch)
|
hidden_states = self.transformer(input_ids, positions, forward_batch)
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
input_ids, hidden_states, self.lm_head, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
|||||||
@@ -394,7 +394,7 @@ class DeepseekForCausalLM(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, forward_batch)
|
hidden_states = self.model(input_ids, positions, forward_batch)
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
input_ids, hidden_states, self.lm_head, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
|||||||
@@ -763,7 +763,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
hidden_states = self.model(input_ids, positions, forward_batch)
|
hidden_states = self.model(input_ids, positions, forward_batch)
|
||||||
if not forward_batch.forward_mode.is_idle():
|
if not forward_batch.forward_mode.is_idle():
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
input_ids, hidden_states, self.lm_head, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
|||||||
@@ -314,7 +314,7 @@ class ExaoneForCausalLM(nn.Module):
|
|||||||
input_ids, positions, forward_batch, input_embeds
|
input_ids, positions, forward_batch, input_embeds
|
||||||
)
|
)
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
input_ids, hidden_states, self.lm_head, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
|||||||
@@ -298,7 +298,7 @@ class GemmaForCausalLM(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
|
input_ids, hidden_states, self.model.embed_tokens, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
|||||||
@@ -363,7 +363,7 @@ class Gemma2ForCausalLM(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
|
input_ids, hidden_states, self.model.embed_tokens, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_attention_sliding_window_size(self):
|
def get_attention_sliding_window_size(self):
|
||||||
|
|||||||
@@ -247,7 +247,7 @@ class GPT2LMHeadModel(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.transformer(input_ids, positions, forward_batch)
|
hidden_states = self.transformer(input_ids, positions, forward_batch)
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
input_ids, hidden_states, self.lm_head, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
|||||||
@@ -271,7 +271,7 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.transformer(input_ids, positions, forward_batch)
|
hidden_states = self.transformer(input_ids, positions, forward_batch)
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
input_ids, hidden_states, self.lm_head, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
|||||||
@@ -304,7 +304,7 @@ class Grok1ForCausalLM(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
input_ids, hidden_states, self.lm_head, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
|||||||
@@ -270,7 +270,7 @@ class InternLM2ForCausalLM(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.output.weight, forward_batch
|
input_ids, hidden_states, self.output, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.utils import make_layers
|
from sglang.srt.utils import make_layers
|
||||||
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -258,6 +259,7 @@ class LlamaModel(nn.Module):
|
|||||||
self.embed_tokens = VocabParallelEmbedding(
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
config.vocab_size,
|
config.vocab_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.layers = make_layers(
|
self.layers = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
@@ -305,7 +307,12 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.torchao_config = global_server_args_dict["torchao_config"]
|
self.torchao_config = global_server_args_dict["torchao_config"]
|
||||||
self.model = LlamaModel(config, quant_config=quant_config)
|
self.model = LlamaModel(config, quant_config=quant_config)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
if self.config.tie_word_embeddings:
|
||||||
|
self.lm_head = self.model.embed_tokens
|
||||||
|
else:
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||||
|
)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||||
self.stacked_params_mapping = [
|
self.stacked_params_mapping = [
|
||||||
@@ -329,7 +336,7 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||||
if not get_embedding:
|
if not get_embedding:
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
input_ids, hidden_states, self.lm_head, forward_batch
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self.pooler(hidden_states, forward_batch)
|
return self.pooler(hidden_states, forward_batch)
|
||||||
@@ -373,7 +380,6 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
return len(params_dict)
|
return len(params_dict)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
embed_tokens_weight = None
|
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
# (param_name, shard_name, shard_id)
|
# (param_name, shard_name, shard_id)
|
||||||
(".qkv_proj", ".q_proj", "q"),
|
(".qkv_proj", ".q_proj", "q"),
|
||||||
@@ -385,12 +391,6 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
|
|
||||||
load_tie_word_embeddings = (
|
|
||||||
hasattr(self.config, "tie_word_embeddings")
|
|
||||||
and self.config.tie_word_embeddings
|
|
||||||
and "lm_head.weight" in params_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if "rotary_emb.inv_freq" in name or "projector" in name:
|
if "rotary_emb.inv_freq" in name or "projector" in name:
|
||||||
continue
|
continue
|
||||||
@@ -423,16 +423,6 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
if load_tie_word_embeddings and name == "model.embed_tokens.weight":
|
|
||||||
embed_tokens_weight = loaded_weight
|
|
||||||
|
|
||||||
if load_tie_word_embeddings:
|
|
||||||
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
|
|
||||||
param = self.lm_head.weight
|
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
||||||
if embed_tokens_weight is not None:
|
|
||||||
weight_loader(param, embed_tokens_weight)
|
|
||||||
|
|
||||||
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
|
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
|
||||||
|
|
||||||
def get_weights_by_name(
|
def get_weights_by_name(
|
||||||
@@ -444,6 +434,17 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
For optimized performance, please use torch.save and torch.load.
|
For optimized performance, please use torch.save and torch.load.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
if name == "lm_head.weight" and self.config.tie_word_embeddings:
|
||||||
|
logger.info(
|
||||||
|
"word embedding is tied for this model, return embed_tokens.weight as lm_head.weight."
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
self.model.embed_tokens.weight.cpu()
|
||||||
|
.to(torch.float32)
|
||||||
|
.numpy()
|
||||||
|
.tolist()[:truncate_size]
|
||||||
|
)
|
||||||
|
|
||||||
mapped_name = name
|
mapped_name = name
|
||||||
mapped_shard_id = None
|
mapped_shard_id = None
|
||||||
for param_name, weight_name, shard_id in self.stacked_params_mapping:
|
for param_name, weight_name, shard_id in self.stacked_params_mapping:
|
||||||
@@ -452,54 +453,48 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
mapped_shard_id = shard_id
|
mapped_shard_id = shard_id
|
||||||
break
|
break
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
if mapped_name in params_dict:
|
param = params_dict[mapped_name]
|
||||||
param = params_dict[mapped_name]
|
if mapped_shard_id is not None:
|
||||||
if mapped_shard_id is not None:
|
if mapped_shard_id in ["q", "k", "v"]:
|
||||||
if mapped_shard_id in ["q", "k", "v"]:
|
num_heads = self.config.num_attention_heads // tp_size
|
||||||
num_heads = self.config.num_attention_heads // tp_size
|
num_kv_heads = self.config.num_key_value_heads // tp_size
|
||||||
num_kv_heads = self.config.num_key_value_heads // tp_size
|
head_dim = (
|
||||||
head_dim = (
|
self.config.hidden_size // self.config.num_attention_heads
|
||||||
self.config.hidden_size // self.config.num_attention_heads
|
)
|
||||||
)
|
if mapped_shard_id == "q":
|
||||||
if mapped_shard_id == "q":
|
offset = 0
|
||||||
offset = 0
|
size = num_heads * head_dim
|
||||||
size = num_heads * head_dim
|
elif mapped_shard_id == "k":
|
||||||
elif mapped_shard_id == "k":
|
offset = num_heads * head_dim
|
||||||
offset = num_heads * head_dim
|
size = num_kv_heads * head_dim
|
||||||
size = num_kv_heads * head_dim
|
elif mapped_shard_id == "v":
|
||||||
elif mapped_shard_id == "v":
|
offset = (num_heads + num_kv_heads) * head_dim
|
||||||
offset = (num_heads + num_kv_heads) * head_dim
|
size = num_kv_heads * head_dim
|
||||||
size = num_kv_heads * head_dim
|
weight = param.data.narrow(0, offset, size)
|
||||||
weight = param.data.narrow(0, offset, size)
|
elif mapped_shard_id in [0, 1]:
|
||||||
elif mapped_shard_id in [0, 1]:
|
intermediate_size = self.config.intermediate_size
|
||||||
intermediate_size = self.config.intermediate_size
|
slice_size = intermediate_size // tp_size
|
||||||
hidden_size = self.config.hidden_size
|
if mapped_shard_id == 0: # gate_proj
|
||||||
slice_size = intermediate_size // tp_size
|
offset = 0
|
||||||
if mapped_shard_id == 0: # gate_proj
|
size = slice_size
|
||||||
offset = 0
|
elif mapped_shard_id == 1: # up_proj
|
||||||
size = slice_size
|
offset = slice_size
|
||||||
elif mapped_shard_id == 1: # up_proj
|
size = slice_size
|
||||||
offset = slice_size
|
|
||||||
size = slice_size
|
|
||||||
|
|
||||||
weight = param.data.narrow(0, offset, size)
|
weight = param.data.narrow(0, offset, size)
|
||||||
else:
|
|
||||||
weight = param.data
|
|
||||||
else:
|
else:
|
||||||
weight = param.data
|
weight = param.data
|
||||||
if tp_size > 1 and ("o_proj" in name or "down_proj" in name):
|
|
||||||
gathered_weights = [
|
|
||||||
torch.zeros_like(weight) for _ in range(tp_size)
|
|
||||||
]
|
|
||||||
torch.distributed.all_gather(gathered_weights, weight)
|
|
||||||
weight = torch.cat(gathered_weights, dim=1)
|
|
||||||
return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size]
|
|
||||||
else:
|
else:
|
||||||
return None
|
weight = param.data
|
||||||
|
if tp_size > 1 and ("o_proj" in name or "down_proj" in name):
|
||||||
|
gathered_weights = [torch.zeros_like(weight) for _ in range(tp_size)]
|
||||||
|
torch.distributed.all_gather(gathered_weights, weight)
|
||||||
|
weight = torch.cat(gathered_weights, dim=1)
|
||||||
|
return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size]
|
||||||
|
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error getting weights by name {name} in LlamaForCausalLM: {e}"
|
f"Error getting weights by name {name} in LlamaForCausalLM: {get_exception_traceback()}"
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -308,12 +308,10 @@ class MiniCPMForCausalLM(nn.Module):
|
|||||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||||
hidden_states = hidden_states / self.scale_width
|
hidden_states = hidden_states / self.scale_width
|
||||||
if self.config.tie_word_embeddings:
|
if self.config.tie_word_embeddings:
|
||||||
lm_head_weight = self.model.embed_tokens.weight
|
lm_head = self.model.embed_tokens
|
||||||
else:
|
else:
|
||||||
lm_head_weight = self.lm_head.weight
|
lm_head = self.lm_head
|
||||||
return self.logits_processor(
|
return self.logits_processor(input_ids, hidden_states, lm_head, forward_batch)
|
||||||
input_ids, hidden_states, lm_head_weight, forward_batch
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
|
|||||||
@@ -585,12 +585,10 @@ class MiniCPM3ForCausalLM(nn.Module):
|
|||||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||||
hidden_states = hidden_states / self.scale_width
|
hidden_states = hidden_states / self.scale_width
|
||||||
if self.config.tie_word_embeddings:
|
if self.config.tie_word_embeddings:
|
||||||
lm_head_weight = self.model.embed_tokens.weight
|
lm_head = self.model.embed_tokens
|
||||||
else:
|
else:
|
||||||
lm_head_weight = self.lm_head.weight
|
lm_head = self.lm_head
|
||||||
return self.logits_processor(
|
return self.logits_processor(input_ids, hidden_states, lm_head, forward_batch)
|
||||||
input_ids, hidden_states, lm_head_weight, forward_batch
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
|
|||||||
@@ -310,7 +310,7 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
input_ids, hidden_states, self.lm_head, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
|||||||
@@ -343,7 +343,7 @@ class QuantMixtralForCausalLM(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
input_ids, hidden_states, self.lm_head, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
|||||||
@@ -966,7 +966,7 @@ class MllamaForConditionalGeneration(nn.Module):
|
|||||||
skip_cross_attention=skip_cross_attention,
|
skip_cross_attention=skip_cross_attention,
|
||||||
)
|
)
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.language_model.lm_head.weight, forward_batch
|
input_ids, hidden_states, self.language_model.lm_head, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
|||||||
@@ -306,7 +306,7 @@ class OlmoForCausalLM(nn.Module):
|
|||||||
input_embeds=input_embeds,
|
input_embeds=input_embeds,
|
||||||
)
|
)
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
input_ids, hidden_states, self.lm_head, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
@@ -326,11 +326,6 @@ class OlmoForCausalLM(nn.Module):
|
|||||||
# Models trained using ColossalAI may include these tensors in
|
# Models trained using ColossalAI may include these tensors in
|
||||||
# the checkpoint. Skip them.
|
# the checkpoint. Skip them.
|
||||||
continue
|
continue
|
||||||
# With tie_word_embeddings, we can skip lm_head.weight
|
|
||||||
# The weight might appear unnecessarily in the files if the model is
|
|
||||||
# processed with quantization, LoRA, fine-tuning, etc.
|
|
||||||
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
|
||||||
continue
|
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -321,7 +321,7 @@ class OlmoeForCausalLM(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
input_ids, hidden_states, self.lm_head, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
|||||||
@@ -397,10 +397,13 @@ class Phi3SmallForCausalLM(nn.Module):
|
|||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
self,
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
sampling_metadata,
|
sampling_metadata,
|
||||||
) -> Optional[torch.Tensor]:
|
) -> Optional[torch.Tensor]:
|
||||||
logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata)
|
logits = self.logits_processor(
|
||||||
|
input_ids, self.lm_head, hidden_states, sampling_metadata
|
||||||
|
)
|
||||||
if self.dummy_token_indices is not None and logits is not None:
|
if self.dummy_token_indices is not None and logits is not None:
|
||||||
logits.index_fill_(-1, self.dummy_token_indices, -torch.inf)
|
logits.index_fill_(-1, self.dummy_token_indices, -torch.inf)
|
||||||
return logits
|
return logits
|
||||||
@@ -422,7 +425,7 @@ class Phi3SmallForCausalLM(nn.Module):
|
|||||||
|
|
||||||
if not get_embedding:
|
if not get_embedding:
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
input_ids, hidden_states, self.lm_head, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -260,7 +260,7 @@ class QWenLMHeadModel(nn.Module):
|
|||||||
):
|
):
|
||||||
hidden_states = self.transformer(input_ids, positions, forward_batch)
|
hidden_states = self.transformer(input_ids, positions, forward_batch)
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
input_ids, hidden_states, self.lm_head, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
|||||||
@@ -230,6 +230,7 @@ class Qwen2Model(nn.Module):
|
|||||||
self.embed_tokens = VocabParallelEmbedding(
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
config.vocab_size,
|
config.vocab_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.layers = make_layers(
|
self.layers = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
@@ -276,7 +277,12 @@ class Qwen2ForCausalLM(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.model = Qwen2Model(config, quant_config=quant_config)
|
self.model = Qwen2Model(config, quant_config=quant_config)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
if config.tie_word_embeddings:
|
||||||
|
self.lm_head = self.model.embed_tokens
|
||||||
|
else:
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||||
|
)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||||
|
|
||||||
@@ -292,7 +298,7 @@ class Qwen2ForCausalLM(nn.Module):
|
|||||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||||
if not get_embedding:
|
if not get_embedding:
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
input_ids, hidden_states, self.lm_head, forward_batch
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self.pooler(hidden_states, forward_batch)
|
return self.pooler(hidden_states, forward_batch)
|
||||||
@@ -306,6 +312,7 @@ class Qwen2ForCausalLM(nn.Module):
|
|||||||
("gate_up_proj", "gate_proj", 0),
|
("gate_up_proj", "gate_proj", 0),
|
||||||
("gate_up_proj", "up_proj", 1),
|
("gate_up_proj", "up_proj", 1),
|
||||||
]
|
]
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if "rotary_emb.inv_freq" in name or "projector" in name:
|
if "rotary_emb.inv_freq" in name or "projector" in name:
|
||||||
@@ -335,11 +342,6 @@ class Qwen2ForCausalLM(nn.Module):
|
|||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
if (
|
|
||||||
self.config.tie_word_embeddings
|
|
||||||
and name == "model.embed_tokens.weight"
|
|
||||||
):
|
|
||||||
weight_loader(params_dict["lm_head.weight"], loaded_weight)
|
|
||||||
|
|
||||||
|
|
||||||
EntryClass = Qwen2ForCausalLM
|
EntryClass = Qwen2ForCausalLM
|
||||||
|
|||||||
@@ -376,7 +376,7 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
input_ids, hidden_states, self.lm_head, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
|||||||
@@ -668,7 +668,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
|
|
||||||
if not get_embedding:
|
if not get_embedding:
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
input_ids, hidden_states, self.lm_head, forward_batch
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self.pooler(hidden_states, forward_batch)
|
return self.pooler(hidden_states, forward_batch)
|
||||||
@@ -686,8 +686,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
|
||||||
continue
|
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -261,7 +261,7 @@ class StableLmForCausalLM(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
input_ids, hidden_states, self.lm_head, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
|||||||
@@ -396,7 +396,10 @@ class TorchNativeLlamaForCausalLM(nn.Module):
|
|||||||
self.torchao_config = global_server_args_dict["torchao_config"]
|
self.torchao_config = global_server_args_dict["torchao_config"]
|
||||||
self.supports_torch_tp = True
|
self.supports_torch_tp = True
|
||||||
self.model = LlamaModel(config, quant_config=quant_config)
|
self.model = LlamaModel(config, quant_config=quant_config)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
if self.config.tie_word_embeddings:
|
||||||
|
self.lm_head = self.model.embed_tokens
|
||||||
|
else:
|
||||||
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
|
||||||
# turning off autotune for fp8dq since it doesn't give speedup and
|
# turning off autotune for fp8dq since it doesn't give speedup and
|
||||||
@@ -413,7 +416,7 @@ class TorchNativeLlamaForCausalLM(nn.Module):
|
|||||||
) -> LogitsProcessorOutput:
|
) -> LogitsProcessorOutput:
|
||||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
input_ids, hidden_states, self.lm_head, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_hidden_dim(self, module_name):
|
def get_hidden_dim(self, module_name):
|
||||||
@@ -501,14 +504,6 @@ class TorchNativeLlamaForCausalLM(nn.Module):
|
|||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
if (
|
|
||||||
hasattr(self.config, "tie_word_embeddings")
|
|
||||||
and self.config.tie_word_embeddings
|
|
||||||
):
|
|
||||||
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
|
|
||||||
param = self.lm_head.weight
|
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
||||||
weight_loader(param, self.model.embed_tokens.weight)
|
|
||||||
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
|
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -315,7 +315,7 @@ class XverseForCausalLM(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
input_ids, hidden_states, self.lm_head, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_weights(
|
def load_weights(
|
||||||
|
|||||||
@@ -390,7 +390,7 @@ class XverseMoeForCausalLM(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, forward_batch)
|
hidden_states = self.model(input_ids, positions, forward_batch)
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
input_ids, hidden_states, self.lm_head, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import random
|
|||||||
import tempfile
|
import tempfile
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from sglang.srt.hf_transformers_utils import check_gguf_file
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
get_amdgpu_memory_capacity,
|
get_amdgpu_memory_capacity,
|
||||||
get_nvgpu_memory_capacity,
|
get_nvgpu_memory_capacity,
|
||||||
@@ -204,6 +205,12 @@ class ServerArgs:
|
|||||||
"Overlap schedule is disabled."
|
"Overlap schedule is disabled."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# GGUF
|
||||||
|
if (
|
||||||
|
self.load_format == "auto" or self.load_format == "gguf"
|
||||||
|
) and check_gguf_file(self.model_path):
|
||||||
|
self.quantization = self.load_format = "gguf"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_cli_args(parser: argparse.ArgumentParser):
|
def add_cli_args(parser: argparse.ArgumentParser):
|
||||||
# Model and port args
|
# Model and port args
|
||||||
@@ -243,7 +250,7 @@ class ServerArgs:
|
|||||||
"--load-format",
|
"--load-format",
|
||||||
type=str,
|
type=str,
|
||||||
default=ServerArgs.load_format,
|
default=ServerArgs.load_format,
|
||||||
choices=["auto", "pt", "safetensors", "npcache", "dummy"],
|
choices=["auto", "pt", "safetensors", "npcache", "dummy", "gguf"],
|
||||||
help="The format of the model weights to load. "
|
help="The format of the model weights to load. "
|
||||||
'"auto" will try to load the weights in the safetensors format '
|
'"auto" will try to load the weights in the safetensors format '
|
||||||
"and fall back to the pytorch bin format if safetensors format "
|
"and fall back to the pytorch bin format if safetensors format "
|
||||||
@@ -253,7 +260,8 @@ class ServerArgs:
|
|||||||
'"npcache" will load the weights in pytorch format and store '
|
'"npcache" will load the weights in pytorch format and store '
|
||||||
"a numpy cache to speed up the loading. "
|
"a numpy cache to speed up the loading. "
|
||||||
'"dummy" will initialize the weights with random values, '
|
'"dummy" will initialize the weights with random values, '
|
||||||
"which is mainly for profiling.",
|
"which is mainly for profiling."
|
||||||
|
'"gguf" will load the weights in the gguf format. ',
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--trust-remote-code",
|
"--trust-remote-code",
|
||||||
@@ -293,6 +301,7 @@ class ServerArgs:
|
|||||||
"gptq_marlin",
|
"gptq_marlin",
|
||||||
"awq_marlin",
|
"awq_marlin",
|
||||||
"bitsandbytes",
|
"bitsandbytes",
|
||||||
|
"gguf",
|
||||||
],
|
],
|
||||||
help="The quantization method.",
|
help="The quantization method.",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -557,6 +557,29 @@ def monkey_patch_vllm_all_gather(reverse: bool = False):
|
|||||||
setattr(GroupCoordinator, "all_gather", all_gather)
|
setattr(GroupCoordinator, "all_gather", all_gather)
|
||||||
|
|
||||||
|
|
||||||
|
def monkey_patch_vllm_gguf_config():
|
||||||
|
from vllm.model_executor.layers.linear import LinearBase
|
||||||
|
from vllm.model_executor.layers.quantization.gguf import (
|
||||||
|
GGUFConfig,
|
||||||
|
GGUFEmbeddingMethod,
|
||||||
|
GGUFLinearMethod,
|
||||||
|
)
|
||||||
|
|
||||||
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||||
|
|
||||||
|
def get_quant_method_with_embedding_replaced(
|
||||||
|
self, layer: torch.nn.Module, prefix: str
|
||||||
|
) -> Optional["QuantizeMethodBase"]:
|
||||||
|
if isinstance(layer, LinearBase):
|
||||||
|
return GGUFLinearMethod(self)
|
||||||
|
elif isinstance(layer, VocabParallelEmbedding):
|
||||||
|
# patch to own VocabParallelEmbedding
|
||||||
|
return GGUFEmbeddingMethod(self)
|
||||||
|
return None
|
||||||
|
|
||||||
|
setattr(GGUFConfig, "get_quant_method", get_quant_method_with_embedding_replaced)
|
||||||
|
|
||||||
|
|
||||||
def maybe_set_triton_cache_manager() -> None:
|
def maybe_set_triton_cache_manager() -> None:
|
||||||
"""Set environment variable to tell Triton to use a
|
"""Set environment variable to tell Triton to use a
|
||||||
custom cache manager"""
|
custom cache manager"""
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ suites = {
|
|||||||
"test_double_sparsity.py",
|
"test_double_sparsity.py",
|
||||||
"test_embedding_openai_server.py",
|
"test_embedding_openai_server.py",
|
||||||
"test_eval_accuracy_mini.py",
|
"test_eval_accuracy_mini.py",
|
||||||
|
"test_gguf.py",
|
||||||
"test_input_embeddings.py",
|
"test_input_embeddings.py",
|
||||||
"test_json_constrained.py",
|
"test_json_constrained.py",
|
||||||
"test_large_max_new_tokens.py",
|
"test_large_max_new_tokens.py",
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ from sglang.test.test_utils import (
|
|||||||
from sglang.utils import terminate_process
|
from sglang.utils import terminate_process
|
||||||
|
|
||||||
|
|
||||||
class TestGetParameterByName(unittest.TestCase):
|
class TestGetWeightsByName(unittest.TestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
26
test/srt/test_gguf.py
Normal file
26
test/srt/test_gguf.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
|
import sglang as sgl
|
||||||
|
|
||||||
|
|
||||||
|
class TestGGUF(unittest.TestCase):
|
||||||
|
def test_models(self):
|
||||||
|
prompt = "Today is a sunny day and I like"
|
||||||
|
sampling_params = {"temperature": 0, "max_new_tokens": 8}
|
||||||
|
|
||||||
|
model_path = hf_hub_download(
|
||||||
|
"Qwen/Qwen2-1.5B-Instruct-GGUF",
|
||||||
|
filename="qwen2-1_5b-instruct-q4_k_m.gguf",
|
||||||
|
)
|
||||||
|
|
||||||
|
engine = sgl.Engine(model_path=model_path, random_seed=42)
|
||||||
|
outputs = engine.generate(prompt, sampling_params)["text"]
|
||||||
|
engine.shutdown()
|
||||||
|
|
||||||
|
self.assertEqual(outputs, " it. I have a lot of work")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user