[gpt-oss] Add gpt-oss mxfp4 support

This commit is contained in:
2025-08-25 15:31:09 +08:00
parent db7f48eeac
commit 7a35b2f32d
32 changed files with 4835 additions and 1190 deletions

View File

@@ -36,6 +36,7 @@ QuantizationMethods = Literal[
"moe_wna16",
"torchao",
"auto-round",
"mxfp4",
]
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
@@ -108,6 +109,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from .marlin import MarlinConfig
from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config
from .moe_wna16 import MoeWNA16Config
from .mxfp4 import Mxfp4Config
from .neuron_quant import NeuronQuantConfig
from .ptpc_fp8 import PTPCFp8Config
from .qqq import QQQConfig
@@ -143,6 +145,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"moe_wna16": MoeWNA16Config,
"torchao": TorchAOConfig,
"auto-round": AutoRoundConfig,
"mxfp4": Mxfp4Config,
}
# Update the `method_to_config` with customized quantization methods.
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)

View File

@@ -0,0 +1,581 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional
import torch
from torch.nn.parameter import Parameter
from vllm import envs
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase, fused_experts)
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
triton_kernel_moe_forward)
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
prepare_moe_fp4_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
_can_support_mxfp4, _swizzle_mxfp4)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils import (has_triton_kernels, is_torch_equal_or_newer,
next_power_of_2, round_up)
if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
# from flashinfer.fused_moe import cutlass_fused_moe
from flashinfer import (mxfp8_quantize, shuffle_matrix_a,
shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe)
class Mxfp4Config(QuantizationConfig):
def __init__(self, ignored_layers: Optional[list[str]] = None):
super().__init__()
self.ignored_layers = ignored_layers
@classmethod
def from_config(cls, config):
return cls()
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_name(cls) -> QuantizationMethods:
return "mxfp4"
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16]
@classmethod
def get_config_filenames(cls) -> list[str]:
return []
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase):
if self.ignored_layers and is_layer_skipped(
prefix=prefix,
ignored_layers=self.ignored_layers,
fused_mapping=self.packed_modules_mapping):
return UnquantizedLinearMethod()
raise NotImplementedError("Mxfp4 linear layer is not implemented")
elif isinstance(layer, FusedMoE):
return Mxfp4MoEMethod(layer.moe_config)
elif isinstance(layer, Attention):
raise NotImplementedError(
"Mxfp4 attention layer is not implemented")
return None
class Mxfp4MoEMethod(FusedMoEMethodBase):
def __init__(self, moe: FusedMoEConfig):
super().__init__()
self.topk_indices_dtype = None
self.moe = moe
self.use_marlin = self._should_use_marlin()
def _should_use_marlin(self):
if envs.VLLM_MXFP4_USE_MARLIN is not None:
return envs.VLLM_MXFP4_USE_MARLIN
# if current_platform.is_cuda() and \
# not current_platform.has_device_capability(100):
# if not current_platform.is_device_capability(90):
# # marlin kernel has better performance on ampere
# return True
# if not has_triton_kernels():
# return True
# if not is_torch_equal_or_newer("2.8.0"):
# return True
return False
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
self.num_experts = num_experts
weight_dtype = torch.uint8
scale_dtype = torch.uint8
# FIXME (zyongye): ship after torch and safetensors support mxfp4
# is_torch_mxfp4_available = (
# hasattr(torch, "float4_e2m1fn_x2") and
# hasattr(torch, "float8_e8m0fnu"))
# if is_torch_mxfp4_available:
# weight_dtype = torch.float4_e2m1fn_x2
# scale_dtype = torch.float8_e8m0fnu
mxfp4_block = 32
intermediate_size_per_partition_after_pad = \
intermediate_size_per_partition
if self.use_marlin:
# The moe marlin kernel requires that for each linear
# n % 256 == 0 and k % 128 == 0.
# In gate_up_proj:
# n = 2 * intermediate_size_per_partition_after_pad
# k = hidden_size
# In down_proj
# n = hidden_size
# k = intermediate_size_per_partition_after_pad
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 128)
hidden_size = round_up(hidden_size, 256)
layer.params_dtype = params_dtype
layer.num_experts = num_experts
layer.hidden_size = hidden_size
layer.intermediate_size_per_partition = \
intermediate_size_per_partition_after_pad
elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
# pad the intermediate size to be a multiple of 2 * mxfp4_block
# for to hold non-uniform sharded tensor as well as swizzling
# other padding to increase performance
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 256)
hidden_size = round_up(hidden_size, 256)
elif current_platform.is_rocm():
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 128)
else:
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 64)
self.intermediate_size = intermediate_size_per_partition_after_pad
self.hidden_size = hidden_size
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition_after_pad,
hidden_size // 2,
dtype=weight_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w13_weight_scale = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition_after_pad,
hidden_size // mxfp4_block,
dtype=scale_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
w13_bias = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition_after_pad,
dtype=torch.bfloat16,
),
requires_grad=False,
)
layer.register_parameter("w13_bias", w13_bias)
set_weight_attrs(w13_bias, extra_weight_attrs)
# down_proj (row parallel)
w2_weight = torch.nn.Parameter(
torch.zeros(
num_experts,
hidden_size,
intermediate_size_per_partition_after_pad // 2,
dtype=weight_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
w2_weight_scale = torch.nn.Parameter(
torch.zeros(
num_experts,
hidden_size,
intermediate_size_per_partition_after_pad // mxfp4_block,
dtype=scale_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
w2_bias = torch.nn.Parameter(
torch.zeros(
num_experts,
hidden_size,
dtype=torch.bfloat16,
),
requires_grad=False,
)
layer.register_parameter("w2_bias", w2_bias)
set_weight_attrs(w2_bias, extra_weight_attrs)
def process_weights_after_loading(self, layer):
if self.use_marlin:
prepare_moe_fp4_layer_for_marlin(layer)
elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
layer.gemm1_alpha = Parameter(torch.tensor(
[1.702] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False)
layer.gemm1_beta = Parameter(torch.tensor(
[1.0] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False)
layer.gemm1_clamp_limit = Parameter(torch.tensor(
[7.0] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False)
sf_block_size = 32 # mxfp4 block size
assert (layer.w13_weight.dim() == 3
and layer.w13_weight.shape[0] == self.num_experts
and layer.w13_weight.shape[1] == self.intermediate_size * 2
and layer.w13_weight.shape[2] == self.hidden_size // 2)
assert (layer.w13_weight_scale.dim() == 3
and layer.w13_weight_scale.shape[0] == self.num_experts
and layer.w13_weight_scale.shape[1]
== self.intermediate_size * 2
and layer.w13_weight_scale.shape[2]
== self.hidden_size // sf_block_size)
assert (layer.w2_weight.dim() == 3
and layer.w2_weight.shape[0] == self.num_experts
and layer.w2_weight.shape[1] == self.hidden_size and
layer.w2_weight.shape[2] == self.intermediate_size // 2)
assert (layer.w2_weight_scale.dim() == 3
and layer.w2_weight_scale.shape[1] == self.hidden_size
and layer.w2_weight_scale.shape[2]
== self.intermediate_size // sf_block_size)
assert (layer.w13_bias.dim() == 2
and layer.w13_bias.shape[0] == self.num_experts
and layer.w13_bias.shape[1] == self.intermediate_size * 2)
assert (layer.w2_bias.dim() == 2
and layer.w2_bias.shape[0] == self.num_experts
and layer.w2_bias.shape[1] == self.hidden_size)
w13_weight_scale = layer.w13_weight_scale.data
w2_weight_scale = layer.w2_weight_scale.data
w13_weight = layer.w13_weight.data
w2_weight = layer.w2_weight.data
w13_bias = layer.w13_bias.data.to(torch.float32)
w2_bias = layer.w2_bias.data.to(torch.float32)
# Swap w1 and w3 as the defenition of
# swiglu is different in the trtllm-gen
def swap_every_two_rows(x, axis=-1):
shape = x.shape
if axis < 0:
axis = len(shape) + axis
# Create a new shape with pairs swapped along specified axis
new_shape = list(shape)
new_shape[axis] = shape[axis] // 2
new_shape.insert(axis + 1, 2)
# Reshape to expose pairs, swap them, and reshape back
x = x.reshape(*new_shape)
x = x.flip(axis + 1)
new_shape = list(shape)
return x.reshape(*new_shape)
w13_weight_scale = swap_every_two_rows(w13_weight_scale, -2)
w13_weight = swap_every_two_rows(w13_weight, -2)
w13_bias = swap_every_two_rows(w13_bias, -1)
# Do not interleave as the checkpoint is already interleaved
# Shuffle weights and scaling factors for transposed mma output
gemm1_weights_mxfp4_shuffled = []
gemm1_scales_mxfp4_shuffled = []
gemm2_weights_mxfp4_shuffled = []
gemm2_scales_mxfp4_shuffled = []
gemm1_bias_shuffled = []
gemm2_bias_shuffled = []
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
for i in range(self.num_experts):
gemm1_weights_mxfp4_shuffled.append(
shuffle_matrix_a(w13_weight[i].view(torch.uint8),
epilogue_tile_m))
gemm1_scales_mxfp4_shuffled.append(
shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8),
epilogue_tile_m))
gemm1_bias_shuffled.append(
shuffle_matrix_a(w13_bias[i].clone().reshape(-1, 1),
epilogue_tile_m))
gemm2_weights_mxfp4_shuffled.append(
shuffle_matrix_a(w2_weight[i].view(torch.uint8),
epilogue_tile_m))
gemm2_scales_mxfp4_shuffled.append(
shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8),
epilogue_tile_m))
gemm2_bias_shuffled.append(
shuffle_matrix_a(w2_bias[i].clone().reshape(-1, 1),
epilogue_tile_m))
w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled)
w13_weight_scale = torch.stack(
gemm1_scales_mxfp4_shuffled).reshape(
self.num_experts, 2 * self.intermediate_size,
self.hidden_size // sf_block_size).view(
torch.float8_e4m3fn)
w2_weight = torch.stack(gemm2_weights_mxfp4_shuffled)
w2_weight_scale = torch.stack(gemm2_scales_mxfp4_shuffled).reshape(
self.num_experts, self.hidden_size, self.intermediate_size //
sf_block_size).view(torch.float8_e4m3fn)
layer.w13_weight = Parameter(w13_weight, requires_grad=False)
layer.w13_weight_scale = Parameter(w13_weight_scale,
requires_grad=False)
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
layer.w2_weight_scale = Parameter(w2_weight_scale,
requires_grad=False)
layer.w13_bias = Parameter(
torch.stack(gemm1_bias_shuffled).reshape(self.num_experts, -1),
requires_grad=False)
layer.w2_bias = Parameter(torch.stack(gemm2_bias_shuffled).reshape(
self.num_experts, -1),
requires_grad=False)
elif has_triton_kernels():
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
w13_bias = layer.w13_bias.to(torch.float32)
w2_bias = layer.w2_bias.to(torch.float32)
layer.w13_bias = Parameter(w13_bias, requires_grad=False)
layer.w2_bias = Parameter(w2_bias, requires_grad=False)
# FIXME warp need to be adjusted based on batch size
# only apply to batched mode
if self.moe.use_ep:
num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
else:
num_warps = 8
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
layer.w13_weight, layer.w13_weight_scale, num_warps)
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
layer.w2_weight, layer.w2_weight_scale, num_warps)
self.w13_precision_config = PrecisionConfig(
weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex))
self.w2_precision_config = PrecisionConfig(
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex))
self.w13_weight_triton_tensor = w13_weight
self.w2_weight_triton_tensor = w2_weight
# need to delete the original weights to save memory on single GPU
del layer.w13_weight
del layer.w2_weight
layer.w13_weight = None
layer.w2_weight = None
torch.cuda.empty_cache()
else:
# normal triton
from .triton_kernels_numerics_details.mxfp import upcast_from_mxfp
w13_weight = upcast_from_mxfp(
layer.w13_weight, layer.w13_weight_scale, dtype=torch.bfloat16, axis=-1
)
w2_weight = upcast_from_mxfp(
layer.w2_weight, layer.w2_weight_scale, dtype=torch.bfloat16, axis=-1
)
del layer.w13_weight
del layer.w2_weight
del layer.w13_weight_scale
del layer.w2_weight_scale
layer.w13_weight = Parameter(w13_weight.data, requires_grad=False)
layer.w2_weight = Parameter(w2_weight.data, requires_grad=False)
torch.cuda.empty_cache()
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int):
# Number of tokens in the input tensor.
num_tokens = x.shape[0]
# Factor to account for the imbalance of the experts.
# factor equals to the
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
# - 1.0 means perfect expert distribution.
# - > 1.0 means some experts have more
# tokens than the perfect distribution.
# - < 1.0 does not make sense.
imbalance_factor = 1.3
# Calculate the number of tokens per expert
# assuming perfect distribution.
num_tokens_per_expert = (num_tokens * top_k) // self.num_experts
# Apply the imbalance factor.
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
# And pad the number to the next power of 2.
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
# Cap to 8-64 tokens per CTA tile
# as it's the range supported by the kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
return tile_tokens_dim
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if enable_eplb:
raise NotImplementedError("EPLB is not supported for mxfp4")
if self.use_marlin:
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_weight,
layer.w2_weight,
layer.w13_bias,
layer.w2_bias,
layer.w13_weight_scale,
layer.w2_weight_scale,
router_logits,
topk_weights,
topk_ids,
global_scale1=None,
global_scale2=None,
quant_type_id=scalar_types.float4_e2m1f.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
activation=activation,
expert_map=expert_map)
assert _can_support_mxfp4(
use_grouped_topk, topk_group, num_expert_group, expert_map,
custom_routing_function, e_score_correction_bias,
apply_router_weight_on_input, scoring_func, activation,
expert_load_view, logical_to_physical_map,
logical_replica_count), (
"MXFP4 are not supported with this configuration.")
if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
assert not self.moe.use_ep, (
"EP is not supported for flashinfer mxfp4 moe backend yet.")
if envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16:
assert x.dtype == torch.bfloat16
x_quant = x
x_scale = None
else:
x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
trtllm_gen_output = trtllm_fp4_block_scale_moe(
router_logits.to(torch.bfloat16),
None, # routing_bias
x_quant,
x_scale,
layer.w13_weight, # uint8 (e2m1 x 2)
layer.w13_weight_scale, # uint8 (e4m3 x 2)
layer.w13_bias, # fp32 per expert per channel
layer.gemm1_alpha, # fp32 per expert
layer.gemm1_beta, # fp32 per expert
layer.gemm1_clamp_limit, # fp32 per expert
layer.w2_weight, # uint8 (e2m1 x 2)
layer.w2_weight_scale, # ue8m0
layer.w2_bias, # fp32 per expert per channel
None, # output1_scale_scalar
None, # output1_scale_gate_scalar
None, # output2_scale_scalar
self.num_experts,
top_k,
None, # n_group
None, # topk_group
self.intermediate_size, # padded to multiple of 256
0, # local_expert_offset
self.num_experts, # local num experts
None,
self._get_tile_tokens_dim(x, top_k),
1 if renormalize else 0, # routing_method_type, renormalize
True, # do finalize
)[0]
return trtllm_gen_output
elif has_triton_kernels():
return triton_kernel_moe_forward(
hidden_states=x,
w1=self.w13_weight_triton_tensor,
w2=self.w2_weight_triton_tensor,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
w1_precision=self.w13_precision_config,
w2_precision=self.w2_precision_config,
apply_router_weight_on_input=apply_router_weight_on_input,
)
else:
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
)

