first commit

This commit is contained in:
2026-03-10 13:31:25 +08:00
parent ba974cecfa
commit b62b889355
2604 changed files with 438977 additions and 0 deletions

View File

@@ -0,0 +1,19 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from . import compressed_tensors, gptq
__all__ = ["gptq", 'compressed_tensors']

View File

@@ -0,0 +1,18 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from .compressed_tensors import *
from .compressed_tensors_moe import *
from .compressed_tensors_wNa16 import *

View File

@@ -0,0 +1,64 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from functools import wraps
from typing import Any, cast
from fastcore.basics import patch_to
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
CompressedTensorsConfig)
@patch_to(CompressedTensorsConfig, cls_method=True)
def from_config(cls, config: dict[str, Any]) -> "CompressedTensorsConfig":
"""
[PatchNote] add qkv_quantized param support
"""
ignore: list[str] = cast(list[str], config.get("ignore", []))
quant_format = cast(str, config.get("format"))
target_scheme_map = cls._quantization_scheme_map_from_config(config=config)
sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config(
config=config)
transform_config = config.get("transform_config")
qkv_quantized = cls.get_from_keys_or(config, ["qkv_quantized"],
default=True)
return cls(target_scheme_map=target_scheme_map,
ignore=ignore,
quant_format=quant_format,
sparsity_scheme_map=sparsity_scheme_map,
sparsity_ignore_list=sparsity_ignore_list,
config=config,
transform_config=transform_config,
qkv_quantized=qkv_quantized)
def wrapper_CompressedTensorsConfig_init(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
qkv_quantized = kwargs.pop("qkv_quantized", True)
fn(self, *args, **kwargs)
self.qkv_quantized = qkv_quantized
return wrapper
CompressedTensorsConfig.__init__ = wrapper_CompressedTensorsConfig_init(
CompressedTensorsConfig.__init__) # noqa: E501

View File

@@ -0,0 +1,594 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from typing import Callable, Optional
import torch
import torch_br
from compressed_tensors.compressors.quantized_compressors import (
unpack_from_int32)
from fastcore.basics import patch_to
from torch_br.utils.tensor_methods import Sbp
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
from vllm.logger import logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
WNA16_SUPPORTED_BITS, CompressedTensorsMoEMethod,
CompressedTensorsWNA16MoEMethod, CompressionFormat)
from vllm.model_executor.utils import set_weight_attrs
from vllm_br import envs
from ...br_utils import (_convert_to_crossed_numa_tensor,
_convert_to_numa_tensor, align_n, cross_weight_32)
from ...fused_moe.supa_moe import fused_moe_quant_device, fused_moe_quant_dyn
@patch_to(CompressedTensorsMoEMethod)
def get_moe_method(
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
layer: torch.nn.Module,
) -> "CompressedTensorsMoEMethod":
"""NOTE:
1. SUPA only supports CompressedTensorsWNA16MoEMethod without Marlin
2. Only Linear targets are supported for MoE layers
"""
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
keys = list(quant_config.target_scheme_map.keys())
assert len(keys) > 0, ("No valid quant key!!!")
# assert "Linear" in quant_config.target_scheme_map
# [Patch]: Only Linear target is supported for MoE layers, for temporary compatibility, we change the key of target_scheme_map to the first one
quant_config.target_scheme_map[
"Linear"] = quant_config.target_scheme_map.pop(keys[0])
target_key = "Linear"
# target_key = keys[0] # normal only one key
weight_quant = quant_config.target_scheme_map[target_key].get("weights")
input_quant = quant_config.target_scheme_map[target_key].get(
"input_activations")
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
logger.info_once("Using CompressedTensorsWNA16MoEMethod")
return CompressedTensorsWNA16MoEMethod(quant_config, layer.moe_config)
else:
raise RuntimeError(
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
@patch_to(CompressedTensorsWNA16MoEMethod)
def __init__(
self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig,
):
super(CompressedTensorsWNA16MoEMethod, self).__init__(moe)
self.quant_config = quant_config
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
config = self.quant_config.target_scheme_map["Linear"].get("weights")
self.num_bits = config.num_bits
self.packed_factor = 32 // config.num_bits
self.strategy = config.strategy
# channelwise is not supported by this kernel
# [Patch]: SUPA use CompressedTensorsWNA16MoEMethod for both channel/group strategies
# assert config.strategy == "group"
self.group_size = config.group_size
# grouped actorder isn't supported by this kernel
# assert config.actorder != "group"
assert config.symmetric, (
"Only symmetric quantization is supported for MoE")
if not (self.quant_config.quant_format
== CompressionFormat.pack_quantized.value
and self.num_bits in WNA16_SUPPORTED_BITS):
raise ValueError("For Fused MoE layers, only ",
f"{CompressionFormat.pack_quantized.value} ",
"is supported for the following bits: ",
f"{WNA16_SUPPORTED_BITS}")
@patch_to(CompressedTensorsWNA16MoEMethod)
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
# Will transpose the loaded weight along the
# intermediate and hidden dim sizes. Will
# shard for TP along the transposed dims
extra_weight_attrs.update({
"is_transposed": True,
"quant_method": self.strategy
})
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size // self.packed_factor,
2 * intermediate_size_per_partition,
dtype=torch.int32,
device="cpu",
),
requires_grad=False)
layer.register_parameter("w13_weight_packed", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
intermediate_size_per_partition // self.packed_factor,
hidden_size,
dtype=torch.int32,
device="cpu"),
requires_grad=False)
layer.register_parameter("w2_weight_packed", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
w2_scales_size = intermediate_size_per_partition
if self.strategy == "channel":
num_groups_w2 = num_groups_w13 = 1
self.group_size = -1
else:
num_groups_w2 = w2_scales_size // self.group_size
num_groups_w13 = hidden_size // self.group_size
w13_scale = torch.nn.Parameter(torch.ones(
num_experts,
num_groups_w13,
2 * intermediate_size_per_partition,
dtype=params_dtype,
device="cpu",
),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_scale)
set_weight_attrs(w13_scale, extra_weight_attrs)
w2_scale = torch.nn.Parameter(torch.ones(num_experts,
num_groups_w2,
hidden_size,
dtype=params_dtype,
device="cpu"),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_scale)
set_weight_attrs(w2_scale, extra_weight_attrs)
set_weight_attrs(w2_scale, {"load_full_w2": False})
w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts,
2,
device="cpu"),
requires_grad=False)
layer.register_parameter("w2_weight_shape", w2_weight_shape)
set_weight_attrs(w2_weight_shape, extra_weight_attrs)
w13_weight_shape = torch.nn.Parameter(torch.empty(num_experts,
2,
device="cpu"),
requires_grad=False)
layer.register_parameter("w13_weight_shape", w13_weight_shape)
set_weight_attrs(w13_weight_shape, extra_weight_attrs)
w13_g_idx = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
device="cpu",
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_g_idx", w13_g_idx)
set_weight_attrs(w13_g_idx, extra_weight_attrs)
w2_g_idx = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size_per_partition,
device="cpu",
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w2_weight_g_idx", w2_g_idx)
set_weight_attrs(w2_g_idx, extra_weight_attrs)
w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
device="cpu",
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w13_g_idx_sort_indices", w13_g_idx_sort_indices)
set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)
w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size_per_partition,
device="cpu",
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w2_g_idx_sort_indices", w2_g_idx_sort_indices)
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
layer.a13_scale = None
layer.a2_scale = None
@patch_to(CompressedTensorsWNA16MoEMethod)
def process_weights_after_loading(self: CompressedTensorsWNA16MoEMethod,
layer: FusedMoE) -> None:
die_spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
die_num = 1 if die_spc_num <= 16 else 2
spc_num = die_spc_num // die_num
cur_device = torch.supa.current_device()
is_dual_die = (die_spc_num > 16)
if self.num_bits == 8:
# NOTE: w13_weight
# after _load_w13, w13_weight is a colparallel weight, shape
# [num_experts, hidden_size // 4, 2 * intermediate_size_per_partition] INT32
# for SUPA, transform it to a NUMA colmajor weight, shape
# [spc_num * num_experts, wk, wn_block] INT8
wk = layer.hidden_size
wn = layer.intermediate_size_per_partition * 2
align_size = 64
wn_block = align_n(wn // die_num,
align_size=align_size,
spc_num=spc_num)
supa_w13_weight_packed = torch_br._empty_ut_only(
size=(die_spc_num * layer.local_num_experts, wk, wn_block),
dtype=torch.int8,
is_numa=True,
device=cur_device,
tensor_type="colmajor",
axis=0,
sbp="SS" if is_dual_die else None)
for expert_id in range(layer.local_num_experts):
expert_w13 = layer.w13_weight_packed[
expert_id] # hidden_size // 4, 2 * intermediate_size_per_partition
expert_1, expert_3 = expert_w13.chunk(
2, dim=1) # each is a packed int4 weight
unpacked_expert_1 = unpack_from_int32(
expert_1, self.num_bits,
torch.Size(
[layer.hidden_size,
layer.intermediate_size_per_partition]), 0)
unpacked_expert_3 = unpack_from_int32(
expert_3, self.num_bits,
torch.Size(
[layer.hidden_size,
layer.intermediate_size_per_partition]), 0)
pad_expert_w13 = _convert_to_crossed_numa_tensor(
unpacked_expert_1,
unpacked_expert_3,
die_spc_num,
dim=1,
need_pad=True,
layout='COLMAJOR',
do_transpose=False)
hw_size = pad_expert_w13.shape[-2] * pad_expert_w13.shape[-1]
narrow_data = supa_w13_weight_packed.view_as_usharp(
"COLMAJOR", pad_expert_w13.shape, Sbp.ss(0),
expert_id * hw_size)
narrow_data.copy_(pad_expert_w13)
layer.w13_weight_packed.data = supa_w13_weight_packed
# NOTE: w13_scale
# after _load_w13, w13_weight is a colparallel weight, shape
# S8: [num_experts, 1, 2 * intermediate_size_per_partition]
# for SUPA, transform it to a NUMA colmajor weight, shape
# [num_experts, wn]
supa_w13_scales = torch_br._empty_ut_only(
size=(layer.local_num_experts, wn),
dtype=torch.float32,
is_numa=False,
device=cur_device,
tensor_type="linear_bias",
sbp="BB" if is_dual_die else None)
for expert_id in range(layer.local_num_experts):
expert_w13_scales = layer.w13_weight_scale[expert_id]
expert_1_scale, expert_3_scale = expert_w13_scales.chunk(
2, dim=1) # each is a packed int4 weight
crossed_expert_w13_scales = cross_weight_32(
expert_1_scale.squeeze(),
expert_3_scale.squeeze(),
die_spc_num,
dim=0,
need_pad=False,
)
narrow_data = supa_w13_scales[expert_id]
narrow_data.copy_(crossed_expert_w13_scales)
layer.w13_weight_scale.data = supa_w13_scales
# NOTE: w2_weight
# after _load_w2, w2_weight is a colparallel weight, shape
# [num_experts, intermediate_size_per_partition // 4, hidden_size] INT32
# for SUPA, transform it to a NUMA colmajor weight, shape
# [spc_num * num_experts, wk, wn_block] INT8
wk = layer.intermediate_size_per_partition
wn = layer.hidden_size
align_size = 32
wn_block = align_n(wn, align_size=align_size, spc_num=spc_num)
supa_w2_weight_packed = torch_br._empty_ut_only(
size=(die_spc_num * layer.local_num_experts, wk // die_num,
wn_block),
dtype=torch.int8,
is_numa=True,
device=cur_device,
tensor_type="colmajor",
axis=0,
sbp="SS" if is_dual_die else None)
for expert_id in range(layer.local_num_experts):
expert_w2 = layer.w2_weight_packed[expert_id]
unpacked_expert_2 = unpack_from_int32(
expert_w2, self.num_bits,
torch.Size(
[layer.intermediate_size_per_partition,
layer.hidden_size]), 0)
pad_expert_w2 = _convert_to_numa_tensor(
unpacked_expert_2,
align_size,
'COLMAJOR',
expert_w2.dtype,
do_transpose=False,
parallel_type="row_parallel")
pad_expert_w2_shape = pad_expert_w2.shape
hw_size = pad_expert_w2_shape[-2] * pad_expert_w2_shape[-1]
narrow_data = supa_w2_weight_packed.view_as_usharp(
"COLMAJOR", pad_expert_w2_shape, Sbp.ss(0),
expert_id * hw_size)
narrow_data.copy_(pad_expert_w2)
layer.w2_weight_packed.data = supa_w2_weight_packed
# NOTE: w2_scale
# after _load_w2, w2_weight is a colparallel weight, shape
# S8: [num_experts, 1, hidden_size]
# for SUPA, transform it to a NUMA colmajor weight, shape
# [num_experts, wn]
supa_w2_scales = torch_br._empty_ut_only(size=(layer.local_num_experts,
wn),
dtype=torch.float32,
is_numa=False,
device=cur_device,
tensor_type="linear_bias")
for expert_id in range(layer.local_num_experts):
expert_w2 = layer.w2_weight_scale[expert_id]
narrow_data = supa_w2_scales[expert_id::layer.local_num_experts]
narrow_data.copy_(expert_w2)
layer.w2_weight_scale.data = supa_w2_scales
elif self.num_bits == 4:
# NOTE: w13_weight
# after _load_w13, w13_weight is a colparallel weight, shape
# [num_experts, hidden_size // 8, 2 * intermediate_size_per_partition] INT32
# for SUPA, transform it to a NUMA colmajor weight, shape
# [spc_num * num_experts, wk, wn_block] INT32
wk = layer.hidden_size // 8
wn = layer.intermediate_size_per_partition * 2
wn_block = align_n(wn, align_size=64, spc_num=spc_num)
supa_w13_weight_packed = torch_br._empty_ut_only(
size=(spc_num * layer.local_num_experts, wk, wn_block),
dtype=torch.int32,
is_numa=True,
device=cur_device,
tensor_type="colmajor")
for expert_id in range(layer.local_num_experts):
expert_w13 = layer.w13_weight_packed[
expert_id] # hidden_size // 4, 2 * intermediate_size_per_partition
expert_1, expert_3 = expert_w13.chunk(
2, dim=0) # each is a packed int4 weight
pad_expert_w13 = _convert_to_crossed_numa_tensor(expert_1,
expert_3,
spc_num,
dim=1,
need_pad=True,
layout='COLMAJOR',
do_transpose=True)
hw_size = pad_expert_w13.shape[-2] * pad_expert_w13.shape[-1]
narrow_data = supa_w13_weight_packed.view_as_usharp(
"COLMAJOR", pad_expert_w13.shape, Sbp.ss(0),
expert_id * hw_size)
narrow_data.copy_(pad_expert_w13)
layer.w13_weight_packed.data = supa_w13_weight_packed
# NOTE: w13_scale
# after _load_w13, w13_weight is a colparallel weight, shape
# S4: [num_experts, hidden_size // 128, 2 * intermediate_size_per_partition]
# for SUPA, transform it to a NUMA colmajor weight, shape
# [num_experts, group_nums, wn]
supa_w13_scales = torch_br._empty_ut_only(
size=(layer.local_num_experts,
layer.hidden_size // self.group_size, wn),
dtype=torch.float32,
is_numa=False,
device=cur_device,
tensor_type="colmajor")
for expert_id in range(layer.local_num_experts):
expert_w13_scales = layer.w13_weight_scale[expert_id]
expert_1_scale, expert_3_scale = expert_w13_scales.chunk(
2, dim=0) # each is a packed int4 weight
crossed_expert_w13_scales = cross_weight_32(
expert_1_scale,
expert_3_scale,
spc_num,
dim=1,
need_pad=False,
)
narrow_data = supa_w13_scales[expert_id]
narrow_data.copy_(crossed_expert_w13_scales)
layer.w13_weight_scale.data = supa_w13_scales
# NOTE: w2_weight
# after _load_w2, w2_weight is a colparallel weight, shape
# [num_experts, intermediate_size_per_partition // 8, hidden_size] INT32
# for SUPA, transform it to a NUMA colmajor weight, shape
# [spc_num * num_experts, wk, wn_block] INT32
wk = layer.intermediate_size_per_partition // 8
wn = layer.hidden_size
wn_block = align_n(wn, align_size=32, spc_num=spc_num)
supa_w2_weight_packed = torch_br._empty_ut_only(
size=(spc_num * layer.local_num_experts, wk, wn_block),
dtype=torch.int32,
is_numa=True,
device=cur_device,
tensor_type="colmajor")
for expert_id in range(layer.local_num_experts):
expert_w2 = layer.w2_weight_packed[expert_id]
pad_expert_w2 = _convert_to_numa_tensor(expert_w2,
spc_num,
'COLMAJOR',
expert_w2.dtype,
do_transpose=True)
pad_expert_w2_shape = pad_expert_w2.shape
hw_size = pad_expert_w2_shape[-2] * pad_expert_w2_shape[-1]
narrow_data = supa_w2_weight_packed.view_as_usharp(
"COLMAJOR", pad_expert_w2_shape, Sbp.ss(0),
expert_id * hw_size)
narrow_data.copy_(pad_expert_w2)
layer.w2_weight_packed.data = supa_w2_weight_packed
# NOTE: w2_scale
# after _load_w2, w2_weight is a colparallel weight, shape
# S4: [num_experts, intermediate_size_per_partition // 128, hidden_size]
# for SUPA, transform it to a NUMA colmajor weight, shape
# [num_experts, group_nums, wn]
supa_w2_scales = torch_br._empty_ut_only(
size=(layer.local_num_experts,
layer.intermediate_size_per_partition // self.group_size,
wn),
dtype=torch.float32,
is_numa=False,
device=cur_device,
tensor_type="colmajor")
for expert_id in range(layer.local_num_experts):
expert_w2 = layer.w2_weight_scale[expert_id]
narrow_data = supa_w2_scales[expert_id::layer.local_num_experts]
narrow_data.copy_(expert_w2)
layer.w2_weight_scale.data = supa_w2_scales
else:
raise ValueError(
f"Unsupported num_bits: {self.num_bits}. Only 4 and 8 are supported."
)
# remove other CompressedTensorsWNA16MoEMethod registied buffer to reduce memory usage
layer.w13_weight_shape = None
layer.w13_weight_g_idx = None
layer.w13_g_idx_sort_indices = None
layer.w2_weight_shape = None
layer.w2_weight_g_idx = None
layer.w2_g_idx_sort_indices = None
@patch_to(CompressedTensorsWNA16MoEMethod)
def apply(
self,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
b_seq = x.shape[0]
gating_weight, shared_gate_up_weight, shared_down_weight = router_logits
if b_seq > envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN:
return fused_moe_quant_dyn(
x,
shared_gate_up_weight,
shared_down_weight,
layer.w13_weight_packed,
layer.w2_weight_packed,
layer.w13_weight_scale,
layer.w2_weight_scale,
gating_weight,
top_k,
layer.intermediate_size_per_partition,
renormalize=renormalize,
inplace=True,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
ep_rank=layer.ep_rank,
ep_size=layer.ep_size,
)
else:
return fused_moe_quant_device(
x,
shared_gate_up_weight,
shared_down_weight,
layer.w13_weight_packed,
layer.w2_weight_packed,
layer.w13_weight_scale,
layer.w2_weight_scale,
gating_weight,
top_k,
layer.intermediate_size_per_partition,
renormalize=renormalize,
inplace=True,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
tp_rank=get_tp_group().rank_in_group,
global_rank=get_tp_group().rank,
tp_size=get_tensor_model_parallel_world_size(),
ep_rank=layer.ep_rank,
ep_size=layer.ep_size,
)

View File

@@ -0,0 +1,267 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from typing import Callable, Optional
import torch
import torch_br
from compressed_tensors.compressors.quantized_compressors import (
unpack_from_int32)
from fastcore.basics import patch_to
from vllm.distributed import (get_pipeline_model_parallel_group,
get_tensor_model_parallel_world_size,
get_tp_group)
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_repeat_scales_on_all_ranks)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
RowvLLMParameter)
# yapf: enable
from vllm_br import envs
from ...br_utils import _convert_to_numa_tensor
@patch_to(CompressedTensorsWNA16)
def create_weights(self, layer: torch.nn.Module, input_size: int,
input_size_per_partition: int, output_size: int,
output_partition_sizes: list[int],
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
self.output_size_per_partition = sum(output_partition_sizes)
group_size = self.group_size if self.group_size != -1 else input_size
row_parallel = (input_size != input_size_per_partition)
partition_scales = not marlin_repeat_scales_on_all_ranks(
self.has_g_idx, self.group_size, row_parallel)
scales_and_zp_size = input_size // group_size
if partition_scales:
assert input_size_per_partition % group_size == 0
scales_and_zp_size = input_size_per_partition // group_size
weight = PackedvLLMParameter(
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
packed_factor=self.pack_factor,
packed_dim=1,
data=torch.empty(self.output_size_per_partition,
input_size_per_partition // self.pack_factor,
dtype=torch.int32,
device="cpu"))
weight_scale_args = {
"weight_loader":
weight_loader,
"data":
torch.empty(
self.output_size_per_partition,
scales_and_zp_size,
device="cpu",
dtype=params_dtype,
)
}
zeros_args = {
"weight_loader":
weight_loader,
"data":
torch.zeros(
self.output_size_per_partition // self.pack_factor,
scales_and_zp_size,
device="cpu",
dtype=torch.int32,
)
}
if not partition_scales:
weight_scale = ChannelQuantScaleParameter(output_dim=0,
**weight_scale_args)
if not self.symmetric:
qzeros = PackedColumnParameter(output_dim=0,
packed_dim=0,
packed_factor=self.pack_factor,
**zeros_args)
else:
weight_scale = GroupQuantScaleParameter(output_dim=0,
input_dim=1,
**weight_scale_args)
if not self.symmetric:
qzeros = PackedvLLMParameter(input_dim=1,
output_dim=0,
packed_dim=0,
packed_factor=self.pack_factor,
**zeros_args)
# A 2D array defining the original shape of the weights
# before packing
weight_shape = BasevLLMParameter(data=torch.empty(2,
dtype=torch.int64,
device="cpu"),
weight_loader=weight_loader)
layer.register_parameter("weight_packed", weight)
layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight_shape", weight_shape)
if not self.symmetric:
layer.register_parameter("weight_zero_point", qzeros)
# group index (for activation reordering)
if self.has_g_idx:
weight_g_idx = RowvLLMParameter(data=torch.empty(
input_size_per_partition,
device="cpu",
dtype=torch.int32,
),
input_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight_g_idx", weight_g_idx)
self.input_size_per_partition = input_size_per_partition
@patch_to(CompressedTensorsWNA16)
def process_weights_after_loading(self: CompressedTensorsWNA16,
layer: torch.nn.Module) -> None:
# spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
# cur_device = torch.supa.current_device()
self.num_bits = 32 // self.pack_factor
layer.weight_packed.data = unpack_from_int32(
layer.weight_packed.data, self.num_bits,
torch.Size(
[self.output_size_per_partition, self.input_size_per_partition]),
1)
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
requires_grad=False)
br_scales = layer.weight_scale.data.to(torch.float32)
layer.weight_scale.data = br_scales
do_transpose = True
parallel_type = "col_parallel"
match layer:
case RowParallelLinear():
parallel_type = "row_parallel"
case _:
pass
if hasattr(layer, 'weight_packed') and len(layer.weight_packed.shape) == 2:
weight_packed = layer.weight_packed.data
layer.weight_packed.data = _convert_to_numa_tensor(
weight_packed,
envs.VLLM_BR_DEVICE_WARP_SIZE,
"colmajor",
torch.int8,
do_transpose=do_transpose,
wk=(weight_packed.shape[1]
if do_transpose else weight_packed.shape[0]),
wn=(weight_packed.shape[0]
if do_transpose else weight_packed.shape[1]),
parallel_type=parallel_type) # noqa: SIM210
if hasattr(layer, 'weight_scale') and layer.weight_scale is not None:
pad_zeros = False
layer.weight_scale.data = _convert_to_numa_tensor(
layer.weight_scale.data.T,
envs.VLLM_BR_DEVICE_WARP_SIZE,
"linear_bias",
torch.float32,
parallel_type=parallel_type,
pad_zeros=pad_zeros)
if hasattr(layer, 'bias') and layer.bias is not None:
pad_zeros = (parallel_type == "row_parallel")
layer.bias.data = _convert_to_numa_tensor(
layer.bias.data,
envs.VLLM_BR_DEVICE_WARP_SIZE,
"linear_bias",
torch.float32,
parallel_type=parallel_type,
pad_zeros=pad_zeros)
@patch_to(CompressedTensorsWNA16)
def apply_weights(self: CompressedTensorsWNA16,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# numa weight is 3-dims
if len(layer.weight_packed.shape) == 3:
output_size = (layer.output_size_per_partition if hasattr(
layer, "output_size_per_partition") else layer.output_size)
act_mode = "act_default"
if isinstance(layer, MergedColumnParallelLinear):
act_mode = "act_swiglu"
output_size //= 2
if isinstance(layer, RowParallelLinear):
seq_len = x.shape[-2]
tp_size = get_tensor_model_parallel_world_size()
# bypass tp8 and tp4pp2 allreduce
pp_size = get_pipeline_model_parallel_group().world_size
all_rank = tp_size * pp_size
support_types = ((16, 4), (16, 8), (32, 2), (32, 4))
layer.reduce_results = not (
all_rank <= envs.VLLM_BR_USE_FUSED_ALLREDUCE
and seq_len <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and
(envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types)
if layer.reduce_results:
return torch_br.br_fused_mlp_infer(
x, [layer.weight_packed.data],
output_w=output_size,
scales=[layer.weight_scale.data]
if layer.weight_scale.data is not None else None,
bias=[bias] if bias is not None else None,
activaion_mode=act_mode)
else:
tp_rank = get_tp_group().rank_in_group
global_rank = get_tp_group().rank
rank_i = global_rank % tp_size
assert rank_i == tp_rank
return torch_br.supa_fused_linear_allreduce_opt(
x,
layer.weight_packed.data,
output_size,
tp_rank,
tp_size,
global_rank,
0,
scales=layer.weight_scale.data,
bias=bias,
act_mode=act_mode)
else:
return torch_br.br_fused_mlp_infer(
x, [layer.weight_packed.data],
output_w=output_size,
scales=[layer.weight_scale.data]
if layer.weight_scale.data is not None else None,
bias=[bias] if bias is not None else None,
activation_mode=act_mode)
xn = x.shape[0]
xh = x.shape[1]
ww = layer.weight_packed.shape[1]
# TODO, hard code to skip dry_run stage
if xh >= 4096:
return torch.ones((xn, xh, ww), dtype=x.dtype, device=x.device)
return torch_br.sudnn_qmatmul_infer(x,
layer.weight_packed,
layer.weight_scale,
bias=bias)

View File

@@ -0,0 +1,34 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from typing import Optional
def get_compressed_tensors_cache_scale(name: str) -> Optional[str]:
"""
Check whether the param name matches the format for k/v cache scales
in compressed-tensors. If this is the case, return its equivalent
param name expected by vLLM
:param name: param name
:return: matching param name for KV cache scale in vLLM
"""
if name.endswith(".output_scale") and ".k_proj" in name:
return name.replace(".k_proj.output_scale", ".attn.k_scale")
if name.endswith(".output_scale") and ".v_proj" in name:
return name.replace(".v_proj.output_scale", ".attn.v_scale")
# If no matches, return None
return None

View File

@@ -0,0 +1,244 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from functools import wraps
from typing import Any, Dict, Optional
import torch
import torch_br
from fastcore.basics import patch_to
from torch.nn.parameter import Parameter
from vllm.distributed import (get_pipeline_model_parallel_group,
get_tensor_model_parallel_world_size,
get_tp_group)
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization.gptq import (GPTQConfig,
GPTQLinearMethod)
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
get_linear_quant_method)
from vllm_br import envs
from ..br_utils import _br_qweight_cvt, _convert_to_numa_tensor
from ..linear import (process_weights_MergedColumnParallelLinear,
process_weights_QuantMergedColumnParallelLinear,
process_weights_ReplicatedLinear)
@patch_to(GPTQConfig, cls_method=True)
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16]
@patch_to(GPTQConfig)
def get_quant_method(self: GPTQConfig, layer: torch.nn.Module,
prefix: str) -> Optional["GPTQLinearMethod"]:
quant_method = get_linear_quant_method(self, layer, prefix,
GPTQLinearMethod)
return quant_method
@patch_to(GPTQConfig, cls_method=True)
def from_config(cls, config: Dict[str, Any]) -> GPTQConfig:
"""
[PatchNote] add qkv_quantized param support
"""
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
dynamic = {} if dynamic is None else dynamic
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
autoround_version = cls.get_from_keys_or(config, ["autoround_version"],
default="")
modules_in_block_to_quantize = cls.get_from_keys_or(
config, ["modules_in_block_to_quantize"], default=None)
qkv_quantized = cls.get_from_keys_or(config, ["qkv_quantized"],
default=True)
return cls(weight_bits=weight_bits,
group_size=group_size,
desc_act=desc_act,
lm_head_quantized=lm_head_quantized,
dynamic=dynamic,
autoround_version=autoround_version,
modules_in_block_to_quantize=modules_in_block_to_quantize,
qkv_quantized=qkv_quantized)
def wrapper_GPTQConfig_init(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
qkv_quantized = kwargs.pop("qkv_quantized", True)
fn(self, *args, **kwargs)
self.qkv_quantized = qkv_quantized
return wrapper
GPTQConfig.__init__ = wrapper_GPTQConfig_init(
GPTQConfig.__init__) # noqa: E501
@patch_to(GPTQLinearMethod)
def process_weights_after_loading(self: GPTQLinearMethod,
layer: torch.nn.Module) -> None:
still_need_process = True
merge_col_quant = False
# NOTE: all process_weights func should done before process_weights_after_loading
parallel_type = "col_parallel"
match layer:
case ReplicatedLinear():
process_weights_ReplicatedLinear(layer)
still_need_process = layer.output_size == 64 or layer.output_size == 256
case MergedColumnParallelLinear():
if hasattr(layer, "qweight"):
merge_col_quant = True
else:
process_weights_MergedColumnParallelLinear(layer)
still_need_process = False
case RowParallelLinear():
parallel_type = "row_parallel"
case _:
pass
# NOTE: if use exllama, br gptq needs similar treatment
# exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass
if layer.qweight.dtype == torch.int32:
input_size = layer.input_size_per_partition if hasattr(
layer, 'input_size_per_partition') else layer.input_size
output_size = layer.output_size_per_partition if hasattr(
layer, 'output_size_per_partition') else layer.output_size
br_qweight = _br_qweight_cvt(self, layer.qweight, layer.qzeros,
input_size, output_size)
layer.qweight.data = br_qweight
if merge_col_quant:
process_weights_QuantMergedColumnParallelLinear(layer)
still_need_process = False
br_scales = layer.scales.to(torch.float32)
layer.scales.data = br_scales
# for torch.compile
layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False)
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
layer.scales = Parameter(layer.scales.data, requires_grad=False)
if not still_need_process or self.weight_type != "NUMA":
return
if hasattr(layer, 'qweight') and len(layer.qweight.shape) == 2:
layer.qweight.data = _convert_to_numa_tensor(
layer.qweight,
envs.VLLM_BR_DEVICE_WARP_SIZE,
"colmajor",
torch.int8,
parallel_type=parallel_type)
if hasattr(layer, 'scales') and layer.scales is not None:
pad_zeros = False
layer.scales.data = _convert_to_numa_tensor(
layer.scales,
envs.VLLM_BR_DEVICE_WARP_SIZE,
"linear_bias",
torch.float32,
parallel_type=parallel_type,
pad_zeros=pad_zeros)
if hasattr(layer, 'bias') and layer.bias is not None:
pad_zeros = (parallel_type == "row_parallel")
layer.bias.data = _convert_to_numa_tensor(
layer.bias,
envs.VLLM_BR_DEVICE_WARP_SIZE,
"linear_bias",
torch.float32,
parallel_type=parallel_type,
pad_zeros=pad_zeros)
@patch_to(GPTQLinearMethod)
def apply(self: GPTQLinearMethod,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# numa weight is 3-dims
if len(layer.qweight.shape) == 3:
output_size = (layer.output_size_per_partition if hasattr(
layer, "output_size_per_partition") else layer.output_size)
act_mode = "act_default"
if isinstance(layer, MergedColumnParallelLinear):
act_mode = "act_swiglu"
output_size //= 2
if isinstance(layer, RowParallelLinear):
seq_len = x.shape[-2]
tp_size = get_tensor_model_parallel_world_size()
# bypass tp8 and tp4pp2 allreduce
pp_size = get_pipeline_model_parallel_group().world_size
all_rank = tp_size * pp_size
support_types = ((16, 4), (16, 8), (32, 2), (32, 4))
layer.reduce_results = not (
all_rank <= envs.VLLM_BR_USE_FUSED_ALLREDUCE
and seq_len <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and
(envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types)
if layer.reduce_results:
return torch_br.br_fused_mlp_infer(
x, [layer.qweight],
output_w=output_size,
scales=[layer.scales]
if layer.scales is not None else None,
bias=[bias] if bias is not None else None,
activation_mode=act_mode)
else:
tp_rank = get_tp_group().rank_in_group
global_rank = get_tp_group().rank
rank_i = global_rank % tp_size
assert rank_i == tp_rank
return torch_br.supa_fused_linear_allreduce_opt(
x,
layer.qweight,
output_size,
tp_rank,
tp_size,
global_rank,
0,
scales=layer.scales,
bias=bias,
act_mode=act_mode)
else:
return torch_br.br_fused_mlp_infer(
x, [layer.qweight],
output_w=output_size,
scales=[layer.scales] if layer.scales is not None else None,
bias=[bias] if bias is not None else None,
activation_mode=act_mode)
xn = x.shape[0]
xh = x.shape[1]
ww = layer.qweight.shape[1]
# TODO, hard code to skip dry_run stage
if xh >= 4096:
return torch.ones((xn, xh, ww), dtype=x.dtype, device=x.device)
return torch_br.sudnn_qmatmul_infer(x,
layer.qweight,
layer.scales,
bias=bias)