################################################################################ # Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ################################################################################ from functools import wraps from typing import Any, Dict, Optional import torch import torch_br from fastcore.basics import patch_to from torch.nn.parameter import Parameter from vllm.distributed import (get_pipeline_model_parallel_group, get_tensor_model_parallel_world_size, get_tp_group) from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.quantization.gptq import (GPTQConfig, GPTQLinearMethod) from vllm.model_executor.layers.quantization.utils.gptq_utils import ( get_linear_quant_method) from vllm_br import envs from ..br_utils import _br_qweight_cvt, _convert_to_numa_tensor from ..linear import (process_weights_MergedColumnParallelLinear, process_weights_QuantMergedColumnParallelLinear, process_weights_ReplicatedLinear) @patch_to(GPTQConfig, cls_method=True) def get_supported_act_dtypes(cls) -> list[torch.dtype]: return [torch.bfloat16] @patch_to(GPTQConfig) def get_quant_method(self: GPTQConfig, layer: torch.nn.Module, prefix: str) -> Optional["GPTQLinearMethod"]: quant_method = get_linear_quant_method(self, layer, prefix, GPTQLinearMethod) return quant_method @patch_to(GPTQConfig, cls_method=True) def from_config(cls, config: Dict[str, Any]) -> GPTQConfig: """ [PatchNote] add qkv_quantized param support """ dynamic = cls.get_from_keys_or(config, ["dynamic"], default={}) dynamic = {} if dynamic is None else dynamic weight_bits = cls.get_from_keys(config, ["bits"]) group_size = cls.get_from_keys(config, ["group_size"]) desc_act = cls.get_from_keys(config, ["desc_act"]) lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) autoround_version = cls.get_from_keys_or(config, ["autoround_version"], default="") modules_in_block_to_quantize = cls.get_from_keys_or( config, ["modules_in_block_to_quantize"], default=None) qkv_quantized = cls.get_from_keys_or(config, ["qkv_quantized"], default=True) return cls(weight_bits=weight_bits, group_size=group_size, desc_act=desc_act, lm_head_quantized=lm_head_quantized, dynamic=dynamic, autoround_version=autoround_version, modules_in_block_to_quantize=modules_in_block_to_quantize, qkv_quantized=qkv_quantized) def wrapper_GPTQConfig_init(fn): @wraps(fn) def wrapper(self, *args, **kwargs): qkv_quantized = kwargs.pop("qkv_quantized", True) fn(self, *args, **kwargs) self.qkv_quantized = qkv_quantized return wrapper GPTQConfig.__init__ = wrapper_GPTQConfig_init( GPTQConfig.__init__) # noqa: E501 @patch_to(GPTQLinearMethod) def process_weights_after_loading(self: GPTQLinearMethod, layer: torch.nn.Module) -> None: still_need_process = True merge_col_quant = False # NOTE: all process_weights func should done before process_weights_after_loading parallel_type = "col_parallel" match layer: case ReplicatedLinear(): process_weights_ReplicatedLinear(layer) still_need_process = layer.output_size == 64 or layer.output_size == 256 case MergedColumnParallelLinear(): if hasattr(layer, "qweight"): merge_col_quant = True else: process_weights_MergedColumnParallelLinear(layer) still_need_process = False case RowParallelLinear(): parallel_type = "row_parallel" case _: pass # NOTE: if use exllama, br gptq needs similar treatment # exllama needs to shuffle the weight after the weight is loaded # here we do the shuffle on first forward pass if layer.qweight.dtype == torch.int32: input_size = layer.input_size_per_partition if hasattr( layer, 'input_size_per_partition') else layer.input_size output_size = layer.output_size_per_partition if hasattr( layer, 'output_size_per_partition') else layer.output_size br_qweight = _br_qweight_cvt(self, layer.qweight, layer.qzeros, input_size, output_size) layer.qweight.data = br_qweight if merge_col_quant: process_weights_QuantMergedColumnParallelLinear(layer) still_need_process = False br_scales = layer.scales.to(torch.float32) layer.scales.data = br_scales # for torch.compile layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False) layer.qweight = Parameter(layer.qweight.data, requires_grad=False) layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False) layer.scales = Parameter(layer.scales.data, requires_grad=False) if not still_need_process or self.weight_type != "NUMA": return if hasattr(layer, 'qweight') and len(layer.qweight.shape) == 2: layer.qweight.data = _convert_to_numa_tensor( layer.qweight, envs.VLLM_BR_DEVICE_WARP_SIZE, "colmajor", torch.int8, parallel_type=parallel_type) if hasattr(layer, 'scales') and layer.scales is not None: pad_zeros = False layer.scales.data = _convert_to_numa_tensor( layer.scales, envs.VLLM_BR_DEVICE_WARP_SIZE, "linear_bias", torch.float32, parallel_type=parallel_type, pad_zeros=pad_zeros) if hasattr(layer, 'bias') and layer.bias is not None: pad_zeros = (parallel_type == "row_parallel") layer.bias.data = _convert_to_numa_tensor( layer.bias, envs.VLLM_BR_DEVICE_WARP_SIZE, "linear_bias", torch.float32, parallel_type=parallel_type, pad_zeros=pad_zeros) @patch_to(GPTQLinearMethod) def apply(self: GPTQLinearMethod, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: # numa weight is 3-dims if len(layer.qweight.shape) == 3: output_size = (layer.output_size_per_partition if hasattr( layer, "output_size_per_partition") else layer.output_size) act_mode = "act_default" if isinstance(layer, MergedColumnParallelLinear): act_mode = "act_swiglu" output_size //= 2 if isinstance(layer, RowParallelLinear): seq_len = x.shape[-2] tp_size = get_tensor_model_parallel_world_size() # bypass tp8 and tp4pp2 allreduce pp_size = get_pipeline_model_parallel_group().world_size all_rank = tp_size * pp_size support_types = ((16, 4), (16, 8), (32, 2), (32, 4)) layer.reduce_results = not ( all_rank <= envs.VLLM_BR_USE_FUSED_ALLREDUCE and seq_len <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and (envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types) if layer.reduce_results: return torch_br.br_fused_mlp_infer( x, [layer.qweight], output_w=output_size, scales=[layer.scales] if layer.scales is not None else None, bias=[bias] if bias is not None else None, activation_mode=act_mode) else: tp_rank = get_tp_group().rank_in_group global_rank = get_tp_group().rank rank_i = global_rank % tp_size assert rank_i == tp_rank return torch_br.supa_fused_linear_allreduce_opt( x, layer.qweight, output_size, tp_rank, tp_size, global_rank, 0, scales=layer.scales, bias=bias, act_mode=act_mode) else: return torch_br.br_fused_mlp_infer( x, [layer.qweight], output_w=output_size, scales=[layer.scales] if layer.scales is not None else None, bias=[bias] if bias is not None else None, activation_mode=act_mode) xn = x.shape[0] xh = x.shape[1] ww = layer.qweight.shape[1] # TODO, hard code to skip dry_run stage if xh >= 4096: return torch.ones((xn, xh, ww), dtype=x.dtype, device=x.device) return torch_br.sudnn_qmatmul_infer(x, layer.qweight, layer.scales, bias=bias)