View File

@@ -6,14 +6,16 @@ from typing import Any, Callable, Optional
import torch
import torch.nn.functional as F
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
OCP_MX_BLOCK_SIZE, per_token_group_quant_mxfp4)
OCP_MX_BLOCK_SIZE, dequant_mxfp4, quant_dequant_mxfp4)
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter)
from vllm.platforms import current_platform
logger = init_logger(__name__)
__all__ = ["QuarkW4A4MXFP4"]
@@ -25,7 +27,29 @@ class QuarkW4A4MXFP4(QuarkScheme):
self.qscheme = "per_group"
self.weight_quant_spec = weight_quant_spec
self.input_quant_spec = input_quant_spec
self.emulate = not current_platform.supports_mx()
self.static_input_scales = not input_quant_spec.get("is_dynamic")
if self.static_input_scales:
raise NotImplementedError(
"QuarkW4A4MXFP4 with static input scales is currently not "
"implemented. Please open an issue.")
if not current_platform.supports_mx():
self.emulate = True
logger.warning_once(
"The current platform does not support native MXFP4 "
"computation. Simulated weight dequantization and activation "
"QDQ (quantize and dequantize) will be used, with the linear "
"layers computed in high precision.")
else:
self.emulate = True
logger.warning_once(
"The current platform supports native MXFP4 "
"computation, but kernels are not yet integrated in vLLM. "
"Simulated weight dequantization and activation "
"QDQ (quantize and dequantize) will be used, with the linear "
"layers computed in high precision.")
@classmethod
def get_min_capability(cls) -> int:
@@ -37,43 +61,6 @@ class QuarkW4A4MXFP4(QuarkScheme):
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
requires_grad=False)
if self.emulate:
try:
from quark.torch.export.nn.modules import realquantizer
from quark.torch.quantization.config.config import (
QuantizationSpec)
except ImportError as err:
raise ImportError(
"The package `amd-quark` is required to use AMD Quark "
"MX-FP4 models. Please install it with `pip install "
"amd-quark`.") from err
weight_quant_spec = QuantizationSpec.from_dict(
self.weight_quant_spec)
weight_quantizer = realquantizer.get_real_quantizer(
qspec=weight_quant_spec,
quantizer=None,
real_quantized=True,
reorder=False,
float_dtype=self.out_dtype,
scale_shape=layer.weight_scale.shape,
zero_point_shape=None,
)
weight_quantizer.scale.data = layer.weight_scale.data
if not envs.VLLM_QUARK_EMU_MEM_OPT:
layer.weight = torch.nn.Parameter(
weight_quantizer(layer.weight.data).to(self.out_dtype),
requires_grad=False,
)
else:
self.weight_quantizer = weight_quantizer
layer.weight_scale = None
# This call is necessary to release the scales memory.
torch.cuda.empty_cache()
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: list[int],
input_size_per_partition: int,
@@ -116,11 +103,10 @@ class QuarkW4A4MXFP4(QuarkScheme):
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.emulate:
if envs.VLLM_QUARK_EMU_MEM_OPT:
dq_w = self.weight_quantizer(layer.weight).to(self.out_dtype)
else:
dq_w = layer.weight
qdq_x, _ = per_token_group_quant_mxfp4(x, OCP_MX_BLOCK_SIZE)
return F.linear(qdq_x, dq_w, bias)
dq_w = dequant_mxfp4(layer.weight, layer.weight_scale, x.dtype)
x = quant_dequant_mxfp4(x)
return F.linear(x, dq_w, bias)
else:
raise NotImplementedError()

