This commit is contained in:
2026-04-02 04:53:13 +00:00
parent 80932c96e5
commit 24df76db9d
1987 changed files with 447445 additions and 0 deletions

View File

@@ -0,0 +1,36 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Literal, Type, get_args
QuantizationMethods = Literal[
# "aqlm",
"awq",
"deepspeedfp",
"tpu_int8",
"fp8",
"ptpc_fp8",
"fbgemm_fp8",
"modelopt",
"modelopt_fp4",
"bitblas",
"gguf",
"gptq_marlin_24",
"gptq_marlin",
"gptq_bitblas",
"awq_marlin",
"gptq",
"compressed-tensors",
"bitsandbytes",
"hqq",
"experts_int8",
"ipex",
"quark",
"moe_wna16",
"torchao",
"auto-round",
"rtn",
"inc",
"mxfp4",
"petit_nvfp4",
]
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))

View File

@@ -0,0 +1,615 @@
import functools
import importlib.util
from typing import Any, Callable, Dict, List, Optional
import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
apply_fp8_block_linear, check_aiter_fp8_linear_support,
create_fp8_input_scale, create_fp8_scale_parameter,
create_fp8_weight_parameter, expert_weight_is_col_major,
maybe_post_process_fp8_weight_block, process_fp8_weight_block_strategy,
process_fp8_weight_tensor_strategy, requant_weight_ue8m0_inplace,
validate_fp8_block_shape)
# from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
# all_close_1d, apply_fp8_linear, convert_to_channelwise,
# cutlass_block_fp8_supported, cutlass_fp8_supported,
# normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
# requantize_with_max_scale)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, all_close_1d, convert_to_channelwise,
cutlass_block_fp8_supported, cutlass_fp8_supported,
maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz,
per_tensor_dequantize, requantize_with_max_scale)
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
from vllm.platforms import current_platform
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, is_layer_skipped)
from vllm.model_executor.layers.linear import QKVParallelLinear
from vllm.utils import has_deep_gemm
logger = init_logger(__name__)
# has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
def Fp8LinearMethod__init(self, quant_config: Fp8Config):
self.quant_config = quant_config
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
self.out_dtype = torch.get_default_dtype()
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self.use_marlin = (not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
# Disable marlin for rocm
if current_platform.is_rocm():
self.use_marlin = False
self.weight_block_size = self.quant_config.weight_block_size
self.block_quant = self.quant_config.weight_block_size is not None
self.act_q_static = self.quant_config.activation_scheme == "static"
# Use per-token quantization for better perf if dynamic and cutlass
if not self.act_q_static and cutlass_fp8_supported():
self.act_q_group_shape = GroupShape.PER_TOKEN
else:
self.act_q_group_shape = GroupShape.PER_TENSOR
if self.block_quant:
self.block_size = self.quant_config.weight_block_size
if self.block_quant:
# Marlin doesn't support block-wise fp8
self.use_marlin = False
self.scale_k = 1
self.scale_n = 1
self.scale_n_prefill = 1 # only for fp8 moe
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.act_q_static,
act_quant_group_shape=self.act_q_group_shape)
class Fp8LinearMethod(LinearMethodBase):
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
if self.block_quant:
scale_n = extra_weight_attrs.get("scale_n")
scale_k = extra_weight_attrs.get("scale_k")
if scale_n is not None:
self.scale_n = scale_n
if scale_k is not None:
self.scale_k = scale_k
assert self.weight_block_size is not None
layer.weight_block_size = self.weight_block_size
tp_size = get_tensor_model_parallel_world_size()
assert self.quant_config.weight_block_size is not None
block_n, block_k = (
self.quant_config.weight_block_size[0] // self.scale_n ,
self.quant_config.weight_block_size[1] // self.scale_k ,
)
# Required by row parallel
if (tp_size > 1
and input_size // input_size_per_partition == tp_size
and input_size_per_partition % block_k != 0):
raise ValueError(
f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible by "
f"weight quantization block_k = {block_k}.")
# Required by column parallel or enabling merged weights
if (tp_size > 1 and output_size // output_size_per_partition
== tp_size) or len(output_partition_sizes) > 1:
for output_partition_size in output_partition_sizes:
if output_partition_size % block_n != 0:
raise ValueError(
f"Weight output_partition_size = "
f"{output_partition_size} is not divisible by "
f"weight quantization block_n = {block_n}.")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
# WEIGHT
weight_dtype = (torch.float8_e4m3fn
if self.quant_config.is_checkpoint_fp8_serialized else
params_dtype)
weight = ModelWeightParameter(data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=weight_dtype),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight)
# If checkpoint is serialized fp8, load them.
# Otherwise, wait until process_weights_after_loading.
if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALE
if not self.block_quant:
scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes),
dtype=torch.float32),
weight_loader=weight_loader,
)
scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", scale)
else:
assert self.quant_config.activation_scheme == "dynamic"
scale = BlockQuantScaleParameter(
data=torch.empty(
(output_size_per_partition + block_n - 1) // block_n,
(input_size_per_partition + block_k - 1) // block_k,
dtype=torch.float32,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
scale[:] = torch.finfo(torch.float32).min
# The weight_scale_inv name is intentional for deepseekv3
layer.register_parameter("weight_scale_inv", scale)
# INPUT ACTIVATION SCALE
if self.quant_config.activation_scheme == "static":
scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("input_scale", scale)
else:
layer.register_parameter("input_scale", None)
def process_weights_after_loading(self, layer: Module) -> None:
# TODO(rob): refactor block quant into separate class.
if self.block_quant:
assert self.quant_config.activation_scheme == "dynamic"
if current_platform.is_fp8_fnuz():
weight, weight_scale_inv, _ = \
normalize_e4m3fn_to_e4m3fnuz(
weight=layer.weight,
weight_scale=layer.weight_scale_inv)
else:
weight = layer.weight.data
weight_scale_inv = layer.weight_scale_inv.data
if isinstance(layer, QKVParallelLinear):
# NOTE: for QKVParallelLinear
# weight_scale should be divisible by 8 Dsps
shape = weight_scale_inv.shape[0]
repeat = 1
while shape % 8 != 0:
repeat *= 2
shape = shape * repeat
weight_scale_inv = torch.repeat_interleave(weight_scale_inv, repeats=repeat, dim=0)
# weight = self._maybe_pad_weight(weight)
# if self.block_quant:
# maybe_post_process_fp8_weight_block(
# layer, self.cutlass_block_fp8_supported)
# Torch.compile cannot use Parameter subclasses.
layer.weight = Parameter(weight, requires_grad=False)
layer.weight_scale_inv = Parameter(weight_scale_inv,
requires_grad=False)
return
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.use_marlin:
return apply_fp8_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias)
# Note: lazy import to avoid triton import error.
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
apply_w8a8_block_fp8_linear)
if self.block_quant:
assert self.quant_config.weight_block_size is not None
return apply_w8a8_block_fp8_linear(
input=x,
weight=layer.weight,
block_size=[layer.weight.shape[0] // layer.weight_scale_inv.shape[0], layer.weight.shape[1] // layer.weight_scale_inv.shape[1]],
weight_scale=layer.weight_scale_inv,
input_scale=layer.input_scale,
bias=bias,
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
)
return self.fp8_linear.apply(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=layer.input_scale,
bias=bias)
# return apply_fp8_linear(
# input=x,
# weight=layer.weight,
# weight_scale=layer.weight_scale,
# input_scale=layer.input_scale,
# bias=bias,
# cutlass_fp8_supported=self.cutlass_fp8_supported,
# # Default to using per_token quantization if cutlass is supported
# use_per_token_if_dynamic=self.cutlass_fp8_supported)
def Fp8MoEMethod_init_(self, quant_config: Fp8Config, layer: torch.nn.Module):
self.layer = layer
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
self.quant_config = quant_config
self.block_quant = self.quant_config.weight_block_size is not None
self.flashinfer_moe_backend = None
self.scale_k = 1
self.scale_n = 1
self.scale_n_prefill = 1
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self.use_marlin = (not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
# Disable marlin for rocm
if current_platform.is_rocm() or current_platform.is_vacc:
self.use_marlin = False
# Check for DeepGemm support.
self.allow_deep_gemm = False
if envs.VLLM_USE_DEEP_GEMM:
if not has_deep_gemm():
logger.warning_once("Failed to import DeepGemm kernels.")
elif not self.block_quant:
logger.warning_once("Model is not block quantized. Not using "
" DeepGemm kernels")
elif (current_platform.is_cuda()
and current_platform.has_device_capability(90)):
logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
self.allow_deep_gemm = True
else:
logger.warning_once(
"DeepGemm not supported on the current platform.")
# Check for CutlassBlockScaledGroupedGemm support.
self.allow_cutlass_block_scaled_grouped_gemm = False
if not self.block_quant:
logger.warning_once("Model is not block quantized. Not using "
"CutlassBlockScaledGroupedGemm kernels")
elif (current_platform.is_cuda()
and current_platform.has_device_capability(100)):
logger.info_once(
"Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod."
)
self.allow_cutlass_block_scaled_grouped_gemm = True
else:
logger.warning_once(
"CutlassBlockScaledGroupedGemm not supported on the current "
"platform.")
self.topk_indices_dtype = None
self.fused_experts = functools.partial( # type: ignore
fused_experts,
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm,
allow_cutlass_block_scaled_grouped_gemm=(
self.allow_cutlass_block_scaled_grouped_gemm))
class Fp8MoEMethod(FusedMoEMethodBase):
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn
if self.block_quant:
assert self.quant_config.weight_block_size is not None
scale_n = extra_weight_attrs.get("scale_n")
scale_n_prefill = extra_weight_attrs.get("scale_n_prefill")
scale_k = extra_weight_attrs.get("scale_k")
if scale_n is not None:
self.scale_n = scale_n
if scale_k is not None:
self.scale_k = scale_k
if scale_n_prefill is not None:
self.scale_n_prefill = scale_n_prefill
if self.quant_config is not None and self.quant_config.weight_block_size is not None:
self.gcd_value = self.quant_config.weight_block_size[0]
output_size_no_merge = intermediate_size_per_partition
#assert isinstance(output_size_no_merge, int), f"merge output size should divded int, valuue is: {output_size_no_merge}"
if output_size_no_merge % self.quant_config.weight_block_size[0]:
import math
gcd_value = math.gcd(output_size_no_merge % self.quant_config.weight_block_size[0], self.quant_config.weight_block_size[0])
self.scale_n =self.scale_n * self.quant_config.weight_block_size[0] // gcd_value
self.scale_n_prefill =self.scale_n_prefill * self.quant_config.weight_block_size[0] // gcd_value
if hidden_size % self.quant_config.weight_block_size[1]:
import math
gcd_value = math.gcd(hidden_size % self.quant_config.weight_block_size[1], self.quant_config.weight_block_size[1])
self.scale_k =self.scale_k * self.quant_config.weight_block_size[1] // gcd_value
# self.scale_k = self.scale_n
# print('output_size_no_merge', output_size_no_merge)
# 按 block_size 分core
# output_size_no_merge = 384
# block_size = 128: 384 = 3x128 只能分3core x 128
# block_size = 16: 384 = 24x16 8core x (3x16) 可以分到 8core
# output_size_no_merge = 512
# block_size = 128: 512 = 4x128 只能分 4core x 128
# block_size = 64: 512 = 8x64 可以分到 8core x 64
# output_size_no_merge = 768
# block_size = 128: 768 = 6x128 只能分 6core x 128
# block_size = 32: 768 = 8x(3x32) 可以分到 8core x (3x32)
core_num = 8
min_block_size = 4
block_size_tmp = self.quant_config.weight_block_size[0] // self.scale_n
if output_size_no_merge > block_size_tmp and \
output_size_no_merge % block_size_tmp == 0 and \
output_size_no_merge // block_size_tmp < core_num and \
output_size_no_merge % core_num == 0:
core_num_old = output_size_no_merge // block_size_tmp
import math
gcd_value = math.gcd(core_num, core_num_old)
new_scale = core_num // gcd_value
if block_size_tmp // new_scale >= min_block_size:
self.scale_n = new_scale * self.scale_n
#print("moe scale n is:", self.scale_n, self.scale_k, intermediate_size_per_partition)
tp_size = get_tensor_model_parallel_world_size()
if self.scale_n != self.scale_n_prefill:
block_n_prefill = self.quant_config.weight_block_size[0] // self.scale_n_prefill
block_n, block_k = (
self.quant_config.weight_block_size[0] // self.scale_n,
self.quant_config.weight_block_size[1] // self.scale_k,
)
# NOTE: To ensure proper alignment of the block-wise quantization
# scales, the output_size of the weights for both the gate and up
# layers must be divisible by block_n.
# Required by column parallel or enabling merged weights
if intermediate_size_per_partition % block_n != 0:
raise ValueError(
f"The output_size of gate's and up's weight = "
f"{intermediate_size_per_partition} is not divisible by "
f"weight quantization block_n = {block_n}.")
if (tp_size > 1
and hidden_size % block_k != 0):
# Required by row parallel
raise ValueError(
f"The input_size of down's weight = "
f"{intermediate_size_per_partition} is not divisible by "
f"weight quantization block_k = {block_k}.")
# WEIGHTS
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
if not self.block_quant:
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter(torch.ones(
num_experts, 2, dtype=torch.float32),
requires_grad=False)
w2_weight_scale = torch.nn.Parameter(torch.ones(
num_experts, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
else:
w13_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
2 * ((intermediate_size_per_partition + block_n - 1) //
block_n),
(hidden_size + block_k - 1) // block_k,
dtype=torch.float32,
),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
(hidden_size + block_k - 1) // block_k,
(intermediate_size_per_partition + block_n - 1) // block_n,
dtype=torch.float32,
),
requires_grad=False,
)
if self.scale_n != self.scale_n_prefill:
w13_weight_scale_prefill = torch.nn.Parameter(
torch.ones(
num_experts,
2 * ((intermediate_size_per_partition + block_n_prefill - 1) //
block_n_prefill),
(hidden_size + block_k - 1) // block_k,
dtype=torch.float32,
),
requires_grad=False,
)
w2_weight_scale_prefill = torch.nn.Parameter(
torch.ones(
num_experts,
(hidden_size + block_k - 1) // block_k,
(intermediate_size_per_partition + block_n_prefill - 1) // block_n_prefill,
dtype=torch.float32,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale_inv_prefill", w13_weight_scale_prefill)
layer.register_parameter("w2_weight_scale_inv_prefill", w2_weight_scale_prefill)
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
assert self.quant_config.activation_scheme == "dynamic"
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.
value} if self.block_quant else
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if self.quant_config.is_checkpoint_fp8_serialized:
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
if self.scale_n != self.scale_n_prefill:
set_weight_attrs(w13_weight_scale_prefill, extra_weight_attrs)
set_weight_attrs(w2_weight_scale_prefill, extra_weight_attrs)
# INPUT_SCALES
if self.quant_config.activation_scheme == "static":
if not self.quant_config.is_checkpoint_fp8_serialized:
raise ValueError(
"Found static activation scheme for checkpoint that "
"was not serialized fp8.")
w13_input_scale = torch.nn.Parameter(torch.ones(
num_experts, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_input_scale", w13_input_scale)
set_weight_attrs(w13_input_scale, extra_weight_attrs)
w2_input_scale = torch.nn.Parameter(torch.ones(
num_experts, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_input_scale", w2_input_scale)
set_weight_attrs(w2_input_scale, extra_weight_attrs)
else:
layer.w13_input_scale = None
layer.w2_input_scale = None
def moe_fp8_apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import FusedMoE
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
try:
from torch_vacc.vacc.custom_ops import fused_experts
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
experts_output = None
if memory_recycler is not None:
# remove MOE_EXPERT_OUT_BUFFER
# experts_output = memory_recycler.MOE_EXPERT_OUT_BUFFER
experts_output = memory_recycler.MOE_SHARED_MLP_OUT_BUFFER
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
use_fp8_w8a8=True,
w13_scale=(layer.w13_weight_scale_inv
if self.block_quant else layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale_inv
if self.block_quant else layer.w2_weight_scale),
a13_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
decode_with_batch=layer.is_decode and x.shape[0] > 1,
output_opt=experts_output
)
except Exception as e:
print(f"vacc fused_expert run fail, now using unfused ops: {e}")
from torch_vacc.vacc.custom_ops_cpu import fused_experts
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
use_fp8_w8a8=True,
w13_scale=(layer.w13_weight_scale_inv
if self.block_quant else layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale_inv
if self.block_quant else layer.w2_weight_scale),
a13_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
)

View File

@@ -0,0 +1,282 @@
# SPDX-License-Identifier: Apache-2.0
import enum
from enum import Enum
from fractions import Fraction
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
get_linear_quant_method)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
RowvLLMParameter)
from vllm.model_executor.layers.quantization.gptq import GPTQConfig as GPTQConfigOrig
from vllm.model_executor.layers.quantization.gptq import ExllamaState
from vllm_vacc.vllm.model_executor.models.vars import TRANSPOSE_GPTQ_WEIGHT
import math
def GPTQLinearMethod__init(self, quant_config: GPTQConfigOrig):
self.quant_config = quant_config
self.scale_k = 1
self.split_num = 4
def int32_to_int4(s0, axis = -2):
# 要先拉平 shape[1, n]
# 每个int32 拆成8个int4, 8个int32表示 得到[8, n]
# x32(int32) => 32bit => 4bit x 8 x4[8] 4bit
# x32 31-28 => x4[7]
# x32 27-24 => x4[6]
# ...
# x32 3-0 => x4[0]
# x32[index=0] => x4[7,6,5,4,3,2,1,0]
# 4bit转真实数字
# 不是按补码方式
# 1111 => 15 => 7
# 15-8 = 7
# 0101 => 6 =>-2
# 6-8 = -2
# 0x 6A CB 37 2B (内存中排列 2B 37 CB 6A => B273BCA6 => (-8) => int4: 3, -6, -1, -5, 3, 4, 2, -2
# 内存中实际排布为小端模式:
# int32: 2B 37 CB 6A => 2,11,3,7,12,11,6,10 => (-8) => -6,3, -5,-1, 4,3, -2,2 => 同一字节所在的两个交换得到 3, -6, -1, -5, 3, 4, 2, -2
# int4: 3, -6, -1, -5, 3, 4, 2, -2
s = s0.view(torch.uint32)
all = []
for i in range(8):
x = 15 << (i*4)
# s2 = torch.bitwise_and(x,s)
s2 = torch.from_numpy(np.bitwise_and(x, s.numpy()))
s3 = s2 / (2 ** (i*4))
s4 = s3.to(torch.int32)
# 补码, 结果不对
# s4[s4 > 7] = s4[s4 > 7]-16
# 直接 - 8 结果正确, 范围: -8-7
s4 = s4 - 8
all.append(s4.reshape(1,*s4.shape))
all = torch.concatenate(all, 0)
if axis == -2 or axis == 0:
# 8,K//8,N => K//8,8,N => K,N
all = all.transpose(-2,0).reshape(-1,all.shape[-1]).contiguous()
else:
# 8,N,K//8 => N,K//8,8 => N,K
all = all.permute(1,2,0).reshape(all.shape[-2],-1).contiguous()
return all
def dequant_weight(qw, scales, group_size = 128):
N = qw.shape[1]
int4_to_int32_axis = -2
if TRANSPOSE_GPTQ_WEIGHT:
N = qw.shape[0]
int4_to_int32_axis = -1
qweight = int32_to_int4(qw,int4_to_int32_axis).to(torch.float16) #int32 => 8 int4 +> fp16
if TRANSPOSE_GPTQ_WEIGHT:
scales = scales.T.contiguous()
qweight = qweight.T.contiguous()
scales = torch.concatenate([scales] * group_size, 1).reshape(-1, N) # scale 按 group_size 扩展, 每 group_size 个数共用一个scale
# print('qweight', qweight.shape, qweight.dtype)
# print('scale', scales.shape, scales.dtype)
dequant_weight = qweight * scales #dequant
return dequant_weight
class GPTQConfig(QuantizationConfig):
"""Config class for GPTQ.
Reference: https://arxiv.org/abs/2210.17323
"""
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half, torch.bfloat16]
class GPTQLinearMethod(LinearMethodBase):
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del output_size # Unused.
weight_loader = extra_weight_attrs.get("weight_loader")
# if input_size_per_partition % self.quant_config.group_size != 0:
# raise ValueError(
# "The input size is not aligned with the quantized "
# "weight shape. This can be caused by too large "
# "tensor parallel size.")
output_size_per_partition = sum(output_partition_sizes)
if (output_size_per_partition % self.quant_config.pack_factor.numerator
!= 0):
raise ValueError(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
exllama_state = ExllamaState.UNINITIALIZED
scale_and_zero_size = input_size // group_size
scale_and_zero_input_dim = None
if (input_size != input_size_per_partition
and self.quant_config.group_size != -1):
# For act-order models, we cannot use Exllama for row parallel layer
if self.quant_config.desc_act:
exllama_state = ExllamaState.UNUSED
else:
# we need to partition qzeros and scales for exllama kernel
scale_and_zero_size = input_size_per_partition // group_size
scale_and_zero_input_dim = 0
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=0,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader)
g_idx = RowvLLMParameter(data=torch.tensor(
[
i // self.quant_config.group_size
for i in range(input_size_per_partition)
],
dtype=torch.int32,
),
input_dim=0,
weight_loader=weight_loader)
qzeros_args = {
"data":
torch.empty(
scale_and_zero_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
"weight_loader":
weight_loader
}
weight_scale_args = {
"data":
torch.empty(
scale_and_zero_size,
output_size_per_partition,
dtype=params_dtype,
),
"weight_loader":
weight_loader
}
if scale_and_zero_input_dim is None:
scales = ChannelQuantScaleParameter(output_dim=1,
**weight_scale_args)
qzeros = PackedColumnParameter(
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args)
else:
scales = GroupQuantScaleParameter(output_dim=1,
input_dim=0,
**weight_scale_args)
qzeros = PackedvLLMParameter(
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args)
layer.register_parameter("qweight", qweight)
layer.register_parameter("g_idx", g_idx)
layer.register_parameter("qzeros", qzeros)
layer.register_parameter("scales", scales)
layer.exllama_state = exllama_state
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# for torch.compile
# self.quant_config.weight_bits == 4
if TRANSPOSE_GPTQ_WEIGHT:
layer.qzeros = Parameter(layer.qzeros.data.T.contiguous(), requires_grad=False)
layer.qweight = Parameter(layer.qweight.data.T.contiguous(), requires_grad=False)
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
layer.scales = Parameter(layer.scales.data.T.contiguous(), requires_grad=False)
else:
layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False)
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
layer.scales = Parameter(layer.scales.data, requires_grad=False)
# exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass
if layer.exllama_state == ExllamaState.UNINITIALIZED:
if self.quant_config.desc_act:
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
layer.exllama_state = ExllamaState.READY
ops.gptq_shuffle(layer.qweight, layer.g_idx,
self.quant_config.weight_bits)
else:
layer.g_idx.data = torch.empty((0, ),
dtype=torch.int,
device=layer.g_idx.device)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
out_shape = x.shape[:-1] + (layer.qweight.shape[-2 if TRANSPOSE_GPTQ_WEIGHT else -1], ) # M,N
reshaped_x = x.reshape(-1, x.shape[-1])
# print(f"~~~~ start dequant")
# import time
# start_quant_time = time.time()
# weight = dequant_weight(layer.qweight.cpu(), layer.scales.cpu(), self.quant_config.group_size // self.scale_k).to(layer.qweight.device)
# end_quant_time = time.time()
# print(f"~~~~ dequant time: {end_quant_time - start_quant_time}")
# if torch.distributed.get_rank() == 0:
# print(f"~~~~ weight shape: {weight.shape}, dtype: {weight.dtype}")
# output = torch.matmul(reshaped_x, weight)
# print("entering GPTQLinearMethod apply, reshaped_x shape:", reshaped_x.shape, "reshaped_x stride", reshaped_x.stride(), "input_tensor", x.shape, "qweight shape:", layer.qweight.shape, "scales shape:", layer.scales.shape)
output = torch.vacc.w4a8_block_int4_matmul(
reshaped_x,
layer.qweight.transpose(-1, -2),
layer.scales.transpose(-1, -2),
[1, self.quant_config.group_size // self.scale_k],
)
# print("exiting GPTQLinearMethod apply, output shape:", output.shape)
# end_gemm_time = time.time()
# if torch.distributed.get_rank() == 0:
# print(f"~~~~ gemm time: {end_gemm_time - end_quant_time}")
if bias is not None:
output.add_(bias)
return output.reshape(out_shape)

View File

@@ -0,0 +1,372 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Callable, Optional, Union
import torch
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, int4_w4a16_moe_quant_config,
int8_w8a16_moe_quant_config)
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supports_layer)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
# [num_experts, N, K//8], int32 ==> [num_experts, N, K], int4 ==> [num_experts, N//8, K], int32
def repack_quant_moe_weight_old(original_packed_tensor):
num_experts = original_packed_tensor.shape[0]
N = original_packed_tensor.shape[1]
K = original_packed_tensor.shape[2] * 8
if original_packed_tensor.dtype != torch.int32:
raise ValueError("data type of input tensor should be int32")
if N % 8 != 0:
raise ValueError("N of input tensor should be divisible by 8")
# --- 1. 解包:将 int32 张量展开为逻辑上的 int4 张量 ---
# 创建一个临时张量来存储解包后的所有 int4 值
# 用 torch.uint8 作为 int4 的临时存储,因为 PyTorch 没有原生的 int4 dtype
unpacked_int4_tensor = torch.zeros(
(num_experts, N, K),
dtype=torch.uint8,
device=original_packed_tensor.device
)
mask = 0b1111
for i in range(8):
# 提取当前 int4 所需的 int32 块中的值
# 通过右移 (i * 4) 位,我们将第 i 个 4 位整数移动到最低有效位
# 然后通过按位与操作与掩码结合,提取出这 4 位的值
extracted_int4s = (original_packed_tensor >> (i * 4)) & mask
# 将提取出的 int4 值放置到 unpacked_int4_tensor 的正确位置
# 使用切片 `i::8`,意思是:从索引 `i` 开始,每隔 8 个位置填充一次
unpacked_int4_tensor[:, :, i::8] = extracted_int4s
# --- 2. 重新打包:将 int4 逻辑张量重新打包为新的 int32 张量 ---
new_packed_tensor = torch.zeros(
(num_experts, N//8, K),
dtype=torch.int32,
device=original_packed_tensor.device
)
for i in range(8):
# 从解包后的 int4 张量中提取当前需要打包的 int4 序列,使用切片 `i::8` 沿着N方向来提取
current_int4_segment = unpacked_int4_tensor[:, i::8, :]
# 将这个 int4 序列转换为 int32 类型(因为打包到 int32并左移到其在新 int32 块中的正确位置
# 然后使用按位或操作符将其合并到 new_packed_tensor 中
new_packed_tensor |= (current_int4_segment.to(torch.int32) << (i * 4))
return new_packed_tensor
def repack_quant_moe_weight(original_packed_tensor):
if original_packed_tensor.dtype != torch.int32:
raise ValueError("data type of input tensor should be int32")
num_experts, N, K_packed = original_packed_tensor.shape
K = K_packed * 8
if N % 8 != 0:
raise ValueError("N of input tensor should be divisible by 8")
new_packed_tensor = torch.zeros((num_experts, N // 8, K),
dtype=torch.int32,
device=original_packed_tensor.device)
for i in range(8):
source_slice = original_packed_tensor[:, i::8, :]
for j in range(8):
unpacked_strip = (source_slice >> (j * 4)) & 0b1111
new_packed_tensor[:, :, j::8] |= (unpacked_strip.to(torch.int32) << (i * 4))
return new_packed_tensor
class MoeWNA16Method(FusedMoEMethodBase):
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
self.moe = layer
layer.quant_config = self.quant_config
bit8_pack_factor = self.quant_config.bit8_pack_factor
bit32_pack_factor = 32 // self.quant_config.weight_bits
group_size = self.quant_config.group_size
group_size_div_factor = 1
group_size_w13 = self.quant_config.group_size
group_size_div_factor_w13 = 1
group_size_w2 = self.quant_config.group_size
group_size_div_factor_w2 = 1
# make intermediate_size and hidden_size divisible by group_size
# we reduce the group size to ensure that
# and we would repeat the loaded_weight later
while intermediate_size_per_partition % group_size or \
hidden_size % group_size:
group_size = group_size // 2
group_size_div_factor *= 2
assert group_size >= 32
layer.group_size = group_size
layer.group_size_div_factor = group_size_div_factor
while intermediate_size_per_partition % group_size_w2:
group_size_w2 = group_size_w2 // 2
group_size_div_factor_w2 *= 2
assert group_size_w2 >= 32
layer.w2_block_size = group_size_w2
layer.group_size_div_factor_w2 = group_size_div_factor_w2
while hidden_size % group_size_w13:
group_size_w13 = group_size_w13 // 2
group_size_div_factor_w13 *= 2
assert group_size_w13 >= 32
layer.w13_block_size = group_size_w13
layer.group_size_div_factor_w13 = group_size_div_factor_w13
strategy = FusedMoeWeightScaleSupported.GROUP.value
extra_weight_attrs.update({
"quant_method": strategy,
"is_transposed": False
})
assert 'weight_loader' in extra_weight_attrs
weight_loader = extra_weight_attrs['weight_loader']
wrapped_weight_loader = MoeWNA16Method.get_weight_loader(
layer, weight_loader)
extra_weight_attrs['weight_loader'] = wrapped_weight_loader
# Fused gate_up_proj (column parallel)
# w13_qweight = torch.nn.Parameter(torch.empty(
# num_experts,
# 2 * intermediate_size_per_partition,
# hidden_size // bit8_pack_factor,
# dtype=torch.uint8),
# requires_grad=False)
w13_qweight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size // bit32_pack_factor,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w13_qweight", w13_qweight)
set_weight_attrs(w13_qweight, extra_weight_attrs)
# down_proj (row parallel)
# w2_qweight = torch.nn.Parameter(torch.empty(
# num_experts,
# hidden_size,
# intermediate_size_per_partition // bit8_pack_factor,
# dtype=torch.uint8),
# requires_grad=False)
w2_qweight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition // bit32_pack_factor,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w2_qweight", w2_qweight)
set_weight_attrs(w2_qweight, extra_weight_attrs)
w13_scales = torch.nn.Parameter(torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
hidden_size // group_size_w13,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_scales", w13_scales)
set_weight_attrs(w13_scales, extra_weight_attrs)
w2_scales = torch.nn.Parameter(torch.zeros(
num_experts,
hidden_size,
intermediate_size_per_partition // group_size_w2,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_scales", w2_scales)
set_weight_attrs(w2_scales, extra_weight_attrs)
if self.quant_config.has_zp:
w13_qzeros = torch.nn.Parameter(torch.zeros(
num_experts,
2 * intermediate_size_per_partition // bit8_pack_factor,
hidden_size // group_size,
dtype=torch.uint8),
requires_grad=False)
layer.register_parameter("w13_qzeros", w13_qzeros)
set_weight_attrs(w13_qzeros, extra_weight_attrs)
w2_qzeros = torch.nn.Parameter(torch.zeros(
num_experts,
hidden_size // bit8_pack_factor,
intermediate_size_per_partition // group_size,
dtype=torch.uint8),
requires_grad=False)
layer.register_parameter("w2_qzeros", w2_qzeros)
set_weight_attrs(w2_qzeros, extra_weight_attrs)
if self.quant_config.linear_quant_method == "gptq":
# some param are unused, but we need to init them in order to
# load weights
invalid_param_keys = ["w13_g_idx", "w2_g_idx"]
if not self.quant_config.has_zp:
invalid_param_keys += ["w13_qzeros", "w2_qzeros"]
for key in invalid_param_keys:
param = torch.nn.Parameter(torch.empty((0, ),
dtype=torch.int32),
requires_grad=False)
layer.register_parameter(key, param)
set_weight_attrs(param, extra_weight_attrs)
@staticmethod
def get_weight_loader(layer, weight_loader):
def convert_awq_tensor(tensor, tensor_type):
# convert awq qweight/qzeros to a standard format (assume int4)
# qweight: (k, n // pack_factor_bit32) -> (n, k // pack_factor_bit8)
# qzeros: (k // group_size, n // pack_factor_bit32) ->
# (n // pack_factor_bit8, k // group_size)
# pack_factor_bit32 = 32 // weight_bits
# pack_factor_bit8 = 8 // weight_bits
# 0. suppose origin shape (a, b), dtype int32
# 1. convert to uint8, shape (a, b) -> (a, 4 * b)
size0 = tensor.size(0)
tensor = tensor.view(torch.uint8)
# 2. unpack to uint4 (only when weight_bits == 4)
# shape (a, 4 * b) -> (a, 4 * b, 2)
shifter = torch.tensor([0, 4],
dtype=torch.uint8,
device=tensor.device)
tensor = (tensor[:, :, None] >> shifter) & 0xF
# 3. change order, see
# https://github.com/casper-hansen/AutoAWQ/blob/v0.2.8/awq/utils/quant_utils.py
# shape -> (a, 4 * b * pack_factor_bit8)
reverse_awq_pack_order = [0, 4, 1, 5, 2, 6, 3, 7]
tensor = tensor.view(-1, 8)[:, reverse_awq_pack_order]
tensor = tensor.view(size0, -1)
# 4. transpose, shape -> (4 * b * pack_factor_bit8, a)
tensor = tensor.T.contiguous()
# 5. repack (only when weight_bits == 4)
# qweight shape -> (4 * b * pack_factor_bit8, a // pack_factor_bit8)
# qzeros shape -> (4 * b, a)
if tensor_type == "qweight":
tensor = tensor[:, 1::2] * 16 + tensor[:, ::2]
elif tensor_type == "qzeros":
tensor = tensor[1::2, :] * 16 + tensor[::2, :]
return tensor
def convert_gptq_int4_qzeros(tensor):
tensor = tensor.view(torch.uint8)
shifter = torch.tensor([0, 4],
dtype=torch.uint8,
device=tensor.device)
tensor = (tensor[:, :, None] >> shifter) & 0xF
tensor = tensor + 1
tensor = tensor[:, :, 0] + tensor[:, :, 1] * 16
return tensor
def moe_wna16_weight_loader(param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
expert_id: int,
return_success: bool = False):
if "g_idx" in weight_name:
return False if return_success else None
if not layer.quant_config.has_zp and "qzeros" in weight_name:
return False if return_success else None
device = get_tp_group().device
tp_rank = get_tensor_model_parallel_rank()
loaded_weight = loaded_weight.to(device)
shard_size = layer.intermediate_size_per_partition
# convert gptq and awq weight to a standard format
if layer.quant_config.linear_quant_method == "awq":
assert layer.quant_config.weight_bits == 4
if "weight" in weight_name:
loaded_weight = convert_awq_tensor(loaded_weight,
"qweight")
elif "zeros" in weight_name:
loaded_weight = convert_awq_tensor(loaded_weight, "qzeros")
else:
loaded_weight = loaded_weight.T
elif layer.quant_config.linear_quant_method == "gptq":
assert layer.quant_config.weight_bits in [4, 8]
if "weight" in weight_name:
# loaded_weight = loaded_weight.T.contiguous().view(
# torch.uint8)
loaded_weight = loaded_weight.T.contiguous()
elif "zeros" in weight_name:
# add 1 to gptq qzeros to align with awq
loaded_weight = loaded_weight.view(torch.uint8)
if layer.quant_config.weight_bits == 4:
loaded_weight = convert_gptq_int4_qzeros(
loaded_weight).T
else:
loaded_weight = loaded_weight.T + 1
else:
# loaded_weight = loaded_weight.T
loaded_weight = loaded_weight.T.contiguous()
# repeat the qzeros/scales to fit new group size
if layer.group_size_div_factor_w13 > 1 and \
"qzeros" in weight_name or "scales" in weight_name and \
shard_id == "w1" or shard_id == "w3":
loaded_weight = loaded_weight.repeat_interleave(
layer.group_size_div_factor_w13, 1)
elif layer.group_size_div_factor_w2 > 1 and \
"qzeros" in weight_name or "scales" in weight_name and \
shard_id == "w2":
loaded_weight = loaded_weight.repeat_interleave(
layer.group_size_div_factor_w2, 1)
elif layer.group_size_div_factor > 1 and \
"qzeros" in weight_name or "scales" in weight_name:
loaded_weight = loaded_weight.repeat_interleave(
layer.group_size_div_factor, 1)
if "w13_qzeros" in weight_name:
tensor = loaded_weight.view(layer.tp_size, -1,
loaded_weight.size(1))[tp_rank]
if shard_id == "w1":
param.data[expert_id, :shard_size // 2] = tensor
else:
param.data[expert_id, shard_size // 2:] = tensor
return True if return_success else None
elif "w2_qzeros" in weight_name:
param.data[expert_id] = loaded_weight.view(
loaded_weight.size(0), layer.tp_size, -1)[:, tp_rank]
return True if return_success else None
else:
# Delegate to the original loader, passing return_success
return weight_loader(param,
loaded_weight,
weight_name,
shard_id,
expert_id,
return_success=return_success)
return moe_wna16_weight_loader
def process_weights_after_loading(self, layer):
dev_w2 = layer.w2_qweight.device
# torch.Size([128, 2048, 24]), torch.int32, strides: (49152, 24, 1)
# ======>
# torch.Size([128, 256, 192]), torch.int32, strides: (49152, 1, 256)
layer.w2_qweight = torch.nn.Parameter(repack_quant_moe_weight(layer.w2_qweight.cpu()).transpose(-1, -2).contiguous().transpose(-1, -2).to(device=dev_w2), requires_grad=False)
# torch.Size([128, 2048, 3]), torch.float16, strides: (6144, 3, 1)
# ======>
# torch.Size([128, 2048, 3]), torch.float16, strides: (6144, 1, 2048)
layer.w2_scales = torch.nn.Parameter(layer.w2_scales.transpose(-1, -2).contiguous().transpose(-1, -2), requires_grad=False)

View File

@@ -0,0 +1,39 @@
import torch
from typing import List, Optional
def _apply_w8a8_block_fp8_linear(
input: torch.Tensor,
weight: torch.Tensor,
block_size: list[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
cutlass_block_fp8_supported: bool = True,
use_aiter_and_is_supported: bool = False,
) -> torch.Tensor:
assert input_scale is None
assert len(block_size) == 2, "only support dim2 block now"
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
try:
from torch_vacc.vacc.custom_ops import w8a8_block_fp8_linear
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
mla_oproj_output = None
if memory_recycler is not None:
os1, os2 = memory_recycler.MLA_OPROJ_OUT_BUFFER.shape
if os1 == input_2d.size(0) and os2 == weight.size(0):
mla_oproj_output = memory_recycler.MLA_OPROJ_OUT_BUFFER
output = w8a8_block_fp8_linear(input_2d, weight, input_scale, weight_scale, block_size, output = mla_oproj_output)
except Exception as e:
print("vacc fuse fp8 matmul run fail:", e, " , now use unfused ops")
from torch_vacc.vacc.custom_ops_cpu import w8a8_block_fp8_linear
output = w8a8_block_fp8_linear(input_2d, weight, input_scale, weight_scale, block_size)
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)