Update model_loader deps and qqq quantization deps (#2220) (#2318)

Co-authored-by: HandH1998 <1335248067@qq.com>
This commit is contained in:
Yineng Zhang
2024-12-02 23:22:13 +08:00
committed by GitHub
parent 33deca81b5
commit 85e1a6f3aa
58 changed files with 2363 additions and 366 deletions

View File

@@ -111,8 +111,12 @@ def load_model(server_args, port_args, tp_rank):
model_config = ModelConfig(
server_args.model_path,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding,
dtype=server_args.dtype,
quantization=server_args.quantization,
)
model_runner = ModelRunner(
model_config=model_config,

View File

@@ -0,0 +1,17 @@
import logging
from typing import Optional
import torch
logger = logging.getLogger(__name__)
class DeviceConfig:
device: Optional[torch.device]
def __init__(self, device: str = "cuda") -> None:
if device in ["cuda", "xpu", "hpu"]:
self.device_type = device
else:
raise RuntimeError(f"Not supported device type: {device}")
self.device = torch.device(self.device_type)

View File

@@ -0,0 +1,84 @@
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
import enum
import json
import logging
from dataclasses import dataclass, field
from typing import List, Optional, Union
from sglang.srt.utils import is_hip
logger = logging.getLogger(__name__)
class LoadFormat(str, enum.Enum):
AUTO = "auto"
PT = "pt"
SAFETENSORS = "safetensors"
NPCACHE = "npcache"
DUMMY = "dummy"
SHARDED_STATE = "sharded_state"
GGUF = "gguf"
BITSANDBYTES = "bitsandbytes"
MISTRAL = "mistral"
@dataclass
class LoadConfig:
"""
download_dir: Directory to download and load the weights, default to the
default cache directory of huggingface.
load_format: The format of the model weights to load:
"auto" will try to load the weights in the safetensors format and
fall back to the pytorch bin format if safetensors format is
not available.
"pt" will load the weights in the pytorch bin format.
"safetensors" will load the weights in the safetensors format.
"npcache" will load the weights in pytorch format and store
a numpy cache to speed up the loading.
"dummy" will initialize the weights with random values, which is
mainly for profiling.
"bitsandbytes" will load nf4 type weights.
ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's
checkpoints.
"""
load_format: Union[str, LoadFormat] = LoadFormat.AUTO
download_dir: Optional[str] = None
model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)
ignore_patterns: Optional[Union[List[str], str]] = None
def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {}
if isinstance(model_loader_extra_config, str):
self.model_loader_extra_config = json.loads(model_loader_extra_config)
self._verify_load_format()
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
logger.info(
"Ignoring the following patterns when downloading weights: %s",
self.ignore_patterns,
)
else:
self.ignore_patterns = ["original/**/*"]
def _verify_load_format(self) -> None:
if not isinstance(self.load_format, str):
return
load_format = self.load_format.lower()
self.load_format = LoadFormat(load_format)
rocm_not_supported_load_format: List[str] = []
if is_hip() and load_format in rocm_not_supported_load_format:
rocm_supported_load_format = [
f
for f in LoadFormat.__members__
if (f not in rocm_not_supported_load_format)
]
raise ValueError(
f"load format '{load_format}' is not supported in ROCm. "
f"Supported load formats are "
f"{rocm_supported_load_format}"
)

View File

@@ -15,12 +15,14 @@
import json
import logging
from enum import IntEnum, auto
from typing import List, Optional
from typing import List, Optional, Union
import torch
from transformers import PretrainedConfig
from sglang.srt.hf_transformers_utils import get_config, get_context_length
from sglang.srt.utils import get_bool_env_var
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
from sglang.srt.utils import get_bool_env_var, is_hip
logger = logging.getLogger(__name__)
@@ -33,17 +35,22 @@ class AttentionArch(IntEnum):
class ModelConfig:
def __init__(
self,
path: str,
model_path: str,
trust_remote_code: bool = True,
revision: Optional[str] = None,
context_length: Optional[int] = None,
model_override_args: Optional[dict] = None,
is_embedding: Optional[bool] = None,
dtype: str = "auto",
quantization: Optional[str] = None,
) -> None:
self.model_path = model_path
self.revision = revision
self.quantization = quantization
# Parse args
self.model_override_args = json.loads(model_override_args)
self.hf_config = get_config(
path,
model_path,
trust_remote_code=trust_remote_code,
revision=revision,
model_override_args=self.model_override_args,
@@ -56,6 +63,7 @@ class ModelConfig:
)
self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
# Derive context length
derived_context_len = get_context_length(self.hf_text_config)
@@ -116,6 +124,8 @@ class ModelConfig:
self.num_hidden_layers = self.hf_text_config.num_hidden_layers
self.vocab_size = self.hf_text_config.vocab_size
self._verify_quantization()
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
def get_total_num_kv_heads(self) -> int:
"""Returns the total number of KV heads."""
@@ -174,6 +184,86 @@ class ModelConfig:
# parallel size so each GPU has at least one KV head.
return max(1, total_num_kv_heads // tensor_parallel_size)
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _parse_quant_hf_config(self):
quant_cfg = getattr(self.hf_config, "quantization_config", None)
if quant_cfg is None:
# compressed-tensors uses a "compression_config" key
quant_cfg = getattr(self.hf_config, "compression_config", None)
return quant_cfg
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _verify_quantization(self) -> None:
supported_quantization = [*QUANTIZATION_METHODS]
rocm_supported_quantization = [
"awq",
"gptq",
"fp8",
"compressed_tensors",
"compressed-tensors",
"fbgemm_fp8",
]
optimized_quantization_methods = [
"fp8",
"marlin",
"modelopt",
"gptq_marlin_24",
"gptq_marlin",
"awq_marlin",
"fbgemm_fp8",
"compressed_tensors",
"compressed-tensors",
"experts_int8",
]
if self.quantization is not None:
self.quantization = self.quantization.lower()
# Parse quantization method from the HF model config, if available.
quant_cfg = self._parse_quant_hf_config()
if quant_cfg is not None:
quant_method = quant_cfg.get("quant_method", "").lower()
# Detect which checkpoint is it
for _, method in QUANTIZATION_METHODS.items():
quantization_override = method.override_quantization_method(
quant_cfg, self.quantization
)
if quantization_override:
quant_method = quantization_override
self.quantization = quantization_override
break
# Verify quantization configurations.
if self.quantization is None:
self.quantization = quant_method
elif self.quantization != quant_method:
raise ValueError(
"Quantization method specified in the model config "
f"({quant_method}) does not match the quantization "
f"method specified in the `quantization` argument "
f"({self.quantization})."
)
if self.quantization is not None:
if self.quantization not in supported_quantization:
raise ValueError(
f"Unknown quantization method: {self.quantization}. Must "
f"be one of {supported_quantization}."
)
if is_hip() and self.quantization not in rocm_supported_quantization:
raise ValueError(
f"{self.quantization} quantization is currently not "
f"supported in ROCm."
)
if self.quantization not in optimized_quantization_methods:
logger.warning(
"%s quantization is not fully "
"optimized yet. The speed can be slower than "
"non-quantized models.",
self.quantization,
)
def get_hf_text_config(config: PretrainedConfig):
"""Get the "sub" config relevant to llm for multi modal models.
@@ -183,6 +273,9 @@ def get_hf_text_config(config: PretrainedConfig):
if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"):
# We support non-hf version of llava models, so we do not want to
# read the wrong values from the unused default text_config.
# NOTE(HandH1998): We set `torch_dtype` of config to `torch.float16` for the weights, as
# `torch.float16` is default used for image features in `python/sglang/srt/models/llava.py`.
setattr(config, "torch_dtype", torch.float16)
return config
if hasattr(config, "text_config"):
@@ -195,6 +288,70 @@ def get_hf_text_config(config: PretrainedConfig):
return config
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.float16,
"float16": torch.float16,
"float": torch.float32,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _get_and_verify_dtype(
config: PretrainedConfig,
dtype: Union[str, torch.dtype],
) -> torch.dtype:
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None.
config_dtype = getattr(config, "torch_dtype", None)
if config_dtype is None:
config_dtype = torch.float32
if isinstance(dtype, str):
dtype = dtype.lower()
if dtype == "auto":
if config_dtype == torch.float32:
if config.model_type == "gemma2":
logger.info(
"For Gemma 2, we downcast float32 to bfloat16 instead "
"of float16 by default. Please specify `dtype` if you "
"want to use float16."
)
torch_dtype = torch.bfloat16
else:
# Following the common practice, we use float16 for float32
# models.
torch_dtype = torch.float16
else:
torch_dtype = config_dtype
else:
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
raise ValueError(f"Unknown dtype: {dtype}")
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
elif isinstance(dtype, torch.dtype):
torch_dtype = dtype
else:
raise ValueError(f"Unknown dtype: {dtype}")
# Verify the dtype.
if torch_dtype != config_dtype:
if torch_dtype == torch.float32:
# Upcasting to float32 is allowed.
logger.info("Upcasting %s to %s.", config_dtype, torch_dtype)
pass
elif config_dtype == torch.float32:
# Downcasting from float32 to float16 or bfloat16 is allowed.
logger.info("Downcasting %s to %s.", config_dtype, torch_dtype)
pass
else:
# Casting between float16 and bfloat16 is allowed with a warning.
logger.warning("Casting %s to %s.", config_dtype, torch_dtype)
return torch_dtype
def is_generation_model(model_architectures: List[str], is_embedding: bool = False):
# We have two ways to determine whether a model is a generative model.
# 1. Check the model architectue

View File

@@ -121,13 +121,10 @@ class Qwen2VLConfig(PretrainedConfig):
self.attention_dropout = attention_dropout
self.rope_scaling = rope_scaling
# NOTE: the following section from original transformers config
# for Qwen2-VL is commented out to address rope config loading issue
#
# if self.rope_scaling is not None and "type" in self.rope_scaling:
# if self.rope_scaling["type"] == "mrope":
# self.rope_scaling["type"] = "default"
# self.rope_scaling["rope_type"] = self.rope_scaling["type"]
# rope_config_validation(self)
# NOTE(HandH1998): This is necessary for configuring the `rope_type`` of qwen2vl models after removing dependencies on vllm.
if self.rope_scaling is not None and "type" in self.rope_scaling:
if self.rope_scaling["type"] == "mrope":
self.rope_scaling["type"] = "default"
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)

View File

@@ -75,6 +75,8 @@ def get_config(
if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type]
config = config_class.from_pretrained(model, revision=revision)
# NOTE(HandH1998): Qwen2VL requires `_name_or_path` attribute in `config`.
setattr(config, "_name_or_path", model)
if model_override_args:
config.update(model_override_args)

View File

@@ -42,6 +42,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"Fp8LinearMethod",
"MarlinLinearMethod",
"GPTQLinearMethod",
"QQQLinearMethod",
]

View File

@@ -31,7 +31,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from sglang.srt.layers.linear import (
ColumnParallelLinear,
@@ -40,6 +39,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_loader.loader import DefaultModelLoader
class BaseLayerWithLoRA(nn.Module):

View File

@@ -147,9 +147,12 @@ class Scheduler:
self.model_config = ModelConfig(
server_args.model_path,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding,
dtype=server_args.dtype,
quantization=server_args.quantization,
)
self.is_generation = self.model_config.is_generation

View File

@@ -109,9 +109,12 @@ class TokenizerManager:
self.model_config = ModelConfig(
server_args.model_path,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding,
dtype=server_args.dtype,
quantization=server_args.quantization,
)
self.is_generation = self.model_config.is_generation

View File

@@ -52,9 +52,12 @@ class TpModelWorker:
self.model_config = ModelConfig(
server_args.model_path,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding,
dtype=server_args.dtype,
quantization=server_args.quantization,
)
self.model_runner = ModelRunner(
model_config=self.model_config,

View File

@@ -14,22 +14,12 @@
"""ModelRunner runs the forward passes of the models."""
import gc
import importlib
import importlib.resources
import inspect
import json
import logging
import pkgutil
import time
from functools import lru_cache
from tokenize import tabsize
from typing import Any, Optional, Type, Union
from typing import Optional
import torch
import torch.distributed as dist
import torch.nn as nn
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
from vllm.distributed import (
get_tp_group,
init_distributed_environment,
@@ -37,9 +27,9 @@ from vllm.distributed import (
set_custom_all_reduce,
)
from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import ModelRegistry
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
@@ -56,16 +46,15 @@ from sglang.srt.mem_cache.memory_pool import (
ReqToTokenPool,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader import get_model
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
crash_on_warnings,
enable_show_time_cost,
get_available_gpu_memory,
init_custom_process_group,
is_hip,
monkey_patch_vllm_gguf_config,
monkey_patch_vllm_model_config,
monkey_patch_vllm_p2p_access_check,
set_cpu_offload_max_bytes,
)
@@ -228,49 +217,6 @@ class ModelRunner:
return min_per_gpu_memory
def setup_model(self):
try:
from vllm.config import VllmConfig
vllm_config = VllmConfig()
vllm_config.model_config = self.vllm_model_config
vllm_config.load_config = self.load_config
vllm_config.device_config = DeviceConfig(self.device)
vllm_config.quant_config = VllmConfig._get_quantization_config(
vllm_config.model_config, vllm_config.load_config
)
return get_model(vllm_config=vllm_config)
except ImportError:
pass
return get_model(
model_config=self.vllm_model_config,
load_config=self.load_config,
device_config=DeviceConfig(self.device),
parallel_config=None,
scheduler_config=None,
lora_config=None,
cache_config=None,
)
def get_model_config_params(self):
sig = inspect.signature(VllmModelConfig.__init__)
params = {
"model": self.server_args.model_path,
"quantization": self.server_args.quantization,
"tokenizer": None,
"tokenizer_mode": None,
"trust_remote_code": self.server_args.trust_remote_code,
"dtype": self.server_args.dtype,
"seed": self.server_args.random_seed,
"skip_tokenizer_init": True,
}
if "task" in sig.parameters:
params["task"] = ""
return params
def load_model(self):
logger.info(
f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
@@ -284,6 +230,7 @@ class ModelRunner:
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
)
self.server_args.dtype = "float16"
self.model_config.dtype = torch.float16
if torch.cuda.get_device_capability()[1] < 5:
raise RuntimeError("SGLang only supports sm75 and above.")
@@ -292,23 +239,21 @@ class ModelRunner:
load_format=self.server_args.load_format,
download_dir=self.server_args.download_dir,
)
monkey_patch_vllm_model_config()
if self.server_args.load_format == "gguf":
monkey_patch_vllm_gguf_config()
self.vllm_model_config = VllmModelConfig(**self.get_model_config_params())
if self.model_config.model_override_args is not None:
self.vllm_model_config.hf_config.update(
self.model_config.model_override_args
)
self.model = self.setup_model()
self.model = get_model(
model_config=self.model_config,
load_config=self.load_config,
device_config=DeviceConfig(self.device),
)
self.sliding_window_size = (
self.model.get_attention_sliding_window_size()
if hasattr(self.model, "get_attention_sliding_window_size")
else None
)
self.dtype = self.vllm_model_config.dtype
self.dtype = self.model_config.dtype
logger.info(
f"Load weight end. "
@@ -319,12 +264,12 @@ class ModelRunner:
def update_weights_from_disk(self, model_path: str, load_format: str):
"""Update engine weights online from disk."""
from vllm.model_executor.model_loader.loader import (
from sglang.srt.model_loader.loader import (
DefaultModelLoader,
device_loading_context,
get_model_loader,
)
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from sglang.srt.model_loader.utils import set_default_torch_dtype
logger.info(
f"Update engine weights online from disk begin. "
@@ -332,15 +277,7 @@ class ModelRunner:
)
target_device = torch.device(self.device)
try:
model_config_params = self.get_model_config_params()
model_config_params["model"] = model_path
vllm_model_config = VllmModelConfig(**model_config_params)
except Exception as e:
message = f"Failed to load model config: {e}."
return False, message
self.model_config.model_path = model_path
load_config = LoadConfig(load_format=load_format)
# Only support vllm DefaultModelLoader for now
@@ -352,7 +289,7 @@ class ModelRunner:
def get_weight_iter(config):
iter = loader._get_weights_iterator(
DefaultModelLoader.Source(
config.model,
config.model_path,
revision=config.revision,
fall_back_to_pt=getattr(
self.model, "fall_back_to_pt_during_load", True
@@ -370,9 +307,9 @@ class ModelRunner:
quant_method.process_weights_after_loading(module)
return model
with set_default_torch_dtype(vllm_model_config.dtype):
with set_default_torch_dtype(self.model_config.dtype):
try:
iter = get_weight_iter(vllm_model_config)
iter = get_weight_iter(self.model_config)
except Exception as e:
message = f"Failed to get weights iterator: {e}."
return False, message
@@ -384,16 +321,14 @@ class ModelRunner:
)
del iter
gc.collect()
iter = get_weight_iter(self.vllm_model_config)
iter = get_weight_iter(self.model_config)
self.model = model_load_weights(self.model, iter)
return False, message
self.model = model
self.server_args.model_path = model_path
self.server_args.load_format = load_format
self.vllm_model_config = vllm_model_config
self.load_config = load_config
self.model_config.path = model_path
logger.info("Update weights end.")
return True, "Succeeded to update model weights."
@@ -794,55 +729,3 @@ class ModelRunner:
if rope_scaling is None:
return False
return rope_scaling.get("type", None) == "mrope"
@lru_cache()
def import_model_classes():
model_arch_name_to_cls = {}
package_name = "sglang.srt.models"
package = importlib.import_module(package_name)
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
if not ispkg:
try:
module = importlib.import_module(name)
except Exception as e:
logger.warning(f"Ignore import error when loading {name}. {e}")
if crash_on_warnings():
raise ValueError(f"Ignore import error when loading {name}. {e}")
continue
if hasattr(module, "EntryClass"):
entry = module.EntryClass
if isinstance(
entry, list
): # To support multiple model classes in one module
for tmp in entry:
assert (
tmp.__name__ not in model_arch_name_to_cls
), f"Duplicated model implementation for {tmp.__name__}"
model_arch_name_to_cls[tmp.__name__] = tmp
else:
assert (
entry.__name__ not in model_arch_name_to_cls
), f"Duplicated model implementation for {entry.__name__}"
model_arch_name_to_cls[entry.__name__] = entry
return model_arch_name_to_cls
def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
model_arch_name_to_cls = import_model_classes()
if model_arch not in model_arch_name_to_cls:
raise ValueError(
f"Unsupported architectures: {model_arch}. "
f"Supported list: {list(model_arch_name_to_cls.keys())}"
)
return model_arch_name_to_cls[model_arch]
# Monkey patch model loader
setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)
setattr(ModelRegistry, "is_multimodal_model", lambda model_architectures: False)
setattr(ModelRegistry, "is_attention_free_model", lambda model_architectures: False)
setattr(ModelRegistry, "model_has_inner_state", lambda model_architectures: False)
setattr(ModelRegistry, "is_embedding_model", lambda model_architectures: False)

View File

@@ -0,0 +1,34 @@
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/__init__.py
from torch import nn
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.model_loader.loader import BaseModelLoader, get_model_loader
from sglang.srt.model_loader.utils import (
get_architecture_class_name,
get_model_architecture,
)
def get_model(
*,
model_config: ModelConfig,
load_config: LoadConfig,
device_config: DeviceConfig,
) -> nn.Module:
loader = get_model_loader(load_config)
return loader.load_model(
model_config=model_config,
device_config=device_config,
)
__all__ = [
"get_model",
"get_model_loader",
"BaseModelLoader",
"get_architecture_class_name",
"get_model_architecture",
]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,41 @@
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/utils.py
"""Utilities for selecting and loading models."""
import contextlib
from typing import Tuple, Type
import torch
from torch import nn
from sglang.srt.configs.model_config import ModelConfig
@contextlib.contextmanager
def set_default_torch_dtype(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
old_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(old_dtype)
def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
from sglang.srt.models.registry import ModelRegistry
architectures = getattr(model_config.hf_config, "architectures", [])
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin", "awq_marlin"]
if (
model_config.quantization is not None
and model_config.quantization not in mixtral_supported
and "MixtralForCausalLM" in architectures
):
architectures = ["QuantMixtralForCausalLM"]
return ModelRegistry.resolve_model_cls(architectures)
def get_architecture_class_name(model_config: ModelConfig) -> str:
return get_model_architecture(model_config)[1]

View File

@@ -0,0 +1,640 @@
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/weight_utils.py
"""Utilities for downloading and initializing model weights."""
import fnmatch
import glob
import hashlib
import json
import logging
import os
import tempfile
from collections import defaultdict
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
import filelock
import gguf
import huggingface_hub.constants
import numpy as np
import torch
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
from safetensors.torch import load_file, safe_open, save_file
from tqdm.auto import tqdm
from vllm.distributed import get_tensor_model_parallel_rank
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
from sglang.srt.utils import print_warning_once
logger = logging.getLogger(__name__)
# use system-level temp directory for file locks, so that multiple users
# can share the same lock without error.
# lock files in the temp directory will be automatically deleted when the
# system reboots, so users will not complain about annoying lock files
temp_dir = tempfile.gettempdir()
def enable_hf_transfer():
"""automatically activates hf_transfer"""
if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
try:
# enable hf hub transfer if available
import hf_transfer # type: ignore # noqa
huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
except ImportError:
pass
enable_hf_transfer()
class DisabledTqdm(tqdm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, disable=True)
def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
lock_dir = cache_dir or temp_dir
os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
model_name = model_name_or_path.replace("/", "-")
hash_name = hashlib.sha256(model_name.encode()).hexdigest()
# add hash to avoid conflict with old users' lock files
lock_file_name = hash_name + model_name + ".lock"
# mode 0o666 is required for the filelock to be shared across users
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), mode=0o666)
return lock
def _shared_pointers(tensors):
ptrs = defaultdict(list)
for k, v in tensors.items():
ptrs[v.data_ptr()].append(k)
failing = []
for _, names in ptrs.items():
if len(names) > 1:
failing.append(names)
return failing
def convert_bin_to_safetensor_file(
pt_filename: str,
sf_filename: str,
) -> None:
loaded = torch.load(pt_filename, map_location="cpu")
if "state_dict" in loaded:
loaded = loaded["state_dict"]
shared = _shared_pointers(loaded)
for shared_weights in shared:
for name in shared_weights[1:]:
loaded.pop(name)
# For tensors to be contiguous
loaded = {k: v.contiguous() for k, v in loaded.items()}
dirname = os.path.dirname(sf_filename)
os.makedirs(dirname, exist_ok=True)
save_file(loaded, sf_filename, metadata={"format": "pt"})
# check file size
sf_size = os.stat(sf_filename).st_size
pt_size = os.stat(pt_filename).st_size
if (sf_size - pt_size) / pt_size > 0.01:
raise RuntimeError(
f"""The file size different is more than 1%:
- {sf_filename}: {sf_size}
- {pt_filename}: {pt_size}
"""
)
# check if the tensors are the same
reloaded = load_file(sf_filename)
for k in loaded:
pt_tensor = loaded[k]
sf_tensor = reloaded[k]
if not torch.equal(pt_tensor, sf_tensor):
raise RuntimeError(f"The output tensors do not match for key {k}")
# TODO(woosuk): Move this to other place.
def get_quant_config(
model_config: ModelConfig, load_config: LoadConfig
) -> QuantizationConfig:
quant_cls = get_quantization_config(model_config.quantization)
# GGUF doesn't have config file
if model_config.quantization == "gguf":
return quant_cls.from_config({})
# Read the quantization config from the HF model config, if available.
hf_quant_config = getattr(model_config.hf_config, "quantization_config", None)
# some vision model may keep quantization_config in their text_config
hf_text_config = getattr(model_config.hf_config, "text_config", None)
if hf_quant_config is None and hf_text_config is not None:
hf_quant_config = getattr(hf_text_config, "quantization_config", None)
if hf_quant_config is None:
# compressed-tensors uses a compressions_config
hf_quant_config = getattr(model_config.hf_config, "compression_config", None)
if hf_quant_config is not None:
return quant_cls.from_config(hf_quant_config)
# In case of bitsandbytes/QLoRA, get quant config from the adapter model.
if model_config.quantization == "bitsandbytes":
if (
not load_config.model_loader_extra_config
or "qlora_adapter_name_or_path" not in load_config.model_loader_extra_config
):
return quant_cls.from_config({"adapter_name_or_path": ""})
model_name_or_path = load_config.model_loader_extra_config[
"qlora_adapter_name_or_path"
]
else:
model_name_or_path = model_config.model_path
is_local = os.path.isdir(model_name_or_path)
if not is_local:
# Download the config files.
with get_lock(model_name_or_path, load_config.download_dir):
hf_folder = snapshot_download(
model_name_or_path,
revision=model_config.revision,
allow_patterns="*.json",
cache_dir=load_config.download_dir,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
tqdm_class=DisabledTqdm,
)
else:
hf_folder = model_name_or_path
possible_config_filenames = quant_cls.get_config_filenames()
# If the quantization config is not found, use the default config.
if not possible_config_filenames:
return quant_cls()
config_files = glob.glob(os.path.join(hf_folder, "*.json"))
quant_config_files = [
f for f in config_files if any(f.endswith(x) for x in possible_config_filenames)
]
if len(quant_config_files) == 0:
raise ValueError(f"Cannot find the config file for {model_config.quantization}")
if len(quant_config_files) > 1:
raise ValueError(
f"Found multiple config files for {model_config.quantization}: "
f"{quant_config_files}"
)
quant_config_file = quant_config_files[0]
with open(quant_config_file) as f:
config = json.load(f)
if model_config.quantization == "bitsandbytes":
config["adapter_name_or_path"] = model_name_or_path
elif model_config.quantization == "modelopt":
if config["producer"]["name"] == "modelopt":
return quant_cls.from_config(config)
else:
raise ValueError(
f"Unsupported quantization config"
f" found for {model_config.quantization} in {f}."
)
return quant_cls.from_config(config)
def download_weights_from_hf(
model_name_or_path: str,
cache_dir: Optional[str],
allow_patterns: List[str],
revision: Optional[str] = None,
ignore_patterns: Optional[Union[str, List[str]]] = None,
) -> str:
"""Download model weights from Hugging Face Hub.
Args:
model_name_or_path (str): The model name or path.
cache_dir (Optional[str]): The cache directory to store the model
weights. If None, will use HF defaults.
allow_patterns (List[str]): The allowed patterns for the
weight files. Files matched by any of the patterns will be
downloaded.
revision (Optional[str]): The revision of the model.
ignore_patterns (Optional[Union[str, List[str]]]): The patterns to
filter out the weight files. Files matched by any of the patterns
will be ignored.
Returns:
str: The path to the downloaded model weights.
"""
if not huggingface_hub.constants.HF_HUB_OFFLINE:
# Before we download we look at that is available:
fs = HfFileSystem()
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
# depending on what is available we download different things
for pattern in allow_patterns:
matching = fnmatch.filter(file_list, pattern)
if len(matching) > 0:
allow_patterns = [pattern]
break
logger.info("Using model weights format %s", allow_patterns)
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir):
hf_folder = snapshot_download(
model_name_or_path,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
cache_dir=cache_dir,
tqdm_class=DisabledTqdm,
revision=revision,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
)
return hf_folder
def download_safetensors_index_file_from_hf(
model_name_or_path: str,
index_file: str,
cache_dir: Optional[str],
revision: Optional[str] = None,
) -> None:
"""Download hf safetensors index file from Hugging Face Hub.
Args:
model_name_or_path (str): The model name or path.
cache_dir (Optional[str]): The cache directory to store the model
weights. If None, will use HF defaults.
revision (Optional[str]): The revision of the model.
"""
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir):
try:
# Download the safetensors index file.
hf_hub_download(
repo_id=model_name_or_path,
filename=index_file,
cache_dir=cache_dir,
revision=revision,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
)
# If file not found on remote or locally, we should not fail since
# only some models will have index_file.
except huggingface_hub.utils.EntryNotFoundError:
logger.info("No %s found in remote.", index_file)
except huggingface_hub.utils.LocalEntryNotFoundError:
logger.info("No %s found in local cache.", index_file)
# For models like Mistral-7B-v0.3, there are both sharded
# safetensors files and a consolidated safetensors file.
# Passing both of these to the weight loader functionality breaks.
# So, we use the index_file to
# look up which safetensors files should be used.
def filter_duplicate_safetensors_files(
hf_weights_files: List[str], hf_folder: str, index_file: str
) -> List[str]:
# model.safetensors.index.json is a mapping from keys in the
# torch state_dict to safetensors file holding that weight.
index_file_name = os.path.join(hf_folder, index_file)
if not os.path.isfile(index_file_name):
return hf_weights_files
# Iterate through the weight_map (weight_name: safetensors files)
# to identify weights that we should use.
with open(index_file_name) as f:
weight_map = json.load(f)["weight_map"]
weight_files_in_index = set()
for weight_name in weight_map:
weight_files_in_index.add(os.path.join(hf_folder, weight_map[weight_name]))
# Filter out any fields that are not found in the index file.
hf_weights_files = [f for f in hf_weights_files if f in weight_files_in_index]
return hf_weights_files
def filter_files_not_needed_for_inference(hf_weights_files: List[str]) -> List[str]:
"""
Exclude files that are not needed for inference.
See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
"""
blacklist = [
"training_args.bin",
"optimizer.bin",
"optimizer.pt",
"scheduler.pt",
"scaler.pt",
]
hf_weights_files = [
f for f in hf_weights_files if not any(f.endswith(x) for x in blacklist)
]
return hf_weights_files
# explicitly use pure text format, with a newline at the end
# this makes it impossible to see the animation in the progress bar
# but will avoid messing up with ray or multiprocessing, which wraps
# each line of output with some prefix.
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501
def np_cache_weights_iterator(
model_name_or_path: str,
cache_dir: Optional[str],
hf_folder: str,
hf_weights_files: List[str],
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model np files.
Will dump the model weights to numpy files if they are not already dumped.
"""
enable_tqdm = (
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
)
# Convert the model weights from torch tensors to numpy arrays for
# faster loading.
np_folder = os.path.join(hf_folder, "np")
os.makedirs(np_folder, exist_ok=True)
weight_names_file = os.path.join(np_folder, "weight_names.json")
# Use file lock to prevent multiple processes from
# dumping the same model weights to numpy at the same time.
with get_lock(model_name_or_path, cache_dir):
if not os.path.exists(weight_names_file):
weight_names: List[str] = []
for bin_file in tqdm(
hf_weights_files,
desc="Loading np_cache checkpoint shards",
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
):
state = torch.load(bin_file, map_location="cpu")
for name, param in state.items():
param_path = os.path.join(np_folder, name)
with open(param_path, "wb") as f:
np.save(f, param.cpu().detach().numpy())
weight_names.append(name)
with open(weight_names_file, "w") as f:
json.dump(weight_names, f)
with open(weight_names_file) as f:
weight_names = json.load(f)
for name in weight_names:
param_path = os.path.join(np_folder, name)
with open(param_path, "rb") as f:
param = np.load(f)
yield name, torch.from_numpy(param)
def safetensors_weights_iterator(
hf_weights_files: List[str],
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files."""
enable_tqdm = (
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
)
for st_file in tqdm(
hf_weights_files,
desc="Loading safetensors checkpoint shards",
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
):
with safe_open(st_file, framework="pt") as f:
for name in f.keys(): # noqa: SIM118
param = f.get_tensor(name)
yield name, param
def pt_weights_iterator(
hf_weights_files: List[str],
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model bin/pt files."""
enable_tqdm = (
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
)
for bin_file in tqdm(
hf_weights_files,
desc="Loading pt checkpoint shards",
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
):
state = torch.load(bin_file, map_location="cpu")
yield from state.items()
del state
torch.cuda.empty_cache()
def get_gguf_extra_tensor_names(
gguf_file: str, gguf_to_hf_name_map: Dict[str, str]
) -> List[str]:
reader = gguf.GGUFReader(gguf_file)
expected_gguf_keys = set(gguf_to_hf_name_map.keys())
exact_gguf_keys = set([tensor.name for tensor in reader.tensors])
extra_keys = expected_gguf_keys - exact_gguf_keys
return [gguf_to_hf_name_map[key] for key in extra_keys]
def gguf_quant_weights_iterator(
gguf_file: str, gguf_to_hf_name_map: Dict[str, str]
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""
Iterate over the quant weights in the model gguf files and convert
them to torch tensors
"""
reader = gguf.GGUFReader(gguf_file)
for tensor in reader.tensors:
if tensor.name in gguf_to_hf_name_map:
weight_type = tensor.tensor_type
name = gguf_to_hf_name_map[tensor.name]
if weight_type.name != "F32":
weight_type_name = name.replace("weight", "qweight_type")
weight_type = torch.tensor(weight_type)
yield weight_type_name, weight_type
for tensor in reader.tensors:
if tensor.name in gguf_to_hf_name_map:
weight = tensor.data
weight_type = tensor.tensor_type
name = gguf_to_hf_name_map[tensor.name]
if weight_type.name != "F32":
name = name.replace("weight", "qweight")
param = torch.tensor(weight)
yield name, param
def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
"""convert PySafeSlice object from safetensors to torch.Tensor
PySafeSlice object supports indexing, which is done before loading the
actual tensor and can reduce the amount of memory being read into the
memory. However, it does not support more advanced functionalities
like `.view()` or `.t()`. Therefore, if we need to modify the loaded
tensor with these more complicated operators, we need to convert to
tensor first.
"""
if not isinstance(x, torch.Tensor):
x = x[:]
return x
def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
"""Default weight loader."""
try:
if param.numel() == 1 and loaded_weight.numel() == 1:
# Sometimes scalar values aren't considered tensors with shapes
# so if both param and loaded_weight are a scalar,
# "broadcast" instead of copy
param.data.fill_(loaded_weight.item())
else:
assert param.size() == loaded_weight.size(), (
f"Attempted to load weight ({loaded_weight.size()}) "
f"into parameter ({param.size()})"
)
param.data.copy_(loaded_weight)
except Exception:
# NOTE: This exception is added for the purpose of setting breakpoint to
# debug weight loading issues.
raise
def row_parallel_weight_loader(
param: torch.Tensor, loaded_weight: torch.Tensor
) -> None:
"""Load weights that are row-parallelized."""
tp_rank = get_tensor_model_parallel_rank()
shard_dim = 0 if param.dim() != 1 else None
if shard_dim is not None:
shard_size = param.data.shape[shard_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(shard_dim, start_idx, shard_size)
return default_weight_loader(param, loaded_weight)
LoaderFunction = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
def sharded_weight_loader(shard_axis: int) -> LoaderFunction:
"""Create a weight loader that shards the weights along the given axis"""
def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
tp_rank = get_tensor_model_parallel_rank()
shard_size = param.data.shape[shard_axis]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(shard_axis, start_idx, shard_size)
return default_weight_loader(param, loaded_weight)
return loader
def composed_weight_loader(
loader: LoaderFunction, fn: Callable[[torch.Tensor], torch.Tensor]
) -> LoaderFunction:
"""Create a weight loader that post-processes the weights after loading"""
def composed_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
loader(param, loaded_weight)
param.data.copy_(fn(param))
return
return composed_loader
def initialize_dummy_weights(
model: torch.nn.Module,
low: float = -1e-3,
high: float = 1e-3,
seed: int = 1234,
) -> None:
"""Initialize model weights with random values.
The model weights must be randomly initialized for accurate performance
measurements. Additionally, the model weights should not cause NaNs in the
forward pass. We empirically found that initializing the weights with
values between -1e-3 and 1e-3 works well for most models.
We use per-parameter random seed, so that dummy weights are consistent,
even if the model is partitioned across multiple devices. When the seed
is fixed, the random values generated by this function only depends on
the parameter's number of elements and its data type.
"""
for param in model.state_dict().values():
if torch.is_floating_point(param):
generator = torch.Generator(device=param.data.device)
generator.manual_seed(seed)
if torch.finfo(param.data.dtype).bits < 16:
# uniform_ doesn't support < 16-bit datatypes (FP8)
dtype = param.data.dtype
tmp_param = param.data.to(torch.float16)
tmp_param = tmp_param.uniform_(low, high, generator=generator).to(dtype)
param.data.copy_(tmp_param)
else:
param.uniform_(low, high, generator=generator)
def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
"""Remap the name of FP8 k/v_scale parameters.
This function handles the remapping of FP8 k/v_scale parameter names.
It detects if the given name ends with a suffix and attempts to remap
it to the expected name format in the model. If the remapped name is not
found in the params_dict, a warning is printed and None is returned.
Args:
name (str): The original loaded checkpoint parameter name.
params_dict (dict): Dictionary containing the model's named parameters.
Returns:
str: The remapped parameter name if successful, or the original name
if no remapping is needed.
None: If the remapped name is not found in params_dict.
"""
if name.endswith(".kv_scale"):
print_warning_once(
"DEPRECATED. Found kv_scale in the checkpoint. "
"This format is deprecated in favor of separate k_scale and "
"v_scale tensors and will be removed in a future release. "
"Functionally, we will remap kv_scale to k_scale and duplicate "
"k_scale to v_scale"
)
# NOTE: we remap the deprecated kv_scale to k_scale
remapped_name = name.replace(".kv_scale", ".attn.k_scale")
if remapped_name not in params_dict:
print_warning_once(
f"Found kv_scale in the checkpoint (e.g. {name}), "
"but not found the expected name in the model "
f"(e.g. {remapped_name}). kv_scale is "
"not loaded."
)
return None
return remapped_name
possible_scale_names = [".k_scale", ".v_scale"]
for scale_name in possible_scale_names:
if name.endswith(scale_name):
remapped_name = name.replace(scale_name, f".attn{scale_name}")
if remapped_name not in params_dict:
print_warning_once(
f"Found {scale_name} in the checkpoint (e.g. {name}), "
"but not found the expected name in the model "
f"(e.g. {remapped_name}). {scale_name} is "
"not loaded."
)
return None
return remapped_name
# If there were no matches, return the untouched param name
return name

View File

@@ -34,7 +34,6 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
@@ -46,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
@@ -329,7 +329,6 @@ class BaiChuanBaseForCausalLM(nn.Module):
self,
config: PretrainedConfig,
position_embedding: str,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
@@ -404,13 +403,12 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
def __init__(
self,
config,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
):
if config.hidden_size == 4096: # baichuan2 7b
super().__init__(config, "ROPE", cache_config, quant_config)
super().__init__(config, "ROPE", quant_config)
else: # baichuan 13b, baichuan2 13b
super().__init__(config, "ALIBI", cache_config, quant_config)
super().__init__(config, "ALIBI", quant_config)
EntryClass = [BaichuanForCausalLM]

View File

@@ -23,7 +23,6 @@ from torch import nn
from torch.nn import LayerNorm
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.transformers_utils.configs import ChatGLMConfig
from sglang.srt.layers.activation import SiluAndMul
@@ -41,6 +40,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
LoraConfig = None
@@ -50,7 +50,6 @@ class GLMAttention(nn.Module):
self,
config,
layer_id: int = 0,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
@@ -186,7 +185,6 @@ class GLMBlock(nn.Module):
self,
config,
layer_id: int,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
@@ -203,7 +201,7 @@ class GLMBlock(nn.Module):
)
# Self attention.
self.self_attention = GLMAttention(config, layer_id, cache_config, quant_config)
self.self_attention = GLMAttention(config, layer_id, quant_config)
self.hidden_dropout = config.hidden_dropout
# Layernorm on the attention output
@@ -258,7 +256,6 @@ class GLMTransformer(nn.Module):
def __init__(
self,
config,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
@@ -269,10 +266,7 @@ class GLMTransformer(nn.Module):
# Transformer layers.
self.layers = nn.ModuleList(
[
GLMBlock(config, i, cache_config, quant_config)
for i in range(self.num_layers)
]
[GLMBlock(config, i, quant_config) for i in range(self.num_layers)]
)
if self.post_layer_norm:
@@ -306,7 +300,6 @@ class ChatGLMM(nn.Module):
def __init__(
self,
config,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
@@ -318,7 +311,7 @@ class ChatGLMM(nn.Module):
self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num
self.kv_channels = config.kv_channels
self.encoder = GLMTransformer(config, cache_config, quant_config)
self.encoder = GLMTransformer(config, quant_config)
self.output_layer = ParallelLMHead(config.padded_vocab_size, config.hidden_size)
@@ -357,15 +350,13 @@ class ChatGLMForCausalLM(nn.Module):
def __init__(
self,
config: ChatGLMConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoraConfig] = None,
):
super().__init__()
self.config: ChatGLMConfig = config
self.quant_config = quant_config
self.max_position_embeddings = getattr(config, "max_sequence_length", 8192)
self.transformer = ChatGLMM(config, cache_config, quant_config)
self.transformer = ChatGLMM(config, quant_config)
self.lm_head = self.transformer.output_layer
self.logits_processor = LogitsProcessor(config)

View File

@@ -49,7 +49,6 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.linear import (
@@ -62,6 +61,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import set_weight_attrs
@@ -318,7 +318,6 @@ class CohereForCausalLM(nn.Module):
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None:
super().__init__()
self.config = config

View File

@@ -25,7 +25,6 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.transformers_utils.configs.dbrx import DbrxConfig
from sglang.srt.layers.fused_moe_triton import fused_moe
@@ -43,6 +42,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import set_weight_attrs
@@ -366,7 +366,6 @@ class DbrxForCausalLM(nn.Module):
self,
config: DbrxConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
):
super().__init__()
self.config = config

View File

@@ -27,7 +27,6 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.fused_moe_triton import fused_moe
@@ -46,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
class DeepseekMLP(nn.Module):
@@ -184,7 +184,6 @@ class DeepseekAttention(nn.Module):
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
@@ -261,7 +260,6 @@ class DeepseekDecoderLayer(nn.Module):
self,
config: PretrainedConfig,
layer_id: int,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
@@ -277,7 +275,6 @@ class DeepseekDecoderLayer(nn.Module):
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
)
if (
@@ -330,7 +327,6 @@ class DeepseekModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
@@ -343,9 +339,7 @@ class DeepseekModel(nn.Module):
)
self.layers = nn.ModuleList(
[
DeepseekDecoderLayer(
config, layer_id, cache_config, quant_config=quant_config
)
DeepseekDecoderLayer(config, layer_id, quant_config=quant_config)
for layer_id in range(config.num_hidden_layers)
]
)
@@ -373,13 +367,12 @@ class DeepseekForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = DeepseekModel(config, cache_config, quant_config)
self.model = DeepseekModel(config, quant_config)
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)

View File

@@ -28,7 +28,6 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.fused_moe_triton import FusedMoE
@@ -48,6 +47,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import is_flashinfer_available
if is_flashinfer_available():
@@ -189,7 +189,6 @@ class DeepseekV2Attention(nn.Module):
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
layer_id=None,
) -> None:
@@ -337,7 +336,6 @@ class DeepseekV2AttentionMLA(nn.Module):
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
layer_id=None,
use_dp=False,
@@ -568,7 +566,6 @@ class DeepseekV2DecoderLayer(nn.Module):
self,
config: PretrainedConfig,
layer_id: int,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
@@ -599,7 +596,6 @@ class DeepseekV2DecoderLayer(nn.Module):
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
layer_id=layer_id,
use_dp=self.enable_dp_attention,
@@ -619,7 +615,6 @@ class DeepseekV2DecoderLayer(nn.Module):
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
layer_id=layer_id,
)
@@ -685,7 +680,6 @@ class DeepseekV2Model(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
@@ -702,7 +696,6 @@ class DeepseekV2Model(nn.Module):
DeepseekV2DecoderLayer(
config,
layer_id,
cache_config=cache_config,
quant_config=quant_config,
)
for layer_id in range(config.num_hidden_layers)
@@ -733,13 +726,12 @@ class DeepseekV2ForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = DeepseekV2Model(config, cache_config, quant_config)
self.model = DeepseekV2Model(config, quant_config)
if global_server_args_dict["enable_dp_attention"]:
self.lm_head = ReplicatedLinear(
config.hidden_size,

View File

@@ -22,7 +22,6 @@ import torch
from torch import nn
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
@@ -39,6 +38,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
class ExaoneGatedMLP(nn.Module):
@@ -293,7 +293,6 @@ class ExaoneForCausalLM(nn.Module):
self,
config,
quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None:
super().__init__()
self.config = config

View File

@@ -21,10 +21,8 @@ from typing import Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.config import LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.layernorm import RMSNorm
@@ -38,6 +36,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
class GemmaMLP(nn.Module):
@@ -278,10 +277,7 @@ class GemmaForCausalLM(nn.Module):
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
cache_config=None,
) -> None:
del lora_config # Unused.
super().__init__()
self.config = config
self.quant_config = quant_config

View File

@@ -20,12 +20,8 @@ from typing import Iterable, Optional, Set, Tuple, Union
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.config import LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
# from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.layernorm import GemmaRMSNorm
from sglang.srt.layers.linear import (
@@ -38,6 +34,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import make_layers
@@ -106,7 +103,6 @@ class Gemma2Attention(nn.Module):
head_dim: int,
max_position_embeddings: int,
rope_theta: float,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
@@ -191,7 +187,6 @@ class Gemma2DecoderLayer(nn.Module):
self,
layer_id: int,
config: PretrainedConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
@@ -205,7 +200,6 @@ class Gemma2DecoderLayer(nn.Module):
head_dim=config.head_dim,
max_position_embeddings=config.max_position_embeddings,
rope_theta=config.rope_theta,
cache_config=cache_config,
quant_config=quant_config,
)
self.hidden_size = config.hidden_size
@@ -258,7 +252,6 @@ class Gemma2Model(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
@@ -273,7 +266,6 @@ class Gemma2Model(nn.Module):
lambda idx, prefix: Gemma2DecoderLayer(
layer_id=idx,
config=config,
cache_config=cache_config,
quant_config=quant_config,
),
prefix="",
@@ -342,15 +334,12 @@ class Gemma2ForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
del lora_config # Unused.
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = Gemma2Model(config, cache_config, quant_config)
self.model = Gemma2Model(config, quant_config)
self.logits_processor = LogitsProcessor(config)
@torch.no_grad()

View File

@@ -29,7 +29,6 @@ class Gemma2ForSequenceClassification(nn.Module):
self,
config: Gemma2Config,
quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None:
super().__init__()
self.config = config

View File

@@ -22,11 +22,9 @@ from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import GPT2Config
from vllm.config import CacheConfig
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
# from sglang.srt.layers.activation import get_act_fn
from sglang.srt.layers.linear import (
@@ -39,6 +37,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
class GPT2Attention(nn.Module):
@@ -47,7 +46,6 @@ class GPT2Attention(nn.Module):
self,
layer_id: int,
config: GPT2Config,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
@@ -140,7 +138,6 @@ class GPT2Block(nn.Module):
self,
layer_id: int,
config: GPT2Config,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
@@ -150,7 +147,7 @@ class GPT2Block(nn.Module):
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPT2Attention(
layer_id, config, cache_config, quant_config, prefix=f"{prefix}.attn"
layer_id, config, quant_config, prefix=f"{prefix}.attn"
)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPT2MLP(inner_dim, config, quant_config, prefix=f"{prefix}.mlp")
@@ -182,7 +179,6 @@ class GPT2Model(nn.Module):
def __init__(
self,
config: GPT2Config,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
@@ -196,7 +192,7 @@ class GPT2Model(nn.Module):
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList(
[
GPT2Block(i, config, cache_config, quant_config)
GPT2Block(i, config, quant_config)
for i in range(config.num_hidden_layers)
]
)
@@ -226,15 +222,12 @@ class GPT2LMHeadModel(nn.Module):
def __init__(
self,
config: GPT2Config,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.quant_config = quant_config
self.transformer = GPT2Model(
config, cache_config, quant_config, prefix="transformer"
)
self.transformer = GPT2Model(config, quant_config, prefix="transformer")
self.lm_head = self.transformer.wte
self.logits_processor = LogitsProcessor(config)

View File

@@ -21,9 +21,7 @@ from typing import Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import GPTBigCodeConfig
from vllm.config import LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import get_act_fn
from sglang.srt.layers.linear import (
@@ -36,6 +34,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
class GPTBigCodeAttention(nn.Module):
@@ -44,7 +43,6 @@ class GPTBigCodeAttention(nn.Module):
self,
layer_id: int,
config: GPTBigCodeConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
@@ -145,7 +143,6 @@ class GPTBigCodeBlock(nn.Module):
self,
layer_id: int,
config: GPTBigCodeConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
@@ -153,7 +150,7 @@ class GPTBigCodeBlock(nn.Module):
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPTBigCodeAttention(layer_id, config, cache_config, quant_config)
self.attn = GPTBigCodeAttention(layer_id, config, quant_config)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPTBigMLP(inner_dim, config, quant_config)
@@ -183,20 +180,14 @@ class GPTBigCodeModel(nn.Module):
def __init__(
self,
config: GPTBigCodeConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__()
self.config = config
assert not config.add_cross_attention
self.embed_dim = config.hidden_size
lora_vocab = (
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
if lora_config
else 0
)
lora_vocab = 0
self.vocab_size = config.vocab_size + lora_vocab
self.wte = VocabParallelEmbedding(
self.vocab_size, self.embed_dim, org_num_embeddings=config.vocab_size
@@ -204,7 +195,7 @@ class GPTBigCodeModel(nn.Module):
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList(
[
GPTBigCodeBlock(i, config, cache_config, quant_config)
GPTBigCodeBlock(i, config, quant_config)
for i in range(config.num_hidden_layers)
]
)
@@ -243,23 +234,16 @@ class GPTBigCodeForCausalLM(nn.Module):
def __init__(
self,
config: GPTBigCodeConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__()
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.transformer = GPTBigCodeModel(
config, cache_config, quant_config, lora_config
)
self.transformer = GPTBigCodeModel(config, quant_config)
self.lm_head = self.transformer.wte
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.logits_processor = LogitsProcessor(config)
@torch.no_grad()

View File

@@ -24,7 +24,6 @@ from torch import nn
from transformers import PretrainedConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm
@@ -43,6 +42,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.loader import DefaultModelLoader
from sglang.srt.model_loader.weight_utils import default_weight_loader
class Grok1MoE(nn.Module):
@@ -285,7 +286,6 @@ class Grok1ForCausalLM(nn.Module):
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None:
super().__init__()
self.config = config

View File

@@ -21,7 +21,6 @@ from torch import nn
from transformers import PretrainedConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
@@ -38,6 +37,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
class InternLM2MLP(nn.Module):
@@ -251,7 +251,6 @@ class InternLM2ForCausalLM(nn.Module):
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None:
super().__init__()
self.config = config

View File

@@ -29,7 +29,6 @@ class InternLM2ForRewardModel(nn.Module):
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None:
super().__init__()
self.config = config

View File

@@ -24,7 +24,6 @@ from torch import nn
from transformers import LlamaConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
@@ -44,6 +43,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import make_layers
from sglang.utils import get_exception_traceback
@@ -300,7 +300,6 @@ class LlamaForCausalLM(nn.Module):
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None:
super().__init__()
self.config = config

View File

@@ -17,11 +17,11 @@ from typing import Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import LlamaConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
@@ -30,7 +30,6 @@ class LlamaForClassification(nn.Module):
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None:
super().__init__()
self.config = config

View File

@@ -3,10 +3,10 @@ from typing import Iterable, Tuple
import torch
from torch import nn
from transformers import LlamaConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.model_executor.model_runner import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llama import LlamaModel
@@ -15,7 +15,6 @@ class LlamaEmbeddingModel(nn.Module):
self,
config: LlamaConfig,
quant_config=None,
cache_config=None,
) -> None:
super().__init__()
self.model = LlamaModel(config, quant_config=quant_config)

View File

@@ -21,6 +21,7 @@ from transformers import LlamaConfig
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
@@ -29,7 +30,6 @@ class LlamaForSequenceClassification(nn.Module):
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None:
super().__init__()
self.config = config
@@ -84,9 +84,8 @@ class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassific
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None:
super().__init__(config, quant_config, cache_config)
super().__init__(config, quant_config)
self.weights = self.Weights(config.hidden_size, self.num_labels)
@torch.no_grad()

View File

@@ -29,7 +29,6 @@ from transformers import (
SiglipVisionModel,
)
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.schedule_batch import ImageInputs
@@ -39,6 +38,7 @@ from sglang.srt.mm_utils import (
unpad_image_shape,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llama import LlamaForCausalLM
from sglang.srt.models.mistral import MistralForCausalLM
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
@@ -451,7 +451,6 @@ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM):
self,
config: LlavaConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None:
super().__init__()
@@ -473,7 +472,6 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
self,
config: LlavaConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None:
super().__init__()
@@ -506,7 +504,6 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
self,
config: LlavaConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None:
super().__init__()

View File

@@ -20,11 +20,11 @@ import torch
from torch import nn
from transformers import CLIPVisionModel, LlavaConfig
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llama import LlamaForCausalLM
@@ -33,7 +33,6 @@ class LlavaVidForCausalLM(nn.Module):
self,
config: LlavaConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None:
super().__init__()
self.config = config

View File

@@ -20,7 +20,6 @@ import torch
from torch import nn
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
@@ -37,6 +36,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
class MiniCPMMLP(nn.Module):
@@ -275,7 +275,6 @@ class MiniCPMForCausalLM(nn.Module):
self,
config,
quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None:
super().__init__()
self.config = config

View File

@@ -27,7 +27,6 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
@@ -40,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import is_flashinfer_available
if is_flashinfer_available():
@@ -105,7 +105,6 @@ class MiniCPM3Attention(nn.Module):
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
layer_id=None,
) -> None:
@@ -249,7 +248,6 @@ class MiniCPM3AttentionMLA(nn.Module):
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
layer_id=None,
) -> None:
@@ -406,7 +404,6 @@ class MiniCPM3DecoderLayer(nn.Module):
self,
config: PretrainedConfig,
layer_id: int,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
@@ -430,7 +427,6 @@ class MiniCPM3DecoderLayer(nn.Module):
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
layer_id=layer_id,
)
@@ -449,7 +445,6 @@ class MiniCPM3DecoderLayer(nn.Module):
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
layer_id=layer_id,
)
@@ -498,7 +493,6 @@ class MiniCPM3Model(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
@@ -512,9 +506,7 @@ class MiniCPM3Model(nn.Module):
)
self.layers = nn.ModuleList(
[
MiniCPM3DecoderLayer(
config, i, cache_config=cache_config, quant_config=quant_config
)
MiniCPM3DecoderLayer(config, i, quant_config=quant_config)
for i in range(config.num_hidden_layers)
]
)
@@ -549,7 +541,6 @@ class MiniCPM3ForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
@@ -557,9 +548,7 @@ class MiniCPM3ForCausalLM(nn.Module):
self.num_experts = getattr(self.config, "num_experts", 0)
self.quant_config = quant_config
self.model = MiniCPM3Model(
config, cache_config=cache_config, quant_config=quant_config
)
self.model = MiniCPM3Model(config, quant_config=quant_config)
# self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
if not self.config.tie_word_embeddings:
self.lm_head = ParallelLMHead(

View File

@@ -23,7 +23,6 @@ from torch import nn
from transformers import MixtralConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm
@@ -42,6 +41,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
class MixtralMoE(nn.Module):
@@ -291,7 +291,6 @@ class MixtralForCausalLM(nn.Module):
self,
config: MixtralConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None:
super().__init__()
self.config = config

View File

@@ -29,7 +29,6 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
@@ -45,6 +44,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
class MixtralMLP(nn.Module):
@@ -324,7 +324,6 @@ class QuantMixtralForCausalLM(nn.Module):
self,
config: MixtralConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None:
super().__init__()
self.config = config

View File

@@ -15,7 +15,6 @@ from transformers.models.mllama.modeling_mllama import (
_prepare_aspect_ratio_attention_mask,
)
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import get_act_fn
from sglang.srt.layers.layernorm import RMSNorm
@@ -34,6 +33,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
)
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llama import LlamaDecoderLayer, LlamaMLP
@@ -654,7 +654,6 @@ class MllamaTextModel(nn.Module):
self,
config: config_mllama.MllamaTextConfig,
quant_config: Optional[QuantizationConfig],
cache_config=None,
):
super().__init__()
self.padding_id = config.pad_token_id
@@ -732,11 +731,10 @@ class MllamaForCausalLM(nn.Module):
self,
config: config_mllama.MllamaTextConfig,
quant_config: Optional[QuantizationConfig],
cache_config=None,
):
super().__init__()
self.vocab_size = config.vocab_size
self.model = MllamaTextModel(config, cache_config, quant_config)
self.model = MllamaTextModel(config, quant_config)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
@@ -772,7 +770,6 @@ class MllamaForConditionalGeneration(nn.Module):
self,
config: config_mllama.MllamaConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
):
super().__init__()
self.vocab_size = config.text_config.vocab_size
@@ -787,7 +784,6 @@ class MllamaForConditionalGeneration(nn.Module):
self.vision_model = MllamaVisionModel(config.vision_config)
self.language_model = MllamaForCausalLM(
config.text_config,
cache_config=cache_config,
quant_config=quant_config,
)
self.multi_modal_projector = nn.Linear(

View File

@@ -22,7 +22,6 @@ from torch import nn
from transformers import OlmoConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.linear import (
@@ -38,6 +37,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import make_layers
@@ -274,7 +274,6 @@ class OlmoForCausalLM(nn.Module):
def __init__(
self,
config: OlmoConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()

View File

@@ -312,7 +312,6 @@ class Olmo2ForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()

View File

@@ -34,8 +34,6 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.utils import print_warning_once
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.fused_moe_triton import FusedMoE
@@ -48,7 +46,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import make_layers
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import make_layers, print_warning_once
class OlmoeMoE(nn.Module):
@@ -300,7 +299,6 @@ class OlmoeForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()

View File

@@ -7,8 +7,6 @@ from transformers import Phi3Config
from transformers.configuration_utils import PretrainedConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import make_layers
from sglang.srt.layers.linear import (
MergedColumnParallelLinear,
@@ -27,6 +25,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import make_layers
@torch.jit.script
@@ -235,7 +235,6 @@ class Phi3SmallDecoderLayer(nn.Module):
self,
config: PretrainedConfig,
layer_id: int,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
@@ -286,7 +285,6 @@ class Phi3SmallModel(nn.Module):
super().__init__()
self.config = config
cache_config = None
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.hidden_size
)
@@ -294,7 +292,7 @@ class Phi3SmallModel(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Phi3SmallDecoderLayer(
config, int(prefix.split(".")[-1]), cache_config, quant_config
config, int(prefix.split(".")[-1]), quant_config
),
prefix=f"{prefix}.layers",
)
@@ -339,7 +337,6 @@ class Phi3SmallForCausalLM(nn.Module):
self,
config: Phi3Config,
quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
):
super().__init__()

View File

@@ -22,7 +22,6 @@ from torch import nn
from transformers import PretrainedConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
@@ -39,6 +38,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
class QWenMLP(nn.Module):
@@ -242,7 +242,6 @@ class QWenLMHeadModel(nn.Module):
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
):
super().__init__()
self.config = config

View File

@@ -22,7 +22,6 @@ import torch
from torch import nn
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
@@ -40,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import make_layers
Qwen2Config = None
@@ -271,7 +271,6 @@ class Qwen2ForCausalLM(nn.Module):
self,
config: Qwen2Config,
quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None:
super().__init__()
self.config = config

View File

@@ -27,7 +27,6 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.fused_moe_triton import FusedMoE
@@ -48,6 +47,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
class Qwen2MoeMLP(nn.Module):
@@ -158,7 +158,6 @@ class Qwen2MoeAttention(nn.Module):
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
@@ -234,7 +233,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
self,
config: PretrainedConfig,
layer_id: int,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
@@ -250,7 +248,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
)
@@ -304,7 +301,6 @@ class Qwen2MoeModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
@@ -317,9 +313,7 @@ class Qwen2MoeModel(nn.Module):
)
self.layers = nn.ModuleList(
[
Qwen2MoeDecoderLayer(
config, layer_id, cache_config, quant_config=quant_config
)
Qwen2MoeDecoderLayer(config, layer_id, quant_config=quant_config)
for layer_id in range(config.num_hidden_layers)
]
)
@@ -353,14 +347,13 @@ class Qwen2MoeForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.torchao_config = global_server_args_dict["torchao_config"]
self.model = Qwen2MoeModel(config, cache_config, quant_config)
self.model = Qwen2MoeModel(config, quant_config)
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)

View File

@@ -30,12 +30,10 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import QuickGELU
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
from sglang.srt.hf_transformers_utils import get_processor
@@ -49,6 +47,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2Model
logger = init_logger(__name__)
@@ -536,7 +535,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
def __init__(
self,
config: Qwen2VLConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()

View File

@@ -0,0 +1,99 @@
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/models/registry.py
import importlib
import logging
import pkgutil
from dataclasses import dataclass, field
from functools import lru_cache
from typing import AbstractSet, Dict, List, Optional, Tuple, Type, Union
import torch.nn as nn
logger = logging.getLogger(__name__)
@dataclass
class _ModelRegistry:
# Keyed by model_arch
models: Dict[str, Union[Type[nn.Module], str]] = field(default_factory=dict)
def get_supported_archs(self) -> AbstractSet[str]:
return self.models.keys()
def _raise_for_unsupported(self, architectures: List[str]):
all_supported_archs = self.get_supported_archs()
if any(arch in all_supported_archs for arch in architectures):
raise ValueError(
f"Model architectures {architectures} failed "
"to be inspected. Please check the logs for more details."
)
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {all_supported_archs}"
)
def _try_load_model_cls(self, model_arch: str) -> Optional[Type[nn.Module]]:
if model_arch not in self.models:
return None
return self.models[model_arch]
def _normalize_archs(
self,
architectures: Union[str, List[str]],
) -> List[str]:
if isinstance(architectures, str):
architectures = [architectures]
if not architectures:
logger.warning("No model architectures are specified")
return architectures
def resolve_model_cls(
self,
architectures: Union[str, List[str]],
) -> Tuple[Type[nn.Module], str]:
architectures = self._normalize_archs(architectures)
for arch in architectures:
model_cls = self._try_load_model_cls(arch)
if model_cls is not None:
return (model_cls, arch)
return self._raise_for_unsupported(architectures)
@lru_cache()
def import_model_classes():
model_arch_name_to_cls = {}
package_name = "sglang.srt.models"
package = importlib.import_module(package_name)
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
if not ispkg:
try:
module = importlib.import_module(name)
except Exception as e:
logger.warning(f"Ignore import error when loading {name}. " f"{e}")
continue
if hasattr(module, "EntryClass"):
entry = module.EntryClass
if isinstance(
entry, list
): # To support multiple model classes in one module
for tmp in entry:
assert (
tmp.__name__ not in model_arch_name_to_cls
), f"Duplicated model implementation for {tmp.__name__}"
model_arch_name_to_cls[tmp.__name__] = tmp
else:
assert (
entry.__name__ not in model_arch_name_to_cls
), f"Duplicated model implementation for {entry.__name__}"
model_arch_name_to_cls[entry.__name__] = entry
return model_arch_name_to_cls
ModelRegistry = _ModelRegistry(import_model_classes())

View File

@@ -26,7 +26,6 @@ from torch import nn
from transformers import PretrainedConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.linear import (
@@ -42,6 +41,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
class StablelmMLP(nn.Module):
@@ -242,7 +242,6 @@ class StableLmForCausalLM(nn.Module):
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None:
super().__init__()
self.config = config

View File

@@ -52,7 +52,6 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
@@ -66,6 +65,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
@@ -388,7 +388,6 @@ class TorchNativeLlamaForCausalLM(nn.Module):
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None:
super().__init__()
self.config = config

View File

@@ -30,7 +30,6 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -40,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from sglang.srt.model_executor.model_runner import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
class XverseMLP(nn.Module):
@@ -295,8 +295,6 @@ class XverseForCausalLM(nn.Module):
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
efficient_weight_load=False,
) -> None:
super().__init__()
self.config = config

View File

@@ -32,7 +32,6 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.fused_moe_triton import fused_moe
from sglang.srt.layers.logits_processor import LogitsProcessor
@@ -43,6 +42,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
class XverseMLP(nn.Module):
@@ -181,7 +181,6 @@ class XverseAttention(nn.Module):
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
@@ -258,7 +257,6 @@ class XverseDecoderLayer(nn.Module):
self,
config: PretrainedConfig,
layer_id: int,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
@@ -277,7 +275,6 @@ class XverseDecoderLayer(nn.Module):
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
)
if config.num_experts is not None:
@@ -326,7 +323,6 @@ class XverseModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
@@ -339,9 +335,7 @@ class XverseModel(nn.Module):
)
self.layers = nn.ModuleList(
[
XverseDecoderLayer(
config, layer_id, cache_config, quant_config=quant_config
)
XverseDecoderLayer(config, layer_id, quant_config=quant_config)
for layer_id in range(config.num_hidden_layers)
]
)
@@ -369,13 +363,12 @@ class XverseMoeForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = XverseModel(config, cache_config, quant_config)
self.model = XverseModel(config, quant_config)
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)

View File

@@ -18,9 +18,9 @@ from typing import Iterable, Optional, Tuple
import torch
import torch.nn as nn
from transformers import CLIPVisionModel, LlavaConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llava import LlavaLlamaForCausalLM
@@ -29,9 +29,8 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
self,
config: LlavaConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None:
super().__init__(config, quant_config, cache_config)
super().__init__(config, quant_config)
self.multi_modal_projector = YiVLMultiModalProjector(self.config)
self.vision_tower_subfolder = self.config.mm_vision_tower.replace(

View File

@@ -50,6 +50,7 @@ class ServerArgs:
served_model_name: Optional[str] = None
chat_template: Optional[str] = None
is_embedding: bool = False
revision: Optional[str] = None
# Port
host: str = "127.0.0.1"
@@ -341,6 +342,14 @@ class ServerArgs:
action="store_true",
help="Whether to use a CausalLM as an embedding model.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
help="The specific model version to use. It can be a branch "
"name, a tag name, or a commit id. If unspecified, will use "
"the default version.",
)
# Memory and scheduling
parser.add_argument(

View File

@@ -430,16 +430,12 @@ def suppress_other_loggers():
from vllm.logger import logger as vllm_default_logger
vllm_default_logger.setLevel(logging.WARN)
logging.getLogger("vllm.config").setLevel(logging.ERROR)
logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel(
logging.WARN
)
logging.getLogger("vllm.distributed.device_communicators.shm_broadcast").setLevel(
logging.WARN
)
logging.getLogger("vllm.selector").setLevel(logging.WARN)
logging.getLogger("vllm.utils").setLevel(logging.ERROR)
logging.getLogger("vllm.model_executor.model_loader.loader").setLevel(logging.ERROR)
warnings.filterwarnings(
"ignore", category=UserWarning, message="The given NumPy array is not writable"
@@ -492,27 +488,6 @@ def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = N
pass
def monkey_patch_vllm_model_config():
from vllm.config import ModelConfig
if not hasattr(ModelConfig, "_resolve_task"):
return
def _resolve_task(
self,
task_option,
hf_config,
):
supported_tasks = {
"generate": True,
"embedding": False,
}
selected_task = "generate"
return supported_tasks, selected_task
setattr(ModelConfig, "_resolve_task", _resolve_task)
def monkey_patch_vllm_p2p_access_check(gpu_id: int):
"""
Monkey patch the slow p2p access check in vllm.
@@ -1041,6 +1016,11 @@ def crash_on_warnings():
return get_bool_env_var("SGLANG_IS_IN_CI")
def print_warning_once(msg: str) -> None:
# Set the stacklevel to 2 to print the caller's line info
logger.warning(msg, stacklevel=2)
def get_device_name(device_id: int = 0) -> str:
if hasattr(torch, "cuda") and torch.cuda.is_available():
return torch.cuda.get_device_name(device_id)
@@ -1055,6 +1035,33 @@ def get_device_name(device_id: int = 0) -> str:
return torch.hpu.get_device_name(device_id)
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
major, minor = None, None
if hasattr(torch, "cuda") and torch.cuda.is_available():
major, minor = torch.cuda.get_device_capability(device_id)
if hasattr(torch, "hip") and torch.hip.is_available():
major, minor = torch.cuda.get_device_capability(device_id)
if hasattr(torch, "xpu") and torch.xpu.is_available():
major, minor, *_ = torch.xpu.get_device_capability(device_id)["version"].split(
"."
)
major, minor = int(major), int(minor)
# TODO(HandH1998): `get_device_capability` is not supported by `torch.hpu` for now.
# Update this once the support is available.
if hasattr(torch, "hpu") and torch.hpu.is_available():
try:
major, minor = torch.hpu.get_device_capability(device_id)
except Exception as e:
raise RuntimeError(
f"An error occurred while getting device capability of hpu: {e}."
) from e
return major, minor
sglang_lib = Library("sglang", "FRAGMENT") # noqa