Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -23,6 +23,10 @@ from vllm.model_executor.model_loader.utils import (
|
||||
get_model_architecture,
|
||||
get_model_cls,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
padding_weight_loader
|
||||
)
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.torchao import torchao_version_at_least
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
@@ -32,6 +33,8 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
)
|
||||
from vllm.tracing import instrument
|
||||
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
|
||||
from vllm import envs
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -287,8 +290,7 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
self.load_config.safetensors_load_strategy = "torchao"
|
||||
|
||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
||||
all_weights = self.get_all_weights(model_config, model)
|
||||
loaded_weights = model.load_weights(all_weights)
|
||||
loaded_weights = model.load_weights(self.get_all_weights(model_config, model))
|
||||
|
||||
self.counter_after_loading_weights = time.perf_counter()
|
||||
logger.info_once(
|
||||
@@ -298,7 +300,8 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
)
|
||||
# We only enable strict check for non-quantized models
|
||||
# that have loaded weights tracking currently.
|
||||
if model_config.quantization is None and loaded_weights is not None:
|
||||
opt_flag = envs.VLLM_MOE_OPT_LEVEL != 0 or envs.VLLM_LINEAR_OPT_LEVEL != 0
|
||||
if model_config.quantization is None and loaded_weights is not None and not opt_flag:
|
||||
weights_not_loaded = weights_to_load - loaded_weights
|
||||
if weights_not_loaded:
|
||||
raise ValueError(
|
||||
|
||||
@@ -39,8 +39,6 @@ from vllm.platforms import current_platform
|
||||
from vllm.tracing import instrument
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
import ixformer.inference.functions as ixfop
|
||||
|
||||
try:
|
||||
from runai_model_streamer import SafetensorsStreamer
|
||||
except ImportError:
|
||||
@@ -289,7 +287,17 @@ def get_quant_config(
|
||||
)
|
||||
|
||||
if hf_quant_config is not None:
|
||||
return quant_cls.from_config(hf_quant_config)
|
||||
# For modelopt_mixed, config.json's quantization_config may or may
|
||||
# not contain the per-layer quantized_layers map. Newer checkpoints
|
||||
# embed it directly; older ones keep it only in hf_quant_config.json.
|
||||
# If it is missing, fall through to the file-based loading path.
|
||||
if (
|
||||
model_config.quantization == "modelopt_mixed"
|
||||
and "quantized_layers" not in hf_quant_config
|
||||
):
|
||||
pass # fall through to file-based loading below
|
||||
else:
|
||||
return quant_cls.from_config(hf_quant_config)
|
||||
|
||||
# if hf_quant_config is None, we will try to get config from
|
||||
# hf_overrides
|
||||
@@ -367,8 +375,8 @@ def get_quant_config(
|
||||
|
||||
if model_config.quantization == "bitsandbytes":
|
||||
config["adapter_name_or_path"] = model_config.model
|
||||
elif model_config.quantization == "modelopt":
|
||||
if config["producer"]["name"] == "modelopt":
|
||||
elif model_config.quantization in ("modelopt", "modelopt_mixed"):
|
||||
if config.get("producer", {}).get("name") == "modelopt":
|
||||
return quant_cls.from_config(config)
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -697,13 +705,6 @@ def np_cache_weights_iterator(
|
||||
yield name, torch.from_numpy(param)
|
||||
|
||||
|
||||
FP_TYPES = {
|
||||
"torch.bfloat16",
|
||||
"torch.float16",
|
||||
"torch.float32",
|
||||
"torch.half",
|
||||
}
|
||||
|
||||
def safetensors_weights_iterator(
|
||||
hf_weights_files: list[str],
|
||||
use_tqdm_on_load: bool,
|
||||
@@ -714,12 +715,6 @@ def safetensors_weights_iterator(
|
||||
if safetensors_load_strategy == "eager":
|
||||
loading_desc += " (eager)"
|
||||
|
||||
CUSTOM_QUANT_CONFIG = os.environ.get("CUSTOM_QUANT_CONFIG", None)
|
||||
try:
|
||||
with open(f"{CUSTOM_QUANT_CONFIG}/quant_map.json", "r") as f:
|
||||
quant_map = json.load(f)
|
||||
except Exception as e:
|
||||
quant_map = None
|
||||
leftover_state_dict: dict[str, torch.Tensor] = {}
|
||||
for st_file in tqdm(
|
||||
sorted(hf_weights_files, key=_natural_sort_key),
|
||||
@@ -763,160 +758,9 @@ def safetensors_weights_iterator(
|
||||
with safe_open(st_file, framework="pt") as f:
|
||||
for name in f.keys(): # noqa: SIM118
|
||||
param = f.get_tensor(name)
|
||||
if not quant_map:
|
||||
yield name, param
|
||||
continue
|
||||
quant_type = quant_map.get(name)
|
||||
|
||||
if quant_type is None or quant_type in FP_TYPES:
|
||||
yield name, param
|
||||
continue
|
||||
|
||||
qtype, qformat = quant_type.split("-")
|
||||
|
||||
qname = name
|
||||
qscale_name = f"{name}_scale"
|
||||
is_expert = ("expert" in name and "shared" not in name)
|
||||
|
||||
# INT8
|
||||
if qtype == "int8":
|
||||
if param.ndim == 2:
|
||||
param.unsqueeze_(0)
|
||||
|
||||
qweight, qscale = weight_quant_bf16_to_int8(param)
|
||||
|
||||
A = qweight.shape[0]
|
||||
|
||||
if is_expert:
|
||||
qscale = qscale.view(A, 1, -1).transpose(-2, -1).contiguous()
|
||||
|
||||
if A == 1:
|
||||
qweight = qweight.squeeze(0)
|
||||
qscale = qscale.squeeze(0)
|
||||
|
||||
yield qname, qweight
|
||||
yield qscale_name, qscale
|
||||
continue
|
||||
|
||||
# INT4
|
||||
if qtype == "int4":
|
||||
i8scales, i8zeros = None, None
|
||||
|
||||
if param.ndim == 2:
|
||||
param.unsqueeze_(0)
|
||||
|
||||
qweight, qscale, i8scales, i8zeros = weight_quant_bf16_to_int4pack8(
|
||||
param,
|
||||
format=qformat,
|
||||
symmetric=True,
|
||||
)
|
||||
|
||||
A = qweight.shape[0]
|
||||
|
||||
if is_expert:
|
||||
qscale = qscale.view(A, 1, -1).contiguous()
|
||||
|
||||
if A == 1:
|
||||
qweight = qweight.squeeze(0)
|
||||
qscale = qscale.squeeze(0)
|
||||
|
||||
yield qname, qweight
|
||||
yield qscale_name, qscale
|
||||
|
||||
if i8scales is not None:
|
||||
yield f"{name}_i8_weight_scale", i8scales.squeeze_(0)
|
||||
|
||||
if i8zeros is not None:
|
||||
yield f"{name}_i8_weight_zero", i8zeros.squeeze_(0)
|
||||
|
||||
continue
|
||||
|
||||
yield name, param
|
||||
|
||||
|
||||
def weight_quant_bf16_to_int8(inputs: torch.Tensor):
|
||||
device = current_platform.current_device()
|
||||
|
||||
assert inputs.dim() == 3, f"inputs shape is [batch, output_dim, input_dim], but got {inputs.dim()}"
|
||||
|
||||
ori_device = inputs.device
|
||||
if inputs.device != device:
|
||||
inputs = inputs.to(device)
|
||||
|
||||
qmax = 127.0
|
||||
abs_max = torch.abs(inputs).max(dim=2, keepdim=True)[0]
|
||||
scale = abs_max / qmax
|
||||
|
||||
assert scale.shape == (*inputs.shape[:2], 1)
|
||||
|
||||
quantized = torch.round(inputs / scale)
|
||||
quantized = torch.clamp(quantized, -qmax, qmax)
|
||||
return quantized.to(torch.int8).to(ori_device), scale.to(torch.float32).to(ori_device)
|
||||
|
||||
|
||||
def weight_quant_bf16_to_int4pack8(
|
||||
v: torch.Tensor, # [B, R, C]
|
||||
block_size: int = 128,
|
||||
group_size: int = -1,
|
||||
format: str = "TN",
|
||||
symmetric: bool = True,
|
||||
version: int = 2,
|
||||
):
|
||||
"""
|
||||
Batch 版本 INT4 量化 + 打包。
|
||||
|
||||
Args:
|
||||
v: [batch, rows, cols], float Tensor
|
||||
|
||||
Returns:
|
||||
i4_weights: [batch, rows, packed_cols]
|
||||
scale: [batch, rows, 1]
|
||||
i8scales: 来自 ixfop.quant_repack_int4
|
||||
i8zeros: 来自 ixfop.quant_repack_int4
|
||||
"""
|
||||
device = current_platform.current_device()
|
||||
ori_device = inputs.device
|
||||
if inputs.device != device:
|
||||
inputs = inputs.to(device)
|
||||
assert v.dim() == 3, f"expected [batch, rows, cols], got {v.shape}"
|
||||
|
||||
B, R, C = v.shape
|
||||
|
||||
qmax = 127.0
|
||||
|
||||
# abs_max: [B, R, 1]
|
||||
abs_max = torch.abs(v).amax(dim=2, keepdim=True)
|
||||
scale = abs_max / qmax # [B, R, 1]
|
||||
|
||||
# quantized: [B, R, C]
|
||||
quantized = torch.round(v / scale)
|
||||
quantized = torch.clamp(quantized, -qmax, qmax).to(torch.int8)
|
||||
|
||||
# ixfop.quant_repack_int4 需要 [batch, rows, cols]
|
||||
# 它本来就是 batch-first,可直接送进去
|
||||
# 返回形状一般是:
|
||||
# i4_weights: [B, R, packed_C]
|
||||
# i8scales: [B, R, groups]
|
||||
# i8zeros: [B, R, groups]
|
||||
i4_weights, i8scales, i8zeros = ixfop.quant_repack_int4(
|
||||
quantized, # 不需要 unsqueeze,因为本来就是 [B, R, C]
|
||||
group_size,
|
||||
version,
|
||||
format,
|
||||
symmetric,
|
||||
)
|
||||
if i8scales is not None:
|
||||
i8scales = i8scales.to(ori_device)
|
||||
i8zeros = i8zeros.to(ori_device)
|
||||
return (
|
||||
i4_weights.to(ori_device), # [B, R, packed_C]
|
||||
scale.to(torch.float32).to(ori_device), # [B, R, 1]
|
||||
i8scales, # 来自 repack
|
||||
i8zeros
|
||||
)
|
||||
|
||||
|
||||
|
||||
def multi_thread_safetensors_weights_iterator(
|
||||
hf_weights_files: list[str],
|
||||
use_tqdm_on_load: bool,
|
||||
@@ -1426,3 +1270,50 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None:
|
||||
|
||||
# If there were no matches, return the untouched param name
|
||||
return name
|
||||
|
||||
|
||||
def padding_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
|
||||
"""Weight loader that allows padding on the last (and optionally middle) dims.
|
||||
|
||||
If shapes match: copy directly.
|
||||
If shapes differ: copy the overlapping slice (min along each dimension).
|
||||
Special-cases MoE weights that have expert dim in front (2D/3D).
|
||||
"""
|
||||
# Fast path: exact match
|
||||
if param.shape == loaded_weight.shape:
|
||||
param.data.copy_(loaded_weight)
|
||||
return
|
||||
|
||||
# Basic sanity checks
|
||||
if param.ndim != loaded_weight.ndim:
|
||||
raise ValueError(
|
||||
f"Cannot load weight with different ndim: param.ndim={param.ndim}, "
|
||||
f"loaded.ndim={loaded_weight.ndim}, param.shape={tuple(param.shape)}, "
|
||||
f"loaded.shape={tuple(loaded_weight.shape)}"
|
||||
)
|
||||
|
||||
dims = param.ndim
|
||||
if dims not in (2, 3):
|
||||
raise ValueError(
|
||||
f"padding_weight_loader only supports 2D/3D tensors, got {dims}D. "
|
||||
f"param.shape={tuple(param.shape)}, loaded.shape={tuple(loaded_weight.shape)}"
|
||||
)
|
||||
|
||||
# For MoE tensors, dim0 is num_experts and must match.
|
||||
if param.shape[0] != loaded_weight.shape[0]:
|
||||
raise AssertionError(
|
||||
f"Mismatch in number of experts: param={param.shape[0]}, loaded={loaded_weight.shape[0]}"
|
||||
)
|
||||
|
||||
# Copy the overlapping region: [:, :min(dim1), :min(dim2)]
|
||||
if dims == 2:
|
||||
copy_d1 = min(param.shape[1], loaded_weight.shape[1])
|
||||
param.data[:, :copy_d1].copy_(loaded_weight[:, :copy_d1])
|
||||
return
|
||||
|
||||
# dims == 3
|
||||
copy_d1 = min(param.shape[1], loaded_weight.shape[1])
|
||||
copy_d2 = min(param.shape[2], loaded_weight.shape[2])
|
||||
param.data[:, :copy_d1, :copy_d2].copy_(
|
||||
loaded_weight[:, :copy_d1, :copy_d2]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user