Files
2026-04-02 04:55:00 +00:00

615 lines
29 KiB
Python

import functools
import importlib.util
from typing import Any, Callable, Dict, List, Optional
import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
apply_fp8_block_linear, check_aiter_fp8_linear_support,
create_fp8_input_scale, create_fp8_scale_parameter,
create_fp8_weight_parameter, expert_weight_is_col_major,
maybe_post_process_fp8_weight_block, process_fp8_weight_block_strategy,
process_fp8_weight_tensor_strategy, requant_weight_ue8m0_inplace,
validate_fp8_block_shape)
# from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
# all_close_1d, apply_fp8_linear, convert_to_channelwise,
# cutlass_block_fp8_supported, cutlass_fp8_supported,
# normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
# requantize_with_max_scale)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, all_close_1d, convert_to_channelwise,
cutlass_block_fp8_supported, cutlass_fp8_supported,
maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz,
per_tensor_dequantize, requantize_with_max_scale)
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
from vllm.platforms import current_platform
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, is_layer_skipped)
from vllm.model_executor.layers.linear import QKVParallelLinear
from vllm.utils import has_deep_gemm
logger = init_logger(__name__)
# has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
def Fp8LinearMethod__init(self, quant_config: Fp8Config):
self.quant_config = quant_config
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
self.out_dtype = torch.get_default_dtype()
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self.use_marlin = (not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
# Disable marlin for rocm
if current_platform.is_rocm():
self.use_marlin = False
self.weight_block_size = self.quant_config.weight_block_size
self.block_quant = self.quant_config.weight_block_size is not None
self.act_q_static = self.quant_config.activation_scheme == "static"
# Use per-token quantization for better perf if dynamic and cutlass
if not self.act_q_static and cutlass_fp8_supported():
self.act_q_group_shape = GroupShape.PER_TOKEN
else:
self.act_q_group_shape = GroupShape.PER_TENSOR
if self.block_quant:
self.block_size = self.quant_config.weight_block_size
if self.block_quant:
# Marlin doesn't support block-wise fp8
self.use_marlin = False
self.scale_k = 1
self.scale_n = 1
self.scale_n_prefill = 1 # only for fp8 moe
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.act_q_static,
act_quant_group_shape=self.act_q_group_shape)
class Fp8LinearMethod(LinearMethodBase):
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
if self.block_quant:
scale_n = extra_weight_attrs.get("scale_n")
scale_k = extra_weight_attrs.get("scale_k")
if scale_n is not None:
self.scale_n = scale_n
if scale_k is not None:
self.scale_k = scale_k
assert self.weight_block_size is not None
layer.weight_block_size = self.weight_block_size
tp_size = get_tensor_model_parallel_world_size()
assert self.quant_config.weight_block_size is not None
block_n, block_k = (
self.quant_config.weight_block_size[0] // self.scale_n ,
self.quant_config.weight_block_size[1] // self.scale_k ,
)
# Required by row parallel
if (tp_size > 1
and input_size // input_size_per_partition == tp_size
and input_size_per_partition % block_k != 0):
raise ValueError(
f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible by "
f"weight quantization block_k = {block_k}.")
# Required by column parallel or enabling merged weights
if (tp_size > 1 and output_size // output_size_per_partition
== tp_size) or len(output_partition_sizes) > 1:
for output_partition_size in output_partition_sizes:
if output_partition_size % block_n != 0:
raise ValueError(
f"Weight output_partition_size = "
f"{output_partition_size} is not divisible by "
f"weight quantization block_n = {block_n}.")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
# WEIGHT
weight_dtype = (torch.float8_e4m3fn
if self.quant_config.is_checkpoint_fp8_serialized else
params_dtype)
weight = ModelWeightParameter(data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=weight_dtype),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight)
# If checkpoint is serialized fp8, load them.
# Otherwise, wait until process_weights_after_loading.
if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALE
if not self.block_quant:
scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes),
dtype=torch.float32),
weight_loader=weight_loader,
)
scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", scale)
else:
assert self.quant_config.activation_scheme == "dynamic"
scale = BlockQuantScaleParameter(
data=torch.empty(
(output_size_per_partition + block_n - 1) // block_n,
(input_size_per_partition + block_k - 1) // block_k,
dtype=torch.float32,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
scale[:] = torch.finfo(torch.float32).min
# The weight_scale_inv name is intentional for deepseekv3
layer.register_parameter("weight_scale_inv", scale)
# INPUT ACTIVATION SCALE
if self.quant_config.activation_scheme == "static":
scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("input_scale", scale)
else:
layer.register_parameter("input_scale", None)
def process_weights_after_loading(self, layer: Module) -> None:
# TODO(rob): refactor block quant into separate class.
if self.block_quant:
assert self.quant_config.activation_scheme == "dynamic"
if current_platform.is_fp8_fnuz():
weight, weight_scale_inv, _ = \
normalize_e4m3fn_to_e4m3fnuz(
weight=layer.weight,
weight_scale=layer.weight_scale_inv)
else:
weight = layer.weight.data
weight_scale_inv = layer.weight_scale_inv.data
if isinstance(layer, QKVParallelLinear):
# NOTE: for QKVParallelLinear
# weight_scale should be divisible by 8 Dsps
shape = weight_scale_inv.shape[0]
repeat = 1
while shape % 8 != 0:
repeat *= 2
shape = shape * repeat
weight_scale_inv = torch.repeat_interleave(weight_scale_inv, repeats=repeat, dim=0)
# weight = self._maybe_pad_weight(weight)
# if self.block_quant:
# maybe_post_process_fp8_weight_block(
# layer, self.cutlass_block_fp8_supported)
# Torch.compile cannot use Parameter subclasses.
layer.weight = Parameter(weight, requires_grad=False)
layer.weight_scale_inv = Parameter(weight_scale_inv,
requires_grad=False)
return
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.use_marlin:
return apply_fp8_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias)
# Note: lazy import to avoid triton import error.
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
apply_w8a8_block_fp8_linear)
if self.block_quant:
assert self.quant_config.weight_block_size is not None
return apply_w8a8_block_fp8_linear(
input=x,
weight=layer.weight,
block_size=[layer.weight.shape[0] // layer.weight_scale_inv.shape[0], layer.weight.shape[1] // layer.weight_scale_inv.shape[1]],
weight_scale=layer.weight_scale_inv,
input_scale=layer.input_scale,
bias=bias,
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
)
return self.fp8_linear.apply(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=layer.input_scale,
bias=bias)
# return apply_fp8_linear(
# input=x,
# weight=layer.weight,
# weight_scale=layer.weight_scale,
# input_scale=layer.input_scale,
# bias=bias,
# cutlass_fp8_supported=self.cutlass_fp8_supported,
# # Default to using per_token quantization if cutlass is supported
# use_per_token_if_dynamic=self.cutlass_fp8_supported)
def Fp8MoEMethod_init_(self, quant_config: Fp8Config, layer: torch.nn.Module):
self.layer = layer
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
self.quant_config = quant_config
self.block_quant = self.quant_config.weight_block_size is not None
self.flashinfer_moe_backend = None
self.scale_k = 1
self.scale_n = 1
self.scale_n_prefill = 1
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self.use_marlin = (not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
# Disable marlin for rocm
if current_platform.is_rocm() or current_platform.is_vacc:
self.use_marlin = False
# Check for DeepGemm support.
self.allow_deep_gemm = False
if envs.VLLM_USE_DEEP_GEMM:
if not has_deep_gemm():
logger.warning_once("Failed to import DeepGemm kernels.")
elif not self.block_quant:
logger.warning_once("Model is not block quantized. Not using "
" DeepGemm kernels")
elif (current_platform.is_cuda()
and current_platform.has_device_capability(90)):
logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
self.allow_deep_gemm = True
else:
logger.warning_once(
"DeepGemm not supported on the current platform.")
# Check for CutlassBlockScaledGroupedGemm support.
self.allow_cutlass_block_scaled_grouped_gemm = False
if not self.block_quant:
logger.warning_once("Model is not block quantized. Not using "
"CutlassBlockScaledGroupedGemm kernels")
elif (current_platform.is_cuda()
and current_platform.has_device_capability(100)):
logger.info_once(
"Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod."
)
self.allow_cutlass_block_scaled_grouped_gemm = True
else:
logger.warning_once(
"CutlassBlockScaledGroupedGemm not supported on the current "
"platform.")
self.topk_indices_dtype = None
self.fused_experts = functools.partial( # type: ignore
fused_experts,
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm,
allow_cutlass_block_scaled_grouped_gemm=(
self.allow_cutlass_block_scaled_grouped_gemm))
class Fp8MoEMethod(FusedMoEMethodBase):
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn
if self.block_quant:
assert self.quant_config.weight_block_size is not None
scale_n = extra_weight_attrs.get("scale_n")
scale_n_prefill = extra_weight_attrs.get("scale_n_prefill")
scale_k = extra_weight_attrs.get("scale_k")
if scale_n is not None:
self.scale_n = scale_n
if scale_k is not None:
self.scale_k = scale_k
if scale_n_prefill is not None:
self.scale_n_prefill = scale_n_prefill
if self.quant_config is not None and self.quant_config.weight_block_size is not None:
self.gcd_value = self.quant_config.weight_block_size[0]
output_size_no_merge = intermediate_size_per_partition
#assert isinstance(output_size_no_merge, int), f"merge output size should divded int, valuue is: {output_size_no_merge}"
if output_size_no_merge % self.quant_config.weight_block_size[0]:
import math
gcd_value = math.gcd(output_size_no_merge % self.quant_config.weight_block_size[0], self.quant_config.weight_block_size[0])
self.scale_n =self.scale_n * self.quant_config.weight_block_size[0] // gcd_value
self.scale_n_prefill =self.scale_n_prefill * self.quant_config.weight_block_size[0] // gcd_value
if hidden_size % self.quant_config.weight_block_size[1]:
import math
gcd_value = math.gcd(hidden_size % self.quant_config.weight_block_size[1], self.quant_config.weight_block_size[1])
self.scale_k =self.scale_k * self.quant_config.weight_block_size[1] // gcd_value
# self.scale_k = self.scale_n
# print('output_size_no_merge', output_size_no_merge)
# 按 block_size 分core
# output_size_no_merge = 384
# block_size = 128: 384 = 3x128 只能分3core x 128
# block_size = 16: 384 = 24x16 8core x (3x16) 可以分到 8core
# output_size_no_merge = 512
# block_size = 128: 512 = 4x128 只能分 4core x 128
# block_size = 64: 512 = 8x64 可以分到 8core x 64
# output_size_no_merge = 768
# block_size = 128: 768 = 6x128 只能分 6core x 128
# block_size = 32: 768 = 8x(3x32) 可以分到 8core x (3x32)
core_num = 8
min_block_size = 4
block_size_tmp = self.quant_config.weight_block_size[0] // self.scale_n
if output_size_no_merge > block_size_tmp and \
output_size_no_merge % block_size_tmp == 0 and \
output_size_no_merge // block_size_tmp < core_num and \
output_size_no_merge % core_num == 0:
core_num_old = output_size_no_merge // block_size_tmp
import math
gcd_value = math.gcd(core_num, core_num_old)
new_scale = core_num // gcd_value
if block_size_tmp // new_scale >= min_block_size:
self.scale_n = new_scale * self.scale_n
#print("moe scale n is:", self.scale_n, self.scale_k, intermediate_size_per_partition)
tp_size = get_tensor_model_parallel_world_size()
if self.scale_n != self.scale_n_prefill:
block_n_prefill = self.quant_config.weight_block_size[0] // self.scale_n_prefill
block_n, block_k = (
self.quant_config.weight_block_size[0] // self.scale_n,
self.quant_config.weight_block_size[1] // self.scale_k,
)
# NOTE: To ensure proper alignment of the block-wise quantization
# scales, the output_size of the weights for both the gate and up
# layers must be divisible by block_n.
# Required by column parallel or enabling merged weights
if intermediate_size_per_partition % block_n != 0:
raise ValueError(
f"The output_size of gate's and up's weight = "
f"{intermediate_size_per_partition} is not divisible by "
f"weight quantization block_n = {block_n}.")
if (tp_size > 1
and hidden_size % block_k != 0):
# Required by row parallel
raise ValueError(
f"The input_size of down's weight = "
f"{intermediate_size_per_partition} is not divisible by "
f"weight quantization block_k = {block_k}.")
# WEIGHTS
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
if not self.block_quant:
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter(torch.ones(
num_experts, 2, dtype=torch.float32),
requires_grad=False)
w2_weight_scale = torch.nn.Parameter(torch.ones(
num_experts, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
else:
w13_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
2 * ((intermediate_size_per_partition + block_n - 1) //
block_n),
(hidden_size + block_k - 1) // block_k,
dtype=torch.float32,
),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
(hidden_size + block_k - 1) // block_k,
(intermediate_size_per_partition + block_n - 1) // block_n,
dtype=torch.float32,
),
requires_grad=False,
)
if self.scale_n != self.scale_n_prefill:
w13_weight_scale_prefill = torch.nn.Parameter(
torch.ones(
num_experts,
2 * ((intermediate_size_per_partition + block_n_prefill - 1) //
block_n_prefill),
(hidden_size + block_k - 1) // block_k,
dtype=torch.float32,
),
requires_grad=False,
)
w2_weight_scale_prefill = torch.nn.Parameter(
torch.ones(
num_experts,
(hidden_size + block_k - 1) // block_k,
(intermediate_size_per_partition + block_n_prefill - 1) // block_n_prefill,
dtype=torch.float32,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale_inv_prefill", w13_weight_scale_prefill)
layer.register_parameter("w2_weight_scale_inv_prefill", w2_weight_scale_prefill)
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
assert self.quant_config.activation_scheme == "dynamic"
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.
value} if self.block_quant else
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if self.quant_config.is_checkpoint_fp8_serialized:
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
if self.scale_n != self.scale_n_prefill:
set_weight_attrs(w13_weight_scale_prefill, extra_weight_attrs)
set_weight_attrs(w2_weight_scale_prefill, extra_weight_attrs)
# INPUT_SCALES
if self.quant_config.activation_scheme == "static":
if not self.quant_config.is_checkpoint_fp8_serialized:
raise ValueError(
"Found static activation scheme for checkpoint that "
"was not serialized fp8.")
w13_input_scale = torch.nn.Parameter(torch.ones(
num_experts, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_input_scale", w13_input_scale)
set_weight_attrs(w13_input_scale, extra_weight_attrs)
w2_input_scale = torch.nn.Parameter(torch.ones(
num_experts, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_input_scale", w2_input_scale)
set_weight_attrs(w2_input_scale, extra_weight_attrs)
else:
layer.w13_input_scale = None
layer.w2_input_scale = None
def moe_fp8_apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import FusedMoE
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
try:
from torch_vacc.vacc.custom_ops import fused_experts
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
experts_output = None
if memory_recycler is not None:
# remove MOE_EXPERT_OUT_BUFFER
# experts_output = memory_recycler.MOE_EXPERT_OUT_BUFFER
experts_output = memory_recycler.MOE_SHARED_MLP_OUT_BUFFER
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
use_fp8_w8a8=True,
w13_scale=(layer.w13_weight_scale_inv
if self.block_quant else layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale_inv
if self.block_quant else layer.w2_weight_scale),
a13_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
decode_with_batch=layer.is_decode and x.shape[0] > 1,
output_opt=experts_output
)
except Exception as e:
print(f"vacc fused_expert run fail, now using unfused ops: {e}")
from torch_vacc.vacc.custom_ops_cpu import fused_experts
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
use_fp8_w8a8=True,
w13_scale=(layer.w13_weight_scale_inv
if self.block_quant else layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale_inv
if self.block_quant else layer.w2_weight_scale),
a13_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
)