View File

@@ -0,0 +1,158 @@
import triton
import triton.language as tl
# fmt: off
MXFP_BLOCK_SIZE = tl.constexpr(32)
@triton.jit
def _get_max_quant_val(dtype: tl.constexpr):
if dtype == tl.uint8:
return 6.0
elif dtype == tl.float8e5:
return 57344.0
elif dtype == tl.float8e4nv:
return 448.0
else:
tl.static_assert(False, f"Invalid {dtype=}")
@triton.jit
def _compute_quant_and_scale(src_tensor, valid_src_mask, mx_tensor_dtype: tl.constexpr,
DEQUANT_SCALE_ROUNDING_MODE: tl.constexpr = 0):
is_fp8: tl.constexpr = mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5
BLOCK_SIZE_OUT_DIM: tl.constexpr = src_tensor.shape[0]
BLOCK_SIZE_QUANT_DIM: tl.constexpr = src_tensor.shape[1]
BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = src_tensor.shape[1] // MXFP_BLOCK_SIZE
# Explicit cast to fp32 since most ops are not supported on bfloat16. We avoid needless conversions to and from bf16
f32_tensor = src_tensor.to(tl.float32)
abs_tensor = tl.abs(f32_tensor)
abs_tensor = tl.where(valid_src_mask, abs_tensor, -1.0) # Don't consider padding tensors in scale computation
abs_tensor = tl.reshape(abs_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
max_val = tl.max(abs_tensor, axis=2, keep_dims=True)
dequant_scale = max_val / _get_max_quant_val(mx_tensor_dtype)
if DEQUANT_SCALE_ROUNDING_MODE == 0:
# DequantScaleRoundingMode.ROUND_UP
# compute 2 ** ceil(log2(dequant_scale))
# Adding 0x007FFFFF adds exponent by 1 unless mantissa is all zeros
# A corner case: exponent is 0xFF that will overflow but that's already
# NaN so assume we don't care.
dequant_scale_exponent = (dequant_scale.to(tl.uint32, bitcast=True) + 0x007FFFFF) & 0x7F800000
else:
# DequantScaleRoundingMode.ROUND_DOWN
# compute 2 ** floor(log2(dequant_scale))
assert DEQUANT_SCALE_ROUNDING_MODE == 1
dequant_scale_exponent = dequant_scale.to(tl.uint32, bitcast=True) & 0x7F800000
dequant_scale_rounded = dequant_scale_exponent.to(tl.float32, bitcast=True)
quant_scale = tl.where(dequant_scale_rounded == 0, 0, 1.0 / dequant_scale_rounded)
f32_tensor = tl.reshape(f32_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
quant_tensor = f32_tensor * quant_scale
# Reshape the tensors after scaling
quant_tensor = quant_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM])
# Set the invalid portions of the tensor to 0. This will ensure that any padding tensors are 0 in the mx format.
quant_tensor = tl.where(valid_src_mask, quant_tensor, 0)
dequant_scale_exponent = dequant_scale_exponent.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE])
# First, we simply extract the exponent part of the scales and store the result
dequant_scale_exponent = (dequant_scale_exponent >> 23).to(tl.uint8)
# Now we must convert the tensors to the mx format.
if is_fp8:
out_tensor = quant_tensor.to(mx_tensor_dtype)
else:
quant_tensor = quant_tensor.to(tl.uint32, bitcast=True)
signs = quant_tensor & 0x80000000
exponents = (quant_tensor >> 23) & 0xFF
mantissas = (quant_tensor & 0x7FFFFF)
# 0.25 <= x < 0.75 maps to 0.5, a denormal number
E8_BIAS = 127
E2_BIAS = 1
# Move implicit bit 1 at the beginning to mantissa for denormals
adjusted_exponents = tl.core.sub(E8_BIAS, exponents + 1, sanitize_overflow=False)
mantissas = tl.where(exponents < E8_BIAS, (0x400000 | (mantissas >> 1)) >> adjusted_exponents, mantissas)
# For normal numbers, we change the bias from 127 to 1, and for subnormals, we keep exponent as 0.
exponents = tl.maximum(exponents, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS)
# Combine sign, exponent, and mantissa, while saturating
# rounding nearest with tie breaking up by adding +1 to one bit right of the LSB, then shift right
e2m1_tmp = tl.minimum((((exponents << 2) | (mantissas >> 21)) + 1) >> 1, 0x7)
e2m1_value = ((signs >> 28) | e2m1_tmp).to(tl.uint8)
e2m1_value = tl.reshape(e2m1_value, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM // 2, 2])
evens, odds = tl.split(e2m1_value)
out_tensor = evens | (odds << 4)
return out_tensor, dequant_scale_exponent
@triton.jit
def _downcast_to_mxfp(mx_tensor_ptr, stride_mxt_outer, stride_mxt_quant: tl.constexpr,
mx_scale_ptr, stride_mx_scale_outer, stride_mx_scale_quant,
src_ptr, stride_src_outer, stride_src_quant,
outer_dim, quant_dim,
BLOCK_SIZE_OUT_DIM: tl.constexpr, BLOCK_SIZE_QUANT_DIM: tl.constexpr,
DEQUANT_SCALE_ROUNDING_MODE: tl.constexpr):
tl.static_assert(stride_mxt_quant == 1, f"Output stride, {stride_mxt_quant=} must be 1.")
tl.static_assert(BLOCK_SIZE_QUANT_DIM % MXFP_BLOCK_SIZE == 0, f"{BLOCK_SIZE_QUANT_DIM=} must be a multiple of 32")
# uint8 signifies two fp4 e2m1 values packed into a single byte
mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty
tl.static_assert(mx_tensor_dtype == tl.uint8 or (mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5),
f"Invalid {mx_tensor_dtype=}. Must be uint8 or float8.")
src_dtype: tl.constexpr = src_ptr.dtype.element_ty
tl.static_assert(mx_scale_ptr.dtype.element_ty == tl.uint8, f"{mx_scale_ptr.dtype.element_ty=} must be uint8")
tl.static_assert((src_dtype == tl.bfloat16) or (src_dtype == tl.float16) or (src_dtype == tl.float32), f"{src_dtype=} must be bfloat16 or float16 or float32")
is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8
outer_block = tl.program_id(0).to(tl.int64)
quant_block = tl.program_id(1).to(tl.int64)
K_DIVISOR: tl.constexpr = 2 if is_fp4 else 1
BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // MXFP_BLOCK_SIZE
BLOCK_SIZE_QUANT_MX_TENSOR: tl.constexpr = BLOCK_SIZE_QUANT_DIM // K_DIVISOR
start_src_quant = quant_block * BLOCK_SIZE_QUANT_DIM
start_mx_scale_quant = quant_block * BLOCK_SIZE_QUANT_MX_SCALE
start_mx_quant = quant_block * BLOCK_SIZE_QUANT_MX_TENSOR
start_out = outer_block * BLOCK_SIZE_OUT_DIM
src_ptr += start_src_quant * stride_src_quant + start_out * stride_src_outer
mx_scale_ptr += start_mx_scale_quant * stride_mx_scale_quant + start_out * stride_mx_scale_outer
mx_tensor_ptr += start_mx_quant * stride_mxt_quant + start_out * stride_mxt_outer
offs_src_quant = tl.arange(0, BLOCK_SIZE_QUANT_DIM)[None, :].to(tl.int64)
offs_mxt_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_TENSOR)[None, :].to(tl.int64)
offs_scale_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_SCALE)[None, :].to(tl.int64)
offs_outer = tl.arange(0, BLOCK_SIZE_OUT_DIM)[:, None].to(tl.int64)
mask_src_quant = start_src_quant + offs_src_quant < quant_dim
mask_n = start_out + offs_outer < outer_dim
full_mask_src = mask_src_quant & mask_n
mask_mxt_quant = start_mx_quant + offs_mxt_quant < tl.cdiv(quant_dim, K_DIVISOR)
full_mask_mxt = mask_mxt_quant & mask_n
scale_mask_k = start_mx_scale_quant + offs_scale_quant < tl.cdiv(quant_dim, MXFP_BLOCK_SIZE)
full_scale_mask = scale_mask_k & mask_n
src_tensor_offsets = offs_src_quant * stride_src_quant + offs_outer * stride_src_outer
mx_scale_offsets = offs_scale_quant * stride_mx_scale_quant + offs_outer * stride_mx_scale_outer
mx_tensor_offsets = offs_mxt_quant * stride_mxt_quant + offs_outer * stride_mxt_outer
src_tensor = tl.load(src_ptr + src_tensor_offsets, mask=full_mask_src)
out_tensor, scale_tensor = _compute_quant_and_scale(src_tensor, full_mask_src, mx_tensor_dtype,
DEQUANT_SCALE_ROUNDING_MODE)
tl.store(mx_scale_ptr + mx_scale_offsets, scale_tensor, mask=full_scale_mask)
tl.store(mx_tensor_ptr + mx_tensor_offsets, out_tensor, mask=full_mask_mxt)
@triton.jit(repr=lambda _: "_dequantize_mxfp8")
def _dequantize_mxfp8_fn(input, mask, pid=None):
return _compute_quant_and_scale(input, mask, tl.float8e4nv)

