Files
2026-03-10 13:31:25 +08:00

245 lines
9.7 KiB
Python

################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from functools import wraps
from typing import Any, Dict, Optional
import torch
import torch_br
from fastcore.basics import patch_to
from torch.nn.parameter import Parameter
from vllm.distributed import (get_pipeline_model_parallel_group,
get_tensor_model_parallel_world_size,
get_tp_group)
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization.gptq import (GPTQConfig,
GPTQLinearMethod)
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
get_linear_quant_method)
from vllm_br import envs
from ..br_utils import _br_qweight_cvt, _convert_to_numa_tensor
from ..linear import (process_weights_MergedColumnParallelLinear,
process_weights_QuantMergedColumnParallelLinear,
process_weights_ReplicatedLinear)
@patch_to(GPTQConfig, cls_method=True)
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16]
@patch_to(GPTQConfig)
def get_quant_method(self: GPTQConfig, layer: torch.nn.Module,
prefix: str) -> Optional["GPTQLinearMethod"]:
quant_method = get_linear_quant_method(self, layer, prefix,
GPTQLinearMethod)
return quant_method
@patch_to(GPTQConfig, cls_method=True)
def from_config(cls, config: Dict[str, Any]) -> GPTQConfig:
"""
[PatchNote] add qkv_quantized param support
"""
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
dynamic = {} if dynamic is None else dynamic
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
autoround_version = cls.get_from_keys_or(config, ["autoround_version"],
default="")
modules_in_block_to_quantize = cls.get_from_keys_or(
config, ["modules_in_block_to_quantize"], default=None)
qkv_quantized = cls.get_from_keys_or(config, ["qkv_quantized"],
default=True)
return cls(weight_bits=weight_bits,
group_size=group_size,
desc_act=desc_act,
lm_head_quantized=lm_head_quantized,
dynamic=dynamic,
autoround_version=autoround_version,
modules_in_block_to_quantize=modules_in_block_to_quantize,
qkv_quantized=qkv_quantized)
def wrapper_GPTQConfig_init(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
qkv_quantized = kwargs.pop("qkv_quantized", True)
fn(self, *args, **kwargs)
self.qkv_quantized = qkv_quantized
return wrapper
GPTQConfig.__init__ = wrapper_GPTQConfig_init(
GPTQConfig.__init__) # noqa: E501
@patch_to(GPTQLinearMethod)
def process_weights_after_loading(self: GPTQLinearMethod,
layer: torch.nn.Module) -> None:
still_need_process = True
merge_col_quant = False
# NOTE: all process_weights func should done before process_weights_after_loading
parallel_type = "col_parallel"
match layer:
case ReplicatedLinear():
process_weights_ReplicatedLinear(layer)
still_need_process = layer.output_size == 64 or layer.output_size == 256
case MergedColumnParallelLinear():
if hasattr(layer, "qweight"):
merge_col_quant = True
else:
process_weights_MergedColumnParallelLinear(layer)
still_need_process = False
case RowParallelLinear():
parallel_type = "row_parallel"
case _:
pass
# NOTE: if use exllama, br gptq needs similar treatment
# exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass
if layer.qweight.dtype == torch.int32:
input_size = layer.input_size_per_partition if hasattr(
layer, 'input_size_per_partition') else layer.input_size
output_size = layer.output_size_per_partition if hasattr(
layer, 'output_size_per_partition') else layer.output_size
br_qweight = _br_qweight_cvt(self, layer.qweight, layer.qzeros,
input_size, output_size)
layer.qweight.data = br_qweight
if merge_col_quant:
process_weights_QuantMergedColumnParallelLinear(layer)
still_need_process = False
br_scales = layer.scales.to(torch.float32)
layer.scales.data = br_scales
# for torch.compile
layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False)
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
layer.scales = Parameter(layer.scales.data, requires_grad=False)
if not still_need_process or self.weight_type != "NUMA":
return
if hasattr(layer, 'qweight') and len(layer.qweight.shape) == 2:
layer.qweight.data = _convert_to_numa_tensor(
layer.qweight,
envs.VLLM_BR_DEVICE_WARP_SIZE,
"colmajor",
torch.int8,
parallel_type=parallel_type)
if hasattr(layer, 'scales') and layer.scales is not None:
pad_zeros = False
layer.scales.data = _convert_to_numa_tensor(
layer.scales,
envs.VLLM_BR_DEVICE_WARP_SIZE,
"linear_bias",
torch.float32,
parallel_type=parallel_type,
pad_zeros=pad_zeros)
if hasattr(layer, 'bias') and layer.bias is not None:
pad_zeros = (parallel_type == "row_parallel")
layer.bias.data = _convert_to_numa_tensor(
layer.bias,
envs.VLLM_BR_DEVICE_WARP_SIZE,
"linear_bias",
torch.float32,
parallel_type=parallel_type,
pad_zeros=pad_zeros)
@patch_to(GPTQLinearMethod)
def apply(self: GPTQLinearMethod,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# numa weight is 3-dims
if len(layer.qweight.shape) == 3:
output_size = (layer.output_size_per_partition if hasattr(
layer, "output_size_per_partition") else layer.output_size)
act_mode = "act_default"
if isinstance(layer, MergedColumnParallelLinear):
act_mode = "act_swiglu"
output_size //= 2
if isinstance(layer, RowParallelLinear):
seq_len = x.shape[-2]
tp_size = get_tensor_model_parallel_world_size()
# bypass tp8 and tp4pp2 allreduce
pp_size = get_pipeline_model_parallel_group().world_size
all_rank = tp_size * pp_size
support_types = ((16, 4), (16, 8), (32, 2), (32, 4))
layer.reduce_results = not (
all_rank <= envs.VLLM_BR_USE_FUSED_ALLREDUCE
and seq_len <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and
(envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types)
if layer.reduce_results:
return torch_br.br_fused_mlp_infer(
x, [layer.qweight],
output_w=output_size,
scales=[layer.scales]
if layer.scales is not None else None,
bias=[bias] if bias is not None else None,
activation_mode=act_mode)
else:
tp_rank = get_tp_group().rank_in_group
global_rank = get_tp_group().rank
rank_i = global_rank % tp_size
assert rank_i == tp_rank
return torch_br.supa_fused_linear_allreduce_opt(
x,
layer.qweight,
output_size,
tp_rank,
tp_size,
global_rank,
0,
scales=layer.scales,
bias=bias,
act_mode=act_mode)
else:
return torch_br.br_fused_mlp_infer(
x, [layer.qweight],
output_w=output_size,
scales=[layer.scales] if layer.scales is not None else None,
bias=[bias] if bias is not None else None,
activation_mode=act_mode)
xn = x.shape[0]
xh = x.shape[1]
ww = layer.qweight.shape[1]
# TODO, hard code to skip dry_run stage
if xh >= 4096:
return torch.ones((xn, xh, ww), dtype=x.dtype, device=x.device)
return torch_br.sudnn_qmatmul_infer(x,
layer.qweight,
layer.scales,
bias=bias)