Co-authored-by: HandH1998 <1335248067@qq.com>
This commit is contained in:
17
python/sglang/srt/configs/device_config.py
Normal file
17
python/sglang/srt/configs/device_config.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DeviceConfig:
|
||||
device: Optional[torch.device]
|
||||
|
||||
def __init__(self, device: str = "cuda") -> None:
|
||||
if device in ["cuda", "xpu", "hpu"]:
|
||||
self.device_type = device
|
||||
else:
|
||||
raise RuntimeError(f"Not supported device type: {device}")
|
||||
self.device = torch.device(self.device_type)
|
||||
84
python/sglang/srt/configs/load_config.py
Normal file
84
python/sglang/srt/configs/load_config.py
Normal file
@@ -0,0 +1,84 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
||||
import enum
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoadFormat(str, enum.Enum):
|
||||
AUTO = "auto"
|
||||
PT = "pt"
|
||||
SAFETENSORS = "safetensors"
|
||||
NPCACHE = "npcache"
|
||||
DUMMY = "dummy"
|
||||
SHARDED_STATE = "sharded_state"
|
||||
GGUF = "gguf"
|
||||
BITSANDBYTES = "bitsandbytes"
|
||||
MISTRAL = "mistral"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadConfig:
|
||||
"""
|
||||
download_dir: Directory to download and load the weights, default to the
|
||||
default cache directory of huggingface.
|
||||
load_format: The format of the model weights to load:
|
||||
"auto" will try to load the weights in the safetensors format and
|
||||
fall back to the pytorch bin format if safetensors format is
|
||||
not available.
|
||||
"pt" will load the weights in the pytorch bin format.
|
||||
"safetensors" will load the weights in the safetensors format.
|
||||
"npcache" will load the weights in pytorch format and store
|
||||
a numpy cache to speed up the loading.
|
||||
"dummy" will initialize the weights with random values, which is
|
||||
mainly for profiling.
|
||||
"bitsandbytes" will load nf4 type weights.
|
||||
ignore_patterns: The list of patterns to ignore when loading the model.
|
||||
Default to "original/**/*" to avoid repeated loading of llama's
|
||||
checkpoints.
|
||||
|
||||
"""
|
||||
|
||||
load_format: Union[str, LoadFormat] = LoadFormat.AUTO
|
||||
download_dir: Optional[str] = None
|
||||
model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)
|
||||
ignore_patterns: Optional[Union[List[str], str]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
model_loader_extra_config = self.model_loader_extra_config or {}
|
||||
if isinstance(model_loader_extra_config, str):
|
||||
self.model_loader_extra_config = json.loads(model_loader_extra_config)
|
||||
self._verify_load_format()
|
||||
|
||||
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
|
||||
logger.info(
|
||||
"Ignoring the following patterns when downloading weights: %s",
|
||||
self.ignore_patterns,
|
||||
)
|
||||
else:
|
||||
self.ignore_patterns = ["original/**/*"]
|
||||
|
||||
def _verify_load_format(self) -> None:
|
||||
if not isinstance(self.load_format, str):
|
||||
return
|
||||
|
||||
load_format = self.load_format.lower()
|
||||
self.load_format = LoadFormat(load_format)
|
||||
|
||||
rocm_not_supported_load_format: List[str] = []
|
||||
if is_hip() and load_format in rocm_not_supported_load_format:
|
||||
rocm_supported_load_format = [
|
||||
f
|
||||
for f in LoadFormat.__members__
|
||||
if (f not in rocm_not_supported_load_format)
|
||||
]
|
||||
raise ValueError(
|
||||
f"load format '{load_format}' is not supported in ROCm. "
|
||||
f"Supported load formats are "
|
||||
f"{rocm_supported_load_format}"
|
||||
)
|
||||
@@ -15,12 +15,14 @@
|
||||
import json
|
||||
import logging
|
||||
from enum import IntEnum, auto
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_config, get_context_length
|
||||
from sglang.srt.utils import get_bool_env_var
|
||||
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
|
||||
from sglang.srt.utils import get_bool_env_var, is_hip
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -33,17 +35,22 @@ class AttentionArch(IntEnum):
|
||||
class ModelConfig:
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
model_path: str,
|
||||
trust_remote_code: bool = True,
|
||||
revision: Optional[str] = None,
|
||||
context_length: Optional[int] = None,
|
||||
model_override_args: Optional[dict] = None,
|
||||
is_embedding: Optional[bool] = None,
|
||||
dtype: str = "auto",
|
||||
quantization: Optional[str] = None,
|
||||
) -> None:
|
||||
self.model_path = model_path
|
||||
self.revision = revision
|
||||
self.quantization = quantization
|
||||
# Parse args
|
||||
self.model_override_args = json.loads(model_override_args)
|
||||
self.hf_config = get_config(
|
||||
path,
|
||||
model_path,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
model_override_args=self.model_override_args,
|
||||
@@ -56,6 +63,7 @@ class ModelConfig:
|
||||
)
|
||||
self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
|
||||
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
|
||||
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
||||
|
||||
# Derive context length
|
||||
derived_context_len = get_context_length(self.hf_text_config)
|
||||
@@ -116,6 +124,8 @@ class ModelConfig:
|
||||
self.num_hidden_layers = self.hf_text_config.num_hidden_layers
|
||||
self.vocab_size = self.hf_text_config.vocab_size
|
||||
|
||||
self._verify_quantization()
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
|
||||
def get_total_num_kv_heads(self) -> int:
|
||||
"""Returns the total number of KV heads."""
|
||||
@@ -174,6 +184,86 @@ class ModelConfig:
|
||||
# parallel size so each GPU has at least one KV head.
|
||||
return max(1, total_num_kv_heads // tensor_parallel_size)
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
||||
def _parse_quant_hf_config(self):
|
||||
quant_cfg = getattr(self.hf_config, "quantization_config", None)
|
||||
if quant_cfg is None:
|
||||
# compressed-tensors uses a "compression_config" key
|
||||
quant_cfg = getattr(self.hf_config, "compression_config", None)
|
||||
return quant_cfg
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
||||
def _verify_quantization(self) -> None:
|
||||
supported_quantization = [*QUANTIZATION_METHODS]
|
||||
rocm_supported_quantization = [
|
||||
"awq",
|
||||
"gptq",
|
||||
"fp8",
|
||||
"compressed_tensors",
|
||||
"compressed-tensors",
|
||||
"fbgemm_fp8",
|
||||
]
|
||||
optimized_quantization_methods = [
|
||||
"fp8",
|
||||
"marlin",
|
||||
"modelopt",
|
||||
"gptq_marlin_24",
|
||||
"gptq_marlin",
|
||||
"awq_marlin",
|
||||
"fbgemm_fp8",
|
||||
"compressed_tensors",
|
||||
"compressed-tensors",
|
||||
"experts_int8",
|
||||
]
|
||||
if self.quantization is not None:
|
||||
self.quantization = self.quantization.lower()
|
||||
|
||||
# Parse quantization method from the HF model config, if available.
|
||||
quant_cfg = self._parse_quant_hf_config()
|
||||
|
||||
if quant_cfg is not None:
|
||||
quant_method = quant_cfg.get("quant_method", "").lower()
|
||||
|
||||
# Detect which checkpoint is it
|
||||
for _, method in QUANTIZATION_METHODS.items():
|
||||
quantization_override = method.override_quantization_method(
|
||||
quant_cfg, self.quantization
|
||||
)
|
||||
if quantization_override:
|
||||
quant_method = quantization_override
|
||||
self.quantization = quantization_override
|
||||
break
|
||||
|
||||
# Verify quantization configurations.
|
||||
if self.quantization is None:
|
||||
self.quantization = quant_method
|
||||
elif self.quantization != quant_method:
|
||||
raise ValueError(
|
||||
"Quantization method specified in the model config "
|
||||
f"({quant_method}) does not match the quantization "
|
||||
f"method specified in the `quantization` argument "
|
||||
f"({self.quantization})."
|
||||
)
|
||||
|
||||
if self.quantization is not None:
|
||||
if self.quantization not in supported_quantization:
|
||||
raise ValueError(
|
||||
f"Unknown quantization method: {self.quantization}. Must "
|
||||
f"be one of {supported_quantization}."
|
||||
)
|
||||
if is_hip() and self.quantization not in rocm_supported_quantization:
|
||||
raise ValueError(
|
||||
f"{self.quantization} quantization is currently not "
|
||||
f"supported in ROCm."
|
||||
)
|
||||
if self.quantization not in optimized_quantization_methods:
|
||||
logger.warning(
|
||||
"%s quantization is not fully "
|
||||
"optimized yet. The speed can be slower than "
|
||||
"non-quantized models.",
|
||||
self.quantization,
|
||||
)
|
||||
|
||||
|
||||
def get_hf_text_config(config: PretrainedConfig):
|
||||
"""Get the "sub" config relevant to llm for multi modal models.
|
||||
@@ -183,6 +273,9 @@ def get_hf_text_config(config: PretrainedConfig):
|
||||
if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"):
|
||||
# We support non-hf version of llava models, so we do not want to
|
||||
# read the wrong values from the unused default text_config.
|
||||
# NOTE(HandH1998): We set `torch_dtype` of config to `torch.float16` for the weights, as
|
||||
# `torch.float16` is default used for image features in `python/sglang/srt/models/llava.py`.
|
||||
setattr(config, "torch_dtype", torch.float16)
|
||||
return config
|
||||
|
||||
if hasattr(config, "text_config"):
|
||||
@@ -195,6 +288,70 @@ def get_hf_text_config(config: PretrainedConfig):
|
||||
return config
|
||||
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"half": torch.float16,
|
||||
"float16": torch.float16,
|
||||
"float": torch.float32,
|
||||
"float32": torch.float32,
|
||||
"bfloat16": torch.bfloat16,
|
||||
}
|
||||
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
||||
def _get_and_verify_dtype(
|
||||
config: PretrainedConfig,
|
||||
dtype: Union[str, torch.dtype],
|
||||
) -> torch.dtype:
|
||||
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
|
||||
# because config.torch_dtype can be None.
|
||||
config_dtype = getattr(config, "torch_dtype", None)
|
||||
if config_dtype is None:
|
||||
config_dtype = torch.float32
|
||||
|
||||
if isinstance(dtype, str):
|
||||
dtype = dtype.lower()
|
||||
if dtype == "auto":
|
||||
if config_dtype == torch.float32:
|
||||
if config.model_type == "gemma2":
|
||||
logger.info(
|
||||
"For Gemma 2, we downcast float32 to bfloat16 instead "
|
||||
"of float16 by default. Please specify `dtype` if you "
|
||||
"want to use float16."
|
||||
)
|
||||
torch_dtype = torch.bfloat16
|
||||
else:
|
||||
# Following the common practice, we use float16 for float32
|
||||
# models.
|
||||
torch_dtype = torch.float16
|
||||
else:
|
||||
torch_dtype = config_dtype
|
||||
else:
|
||||
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||
elif isinstance(dtype, torch.dtype):
|
||||
torch_dtype = dtype
|
||||
else:
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
|
||||
# Verify the dtype.
|
||||
if torch_dtype != config_dtype:
|
||||
if torch_dtype == torch.float32:
|
||||
# Upcasting to float32 is allowed.
|
||||
logger.info("Upcasting %s to %s.", config_dtype, torch_dtype)
|
||||
pass
|
||||
elif config_dtype == torch.float32:
|
||||
# Downcasting from float32 to float16 or bfloat16 is allowed.
|
||||
logger.info("Downcasting %s to %s.", config_dtype, torch_dtype)
|
||||
pass
|
||||
else:
|
||||
# Casting between float16 and bfloat16 is allowed with a warning.
|
||||
logger.warning("Casting %s to %s.", config_dtype, torch_dtype)
|
||||
|
||||
return torch_dtype
|
||||
|
||||
|
||||
def is_generation_model(model_architectures: List[str], is_embedding: bool = False):
|
||||
# We have two ways to determine whether a model is a generative model.
|
||||
# 1. Check the model architectue
|
||||
|
||||
@@ -121,13 +121,10 @@ class Qwen2VLConfig(PretrainedConfig):
|
||||
self.attention_dropout = attention_dropout
|
||||
self.rope_scaling = rope_scaling
|
||||
|
||||
# NOTE: the following section from original transformers config
|
||||
# for Qwen2-VL is commented out to address rope config loading issue
|
||||
#
|
||||
# if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
# if self.rope_scaling["type"] == "mrope":
|
||||
# self.rope_scaling["type"] = "default"
|
||||
# self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
# rope_config_validation(self)
|
||||
# NOTE(HandH1998): This is necessary for configuring the `rope_type`` of qwen2vl models after removing dependencies on vllm.
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
if self.rope_scaling["type"] == "mrope":
|
||||
self.rope_scaling["type"] = "default"
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user