View File

@@ -0,0 +1,136 @@
import triton
import triton.language as tl
from ._downcast_to_mxfp import MXFP_BLOCK_SIZE
# fmt: off
@triton.jit
def _upcast_from_mxfp(out_ptr, stride_o_outer, stride_o_quant: tl.constexpr, mx_scale_ptr, stride_scale_outer,
stride_scale_quant, mx_tensor_ptr, stride_tensor_outer, stride_tensor_quant: tl.constexpr,
outer_dim, quant_dim, BLOCK_SIZE_OUT_DIM: tl.constexpr, BLOCK_SIZE_QUANT_DIM: tl.constexpr):
tl.static_assert(stride_o_quant == 1, "the weight must be contiguous in the k dimension for mx")
tl.static_assert(BLOCK_SIZE_QUANT_DIM % MXFP_BLOCK_SIZE == 0, "BLOCK_SIZE_K must be a multiple of 32")
# uint8 signifies two fp4 e2m1 values packed into a single byte
mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty
dst_dtype: tl.constexpr = out_ptr.dtype.element_ty
tl.static_assert(dst_dtype == tl.float16 or (dst_dtype == tl.bfloat16 or dst_dtype == tl.float32))
tl.static_assert(
mx_tensor_dtype == tl.uint8
or ((mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5) or mx_tensor_dtype == dst_dtype),
"mx_tensor_ptr must be uint8 or float8 or dst_dtype")
tl.static_assert(mx_scale_ptr.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8")
# Determine if we are dealing with fp8 types.
is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8
is_fp8: tl.constexpr = mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5
K_DIVISOR: tl.constexpr = 2 if is_fp4 else 1
BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // MXFP_BLOCK_SIZE
BLOCK_SIZE_QUANT_MX_TENSOR: tl.constexpr = BLOCK_SIZE_QUANT_DIM // K_DIVISOR
# Compute starting indices for the quantized (packed) dimension and the outer dimension.
outer_block = tl.program_id(0).to(tl.int64)
quant_block = tl.program_id(1).to(tl.int64)
start_mxt_quant = quant_block * BLOCK_SIZE_QUANT_MX_TENSOR
start_out_quant = quant_block * BLOCK_SIZE_QUANT_DIM
start_mx_scale_quant = quant_block * BLOCK_SIZE_QUANT_MX_SCALE
start_out = outer_block * BLOCK_SIZE_OUT_DIM
mx_tensor_ptr += start_mxt_quant * stride_tensor_quant + start_out * stride_tensor_outer
mx_scale_ptr += start_mx_scale_quant * stride_scale_quant + start_out * stride_scale_outer
out_ptr += start_out * stride_o_outer + start_out_quant * stride_o_quant
# Compute offsets and masks.
offs_src_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_TENSOR)[None, :].to(tl.int64)
offs_out_quant = tl.arange(0, BLOCK_SIZE_QUANT_DIM)[None, :].to(tl.int64)
offs_outer = tl.arange(0, BLOCK_SIZE_OUT_DIM)[:, None].to(tl.int64)
offs_scale = tl.arange(0, BLOCK_SIZE_QUANT_MX_SCALE)[None, :].to(tl.int64)
mask_outer = start_out + offs_outer < outer_dim
mask_out_quant = start_out_quant + offs_out_quant < quant_dim
full_mask_out = mask_out_quant & mask_outer
mask_src_quant = start_mxt_quant + offs_src_quant < tl.cdiv(quant_dim, K_DIVISOR)
full_mask_src = mask_src_quant & mask_outer
mask_scale = start_mx_scale_quant + offs_scale < tl.cdiv(quant_dim, MXFP_BLOCK_SIZE)
full_scale_mask = mask_scale & mask_outer
tensor_offsets = offs_src_quant * stride_tensor_quant + offs_outer * stride_tensor_outer
scale_offsets = offs_scale * stride_scale_quant + offs_outer * stride_scale_outer
out_offsets = offs_out_quant * stride_o_quant + offs_outer * stride_o_outer
# Load the packed tensor and scale.
tensor = tl.load(mx_tensor_ptr + tensor_offsets, mask=full_mask_src)
scale = tl.load(mx_scale_ptr + scale_offsets, mask=full_scale_mask)
# Upcast the scale to the destination type.
if dst_dtype == tl.bfloat16:
# dst_scale = (scale.to(tl.uint16) << 7).to(dst_dtype, bitcast=True)
dst_scale = (scale.to(tl.uint16) << 7).to(tl.uint16).to(tl.bfloat16, bitcast=True)
else:
dst_scale = (scale.to(tl.uint32) << 23).to(tl.float32, bitcast=True)
if dst_dtype == tl.float16:
dst_scale = dst_scale.to(tl.float16)
# Now upcast the tensor.
intermediate_dtype: tl.constexpr = tl.bfloat16 if dst_dtype == tl.float32 else dst_dtype
if is_fp8:
dst_tensor = tensor.to(intermediate_dtype)
if tensor.dtype == tl.float8e5:
from_e_bits: tl.constexpr = 5
from_m_bits: tl.constexpr = 2
to_e_bits: tl.constexpr = 8 if intermediate_dtype == tl.bfloat16 else 5
to_m_bits: tl.constexpr = 7 if intermediate_dtype == tl.bfloat16 else 10
# Preserve infs and nans. FIXME Fp8E5M2_to_Bf16 doesn't preserve them!
non_finite_mask_src: tl.constexpr = ((1 << from_e_bits) - 1) << from_m_bits
non_finite_mask_dst: tl.constexpr = ((1 << to_e_bits) - 1) << to_m_bits
dst_tensor = tl.where(
(tensor.to(tl.uint8, bitcast=True) & non_finite_mask_src) == non_finite_mask_src,
(dst_tensor.to(tl.uint16, bitcast=True) | non_finite_mask_dst).to(intermediate_dtype, bitcast=True),
dst_tensor,
)
else:
assert is_fp4
dst_bias: tl.constexpr = 127 if intermediate_dtype == tl.bfloat16 else 15
dst_0p5: tl.constexpr = 16128 if intermediate_dtype == tl.bfloat16 else 0x3800
dst_m_bits: tl.constexpr = 7 if intermediate_dtype == tl.bfloat16 else 10
# e2m1
em0 = tensor & 0x07
em1 = tensor & 0x70
x0 = (em0.to(tl.uint16) << (dst_m_bits - 1)) | ((tensor & 0x08).to(tl.uint16) << 12)
x1 = (em1.to(tl.uint16) << (dst_m_bits - 5)) | ((tensor & 0x80).to(tl.uint16) << 8)
# Three cases:
# 1) x is normal and non-zero: Correct bias
x0 = tl.where((em0 & 0x06) != 0, x0 + ((dst_bias - 1) << dst_m_bits), x0)
x1 = tl.where((em1 & 0x60) != 0, x1 + ((dst_bias - 1) << dst_m_bits), x1)
# 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
x0 = tl.where(em0 == 0x01, dst_0p5 | (x0 & 0x8000), x0)
x1 = tl.where(em1 == 0x10, dst_0p5 | (x1 & 0x8000), x1)
# 3) x is zero, do nothing
if intermediate_dtype == tl.bfloat16:
dst_tensor = tl.interleave(x0, x1).to(tl.uint16).to(tl.bfloat16, bitcast=True)
else:
dst_tensor = tl.interleave(x0, x1).to(tl.float16, bitcast=True)
# dst_tensor = dst_tensor.to(dst_dtype)
if dst_dtype == tl.bfloat16:
dst_tensor = dst_tensor.to(tl.bfloat16)
elif dst_dtype == tl.float16:
dst_tensor = dst_tensor.to(tl.float16)
else:
dst_tensor = dst_tensor.to(tl.float32)
# Reshape for proper broadcasting: the scale was stored with a 32sized “inner” grouping.
dst_tensor = dst_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
dst_scale = dst_scale.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1])
scale = scale.reshape(dst_scale.shape)
out_tensor = dst_tensor * dst_scale
# Correct any NaNs encoded via the scale.
out_tensor = tl.where(scale == 0xFF, float("nan"), out_tensor)
out_tensor = out_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM])
tl.store(out_ptr + out_offsets, out_tensor, mask=full_mask_out)

