Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -2,8 +2,8 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import itertools
|
||||
import ast, re
|
||||
from abc import abstractmethod
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
@@ -16,6 +16,7 @@ from vllm.distributed import (
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import PluggableLayer
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
@@ -28,7 +29,9 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
)
|
||||
from vllm.model_executor.layers.utils import (
|
||||
dispatch_unquantized_gemm,
|
||||
is_layer_moe_router_gate,
|
||||
parse_opt_exclude_layers,
|
||||
weight_quant_l1,
|
||||
weight_quant_l2,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
BasevLLMParameter,
|
||||
@@ -41,12 +44,11 @@ from vllm.model_executor.parameter import (
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
import vllm.envs as envs
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
WEIGHT_LOADER_V2_SUPPORTED = [
|
||||
"UnquantizedLinearMethod",
|
||||
"CompressedTensorsLinearMethod",
|
||||
"CompressedTensorsLinearTransformMethod",
|
||||
"AWQMarlinLinearMethod",
|
||||
@@ -66,6 +68,14 @@ WEIGHT_LOADER_V2_SUPPORTED = [
|
||||
"PetitNvFp4LinearMethod",
|
||||
]
|
||||
|
||||
LINEAR_OPT_SUPPORTED = [
|
||||
"ColumnParallelLinear",
|
||||
"ReplicatedLinear",
|
||||
"RowParallelLinear",
|
||||
"QKVParallelLinear",
|
||||
"MergedColumnParallelLinear",
|
||||
]
|
||||
|
||||
|
||||
def adjust_marlin_shard(
|
||||
param: Parameter,
|
||||
@@ -135,44 +145,6 @@ def adjust_scalar_to_fused_array(
|
||||
return param_data[shard_id], loaded_weight
|
||||
|
||||
|
||||
# TODO(Isotr0py): We might need a more flexible structure to handle
|
||||
# bitsandbytes shard offsets.
|
||||
def left_shift_bitsandbytes_4bit_shard(
|
||||
bnb_weight_attrs: dict[str, Any],
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
"""
|
||||
Separate the BitsAndBytes 4-bit shard.
|
||||
|
||||
For example, given bnb weight attributes as below:
|
||||
{
|
||||
'bnb_shard_offsets': array([0, 4, 8, 16]),
|
||||
'bnb_quant_state': {0: ..., 1: ..., 2: ...},
|
||||
}
|
||||
|
||||
The function will return:
|
||||
{
|
||||
'bnb_shard_offsets': array([0, 4]),
|
||||
'bnb_quant_state': {0: ...},
|
||||
}
|
||||
and
|
||||
{
|
||||
'bnb_shard_offsets': array([0, 4, 12]),
|
||||
'bnb_quant_state': {0: ..., 1: ...},
|
||||
}
|
||||
"""
|
||||
shard_offsets = bnb_weight_attrs["bnb_shard_offsets"]
|
||||
offset_l = shard_offsets[:2]
|
||||
offset_r = shard_offsets[1:] - shard_offsets[1]
|
||||
quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]}
|
||||
quant_state_r = {
|
||||
i - 1: bnb_weight_attrs["bnb_quant_state"][i]
|
||||
for i in range(1, len(shard_offsets) - 1)
|
||||
}
|
||||
left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l)
|
||||
right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r)
|
||||
return left, right
|
||||
|
||||
|
||||
class LinearMethodBase(QuantizeMethodBase):
|
||||
"""Base class for different (maybe quantized) linear methods."""
|
||||
|
||||
@@ -231,17 +203,11 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
# The weights are not quantized, and they are not sharded.
|
||||
# The amount of memory allocated for the weights is
|
||||
# sum(output_partition_sizes) * input_size_per_partition.
|
||||
weight_loader = extra_weight_attrs.pop("weight_loader")
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
@@ -258,11 +224,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if (
|
||||
vllm_is_batch_invariant()
|
||||
and current_platform.is_cuda_alike()
|
||||
and is_layer_moe_router_gate(getattr(layer, "prefix", ""))
|
||||
):
|
||||
if vllm_is_batch_invariant() and current_platform.is_cuda_alike():
|
||||
return linear_batch_invariant(x, layer.weight, bias)
|
||||
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
|
||||
|
||||
@@ -305,15 +267,31 @@ class LinearBase(PluggableLayer):
|
||||
self.quant_config = quant_config
|
||||
self.prefix = prefix
|
||||
self.allow_fp8_block_shape_mismatch = False
|
||||
if quant_config is None:
|
||||
self.opt_level = envs.VLLM_LINEAR_OPT_LEVEL
|
||||
if parse_opt_exclude_layers(envs.VLLM_LINEAR_SPECIFIED_LAYERS, self.prefix) or \
|
||||
(envs.VLLM_LINEAR_SPECIFIED_KEYS != "" and envs.VLLM_LINEAR_SPECIFIED_KEYS in self.prefix):
|
||||
self.opt_level = envs.VLLM_LINEAR_SPECIFIED_OPT_LEVEL
|
||||
self.opt_flag = quant_config is None and self.opt_level != 0 and \
|
||||
self.__class__.__name__ in LINEAR_OPT_SUPPORTED
|
||||
|
||||
if parse_opt_exclude_layers(envs.VLLM_OPT_EXCLUDE_LAYERS, self.prefix):
|
||||
self.opt_flag = False
|
||||
logger.info(f"Excluding layer {self.prefix} from optimization")
|
||||
|
||||
if self.opt_flag:
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import CompressedTensorsLinearMethod
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import CompressedTensorsW8A8Int8
|
||||
self.quant_method: QuantizeMethodBase | None = CompressedTensorsLinearMethod(None)
|
||||
self.scheme = CompressedTensorsW8A8Int8(QuantizationStrategy.CHANNEL, False, True, is_w4a8_linear=True if self.opt_level == 2 else False)
|
||||
elif quant_config is None:
|
||||
self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod()
|
||||
else:
|
||||
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
|
||||
self.return_bias = return_bias
|
||||
self.output_padding_size = 0
|
||||
self.disable_tp = disable_tp
|
||||
self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
|
||||
self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
|
||||
self.output_padding_size = 0
|
||||
|
||||
def update_param_tp_status(self):
|
||||
for param in self.parameters():
|
||||
@@ -402,7 +380,7 @@ class ReplicatedLinear(LinearBase):
|
||||
# If the weight on disk does not have a shape, give it one
|
||||
# (such scales for AutoFp8).
|
||||
# Special case for GGUF
|
||||
|
||||
|
||||
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
||||
if is_gguf_weight_type:
|
||||
@@ -419,7 +397,17 @@ class ReplicatedLinear(LinearBase):
|
||||
f"Tried to load weights of size {loaded_weight.size()}"
|
||||
f"to a parameter of size {param.size()}"
|
||||
)
|
||||
if self.opt_flag:
|
||||
if self.opt_level == 1:
|
||||
loaded_weight, scale = weight_quant_l1(loaded_weight)
|
||||
elif self.opt_level == 2:
|
||||
loaded_weight, scale = weight_quant_l2(loaded_weight, format="NN")
|
||||
|
||||
param.data.copy_(loaded_weight)
|
||||
if self.opt_flag:
|
||||
params_dict = dict(self.named_parameters())
|
||||
scale_param = params_dict["weight_scale"]
|
||||
scale_param.data.copy_(scale)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -609,7 +597,18 @@ class ColumnParallelLinear(LinearBase):
|
||||
if len(loaded_weight.shape) == 0:
|
||||
assert loaded_weight.numel() == 1
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
|
||||
if self.opt_flag:
|
||||
if self.opt_level == 1:
|
||||
loaded_weight, scale = weight_quant_l1(loaded_weight)
|
||||
elif self.opt_level == 2:
|
||||
loaded_weight, scale = weight_quant_l2(loaded_weight, format="NN")
|
||||
|
||||
param.load_column_parallel_weight(loaded_weight=loaded_weight)
|
||||
if self.opt_flag:
|
||||
params_dict = dict(self.named_parameters())
|
||||
scale_param = params_dict["weight_scale"]
|
||||
scale_param.load_column_parallel_weight(loaded_weight=scale)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -733,16 +732,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
loaded_shard_id: tuple[int, ...] | int | None = None,
|
||||
):
|
||||
self.validate_shard_id(loaded_shard_id)
|
||||
# FIXME(Isotr0py): Enable tuple shard_id for BNB quantization.
|
||||
if isinstance(loaded_shard_id, tuple):
|
||||
raise NotImplementedError(
|
||||
"Shard id with multiple indices is not supported in weight_loader, "
|
||||
"please use weight_loader_v2 instead."
|
||||
)
|
||||
# Special case for GGUF
|
||||
# initialize GGUF param after we know the quantize type
|
||||
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
||||
if isinstance(loaded_shard_id, tuple) and (
|
||||
is_gguf_weight or is_gguf_weight_type
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"Shard id with multiple indices is not supported for GGUF."
|
||||
)
|
||||
if is_gguf_weight_type:
|
||||
if loaded_shard_id is not None:
|
||||
param.data[loaded_shard_id].copy_(loaded_weight)
|
||||
@@ -770,7 +769,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
# Special case for per-tensor scale to load scalar into fused array.
|
||||
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
|
||||
|
||||
if loaded_shard_id is None:
|
||||
if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
|
||||
# Loaded weight is already fused on disk (mlp).
|
||||
# (e.g., Phi-3's gate_up_proj).
|
||||
if output_dim is None:
|
||||
@@ -782,10 +781,25 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
return
|
||||
|
||||
output_sizes = (
|
||||
self.output_sizes[loaded_shard_id[0] : loaded_shard_id[-1] + 1]
|
||||
if loaded_shard_id is not None
|
||||
else self.output_sizes
|
||||
)
|
||||
current_shard_offset = 0
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||
if (
|
||||
use_bitsandbytes_4bit
|
||||
and isinstance(loaded_shard_id, tuple)
|
||||
and self.tp_size > 1
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"Shard id with multiple indices is not supported "
|
||||
"for BNB quantization with TP yet."
|
||||
)
|
||||
shard_offsets: list[tuple[int, int, int]] = []
|
||||
for i, output_size in enumerate(self.output_sizes):
|
||||
for i, output_size in enumerate(output_sizes):
|
||||
shard_offsets.append((i, current_shard_offset, output_size))
|
||||
current_shard_offset += output_size
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
@@ -850,9 +864,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
|
||||
|
||||
if use_bitsandbytes_4bit:
|
||||
shard_size = loaded_weight.shape[output_dim]
|
||||
shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id
|
||||
|
||||
index = list(itertools.accumulate([0] + self.output_sizes))
|
||||
orig_offsets = {
|
||||
str(i): (index[i], size) for i, size in enumerate(self.output_sizes)
|
||||
}
|
||||
orig_offsets["total"] = (self.output_size, 0)
|
||||
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
|
||||
param, orig_offsets, str(loaded_shard_id)
|
||||
)
|
||||
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
|
||||
start_idx = self.tp_rank * shard_size
|
||||
if not is_sharded_weight:
|
||||
@@ -921,12 +940,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: tuple[int, ...] | int | None = None,
|
||||
):
|
||||
if self.opt_flag:
|
||||
if self.opt_level == 1:
|
||||
loaded_weight, scale = weight_quant_l1(loaded_weight)
|
||||
elif self.opt_level == 2:
|
||||
loaded_weight, scale = weight_quant_l2(loaded_weight, format="NN")
|
||||
self.validate_shard_id(loaded_shard_id)
|
||||
dtype = loaded_weight.dtype
|
||||
if envs.VLLM_W8A8_LINEAR_USE_W4A8 and not (param.shape[0] == 1 or param.shape[1] == 1) and dtype == torch.int8:
|
||||
load_sizes = [self.output_sizes[i] // 2 for i in range(len(self.output_sizes))]
|
||||
else:
|
||||
load_sizes = self.output_sizes
|
||||
if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
|
||||
if isinstance(param, PerTensorScaleParameter):
|
||||
param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
|
||||
@@ -953,19 +972,21 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
|
||||
assert loaded_shard_id < len(self.output_sizes)
|
||||
|
||||
# shard_offset = sum(self.output_sizes[:loaded_shard_id])
|
||||
# shard_size = self.output_sizes[loaded_shard_id]
|
||||
shard_offset = sum(load_sizes[:loaded_shard_id])
|
||||
shard_size = load_sizes[loaded_shard_id]
|
||||
shard_offset = sum(self.output_sizes[:loaded_shard_id])
|
||||
shard_size = self.output_sizes[loaded_shard_id]
|
||||
shard_offset //= self.tp_size
|
||||
shard_size //= self.tp_size
|
||||
scale_shard_offset = shard_offset
|
||||
scale_shard_size = shard_size
|
||||
if self.opt_flag and self.opt_level == 2:
|
||||
shard_offset = shard_offset // 2
|
||||
shard_size = shard_size // 2
|
||||
|
||||
if isinstance(param, BlockQuantScaleParameter):
|
||||
weight_block_size = getattr(self, "weight_block_size", None)
|
||||
shard_size, shard_offset = adjust_block_scale_shard(
|
||||
weight_block_size, shard_size, shard_offset
|
||||
)
|
||||
|
||||
param.load_merged_column_weight(
|
||||
loaded_weight=loaded_weight,
|
||||
shard_id=loaded_shard_id,
|
||||
@@ -973,6 +994,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
shard_size=shard_size,
|
||||
tp_rank=self.tp_rank,
|
||||
)
|
||||
if self.opt_flag:
|
||||
params_dict = dict(self.named_parameters())
|
||||
scale_param = params_dict["weight_scale"]
|
||||
scale_param.load_merged_column_weight(
|
||||
loaded_weight=scale,
|
||||
shard_id=loaded_shard_id,
|
||||
shard_offset=scale_shard_offset,
|
||||
shard_size=scale_shard_size,
|
||||
tp_rank=self.tp_rank,
|
||||
)
|
||||
|
||||
|
||||
class QKVParallelLinear(ColumnParallelLinear):
|
||||
@@ -1128,12 +1159,24 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
|
||||
shard_size=shard_size, shard_offset=shard_offset
|
||||
)
|
||||
|
||||
loaded_weight_shard = loaded_weight.narrow(
|
||||
param.output_dim, shard_offset, shard_size
|
||||
)
|
||||
if self.opt_level == 2:
|
||||
loaded_weight_shard = loaded_weight.narrow(
|
||||
0, shard_offset, shard_size
|
||||
)
|
||||
else:
|
||||
loaded_weight_shard = loaded_weight.narrow(
|
||||
param.output_dim, shard_offset, shard_size
|
||||
)
|
||||
self.weight_loader_v2(param, loaded_weight_shard, shard_id)
|
||||
|
||||
def quant(self, loaded_weight: torch.Tensor):
|
||||
if self.opt_flag:
|
||||
if self.opt_level == 1:
|
||||
return weight_quant_l1(loaded_weight)
|
||||
elif self.opt_level == 2:
|
||||
return weight_quant_l2(loaded_weight, format="NN")
|
||||
return loaded_weight, None
|
||||
|
||||
def weight_loader_v2(
|
||||
self,
|
||||
param: BasevLLMParameter,
|
||||
@@ -1141,15 +1184,27 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
loaded_shard_id: str | None = None,
|
||||
):
|
||||
self.validate_shard_id(loaded_shard_id)
|
||||
params_dict = dict(self.named_parameters())
|
||||
if loaded_shard_id is None: # special case for certain models
|
||||
if isinstance(param, PerTensorScaleParameter):
|
||||
loaded_weight, scale = self.quant(loaded_weight)
|
||||
param.load_qkv_weight(
|
||||
loaded_weight=loaded_weight, shard_id=0, tp_rank=self.tp_rank
|
||||
)
|
||||
if self.opt_flag:
|
||||
scale_param = params_dict["weight_scale"]
|
||||
scale_param.load_qkv_weight(
|
||||
loaded_weight=scale, shard_id=0, tp_rank=self.tp_rank
|
||||
)
|
||||
return
|
||||
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
|
||||
loaded_weight, scale = self.quant(loaded_weight)
|
||||
param.load_qkv_weight(loaded_weight=loaded_weight, tp_rank=self.tp_rank)
|
||||
if self.opt_flag:
|
||||
scale_param = params_dict["weight_scale"]
|
||||
scale_param.load_qkv_weight(loaded_weight=scale, tp_rank=self.tp_rank)
|
||||
return
|
||||
|
||||
# TODO: @dsikka - move to parameter.py
|
||||
self._load_fused_module_from_checkpoint(param, loaded_weight)
|
||||
return
|
||||
@@ -1158,11 +1213,15 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
|
||||
shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
|
||||
shard_size = self._get_shard_size_mapping(loaded_shard_id)
|
||||
dtype = loaded_weight.dtype
|
||||
# w4a8 gemm需要除2,scale 不需要
|
||||
if envs.VLLM_W8A8_LINEAR_USE_W4A8 and not (param.shape[0] == 1 or param.shape[1] == 1) and dtype == torch.int8:
|
||||
shard_offset //= 2
|
||||
shard_size //= 2
|
||||
|
||||
scale_shard_offset = shard_offset
|
||||
scale_shard_size = shard_size
|
||||
|
||||
loaded_weight, scale = self.quant(loaded_weight)
|
||||
|
||||
if self.opt_flag and self.opt_level == 2:
|
||||
shard_offset = shard_offset // 2
|
||||
shard_size = shard_size // 2
|
||||
|
||||
if isinstance(param, BlockQuantScaleParameter):
|
||||
weight_block_size = getattr(self, "weight_block_size", None)
|
||||
@@ -1179,6 +1238,15 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
tp_rank=self.tp_rank,
|
||||
)
|
||||
|
||||
if self.opt_flag:
|
||||
scale_param = params_dict["weight_scale"]
|
||||
scale_param.load_qkv_weight(loaded_weight=scale,
|
||||
num_heads=self.num_kv_head_replicas,
|
||||
shard_id=loaded_shard_id,
|
||||
shard_offset=scale_shard_offset,
|
||||
shard_size=scale_shard_size,
|
||||
tp_rank=self.tp_rank)
|
||||
|
||||
def weight_loader(
|
||||
self,
|
||||
param: Parameter,
|
||||
@@ -1525,7 +1593,17 @@ class RowParallelLinear(LinearBase):
|
||||
assert loaded_weight.numel() == 1
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
|
||||
if self.opt_flag:
|
||||
if self.opt_level == 1:
|
||||
loaded_weight, scale = weight_quant_l1(loaded_weight)
|
||||
elif self.opt_level == 2:
|
||||
loaded_weight, scale = weight_quant_l2(loaded_weight, format="NN")
|
||||
|
||||
param.load_row_parallel_weight(loaded_weight=loaded_weight)
|
||||
if self.opt_flag:
|
||||
params_dict = dict(self.named_parameters())
|
||||
scale_param = params_dict["weight_scale"]
|
||||
scale_param.load_row_parallel_weight(loaded_weight=scale)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user