init
This commit is contained in:
@@ -0,0 +1,36 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Literal, Type, get_args
|
||||
|
||||
QuantizationMethods = Literal[
|
||||
# "aqlm",
|
||||
"awq",
|
||||
"deepspeedfp",
|
||||
"tpu_int8",
|
||||
"fp8",
|
||||
"ptpc_fp8",
|
||||
"fbgemm_fp8",
|
||||
"modelopt",
|
||||
"modelopt_fp4",
|
||||
"bitblas",
|
||||
"gguf",
|
||||
"gptq_marlin_24",
|
||||
"gptq_marlin",
|
||||
"gptq_bitblas",
|
||||
"awq_marlin",
|
||||
"gptq",
|
||||
"compressed-tensors",
|
||||
"bitsandbytes",
|
||||
"hqq",
|
||||
"experts_int8",
|
||||
"ipex",
|
||||
"quark",
|
||||
"moe_wna16",
|
||||
"torchao",
|
||||
"auto-round",
|
||||
"rtn",
|
||||
"inc",
|
||||
"mxfp4",
|
||||
"petit_nvfp4",
|
||||
]
|
||||
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
615
vllm_vacc/vllm/model_executor/layers/quantization/fp8.py
Normal file
615
vllm_vacc/vllm/model_executor/layers/quantization/fp8.py
Normal file
@@ -0,0 +1,615 @@
|
||||
|
||||
import functools
|
||||
import importlib.util
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
apply_fp8_block_linear, check_aiter_fp8_linear_support,
|
||||
create_fp8_input_scale, create_fp8_scale_parameter,
|
||||
create_fp8_weight_parameter, expert_weight_is_col_major,
|
||||
maybe_post_process_fp8_weight_block, process_fp8_weight_block_strategy,
|
||||
process_fp8_weight_tensor_strategy, requant_weight_ue8m0_inplace,
|
||||
validate_fp8_block_shape)
|
||||
# from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
# all_close_1d, apply_fp8_linear, convert_to_channelwise,
|
||||
# cutlass_block_fp8_supported, cutlass_fp8_supported,
|
||||
# normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
|
||||
# requantize_with_max_scale)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp, all_close_1d, convert_to_channelwise,
|
||||
cutlass_block_fp8_supported, cutlass_fp8_supported,
|
||||
maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz,
|
||||
per_tensor_dequantize, requantize_with_max_scale)
|
||||
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape, is_layer_skipped)
|
||||
from vllm.model_executor.layers.linear import QKVParallelLinear
|
||||
from vllm.utils import has_deep_gemm
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
||||
|
||||
|
||||
def Fp8LinearMethod__init(self, quant_config: Fp8Config):
|
||||
self.quant_config = quant_config
|
||||
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
self.use_marlin = (not current_platform.has_device_capability(89)
|
||||
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
|
||||
# Disable marlin for rocm
|
||||
if current_platform.is_rocm():
|
||||
self.use_marlin = False
|
||||
|
||||
self.weight_block_size = self.quant_config.weight_block_size
|
||||
self.block_quant = self.quant_config.weight_block_size is not None
|
||||
self.act_q_static = self.quant_config.activation_scheme == "static"
|
||||
# Use per-token quantization for better perf if dynamic and cutlass
|
||||
if not self.act_q_static and cutlass_fp8_supported():
|
||||
self.act_q_group_shape = GroupShape.PER_TOKEN
|
||||
else:
|
||||
self.act_q_group_shape = GroupShape.PER_TENSOR
|
||||
|
||||
if self.block_quant:
|
||||
self.block_size = self.quant_config.weight_block_size
|
||||
if self.block_quant:
|
||||
# Marlin doesn't support block-wise fp8
|
||||
self.use_marlin = False
|
||||
self.scale_k = 1
|
||||
self.scale_n = 1
|
||||
self.scale_n_prefill = 1 # only for fp8 moe
|
||||
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=self.act_q_static,
|
||||
act_quant_group_shape=self.act_q_group_shape)
|
||||
|
||||
class Fp8LinearMethod(LinearMethodBase):
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
if self.block_quant:
|
||||
|
||||
scale_n = extra_weight_attrs.get("scale_n")
|
||||
scale_k = extra_weight_attrs.get("scale_k")
|
||||
if scale_n is not None:
|
||||
self.scale_n = scale_n
|
||||
if scale_k is not None:
|
||||
self.scale_k = scale_k
|
||||
|
||||
assert self.weight_block_size is not None
|
||||
layer.weight_block_size = self.weight_block_size
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
assert self.quant_config.weight_block_size is not None
|
||||
block_n, block_k = (
|
||||
self.quant_config.weight_block_size[0] // self.scale_n ,
|
||||
self.quant_config.weight_block_size[1] // self.scale_k ,
|
||||
)
|
||||
# Required by row parallel
|
||||
if (tp_size > 1
|
||||
and input_size // input_size_per_partition == tp_size
|
||||
and input_size_per_partition % block_k != 0):
|
||||
raise ValueError(
|
||||
f"Weight input_size_per_partition = "
|
||||
f"{input_size_per_partition} is not divisible by "
|
||||
f"weight quantization block_k = {block_k}.")
|
||||
# Required by column parallel or enabling merged weights
|
||||
if (tp_size > 1 and output_size // output_size_per_partition
|
||||
== tp_size) or len(output_partition_sizes) > 1:
|
||||
for output_partition_size in output_partition_sizes:
|
||||
if output_partition_size % block_n != 0:
|
||||
raise ValueError(
|
||||
f"Weight output_partition_size = "
|
||||
f"{output_partition_size} is not divisible by "
|
||||
f"weight quantization block_n = {block_n}.")
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.orig_dtype = params_dtype
|
||||
# WEIGHT
|
||||
weight_dtype = (torch.float8_e4m3fn
|
||||
if self.quant_config.is_checkpoint_fp8_serialized else
|
||||
params_dtype)
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=weight_dtype),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight", weight)
|
||||
# If checkpoint is serialized fp8, load them.
|
||||
# Otherwise, wait until process_weights_after_loading.
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
# WEIGHT SCALE
|
||||
if not self.block_quant:
|
||||
scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes),
|
||||
dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("weight_scale", scale)
|
||||
else:
|
||||
assert self.quant_config.activation_scheme == "dynamic"
|
||||
scale = BlockQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
(output_size_per_partition + block_n - 1) // block_n,
|
||||
(input_size_per_partition + block_k - 1) // block_k,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
# The weight_scale_inv name is intentional for deepseekv3
|
||||
layer.register_parameter("weight_scale_inv", scale)
|
||||
# INPUT ACTIVATION SCALE
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
scale = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("input_scale", scale)
|
||||
else:
|
||||
layer.register_parameter("input_scale", None)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# TODO(rob): refactor block quant into separate class.
|
||||
if self.block_quant:
|
||||
assert self.quant_config.activation_scheme == "dynamic"
|
||||
if current_platform.is_fp8_fnuz():
|
||||
weight, weight_scale_inv, _ = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale_inv)
|
||||
else:
|
||||
weight = layer.weight.data
|
||||
weight_scale_inv = layer.weight_scale_inv.data
|
||||
|
||||
if isinstance(layer, QKVParallelLinear):
|
||||
# NOTE: for QKVParallelLinear
|
||||
# weight_scale should be divisible by 8 Dsps
|
||||
shape = weight_scale_inv.shape[0]
|
||||
repeat = 1
|
||||
while shape % 8 != 0:
|
||||
repeat *= 2
|
||||
shape = shape * repeat
|
||||
weight_scale_inv = torch.repeat_interleave(weight_scale_inv, repeats=repeat, dim=0)
|
||||
|
||||
# weight = self._maybe_pad_weight(weight)
|
||||
# if self.block_quant:
|
||||
# maybe_post_process_fp8_weight_block(
|
||||
# layer, self.cutlass_block_fp8_supported)
|
||||
|
||||
# Torch.compile cannot use Parameter subclasses.
|
||||
layer.weight = Parameter(weight, requires_grad=False)
|
||||
layer.weight_scale_inv = Parameter(weight_scale_inv,
|
||||
requires_grad=False)
|
||||
return
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
if self.use_marlin:
|
||||
return apply_fp8_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias)
|
||||
|
||||
# Note: lazy import to avoid triton import error.
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
apply_w8a8_block_fp8_linear)
|
||||
if self.block_quant:
|
||||
assert self.quant_config.weight_block_size is not None
|
||||
return apply_w8a8_block_fp8_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
block_size=[layer.weight.shape[0] // layer.weight_scale_inv.shape[0], layer.weight.shape[1] // layer.weight_scale_inv.shape[1]],
|
||||
weight_scale=layer.weight_scale_inv,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
|
||||
)
|
||||
|
||||
return self.fp8_linear.apply(input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
out_dtype=self.out_dtype,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias)
|
||||
# return apply_fp8_linear(
|
||||
# input=x,
|
||||
# weight=layer.weight,
|
||||
# weight_scale=layer.weight_scale,
|
||||
# input_scale=layer.input_scale,
|
||||
# bias=bias,
|
||||
# cutlass_fp8_supported=self.cutlass_fp8_supported,
|
||||
# # Default to using per_token quantization if cutlass is supported
|
||||
# use_per_token_if_dynamic=self.cutlass_fp8_supported)
|
||||
|
||||
def Fp8MoEMethod_init_(self, quant_config: Fp8Config, layer: torch.nn.Module):
|
||||
self.layer = layer
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
self.quant_config = quant_config
|
||||
self.block_quant = self.quant_config.weight_block_size is not None
|
||||
self.flashinfer_moe_backend = None
|
||||
|
||||
self.scale_k = 1
|
||||
self.scale_n = 1
|
||||
self.scale_n_prefill = 1
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
self.use_marlin = (not current_platform.has_device_capability(89)
|
||||
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
|
||||
# Disable marlin for rocm
|
||||
if current_platform.is_rocm() or current_platform.is_vacc:
|
||||
self.use_marlin = False
|
||||
|
||||
# Check for DeepGemm support.
|
||||
self.allow_deep_gemm = False
|
||||
if envs.VLLM_USE_DEEP_GEMM:
|
||||
if not has_deep_gemm():
|
||||
logger.warning_once("Failed to import DeepGemm kernels.")
|
||||
elif not self.block_quant:
|
||||
logger.warning_once("Model is not block quantized. Not using "
|
||||
" DeepGemm kernels")
|
||||
elif (current_platform.is_cuda()
|
||||
and current_platform.has_device_capability(90)):
|
||||
logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
|
||||
self.allow_deep_gemm = True
|
||||
else:
|
||||
logger.warning_once(
|
||||
"DeepGemm not supported on the current platform.")
|
||||
|
||||
# Check for CutlassBlockScaledGroupedGemm support.
|
||||
self.allow_cutlass_block_scaled_grouped_gemm = False
|
||||
if not self.block_quant:
|
||||
logger.warning_once("Model is not block quantized. Not using "
|
||||
"CutlassBlockScaledGroupedGemm kernels")
|
||||
elif (current_platform.is_cuda()
|
||||
and current_platform.has_device_capability(100)):
|
||||
logger.info_once(
|
||||
"Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod."
|
||||
)
|
||||
self.allow_cutlass_block_scaled_grouped_gemm = True
|
||||
else:
|
||||
logger.warning_once(
|
||||
"CutlassBlockScaledGroupedGemm not supported on the current "
|
||||
"platform.")
|
||||
|
||||
self.topk_indices_dtype = None
|
||||
self.fused_experts = functools.partial( # type: ignore
|
||||
fused_experts,
|
||||
use_fp8_w8a8=True,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
allow_deep_gemm=self.allow_deep_gemm,
|
||||
allow_cutlass_block_scaled_grouped_gemm=(
|
||||
self.allow_cutlass_block_scaled_grouped_gemm))
|
||||
|
||||
class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
params_dtype = torch.float8_e4m3fn
|
||||
if self.block_quant:
|
||||
assert self.quant_config.weight_block_size is not None
|
||||
|
||||
scale_n = extra_weight_attrs.get("scale_n")
|
||||
scale_n_prefill = extra_weight_attrs.get("scale_n_prefill")
|
||||
scale_k = extra_weight_attrs.get("scale_k")
|
||||
if scale_n is not None:
|
||||
self.scale_n = scale_n
|
||||
if scale_k is not None:
|
||||
self.scale_k = scale_k
|
||||
if scale_n_prefill is not None:
|
||||
self.scale_n_prefill = scale_n_prefill
|
||||
|
||||
if self.quant_config is not None and self.quant_config.weight_block_size is not None:
|
||||
self.gcd_value = self.quant_config.weight_block_size[0]
|
||||
|
||||
output_size_no_merge = intermediate_size_per_partition
|
||||
#assert isinstance(output_size_no_merge, int), f"merge output size should divded int, valuue is: {output_size_no_merge}"
|
||||
|
||||
if output_size_no_merge % self.quant_config.weight_block_size[0]:
|
||||
import math
|
||||
gcd_value = math.gcd(output_size_no_merge % self.quant_config.weight_block_size[0], self.quant_config.weight_block_size[0])
|
||||
self.scale_n =self.scale_n * self.quant_config.weight_block_size[0] // gcd_value
|
||||
self.scale_n_prefill =self.scale_n_prefill * self.quant_config.weight_block_size[0] // gcd_value
|
||||
if hidden_size % self.quant_config.weight_block_size[1]:
|
||||
import math
|
||||
gcd_value = math.gcd(hidden_size % self.quant_config.weight_block_size[1], self.quant_config.weight_block_size[1])
|
||||
self.scale_k =self.scale_k * self.quant_config.weight_block_size[1] // gcd_value
|
||||
# self.scale_k = self.scale_n
|
||||
|
||||
# print('output_size_no_merge', output_size_no_merge)
|
||||
# 按 block_size 分core
|
||||
# output_size_no_merge = 384
|
||||
# block_size = 128: 384 = 3x128 只能分3core x 128
|
||||
# block_size = 16: 384 = 24x16 8core x (3x16) 可以分到 8core
|
||||
|
||||
# output_size_no_merge = 512
|
||||
# block_size = 128: 512 = 4x128 只能分 4core x 128
|
||||
# block_size = 64: 512 = 8x64 可以分到 8core x 64
|
||||
|
||||
# output_size_no_merge = 768
|
||||
# block_size = 128: 768 = 6x128 只能分 6core x 128
|
||||
# block_size = 32: 768 = 8x(3x32) 可以分到 8core x (3x32)
|
||||
|
||||
core_num = 8
|
||||
min_block_size = 4
|
||||
block_size_tmp = self.quant_config.weight_block_size[0] // self.scale_n
|
||||
if output_size_no_merge > block_size_tmp and \
|
||||
output_size_no_merge % block_size_tmp == 0 and \
|
||||
output_size_no_merge // block_size_tmp < core_num and \
|
||||
output_size_no_merge % core_num == 0:
|
||||
core_num_old = output_size_no_merge // block_size_tmp
|
||||
import math
|
||||
gcd_value = math.gcd(core_num, core_num_old)
|
||||
new_scale = core_num // gcd_value
|
||||
if block_size_tmp // new_scale >= min_block_size:
|
||||
self.scale_n = new_scale * self.scale_n
|
||||
|
||||
|
||||
#print("moe scale n is:", self.scale_n, self.scale_k, intermediate_size_per_partition)
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if self.scale_n != self.scale_n_prefill:
|
||||
block_n_prefill = self.quant_config.weight_block_size[0] // self.scale_n_prefill
|
||||
|
||||
block_n, block_k = (
|
||||
self.quant_config.weight_block_size[0] // self.scale_n,
|
||||
self.quant_config.weight_block_size[1] // self.scale_k,
|
||||
)
|
||||
# NOTE: To ensure proper alignment of the block-wise quantization
|
||||
# scales, the output_size of the weights for both the gate and up
|
||||
# layers must be divisible by block_n.
|
||||
# Required by column parallel or enabling merged weights
|
||||
if intermediate_size_per_partition % block_n != 0:
|
||||
raise ValueError(
|
||||
f"The output_size of gate's and up's weight = "
|
||||
f"{intermediate_size_per_partition} is not divisible by "
|
||||
f"weight quantization block_n = {block_n}.")
|
||||
if (tp_size > 1
|
||||
and hidden_size % block_k != 0):
|
||||
# Required by row parallel
|
||||
raise ValueError(
|
||||
f"The input_size of down's weight = "
|
||||
f"{intermediate_size_per_partition} is not divisible by "
|
||||
f"weight quantization block_k = {block_k}.")
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# WEIGHT_SCALES
|
||||
if not self.block_quant:
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
# They will be combined to a single scale after weight loading.
|
||||
w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, 2, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
w2_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
else:
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts,
|
||||
2 * ((intermediate_size_per_partition + block_n - 1) //
|
||||
block_n),
|
||||
(hidden_size + block_k - 1) // block_k,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts,
|
||||
(hidden_size + block_k - 1) // block_k,
|
||||
(intermediate_size_per_partition + block_n - 1) // block_n,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
if self.scale_n != self.scale_n_prefill:
|
||||
w13_weight_scale_prefill = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts,
|
||||
2 * ((intermediate_size_per_partition + block_n_prefill - 1) //
|
||||
block_n_prefill),
|
||||
(hidden_size + block_k - 1) // block_k,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
w2_weight_scale_prefill = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts,
|
||||
(hidden_size + block_k - 1) // block_k,
|
||||
(intermediate_size_per_partition + block_n_prefill - 1) // block_n_prefill,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale_inv_prefill", w13_weight_scale_prefill)
|
||||
layer.register_parameter("w2_weight_scale_inv_prefill", w2_weight_scale_prefill)
|
||||
|
||||
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
|
||||
assert self.quant_config.activation_scheme == "dynamic"
|
||||
|
||||
# Add the quantization method used (per tensor/grouped/channel)
|
||||
# to ensure the weight scales are loaded in properly
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.
|
||||
value} if self.block_quant else
|
||||
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
|
||||
# If loading fp8 checkpoint, pass the weight loaders.
|
||||
# If loading an fp16 checkpoint, do not (we will quantize in
|
||||
# process_weights_after_loading()
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
if self.scale_n != self.scale_n_prefill:
|
||||
set_weight_attrs(w13_weight_scale_prefill, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale_prefill, extra_weight_attrs)
|
||||
# INPUT_SCALES
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
raise ValueError(
|
||||
"Found static activation scheme for checkpoint that "
|
||||
"was not serialized fp8.")
|
||||
|
||||
w13_input_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
||||
|
||||
w2_input_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
||||
|
||||
else:
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
def moe_fp8_apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
try:
|
||||
from torch_vacc.vacc.custom_ops import fused_experts
|
||||
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
|
||||
experts_output = None
|
||||
if memory_recycler is not None:
|
||||
# remove MOE_EXPERT_OUT_BUFFER
|
||||
# experts_output = memory_recycler.MOE_EXPERT_OUT_BUFFER
|
||||
experts_output = memory_recycler.MOE_SHARED_MLP_OUT_BUFFER
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
use_fp8_w8a8=True,
|
||||
w13_scale=(layer.w13_weight_scale_inv
|
||||
if self.block_quant else layer.w13_weight_scale),
|
||||
w2_scale=(layer.w2_weight_scale_inv
|
||||
if self.block_quant else layer.w2_weight_scale),
|
||||
a13_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
decode_with_batch=layer.is_decode and x.shape[0] > 1,
|
||||
output_opt=experts_output
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"vacc fused_expert run fail, now using unfused ops: {e}")
|
||||
from torch_vacc.vacc.custom_ops_cpu import fused_experts
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
use_fp8_w8a8=True,
|
||||
w13_scale=(layer.w13_weight_scale_inv
|
||||
if self.block_quant else layer.w13_weight_scale),
|
||||
w2_scale=(layer.w2_weight_scale_inv
|
||||
if self.block_quant else layer.w2_weight_scale),
|
||||
a13_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
)
|
||||
282
vllm_vacc/vllm/model_executor/layers/quantization/gptq.py
Normal file
282
vllm_vacc/vllm/model_executor/layers/quantization/gptq.py
Normal file
@@ -0,0 +1,282 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import enum
|
||||
from enum import Enum
|
||||
from fractions import Fraction
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
||||
get_linear_quant_method)
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
RowvLLMParameter)
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig as GPTQConfigOrig
|
||||
from vllm.model_executor.layers.quantization.gptq import ExllamaState
|
||||
from vllm_vacc.vllm.model_executor.models.vars import TRANSPOSE_GPTQ_WEIGHT
|
||||
import math
|
||||
|
||||
def GPTQLinearMethod__init(self, quant_config: GPTQConfigOrig):
|
||||
self.quant_config = quant_config
|
||||
self.scale_k = 1
|
||||
self.split_num = 4
|
||||
|
||||
def int32_to_int4(s0, axis = -2):
|
||||
# 要先拉平 shape[1, n]
|
||||
# 每个int32 拆成8个int4, 8个int32表示, 得到[8, n]
|
||||
|
||||
# x32(int32) => 32bit => 4bit x 8 x4[8] 4bit
|
||||
|
||||
# x32 31-28 => x4[7]
|
||||
# x32 27-24 => x4[6]
|
||||
# ...
|
||||
# x32 3-0 => x4[0]
|
||||
|
||||
# x32[index=0] => x4[7,6,5,4,3,2,1,0]
|
||||
|
||||
# 4bit转真实数字:
|
||||
# 不是按补码方式
|
||||
|
||||
# 1111 => 15 => 7
|
||||
# 15-8 = 7
|
||||
|
||||
# 0101 => 6 =>-2
|
||||
# 6-8 = -2
|
||||
|
||||
# 0x 6A CB 37 2B (内存中排列 2B 37 CB 6A) => B273BCA6 => (-8) => int4: 3, -6, -1, -5, 3, 4, 2, -2
|
||||
|
||||
# 内存中实际排布为小端模式:
|
||||
# int32: 2B 37 CB 6A => 2,11,3,7,12,11,6,10 => (-8) => -6,3, -5,-1, 4,3, -2,2 => 同一字节所在的两个交换得到 3, -6, -1, -5, 3, 4, 2, -2
|
||||
# int4: 3, -6, -1, -5, 3, 4, 2, -2
|
||||
|
||||
s = s0.view(torch.uint32)
|
||||
all = []
|
||||
for i in range(8):
|
||||
x = 15 << (i*4)
|
||||
# s2 = torch.bitwise_and(x,s)
|
||||
s2 = torch.from_numpy(np.bitwise_and(x, s.numpy()))
|
||||
s3 = s2 / (2 ** (i*4))
|
||||
s4 = s3.to(torch.int32)
|
||||
# 补码, 结果不对
|
||||
# s4[s4 > 7] = s4[s4 > 7]-16
|
||||
# 直接 - 8 结果正确, 范围: -8-7
|
||||
s4 = s4 - 8
|
||||
all.append(s4.reshape(1,*s4.shape))
|
||||
all = torch.concatenate(all, 0)
|
||||
if axis == -2 or axis == 0:
|
||||
# 8,K//8,N => K//8,8,N => K,N
|
||||
all = all.transpose(-2,0).reshape(-1,all.shape[-1]).contiguous()
|
||||
else:
|
||||
# 8,N,K//8 => N,K//8,8 => N,K
|
||||
all = all.permute(1,2,0).reshape(all.shape[-2],-1).contiguous()
|
||||
return all
|
||||
|
||||
|
||||
def dequant_weight(qw, scales, group_size = 128):
|
||||
N = qw.shape[1]
|
||||
int4_to_int32_axis = -2
|
||||
if TRANSPOSE_GPTQ_WEIGHT:
|
||||
N = qw.shape[0]
|
||||
int4_to_int32_axis = -1
|
||||
qweight = int32_to_int4(qw,int4_to_int32_axis).to(torch.float16) #int32 => 8 int4 +> fp16
|
||||
|
||||
if TRANSPOSE_GPTQ_WEIGHT:
|
||||
scales = scales.T.contiguous()
|
||||
qweight = qweight.T.contiguous()
|
||||
|
||||
scales = torch.concatenate([scales] * group_size, 1).reshape(-1, N) # scale 按 group_size 扩展, 每 group_size 个数共用一个scale
|
||||
|
||||
# print('qweight', qweight.shape, qweight.dtype)
|
||||
# print('scale', scales.shape, scales.dtype)
|
||||
|
||||
dequant_weight = qweight * scales #dequant
|
||||
return dequant_weight
|
||||
|
||||
class GPTQConfig(QuantizationConfig):
|
||||
"""Config class for GPTQ.
|
||||
|
||||
Reference: https://arxiv.org/abs/2210.17323
|
||||
"""
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
class GPTQLinearMethod(LinearMethodBase):
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
del output_size # Unused.
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
# if input_size_per_partition % self.quant_config.group_size != 0:
|
||||
# raise ValueError(
|
||||
# "The input size is not aligned with the quantized "
|
||||
# "weight shape. This can be caused by too large "
|
||||
# "tensor parallel size.")
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if (output_size_per_partition % self.quant_config.pack_factor.numerator
|
||||
!= 0):
|
||||
raise ValueError(
|
||||
"The output size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
exllama_state = ExllamaState.UNINITIALIZED
|
||||
scale_and_zero_size = input_size // group_size
|
||||
scale_and_zero_input_dim = None
|
||||
if (input_size != input_size_per_partition
|
||||
and self.quant_config.group_size != -1):
|
||||
# For act-order models, we cannot use Exllama for row parallel layer
|
||||
if self.quant_config.desc_act:
|
||||
exllama_state = ExllamaState.UNUSED
|
||||
else:
|
||||
# we need to partition qzeros and scales for exllama kernel
|
||||
scale_and_zero_size = input_size_per_partition // group_size
|
||||
scale_and_zero_input_dim = 0
|
||||
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // self.quant_config.pack_factor,
|
||||
output_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=0,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
g_idx = RowvLLMParameter(data=torch.tensor(
|
||||
[
|
||||
i // self.quant_config.group_size
|
||||
for i in range(input_size_per_partition)
|
||||
],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
qzeros_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
scale_and_zero_size,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
}
|
||||
weight_scale_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
scale_and_zero_size,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
}
|
||||
if scale_and_zero_input_dim is None:
|
||||
scales = ChannelQuantScaleParameter(output_dim=1,
|
||||
**weight_scale_args)
|
||||
qzeros = PackedColumnParameter(
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args)
|
||||
|
||||
else:
|
||||
scales = GroupQuantScaleParameter(output_dim=1,
|
||||
input_dim=0,
|
||||
**weight_scale_args)
|
||||
qzeros = PackedvLLMParameter(
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("g_idx", g_idx)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
layer.register_parameter("scales", scales)
|
||||
|
||||
layer.exllama_state = exllama_state
|
||||
|
||||
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# for torch.compile
|
||||
# self.quant_config.weight_bits == 4
|
||||
if TRANSPOSE_GPTQ_WEIGHT:
|
||||
layer.qzeros = Parameter(layer.qzeros.data.T.contiguous(), requires_grad=False)
|
||||
layer.qweight = Parameter(layer.qweight.data.T.contiguous(), requires_grad=False)
|
||||
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
|
||||
layer.scales = Parameter(layer.scales.data.T.contiguous(), requires_grad=False)
|
||||
else:
|
||||
layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False)
|
||||
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
|
||||
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
|
||||
layer.scales = Parameter(layer.scales.data, requires_grad=False)
|
||||
|
||||
# exllama needs to shuffle the weight after the weight is loaded
|
||||
# here we do the shuffle on first forward pass
|
||||
if layer.exllama_state == ExllamaState.UNINITIALIZED:
|
||||
if self.quant_config.desc_act:
|
||||
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
|
||||
layer.exllama_state = ExllamaState.READY
|
||||
ops.gptq_shuffle(layer.qweight, layer.g_idx,
|
||||
self.quant_config.weight_bits)
|
||||
else:
|
||||
layer.g_idx.data = torch.empty((0, ),
|
||||
dtype=torch.int,
|
||||
device=layer.g_idx.device)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
out_shape = x.shape[:-1] + (layer.qweight.shape[-2 if TRANSPOSE_GPTQ_WEIGHT else -1], ) # M,N
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
# print(f"~~~~ start dequant")
|
||||
# import time
|
||||
# start_quant_time = time.time()
|
||||
# weight = dequant_weight(layer.qweight.cpu(), layer.scales.cpu(), self.quant_config.group_size // self.scale_k).to(layer.qweight.device)
|
||||
# end_quant_time = time.time()
|
||||
# print(f"~~~~ dequant time: {end_quant_time - start_quant_time}")
|
||||
# if torch.distributed.get_rank() == 0:
|
||||
# print(f"~~~~ weight shape: {weight.shape}, dtype: {weight.dtype}")
|
||||
# output = torch.matmul(reshaped_x, weight)
|
||||
# print("entering GPTQLinearMethod apply, reshaped_x shape:", reshaped_x.shape, "reshaped_x stride", reshaped_x.stride(), "input_tensor", x.shape, "qweight shape:", layer.qweight.shape, "scales shape:", layer.scales.shape)
|
||||
output = torch.vacc.w4a8_block_int4_matmul(
|
||||
reshaped_x,
|
||||
layer.qweight.transpose(-1, -2),
|
||||
layer.scales.transpose(-1, -2),
|
||||
[1, self.quant_config.group_size // self.scale_k],
|
||||
)
|
||||
# print("exiting GPTQLinearMethod apply, output shape:", output.shape)
|
||||
# end_gemm_time = time.time()
|
||||
# if torch.distributed.get_rank() == 0:
|
||||
# print(f"~~~~ gemm time: {end_gemm_time - end_quant_time}")
|
||||
if bias is not None:
|
||||
output.add_(bias)
|
||||
return output.reshape(out_shape)
|
||||
372
vllm_vacc/vllm/model_executor/layers/quantization/moe_wna16.py
Normal file
372
vllm_vacc/vllm/model_executor/layers/quantization/moe_wna16.py
Normal file
@@ -0,0 +1,372 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig, int4_w4a16_moe_quant_config,
|
||||
int8_w8a16_moe_quant_config)
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_marlin_supports_layer)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
# [num_experts, N, K//8], int32 ==> [num_experts, N, K], int4 ==> [num_experts, N//8, K], int32
|
||||
def repack_quant_moe_weight_old(original_packed_tensor):
|
||||
num_experts = original_packed_tensor.shape[0]
|
||||
N = original_packed_tensor.shape[1]
|
||||
K = original_packed_tensor.shape[2] * 8
|
||||
if original_packed_tensor.dtype != torch.int32:
|
||||
raise ValueError("data type of input tensor should be int32")
|
||||
if N % 8 != 0:
|
||||
raise ValueError("N of input tensor should be divisible by 8")
|
||||
|
||||
# --- 1. 解包:将 int32 张量展开为逻辑上的 int4 张量 ---
|
||||
# 创建一个临时张量来存储解包后的所有 int4 值
|
||||
# 用 torch.uint8 作为 int4 的临时存储,因为 PyTorch 没有原生的 int4 dtype
|
||||
unpacked_int4_tensor = torch.zeros(
|
||||
(num_experts, N, K),
|
||||
dtype=torch.uint8,
|
||||
device=original_packed_tensor.device
|
||||
)
|
||||
mask = 0b1111
|
||||
for i in range(8):
|
||||
# 提取当前 int4 所需的 int32 块中的值
|
||||
# 通过右移 (i * 4) 位,我们将第 i 个 4 位整数移动到最低有效位
|
||||
# 然后通过按位与操作与掩码结合,提取出这 4 位的值
|
||||
extracted_int4s = (original_packed_tensor >> (i * 4)) & mask
|
||||
# 将提取出的 int4 值放置到 unpacked_int4_tensor 的正确位置
|
||||
# 使用切片 `i::8`,意思是:从索引 `i` 开始,每隔 8 个位置填充一次
|
||||
unpacked_int4_tensor[:, :, i::8] = extracted_int4s
|
||||
|
||||
# --- 2. 重新打包:将 int4 逻辑张量重新打包为新的 int32 张量 ---
|
||||
new_packed_tensor = torch.zeros(
|
||||
(num_experts, N//8, K),
|
||||
dtype=torch.int32,
|
||||
device=original_packed_tensor.device
|
||||
)
|
||||
for i in range(8):
|
||||
# 从解包后的 int4 张量中提取当前需要打包的 int4 序列,使用切片 `i::8` 沿着N方向来提取
|
||||
current_int4_segment = unpacked_int4_tensor[:, i::8, :]
|
||||
# 将这个 int4 序列转换为 int32 类型(因为打包到 int32)并左移到其在新 int32 块中的正确位置
|
||||
# 然后使用按位或操作符将其合并到 new_packed_tensor 中
|
||||
new_packed_tensor |= (current_int4_segment.to(torch.int32) << (i * 4))
|
||||
|
||||
return new_packed_tensor
|
||||
|
||||
|
||||
def repack_quant_moe_weight(original_packed_tensor):
|
||||
if original_packed_tensor.dtype != torch.int32:
|
||||
raise ValueError("data type of input tensor should be int32")
|
||||
|
||||
num_experts, N, K_packed = original_packed_tensor.shape
|
||||
K = K_packed * 8
|
||||
|
||||
if N % 8 != 0:
|
||||
raise ValueError("N of input tensor should be divisible by 8")
|
||||
|
||||
new_packed_tensor = torch.zeros((num_experts, N // 8, K),
|
||||
dtype=torch.int32,
|
||||
device=original_packed_tensor.device)
|
||||
for i in range(8):
|
||||
source_slice = original_packed_tensor[:, i::8, :]
|
||||
for j in range(8):
|
||||
unpacked_strip = (source_slice >> (j * 4)) & 0b1111
|
||||
new_packed_tensor[:, :, j::8] |= (unpacked_strip.to(torch.int32) << (i * 4))
|
||||
|
||||
return new_packed_tensor
|
||||
|
||||
|
||||
class MoeWNA16Method(FusedMoEMethodBase):
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
self.moe = layer
|
||||
layer.quant_config = self.quant_config
|
||||
bit8_pack_factor = self.quant_config.bit8_pack_factor
|
||||
bit32_pack_factor = 32 // self.quant_config.weight_bits
|
||||
group_size = self.quant_config.group_size
|
||||
group_size_div_factor = 1
|
||||
group_size_w13 = self.quant_config.group_size
|
||||
group_size_div_factor_w13 = 1
|
||||
group_size_w2 = self.quant_config.group_size
|
||||
group_size_div_factor_w2 = 1
|
||||
|
||||
# make intermediate_size and hidden_size divisible by group_size
|
||||
# we reduce the group size to ensure that
|
||||
# and we would repeat the loaded_weight later
|
||||
while intermediate_size_per_partition % group_size or \
|
||||
hidden_size % group_size:
|
||||
group_size = group_size // 2
|
||||
group_size_div_factor *= 2
|
||||
assert group_size >= 32
|
||||
layer.group_size = group_size
|
||||
layer.group_size_div_factor = group_size_div_factor
|
||||
|
||||
while intermediate_size_per_partition % group_size_w2:
|
||||
group_size_w2 = group_size_w2 // 2
|
||||
group_size_div_factor_w2 *= 2
|
||||
assert group_size_w2 >= 32
|
||||
layer.w2_block_size = group_size_w2
|
||||
layer.group_size_div_factor_w2 = group_size_div_factor_w2
|
||||
|
||||
while hidden_size % group_size_w13:
|
||||
group_size_w13 = group_size_w13 // 2
|
||||
group_size_div_factor_w13 *= 2
|
||||
assert group_size_w13 >= 32
|
||||
layer.w13_block_size = group_size_w13
|
||||
layer.group_size_div_factor_w13 = group_size_div_factor_w13
|
||||
|
||||
strategy = FusedMoeWeightScaleSupported.GROUP.value
|
||||
extra_weight_attrs.update({
|
||||
"quant_method": strategy,
|
||||
"is_transposed": False
|
||||
})
|
||||
|
||||
assert 'weight_loader' in extra_weight_attrs
|
||||
weight_loader = extra_weight_attrs['weight_loader']
|
||||
wrapped_weight_loader = MoeWNA16Method.get_weight_loader(
|
||||
layer, weight_loader)
|
||||
extra_weight_attrs['weight_loader'] = wrapped_weight_loader
|
||||
|
||||
# Fused gate_up_proj (column parallel)
|
||||
# w13_qweight = torch.nn.Parameter(torch.empty(
|
||||
# num_experts,
|
||||
# 2 * intermediate_size_per_partition,
|
||||
# hidden_size // bit8_pack_factor,
|
||||
# dtype=torch.uint8),
|
||||
# requires_grad=False)
|
||||
w13_qweight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size // bit32_pack_factor,
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_qweight", w13_qweight)
|
||||
set_weight_attrs(w13_qweight, extra_weight_attrs)
|
||||
|
||||
# down_proj (row parallel)
|
||||
# w2_qweight = torch.nn.Parameter(torch.empty(
|
||||
# num_experts,
|
||||
# hidden_size,
|
||||
# intermediate_size_per_partition // bit8_pack_factor,
|
||||
# dtype=torch.uint8),
|
||||
# requires_grad=False)
|
||||
w2_qweight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition // bit32_pack_factor,
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_qweight", w2_qweight)
|
||||
set_weight_attrs(w2_qweight, extra_weight_attrs)
|
||||
|
||||
w13_scales = torch.nn.Parameter(torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size // group_size_w13,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_scales", w13_scales)
|
||||
set_weight_attrs(w13_scales, extra_weight_attrs)
|
||||
|
||||
w2_scales = torch.nn.Parameter(torch.zeros(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition // group_size_w2,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_scales", w2_scales)
|
||||
set_weight_attrs(w2_scales, extra_weight_attrs)
|
||||
|
||||
if self.quant_config.has_zp:
|
||||
w13_qzeros = torch.nn.Parameter(torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition // bit8_pack_factor,
|
||||
hidden_size // group_size,
|
||||
dtype=torch.uint8),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_qzeros", w13_qzeros)
|
||||
set_weight_attrs(w13_qzeros, extra_weight_attrs)
|
||||
|
||||
w2_qzeros = torch.nn.Parameter(torch.zeros(
|
||||
num_experts,
|
||||
hidden_size // bit8_pack_factor,
|
||||
intermediate_size_per_partition // group_size,
|
||||
dtype=torch.uint8),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_qzeros", w2_qzeros)
|
||||
set_weight_attrs(w2_qzeros, extra_weight_attrs)
|
||||
|
||||
if self.quant_config.linear_quant_method == "gptq":
|
||||
# some param are unused, but we need to init them in order to
|
||||
# load weights
|
||||
invalid_param_keys = ["w13_g_idx", "w2_g_idx"]
|
||||
if not self.quant_config.has_zp:
|
||||
invalid_param_keys += ["w13_qzeros", "w2_qzeros"]
|
||||
for key in invalid_param_keys:
|
||||
param = torch.nn.Parameter(torch.empty((0, ),
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter(key, param)
|
||||
set_weight_attrs(param, extra_weight_attrs)
|
||||
|
||||
@staticmethod
|
||||
def get_weight_loader(layer, weight_loader):
|
||||
|
||||
def convert_awq_tensor(tensor, tensor_type):
|
||||
# convert awq qweight/qzeros to a standard format (assume int4)
|
||||
# qweight: (k, n // pack_factor_bit32) -> (n, k // pack_factor_bit8)
|
||||
# qzeros: (k // group_size, n // pack_factor_bit32) ->
|
||||
# (n // pack_factor_bit8, k // group_size)
|
||||
# pack_factor_bit32 = 32 // weight_bits
|
||||
# pack_factor_bit8 = 8 // weight_bits
|
||||
|
||||
# 0. suppose origin shape (a, b), dtype int32
|
||||
# 1. convert to uint8, shape (a, b) -> (a, 4 * b)
|
||||
size0 = tensor.size(0)
|
||||
tensor = tensor.view(torch.uint8)
|
||||
|
||||
# 2. unpack to uint4 (only when weight_bits == 4)
|
||||
# shape (a, 4 * b) -> (a, 4 * b, 2)
|
||||
shifter = torch.tensor([0, 4],
|
||||
dtype=torch.uint8,
|
||||
device=tensor.device)
|
||||
tensor = (tensor[:, :, None] >> shifter) & 0xF
|
||||
|
||||
# 3. change order, see
|
||||
# https://github.com/casper-hansen/AutoAWQ/blob/v0.2.8/awq/utils/quant_utils.py
|
||||
# shape -> (a, 4 * b * pack_factor_bit8)
|
||||
reverse_awq_pack_order = [0, 4, 1, 5, 2, 6, 3, 7]
|
||||
tensor = tensor.view(-1, 8)[:, reverse_awq_pack_order]
|
||||
tensor = tensor.view(size0, -1)
|
||||
|
||||
# 4. transpose, shape -> (4 * b * pack_factor_bit8, a)
|
||||
tensor = tensor.T.contiguous()
|
||||
|
||||
# 5. repack (only when weight_bits == 4)
|
||||
# qweight shape -> (4 * b * pack_factor_bit8, a // pack_factor_bit8)
|
||||
# qzeros shape -> (4 * b, a)
|
||||
|
||||
if tensor_type == "qweight":
|
||||
tensor = tensor[:, 1::2] * 16 + tensor[:, ::2]
|
||||
elif tensor_type == "qzeros":
|
||||
tensor = tensor[1::2, :] * 16 + tensor[::2, :]
|
||||
return tensor
|
||||
|
||||
def convert_gptq_int4_qzeros(tensor):
|
||||
tensor = tensor.view(torch.uint8)
|
||||
shifter = torch.tensor([0, 4],
|
||||
dtype=torch.uint8,
|
||||
device=tensor.device)
|
||||
tensor = (tensor[:, :, None] >> shifter) & 0xF
|
||||
tensor = tensor + 1
|
||||
tensor = tensor[:, :, 0] + tensor[:, :, 1] * 16
|
||||
return tensor
|
||||
|
||||
def moe_wna16_weight_loader(param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
weight_name: str,
|
||||
shard_id: str,
|
||||
expert_id: int,
|
||||
return_success: bool = False):
|
||||
if "g_idx" in weight_name:
|
||||
return False if return_success else None
|
||||
if not layer.quant_config.has_zp and "qzeros" in weight_name:
|
||||
return False if return_success else None
|
||||
|
||||
device = get_tp_group().device
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
loaded_weight = loaded_weight.to(device)
|
||||
shard_size = layer.intermediate_size_per_partition
|
||||
|
||||
# convert gptq and awq weight to a standard format
|
||||
if layer.quant_config.linear_quant_method == "awq":
|
||||
assert layer.quant_config.weight_bits == 4
|
||||
if "weight" in weight_name:
|
||||
loaded_weight = convert_awq_tensor(loaded_weight,
|
||||
"qweight")
|
||||
elif "zeros" in weight_name:
|
||||
loaded_weight = convert_awq_tensor(loaded_weight, "qzeros")
|
||||
else:
|
||||
loaded_weight = loaded_weight.T
|
||||
elif layer.quant_config.linear_quant_method == "gptq":
|
||||
assert layer.quant_config.weight_bits in [4, 8]
|
||||
if "weight" in weight_name:
|
||||
# loaded_weight = loaded_weight.T.contiguous().view(
|
||||
# torch.uint8)
|
||||
loaded_weight = loaded_weight.T.contiguous()
|
||||
elif "zeros" in weight_name:
|
||||
# add 1 to gptq qzeros to align with awq
|
||||
loaded_weight = loaded_weight.view(torch.uint8)
|
||||
if layer.quant_config.weight_bits == 4:
|
||||
loaded_weight = convert_gptq_int4_qzeros(
|
||||
loaded_weight).T
|
||||
else:
|
||||
loaded_weight = loaded_weight.T + 1
|
||||
else:
|
||||
# loaded_weight = loaded_weight.T
|
||||
loaded_weight = loaded_weight.T.contiguous()
|
||||
|
||||
# repeat the qzeros/scales to fit new group size
|
||||
if layer.group_size_div_factor_w13 > 1 and \
|
||||
"qzeros" in weight_name or "scales" in weight_name and \
|
||||
shard_id == "w1" or shard_id == "w3":
|
||||
loaded_weight = loaded_weight.repeat_interleave(
|
||||
layer.group_size_div_factor_w13, 1)
|
||||
elif layer.group_size_div_factor_w2 > 1 and \
|
||||
"qzeros" in weight_name or "scales" in weight_name and \
|
||||
shard_id == "w2":
|
||||
loaded_weight = loaded_weight.repeat_interleave(
|
||||
layer.group_size_div_factor_w2, 1)
|
||||
elif layer.group_size_div_factor > 1 and \
|
||||
"qzeros" in weight_name or "scales" in weight_name:
|
||||
loaded_weight = loaded_weight.repeat_interleave(
|
||||
layer.group_size_div_factor, 1)
|
||||
|
||||
if "w13_qzeros" in weight_name:
|
||||
tensor = loaded_weight.view(layer.tp_size, -1,
|
||||
loaded_weight.size(1))[tp_rank]
|
||||
if shard_id == "w1":
|
||||
param.data[expert_id, :shard_size // 2] = tensor
|
||||
else:
|
||||
param.data[expert_id, shard_size // 2:] = tensor
|
||||
return True if return_success else None
|
||||
elif "w2_qzeros" in weight_name:
|
||||
param.data[expert_id] = loaded_weight.view(
|
||||
loaded_weight.size(0), layer.tp_size, -1)[:, tp_rank]
|
||||
return True if return_success else None
|
||||
else:
|
||||
# Delegate to the original loader, passing return_success
|
||||
return weight_loader(param,
|
||||
loaded_weight,
|
||||
weight_name,
|
||||
shard_id,
|
||||
expert_id,
|
||||
return_success=return_success)
|
||||
|
||||
return moe_wna16_weight_loader
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
dev_w2 = layer.w2_qweight.device
|
||||
# torch.Size([128, 2048, 24]), torch.int32, strides: (49152, 24, 1)
|
||||
# ======>
|
||||
# torch.Size([128, 256, 192]), torch.int32, strides: (49152, 1, 256)
|
||||
layer.w2_qweight = torch.nn.Parameter(repack_quant_moe_weight(layer.w2_qweight.cpu()).transpose(-1, -2).contiguous().transpose(-1, -2).to(device=dev_w2), requires_grad=False)
|
||||
# torch.Size([128, 2048, 3]), torch.float16, strides: (6144, 3, 1)
|
||||
# ======>
|
||||
# torch.Size([128, 2048, 3]), torch.float16, strides: (6144, 1, 2048)
|
||||
layer.w2_scales = torch.nn.Parameter(layer.w2_scales.transpose(-1, -2).contiguous().transpose(-1, -2), requires_grad=False)
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,39 @@
|
||||
import torch
|
||||
from typing import List, Optional
|
||||
|
||||
def _apply_w8a8_block_fp8_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
block_size: list[int],
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
cutlass_block_fp8_supported: bool = True,
|
||||
use_aiter_and_is_supported: bool = False,
|
||||
) -> torch.Tensor:
|
||||
assert input_scale is None
|
||||
assert len(block_size) == 2, "only support dim2 block now"
|
||||
# View input as 2D matrix for fp8 methods
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||
|
||||
try:
|
||||
from torch_vacc.vacc.custom_ops import w8a8_block_fp8_linear
|
||||
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
|
||||
|
||||
mla_oproj_output = None
|
||||
if memory_recycler is not None:
|
||||
os1, os2 = memory_recycler.MLA_OPROJ_OUT_BUFFER.shape
|
||||
if os1 == input_2d.size(0) and os2 == weight.size(0):
|
||||
mla_oproj_output = memory_recycler.MLA_OPROJ_OUT_BUFFER
|
||||
|
||||
output = w8a8_block_fp8_linear(input_2d, weight, input_scale, weight_scale, block_size, output = mla_oproj_output)
|
||||
except Exception as e:
|
||||
print("vacc fuse fp8 matmul run fail:", e, " , now use unfused ops")
|
||||
from torch_vacc.vacc.custom_ops_cpu import w8a8_block_fp8_linear
|
||||
output = w8a8_block_fp8_linear(input_2d, weight, input_scale, weight_scale, block_size)
|
||||
|
||||
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(dtype=input.dtype).view(*output_shape)
|
||||
Reference in New Issue
Block a user