View File

@@ -0,0 +1,303 @@
# isort: off
# fmt: off
from enum import Enum
import triton
import torch
import torch.nn.functional as F
from ._upcast_from_mxfp import _upcast_from_mxfp
from ._downcast_to_mxfp import _downcast_to_mxfp, _dequantize_mxfp8_fn, MXFP_BLOCK_SIZE
# -----------------------------------------------------------------------------
# Dequantization / Quantization Utilities
# -----------------------------------------------------------------------------
class DequantScaleRoundingMode(Enum):
ROUND_UP = 0
ROUND_DOWN = 1
def downcast_to_mxfp(src_tensor: torch.Tensor, out_quant_type: torch.dtype, axis: int,
DEQUANT_SCALE_ROUNDING_MODE: DequantScaleRoundingMode = DequantScaleRoundingMode.ROUND_UP):
"""
Convert the src weights to mx format. The src weight is quantized along the axis dimension.
If weight_quant_type is torch.uint8, we output mxfp4 where two e2m1 values are packed into a single byte.
Note that this means the k_dim of the tensor will be half of the logical k_dim.
If weight_quant_type is torch.float8_e4m3fn or torch.float8_e5m2, we output mxfp8 with the float8s are stored
in their respective formats.
"""
ndim = src_tensor.ndim
assert -ndim <= axis < ndim, f"Invalid axis {axis=}"
axis = axis if axis >= 0 else axis + ndim
# downcast
src_tensor = src_tensor.transpose(axis, src_tensor.ndim - 1)
is_fp4 = out_quant_type == torch.uint8
is_fp8 = out_quant_type in (torch.float8_e4m3fn, torch.float8_e5m2)
assert is_fp4 or is_fp8
divisor = 2 if is_fp4 else 1
L = src_tensor.shape[-1]
if is_fp4:
assert L % 2 == 0, f"axis dim must be divisible by 2 for e2m1. Got {L}"
out_shape = src_tensor.shape[:-1] + (L // divisor, )
out_scale_shape = src_tensor.shape[:-1] + (triton.cdiv(L, MXFP_BLOCK_SIZE), )
out_quant_tensor = src_tensor.new_empty(out_shape, dtype=out_quant_type)
out_scale = src_tensor.new_empty(out_scale_shape, dtype=torch.uint8)
if src_tensor.numel() > 0:
kernel_src_tensor = src_tensor.reshape(-1, src_tensor.shape[-1])
kernel_quant_tensor = out_quant_tensor.view(-1, out_quant_tensor.shape[-1])
kernel_scale = out_scale.view(-1, out_scale.shape[-1])
BLOCK_OUT_DIM = 128
BLOCK_QUANT_DIM = MXFP_BLOCK_SIZE.value
grid_out = triton.cdiv(kernel_src_tensor.shape[0], BLOCK_OUT_DIM)
grid_quant = triton.cdiv(kernel_src_tensor.shape[1], BLOCK_QUANT_DIM)
_downcast_to_mxfp[(grid_out, grid_quant)](kernel_quant_tensor, *kernel_quant_tensor.stride(), kernel_scale,
*kernel_scale.stride(), kernel_src_tensor, *kernel_src_tensor.stride(),
*kernel_src_tensor.shape, BLOCK_OUT_DIM, BLOCK_QUANT_DIM,
DEQUANT_SCALE_ROUNDING_MODE.value, num_warps=8)
out_quant_tensor = out_quant_tensor.transpose(axis, src_tensor.ndim - 1)
out_scale = out_scale.transpose(axis, src_tensor.ndim - 1)
return out_quant_tensor, out_scale
def upcast_from_mxfp(tensor: torch.Tensor, scale: torch.Tensor, dtype: torch.dtype, axis: int):
"""
Upcasts an mxfp (packed) weight tensor back to float16 or bfloat16.
The function assumes that the tensors were quantized along the given axis.
It permutes the tensor so that the quantized axis is last, reshapes to 2D,
launches the Triton upcast kernel, and then unpermutes back to the original order.
"""
ndim = tensor.ndim
assert -ndim <= axis < ndim, f"Invalid axis {axis=}"
axis = axis if axis >= 0 else axis + ndim
assert tensor.ndim == scale.ndim, (f"Weight and scale must have the same number of dimensions. "
f"Got {tensor.ndim=} and {scale.ndim=}")
# dtype checks
assert tensor.dtype in {torch.uint8, torch.float8_e5m2, torch.float8_e4m3fn}, \
f"Invalid tensor dtype {tensor.dtype=}"
assert scale.dtype == torch.uint8, f"Invalid scale dtype {scale.dtype=}"
assert dtype in (torch.float16, torch.bfloat16, torch.float32), f"Invalid output dtype {dtype=}"
# upcast
logical_quant_dim = tensor.shape[axis] * (2 if tensor.dtype == torch.uint8 else 1)
tensor = tensor.transpose(axis, tensor.ndim - 1).contiguous()
scale = scale.transpose(axis, scale.ndim - 1).contiguous()
out = torch.empty((*tensor.shape[:-1], logical_quant_dim), dtype=dtype, device=tensor.device)
reshaped_out = out.view(-1, out.shape[-1])
reshaped_tensor = tensor.view(-1, tensor.shape[-1])
reshaped_scale = scale.view(-1, scale.shape[-1])
BLOCK_OUT_DIM = 128
BLOCK_QUANT_DIM = MXFP_BLOCK_SIZE.value
blocks_out_dim = triton.cdiv(reshaped_out.shape[0], BLOCK_OUT_DIM)
blocks_quant_dim = triton.cdiv(reshaped_out.shape[1], BLOCK_QUANT_DIM)
_upcast_from_mxfp[(blocks_out_dim, blocks_quant_dim)](reshaped_out, *reshaped_out.stride(), reshaped_scale,
*reshaped_scale.stride(), reshaped_tensor,
*reshaped_tensor.stride(), *reshaped_out.shape, BLOCK_OUT_DIM,
BLOCK_QUANT_DIM, num_warps=8)
out = out.transpose(axis, scale.ndim - 1).contiguous()
return out
# ------------
def right_shift_unsigned(x, shift):
# CUDA torch does not support bit ops on uint32, so we need to mask to get unsigned right shift
return (x >> shift) & ((1 << (32 - shift)) - 1)
def get_max_quant_val(dtype: torch.dtype):
d = {torch.uint8: 6.0, torch.float8_e5m2: 57344.0, torch.float8_e4m3fn: 448.0}
assert dtype in d
return d[dtype]
def downcast_to_mxfp_torch(src_tensor: torch.Tensor, out_quant_type: torch.dtype, axis: int,
DEQUANT_SCALE_ROUNDING_MODE: DequantScaleRoundingMode = DequantScaleRoundingMode.ROUND_UP):
"""
Converts the src tensor to the output format specified by out_quant_type.
axis: The axis along which the tensors are contiguous and quantization is applied.
DEQUANT_SCALE_ROUNDING_MODE: 0 for ROUND_UP, 1 for ROUND_DOWN.
Returns:
out_quant_tensor: Quantized tensor in mx format.
• For mxfp8, the output has the same shape as src_tensor.
• For mxfp4, the size along the axis is halved, and the tensor is returned as a torch.uint8.
scale: Scale tensor (stored as uint8) computed per group of 32 elements along the axis.
Its shape is the same as src_tensor except that the axis is replaced by ceil(L/32),
where L is the original length along that axis.
"""
# This should probably be packed into its own tiny class
ndim = src_tensor.ndim
assert -ndim <= axis < ndim, f"Invalid axis {axis=}"
assert src_tensor.dtype in {torch.float32, torch.bfloat16,
torch.float16}, f"Invalid input tensor dtype {src_tensor.dtype}"
axis = axis if axis >= 0 else axis + ndim
is_fp4 = out_quant_type == torch.uint8
is_fp8 = "float8" in str(out_quant_type)
assert is_fp4 or is_fp8, f"Invalid input tensor dtype {out_quant_type}"
device = src_tensor.device
# For mxfp4 conversion, we assume the contiguous axis length is even.
if is_fp4:
axis_shape = src_tensor.size(axis)
assert axis_shape % 2 == 0, "For mxfp4 conversion the contiguous axis length must be even."
# Permute the tensor so that the contiguous axis becomes the last dimension.
src = src_tensor.transpose(axis, src_tensor.ndim - 1).to(torch.float32)
axis_shape = src.shape[-1]
# Pad the axis to be divisible by 32, in case it is not.
next_multiple = triton.cdiv(axis_shape, MXFP_BLOCK_SIZE) * MXFP_BLOCK_SIZE
pad_amount = next_multiple - axis_shape
padded_src = F.pad(src, (0, pad_amount))
valid_mask = F.pad(torch.ones_like(src, dtype=torch.bool), (0, pad_amount))
padded_axis_shape = padded_src.size(-1) # now divisible by 32
# --- Compute per-group maximums for scale ---
# Set padded entries to -1 so they dont affect the max.
abs_f = torch.abs(padded_src)
abs_f = torch.where(valid_mask, abs_f, torch.tensor(-1.0, device=device, dtype=padded_src.dtype))
# Reshape the last dimension into groups of 32.
new_shape = padded_src.shape[:-1] + (padded_axis_shape // MXFP_BLOCK_SIZE, MXFP_BLOCK_SIZE)
abs_groups = abs_f.view(*new_shape)
# Compute maximum along the group dimension (of size 32).
max_val, _ = abs_groups.max(dim=-1, keepdim=True)
# Choose a max quantization value depending on type.
max_quant_val = get_max_quant_val(out_quant_type)
dequant_scale = max_val / max_quant_val # shape: (..., padded_axis_shape//32, 1)
# Convert to int to round the FP32 scale, prior to quantization!
ds_int = dequant_scale.view(torch.int32)
if DEQUANT_SCALE_ROUNDING_MODE == DequantScaleRoundingMode.ROUND_UP:
ds_int_rounded = (ds_int + 0x007FFFFF) & 0x7F800000
else:
ds_int_rounded = ds_int & 0x7F800000
# Reinterpret back as float32.
dequant_scale_rounded = ds_int_rounded.view(torch.float32)
# Compute the quantization scale.
quant_scale = torch.where(dequant_scale_rounded == 0, torch.tensor(0.0, device=device), 1.0 / dequant_scale_rounded)
# Quantize the tensor
orig_padded_shape = padded_src.shape
padded_src_groups = padded_src.view(*new_shape)
quant_tensor = padded_src_groups * quant_scale
# Reshape back to the original shape and trim padding
quant_tensor = quant_tensor.view(orig_padded_shape)
quant_tensor = quant_tensor[..., :axis_shape]
# Finally, convert the quantized tensor to the target format
if is_fp8:
# Conversion must use satfinite PTX, so clamp before the conversion in torch to emulate this behavior
quant_tensor = torch.clamp(quant_tensor, -max_quant_val, max_quant_val)
out_weight = quant_tensor.to(out_quant_type)
else:
assert is_fp4, f"Invalid output quantization type {out_quant_type}"
# For mxfp4, perform bit-level manipulation and pack two 4-bit values per uint8.
# First, reinterpret the quantized tensor bits.
q_int = quant_tensor.contiguous().view(torch.int32)
# Extract sign, exponent, and mantissa.
signs = q_int & 0x80000000
exponents = right_shift_unsigned(q_int, 23) & 0xFF
mantissas = q_int & 0x7FFFFF
E8_BIAS = 127
E2_BIAS = 1
# Adjust mantissas for subnormals.
mantissas = torch.where(exponents < E8_BIAS, (0x400000 | right_shift_unsigned(mantissas, 1)) >>
(E8_BIAS - exponents - 1), mantissas)
exponents = torch.maximum(exponents, torch.tensor(E8_BIAS - E2_BIAS, device=device)) - (E8_BIAS - E2_BIAS)
e2m1_tmp = right_shift_unsigned(((exponents << 2) | right_shift_unsigned(mantissas, 21)) + 1, 1)
e2m1_tmp = torch.minimum(e2m1_tmp, torch.tensor(0x7, device=device))
e2m1_value = (right_shift_unsigned(signs, 28) | e2m1_tmp).to(torch.uint8) # shape: (..., even_axis_shape)
# Pack pairs of 4-bit values along the last dimension.
e2m1_value = e2m1_value.view(*e2m1_value.shape[:-1], axis_shape // 2, 2)
evens = e2m1_value[..., 0]
odds = e2m1_value[..., 1]
out_weight = evens | (odds << 4) # shape: (..., axis_shape//2)
# --- Process and output the scale ---
dq_scale = (ds_int_rounded.view(*dequant_scale.shape) >> 23).to(torch.uint8) # shape: (..., axis_shape//32, 1)
dq_scale = dq_scale.squeeze(-1)
out_weight = out_weight.transpose(axis, src_tensor.ndim - 1)
dq_scale = dq_scale.transpose(axis, src_tensor.ndim - 1)
return out_weight, dq_scale
def cvt_e2m1_to_fp32(input_tensor):
assert input_tensor.dtype == torch.uint8
input_tensor = input_tensor.to(torch.int32)
evens = input_tensor & 0xF
odds = (input_tensor >> 4) & 0xF
vals = [0.0, 0.5, 1, 1.5, 2, 3, 4, 6]
outputs = torch.tensor(vals, dtype=torch.float32, device=input_tensor.device)
outputs = torch.cat([outputs, -outputs])
even_floats = outputs[evens]
odd_floats = outputs[odds]
output_tensor = torch.stack([even_floats, odd_floats], dim=-1)
output_tensor = output_tensor.view(*input_tensor.shape[:-1], -1)
return output_tensor
def upcast_from_mxfp_torch(tensor: torch.Tensor, scale: torch.Tensor, target_dtype: torch.dtype, axis: int):
"""
Converts the mxfp4/mxfp8 tensor to the target format specified by target_dtype.
axis: The axis along which dequantization is applied.
Returns:
out_weight: Tensor in the target format.
"""
ndim = tensor.ndim
assert -ndim <= axis < ndim, f"Invalid axis {axis=}"
is_fp8 = tensor.dtype == torch.float8_e4m3fn or tensor.dtype == torch.float8_e5m2
assert is_fp8 or tensor.dtype == torch.uint8, f"Invalid input quantization type {tensor.dtype}"
# Permute the tensor and scale so that the quantization axis becomes the last dimension
axis = axis if axis >= 0 else axis + ndim
scale = scale.transpose(axis, scale.ndim - 1)
tensor = tensor.transpose(axis, tensor.ndim - 1)
dq_scale = (scale.to(torch.int32) << 23).view(torch.float32) # Shift to the exponent and bitcast to fp32
if tensor.dtype == torch.uint8:
fp32_tensor = cvt_e2m1_to_fp32(tensor)
else:
fp32_tensor = tensor.to(torch.float32)
logical_quant_dim = tensor.shape[-1] * (2 if tensor.dtype == torch.uint8 else 1)
axis_shape = fp32_tensor.size(-1)
padded_axis_shape = triton.cdiv(logical_quant_dim, MXFP_BLOCK_SIZE) * MXFP_BLOCK_SIZE
pad_size = padded_axis_shape - axis_shape
padded_tensor = F.pad(fp32_tensor, (0, pad_size))
new_axis_shape = padded_tensor.shape[-1]
new_shape = padded_tensor.shape[:-1] + (new_axis_shape // MXFP_BLOCK_SIZE, MXFP_BLOCK_SIZE)
padded_tensor = padded_tensor.view(*new_shape)
dq_scale_padded = dq_scale.unsqueeze(-1) # shape: [..., ceil(axis_shape/32), 1]
out_padded = padded_tensor * dq_scale_padded
# Flatten back and remove the padded tail
out_padded = out_padded.view(*fp32_tensor.shape[:-1], new_axis_shape)
out_tensor = out_padded[..., :axis_shape]
out_tensor = out_tensor.to(target_dtype).contiguous()
out_tensor = out_tensor.transpose(axis, tensor.ndim - 1)
return out_tensor
dequantize_mxfp8_fn = _dequantize_mxfp8_fn

View File

@@ -261,6 +261,13 @@ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
return s
def marlin_permute_bias(s: torch.Tensor) -> torch.Tensor:
origin_shape = s.shape
_, scale_perm_single = get_scale_perms()
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
return s.reshape(*origin_shape).contiguous()
def marlin_moe_permute_scales(
s: torch.Tensor,
size_k: int,
@@ -410,6 +417,7 @@ def apply_gptq_marlin_linear(
output = ops.gptq_marlin_gemm(reshaped_x,
None,
weight,
bias,
weight_scale,
None,
weight_zp,
@@ -425,9 +433,6 @@ def apply_gptq_marlin_linear(
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)
@@ -456,6 +461,7 @@ def apply_awq_marlin_linear(
output = ops.gptq_marlin_gemm(reshaped_x,
None,
weight,
bias,
weight_scale,
None,
weight_zp,
@@ -470,7 +476,4 @@ def apply_awq_marlin_linear(
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)

View File

@@ -8,8 +8,8 @@ import torch
import vllm._custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales,
should_use_atomic_add_reduce)
USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_bias,
marlin_permute_scales, should_use_atomic_add_reduce)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
@@ -22,7 +22,7 @@ def is_fp4_marlin_supported():
return current_platform.has_device_capability(80)
def fp4_marlin_process_scales(marlin_scales):
def nvfp4_marlin_process_scales(marlin_scales):
if not (marlin_scales >= 0).all():
logger.warning_once(
"NVFP4 Marlin assumes the scales to be >=0, but has encountered "
@@ -56,7 +56,20 @@ def fp4_marlin_process_scales(marlin_scales):
return marlin_scales
def fp4_marlin_process_global_scale(global_scale):
def mxfp4_marlin_process_scales(marlin_scales):
# 8 is the number of scale number using by one thread
marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8)
marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape(
marlin_scales.size(0) * 2, -1)
# fit the layout of fp8 dequantization
marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view(
marlin_scales.size(0), -1)
marlin_scales = marlin_scales.to(torch.float8_e8m0fnu)
return marlin_scales
def nvfp4_marlin_process_global_scale(global_scale):
assert global_scale.dtype in [torch.half, torch.bfloat16]
fp4_exponent = 2
if global_scale.dtype == torch.half:
@@ -73,7 +86,7 @@ def apply_fp4_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
weight_scale_2: torch.Tensor,
weight_scale_2: Optional[torch.Tensor],
workspace: torch.Tensor,
size_n: int,
size_k: int,
@@ -94,6 +107,7 @@ def apply_fp4_marlin_linear(
output = ops.gptq_marlin_gemm(a=reshaped_x,
c=None,
b_q_weight=weight,
b_bias=bias,
b_scales=weight_scale,
global_scale=weight_scale_2,
b_zeros=None,
@@ -107,9 +121,6 @@ def apply_fp4_marlin_linear(
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)
@@ -120,6 +131,9 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads.")
is_nvfp4 = hasattr(layer, "weight_scale_2")
group_size = 16 if is_nvfp4 else 32
part_size_n = layer.output_size_per_partition
part_size_k = layer.input_size_per_partition
param_dtype = layer.params_dtype
@@ -145,18 +159,35 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
# WEIGHT SCALES
# Permute scales
weight_scale = layer.weight_scale.T.to(param_dtype)
weight_scale = layer.weight_scale.T.contiguous()
if not is_nvfp4:
weight_scale = weight_scale.view(torch.float8_e8m0fnu)
weight_scale = weight_scale.to(param_dtype)
weight_scale = marlin_permute_scales(s=weight_scale,
size_k=part_size_k,
size_n=part_size_n,
group_size=16)
weight_scale = fp4_marlin_process_scales(weight_scale)
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
group_size=group_size)
weight_scale_2 = layer.weight_scale_2.to(param_dtype)
weight_scale_2 = fp4_marlin_process_global_scale(weight_scale_2)
layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2,
requires_grad=False)
if is_nvfp4:
weight_scale = nvfp4_marlin_process_scales(weight_scale)
layer.weight_scale = torch.nn.Parameter(weight_scale,
requires_grad=False)
weight_scale_2 = layer.weight_scale_2.to(param_dtype)
weight_scale_2 = nvfp4_marlin_process_global_scale(weight_scale_2)
layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2,
requires_grad=False)
else:
weight_scale = mxfp4_marlin_process_scales(weight_scale)
layer.weight_scale = torch.nn.Parameter(weight_scale,
requires_grad=False)
if hasattr(layer, "bias") and layer.bias is not None:
assert layer.bias.shape == (part_size_n, )
bias = marlin_permute_bias(layer.bias)
layer.bias = torch.nn.Parameter(bias, requires_grad=False)
return
@@ -168,6 +199,9 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads.")
is_nvfp4 = hasattr(layer, "w13_weight_scale_2")
group_size = 16 if is_nvfp4 else 32
e = layer.num_experts
k = layer.hidden_size
n = layer.intermediate_size_per_partition
@@ -208,8 +242,13 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
# WEIGHT SCALES
# Permute scales
for name in ["w13", "w2"]:
scales = getattr(layer, name + "_weight_scale").to(param_dtype)
global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype)
scales = getattr(layer, name + "_weight_scale")
if not is_nvfp4:
scales = scales.view(torch.float8_e8m0fnu)
scales = scales.to(param_dtype)
if is_nvfp4:
global_scale = getattr(layer,
name + "_weight_scale_2").to(param_dtype)
tensor_list = []
if "w13" in name:
@@ -218,23 +257,47 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
size_n, size_k = k, n
for i in range(e):
marlin_scales = marlin_permute_scales(s=scales[i].T,
scale = scales[i].T
marlin_scales = marlin_permute_scales(s=scale,
size_k=size_k,
size_n=size_n,
group_size=16)
marlin_scales = fp4_marlin_process_scales(marlin_scales)
group_size=group_size)
if is_nvfp4:
marlin_scales = nvfp4_marlin_process_scales(marlin_scales)
else:
marlin_scales = mxfp4_marlin_process_scales(marlin_scales)
tensor_list.append(marlin_scales)
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
scales = torch.nn.Parameter(scales, requires_grad=False)
setattr(layer, name + "_weight_scale", scales)
global_scale = fp4_marlin_process_global_scale(global_scale)
global_scale = torch.nn.Parameter(global_scale, requires_grad=False)
setattr(layer, name + "_weight_scale_2", global_scale)
if is_nvfp4:
global_scale = nvfp4_marlin_process_global_scale(global_scale)
global_scale = torch.nn.Parameter(global_scale,
requires_grad=False)
setattr(layer, name + "_weight_scale_2", global_scale)
# BIAS
# Permute bias
for name in ["w13_bias", "w2_bias"]:
if not hasattr(layer, name):
continue
bias = getattr(layer, name).to(param_dtype)
tensor_list = []
for i in range(e):
expert_bias = bias[i]
tensor_list.append(marlin_permute_bias(expert_bias))
bias = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
bias = torch.nn.Parameter(bias, requires_grad=False)
setattr(layer, name, bias)
def rand_marlin_weight_fp4_like(weight, group_size):
def rand_marlin_weight_nvfp4_like(weight, group_size):
assert group_size > 0
size_n, size_k = weight.shape
device = weight.device
@@ -276,8 +339,58 @@ def rand_marlin_weight_fp4_like(weight, group_size):
size_k=size_k,
size_n=size_n,
group_size=group_size)
marlin_scales = fp4_marlin_process_scales(marlin_scales)
marlin_scales = nvfp4_marlin_process_scales(marlin_scales)
global_scale = fp4_marlin_process_global_scale(global_scale)
global_scale = nvfp4_marlin_process_global_scale(global_scale)
return weight_ref.T, marlin_qweight, marlin_scales, global_scale
def rand_marlin_weight_mxfp4_like(weight, group_size):
assert group_size > 0
size_n, size_k = weight.shape
device = weight.device
scales = torch.randint(100,
125, (size_n, size_k // group_size),
dtype=torch.uint8,
device=weight.device)
scales = scales.view(torch.float8_e8m0fnu)
fp4_weight = torch.randint(0,
256, (size_n, size_k // 2),
dtype=torch.uint8,
device=weight.device)
fp4_weight_part_1 = ((fp4_weight & 0b10000000) |
((fp4_weight & 0b01110000) >> 2))
fp4_weight_part_1 = fp4_weight_part_1.view(torch.float8_e4m3fn)
fp4_weight_part_1 = fp4_weight_part_1.to(weight.dtype) * (2**6)
fp4_weight2 = fp4_weight << 4
fp4_weight_part_2 = ((fp4_weight2 & 0b10000000) |
((fp4_weight2 & 0b01110000) >> 2))
fp4_weight_part_2 = fp4_weight_part_2.view(torch.float8_e4m3fn)
fp4_weight_part_2 = fp4_weight_part_2.to(weight.dtype) * (2**6)
weight_ref = torch.cat(
[fp4_weight_part_2.unsqueeze(2),
fp4_weight_part_1.unsqueeze(2)], 2).view(size_n, size_k)
weight_ref = weight_ref * \
scales.repeat_interleave(group_size, 1).to(weight.dtype)
marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=fp4_weight.view(torch.int32).T.contiguous(),
perm=torch.empty(0, dtype=torch.int, device=device),
size_k=size_k,
size_n=size_n,
num_bits=4,
)
marlin_scales = marlin_permute_scales(s=scales.T.to(weight.dtype),
size_k=size_k,
size_n=size_n,
group_size=group_size)
marlin_scales = mxfp4_marlin_process_scales(marlin_scales)
return weight_ref.T, marlin_qweight, marlin_scales.to(torch.float8_e8m0fnu)

View File

@@ -1,45 +1,133 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional
import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
logger = init_logger(__name__)
OCP_MX_BLOCK_SIZE = 32
def per_token_group_quant_mxfp4(x: torch.Tensor,
block_k: int,
scale_calculation_mode: str = "even"
) -> tuple[torch.Tensor, torch.Tensor]:
def _swizzle_mxfp4(quant_tensor, scale, num_warps):
""" weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel
"""
import triton_kernels.matmul_ogs_details.opt_flags as opt_flags
from triton_kernels.numerics import InFlexData
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
from triton_kernels.tensor_details import layout
from triton_kernels.tensor_details.layout import StridedLayout
if (current_platform.is_cuda()
and current_platform.is_device_capability(90)
and not is_torch_equal_or_newer("2.8.1")):
logger.warning_once(
"Mxfp4 on hopper is running on torch < 2.8.1, "
"this cause swizling to be disabled, which may "
"cause performance degradation. Please upgrade to torch nightly")
value_layout, value_layout_opts = StridedLayout, dict()
scale_layout, scale_layout_opts = StridedLayout, dict()
else:
value_layout, value_layout_opts = \
layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
scale_layout, scale_layout_opts = (
layout.make_default_matmul_mxfp4_w_scale_layout(
mx_axis=1, num_warps=num_warps))
if current_platform.is_cuda() and \
current_platform.is_device_capability(100):
constraints = {
"is_persistent": True,
"epilogue_subtile": 1,
}
opt_flags.update_opt_flags_constraints(constraints)
# transpose the tensor so that the quantization axis is on dim1
quant_tensor = quant_tensor.transpose(-2, -1)
scale = scale.transpose(-2, -1)
quant_tensor = convert_layout(wrap_torch_tensor(quant_tensor, dtype=FP4),
value_layout, **value_layout_opts)
scale = convert_layout(wrap_torch_tensor(scale), scale_layout,
**scale_layout_opts)
return quant_tensor, InFlexData(), scale
def _can_support_mxfp4(use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
scoring_func: str = "softmax",
activation: str = "swiglu_oai",
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None):
return not (use_grouped_topk or topk_group or num_expert_group
or expert_map or custom_routing_function
or e_score_correction_bias or apply_router_weight_on_input
or scoring_func != "softmax" or activation != "swiglu_oai"
or expert_load_view or logical_to_physical_map
or logical_replica_count)
def _dequant_mxfp4(x: torch.Tensor, scale: torch.Tensor,
float_dtype: torch.dtype) -> torch.Tensor:
try:
from quark.torch.kernel.hw_emulation.hw_emulation_interface import (
fake_quantize_fp4_fp6_per_group_with_scale)
from quark.torch.quantization.utils import (even_round,
reshape_to_blocks)
from quark.torch.kernel import mx
except ImportError as err:
raise ImportError("The package `amd-quark` is required to use "
"MX-FP4 models. Please install it with `pip install "
"amd-quark`.") from err
axis = -1
block_x = reshape_to_blocks(x, block_k, axis)
amax, _ = torch.max(torch.abs(block_x), dim=-1, keepdim=True)
amax = amax.squeeze(-1)
return mx.dq_mxfp4(x, scale, float_dtype)
# TODO: there are other rounding strategies supported in quark and in the
# config.json that we do not check for here!
if scale_calculation_mode != "even":
raise NotImplementedError(
f"Scale calculation mode {scale_calculation_mode} is not yet "
"supported in MX-FP4 quantization")
scale = even_round(amax, "fp4")
# Apply dequantize(quantize(x)).
x = fake_quantize_fp4_fp6_per_group_with_scale(
x,
scale.to(x.device),
axis=axis,
group_size=block_k,
quant_dtype="fp4",
def _dequant_mxfp4_fake(x: torch.Tensor, scale: torch.Tensor,
float_dtype: torch.dtype) -> torch.Tensor:
return torch.empty((*x.shape[:-1], x.shape[-1] * 2),
dtype=float_dtype,
device=x.device)
def _quant_dequant_mxfp4(x: torch.Tensor,
scale_calculation_mode: str = "even") -> torch.Tensor:
try:
from quark.torch.kernel import mx
except ImportError as err:
raise ImportError("The package `amd-quark` is required to use "
"MX-FP4 models. Please install it with `pip install "
"amd-quark`.") from err
return mx.qdq_mxfp4(x, scale_calculation_mode)
def _quant_dequant_mxfp4_fake(x: torch.Tensor,
scale_calculation_mode: str = "even"
) -> torch.Tensor:
return torch.empty_like(x)
try:
direct_register_custom_op(
op_name="dequant_mxfp4",
op_func=_dequant_mxfp4,
mutates_args=[],
fake_impl=_dequant_mxfp4_fake,
)
dequant_mxfp4 = torch.ops.vllm.dequant_mxfp4
except AttributeError as error:
raise error
return x, scale
try:
direct_register_custom_op(
op_name="quant_dequant_mxfp4",
op_func=_quant_dequant_mxfp4,
mutates_args=[],
fake_impl=_quant_dequant_mxfp4_fake,
)
quant_dequant_mxfp4 = torch.ops.vllm.quant_dequant_mxfp4
except AttributeError as error:
raise error

View File

@@ -3,22 +3,41 @@
"""This file is used for /tests and /benchmarks"""
from collections.abc import Mapping
from types import MappingProxyType
from typing import Optional
from typing import ClassVar, NamedTuple, Optional
import numpy
import torch
from vllm._custom_ops import cutlass_scaled_mm_supports_fp4
from vllm.model_executor.layers.quantization.qqq import (
MARLIN_QQQ_SUPPORTED_NUM_BITS)
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
# Use proxy as NamedTuple direct subclasses cannot have static members
class _GroupShape(NamedTuple):
row: int
col: int
class GroupShape(_GroupShape):
"""
This class describes the quantization group shape.
It includes static members for common shapes (per-tensor, per-token).
"""
# Aliases for common quantization group shapes
PER_TENSOR: ClassVar['GroupShape']
PER_TOKEN: ClassVar['GroupShape']
GroupShape.PER_TENSOR = GroupShape(-1, -1)
GroupShape.PER_TOKEN = GroupShape(1, -1)
# Normalize the group_shape to the full extent for any dims that are -1
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: tuple[int,
int]):
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape):
# -1 means full extent
return (group_shape[0] if group_shape[0] > 0 else x.shape[-2],
group_shape[1] if group_shape[1] > 0 else x.shape[-1])
@@ -58,7 +77,7 @@ def group_broadcast(t, shape):
# (i.e. per-token-per-group)
def scaled_quantize(
x: torch.Tensor,
group_shape: tuple[int, int],
group_shape: GroupShape,
quant_dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor]:
group_shape = _normalize_quant_group_shape(x, group_shape)
@@ -99,7 +118,7 @@ def scaled_quantize(
def scaled_dequantize(
x_q: torch.Tensor,
x_s: torch.Tensor,
group_shape: Optional[tuple[int, int]] = None,
group_shape: Optional[GroupShape] = None,
out_dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor]:
if group_shape is not None:
@@ -332,6 +351,10 @@ def quantize_weights(w: torch.Tensor,
)
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
def gptq_quantize_weights(w: torch.Tensor,
quant_type: ScalarType,
group_size: int,
@@ -571,3 +594,56 @@ def awq_pack(
q_w = q_w.reshape((-1, size_n)).contiguous()
return pack_cols(q_w, num_bits, size_k, size_n)
def swizzle_blockscale(scale: torch.Tensor) -> torch.Tensor:
"""
Pad and block-interleave the FP4 block-scales so that they match the data
layout expected by the CUTLASS / FlashInfer kernels.
Parameters
----------
scale: torch.Tensor
Returns
-------
torch.Tensor
The swizzled tensor with the same logical shape as *scale*.
"""
assert scale.dtype == torch.float8_e4m3fn, (
"swizzle_blockscale expects the input tensor to be in "
"torch.float8_e4m3fn format.")
scale_ndim = scale.ndim
if scale_ndim == 2:
scale = scale.unsqueeze(0) # (1, M, K)
assert scale.ndim == 3, "Expected a 2-D or 3-D tensor for block scales."
B, M, K = scale.shape
def _round_up(x: int, m: int) -> int:
return (x + m - 1) // m * m
M_padded = _round_up(M, 128)
K_padded = _round_up(K, 4)
padded = torch.zeros((B, M_padded, K_padded),
dtype=scale.dtype,
device=scale.device)
padded[:B, :M, :K] = scale
# Reshape / permute to the layout required by the kernel.
padded = padded.reshape(B, M_padded // 128, 4, 32, K_padded // 4, 4)
swizzled = padded.permute(0, 1, 4, 3, 2, 5).contiguous().cuda()
if scale_ndim == 2:
return swizzled.reshape(M, K)
return swizzled.reshape(B, M, K)
def cutlass_fp4_supported() -> bool:
if not current_platform.is_cuda():
return False
capability_tuple = current_platform.get_device_capability()
capability = -1 if capability_tuple is None else capability_tuple.to_int()
return cutlass_scaled_mm_supports_fp4(capability)