Co-authored-by: HandH1998 <1335248067@qq.com>
This commit is contained in:
@@ -111,8 +111,12 @@ def load_model(server_args, port_args, tp_rank):
|
||||
model_config = ModelConfig(
|
||||
server_args.model_path,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
revision=server_args.revision,
|
||||
context_length=server_args.context_length,
|
||||
model_override_args=server_args.json_model_override_args,
|
||||
is_embedding=server_args.is_embedding,
|
||||
dtype=server_args.dtype,
|
||||
quantization=server_args.quantization,
|
||||
)
|
||||
model_runner = ModelRunner(
|
||||
model_config=model_config,
|
||||
|
||||
17
python/sglang/srt/configs/device_config.py
Normal file
17
python/sglang/srt/configs/device_config.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DeviceConfig:
|
||||
device: Optional[torch.device]
|
||||
|
||||
def __init__(self, device: str = "cuda") -> None:
|
||||
if device in ["cuda", "xpu", "hpu"]:
|
||||
self.device_type = device
|
||||
else:
|
||||
raise RuntimeError(f"Not supported device type: {device}")
|
||||
self.device = torch.device(self.device_type)
|
||||
84
python/sglang/srt/configs/load_config.py
Normal file
84
python/sglang/srt/configs/load_config.py
Normal file
@@ -0,0 +1,84 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
||||
import enum
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoadFormat(str, enum.Enum):
|
||||
AUTO = "auto"
|
||||
PT = "pt"
|
||||
SAFETENSORS = "safetensors"
|
||||
NPCACHE = "npcache"
|
||||
DUMMY = "dummy"
|
||||
SHARDED_STATE = "sharded_state"
|
||||
GGUF = "gguf"
|
||||
BITSANDBYTES = "bitsandbytes"
|
||||
MISTRAL = "mistral"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadConfig:
|
||||
"""
|
||||
download_dir: Directory to download and load the weights, default to the
|
||||
default cache directory of huggingface.
|
||||
load_format: The format of the model weights to load:
|
||||
"auto" will try to load the weights in the safetensors format and
|
||||
fall back to the pytorch bin format if safetensors format is
|
||||
not available.
|
||||
"pt" will load the weights in the pytorch bin format.
|
||||
"safetensors" will load the weights in the safetensors format.
|
||||
"npcache" will load the weights in pytorch format and store
|
||||
a numpy cache to speed up the loading.
|
||||
"dummy" will initialize the weights with random values, which is
|
||||
mainly for profiling.
|
||||
"bitsandbytes" will load nf4 type weights.
|
||||
ignore_patterns: The list of patterns to ignore when loading the model.
|
||||
Default to "original/**/*" to avoid repeated loading of llama's
|
||||
checkpoints.
|
||||
|
||||
"""
|
||||
|
||||
load_format: Union[str, LoadFormat] = LoadFormat.AUTO
|
||||
download_dir: Optional[str] = None
|
||||
model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)
|
||||
ignore_patterns: Optional[Union[List[str], str]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
model_loader_extra_config = self.model_loader_extra_config or {}
|
||||
if isinstance(model_loader_extra_config, str):
|
||||
self.model_loader_extra_config = json.loads(model_loader_extra_config)
|
||||
self._verify_load_format()
|
||||
|
||||
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
|
||||
logger.info(
|
||||
"Ignoring the following patterns when downloading weights: %s",
|
||||
self.ignore_patterns,
|
||||
)
|
||||
else:
|
||||
self.ignore_patterns = ["original/**/*"]
|
||||
|
||||
def _verify_load_format(self) -> None:
|
||||
if not isinstance(self.load_format, str):
|
||||
return
|
||||
|
||||
load_format = self.load_format.lower()
|
||||
self.load_format = LoadFormat(load_format)
|
||||
|
||||
rocm_not_supported_load_format: List[str] = []
|
||||
if is_hip() and load_format in rocm_not_supported_load_format:
|
||||
rocm_supported_load_format = [
|
||||
f
|
||||
for f in LoadFormat.__members__
|
||||
if (f not in rocm_not_supported_load_format)
|
||||
]
|
||||
raise ValueError(
|
||||
f"load format '{load_format}' is not supported in ROCm. "
|
||||
f"Supported load formats are "
|
||||
f"{rocm_supported_load_format}"
|
||||
)
|
||||
@@ -15,12 +15,14 @@
|
||||
import json
|
||||
import logging
|
||||
from enum import IntEnum, auto
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_config, get_context_length
|
||||
from sglang.srt.utils import get_bool_env_var
|
||||
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
|
||||
from sglang.srt.utils import get_bool_env_var, is_hip
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -33,17 +35,22 @@ class AttentionArch(IntEnum):
|
||||
class ModelConfig:
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
model_path: str,
|
||||
trust_remote_code: bool = True,
|
||||
revision: Optional[str] = None,
|
||||
context_length: Optional[int] = None,
|
||||
model_override_args: Optional[dict] = None,
|
||||
is_embedding: Optional[bool] = None,
|
||||
dtype: str = "auto",
|
||||
quantization: Optional[str] = None,
|
||||
) -> None:
|
||||
self.model_path = model_path
|
||||
self.revision = revision
|
||||
self.quantization = quantization
|
||||
# Parse args
|
||||
self.model_override_args = json.loads(model_override_args)
|
||||
self.hf_config = get_config(
|
||||
path,
|
||||
model_path,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
model_override_args=self.model_override_args,
|
||||
@@ -56,6 +63,7 @@ class ModelConfig:
|
||||
)
|
||||
self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
|
||||
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
|
||||
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
||||
|
||||
# Derive context length
|
||||
derived_context_len = get_context_length(self.hf_text_config)
|
||||
@@ -116,6 +124,8 @@ class ModelConfig:
|
||||
self.num_hidden_layers = self.hf_text_config.num_hidden_layers
|
||||
self.vocab_size = self.hf_text_config.vocab_size
|
||||
|
||||
self._verify_quantization()
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
|
||||
def get_total_num_kv_heads(self) -> int:
|
||||
"""Returns the total number of KV heads."""
|
||||
@@ -174,6 +184,86 @@ class ModelConfig:
|
||||
# parallel size so each GPU has at least one KV head.
|
||||
return max(1, total_num_kv_heads // tensor_parallel_size)
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
||||
def _parse_quant_hf_config(self):
|
||||
quant_cfg = getattr(self.hf_config, "quantization_config", None)
|
||||
if quant_cfg is None:
|
||||
# compressed-tensors uses a "compression_config" key
|
||||
quant_cfg = getattr(self.hf_config, "compression_config", None)
|
||||
return quant_cfg
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
||||
def _verify_quantization(self) -> None:
|
||||
supported_quantization = [*QUANTIZATION_METHODS]
|
||||
rocm_supported_quantization = [
|
||||
"awq",
|
||||
"gptq",
|
||||
"fp8",
|
||||
"compressed_tensors",
|
||||
"compressed-tensors",
|
||||
"fbgemm_fp8",
|
||||
]
|
||||
optimized_quantization_methods = [
|
||||
"fp8",
|
||||
"marlin",
|
||||
"modelopt",
|
||||
"gptq_marlin_24",
|
||||
"gptq_marlin",
|
||||
"awq_marlin",
|
||||
"fbgemm_fp8",
|
||||
"compressed_tensors",
|
||||
"compressed-tensors",
|
||||
"experts_int8",
|
||||
]
|
||||
if self.quantization is not None:
|
||||
self.quantization = self.quantization.lower()
|
||||
|
||||
# Parse quantization method from the HF model config, if available.
|
||||
quant_cfg = self._parse_quant_hf_config()
|
||||
|
||||
if quant_cfg is not None:
|
||||
quant_method = quant_cfg.get("quant_method", "").lower()
|
||||
|
||||
# Detect which checkpoint is it
|
||||
for _, method in QUANTIZATION_METHODS.items():
|
||||
quantization_override = method.override_quantization_method(
|
||||
quant_cfg, self.quantization
|
||||
)
|
||||
if quantization_override:
|
||||
quant_method = quantization_override
|
||||
self.quantization = quantization_override
|
||||
break
|
||||
|
||||
# Verify quantization configurations.
|
||||
if self.quantization is None:
|
||||
self.quantization = quant_method
|
||||
elif self.quantization != quant_method:
|
||||
raise ValueError(
|
||||
"Quantization method specified in the model config "
|
||||
f"({quant_method}) does not match the quantization "
|
||||
f"method specified in the `quantization` argument "
|
||||
f"({self.quantization})."
|
||||
)
|
||||
|
||||
if self.quantization is not None:
|
||||
if self.quantization not in supported_quantization:
|
||||
raise ValueError(
|
||||
f"Unknown quantization method: {self.quantization}. Must "
|
||||
f"be one of {supported_quantization}."
|
||||
)
|
||||
if is_hip() and self.quantization not in rocm_supported_quantization:
|
||||
raise ValueError(
|
||||
f"{self.quantization} quantization is currently not "
|
||||
f"supported in ROCm."
|
||||
)
|
||||
if self.quantization not in optimized_quantization_methods:
|
||||
logger.warning(
|
||||
"%s quantization is not fully "
|
||||
"optimized yet. The speed can be slower than "
|
||||
"non-quantized models.",
|
||||
self.quantization,
|
||||
)
|
||||
|
||||
|
||||
def get_hf_text_config(config: PretrainedConfig):
|
||||
"""Get the "sub" config relevant to llm for multi modal models.
|
||||
@@ -183,6 +273,9 @@ def get_hf_text_config(config: PretrainedConfig):
|
||||
if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"):
|
||||
# We support non-hf version of llava models, so we do not want to
|
||||
# read the wrong values from the unused default text_config.
|
||||
# NOTE(HandH1998): We set `torch_dtype` of config to `torch.float16` for the weights, as
|
||||
# `torch.float16` is default used for image features in `python/sglang/srt/models/llava.py`.
|
||||
setattr(config, "torch_dtype", torch.float16)
|
||||
return config
|
||||
|
||||
if hasattr(config, "text_config"):
|
||||
@@ -195,6 +288,70 @@ def get_hf_text_config(config: PretrainedConfig):
|
||||
return config
|
||||
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"half": torch.float16,
|
||||
"float16": torch.float16,
|
||||
"float": torch.float32,
|
||||
"float32": torch.float32,
|
||||
"bfloat16": torch.bfloat16,
|
||||
}
|
||||
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
||||
def _get_and_verify_dtype(
|
||||
config: PretrainedConfig,
|
||||
dtype: Union[str, torch.dtype],
|
||||
) -> torch.dtype:
|
||||
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
|
||||
# because config.torch_dtype can be None.
|
||||
config_dtype = getattr(config, "torch_dtype", None)
|
||||
if config_dtype is None:
|
||||
config_dtype = torch.float32
|
||||
|
||||
if isinstance(dtype, str):
|
||||
dtype = dtype.lower()
|
||||
if dtype == "auto":
|
||||
if config_dtype == torch.float32:
|
||||
if config.model_type == "gemma2":
|
||||
logger.info(
|
||||
"For Gemma 2, we downcast float32 to bfloat16 instead "
|
||||
"of float16 by default. Please specify `dtype` if you "
|
||||
"want to use float16."
|
||||
)
|
||||
torch_dtype = torch.bfloat16
|
||||
else:
|
||||
# Following the common practice, we use float16 for float32
|
||||
# models.
|
||||
torch_dtype = torch.float16
|
||||
else:
|
||||
torch_dtype = config_dtype
|
||||
else:
|
||||
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||
elif isinstance(dtype, torch.dtype):
|
||||
torch_dtype = dtype
|
||||
else:
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
|
||||
# Verify the dtype.
|
||||
if torch_dtype != config_dtype:
|
||||
if torch_dtype == torch.float32:
|
||||
# Upcasting to float32 is allowed.
|
||||
logger.info("Upcasting %s to %s.", config_dtype, torch_dtype)
|
||||
pass
|
||||
elif config_dtype == torch.float32:
|
||||
# Downcasting from float32 to float16 or bfloat16 is allowed.
|
||||
logger.info("Downcasting %s to %s.", config_dtype, torch_dtype)
|
||||
pass
|
||||
else:
|
||||
# Casting between float16 and bfloat16 is allowed with a warning.
|
||||
logger.warning("Casting %s to %s.", config_dtype, torch_dtype)
|
||||
|
||||
return torch_dtype
|
||||
|
||||
|
||||
def is_generation_model(model_architectures: List[str], is_embedding: bool = False):
|
||||
# We have two ways to determine whether a model is a generative model.
|
||||
# 1. Check the model architectue
|
||||
|
||||
@@ -121,13 +121,10 @@ class Qwen2VLConfig(PretrainedConfig):
|
||||
self.attention_dropout = attention_dropout
|
||||
self.rope_scaling = rope_scaling
|
||||
|
||||
# NOTE: the following section from original transformers config
|
||||
# for Qwen2-VL is commented out to address rope config loading issue
|
||||
#
|
||||
# if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
# if self.rope_scaling["type"] == "mrope":
|
||||
# self.rope_scaling["type"] = "default"
|
||||
# self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
# rope_config_validation(self)
|
||||
# NOTE(HandH1998): This is necessary for configuring the `rope_type`` of qwen2vl models after removing dependencies on vllm.
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
if self.rope_scaling["type"] == "mrope":
|
||||
self.rope_scaling["type"] = "default"
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
|
||||
@@ -75,6 +75,8 @@ def get_config(
|
||||
if config.model_type in _CONFIG_REGISTRY:
|
||||
config_class = _CONFIG_REGISTRY[config.model_type]
|
||||
config = config_class.from_pretrained(model, revision=revision)
|
||||
# NOTE(HandH1998): Qwen2VL requires `_name_or_path` attribute in `config`.
|
||||
setattr(config, "_name_or_path", model)
|
||||
if model_override_args:
|
||||
config.update(model_override_args)
|
||||
|
||||
|
||||
@@ -42,6 +42,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
|
||||
"Fp8LinearMethod",
|
||||
"MarlinLinearMethod",
|
||||
"GPTQLinearMethod",
|
||||
"QQQLinearMethod",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -31,7 +31,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
||||
|
||||
from sglang.srt.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
@@ -40,6 +39,7 @@ from sglang.srt.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.model_loader.loader import DefaultModelLoader
|
||||
|
||||
|
||||
class BaseLayerWithLoRA(nn.Module):
|
||||
|
||||
@@ -147,9 +147,12 @@ class Scheduler:
|
||||
self.model_config = ModelConfig(
|
||||
server_args.model_path,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
revision=server_args.revision,
|
||||
context_length=server_args.context_length,
|
||||
model_override_args=server_args.json_model_override_args,
|
||||
is_embedding=server_args.is_embedding,
|
||||
dtype=server_args.dtype,
|
||||
quantization=server_args.quantization,
|
||||
)
|
||||
self.is_generation = self.model_config.is_generation
|
||||
|
||||
|
||||
@@ -109,9 +109,12 @@ class TokenizerManager:
|
||||
self.model_config = ModelConfig(
|
||||
server_args.model_path,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
revision=server_args.revision,
|
||||
context_length=server_args.context_length,
|
||||
model_override_args=server_args.json_model_override_args,
|
||||
is_embedding=server_args.is_embedding,
|
||||
dtype=server_args.dtype,
|
||||
quantization=server_args.quantization,
|
||||
)
|
||||
|
||||
self.is_generation = self.model_config.is_generation
|
||||
|
||||
@@ -52,9 +52,12 @@ class TpModelWorker:
|
||||
self.model_config = ModelConfig(
|
||||
server_args.model_path,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
revision=server_args.revision,
|
||||
context_length=server_args.context_length,
|
||||
model_override_args=server_args.json_model_override_args,
|
||||
is_embedding=server_args.is_embedding,
|
||||
dtype=server_args.dtype,
|
||||
quantization=server_args.quantization,
|
||||
)
|
||||
self.model_runner = ModelRunner(
|
||||
model_config=self.model_config,
|
||||
|
||||
@@ -14,22 +14,12 @@
|
||||
"""ModelRunner runs the forward passes of the models."""
|
||||
|
||||
import gc
|
||||
import importlib
|
||||
import importlib.resources
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import pkgutil
|
||||
import time
|
||||
from functools import lru_cache
|
||||
from tokenize import tabsize
|
||||
from typing import Any, Optional, Type, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from vllm.config import DeviceConfig, LoadConfig
|
||||
from vllm.config import ModelConfig as VllmModelConfig
|
||||
from vllm.distributed import (
|
||||
get_tp_group,
|
||||
init_distributed_environment,
|
||||
@@ -37,9 +27,9 @@ from vllm.distributed import (
|
||||
set_custom_all_reduce,
|
||||
)
|
||||
from vllm.distributed.parallel_state import in_the_same_node_as
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
|
||||
from sglang.srt.configs.device_config import DeviceConfig
|
||||
from sglang.srt.configs.load_config import LoadConfig
|
||||
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
||||
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
|
||||
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
|
||||
@@ -56,16 +46,15 @@ from sglang.srt.mem_cache.memory_pool import (
|
||||
ReqToTokenPool,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader import get_model
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
crash_on_warnings,
|
||||
enable_show_time_cost,
|
||||
get_available_gpu_memory,
|
||||
init_custom_process_group,
|
||||
is_hip,
|
||||
monkey_patch_vllm_gguf_config,
|
||||
monkey_patch_vllm_model_config,
|
||||
monkey_patch_vllm_p2p_access_check,
|
||||
set_cpu_offload_max_bytes,
|
||||
)
|
||||
@@ -228,49 +217,6 @@ class ModelRunner:
|
||||
|
||||
return min_per_gpu_memory
|
||||
|
||||
def setup_model(self):
|
||||
try:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.model_config = self.vllm_model_config
|
||||
vllm_config.load_config = self.load_config
|
||||
vllm_config.device_config = DeviceConfig(self.device)
|
||||
vllm_config.quant_config = VllmConfig._get_quantization_config(
|
||||
vllm_config.model_config, vllm_config.load_config
|
||||
)
|
||||
return get_model(vllm_config=vllm_config)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return get_model(
|
||||
model_config=self.vllm_model_config,
|
||||
load_config=self.load_config,
|
||||
device_config=DeviceConfig(self.device),
|
||||
parallel_config=None,
|
||||
scheduler_config=None,
|
||||
lora_config=None,
|
||||
cache_config=None,
|
||||
)
|
||||
|
||||
def get_model_config_params(self):
|
||||
sig = inspect.signature(VllmModelConfig.__init__)
|
||||
params = {
|
||||
"model": self.server_args.model_path,
|
||||
"quantization": self.server_args.quantization,
|
||||
"tokenizer": None,
|
||||
"tokenizer_mode": None,
|
||||
"trust_remote_code": self.server_args.trust_remote_code,
|
||||
"dtype": self.server_args.dtype,
|
||||
"seed": self.server_args.random_seed,
|
||||
"skip_tokenizer_init": True,
|
||||
}
|
||||
|
||||
if "task" in sig.parameters:
|
||||
params["task"] = ""
|
||||
|
||||
return params
|
||||
|
||||
def load_model(self):
|
||||
logger.info(
|
||||
f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||
@@ -284,6 +230,7 @@ class ModelRunner:
|
||||
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
|
||||
)
|
||||
self.server_args.dtype = "float16"
|
||||
self.model_config.dtype = torch.float16
|
||||
if torch.cuda.get_device_capability()[1] < 5:
|
||||
raise RuntimeError("SGLang only supports sm75 and above.")
|
||||
|
||||
@@ -292,23 +239,21 @@ class ModelRunner:
|
||||
load_format=self.server_args.load_format,
|
||||
download_dir=self.server_args.download_dir,
|
||||
)
|
||||
monkey_patch_vllm_model_config()
|
||||
|
||||
if self.server_args.load_format == "gguf":
|
||||
monkey_patch_vllm_gguf_config()
|
||||
self.vllm_model_config = VllmModelConfig(**self.get_model_config_params())
|
||||
if self.model_config.model_override_args is not None:
|
||||
self.vllm_model_config.hf_config.update(
|
||||
self.model_config.model_override_args
|
||||
)
|
||||
|
||||
self.model = self.setup_model()
|
||||
self.model = get_model(
|
||||
model_config=self.model_config,
|
||||
load_config=self.load_config,
|
||||
device_config=DeviceConfig(self.device),
|
||||
)
|
||||
|
||||
self.sliding_window_size = (
|
||||
self.model.get_attention_sliding_window_size()
|
||||
if hasattr(self.model, "get_attention_sliding_window_size")
|
||||
else None
|
||||
)
|
||||
self.dtype = self.vllm_model_config.dtype
|
||||
self.dtype = self.model_config.dtype
|
||||
|
||||
logger.info(
|
||||
f"Load weight end. "
|
||||
@@ -319,12 +264,12 @@ class ModelRunner:
|
||||
|
||||
def update_weights_from_disk(self, model_path: str, load_format: str):
|
||||
"""Update engine weights online from disk."""
|
||||
from vllm.model_executor.model_loader.loader import (
|
||||
from sglang.srt.model_loader.loader import (
|
||||
DefaultModelLoader,
|
||||
device_loading_context,
|
||||
get_model_loader,
|
||||
)
|
||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
||||
|
||||
logger.info(
|
||||
f"Update engine weights online from disk begin. "
|
||||
@@ -332,15 +277,7 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
target_device = torch.device(self.device)
|
||||
|
||||
try:
|
||||
model_config_params = self.get_model_config_params()
|
||||
model_config_params["model"] = model_path
|
||||
vllm_model_config = VllmModelConfig(**model_config_params)
|
||||
except Exception as e:
|
||||
message = f"Failed to load model config: {e}."
|
||||
return False, message
|
||||
|
||||
self.model_config.model_path = model_path
|
||||
load_config = LoadConfig(load_format=load_format)
|
||||
|
||||
# Only support vllm DefaultModelLoader for now
|
||||
@@ -352,7 +289,7 @@ class ModelRunner:
|
||||
def get_weight_iter(config):
|
||||
iter = loader._get_weights_iterator(
|
||||
DefaultModelLoader.Source(
|
||||
config.model,
|
||||
config.model_path,
|
||||
revision=config.revision,
|
||||
fall_back_to_pt=getattr(
|
||||
self.model, "fall_back_to_pt_during_load", True
|
||||
@@ -370,9 +307,9 @@ class ModelRunner:
|
||||
quant_method.process_weights_after_loading(module)
|
||||
return model
|
||||
|
||||
with set_default_torch_dtype(vllm_model_config.dtype):
|
||||
with set_default_torch_dtype(self.model_config.dtype):
|
||||
try:
|
||||
iter = get_weight_iter(vllm_model_config)
|
||||
iter = get_weight_iter(self.model_config)
|
||||
except Exception as e:
|
||||
message = f"Failed to get weights iterator: {e}."
|
||||
return False, message
|
||||
@@ -384,16 +321,14 @@ class ModelRunner:
|
||||
)
|
||||
del iter
|
||||
gc.collect()
|
||||
iter = get_weight_iter(self.vllm_model_config)
|
||||
iter = get_weight_iter(self.model_config)
|
||||
self.model = model_load_weights(self.model, iter)
|
||||
return False, message
|
||||
|
||||
self.model = model
|
||||
self.server_args.model_path = model_path
|
||||
self.server_args.load_format = load_format
|
||||
self.vllm_model_config = vllm_model_config
|
||||
self.load_config = load_config
|
||||
self.model_config.path = model_path
|
||||
|
||||
logger.info("Update weights end.")
|
||||
return True, "Succeeded to update model weights."
|
||||
@@ -794,55 +729,3 @@ class ModelRunner:
|
||||
if rope_scaling is None:
|
||||
return False
|
||||
return rope_scaling.get("type", None) == "mrope"
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def import_model_classes():
|
||||
model_arch_name_to_cls = {}
|
||||
package_name = "sglang.srt.models"
|
||||
package = importlib.import_module(package_name)
|
||||
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
|
||||
if not ispkg:
|
||||
try:
|
||||
module = importlib.import_module(name)
|
||||
except Exception as e:
|
||||
logger.warning(f"Ignore import error when loading {name}. {e}")
|
||||
if crash_on_warnings():
|
||||
raise ValueError(f"Ignore import error when loading {name}. {e}")
|
||||
continue
|
||||
if hasattr(module, "EntryClass"):
|
||||
entry = module.EntryClass
|
||||
if isinstance(
|
||||
entry, list
|
||||
): # To support multiple model classes in one module
|
||||
for tmp in entry:
|
||||
assert (
|
||||
tmp.__name__ not in model_arch_name_to_cls
|
||||
), f"Duplicated model implementation for {tmp.__name__}"
|
||||
model_arch_name_to_cls[tmp.__name__] = tmp
|
||||
else:
|
||||
assert (
|
||||
entry.__name__ not in model_arch_name_to_cls
|
||||
), f"Duplicated model implementation for {entry.__name__}"
|
||||
model_arch_name_to_cls[entry.__name__] = entry
|
||||
|
||||
return model_arch_name_to_cls
|
||||
|
||||
|
||||
def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
|
||||
model_arch_name_to_cls = import_model_classes()
|
||||
|
||||
if model_arch not in model_arch_name_to_cls:
|
||||
raise ValueError(
|
||||
f"Unsupported architectures: {model_arch}. "
|
||||
f"Supported list: {list(model_arch_name_to_cls.keys())}"
|
||||
)
|
||||
return model_arch_name_to_cls[model_arch]
|
||||
|
||||
|
||||
# Monkey patch model loader
|
||||
setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)
|
||||
setattr(ModelRegistry, "is_multimodal_model", lambda model_architectures: False)
|
||||
setattr(ModelRegistry, "is_attention_free_model", lambda model_architectures: False)
|
||||
setattr(ModelRegistry, "model_has_inner_state", lambda model_architectures: False)
|
||||
setattr(ModelRegistry, "is_embedding_model", lambda model_architectures: False)
|
||||
|
||||
34
python/sglang/srt/model_loader/__init__.py
Normal file
34
python/sglang/srt/model_loader/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/__init__.py
|
||||
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.configs.device_config import DeviceConfig
|
||||
from sglang.srt.configs.load_config import LoadConfig
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.model_loader.loader import BaseModelLoader, get_model_loader
|
||||
from sglang.srt.model_loader.utils import (
|
||||
get_architecture_class_name,
|
||||
get_model_architecture,
|
||||
)
|
||||
|
||||
|
||||
def get_model(
|
||||
*,
|
||||
model_config: ModelConfig,
|
||||
load_config: LoadConfig,
|
||||
device_config: DeviceConfig,
|
||||
) -> nn.Module:
|
||||
loader = get_model_loader(load_config)
|
||||
return loader.load_model(
|
||||
model_config=model_config,
|
||||
device_config=device_config,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_model",
|
||||
"get_model_loader",
|
||||
"BaseModelLoader",
|
||||
"get_architecture_class_name",
|
||||
"get_model_architecture",
|
||||
]
|
||||
1139
python/sglang/srt/model_loader/loader.py
Normal file
1139
python/sglang/srt/model_loader/loader.py
Normal file
File diff suppressed because it is too large
Load Diff
41
python/sglang/srt/model_loader/utils.py
Normal file
41
python/sglang/srt/model_loader/utils.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/utils.py
|
||||
|
||||
"""Utilities for selecting and loading models."""
|
||||
import contextlib
|
||||
from typing import Tuple, Type
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_default_torch_dtype(dtype: torch.dtype):
|
||||
"""Sets the default torch dtype to the given dtype."""
|
||||
old_dtype = torch.get_default_dtype()
|
||||
torch.set_default_dtype(dtype)
|
||||
yield
|
||||
torch.set_default_dtype(old_dtype)
|
||||
|
||||
|
||||
def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
|
||||
from sglang.srt.models.registry import ModelRegistry
|
||||
|
||||
architectures = getattr(model_config.hf_config, "architectures", [])
|
||||
# Special handling for quantized Mixtral.
|
||||
# FIXME(woosuk): This is a temporary hack.
|
||||
mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin", "awq_marlin"]
|
||||
|
||||
if (
|
||||
model_config.quantization is not None
|
||||
and model_config.quantization not in mixtral_supported
|
||||
and "MixtralForCausalLM" in architectures
|
||||
):
|
||||
architectures = ["QuantMixtralForCausalLM"]
|
||||
|
||||
return ModelRegistry.resolve_model_cls(architectures)
|
||||
|
||||
|
||||
def get_architecture_class_name(model_config: ModelConfig) -> str:
|
||||
return get_model_architecture(model_config)[1]
|
||||
640
python/sglang/srt/model_loader/weight_utils.py
Normal file
640
python/sglang/srt/model_loader/weight_utils.py
Normal file
@@ -0,0 +1,640 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/weight_utils.py
|
||||
|
||||
"""Utilities for downloading and initializing model weights."""
|
||||
import fnmatch
|
||||
import glob
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
|
||||
|
||||
import filelock
|
||||
import gguf
|
||||
import huggingface_hub.constants
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
|
||||
from safetensors.torch import load_file, safe_open, save_file
|
||||
from tqdm.auto import tqdm
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
from sglang.srt.configs.load_config import LoadConfig
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
|
||||
from sglang.srt.utils import print_warning_once
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# use system-level temp directory for file locks, so that multiple users
|
||||
# can share the same lock without error.
|
||||
# lock files in the temp directory will be automatically deleted when the
|
||||
# system reboots, so users will not complain about annoying lock files
|
||||
temp_dir = tempfile.gettempdir()
|
||||
|
||||
|
||||
def enable_hf_transfer():
|
||||
"""automatically activates hf_transfer"""
|
||||
if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
|
||||
try:
|
||||
# enable hf hub transfer if available
|
||||
import hf_transfer # type: ignore # noqa
|
||||
|
||||
huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
enable_hf_transfer()
|
||||
|
||||
|
||||
class DisabledTqdm(tqdm):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs, disable=True)
|
||||
|
||||
|
||||
def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
|
||||
lock_dir = cache_dir or temp_dir
|
||||
os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
|
||||
model_name = model_name_or_path.replace("/", "-")
|
||||
hash_name = hashlib.sha256(model_name.encode()).hexdigest()
|
||||
# add hash to avoid conflict with old users' lock files
|
||||
lock_file_name = hash_name + model_name + ".lock"
|
||||
# mode 0o666 is required for the filelock to be shared across users
|
||||
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), mode=0o666)
|
||||
return lock
|
||||
|
||||
|
||||
def _shared_pointers(tensors):
|
||||
ptrs = defaultdict(list)
|
||||
for k, v in tensors.items():
|
||||
ptrs[v.data_ptr()].append(k)
|
||||
failing = []
|
||||
for _, names in ptrs.items():
|
||||
if len(names) > 1:
|
||||
failing.append(names)
|
||||
return failing
|
||||
|
||||
|
||||
def convert_bin_to_safetensor_file(
|
||||
pt_filename: str,
|
||||
sf_filename: str,
|
||||
) -> None:
|
||||
loaded = torch.load(pt_filename, map_location="cpu")
|
||||
if "state_dict" in loaded:
|
||||
loaded = loaded["state_dict"]
|
||||
shared = _shared_pointers(loaded)
|
||||
for shared_weights in shared:
|
||||
for name in shared_weights[1:]:
|
||||
loaded.pop(name)
|
||||
|
||||
# For tensors to be contiguous
|
||||
loaded = {k: v.contiguous() for k, v in loaded.items()}
|
||||
|
||||
dirname = os.path.dirname(sf_filename)
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
save_file(loaded, sf_filename, metadata={"format": "pt"})
|
||||
|
||||
# check file size
|
||||
sf_size = os.stat(sf_filename).st_size
|
||||
pt_size = os.stat(pt_filename).st_size
|
||||
if (sf_size - pt_size) / pt_size > 0.01:
|
||||
raise RuntimeError(
|
||||
f"""The file size different is more than 1%:
|
||||
- {sf_filename}: {sf_size}
|
||||
- {pt_filename}: {pt_size}
|
||||
"""
|
||||
)
|
||||
|
||||
# check if the tensors are the same
|
||||
reloaded = load_file(sf_filename)
|
||||
for k in loaded:
|
||||
pt_tensor = loaded[k]
|
||||
sf_tensor = reloaded[k]
|
||||
if not torch.equal(pt_tensor, sf_tensor):
|
||||
raise RuntimeError(f"The output tensors do not match for key {k}")
|
||||
|
||||
|
||||
# TODO(woosuk): Move this to other place.
|
||||
def get_quant_config(
|
||||
model_config: ModelConfig, load_config: LoadConfig
|
||||
) -> QuantizationConfig:
|
||||
|
||||
quant_cls = get_quantization_config(model_config.quantization)
|
||||
|
||||
# GGUF doesn't have config file
|
||||
if model_config.quantization == "gguf":
|
||||
return quant_cls.from_config({})
|
||||
|
||||
# Read the quantization config from the HF model config, if available.
|
||||
hf_quant_config = getattr(model_config.hf_config, "quantization_config", None)
|
||||
# some vision model may keep quantization_config in their text_config
|
||||
hf_text_config = getattr(model_config.hf_config, "text_config", None)
|
||||
if hf_quant_config is None and hf_text_config is not None:
|
||||
hf_quant_config = getattr(hf_text_config, "quantization_config", None)
|
||||
if hf_quant_config is None:
|
||||
# compressed-tensors uses a compressions_config
|
||||
hf_quant_config = getattr(model_config.hf_config, "compression_config", None)
|
||||
if hf_quant_config is not None:
|
||||
return quant_cls.from_config(hf_quant_config)
|
||||
# In case of bitsandbytes/QLoRA, get quant config from the adapter model.
|
||||
if model_config.quantization == "bitsandbytes":
|
||||
if (
|
||||
not load_config.model_loader_extra_config
|
||||
or "qlora_adapter_name_or_path" not in load_config.model_loader_extra_config
|
||||
):
|
||||
return quant_cls.from_config({"adapter_name_or_path": ""})
|
||||
model_name_or_path = load_config.model_loader_extra_config[
|
||||
"qlora_adapter_name_or_path"
|
||||
]
|
||||
|
||||
else:
|
||||
model_name_or_path = model_config.model_path
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
if not is_local:
|
||||
# Download the config files.
|
||||
with get_lock(model_name_or_path, load_config.download_dir):
|
||||
hf_folder = snapshot_download(
|
||||
model_name_or_path,
|
||||
revision=model_config.revision,
|
||||
allow_patterns="*.json",
|
||||
cache_dir=load_config.download_dir,
|
||||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||
tqdm_class=DisabledTqdm,
|
||||
)
|
||||
else:
|
||||
hf_folder = model_name_or_path
|
||||
|
||||
possible_config_filenames = quant_cls.get_config_filenames()
|
||||
|
||||
# If the quantization config is not found, use the default config.
|
||||
if not possible_config_filenames:
|
||||
return quant_cls()
|
||||
|
||||
config_files = glob.glob(os.path.join(hf_folder, "*.json"))
|
||||
|
||||
quant_config_files = [
|
||||
f for f in config_files if any(f.endswith(x) for x in possible_config_filenames)
|
||||
]
|
||||
if len(quant_config_files) == 0:
|
||||
raise ValueError(f"Cannot find the config file for {model_config.quantization}")
|
||||
if len(quant_config_files) > 1:
|
||||
raise ValueError(
|
||||
f"Found multiple config files for {model_config.quantization}: "
|
||||
f"{quant_config_files}"
|
||||
)
|
||||
|
||||
quant_config_file = quant_config_files[0]
|
||||
with open(quant_config_file) as f:
|
||||
config = json.load(f)
|
||||
|
||||
if model_config.quantization == "bitsandbytes":
|
||||
config["adapter_name_or_path"] = model_name_or_path
|
||||
elif model_config.quantization == "modelopt":
|
||||
if config["producer"]["name"] == "modelopt":
|
||||
return quant_cls.from_config(config)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported quantization config"
|
||||
f" found for {model_config.quantization} in {f}."
|
||||
)
|
||||
|
||||
return quant_cls.from_config(config)
|
||||
|
||||
|
||||
def download_weights_from_hf(
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str],
|
||||
allow_patterns: List[str],
|
||||
revision: Optional[str] = None,
|
||||
ignore_patterns: Optional[Union[str, List[str]]] = None,
|
||||
) -> str:
|
||||
"""Download model weights from Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
model_name_or_path (str): The model name or path.
|
||||
cache_dir (Optional[str]): The cache directory to store the model
|
||||
weights. If None, will use HF defaults.
|
||||
allow_patterns (List[str]): The allowed patterns for the
|
||||
weight files. Files matched by any of the patterns will be
|
||||
downloaded.
|
||||
revision (Optional[str]): The revision of the model.
|
||||
ignore_patterns (Optional[Union[str, List[str]]]): The patterns to
|
||||
filter out the weight files. Files matched by any of the patterns
|
||||
will be ignored.
|
||||
|
||||
Returns:
|
||||
str: The path to the downloaded model weights.
|
||||
"""
|
||||
if not huggingface_hub.constants.HF_HUB_OFFLINE:
|
||||
# Before we download we look at that is available:
|
||||
fs = HfFileSystem()
|
||||
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
|
||||
|
||||
# depending on what is available we download different things
|
||||
for pattern in allow_patterns:
|
||||
matching = fnmatch.filter(file_list, pattern)
|
||||
if len(matching) > 0:
|
||||
allow_patterns = [pattern]
|
||||
break
|
||||
|
||||
logger.info("Using model weights format %s", allow_patterns)
|
||||
# Use file lock to prevent multiple processes from
|
||||
# downloading the same model weights at the same time.
|
||||
with get_lock(model_name_or_path, cache_dir):
|
||||
hf_folder = snapshot_download(
|
||||
model_name_or_path,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
cache_dir=cache_dir,
|
||||
tqdm_class=DisabledTqdm,
|
||||
revision=revision,
|
||||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||
)
|
||||
return hf_folder
|
||||
|
||||
|
||||
def download_safetensors_index_file_from_hf(
|
||||
model_name_or_path: str,
|
||||
index_file: str,
|
||||
cache_dir: Optional[str],
|
||||
revision: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Download hf safetensors index file from Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
model_name_or_path (str): The model name or path.
|
||||
cache_dir (Optional[str]): The cache directory to store the model
|
||||
weights. If None, will use HF defaults.
|
||||
revision (Optional[str]): The revision of the model.
|
||||
"""
|
||||
# Use file lock to prevent multiple processes from
|
||||
# downloading the same model weights at the same time.
|
||||
with get_lock(model_name_or_path, cache_dir):
|
||||
try:
|
||||
# Download the safetensors index file.
|
||||
hf_hub_download(
|
||||
repo_id=model_name_or_path,
|
||||
filename=index_file,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision,
|
||||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||
)
|
||||
# If file not found on remote or locally, we should not fail since
|
||||
# only some models will have index_file.
|
||||
except huggingface_hub.utils.EntryNotFoundError:
|
||||
logger.info("No %s found in remote.", index_file)
|
||||
except huggingface_hub.utils.LocalEntryNotFoundError:
|
||||
logger.info("No %s found in local cache.", index_file)
|
||||
|
||||
|
||||
# For models like Mistral-7B-v0.3, there are both sharded
|
||||
# safetensors files and a consolidated safetensors file.
|
||||
# Passing both of these to the weight loader functionality breaks.
|
||||
# So, we use the index_file to
|
||||
# look up which safetensors files should be used.
|
||||
def filter_duplicate_safetensors_files(
|
||||
hf_weights_files: List[str], hf_folder: str, index_file: str
|
||||
) -> List[str]:
|
||||
# model.safetensors.index.json is a mapping from keys in the
|
||||
# torch state_dict to safetensors file holding that weight.
|
||||
index_file_name = os.path.join(hf_folder, index_file)
|
||||
if not os.path.isfile(index_file_name):
|
||||
return hf_weights_files
|
||||
|
||||
# Iterate through the weight_map (weight_name: safetensors files)
|
||||
# to identify weights that we should use.
|
||||
with open(index_file_name) as f:
|
||||
weight_map = json.load(f)["weight_map"]
|
||||
weight_files_in_index = set()
|
||||
for weight_name in weight_map:
|
||||
weight_files_in_index.add(os.path.join(hf_folder, weight_map[weight_name]))
|
||||
# Filter out any fields that are not found in the index file.
|
||||
hf_weights_files = [f for f in hf_weights_files if f in weight_files_in_index]
|
||||
return hf_weights_files
|
||||
|
||||
|
||||
def filter_files_not_needed_for_inference(hf_weights_files: List[str]) -> List[str]:
|
||||
"""
|
||||
Exclude files that are not needed for inference.
|
||||
|
||||
See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
|
||||
"""
|
||||
blacklist = [
|
||||
"training_args.bin",
|
||||
"optimizer.bin",
|
||||
"optimizer.pt",
|
||||
"scheduler.pt",
|
||||
"scaler.pt",
|
||||
]
|
||||
hf_weights_files = [
|
||||
f for f in hf_weights_files if not any(f.endswith(x) for x in blacklist)
|
||||
]
|
||||
return hf_weights_files
|
||||
|
||||
|
||||
# explicitly use pure text format, with a newline at the end
|
||||
# this makes it impossible to see the animation in the progress bar
|
||||
# but will avoid messing up with ray or multiprocessing, which wraps
|
||||
# each line of output with some prefix.
|
||||
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501
|
||||
|
||||
|
||||
def np_cache_weights_iterator(
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str],
|
||||
hf_folder: str,
|
||||
hf_weights_files: List[str],
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model np files.
|
||||
|
||||
Will dump the model weights to numpy files if they are not already dumped.
|
||||
"""
|
||||
enable_tqdm = (
|
||||
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
|
||||
)
|
||||
# Convert the model weights from torch tensors to numpy arrays for
|
||||
# faster loading.
|
||||
np_folder = os.path.join(hf_folder, "np")
|
||||
os.makedirs(np_folder, exist_ok=True)
|
||||
weight_names_file = os.path.join(np_folder, "weight_names.json")
|
||||
# Use file lock to prevent multiple processes from
|
||||
# dumping the same model weights to numpy at the same time.
|
||||
with get_lock(model_name_or_path, cache_dir):
|
||||
if not os.path.exists(weight_names_file):
|
||||
weight_names: List[str] = []
|
||||
for bin_file in tqdm(
|
||||
hf_weights_files,
|
||||
desc="Loading np_cache checkpoint shards",
|
||||
disable=not enable_tqdm,
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
state = torch.load(bin_file, map_location="cpu")
|
||||
for name, param in state.items():
|
||||
param_path = os.path.join(np_folder, name)
|
||||
with open(param_path, "wb") as f:
|
||||
np.save(f, param.cpu().detach().numpy())
|
||||
weight_names.append(name)
|
||||
with open(weight_names_file, "w") as f:
|
||||
json.dump(weight_names, f)
|
||||
|
||||
with open(weight_names_file) as f:
|
||||
weight_names = json.load(f)
|
||||
|
||||
for name in weight_names:
|
||||
param_path = os.path.join(np_folder, name)
|
||||
with open(param_path, "rb") as f:
|
||||
param = np.load(f)
|
||||
yield name, torch.from_numpy(param)
|
||||
|
||||
|
||||
def safetensors_weights_iterator(
|
||||
hf_weights_files: List[str],
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model safetensor files."""
|
||||
enable_tqdm = (
|
||||
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
|
||||
)
|
||||
for st_file in tqdm(
|
||||
hf_weights_files,
|
||||
desc="Loading safetensors checkpoint shards",
|
||||
disable=not enable_tqdm,
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
with safe_open(st_file, framework="pt") as f:
|
||||
for name in f.keys(): # noqa: SIM118
|
||||
param = f.get_tensor(name)
|
||||
yield name, param
|
||||
|
||||
|
||||
def pt_weights_iterator(
|
||||
hf_weights_files: List[str],
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model bin/pt files."""
|
||||
enable_tqdm = (
|
||||
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
|
||||
)
|
||||
for bin_file in tqdm(
|
||||
hf_weights_files,
|
||||
desc="Loading pt checkpoint shards",
|
||||
disable=not enable_tqdm,
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
state = torch.load(bin_file, map_location="cpu")
|
||||
yield from state.items()
|
||||
del state
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def get_gguf_extra_tensor_names(
|
||||
gguf_file: str, gguf_to_hf_name_map: Dict[str, str]
|
||||
) -> List[str]:
|
||||
reader = gguf.GGUFReader(gguf_file)
|
||||
expected_gguf_keys = set(gguf_to_hf_name_map.keys())
|
||||
exact_gguf_keys = set([tensor.name for tensor in reader.tensors])
|
||||
extra_keys = expected_gguf_keys - exact_gguf_keys
|
||||
return [gguf_to_hf_name_map[key] for key in extra_keys]
|
||||
|
||||
|
||||
def gguf_quant_weights_iterator(
|
||||
gguf_file: str, gguf_to_hf_name_map: Dict[str, str]
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""
|
||||
Iterate over the quant weights in the model gguf files and convert
|
||||
them to torch tensors
|
||||
"""
|
||||
|
||||
reader = gguf.GGUFReader(gguf_file)
|
||||
|
||||
for tensor in reader.tensors:
|
||||
if tensor.name in gguf_to_hf_name_map:
|
||||
weight_type = tensor.tensor_type
|
||||
name = gguf_to_hf_name_map[tensor.name]
|
||||
|
||||
if weight_type.name != "F32":
|
||||
weight_type_name = name.replace("weight", "qweight_type")
|
||||
weight_type = torch.tensor(weight_type)
|
||||
yield weight_type_name, weight_type
|
||||
|
||||
for tensor in reader.tensors:
|
||||
if tensor.name in gguf_to_hf_name_map:
|
||||
weight = tensor.data
|
||||
weight_type = tensor.tensor_type
|
||||
name = gguf_to_hf_name_map[tensor.name]
|
||||
|
||||
if weight_type.name != "F32":
|
||||
name = name.replace("weight", "qweight")
|
||||
param = torch.tensor(weight)
|
||||
yield name, param
|
||||
|
||||
|
||||
def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
|
||||
"""convert PySafeSlice object from safetensors to torch.Tensor
|
||||
|
||||
PySafeSlice object supports indexing, which is done before loading the
|
||||
actual tensor and can reduce the amount of memory being read into the
|
||||
memory. However, it does not support more advanced functionalities
|
||||
like `.view()` or `.t()`. Therefore, if we need to modify the loaded
|
||||
tensor with these more complicated operators, we need to convert to
|
||||
tensor first.
|
||||
"""
|
||||
if not isinstance(x, torch.Tensor):
|
||||
x = x[:]
|
||||
return x
|
||||
|
||||
|
||||
def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
|
||||
"""Default weight loader."""
|
||||
try:
|
||||
if param.numel() == 1 and loaded_weight.numel() == 1:
|
||||
# Sometimes scalar values aren't considered tensors with shapes
|
||||
# so if both param and loaded_weight are a scalar,
|
||||
# "broadcast" instead of copy
|
||||
param.data.fill_(loaded_weight.item())
|
||||
else:
|
||||
assert param.size() == loaded_weight.size(), (
|
||||
f"Attempted to load weight ({loaded_weight.size()}) "
|
||||
f"into parameter ({param.size()})"
|
||||
)
|
||||
|
||||
param.data.copy_(loaded_weight)
|
||||
except Exception:
|
||||
# NOTE: This exception is added for the purpose of setting breakpoint to
|
||||
# debug weight loading issues.
|
||||
raise
|
||||
|
||||
|
||||
def row_parallel_weight_loader(
|
||||
param: torch.Tensor, loaded_weight: torch.Tensor
|
||||
) -> None:
|
||||
"""Load weights that are row-parallelized."""
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_dim = 0 if param.dim() != 1 else None
|
||||
|
||||
if shard_dim is not None:
|
||||
shard_size = param.data.shape[shard_dim]
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(shard_dim, start_idx, shard_size)
|
||||
|
||||
return default_weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
LoaderFunction = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
|
||||
|
||||
|
||||
def sharded_weight_loader(shard_axis: int) -> LoaderFunction:
|
||||
"""Create a weight loader that shards the weights along the given axis"""
|
||||
|
||||
def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
shard_size = param.data.shape[shard_axis]
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(shard_axis, start_idx, shard_size)
|
||||
|
||||
return default_weight_loader(param, loaded_weight)
|
||||
|
||||
return loader
|
||||
|
||||
|
||||
def composed_weight_loader(
|
||||
loader: LoaderFunction, fn: Callable[[torch.Tensor], torch.Tensor]
|
||||
) -> LoaderFunction:
|
||||
"""Create a weight loader that post-processes the weights after loading"""
|
||||
|
||||
def composed_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
|
||||
loader(param, loaded_weight)
|
||||
param.data.copy_(fn(param))
|
||||
return
|
||||
|
||||
return composed_loader
|
||||
|
||||
|
||||
def initialize_dummy_weights(
|
||||
model: torch.nn.Module,
|
||||
low: float = -1e-3,
|
||||
high: float = 1e-3,
|
||||
seed: int = 1234,
|
||||
) -> None:
|
||||
"""Initialize model weights with random values.
|
||||
|
||||
The model weights must be randomly initialized for accurate performance
|
||||
measurements. Additionally, the model weights should not cause NaNs in the
|
||||
forward pass. We empirically found that initializing the weights with
|
||||
values between -1e-3 and 1e-3 works well for most models.
|
||||
|
||||
We use per-parameter random seed, so that dummy weights are consistent,
|
||||
even if the model is partitioned across multiple devices. When the seed
|
||||
is fixed, the random values generated by this function only depends on
|
||||
the parameter's number of elements and its data type.
|
||||
"""
|
||||
for param in model.state_dict().values():
|
||||
if torch.is_floating_point(param):
|
||||
generator = torch.Generator(device=param.data.device)
|
||||
generator.manual_seed(seed)
|
||||
if torch.finfo(param.data.dtype).bits < 16:
|
||||
# uniform_ doesn't support < 16-bit datatypes (FP8)
|
||||
dtype = param.data.dtype
|
||||
tmp_param = param.data.to(torch.float16)
|
||||
tmp_param = tmp_param.uniform_(low, high, generator=generator).to(dtype)
|
||||
param.data.copy_(tmp_param)
|
||||
else:
|
||||
param.uniform_(low, high, generator=generator)
|
||||
|
||||
|
||||
def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
|
||||
"""Remap the name of FP8 k/v_scale parameters.
|
||||
|
||||
This function handles the remapping of FP8 k/v_scale parameter names.
|
||||
It detects if the given name ends with a suffix and attempts to remap
|
||||
it to the expected name format in the model. If the remapped name is not
|
||||
found in the params_dict, a warning is printed and None is returned.
|
||||
|
||||
Args:
|
||||
name (str): The original loaded checkpoint parameter name.
|
||||
params_dict (dict): Dictionary containing the model's named parameters.
|
||||
|
||||
Returns:
|
||||
str: The remapped parameter name if successful, or the original name
|
||||
if no remapping is needed.
|
||||
None: If the remapped name is not found in params_dict.
|
||||
"""
|
||||
if name.endswith(".kv_scale"):
|
||||
print_warning_once(
|
||||
"DEPRECATED. Found kv_scale in the checkpoint. "
|
||||
"This format is deprecated in favor of separate k_scale and "
|
||||
"v_scale tensors and will be removed in a future release. "
|
||||
"Functionally, we will remap kv_scale to k_scale and duplicate "
|
||||
"k_scale to v_scale"
|
||||
)
|
||||
# NOTE: we remap the deprecated kv_scale to k_scale
|
||||
remapped_name = name.replace(".kv_scale", ".attn.k_scale")
|
||||
if remapped_name not in params_dict:
|
||||
print_warning_once(
|
||||
f"Found kv_scale in the checkpoint (e.g. {name}), "
|
||||
"but not found the expected name in the model "
|
||||
f"(e.g. {remapped_name}). kv_scale is "
|
||||
"not loaded."
|
||||
)
|
||||
return None
|
||||
return remapped_name
|
||||
|
||||
possible_scale_names = [".k_scale", ".v_scale"]
|
||||
for scale_name in possible_scale_names:
|
||||
if name.endswith(scale_name):
|
||||
remapped_name = name.replace(scale_name, f".attn{scale_name}")
|
||||
if remapped_name not in params_dict:
|
||||
print_warning_once(
|
||||
f"Found {scale_name} in the checkpoint (e.g. {name}), "
|
||||
"but not found the expected name in the model "
|
||||
f"(e.g. {remapped_name}). {scale_name} is "
|
||||
"not loaded."
|
||||
)
|
||||
return None
|
||||
return remapped_name
|
||||
|
||||
# If there were no matches, return the untouched param name
|
||||
return name
|
||||
@@ -34,7 +34,6 @@ from vllm.model_executor.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
@@ -46,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
|
||||
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
||||
@@ -329,7 +329,6 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
position_embedding: str,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -404,13 +403,12 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
if config.hidden_size == 4096: # baichuan2 7b
|
||||
super().__init__(config, "ROPE", cache_config, quant_config)
|
||||
super().__init__(config, "ROPE", quant_config)
|
||||
else: # baichuan 13b, baichuan2 13b
|
||||
super().__init__(config, "ALIBI", cache_config, quant_config)
|
||||
super().__init__(config, "ALIBI", quant_config)
|
||||
|
||||
|
||||
EntryClass = [BaichuanForCausalLM]
|
||||
|
||||
@@ -23,7 +23,6 @@ from torch import nn
|
||||
from torch.nn import LayerNorm
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.transformers_utils.configs import ChatGLMConfig
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
@@ -41,6 +40,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
LoraConfig = None
|
||||
|
||||
@@ -50,7 +50,6 @@ class GLMAttention(nn.Module):
|
||||
self,
|
||||
config,
|
||||
layer_id: int = 0,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -186,7 +185,6 @@ class GLMBlock(nn.Module):
|
||||
self,
|
||||
config,
|
||||
layer_id: int,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -203,7 +201,7 @@ class GLMBlock(nn.Module):
|
||||
)
|
||||
|
||||
# Self attention.
|
||||
self.self_attention = GLMAttention(config, layer_id, cache_config, quant_config)
|
||||
self.self_attention = GLMAttention(config, layer_id, quant_config)
|
||||
self.hidden_dropout = config.hidden_dropout
|
||||
|
||||
# Layernorm on the attention output
|
||||
@@ -258,7 +256,6 @@ class GLMTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -269,10 +266,7 @@ class GLMTransformer(nn.Module):
|
||||
|
||||
# Transformer layers.
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
GLMBlock(config, i, cache_config, quant_config)
|
||||
for i in range(self.num_layers)
|
||||
]
|
||||
[GLMBlock(config, i, quant_config) for i in range(self.num_layers)]
|
||||
)
|
||||
|
||||
if self.post_layer_norm:
|
||||
@@ -306,7 +300,6 @@ class ChatGLMM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -318,7 +311,7 @@ class ChatGLMM(nn.Module):
|
||||
self.num_layers = config.num_layers
|
||||
self.multi_query_group_num = config.multi_query_group_num
|
||||
self.kv_channels = config.kv_channels
|
||||
self.encoder = GLMTransformer(config, cache_config, quant_config)
|
||||
self.encoder = GLMTransformer(config, quant_config)
|
||||
|
||||
self.output_layer = ParallelLMHead(config.padded_vocab_size, config.hidden_size)
|
||||
|
||||
@@ -357,15 +350,13 @@ class ChatGLMForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: ChatGLMConfig,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoraConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config: ChatGLMConfig = config
|
||||
self.quant_config = quant_config
|
||||
self.max_position_embeddings = getattr(config, "max_sequence_length", 8192)
|
||||
self.transformer = ChatGLMM(config, cache_config, quant_config)
|
||||
self.transformer = ChatGLMM(config, quant_config)
|
||||
self.lm_head = self.transformer.output_layer
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
|
||||
@@ -49,7 +49,6 @@ from vllm.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.linear import (
|
||||
@@ -62,6 +61,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import set_weight_attrs
|
||||
|
||||
|
||||
@@ -318,7 +318,6 @@ class CohereForCausalLM(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@@ -25,7 +25,6 @@ from vllm.distributed import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
||||
|
||||
from sglang.srt.layers.fused_moe_triton import fused_moe
|
||||
@@ -43,6 +42,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import set_weight_attrs
|
||||
|
||||
|
||||
@@ -366,7 +366,6 @@ class DbrxForCausalLM(nn.Module):
|
||||
self,
|
||||
config: DbrxConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@@ -27,7 +27,6 @@ from vllm.distributed import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.fused_moe_triton import fused_moe
|
||||
@@ -46,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
|
||||
class DeepseekMLP(nn.Module):
|
||||
@@ -184,7 +184,6 @@ class DeepseekAttention(nn.Module):
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -261,7 +260,6 @@ class DeepseekDecoderLayer(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
layer_id: int,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -277,7 +275,6 @@ class DeepseekDecoderLayer(nn.Module):
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
if (
|
||||
@@ -330,7 +327,6 @@ class DeepseekModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -343,9 +339,7 @@ class DeepseekModel(nn.Module):
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
DeepseekDecoderLayer(
|
||||
config, layer_id, cache_config, quant_config=quant_config
|
||||
)
|
||||
DeepseekDecoderLayer(config, layer_id, quant_config=quant_config)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
@@ -373,13 +367,12 @@ class DeepseekForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = DeepseekModel(config, cache_config, quant_config)
|
||||
self.model = DeepseekModel(config, quant_config)
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||
)
|
||||
|
||||
@@ -28,7 +28,6 @@ from vllm.distributed import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
||||
@@ -48,6 +47,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import is_flashinfer_available
|
||||
|
||||
if is_flashinfer_available():
|
||||
@@ -189,7 +189,6 @@ class DeepseekV2Attention(nn.Module):
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
layer_id=None,
|
||||
) -> None:
|
||||
@@ -337,7 +336,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
layer_id=None,
|
||||
use_dp=False,
|
||||
@@ -568,7 +566,6 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
layer_id: int,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -599,7 +596,6 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
layer_id=layer_id,
|
||||
use_dp=self.enable_dp_attention,
|
||||
@@ -619,7 +615,6 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
@@ -685,7 +680,6 @@ class DeepseekV2Model(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -702,7 +696,6 @@ class DeepseekV2Model(nn.Module):
|
||||
DeepseekV2DecoderLayer(
|
||||
config,
|
||||
layer_id,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
@@ -733,13 +726,12 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = DeepseekV2Model(config, cache_config, quant_config)
|
||||
self.model = DeepseekV2Model(config, quant_config)
|
||||
if global_server_args_dict["enable_dp_attention"]:
|
||||
self.lm_head = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
|
||||
@@ -22,7 +22,6 @@ import torch
|
||||
from torch import nn
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
@@ -39,6 +38,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
|
||||
class ExaoneGatedMLP(nn.Module):
|
||||
@@ -293,7 +293,6 @@ class ExaoneForCausalLM(nn.Module):
|
||||
self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@@ -21,10 +21,8 @@ from typing import Iterable, Optional, Tuple
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.activation import GeluAndMul
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
@@ -38,6 +36,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
|
||||
class GemmaMLP(nn.Module):
|
||||
@@ -278,10 +277,7 @@ class GemmaForCausalLM(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
cache_config=None,
|
||||
) -> None:
|
||||
del lora_config # Unused.
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
|
||||
@@ -20,12 +20,8 @@ from typing import Iterable, Optional, Set, Tuple, Union
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
|
||||
# from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.activation import GeluAndMul
|
||||
from sglang.srt.layers.layernorm import GemmaRMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
@@ -38,6 +34,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import make_layers
|
||||
|
||||
|
||||
@@ -106,7 +103,6 @@ class Gemma2Attention(nn.Module):
|
||||
head_dim: int,
|
||||
max_position_embeddings: int,
|
||||
rope_theta: float,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -191,7 +187,6 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
self,
|
||||
layer_id: int,
|
||||
config: PretrainedConfig,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -205,7 +200,6 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
head_dim=config.head_dim,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
rope_theta=config.rope_theta,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.hidden_size = config.hidden_size
|
||||
@@ -258,7 +252,6 @@ class Gemma2Model(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -273,7 +266,6 @@ class Gemma2Model(nn.Module):
|
||||
lambda idx, prefix: Gemma2DecoderLayer(
|
||||
layer_id=idx,
|
||||
config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
prefix="",
|
||||
@@ -342,15 +334,12 @@ class Gemma2ForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
) -> None:
|
||||
del lora_config # Unused.
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = Gemma2Model(config, cache_config, quant_config)
|
||||
self.model = Gemma2Model(config, quant_config)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -29,7 +29,6 @@ class Gemma2ForSequenceClassification(nn.Module):
|
||||
self,
|
||||
config: Gemma2Config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@@ -22,11 +22,9 @@ from typing import Iterable, List, Optional, Tuple
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import GPT2Config
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
# from sglang.srt.layers.activation import get_act_fn
|
||||
from sglang.srt.layers.linear import (
|
||||
@@ -39,6 +37,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
|
||||
class GPT2Attention(nn.Module):
|
||||
@@ -47,7 +46,6 @@ class GPT2Attention(nn.Module):
|
||||
self,
|
||||
layer_id: int,
|
||||
config: GPT2Config,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
@@ -140,7 +138,6 @@ class GPT2Block(nn.Module):
|
||||
self,
|
||||
layer_id: int,
|
||||
config: GPT2Config,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
@@ -150,7 +147,7 @@ class GPT2Block(nn.Module):
|
||||
|
||||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.attn = GPT2Attention(
|
||||
layer_id, config, cache_config, quant_config, prefix=f"{prefix}.attn"
|
||||
layer_id, config, quant_config, prefix=f"{prefix}.attn"
|
||||
)
|
||||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.mlp = GPT2MLP(inner_dim, config, quant_config, prefix=f"{prefix}.mlp")
|
||||
@@ -182,7 +179,6 @@ class GPT2Model(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: GPT2Config,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
@@ -196,7 +192,7 @@ class GPT2Model(nn.Module):
|
||||
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
||||
self.h = nn.ModuleList(
|
||||
[
|
||||
GPT2Block(i, config, cache_config, quant_config)
|
||||
GPT2Block(i, config, quant_config)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
@@ -226,15 +222,12 @@ class GPT2LMHeadModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: GPT2Config,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.transformer = GPT2Model(
|
||||
config, cache_config, quant_config, prefix="transformer"
|
||||
)
|
||||
self.transformer = GPT2Model(config, quant_config, prefix="transformer")
|
||||
self.lm_head = self.transformer.wte
|
||||
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
@@ -21,9 +21,7 @@ from typing import Iterable, Optional, Tuple
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import GPTBigCodeConfig
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.activation import get_act_fn
|
||||
from sglang.srt.layers.linear import (
|
||||
@@ -36,6 +34,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
|
||||
class GPTBigCodeAttention(nn.Module):
|
||||
@@ -44,7 +43,6 @@ class GPTBigCodeAttention(nn.Module):
|
||||
self,
|
||||
layer_id: int,
|
||||
config: GPTBigCodeConfig,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -145,7 +143,6 @@ class GPTBigCodeBlock(nn.Module):
|
||||
self,
|
||||
layer_id: int,
|
||||
config: GPTBigCodeConfig,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -153,7 +150,7 @@ class GPTBigCodeBlock(nn.Module):
|
||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
||||
|
||||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.attn = GPTBigCodeAttention(layer_id, config, cache_config, quant_config)
|
||||
self.attn = GPTBigCodeAttention(layer_id, config, quant_config)
|
||||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.mlp = GPTBigMLP(inner_dim, config, quant_config)
|
||||
|
||||
@@ -183,20 +180,14 @@ class GPTBigCodeModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: GPTBigCodeConfig,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
assert not config.add_cross_attention
|
||||
|
||||
self.embed_dim = config.hidden_size
|
||||
lora_vocab = (
|
||||
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
|
||||
if lora_config
|
||||
else 0
|
||||
)
|
||||
lora_vocab = 0
|
||||
self.vocab_size = config.vocab_size + lora_vocab
|
||||
self.wte = VocabParallelEmbedding(
|
||||
self.vocab_size, self.embed_dim, org_num_embeddings=config.vocab_size
|
||||
@@ -204,7 +195,7 @@ class GPTBigCodeModel(nn.Module):
|
||||
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
||||
self.h = nn.ModuleList(
|
||||
[
|
||||
GPTBigCodeBlock(i, config, cache_config, quant_config)
|
||||
GPTBigCodeBlock(i, config, quant_config)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
@@ -243,23 +234,16 @@ class GPTBigCodeForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: GPTBigCodeConfig,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.transformer = GPTBigCodeModel(
|
||||
config, cache_config, quant_config, lora_config
|
||||
)
|
||||
self.transformer = GPTBigCodeModel(config, quant_config)
|
||||
self.lm_head = self.transformer.wte
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -24,7 +24,6 @@ from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
@@ -43,6 +42,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.loader import DefaultModelLoader
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
|
||||
class Grok1MoE(nn.Module):
|
||||
@@ -285,7 +286,6 @@ class Grok1ForCausalLM(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@@ -21,7 +21,6 @@ from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
@@ -38,6 +37,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
|
||||
class InternLM2MLP(nn.Module):
|
||||
@@ -251,7 +251,6 @@ class InternLM2ForCausalLM(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@@ -29,7 +29,6 @@ class InternLM2ForRewardModel(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@@ -24,7 +24,6 @@ from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
@@ -44,6 +43,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import make_layers
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
@@ -300,7 +300,6 @@ class LlamaForCausalLM(nn.Module):
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@@ -17,11 +17,11 @@ from typing import Iterable, Optional, Tuple
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
|
||||
|
||||
|
||||
@@ -30,7 +30,6 @@ class LlamaForClassification(nn.Module):
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@@ -3,10 +3,10 @@ from typing import Iterable, Tuple
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
||||
from sglang.srt.model_executor.model_runner import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.llama import LlamaModel
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ class LlamaEmbeddingModel(nn.Module):
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
quant_config=None,
|
||||
cache_config=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.model = LlamaModel(config, quant_config=quant_config)
|
||||
|
||||
@@ -21,6 +21,7 @@ from transformers import LlamaConfig
|
||||
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
|
||||
|
||||
|
||||
@@ -29,7 +30,6 @@ class LlamaForSequenceClassification(nn.Module):
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -84,9 +84,8 @@ class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassific
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config=None,
|
||||
) -> None:
|
||||
super().__init__(config, quant_config, cache_config)
|
||||
super().__init__(config, quant_config)
|
||||
self.weights = self.Weights(config.hidden_size, self.num_labels)
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -29,7 +29,6 @@ from transformers import (
|
||||
SiglipVisionModel,
|
||||
)
|
||||
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
@@ -39,6 +38,7 @@ from sglang.srt.mm_utils import (
|
||||
unpad_image_shape,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.llama import LlamaForCausalLM
|
||||
from sglang.srt.models.mistral import MistralForCausalLM
|
||||
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||
@@ -451,7 +451,6 @@ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM):
|
||||
self,
|
||||
config: LlavaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -473,7 +472,6 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
|
||||
self,
|
||||
config: LlavaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -506,7 +504,6 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
|
||||
self,
|
||||
config: LlavaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -20,11 +20,11 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import CLIPVisionModel, LlavaConfig
|
||||
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.llama import LlamaForCausalLM
|
||||
|
||||
|
||||
@@ -33,7 +33,6 @@ class LlavaVidForCausalLM(nn.Module):
|
||||
self,
|
||||
config: LlavaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@@ -20,7 +20,6 @@ import torch
|
||||
from torch import nn
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
@@ -37,6 +36,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
|
||||
class MiniCPMMLP(nn.Module):
|
||||
@@ -275,7 +275,6 @@ class MiniCPMForCausalLM(nn.Module):
|
||||
self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@@ -27,7 +27,6 @@ from vllm.model_executor.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
@@ -40,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import is_flashinfer_available
|
||||
|
||||
if is_flashinfer_available():
|
||||
@@ -105,7 +105,6 @@ class MiniCPM3Attention(nn.Module):
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
layer_id=None,
|
||||
) -> None:
|
||||
@@ -249,7 +248,6 @@ class MiniCPM3AttentionMLA(nn.Module):
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
layer_id=None,
|
||||
) -> None:
|
||||
@@ -406,7 +404,6 @@ class MiniCPM3DecoderLayer(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
layer_id: int,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -430,7 +427,6 @@ class MiniCPM3DecoderLayer(nn.Module):
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
@@ -449,7 +445,6 @@ class MiniCPM3DecoderLayer(nn.Module):
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
@@ -498,7 +493,6 @@ class MiniCPM3Model(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -512,9 +506,7 @@ class MiniCPM3Model(nn.Module):
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
MiniCPM3DecoderLayer(
|
||||
config, i, cache_config=cache_config, quant_config=quant_config
|
||||
)
|
||||
MiniCPM3DecoderLayer(config, i, quant_config=quant_config)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
@@ -549,7 +541,6 @@ class MiniCPM3ForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -557,9 +548,7 @@ class MiniCPM3ForCausalLM(nn.Module):
|
||||
|
||||
self.num_experts = getattr(self.config, "num_experts", 0)
|
||||
self.quant_config = quant_config
|
||||
self.model = MiniCPM3Model(
|
||||
config, cache_config=cache_config, quant_config=quant_config
|
||||
)
|
||||
self.model = MiniCPM3Model(config, quant_config=quant_config)
|
||||
# self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
if not self.config.tie_word_embeddings:
|
||||
self.lm_head = ParallelLMHead(
|
||||
|
||||
@@ -23,7 +23,6 @@ from torch import nn
|
||||
from transformers import MixtralConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
@@ -42,6 +41,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
|
||||
class MixtralMoE(nn.Module):
|
||||
@@ -291,7 +291,6 @@ class MixtralForCausalLM(nn.Module):
|
||||
self,
|
||||
config: MixtralConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@@ -29,7 +29,6 @@ from vllm.distributed import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
@@ -45,6 +44,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
|
||||
class MixtralMLP(nn.Module):
|
||||
@@ -324,7 +324,6 @@ class QuantMixtralForCausalLM(nn.Module):
|
||||
self,
|
||||
config: MixtralConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@@ -15,7 +15,6 @@ from transformers.models.mllama.modeling_mllama import (
|
||||
_prepare_aspect_ratio_attention_mask,
|
||||
)
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.activation import get_act_fn
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
@@ -34,6 +33,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.llama import LlamaDecoderLayer, LlamaMLP
|
||||
|
||||
|
||||
@@ -654,7 +654,6 @@ class MllamaTextModel(nn.Module):
|
||||
self,
|
||||
config: config_mllama.MllamaTextConfig,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
cache_config=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.padding_id = config.pad_token_id
|
||||
@@ -732,11 +731,10 @@ class MllamaForCausalLM(nn.Module):
|
||||
self,
|
||||
config: config_mllama.MllamaTextConfig,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
cache_config=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.vocab_size = config.vocab_size
|
||||
self.model = MllamaTextModel(config, cache_config, quant_config)
|
||||
self.model = MllamaTextModel(config, quant_config)
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
@@ -772,7 +770,6 @@ class MllamaForConditionalGeneration(nn.Module):
|
||||
self,
|
||||
config: config_mllama.MllamaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
@@ -787,7 +784,6 @@ class MllamaForConditionalGeneration(nn.Module):
|
||||
self.vision_model = MllamaVisionModel(config.vision_config)
|
||||
self.language_model = MllamaForCausalLM(
|
||||
config.text_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.multi_modal_projector = nn.Linear(
|
||||
|
||||
@@ -22,7 +22,6 @@ from torch import nn
|
||||
from transformers import OlmoConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.linear import (
|
||||
@@ -38,6 +37,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import make_layers
|
||||
|
||||
|
||||
@@ -274,7 +274,6 @@ class OlmoForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: OlmoConfig,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -312,7 +312,6 @@ class Olmo2ForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -34,8 +34,6 @@ from vllm.model_executor.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
||||
@@ -48,7 +46,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import make_layers
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import make_layers, print_warning_once
|
||||
|
||||
|
||||
class OlmoeMoE(nn.Module):
|
||||
@@ -300,7 +299,6 @@ class OlmoeForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -7,8 +7,6 @@ from transformers import Phi3Config
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.utils import make_layers
|
||||
|
||||
from sglang.srt.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
@@ -27,6 +25,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import make_layers
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@@ -235,7 +235,6 @@ class Phi3SmallDecoderLayer(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
layer_id: int,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -286,7 +285,6 @@ class Phi3SmallModel(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
cache_config = None
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size, config.hidden_size
|
||||
)
|
||||
@@ -294,7 +292,7 @@ class Phi3SmallModel(nn.Module):
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: Phi3SmallDecoderLayer(
|
||||
config, int(prefix.split(".")[-1]), cache_config, quant_config
|
||||
config, int(prefix.split(".")[-1]), quant_config
|
||||
),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
@@ -339,7 +337,6 @@ class Phi3SmallForCausalLM(nn.Module):
|
||||
self,
|
||||
config: Phi3Config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config=None,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
@@ -22,7 +22,6 @@ from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
@@ -39,6 +38,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
|
||||
class QWenMLP(nn.Module):
|
||||
@@ -242,7 +242,6 @@ class QWenLMHeadModel(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@@ -22,7 +22,6 @@ import torch
|
||||
from torch import nn
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
@@ -40,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import make_layers
|
||||
|
||||
Qwen2Config = None
|
||||
@@ -271,7 +271,6 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
self,
|
||||
config: Qwen2Config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@@ -27,7 +27,6 @@ from vllm.distributed import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
||||
@@ -48,6 +47,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
|
||||
class Qwen2MoeMLP(nn.Module):
|
||||
@@ -158,7 +158,6 @@ class Qwen2MoeAttention(nn.Module):
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -234,7 +233,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
layer_id: int,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -250,7 +248,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
@@ -304,7 +301,6 @@ class Qwen2MoeModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -317,9 +313,7 @@ class Qwen2MoeModel(nn.Module):
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
Qwen2MoeDecoderLayer(
|
||||
config, layer_id, cache_config, quant_config=quant_config
|
||||
)
|
||||
Qwen2MoeDecoderLayer(config, layer_id, quant_config=quant_config)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
@@ -353,14 +347,13 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.torchao_config = global_server_args_dict["torchao_config"]
|
||||
self.model = Qwen2MoeModel(config, cache_config, quant_config)
|
||||
self.model = Qwen2MoeModel(config, quant_config)
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||
)
|
||||
|
||||
@@ -30,12 +30,10 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.distributed import parallel_state
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import QuickGELU
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
|
||||
from sglang.srt.hf_transformers_utils import get_processor
|
||||
@@ -49,6 +47,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.qwen2 import Qwen2Model
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -536,7 +535,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen2VLConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
99
python/sglang/srt/models/registry.py
Normal file
99
python/sglang/srt/models/registry.py
Normal file
@@ -0,0 +1,99 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/models/registry.py
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import pkgutil
|
||||
from dataclasses import dataclass, field
|
||||
from functools import lru_cache
|
||||
from typing import AbstractSet, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ModelRegistry:
|
||||
# Keyed by model_arch
|
||||
models: Dict[str, Union[Type[nn.Module], str]] = field(default_factory=dict)
|
||||
|
||||
def get_supported_archs(self) -> AbstractSet[str]:
|
||||
return self.models.keys()
|
||||
|
||||
def _raise_for_unsupported(self, architectures: List[str]):
|
||||
all_supported_archs = self.get_supported_archs()
|
||||
|
||||
if any(arch in all_supported_archs for arch in architectures):
|
||||
raise ValueError(
|
||||
f"Model architectures {architectures} failed "
|
||||
"to be inspected. Please check the logs for more details."
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Model architectures {architectures} are not supported for now. "
|
||||
f"Supported architectures: {all_supported_archs}"
|
||||
)
|
||||
|
||||
def _try_load_model_cls(self, model_arch: str) -> Optional[Type[nn.Module]]:
|
||||
if model_arch not in self.models:
|
||||
return None
|
||||
|
||||
return self.models[model_arch]
|
||||
|
||||
def _normalize_archs(
|
||||
self,
|
||||
architectures: Union[str, List[str]],
|
||||
) -> List[str]:
|
||||
if isinstance(architectures, str):
|
||||
architectures = [architectures]
|
||||
if not architectures:
|
||||
logger.warning("No model architectures are specified")
|
||||
|
||||
return architectures
|
||||
|
||||
def resolve_model_cls(
|
||||
self,
|
||||
architectures: Union[str, List[str]],
|
||||
) -> Tuple[Type[nn.Module], str]:
|
||||
architectures = self._normalize_archs(architectures)
|
||||
|
||||
for arch in architectures:
|
||||
model_cls = self._try_load_model_cls(arch)
|
||||
if model_cls is not None:
|
||||
return (model_cls, arch)
|
||||
|
||||
return self._raise_for_unsupported(architectures)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def import_model_classes():
|
||||
model_arch_name_to_cls = {}
|
||||
package_name = "sglang.srt.models"
|
||||
package = importlib.import_module(package_name)
|
||||
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
|
||||
if not ispkg:
|
||||
try:
|
||||
module = importlib.import_module(name)
|
||||
except Exception as e:
|
||||
logger.warning(f"Ignore import error when loading {name}. " f"{e}")
|
||||
continue
|
||||
if hasattr(module, "EntryClass"):
|
||||
entry = module.EntryClass
|
||||
if isinstance(
|
||||
entry, list
|
||||
): # To support multiple model classes in one module
|
||||
for tmp in entry:
|
||||
assert (
|
||||
tmp.__name__ not in model_arch_name_to_cls
|
||||
), f"Duplicated model implementation for {tmp.__name__}"
|
||||
model_arch_name_to_cls[tmp.__name__] = tmp
|
||||
else:
|
||||
assert (
|
||||
entry.__name__ not in model_arch_name_to_cls
|
||||
), f"Duplicated model implementation for {entry.__name__}"
|
||||
model_arch_name_to_cls[entry.__name__] = entry
|
||||
|
||||
return model_arch_name_to_cls
|
||||
|
||||
|
||||
ModelRegistry = _ModelRegistry(import_model_classes())
|
||||
@@ -26,7 +26,6 @@ from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.linear import (
|
||||
@@ -42,6 +41,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
|
||||
class StablelmMLP(nn.Module):
|
||||
@@ -242,7 +242,6 @@ class StableLmForCausalLM(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@@ -52,7 +52,6 @@ from vllm.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
@@ -66,6 +65,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
@@ -388,7 +388,6 @@ class TorchNativeLlamaForCausalLM(nn.Module):
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@@ -30,7 +30,6 @@ from vllm.model_executor.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
@@ -40,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.model_executor.model_runner import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
|
||||
class XverseMLP(nn.Module):
|
||||
@@ -295,8 +295,6 @@ class XverseForCausalLM(nn.Module):
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config=None,
|
||||
efficient_weight_load=False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@@ -32,7 +32,6 @@ from vllm.model_executor.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.fused_moe_triton import fused_moe
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
@@ -43,6 +42,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
|
||||
class XverseMLP(nn.Module):
|
||||
@@ -181,7 +181,6 @@ class XverseAttention(nn.Module):
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -258,7 +257,6 @@ class XverseDecoderLayer(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
layer_id: int,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -277,7 +275,6 @@ class XverseDecoderLayer(nn.Module):
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
if config.num_experts is not None:
|
||||
@@ -326,7 +323,6 @@ class XverseModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -339,9 +335,7 @@ class XverseModel(nn.Module):
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
XverseDecoderLayer(
|
||||
config, layer_id, cache_config, quant_config=quant_config
|
||||
)
|
||||
XverseDecoderLayer(config, layer_id, quant_config=quant_config)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
@@ -369,13 +363,12 @@ class XverseMoeForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config=None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = XverseModel(config, cache_config, quant_config)
|
||||
self.model = XverseModel(config, quant_config)
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||
)
|
||||
|
||||
@@ -18,9 +18,9 @@ from typing import Iterable, Optional, Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import CLIPVisionModel, LlavaConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.llava import LlavaLlamaForCausalLM
|
||||
|
||||
|
||||
@@ -29,9 +29,8 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
|
||||
self,
|
||||
config: LlavaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config=None,
|
||||
) -> None:
|
||||
super().__init__(config, quant_config, cache_config)
|
||||
super().__init__(config, quant_config)
|
||||
|
||||
self.multi_modal_projector = YiVLMultiModalProjector(self.config)
|
||||
self.vision_tower_subfolder = self.config.mm_vision_tower.replace(
|
||||
|
||||
@@ -50,6 +50,7 @@ class ServerArgs:
|
||||
served_model_name: Optional[str] = None
|
||||
chat_template: Optional[str] = None
|
||||
is_embedding: bool = False
|
||||
revision: Optional[str] = None
|
||||
|
||||
# Port
|
||||
host: str = "127.0.0.1"
|
||||
@@ -341,6 +342,14 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Whether to use a CausalLM as an embedding model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The specific model version to use. It can be a branch "
|
||||
"name, a tag name, or a commit id. If unspecified, will use "
|
||||
"the default version.",
|
||||
)
|
||||
|
||||
# Memory and scheduling
|
||||
parser.add_argument(
|
||||
|
||||
@@ -430,16 +430,12 @@ def suppress_other_loggers():
|
||||
from vllm.logger import logger as vllm_default_logger
|
||||
|
||||
vllm_default_logger.setLevel(logging.WARN)
|
||||
logging.getLogger("vllm.config").setLevel(logging.ERROR)
|
||||
logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel(
|
||||
logging.WARN
|
||||
)
|
||||
logging.getLogger("vllm.distributed.device_communicators.shm_broadcast").setLevel(
|
||||
logging.WARN
|
||||
)
|
||||
logging.getLogger("vllm.selector").setLevel(logging.WARN)
|
||||
logging.getLogger("vllm.utils").setLevel(logging.ERROR)
|
||||
logging.getLogger("vllm.model_executor.model_loader.loader").setLevel(logging.ERROR)
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore", category=UserWarning, message="The given NumPy array is not writable"
|
||||
@@ -492,27 +488,6 @@ def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = N
|
||||
pass
|
||||
|
||||
|
||||
def monkey_patch_vllm_model_config():
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
if not hasattr(ModelConfig, "_resolve_task"):
|
||||
return
|
||||
|
||||
def _resolve_task(
|
||||
self,
|
||||
task_option,
|
||||
hf_config,
|
||||
):
|
||||
supported_tasks = {
|
||||
"generate": True,
|
||||
"embedding": False,
|
||||
}
|
||||
selected_task = "generate"
|
||||
return supported_tasks, selected_task
|
||||
|
||||
setattr(ModelConfig, "_resolve_task", _resolve_task)
|
||||
|
||||
|
||||
def monkey_patch_vllm_p2p_access_check(gpu_id: int):
|
||||
"""
|
||||
Monkey patch the slow p2p access check in vllm.
|
||||
@@ -1041,6 +1016,11 @@ def crash_on_warnings():
|
||||
return get_bool_env_var("SGLANG_IS_IN_CI")
|
||||
|
||||
|
||||
def print_warning_once(msg: str) -> None:
|
||||
# Set the stacklevel to 2 to print the caller's line info
|
||||
logger.warning(msg, stacklevel=2)
|
||||
|
||||
|
||||
def get_device_name(device_id: int = 0) -> str:
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return torch.cuda.get_device_name(device_id)
|
||||
@@ -1055,6 +1035,33 @@ def get_device_name(device_id: int = 0) -> str:
|
||||
return torch.hpu.get_device_name(device_id)
|
||||
|
||||
|
||||
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
|
||||
major, minor = None, None
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
major, minor = torch.cuda.get_device_capability(device_id)
|
||||
|
||||
if hasattr(torch, "hip") and torch.hip.is_available():
|
||||
major, minor = torch.cuda.get_device_capability(device_id)
|
||||
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
major, minor, *_ = torch.xpu.get_device_capability(device_id)["version"].split(
|
||||
"."
|
||||
)
|
||||
major, minor = int(major), int(minor)
|
||||
|
||||
# TODO(HandH1998): `get_device_capability` is not supported by `torch.hpu` for now.
|
||||
# Update this once the support is available.
|
||||
if hasattr(torch, "hpu") and torch.hpu.is_available():
|
||||
try:
|
||||
major, minor = torch.hpu.get_device_capability(device_id)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"An error occurred while getting device capability of hpu: {e}."
|
||||
) from e
|
||||
|
||||
return major, minor
|
||||
|
||||
|
||||
sglang_lib = Library("sglang", "FRAGMENT") # noqa
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user