245 lines
9.7 KiB
Python
245 lines
9.7 KiB
Python
################################################################################
|
|
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
################################################################################
|
|
|
|
from functools import wraps
|
|
from typing import Any, Dict, Optional
|
|
|
|
import torch
|
|
import torch_br
|
|
from fastcore.basics import patch_to
|
|
from torch.nn.parameter import Parameter
|
|
|
|
from vllm.distributed import (get_pipeline_model_parallel_group,
|
|
get_tensor_model_parallel_world_size,
|
|
get_tp_group)
|
|
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
|
ReplicatedLinear,
|
|
RowParallelLinear)
|
|
from vllm.model_executor.layers.quantization.gptq import (GPTQConfig,
|
|
GPTQLinearMethod)
|
|
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
|
get_linear_quant_method)
|
|
from vllm_br import envs
|
|
from ..br_utils import _br_qweight_cvt, _convert_to_numa_tensor
|
|
from ..linear import (process_weights_MergedColumnParallelLinear,
|
|
process_weights_QuantMergedColumnParallelLinear,
|
|
process_weights_ReplicatedLinear)
|
|
|
|
|
|
@patch_to(GPTQConfig, cls_method=True)
|
|
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
|
return [torch.bfloat16]
|
|
|
|
|
|
@patch_to(GPTQConfig)
|
|
def get_quant_method(self: GPTQConfig, layer: torch.nn.Module,
|
|
prefix: str) -> Optional["GPTQLinearMethod"]:
|
|
quant_method = get_linear_quant_method(self, layer, prefix,
|
|
GPTQLinearMethod)
|
|
|
|
return quant_method
|
|
|
|
|
|
@patch_to(GPTQConfig, cls_method=True)
|
|
def from_config(cls, config: Dict[str, Any]) -> GPTQConfig:
|
|
"""
|
|
[PatchNote] add qkv_quantized param support
|
|
"""
|
|
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
|
|
dynamic = {} if dynamic is None else dynamic
|
|
|
|
weight_bits = cls.get_from_keys(config, ["bits"])
|
|
group_size = cls.get_from_keys(config, ["group_size"])
|
|
desc_act = cls.get_from_keys(config, ["desc_act"])
|
|
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
|
default=False)
|
|
autoround_version = cls.get_from_keys_or(config, ["autoround_version"],
|
|
default="")
|
|
modules_in_block_to_quantize = cls.get_from_keys_or(
|
|
config, ["modules_in_block_to_quantize"], default=None)
|
|
qkv_quantized = cls.get_from_keys_or(config, ["qkv_quantized"],
|
|
default=True)
|
|
return cls(weight_bits=weight_bits,
|
|
group_size=group_size,
|
|
desc_act=desc_act,
|
|
lm_head_quantized=lm_head_quantized,
|
|
dynamic=dynamic,
|
|
autoround_version=autoround_version,
|
|
modules_in_block_to_quantize=modules_in_block_to_quantize,
|
|
qkv_quantized=qkv_quantized)
|
|
|
|
|
|
def wrapper_GPTQConfig_init(fn):
|
|
|
|
@wraps(fn)
|
|
def wrapper(self, *args, **kwargs):
|
|
qkv_quantized = kwargs.pop("qkv_quantized", True)
|
|
fn(self, *args, **kwargs)
|
|
|
|
self.qkv_quantized = qkv_quantized
|
|
|
|
return wrapper
|
|
|
|
|
|
GPTQConfig.__init__ = wrapper_GPTQConfig_init(
|
|
GPTQConfig.__init__) # noqa: E501
|
|
|
|
|
|
@patch_to(GPTQLinearMethod)
|
|
def process_weights_after_loading(self: GPTQLinearMethod,
|
|
layer: torch.nn.Module) -> None:
|
|
still_need_process = True
|
|
merge_col_quant = False
|
|
# NOTE: all process_weights func should done before process_weights_after_loading
|
|
parallel_type = "col_parallel"
|
|
match layer:
|
|
case ReplicatedLinear():
|
|
process_weights_ReplicatedLinear(layer)
|
|
still_need_process = layer.output_size == 64 or layer.output_size == 256
|
|
case MergedColumnParallelLinear():
|
|
if hasattr(layer, "qweight"):
|
|
merge_col_quant = True
|
|
else:
|
|
process_weights_MergedColumnParallelLinear(layer)
|
|
still_need_process = False
|
|
case RowParallelLinear():
|
|
parallel_type = "row_parallel"
|
|
|
|
case _:
|
|
pass
|
|
|
|
# NOTE: if use exllama, br gptq needs similar treatment
|
|
# exllama needs to shuffle the weight after the weight is loaded
|
|
# here we do the shuffle on first forward pass
|
|
if layer.qweight.dtype == torch.int32:
|
|
input_size = layer.input_size_per_partition if hasattr(
|
|
layer, 'input_size_per_partition') else layer.input_size
|
|
output_size = layer.output_size_per_partition if hasattr(
|
|
layer, 'output_size_per_partition') else layer.output_size
|
|
br_qweight = _br_qweight_cvt(self, layer.qweight, layer.qzeros,
|
|
input_size, output_size)
|
|
layer.qweight.data = br_qweight
|
|
if merge_col_quant:
|
|
process_weights_QuantMergedColumnParallelLinear(layer)
|
|
still_need_process = False
|
|
|
|
br_scales = layer.scales.to(torch.float32)
|
|
layer.scales.data = br_scales
|
|
|
|
# for torch.compile
|
|
layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False)
|
|
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
|
|
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
|
|
layer.scales = Parameter(layer.scales.data, requires_grad=False)
|
|
|
|
if not still_need_process or self.weight_type != "NUMA":
|
|
return
|
|
|
|
if hasattr(layer, 'qweight') and len(layer.qweight.shape) == 2:
|
|
layer.qweight.data = _convert_to_numa_tensor(
|
|
layer.qweight,
|
|
envs.VLLM_BR_DEVICE_WARP_SIZE,
|
|
"colmajor",
|
|
torch.int8,
|
|
parallel_type=parallel_type)
|
|
|
|
if hasattr(layer, 'scales') and layer.scales is not None:
|
|
pad_zeros = False
|
|
layer.scales.data = _convert_to_numa_tensor(
|
|
layer.scales,
|
|
envs.VLLM_BR_DEVICE_WARP_SIZE,
|
|
"linear_bias",
|
|
torch.float32,
|
|
parallel_type=parallel_type,
|
|
pad_zeros=pad_zeros)
|
|
|
|
if hasattr(layer, 'bias') and layer.bias is not None:
|
|
pad_zeros = (parallel_type == "row_parallel")
|
|
layer.bias.data = _convert_to_numa_tensor(
|
|
layer.bias,
|
|
envs.VLLM_BR_DEVICE_WARP_SIZE,
|
|
"linear_bias",
|
|
torch.float32,
|
|
parallel_type=parallel_type,
|
|
pad_zeros=pad_zeros)
|
|
|
|
|
|
@patch_to(GPTQLinearMethod)
|
|
def apply(self: GPTQLinearMethod,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
# numa weight is 3-dims
|
|
if len(layer.qweight.shape) == 3:
|
|
output_size = (layer.output_size_per_partition if hasattr(
|
|
layer, "output_size_per_partition") else layer.output_size)
|
|
act_mode = "act_default"
|
|
if isinstance(layer, MergedColumnParallelLinear):
|
|
act_mode = "act_swiglu"
|
|
output_size //= 2
|
|
if isinstance(layer, RowParallelLinear):
|
|
seq_len = x.shape[-2]
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
# bypass tp8 and tp4pp2 allreduce
|
|
pp_size = get_pipeline_model_parallel_group().world_size
|
|
all_rank = tp_size * pp_size
|
|
support_types = ((16, 4), (16, 8), (32, 2), (32, 4))
|
|
layer.reduce_results = not (
|
|
all_rank <= envs.VLLM_BR_USE_FUSED_ALLREDUCE
|
|
and seq_len <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and
|
|
(envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types)
|
|
if layer.reduce_results:
|
|
return torch_br.br_fused_mlp_infer(
|
|
x, [layer.qweight],
|
|
output_w=output_size,
|
|
scales=[layer.scales]
|
|
if layer.scales is not None else None,
|
|
bias=[bias] if bias is not None else None,
|
|
activation_mode=act_mode)
|
|
else:
|
|
tp_rank = get_tp_group().rank_in_group
|
|
global_rank = get_tp_group().rank
|
|
rank_i = global_rank % tp_size
|
|
assert rank_i == tp_rank
|
|
return torch_br.supa_fused_linear_allreduce_opt(
|
|
x,
|
|
layer.qweight,
|
|
output_size,
|
|
tp_rank,
|
|
tp_size,
|
|
global_rank,
|
|
0,
|
|
scales=layer.scales,
|
|
bias=bias,
|
|
act_mode=act_mode)
|
|
else:
|
|
return torch_br.br_fused_mlp_infer(
|
|
x, [layer.qweight],
|
|
output_w=output_size,
|
|
scales=[layer.scales] if layer.scales is not None else None,
|
|
bias=[bias] if bias is not None else None,
|
|
activation_mode=act_mode)
|
|
xn = x.shape[0]
|
|
xh = x.shape[1]
|
|
ww = layer.qweight.shape[1]
|
|
# TODO, hard code to skip dry_run stage
|
|
if xh >= 4096:
|
|
return torch.ones((xn, xh, ww), dtype=x.dtype, device=x.device)
|
|
return torch_br.sudnn_qmatmul_infer(x,
|
|
layer.qweight,
|
|
layer.scales,
|
|
bias=bias)
|