first commit
This commit is contained in:
244
vllm_br/model_executor/layers/quantization/gptq.py
Normal file
244
vllm_br/model_executor/layers/quantization/gptq.py
Normal file
@@ -0,0 +1,244 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from functools import wraps
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch_br
|
||||
from fastcore.basics import patch_to
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.distributed import (get_pipeline_model_parallel_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group)
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization.gptq import (GPTQConfig,
|
||||
GPTQLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
||||
get_linear_quant_method)
|
||||
from vllm_br import envs
|
||||
from ..br_utils import _br_qweight_cvt, _convert_to_numa_tensor
|
||||
from ..linear import (process_weights_MergedColumnParallelLinear,
|
||||
process_weights_QuantMergedColumnParallelLinear,
|
||||
process_weights_ReplicatedLinear)
|
||||
|
||||
|
||||
@patch_to(GPTQConfig, cls_method=True)
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16]
|
||||
|
||||
|
||||
@patch_to(GPTQConfig)
|
||||
def get_quant_method(self: GPTQConfig, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["GPTQLinearMethod"]:
|
||||
quant_method = get_linear_quant_method(self, layer, prefix,
|
||||
GPTQLinearMethod)
|
||||
|
||||
return quant_method
|
||||
|
||||
|
||||
@patch_to(GPTQConfig, cls_method=True)
|
||||
def from_config(cls, config: Dict[str, Any]) -> GPTQConfig:
|
||||
"""
|
||||
[PatchNote] add qkv_quantized param support
|
||||
"""
|
||||
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
|
||||
dynamic = {} if dynamic is None else dynamic
|
||||
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
autoround_version = cls.get_from_keys_or(config, ["autoround_version"],
|
||||
default="")
|
||||
modules_in_block_to_quantize = cls.get_from_keys_or(
|
||||
config, ["modules_in_block_to_quantize"], default=None)
|
||||
qkv_quantized = cls.get_from_keys_or(config, ["qkv_quantized"],
|
||||
default=True)
|
||||
return cls(weight_bits=weight_bits,
|
||||
group_size=group_size,
|
||||
desc_act=desc_act,
|
||||
lm_head_quantized=lm_head_quantized,
|
||||
dynamic=dynamic,
|
||||
autoround_version=autoround_version,
|
||||
modules_in_block_to_quantize=modules_in_block_to_quantize,
|
||||
qkv_quantized=qkv_quantized)
|
||||
|
||||
|
||||
def wrapper_GPTQConfig_init(fn):
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
qkv_quantized = kwargs.pop("qkv_quantized", True)
|
||||
fn(self, *args, **kwargs)
|
||||
|
||||
self.qkv_quantized = qkv_quantized
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
GPTQConfig.__init__ = wrapper_GPTQConfig_init(
|
||||
GPTQConfig.__init__) # noqa: E501
|
||||
|
||||
|
||||
@patch_to(GPTQLinearMethod)
|
||||
def process_weights_after_loading(self: GPTQLinearMethod,
|
||||
layer: torch.nn.Module) -> None:
|
||||
still_need_process = True
|
||||
merge_col_quant = False
|
||||
# NOTE: all process_weights func should done before process_weights_after_loading
|
||||
parallel_type = "col_parallel"
|
||||
match layer:
|
||||
case ReplicatedLinear():
|
||||
process_weights_ReplicatedLinear(layer)
|
||||
still_need_process = layer.output_size == 64 or layer.output_size == 256
|
||||
case MergedColumnParallelLinear():
|
||||
if hasattr(layer, "qweight"):
|
||||
merge_col_quant = True
|
||||
else:
|
||||
process_weights_MergedColumnParallelLinear(layer)
|
||||
still_need_process = False
|
||||
case RowParallelLinear():
|
||||
parallel_type = "row_parallel"
|
||||
|
||||
case _:
|
||||
pass
|
||||
|
||||
# NOTE: if use exllama, br gptq needs similar treatment
|
||||
# exllama needs to shuffle the weight after the weight is loaded
|
||||
# here we do the shuffle on first forward pass
|
||||
if layer.qweight.dtype == torch.int32:
|
||||
input_size = layer.input_size_per_partition if hasattr(
|
||||
layer, 'input_size_per_partition') else layer.input_size
|
||||
output_size = layer.output_size_per_partition if hasattr(
|
||||
layer, 'output_size_per_partition') else layer.output_size
|
||||
br_qweight = _br_qweight_cvt(self, layer.qweight, layer.qzeros,
|
||||
input_size, output_size)
|
||||
layer.qweight.data = br_qweight
|
||||
if merge_col_quant:
|
||||
process_weights_QuantMergedColumnParallelLinear(layer)
|
||||
still_need_process = False
|
||||
|
||||
br_scales = layer.scales.to(torch.float32)
|
||||
layer.scales.data = br_scales
|
||||
|
||||
# for torch.compile
|
||||
layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False)
|
||||
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
|
||||
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
|
||||
layer.scales = Parameter(layer.scales.data, requires_grad=False)
|
||||
|
||||
if not still_need_process or self.weight_type != "NUMA":
|
||||
return
|
||||
|
||||
if hasattr(layer, 'qweight') and len(layer.qweight.shape) == 2:
|
||||
layer.qweight.data = _convert_to_numa_tensor(
|
||||
layer.qweight,
|
||||
envs.VLLM_BR_DEVICE_WARP_SIZE,
|
||||
"colmajor",
|
||||
torch.int8,
|
||||
parallel_type=parallel_type)
|
||||
|
||||
if hasattr(layer, 'scales') and layer.scales is not None:
|
||||
pad_zeros = False
|
||||
layer.scales.data = _convert_to_numa_tensor(
|
||||
layer.scales,
|
||||
envs.VLLM_BR_DEVICE_WARP_SIZE,
|
||||
"linear_bias",
|
||||
torch.float32,
|
||||
parallel_type=parallel_type,
|
||||
pad_zeros=pad_zeros)
|
||||
|
||||
if hasattr(layer, 'bias') and layer.bias is not None:
|
||||
pad_zeros = (parallel_type == "row_parallel")
|
||||
layer.bias.data = _convert_to_numa_tensor(
|
||||
layer.bias,
|
||||
envs.VLLM_BR_DEVICE_WARP_SIZE,
|
||||
"linear_bias",
|
||||
torch.float32,
|
||||
parallel_type=parallel_type,
|
||||
pad_zeros=pad_zeros)
|
||||
|
||||
|
||||
@patch_to(GPTQLinearMethod)
|
||||
def apply(self: GPTQLinearMethod,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
# numa weight is 3-dims
|
||||
if len(layer.qweight.shape) == 3:
|
||||
output_size = (layer.output_size_per_partition if hasattr(
|
||||
layer, "output_size_per_partition") else layer.output_size)
|
||||
act_mode = "act_default"
|
||||
if isinstance(layer, MergedColumnParallelLinear):
|
||||
act_mode = "act_swiglu"
|
||||
output_size //= 2
|
||||
if isinstance(layer, RowParallelLinear):
|
||||
seq_len = x.shape[-2]
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
# bypass tp8 and tp4pp2 allreduce
|
||||
pp_size = get_pipeline_model_parallel_group().world_size
|
||||
all_rank = tp_size * pp_size
|
||||
support_types = ((16, 4), (16, 8), (32, 2), (32, 4))
|
||||
layer.reduce_results = not (
|
||||
all_rank <= envs.VLLM_BR_USE_FUSED_ALLREDUCE
|
||||
and seq_len <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and
|
||||
(envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types)
|
||||
if layer.reduce_results:
|
||||
return torch_br.br_fused_mlp_infer(
|
||||
x, [layer.qweight],
|
||||
output_w=output_size,
|
||||
scales=[layer.scales]
|
||||
if layer.scales is not None else None,
|
||||
bias=[bias] if bias is not None else None,
|
||||
activation_mode=act_mode)
|
||||
else:
|
||||
tp_rank = get_tp_group().rank_in_group
|
||||
global_rank = get_tp_group().rank
|
||||
rank_i = global_rank % tp_size
|
||||
assert rank_i == tp_rank
|
||||
return torch_br.supa_fused_linear_allreduce_opt(
|
||||
x,
|
||||
layer.qweight,
|
||||
output_size,
|
||||
tp_rank,
|
||||
tp_size,
|
||||
global_rank,
|
||||
0,
|
||||
scales=layer.scales,
|
||||
bias=bias,
|
||||
act_mode=act_mode)
|
||||
else:
|
||||
return torch_br.br_fused_mlp_infer(
|
||||
x, [layer.qweight],
|
||||
output_w=output_size,
|
||||
scales=[layer.scales] if layer.scales is not None else None,
|
||||
bias=[bias] if bias is not None else None,
|
||||
activation_mode=act_mode)
|
||||
xn = x.shape[0]
|
||||
xh = x.shape[1]
|
||||
ww = layer.qweight.shape[1]
|
||||
# TODO, hard code to skip dry_run stage
|
||||
if xh >= 4096:
|
||||
return torch.ones((xn, xh, ww), dtype=x.dtype, device=x.device)
|
||||
return torch_br.sudnn_qmatmul_infer(x,
|
||||
layer.qweight,
|
||||
layer.scales,
|
||||
bias=bias)
|
||||
Reference in New Issue
Block a user