first commit
This commit is contained in:
19
vllm_br/model_executor/layers/quantization/__init__.py
Normal file
19
vllm_br/model_executor/layers/quantization/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from . import compressed_tensors, gptq
|
||||
|
||||
__all__ = ["gptq", 'compressed_tensors']
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,18 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
from .compressed_tensors import *
|
||||
from .compressed_tensors_moe import *
|
||||
from .compressed_tensors_wNa16 import *
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,64 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from functools import wraps
|
||||
from typing import Any, cast
|
||||
|
||||
from fastcore.basics import patch_to
|
||||
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
|
||||
CompressedTensorsConfig)
|
||||
|
||||
|
||||
@patch_to(CompressedTensorsConfig, cls_method=True)
|
||||
def from_config(cls, config: dict[str, Any]) -> "CompressedTensorsConfig":
|
||||
"""
|
||||
[PatchNote] add qkv_quantized param support
|
||||
"""
|
||||
ignore: list[str] = cast(list[str], config.get("ignore", []))
|
||||
quant_format = cast(str, config.get("format"))
|
||||
target_scheme_map = cls._quantization_scheme_map_from_config(config=config)
|
||||
sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config(
|
||||
config=config)
|
||||
transform_config = config.get("transform_config")
|
||||
|
||||
qkv_quantized = cls.get_from_keys_or(config, ["qkv_quantized"],
|
||||
default=True)
|
||||
|
||||
return cls(target_scheme_map=target_scheme_map,
|
||||
ignore=ignore,
|
||||
quant_format=quant_format,
|
||||
sparsity_scheme_map=sparsity_scheme_map,
|
||||
sparsity_ignore_list=sparsity_ignore_list,
|
||||
config=config,
|
||||
transform_config=transform_config,
|
||||
qkv_quantized=qkv_quantized)
|
||||
|
||||
|
||||
def wrapper_CompressedTensorsConfig_init(fn):
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
qkv_quantized = kwargs.pop("qkv_quantized", True)
|
||||
fn(self, *args, **kwargs)
|
||||
|
||||
self.qkv_quantized = qkv_quantized
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
CompressedTensorsConfig.__init__ = wrapper_CompressedTensorsConfig_init(
|
||||
CompressedTensorsConfig.__init__) # noqa: E501
|
||||
@@ -0,0 +1,594 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch_br
|
||||
from compressed_tensors.compressors.quantized_compressors import (
|
||||
unpack_from_int32)
|
||||
from fastcore.basics import patch_to
|
||||
from torch_br.utils.tensor_methods import Sbp
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
|
||||
WNA16_SUPPORTED_BITS, CompressedTensorsMoEMethod,
|
||||
CompressedTensorsWNA16MoEMethod, CompressionFormat)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm_br import envs
|
||||
from ...br_utils import (_convert_to_crossed_numa_tensor,
|
||||
_convert_to_numa_tensor, align_n, cross_weight_32)
|
||||
from ...fused_moe.supa_moe import fused_moe_quant_device, fused_moe_quant_dyn
|
||||
|
||||
|
||||
@patch_to(CompressedTensorsMoEMethod)
|
||||
def get_moe_method(
|
||||
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||
layer: torch.nn.Module,
|
||||
) -> "CompressedTensorsMoEMethod":
|
||||
"""NOTE:
|
||||
1. SUPA only supports CompressedTensorsWNA16MoEMethod without Marlin
|
||||
2. Only Linear targets are supported for MoE layers
|
||||
"""
|
||||
# TODO: @dsikka: refactor this to use schemes as other kernels
|
||||
# are supported + check if the layer is being ignored.
|
||||
keys = list(quant_config.target_scheme_map.keys())
|
||||
assert len(keys) > 0, ("No valid quant key!!!")
|
||||
# assert "Linear" in quant_config.target_scheme_map
|
||||
# [Patch]: Only Linear target is supported for MoE layers, for temporary compatibility, we change the key of target_scheme_map to the first one
|
||||
quant_config.target_scheme_map[
|
||||
"Linear"] = quant_config.target_scheme_map.pop(keys[0])
|
||||
target_key = "Linear"
|
||||
# target_key = keys[0] # normal only one key
|
||||
weight_quant = quant_config.target_scheme_map[target_key].get("weights")
|
||||
input_quant = quant_config.target_scheme_map[target_key].get(
|
||||
"input_activations")
|
||||
|
||||
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
|
||||
logger.info_once("Using CompressedTensorsWNA16MoEMethod")
|
||||
return CompressedTensorsWNA16MoEMethod(quant_config, layer.moe_config)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
|
||||
|
||||
|
||||
@patch_to(CompressedTensorsWNA16MoEMethod)
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||
moe: FusedMoEConfig,
|
||||
):
|
||||
super(CompressedTensorsWNA16MoEMethod, self).__init__(moe)
|
||||
self.quant_config = quant_config
|
||||
# TODO: @dsikka: refactor this to use schemes as other kernels
|
||||
# are supported + check if the layer is being ignored.
|
||||
config = self.quant_config.target_scheme_map["Linear"].get("weights")
|
||||
self.num_bits = config.num_bits
|
||||
self.packed_factor = 32 // config.num_bits
|
||||
self.strategy = config.strategy
|
||||
# channelwise is not supported by this kernel
|
||||
# [Patch]: SUPA use CompressedTensorsWNA16MoEMethod for both channel/group strategies
|
||||
# assert config.strategy == "group"
|
||||
self.group_size = config.group_size
|
||||
# grouped actorder isn't supported by this kernel
|
||||
# assert config.actorder != "group"
|
||||
assert config.symmetric, (
|
||||
"Only symmetric quantization is supported for MoE")
|
||||
|
||||
if not (self.quant_config.quant_format
|
||||
== CompressionFormat.pack_quantized.value
|
||||
and self.num_bits in WNA16_SUPPORTED_BITS):
|
||||
raise ValueError("For Fused MoE layers, only ",
|
||||
f"{CompressionFormat.pack_quantized.value} ",
|
||||
"is supported for the following bits: ",
|
||||
f"{WNA16_SUPPORTED_BITS}")
|
||||
|
||||
|
||||
@patch_to(CompressedTensorsWNA16MoEMethod)
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
# Will transpose the loaded weight along the
|
||||
# intermediate and hidden dim sizes. Will
|
||||
# shard for TP along the transposed dims
|
||||
extra_weight_attrs.update({
|
||||
"is_transposed": True,
|
||||
"quant_method": self.strategy
|
||||
})
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size // self.packed_factor,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_packed", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
intermediate_size_per_partition // self.packed_factor,
|
||||
hidden_size,
|
||||
dtype=torch.int32,
|
||||
device="cpu"),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight_packed", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
w2_scales_size = intermediate_size_per_partition
|
||||
|
||||
if self.strategy == "channel":
|
||||
num_groups_w2 = num_groups_w13 = 1
|
||||
self.group_size = -1
|
||||
else:
|
||||
num_groups_w2 = w2_scales_size // self.group_size
|
||||
num_groups_w13 = hidden_size // self.group_size
|
||||
|
||||
w13_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts,
|
||||
num_groups_w13,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
device="cpu",
|
||||
),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_scale", w13_scale)
|
||||
set_weight_attrs(w13_scale, extra_weight_attrs)
|
||||
|
||||
w2_scale = torch.nn.Parameter(torch.ones(num_experts,
|
||||
num_groups_w2,
|
||||
hidden_size,
|
||||
dtype=params_dtype,
|
||||
device="cpu"),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight_scale", w2_scale)
|
||||
set_weight_attrs(w2_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_scale, {"load_full_w2": False})
|
||||
|
||||
w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts,
|
||||
2,
|
||||
device="cpu"),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight_shape", w2_weight_shape)
|
||||
set_weight_attrs(w2_weight_shape, extra_weight_attrs)
|
||||
w13_weight_shape = torch.nn.Parameter(torch.empty(num_experts,
|
||||
2,
|
||||
device="cpu"),
|
||||
requires_grad=False)
|
||||
|
||||
layer.register_parameter("w13_weight_shape", w13_weight_shape)
|
||||
set_weight_attrs(w13_weight_shape, extra_weight_attrs)
|
||||
|
||||
w13_g_idx = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_g_idx", w13_g_idx)
|
||||
set_weight_attrs(w13_g_idx, extra_weight_attrs)
|
||||
|
||||
w2_g_idx = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
intermediate_size_per_partition,
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight_g_idx", w2_g_idx)
|
||||
set_weight_attrs(w2_g_idx, extra_weight_attrs)
|
||||
|
||||
w13_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_g_idx_sort_indices", w13_g_idx_sort_indices)
|
||||
set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)
|
||||
|
||||
w2_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
intermediate_size_per_partition,
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_g_idx_sort_indices", w2_g_idx_sort_indices)
|
||||
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
|
||||
|
||||
layer.a13_scale = None
|
||||
layer.a2_scale = None
|
||||
|
||||
|
||||
@patch_to(CompressedTensorsWNA16MoEMethod)
|
||||
def process_weights_after_loading(self: CompressedTensorsWNA16MoEMethod,
|
||||
layer: FusedMoE) -> None:
|
||||
die_spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
|
||||
die_num = 1 if die_spc_num <= 16 else 2
|
||||
spc_num = die_spc_num // die_num
|
||||
cur_device = torch.supa.current_device()
|
||||
is_dual_die = (die_spc_num > 16)
|
||||
|
||||
if self.num_bits == 8:
|
||||
# NOTE: w13_weight
|
||||
# after _load_w13, w13_weight is a colparallel weight, shape
|
||||
# [num_experts, hidden_size // 4, 2 * intermediate_size_per_partition] INT32
|
||||
# for SUPA, transform it to a NUMA colmajor weight, shape
|
||||
# [spc_num * num_experts, wk, wn_block] INT8
|
||||
wk = layer.hidden_size
|
||||
wn = layer.intermediate_size_per_partition * 2
|
||||
align_size = 64
|
||||
wn_block = align_n(wn // die_num,
|
||||
align_size=align_size,
|
||||
spc_num=spc_num)
|
||||
|
||||
supa_w13_weight_packed = torch_br._empty_ut_only(
|
||||
size=(die_spc_num * layer.local_num_experts, wk, wn_block),
|
||||
dtype=torch.int8,
|
||||
is_numa=True,
|
||||
device=cur_device,
|
||||
tensor_type="colmajor",
|
||||
axis=0,
|
||||
sbp="SS" if is_dual_die else None)
|
||||
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
expert_w13 = layer.w13_weight_packed[
|
||||
expert_id] # hidden_size // 4, 2 * intermediate_size_per_partition
|
||||
expert_1, expert_3 = expert_w13.chunk(
|
||||
2, dim=1) # each is a packed int4 weight
|
||||
|
||||
unpacked_expert_1 = unpack_from_int32(
|
||||
expert_1, self.num_bits,
|
||||
torch.Size(
|
||||
[layer.hidden_size,
|
||||
layer.intermediate_size_per_partition]), 0)
|
||||
|
||||
unpacked_expert_3 = unpack_from_int32(
|
||||
expert_3, self.num_bits,
|
||||
torch.Size(
|
||||
[layer.hidden_size,
|
||||
layer.intermediate_size_per_partition]), 0)
|
||||
|
||||
pad_expert_w13 = _convert_to_crossed_numa_tensor(
|
||||
unpacked_expert_1,
|
||||
unpacked_expert_3,
|
||||
die_spc_num,
|
||||
dim=1,
|
||||
need_pad=True,
|
||||
layout='COLMAJOR',
|
||||
do_transpose=False)
|
||||
hw_size = pad_expert_w13.shape[-2] * pad_expert_w13.shape[-1]
|
||||
narrow_data = supa_w13_weight_packed.view_as_usharp(
|
||||
"COLMAJOR", pad_expert_w13.shape, Sbp.ss(0),
|
||||
expert_id * hw_size)
|
||||
narrow_data.copy_(pad_expert_w13)
|
||||
|
||||
layer.w13_weight_packed.data = supa_w13_weight_packed
|
||||
|
||||
# NOTE: w13_scale
|
||||
# after _load_w13, w13_weight is a colparallel weight, shape
|
||||
# S8: [num_experts, 1, 2 * intermediate_size_per_partition]
|
||||
# for SUPA, transform it to a NUMA colmajor weight, shape
|
||||
# [num_experts, wn]
|
||||
supa_w13_scales = torch_br._empty_ut_only(
|
||||
size=(layer.local_num_experts, wn),
|
||||
dtype=torch.float32,
|
||||
is_numa=False,
|
||||
device=cur_device,
|
||||
tensor_type="linear_bias",
|
||||
sbp="BB" if is_dual_die else None)
|
||||
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
expert_w13_scales = layer.w13_weight_scale[expert_id]
|
||||
expert_1_scale, expert_3_scale = expert_w13_scales.chunk(
|
||||
2, dim=1) # each is a packed int4 weight
|
||||
crossed_expert_w13_scales = cross_weight_32(
|
||||
expert_1_scale.squeeze(),
|
||||
expert_3_scale.squeeze(),
|
||||
die_spc_num,
|
||||
dim=0,
|
||||
need_pad=False,
|
||||
)
|
||||
narrow_data = supa_w13_scales[expert_id]
|
||||
narrow_data.copy_(crossed_expert_w13_scales)
|
||||
|
||||
layer.w13_weight_scale.data = supa_w13_scales
|
||||
|
||||
# NOTE: w2_weight
|
||||
# after _load_w2, w2_weight is a colparallel weight, shape
|
||||
# [num_experts, intermediate_size_per_partition // 4, hidden_size] INT32
|
||||
# for SUPA, transform it to a NUMA colmajor weight, shape
|
||||
# [spc_num * num_experts, wk, wn_block] INT8
|
||||
wk = layer.intermediate_size_per_partition
|
||||
wn = layer.hidden_size
|
||||
align_size = 32
|
||||
wn_block = align_n(wn, align_size=align_size, spc_num=spc_num)
|
||||
|
||||
supa_w2_weight_packed = torch_br._empty_ut_only(
|
||||
size=(die_spc_num * layer.local_num_experts, wk // die_num,
|
||||
wn_block),
|
||||
dtype=torch.int8,
|
||||
is_numa=True,
|
||||
device=cur_device,
|
||||
tensor_type="colmajor",
|
||||
axis=0,
|
||||
sbp="SS" if is_dual_die else None)
|
||||
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
expert_w2 = layer.w2_weight_packed[expert_id]
|
||||
|
||||
unpacked_expert_2 = unpack_from_int32(
|
||||
expert_w2, self.num_bits,
|
||||
torch.Size(
|
||||
[layer.intermediate_size_per_partition,
|
||||
layer.hidden_size]), 0)
|
||||
|
||||
pad_expert_w2 = _convert_to_numa_tensor(
|
||||
unpacked_expert_2,
|
||||
align_size,
|
||||
'COLMAJOR',
|
||||
expert_w2.dtype,
|
||||
do_transpose=False,
|
||||
parallel_type="row_parallel")
|
||||
pad_expert_w2_shape = pad_expert_w2.shape
|
||||
hw_size = pad_expert_w2_shape[-2] * pad_expert_w2_shape[-1]
|
||||
narrow_data = supa_w2_weight_packed.view_as_usharp(
|
||||
"COLMAJOR", pad_expert_w2_shape, Sbp.ss(0),
|
||||
expert_id * hw_size)
|
||||
narrow_data.copy_(pad_expert_w2)
|
||||
|
||||
layer.w2_weight_packed.data = supa_w2_weight_packed
|
||||
|
||||
# NOTE: w2_scale
|
||||
# after _load_w2, w2_weight is a colparallel weight, shape
|
||||
# S8: [num_experts, 1, hidden_size]
|
||||
# for SUPA, transform it to a NUMA colmajor weight, shape
|
||||
# [num_experts, wn]
|
||||
supa_w2_scales = torch_br._empty_ut_only(size=(layer.local_num_experts,
|
||||
wn),
|
||||
dtype=torch.float32,
|
||||
is_numa=False,
|
||||
device=cur_device,
|
||||
tensor_type="linear_bias")
|
||||
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
expert_w2 = layer.w2_weight_scale[expert_id]
|
||||
narrow_data = supa_w2_scales[expert_id::layer.local_num_experts]
|
||||
narrow_data.copy_(expert_w2)
|
||||
|
||||
layer.w2_weight_scale.data = supa_w2_scales
|
||||
|
||||
elif self.num_bits == 4:
|
||||
# NOTE: w13_weight
|
||||
# after _load_w13, w13_weight is a colparallel weight, shape
|
||||
# [num_experts, hidden_size // 8, 2 * intermediate_size_per_partition] INT32
|
||||
# for SUPA, transform it to a NUMA colmajor weight, shape
|
||||
# [spc_num * num_experts, wk, wn_block] INT32
|
||||
wk = layer.hidden_size // 8
|
||||
wn = layer.intermediate_size_per_partition * 2
|
||||
wn_block = align_n(wn, align_size=64, spc_num=spc_num)
|
||||
|
||||
supa_w13_weight_packed = torch_br._empty_ut_only(
|
||||
size=(spc_num * layer.local_num_experts, wk, wn_block),
|
||||
dtype=torch.int32,
|
||||
is_numa=True,
|
||||
device=cur_device,
|
||||
tensor_type="colmajor")
|
||||
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
expert_w13 = layer.w13_weight_packed[
|
||||
expert_id] # hidden_size // 4, 2 * intermediate_size_per_partition
|
||||
expert_1, expert_3 = expert_w13.chunk(
|
||||
2, dim=0) # each is a packed int4 weight
|
||||
|
||||
pad_expert_w13 = _convert_to_crossed_numa_tensor(expert_1,
|
||||
expert_3,
|
||||
spc_num,
|
||||
dim=1,
|
||||
need_pad=True,
|
||||
layout='COLMAJOR',
|
||||
do_transpose=True)
|
||||
hw_size = pad_expert_w13.shape[-2] * pad_expert_w13.shape[-1]
|
||||
narrow_data = supa_w13_weight_packed.view_as_usharp(
|
||||
"COLMAJOR", pad_expert_w13.shape, Sbp.ss(0),
|
||||
expert_id * hw_size)
|
||||
narrow_data.copy_(pad_expert_w13)
|
||||
|
||||
layer.w13_weight_packed.data = supa_w13_weight_packed
|
||||
|
||||
# NOTE: w13_scale
|
||||
# after _load_w13, w13_weight is a colparallel weight, shape
|
||||
# S4: [num_experts, hidden_size // 128, 2 * intermediate_size_per_partition]
|
||||
# for SUPA, transform it to a NUMA colmajor weight, shape
|
||||
# [num_experts, group_nums, wn]
|
||||
supa_w13_scales = torch_br._empty_ut_only(
|
||||
size=(layer.local_num_experts,
|
||||
layer.hidden_size // self.group_size, wn),
|
||||
dtype=torch.float32,
|
||||
is_numa=False,
|
||||
device=cur_device,
|
||||
tensor_type="colmajor")
|
||||
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
expert_w13_scales = layer.w13_weight_scale[expert_id]
|
||||
expert_1_scale, expert_3_scale = expert_w13_scales.chunk(
|
||||
2, dim=0) # each is a packed int4 weight
|
||||
crossed_expert_w13_scales = cross_weight_32(
|
||||
expert_1_scale,
|
||||
expert_3_scale,
|
||||
spc_num,
|
||||
dim=1,
|
||||
need_pad=False,
|
||||
)
|
||||
narrow_data = supa_w13_scales[expert_id]
|
||||
narrow_data.copy_(crossed_expert_w13_scales)
|
||||
|
||||
layer.w13_weight_scale.data = supa_w13_scales
|
||||
|
||||
# NOTE: w2_weight
|
||||
# after _load_w2, w2_weight is a colparallel weight, shape
|
||||
# [num_experts, intermediate_size_per_partition // 8, hidden_size] INT32
|
||||
# for SUPA, transform it to a NUMA colmajor weight, shape
|
||||
# [spc_num * num_experts, wk, wn_block] INT32
|
||||
wk = layer.intermediate_size_per_partition // 8
|
||||
wn = layer.hidden_size
|
||||
wn_block = align_n(wn, align_size=32, spc_num=spc_num)
|
||||
|
||||
supa_w2_weight_packed = torch_br._empty_ut_only(
|
||||
size=(spc_num * layer.local_num_experts, wk, wn_block),
|
||||
dtype=torch.int32,
|
||||
is_numa=True,
|
||||
device=cur_device,
|
||||
tensor_type="colmajor")
|
||||
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
expert_w2 = layer.w2_weight_packed[expert_id]
|
||||
pad_expert_w2 = _convert_to_numa_tensor(expert_w2,
|
||||
spc_num,
|
||||
'COLMAJOR',
|
||||
expert_w2.dtype,
|
||||
do_transpose=True)
|
||||
pad_expert_w2_shape = pad_expert_w2.shape
|
||||
hw_size = pad_expert_w2_shape[-2] * pad_expert_w2_shape[-1]
|
||||
narrow_data = supa_w2_weight_packed.view_as_usharp(
|
||||
"COLMAJOR", pad_expert_w2_shape, Sbp.ss(0),
|
||||
expert_id * hw_size)
|
||||
narrow_data.copy_(pad_expert_w2)
|
||||
|
||||
layer.w2_weight_packed.data = supa_w2_weight_packed
|
||||
|
||||
# NOTE: w2_scale
|
||||
# after _load_w2, w2_weight is a colparallel weight, shape
|
||||
# S4: [num_experts, intermediate_size_per_partition // 128, hidden_size]
|
||||
# for SUPA, transform it to a NUMA colmajor weight, shape
|
||||
# [num_experts, group_nums, wn]
|
||||
supa_w2_scales = torch_br._empty_ut_only(
|
||||
size=(layer.local_num_experts,
|
||||
layer.intermediate_size_per_partition // self.group_size,
|
||||
wn),
|
||||
dtype=torch.float32,
|
||||
is_numa=False,
|
||||
device=cur_device,
|
||||
tensor_type="colmajor")
|
||||
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
expert_w2 = layer.w2_weight_scale[expert_id]
|
||||
narrow_data = supa_w2_scales[expert_id::layer.local_num_experts]
|
||||
narrow_data.copy_(expert_w2)
|
||||
|
||||
layer.w2_weight_scale.data = supa_w2_scales
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported num_bits: {self.num_bits}. Only 4 and 8 are supported."
|
||||
)
|
||||
|
||||
# remove other CompressedTensorsWNA16MoEMethod registied buffer to reduce memory usage
|
||||
layer.w13_weight_shape = None
|
||||
layer.w13_weight_g_idx = None
|
||||
layer.w13_g_idx_sort_indices = None
|
||||
|
||||
layer.w2_weight_shape = None
|
||||
layer.w2_weight_g_idx = None
|
||||
layer.w2_g_idx_sort_indices = None
|
||||
|
||||
|
||||
@patch_to(CompressedTensorsWNA16MoEMethod)
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
b_seq = x.shape[0]
|
||||
gating_weight, shared_gate_up_weight, shared_down_weight = router_logits
|
||||
if b_seq > envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN:
|
||||
return fused_moe_quant_dyn(
|
||||
x,
|
||||
shared_gate_up_weight,
|
||||
shared_down_weight,
|
||||
layer.w13_weight_packed,
|
||||
layer.w2_weight_packed,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
gating_weight,
|
||||
top_k,
|
||||
layer.intermediate_size_per_partition,
|
||||
renormalize=renormalize,
|
||||
inplace=True,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
ep_rank=layer.ep_rank,
|
||||
ep_size=layer.ep_size,
|
||||
)
|
||||
else:
|
||||
return fused_moe_quant_device(
|
||||
x,
|
||||
shared_gate_up_weight,
|
||||
shared_down_weight,
|
||||
layer.w13_weight_packed,
|
||||
layer.w2_weight_packed,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
gating_weight,
|
||||
top_k,
|
||||
layer.intermediate_size_per_partition,
|
||||
renormalize=renormalize,
|
||||
inplace=True,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
tp_rank=get_tp_group().rank_in_group,
|
||||
global_rank=get_tp_group().rank,
|
||||
tp_size=get_tensor_model_parallel_world_size(),
|
||||
ep_rank=layer.ep_rank,
|
||||
ep_size=layer.ep_size,
|
||||
)
|
||||
@@ -0,0 +1,267 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch_br
|
||||
from compressed_tensors.compressors.quantized_compressors import (
|
||||
unpack_from_int32)
|
||||
from fastcore.basics import patch_to
|
||||
|
||||
from vllm.distributed import (get_pipeline_model_parallel_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group)
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsWNA16)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
marlin_repeat_scales_on_all_ranks)
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
RowvLLMParameter)
|
||||
# yapf: enable
|
||||
from vllm_br import envs
|
||||
from ...br_utils import _convert_to_numa_tensor
|
||||
|
||||
|
||||
@patch_to(CompressedTensorsWNA16)
|
||||
def create_weights(self, layer: torch.nn.Module, input_size: int,
|
||||
input_size_per_partition: int, output_size: int,
|
||||
output_partition_sizes: list[int],
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
self.output_size_per_partition = sum(output_partition_sizes)
|
||||
group_size = self.group_size if self.group_size != -1 else input_size
|
||||
row_parallel = (input_size != input_size_per_partition)
|
||||
partition_scales = not marlin_repeat_scales_on_all_ranks(
|
||||
self.has_g_idx, self.group_size, row_parallel)
|
||||
scales_and_zp_size = input_size // group_size
|
||||
if partition_scales:
|
||||
assert input_size_per_partition % group_size == 0
|
||||
scales_and_zp_size = input_size_per_partition // group_size
|
||||
weight = PackedvLLMParameter(
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
packed_factor=self.pack_factor,
|
||||
packed_dim=1,
|
||||
data=torch.empty(self.output_size_per_partition,
|
||||
input_size_per_partition // self.pack_factor,
|
||||
dtype=torch.int32,
|
||||
device="cpu"))
|
||||
|
||||
weight_scale_args = {
|
||||
"weight_loader":
|
||||
weight_loader,
|
||||
"data":
|
||||
torch.empty(
|
||||
self.output_size_per_partition,
|
||||
scales_and_zp_size,
|
||||
device="cpu",
|
||||
dtype=params_dtype,
|
||||
)
|
||||
}
|
||||
zeros_args = {
|
||||
"weight_loader":
|
||||
weight_loader,
|
||||
"data":
|
||||
torch.zeros(
|
||||
self.output_size_per_partition // self.pack_factor,
|
||||
scales_and_zp_size,
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
)
|
||||
}
|
||||
if not partition_scales:
|
||||
weight_scale = ChannelQuantScaleParameter(output_dim=0,
|
||||
**weight_scale_args)
|
||||
if not self.symmetric:
|
||||
qzeros = PackedColumnParameter(output_dim=0,
|
||||
packed_dim=0,
|
||||
packed_factor=self.pack_factor,
|
||||
**zeros_args)
|
||||
else:
|
||||
weight_scale = GroupQuantScaleParameter(output_dim=0,
|
||||
input_dim=1,
|
||||
**weight_scale_args)
|
||||
if not self.symmetric:
|
||||
qzeros = PackedvLLMParameter(input_dim=1,
|
||||
output_dim=0,
|
||||
packed_dim=0,
|
||||
packed_factor=self.pack_factor,
|
||||
**zeros_args)
|
||||
# A 2D array defining the original shape of the weights
|
||||
# before packing
|
||||
weight_shape = BasevLLMParameter(data=torch.empty(2,
|
||||
dtype=torch.int64,
|
||||
device="cpu"),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight_packed", weight)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
layer.register_parameter("weight_shape", weight_shape)
|
||||
|
||||
if not self.symmetric:
|
||||
layer.register_parameter("weight_zero_point", qzeros)
|
||||
# group index (for activation reordering)
|
||||
if self.has_g_idx:
|
||||
weight_g_idx = RowvLLMParameter(data=torch.empty(
|
||||
input_size_per_partition,
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight_g_idx", weight_g_idx)
|
||||
self.input_size_per_partition = input_size_per_partition
|
||||
|
||||
|
||||
@patch_to(CompressedTensorsWNA16)
|
||||
def process_weights_after_loading(self: CompressedTensorsWNA16,
|
||||
layer: torch.nn.Module) -> None:
|
||||
# spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
|
||||
# cur_device = torch.supa.current_device()
|
||||
self.num_bits = 32 // self.pack_factor
|
||||
layer.weight_packed.data = unpack_from_int32(
|
||||
layer.weight_packed.data, self.num_bits,
|
||||
torch.Size(
|
||||
[self.output_size_per_partition, self.input_size_per_partition]),
|
||||
1)
|
||||
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
|
||||
requires_grad=False)
|
||||
|
||||
br_scales = layer.weight_scale.data.to(torch.float32)
|
||||
layer.weight_scale.data = br_scales
|
||||
|
||||
do_transpose = True
|
||||
parallel_type = "col_parallel"
|
||||
match layer:
|
||||
case RowParallelLinear():
|
||||
parallel_type = "row_parallel"
|
||||
case _:
|
||||
pass
|
||||
|
||||
if hasattr(layer, 'weight_packed') and len(layer.weight_packed.shape) == 2:
|
||||
weight_packed = layer.weight_packed.data
|
||||
layer.weight_packed.data = _convert_to_numa_tensor(
|
||||
weight_packed,
|
||||
envs.VLLM_BR_DEVICE_WARP_SIZE,
|
||||
"colmajor",
|
||||
torch.int8,
|
||||
do_transpose=do_transpose,
|
||||
wk=(weight_packed.shape[1]
|
||||
if do_transpose else weight_packed.shape[0]),
|
||||
wn=(weight_packed.shape[0]
|
||||
if do_transpose else weight_packed.shape[1]),
|
||||
parallel_type=parallel_type) # noqa: SIM210
|
||||
|
||||
if hasattr(layer, 'weight_scale') and layer.weight_scale is not None:
|
||||
pad_zeros = False
|
||||
layer.weight_scale.data = _convert_to_numa_tensor(
|
||||
layer.weight_scale.data.T,
|
||||
envs.VLLM_BR_DEVICE_WARP_SIZE,
|
||||
"linear_bias",
|
||||
torch.float32,
|
||||
parallel_type=parallel_type,
|
||||
pad_zeros=pad_zeros)
|
||||
|
||||
if hasattr(layer, 'bias') and layer.bias is not None:
|
||||
pad_zeros = (parallel_type == "row_parallel")
|
||||
layer.bias.data = _convert_to_numa_tensor(
|
||||
layer.bias.data,
|
||||
envs.VLLM_BR_DEVICE_WARP_SIZE,
|
||||
"linear_bias",
|
||||
torch.float32,
|
||||
parallel_type=parallel_type,
|
||||
pad_zeros=pad_zeros)
|
||||
|
||||
|
||||
@patch_to(CompressedTensorsWNA16)
|
||||
def apply_weights(self: CompressedTensorsWNA16,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
# numa weight is 3-dims
|
||||
|
||||
if len(layer.weight_packed.shape) == 3:
|
||||
output_size = (layer.output_size_per_partition if hasattr(
|
||||
layer, "output_size_per_partition") else layer.output_size)
|
||||
act_mode = "act_default"
|
||||
if isinstance(layer, MergedColumnParallelLinear):
|
||||
act_mode = "act_swiglu"
|
||||
output_size //= 2
|
||||
|
||||
if isinstance(layer, RowParallelLinear):
|
||||
seq_len = x.shape[-2]
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
# bypass tp8 and tp4pp2 allreduce
|
||||
pp_size = get_pipeline_model_parallel_group().world_size
|
||||
all_rank = tp_size * pp_size
|
||||
support_types = ((16, 4), (16, 8), (32, 2), (32, 4))
|
||||
layer.reduce_results = not (
|
||||
all_rank <= envs.VLLM_BR_USE_FUSED_ALLREDUCE
|
||||
and seq_len <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and
|
||||
(envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types)
|
||||
|
||||
if layer.reduce_results:
|
||||
return torch_br.br_fused_mlp_infer(
|
||||
x, [layer.weight_packed.data],
|
||||
output_w=output_size,
|
||||
scales=[layer.weight_scale.data]
|
||||
if layer.weight_scale.data is not None else None,
|
||||
bias=[bias] if bias is not None else None,
|
||||
activaion_mode=act_mode)
|
||||
else:
|
||||
tp_rank = get_tp_group().rank_in_group
|
||||
global_rank = get_tp_group().rank
|
||||
rank_i = global_rank % tp_size
|
||||
assert rank_i == tp_rank
|
||||
return torch_br.supa_fused_linear_allreduce_opt(
|
||||
x,
|
||||
layer.weight_packed.data,
|
||||
output_size,
|
||||
tp_rank,
|
||||
tp_size,
|
||||
global_rank,
|
||||
0,
|
||||
scales=layer.weight_scale.data,
|
||||
bias=bias,
|
||||
act_mode=act_mode)
|
||||
else:
|
||||
return torch_br.br_fused_mlp_infer(
|
||||
x, [layer.weight_packed.data],
|
||||
output_w=output_size,
|
||||
scales=[layer.weight_scale.data]
|
||||
if layer.weight_scale.data is not None else None,
|
||||
bias=[bias] if bias is not None else None,
|
||||
activation_mode=act_mode)
|
||||
|
||||
xn = x.shape[0]
|
||||
xh = x.shape[1]
|
||||
ww = layer.weight_packed.shape[1]
|
||||
# TODO, hard code to skip dry_run stage
|
||||
if xh >= 4096:
|
||||
return torch.ones((xn, xh, ww), dtype=x.dtype, device=x.device)
|
||||
return torch_br.sudnn_qmatmul_infer(x,
|
||||
layer.weight_packed,
|
||||
layer.weight_scale,
|
||||
bias=bias)
|
||||
@@ -0,0 +1,34 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def get_compressed_tensors_cache_scale(name: str) -> Optional[str]:
|
||||
"""
|
||||
Check whether the param name matches the format for k/v cache scales
|
||||
in compressed-tensors. If this is the case, return its equivalent
|
||||
param name expected by vLLM
|
||||
|
||||
:param name: param name
|
||||
:return: matching param name for KV cache scale in vLLM
|
||||
"""
|
||||
if name.endswith(".output_scale") and ".k_proj" in name:
|
||||
return name.replace(".k_proj.output_scale", ".attn.k_scale")
|
||||
if name.endswith(".output_scale") and ".v_proj" in name:
|
||||
return name.replace(".v_proj.output_scale", ".attn.v_scale")
|
||||
# If no matches, return None
|
||||
return None
|
||||
244
vllm_br/model_executor/layers/quantization/gptq.py
Normal file
244
vllm_br/model_executor/layers/quantization/gptq.py
Normal file
@@ -0,0 +1,244 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from functools import wraps
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch_br
|
||||
from fastcore.basics import patch_to
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.distributed import (get_pipeline_model_parallel_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group)
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization.gptq import (GPTQConfig,
|
||||
GPTQLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
||||
get_linear_quant_method)
|
||||
from vllm_br import envs
|
||||
from ..br_utils import _br_qweight_cvt, _convert_to_numa_tensor
|
||||
from ..linear import (process_weights_MergedColumnParallelLinear,
|
||||
process_weights_QuantMergedColumnParallelLinear,
|
||||
process_weights_ReplicatedLinear)
|
||||
|
||||
|
||||
@patch_to(GPTQConfig, cls_method=True)
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16]
|
||||
|
||||
|
||||
@patch_to(GPTQConfig)
|
||||
def get_quant_method(self: GPTQConfig, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["GPTQLinearMethod"]:
|
||||
quant_method = get_linear_quant_method(self, layer, prefix,
|
||||
GPTQLinearMethod)
|
||||
|
||||
return quant_method
|
||||
|
||||
|
||||
@patch_to(GPTQConfig, cls_method=True)
|
||||
def from_config(cls, config: Dict[str, Any]) -> GPTQConfig:
|
||||
"""
|
||||
[PatchNote] add qkv_quantized param support
|
||||
"""
|
||||
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
|
||||
dynamic = {} if dynamic is None else dynamic
|
||||
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
autoround_version = cls.get_from_keys_or(config, ["autoround_version"],
|
||||
default="")
|
||||
modules_in_block_to_quantize = cls.get_from_keys_or(
|
||||
config, ["modules_in_block_to_quantize"], default=None)
|
||||
qkv_quantized = cls.get_from_keys_or(config, ["qkv_quantized"],
|
||||
default=True)
|
||||
return cls(weight_bits=weight_bits,
|
||||
group_size=group_size,
|
||||
desc_act=desc_act,
|
||||
lm_head_quantized=lm_head_quantized,
|
||||
dynamic=dynamic,
|
||||
autoround_version=autoround_version,
|
||||
modules_in_block_to_quantize=modules_in_block_to_quantize,
|
||||
qkv_quantized=qkv_quantized)
|
||||
|
||||
|
||||
def wrapper_GPTQConfig_init(fn):
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
qkv_quantized = kwargs.pop("qkv_quantized", True)
|
||||
fn(self, *args, **kwargs)
|
||||
|
||||
self.qkv_quantized = qkv_quantized
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
GPTQConfig.__init__ = wrapper_GPTQConfig_init(
|
||||
GPTQConfig.__init__) # noqa: E501
|
||||
|
||||
|
||||
@patch_to(GPTQLinearMethod)
|
||||
def process_weights_after_loading(self: GPTQLinearMethod,
|
||||
layer: torch.nn.Module) -> None:
|
||||
still_need_process = True
|
||||
merge_col_quant = False
|
||||
# NOTE: all process_weights func should done before process_weights_after_loading
|
||||
parallel_type = "col_parallel"
|
||||
match layer:
|
||||
case ReplicatedLinear():
|
||||
process_weights_ReplicatedLinear(layer)
|
||||
still_need_process = layer.output_size == 64 or layer.output_size == 256
|
||||
case MergedColumnParallelLinear():
|
||||
if hasattr(layer, "qweight"):
|
||||
merge_col_quant = True
|
||||
else:
|
||||
process_weights_MergedColumnParallelLinear(layer)
|
||||
still_need_process = False
|
||||
case RowParallelLinear():
|
||||
parallel_type = "row_parallel"
|
||||
|
||||
case _:
|
||||
pass
|
||||
|
||||
# NOTE: if use exllama, br gptq needs similar treatment
|
||||
# exllama needs to shuffle the weight after the weight is loaded
|
||||
# here we do the shuffle on first forward pass
|
||||
if layer.qweight.dtype == torch.int32:
|
||||
input_size = layer.input_size_per_partition if hasattr(
|
||||
layer, 'input_size_per_partition') else layer.input_size
|
||||
output_size = layer.output_size_per_partition if hasattr(
|
||||
layer, 'output_size_per_partition') else layer.output_size
|
||||
br_qweight = _br_qweight_cvt(self, layer.qweight, layer.qzeros,
|
||||
input_size, output_size)
|
||||
layer.qweight.data = br_qweight
|
||||
if merge_col_quant:
|
||||
process_weights_QuantMergedColumnParallelLinear(layer)
|
||||
still_need_process = False
|
||||
|
||||
br_scales = layer.scales.to(torch.float32)
|
||||
layer.scales.data = br_scales
|
||||
|
||||
# for torch.compile
|
||||
layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False)
|
||||
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
|
||||
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
|
||||
layer.scales = Parameter(layer.scales.data, requires_grad=False)
|
||||
|
||||
if not still_need_process or self.weight_type != "NUMA":
|
||||
return
|
||||
|
||||
if hasattr(layer, 'qweight') and len(layer.qweight.shape) == 2:
|
||||
layer.qweight.data = _convert_to_numa_tensor(
|
||||
layer.qweight,
|
||||
envs.VLLM_BR_DEVICE_WARP_SIZE,
|
||||
"colmajor",
|
||||
torch.int8,
|
||||
parallel_type=parallel_type)
|
||||
|
||||
if hasattr(layer, 'scales') and layer.scales is not None:
|
||||
pad_zeros = False
|
||||
layer.scales.data = _convert_to_numa_tensor(
|
||||
layer.scales,
|
||||
envs.VLLM_BR_DEVICE_WARP_SIZE,
|
||||
"linear_bias",
|
||||
torch.float32,
|
||||
parallel_type=parallel_type,
|
||||
pad_zeros=pad_zeros)
|
||||
|
||||
if hasattr(layer, 'bias') and layer.bias is not None:
|
||||
pad_zeros = (parallel_type == "row_parallel")
|
||||
layer.bias.data = _convert_to_numa_tensor(
|
||||
layer.bias,
|
||||
envs.VLLM_BR_DEVICE_WARP_SIZE,
|
||||
"linear_bias",
|
||||
torch.float32,
|
||||
parallel_type=parallel_type,
|
||||
pad_zeros=pad_zeros)
|
||||
|
||||
|
||||
@patch_to(GPTQLinearMethod)
|
||||
def apply(self: GPTQLinearMethod,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
# numa weight is 3-dims
|
||||
if len(layer.qweight.shape) == 3:
|
||||
output_size = (layer.output_size_per_partition if hasattr(
|
||||
layer, "output_size_per_partition") else layer.output_size)
|
||||
act_mode = "act_default"
|
||||
if isinstance(layer, MergedColumnParallelLinear):
|
||||
act_mode = "act_swiglu"
|
||||
output_size //= 2
|
||||
if isinstance(layer, RowParallelLinear):
|
||||
seq_len = x.shape[-2]
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
# bypass tp8 and tp4pp2 allreduce
|
||||
pp_size = get_pipeline_model_parallel_group().world_size
|
||||
all_rank = tp_size * pp_size
|
||||
support_types = ((16, 4), (16, 8), (32, 2), (32, 4))
|
||||
layer.reduce_results = not (
|
||||
all_rank <= envs.VLLM_BR_USE_FUSED_ALLREDUCE
|
||||
and seq_len <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and
|
||||
(envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types)
|
||||
if layer.reduce_results:
|
||||
return torch_br.br_fused_mlp_infer(
|
||||
x, [layer.qweight],
|
||||
output_w=output_size,
|
||||
scales=[layer.scales]
|
||||
if layer.scales is not None else None,
|
||||
bias=[bias] if bias is not None else None,
|
||||
activation_mode=act_mode)
|
||||
else:
|
||||
tp_rank = get_tp_group().rank_in_group
|
||||
global_rank = get_tp_group().rank
|
||||
rank_i = global_rank % tp_size
|
||||
assert rank_i == tp_rank
|
||||
return torch_br.supa_fused_linear_allreduce_opt(
|
||||
x,
|
||||
layer.qweight,
|
||||
output_size,
|
||||
tp_rank,
|
||||
tp_size,
|
||||
global_rank,
|
||||
0,
|
||||
scales=layer.scales,
|
||||
bias=bias,
|
||||
act_mode=act_mode)
|
||||
else:
|
||||
return torch_br.br_fused_mlp_infer(
|
||||
x, [layer.qweight],
|
||||
output_w=output_size,
|
||||
scales=[layer.scales] if layer.scales is not None else None,
|
||||
bias=[bias] if bias is not None else None,
|
||||
activation_mode=act_mode)
|
||||
xn = x.shape[0]
|
||||
xh = x.shape[1]
|
||||
ww = layer.qweight.shape[1]
|
||||
# TODO, hard code to skip dry_run stage
|
||||
if xh >= 4096:
|
||||
return torch.ones((xn, xh, ww), dtype=x.dtype, device=x.device)
|
||||
return torch_br.sudnn_qmatmul_infer(x,
|
||||
layer.qweight,
|
||||
layer.scales,
|
||||
bias=bias)
|
||||
Reference in New Issue
Block a user