[gpt-oss] Add gpt-oss mxfp4 support
This commit is contained in:
@@ -36,6 +36,7 @@ QuantizationMethods = Literal[
|
||||
"moe_wna16",
|
||||
"torchao",
|
||||
"auto-round",
|
||||
"mxfp4",
|
||||
]
|
||||
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
|
||||
|
||||
@@ -108,6 +109,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
from .marlin import MarlinConfig
|
||||
from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config
|
||||
from .moe_wna16 import MoeWNA16Config
|
||||
from .mxfp4 import Mxfp4Config
|
||||
from .neuron_quant import NeuronQuantConfig
|
||||
from .ptpc_fp8 import PTPCFp8Config
|
||||
from .qqq import QQQConfig
|
||||
@@ -143,6 +145,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
"moe_wna16": MoeWNA16Config,
|
||||
"torchao": TorchAOConfig,
|
||||
"auto-round": AutoRoundConfig,
|
||||
"mxfp4": Mxfp4Config,
|
||||
}
|
||||
# Update the `method_to_config` with customized quantization methods.
|
||||
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
|
||||
|
||||
581
vllm/model_executor/layers/quantization/mxfp4.py
Normal file
581
vllm/model_executor/layers/quantization/mxfp4.py
Normal file
@@ -0,0 +1,581 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import envs
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
||||
FusedMoEMethodBase, fused_experts)
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
|
||||
triton_kernel_moe_forward)
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
prepare_moe_fp4_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
_can_support_mxfp4, _swizzle_mxfp4)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils import (has_triton_kernels, is_torch_equal_or_newer,
|
||||
next_power_of_2, round_up)
|
||||
|
||||
if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
||||
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
|
||||
# from flashinfer.fused_moe import cutlass_fused_moe
|
||||
from flashinfer import (mxfp8_quantize, shuffle_matrix_a,
|
||||
shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe)
|
||||
|
||||
|
||||
class Mxfp4Config(QuantizationConfig):
|
||||
|
||||
def __init__(self, ignored_layers: Optional[list[str]] = None):
|
||||
super().__init__()
|
||||
self.ignored_layers = ignored_layers
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
return cls()
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "mxfp4"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return []
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if self.ignored_layers and is_layer_skipped(
|
||||
prefix=prefix,
|
||||
ignored_layers=self.ignored_layers,
|
||||
fused_mapping=self.packed_modules_mapping):
|
||||
return UnquantizedLinearMethod()
|
||||
raise NotImplementedError("Mxfp4 linear layer is not implemented")
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return Mxfp4MoEMethod(layer.moe_config)
|
||||
elif isinstance(layer, Attention):
|
||||
raise NotImplementedError(
|
||||
"Mxfp4 attention layer is not implemented")
|
||||
return None
|
||||
|
||||
|
||||
class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def __init__(self, moe: FusedMoEConfig):
|
||||
super().__init__()
|
||||
self.topk_indices_dtype = None
|
||||
self.moe = moe
|
||||
self.use_marlin = self._should_use_marlin()
|
||||
|
||||
def _should_use_marlin(self):
|
||||
if envs.VLLM_MXFP4_USE_MARLIN is not None:
|
||||
return envs.VLLM_MXFP4_USE_MARLIN
|
||||
# if current_platform.is_cuda() and \
|
||||
# not current_platform.has_device_capability(100):
|
||||
# if not current_platform.is_device_capability(90):
|
||||
# # marlin kernel has better performance on ampere
|
||||
# return True
|
||||
# if not has_triton_kernels():
|
||||
# return True
|
||||
# if not is_torch_equal_or_newer("2.8.0"):
|
||||
# return True
|
||||
return False
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
self.num_experts = num_experts
|
||||
weight_dtype = torch.uint8
|
||||
scale_dtype = torch.uint8
|
||||
|
||||
# FIXME (zyongye): ship after torch and safetensors support mxfp4
|
||||
# is_torch_mxfp4_available = (
|
||||
# hasattr(torch, "float4_e2m1fn_x2") and
|
||||
# hasattr(torch, "float8_e8m0fnu"))
|
||||
# if is_torch_mxfp4_available:
|
||||
# weight_dtype = torch.float4_e2m1fn_x2
|
||||
# scale_dtype = torch.float8_e8m0fnu
|
||||
|
||||
mxfp4_block = 32
|
||||
|
||||
intermediate_size_per_partition_after_pad = \
|
||||
intermediate_size_per_partition
|
||||
if self.use_marlin:
|
||||
# The moe marlin kernel requires that for each linear
|
||||
# n % 256 == 0 and k % 128 == 0.
|
||||
# In gate_up_proj:
|
||||
# n = 2 * intermediate_size_per_partition_after_pad
|
||||
# k = hidden_size
|
||||
# In down_proj
|
||||
# n = hidden_size
|
||||
# k = intermediate_size_per_partition_after_pad
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size_per_partition, 128)
|
||||
hidden_size = round_up(hidden_size, 256)
|
||||
|
||||
layer.params_dtype = params_dtype
|
||||
layer.num_experts = num_experts
|
||||
layer.hidden_size = hidden_size
|
||||
layer.intermediate_size_per_partition = \
|
||||
intermediate_size_per_partition_after_pad
|
||||
elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
||||
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
|
||||
# pad the intermediate size to be a multiple of 2 * mxfp4_block
|
||||
# for to hold non-uniform sharded tensor as well as swizzling
|
||||
# other padding to increase performance
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size_per_partition, 256)
|
||||
hidden_size = round_up(hidden_size, 256)
|
||||
elif current_platform.is_rocm():
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size_per_partition, 128)
|
||||
else:
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size_per_partition, 64)
|
||||
|
||||
self.intermediate_size = intermediate_size_per_partition_after_pad
|
||||
self.hidden_size = hidden_size
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition_after_pad,
|
||||
hidden_size // 2,
|
||||
dtype=weight_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition_after_pad,
|
||||
hidden_size // mxfp4_block,
|
||||
dtype=scale_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
|
||||
w13_bias = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition_after_pad,
|
||||
dtype=torch.bfloat16,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_bias", w13_bias)
|
||||
set_weight_attrs(w13_bias, extra_weight_attrs)
|
||||
|
||||
# down_proj (row parallel)
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition_after_pad // 2,
|
||||
dtype=weight_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition_after_pad // mxfp4_block,
|
||||
dtype=scale_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
w2_bias = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
dtype=torch.bfloat16,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_bias", w2_bias)
|
||||
set_weight_attrs(w2_bias, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if self.use_marlin:
|
||||
prepare_moe_fp4_layer_for_marlin(layer)
|
||||
elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
||||
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
|
||||
layer.gemm1_alpha = Parameter(torch.tensor(
|
||||
[1.702] * self.num_experts, dtype=torch.float32).cuda(),
|
||||
requires_grad=False)
|
||||
layer.gemm1_beta = Parameter(torch.tensor(
|
||||
[1.0] * self.num_experts, dtype=torch.float32).cuda(),
|
||||
requires_grad=False)
|
||||
layer.gemm1_clamp_limit = Parameter(torch.tensor(
|
||||
[7.0] * self.num_experts, dtype=torch.float32).cuda(),
|
||||
requires_grad=False)
|
||||
sf_block_size = 32 # mxfp4 block size
|
||||
|
||||
assert (layer.w13_weight.dim() == 3
|
||||
and layer.w13_weight.shape[0] == self.num_experts
|
||||
and layer.w13_weight.shape[1] == self.intermediate_size * 2
|
||||
and layer.w13_weight.shape[2] == self.hidden_size // 2)
|
||||
assert (layer.w13_weight_scale.dim() == 3
|
||||
and layer.w13_weight_scale.shape[0] == self.num_experts
|
||||
and layer.w13_weight_scale.shape[1]
|
||||
== self.intermediate_size * 2
|
||||
and layer.w13_weight_scale.shape[2]
|
||||
== self.hidden_size // sf_block_size)
|
||||
assert (layer.w2_weight.dim() == 3
|
||||
and layer.w2_weight.shape[0] == self.num_experts
|
||||
and layer.w2_weight.shape[1] == self.hidden_size and
|
||||
layer.w2_weight.shape[2] == self.intermediate_size // 2)
|
||||
assert (layer.w2_weight_scale.dim() == 3
|
||||
and layer.w2_weight_scale.shape[1] == self.hidden_size
|
||||
and layer.w2_weight_scale.shape[2]
|
||||
== self.intermediate_size // sf_block_size)
|
||||
assert (layer.w13_bias.dim() == 2
|
||||
and layer.w13_bias.shape[0] == self.num_experts
|
||||
and layer.w13_bias.shape[1] == self.intermediate_size * 2)
|
||||
assert (layer.w2_bias.dim() == 2
|
||||
and layer.w2_bias.shape[0] == self.num_experts
|
||||
and layer.w2_bias.shape[1] == self.hidden_size)
|
||||
|
||||
w13_weight_scale = layer.w13_weight_scale.data
|
||||
w2_weight_scale = layer.w2_weight_scale.data
|
||||
w13_weight = layer.w13_weight.data
|
||||
w2_weight = layer.w2_weight.data
|
||||
w13_bias = layer.w13_bias.data.to(torch.float32)
|
||||
w2_bias = layer.w2_bias.data.to(torch.float32)
|
||||
|
||||
# Swap w1 and w3 as the defenition of
|
||||
# swiglu is different in the trtllm-gen
|
||||
def swap_every_two_rows(x, axis=-1):
|
||||
shape = x.shape
|
||||
if axis < 0:
|
||||
axis = len(shape) + axis
|
||||
|
||||
# Create a new shape with pairs swapped along specified axis
|
||||
new_shape = list(shape)
|
||||
new_shape[axis] = shape[axis] // 2
|
||||
new_shape.insert(axis + 1, 2)
|
||||
|
||||
# Reshape to expose pairs, swap them, and reshape back
|
||||
x = x.reshape(*new_shape)
|
||||
x = x.flip(axis + 1)
|
||||
new_shape = list(shape)
|
||||
return x.reshape(*new_shape)
|
||||
|
||||
w13_weight_scale = swap_every_two_rows(w13_weight_scale, -2)
|
||||
w13_weight = swap_every_two_rows(w13_weight, -2)
|
||||
w13_bias = swap_every_two_rows(w13_bias, -1)
|
||||
|
||||
# Do not interleave as the checkpoint is already interleaved
|
||||
|
||||
# Shuffle weights and scaling factors for transposed mma output
|
||||
gemm1_weights_mxfp4_shuffled = []
|
||||
gemm1_scales_mxfp4_shuffled = []
|
||||
gemm2_weights_mxfp4_shuffled = []
|
||||
gemm2_scales_mxfp4_shuffled = []
|
||||
gemm1_bias_shuffled = []
|
||||
gemm2_bias_shuffled = []
|
||||
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
|
||||
for i in range(self.num_experts):
|
||||
gemm1_weights_mxfp4_shuffled.append(
|
||||
shuffle_matrix_a(w13_weight[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
gemm1_scales_mxfp4_shuffled.append(
|
||||
shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
gemm1_bias_shuffled.append(
|
||||
shuffle_matrix_a(w13_bias[i].clone().reshape(-1, 1),
|
||||
epilogue_tile_m))
|
||||
|
||||
gemm2_weights_mxfp4_shuffled.append(
|
||||
shuffle_matrix_a(w2_weight[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
gemm2_scales_mxfp4_shuffled.append(
|
||||
shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
gemm2_bias_shuffled.append(
|
||||
shuffle_matrix_a(w2_bias[i].clone().reshape(-1, 1),
|
||||
epilogue_tile_m))
|
||||
|
||||
w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled)
|
||||
w13_weight_scale = torch.stack(
|
||||
gemm1_scales_mxfp4_shuffled).reshape(
|
||||
self.num_experts, 2 * self.intermediate_size,
|
||||
self.hidden_size // sf_block_size).view(
|
||||
torch.float8_e4m3fn)
|
||||
|
||||
w2_weight = torch.stack(gemm2_weights_mxfp4_shuffled)
|
||||
w2_weight_scale = torch.stack(gemm2_scales_mxfp4_shuffled).reshape(
|
||||
self.num_experts, self.hidden_size, self.intermediate_size //
|
||||
sf_block_size).view(torch.float8_e4m3fn)
|
||||
|
||||
layer.w13_weight = Parameter(w13_weight, requires_grad=False)
|
||||
layer.w13_weight_scale = Parameter(w13_weight_scale,
|
||||
requires_grad=False)
|
||||
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
|
||||
layer.w2_weight_scale = Parameter(w2_weight_scale,
|
||||
requires_grad=False)
|
||||
layer.w13_bias = Parameter(
|
||||
torch.stack(gemm1_bias_shuffled).reshape(self.num_experts, -1),
|
||||
requires_grad=False)
|
||||
layer.w2_bias = Parameter(torch.stack(gemm2_bias_shuffled).reshape(
|
||||
self.num_experts, -1),
|
||||
requires_grad=False)
|
||||
elif has_triton_kernels():
|
||||
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
||||
|
||||
w13_bias = layer.w13_bias.to(torch.float32)
|
||||
w2_bias = layer.w2_bias.to(torch.float32)
|
||||
|
||||
layer.w13_bias = Parameter(w13_bias, requires_grad=False)
|
||||
layer.w2_bias = Parameter(w2_bias, requires_grad=False)
|
||||
|
||||
# FIXME warp need to be adjusted based on batch size
|
||||
# only apply to batched mode
|
||||
if self.moe.use_ep:
|
||||
num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
|
||||
else:
|
||||
num_warps = 8
|
||||
|
||||
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
|
||||
layer.w13_weight, layer.w13_weight_scale, num_warps)
|
||||
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
|
||||
layer.w2_weight, layer.w2_weight_scale, num_warps)
|
||||
|
||||
self.w13_precision_config = PrecisionConfig(
|
||||
weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex))
|
||||
self.w2_precision_config = PrecisionConfig(
|
||||
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex))
|
||||
|
||||
self.w13_weight_triton_tensor = w13_weight
|
||||
self.w2_weight_triton_tensor = w2_weight
|
||||
|
||||
# need to delete the original weights to save memory on single GPU
|
||||
del layer.w13_weight
|
||||
del layer.w2_weight
|
||||
layer.w13_weight = None
|
||||
layer.w2_weight = None
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
# normal triton
|
||||
from .triton_kernels_numerics_details.mxfp import upcast_from_mxfp
|
||||
w13_weight = upcast_from_mxfp(
|
||||
layer.w13_weight, layer.w13_weight_scale, dtype=torch.bfloat16, axis=-1
|
||||
)
|
||||
w2_weight = upcast_from_mxfp(
|
||||
layer.w2_weight, layer.w2_weight_scale, dtype=torch.bfloat16, axis=-1
|
||||
)
|
||||
del layer.w13_weight
|
||||
del layer.w2_weight
|
||||
del layer.w13_weight_scale
|
||||
del layer.w2_weight_scale
|
||||
layer.w13_weight = Parameter(w13_weight.data, requires_grad=False)
|
||||
layer.w2_weight = Parameter(w2_weight.data, requires_grad=False)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int):
|
||||
# Number of tokens in the input tensor.
|
||||
num_tokens = x.shape[0]
|
||||
# Factor to account for the imbalance of the experts.
|
||||
# factor equals to the
|
||||
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
|
||||
# - 1.0 means perfect expert distribution.
|
||||
# - > 1.0 means some experts have more
|
||||
# tokens than the perfect distribution.
|
||||
# - < 1.0 does not make sense.
|
||||
imbalance_factor = 1.3
|
||||
# Calculate the number of tokens per expert
|
||||
# assuming perfect distribution.
|
||||
num_tokens_per_expert = (num_tokens * top_k) // self.num_experts
|
||||
# Apply the imbalance factor.
|
||||
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
|
||||
# And pad the number to the next power of 2.
|
||||
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
|
||||
# Cap to 8-64 tokens per CTA tile
|
||||
# as it's the range supported by the kernel.
|
||||
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
||||
|
||||
return tile_tokens_dim
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError("EPLB is not supported for mxfp4")
|
||||
|
||||
if self.use_marlin:
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
layer.w13_bias,
|
||||
layer.w2_bias,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_scale1=None,
|
||||
global_scale2=None,
|
||||
quant_type_id=scalar_types.float4_e2m1f.id,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
activation=activation,
|
||||
expert_map=expert_map)
|
||||
|
||||
assert _can_support_mxfp4(
|
||||
use_grouped_topk, topk_group, num_expert_group, expert_map,
|
||||
custom_routing_function, e_score_correction_bias,
|
||||
apply_router_weight_on_input, scoring_func, activation,
|
||||
expert_load_view, logical_to_physical_map,
|
||||
logical_replica_count), (
|
||||
"MXFP4 are not supported with this configuration.")
|
||||
|
||||
if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
||||
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
|
||||
assert not self.moe.use_ep, (
|
||||
"EP is not supported for flashinfer mxfp4 moe backend yet.")
|
||||
if envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16:
|
||||
assert x.dtype == torch.bfloat16
|
||||
x_quant = x
|
||||
x_scale = None
|
||||
else:
|
||||
x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
|
||||
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
|
||||
trtllm_gen_output = trtllm_fp4_block_scale_moe(
|
||||
router_logits.to(torch.bfloat16),
|
||||
None, # routing_bias
|
||||
x_quant,
|
||||
x_scale,
|
||||
layer.w13_weight, # uint8 (e2m1 x 2)
|
||||
layer.w13_weight_scale, # uint8 (e4m3 x 2)
|
||||
layer.w13_bias, # fp32 per expert per channel
|
||||
layer.gemm1_alpha, # fp32 per expert
|
||||
layer.gemm1_beta, # fp32 per expert
|
||||
layer.gemm1_clamp_limit, # fp32 per expert
|
||||
layer.w2_weight, # uint8 (e2m1 x 2)
|
||||
layer.w2_weight_scale, # ue8m0
|
||||
layer.w2_bias, # fp32 per expert per channel
|
||||
None, # output1_scale_scalar
|
||||
None, # output1_scale_gate_scalar
|
||||
None, # output2_scale_scalar
|
||||
self.num_experts,
|
||||
top_k,
|
||||
None, # n_group
|
||||
None, # topk_group
|
||||
self.intermediate_size, # padded to multiple of 256
|
||||
0, # local_expert_offset
|
||||
self.num_experts, # local num experts
|
||||
None,
|
||||
self._get_tile_tokens_dim(x, top_k),
|
||||
1 if renormalize else 0, # routing_method_type, renormalize
|
||||
True, # do finalize
|
||||
)[0]
|
||||
return trtllm_gen_output
|
||||
elif has_triton_kernels():
|
||||
return triton_kernel_moe_forward(
|
||||
hidden_states=x,
|
||||
w1=self.w13_weight_triton_tensor,
|
||||
w2=self.w2_weight_triton_tensor,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_bias=layer.w13_bias,
|
||||
w2_bias=layer.w2_bias,
|
||||
w1_precision=self.w13_precision_config,
|
||||
w2_precision=self.w2_precision_config,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
return fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_bias=layer.w13_bias,
|
||||
w2_bias=layer.w2_bias,
|
||||
)
|
||||
@@ -6,14 +6,16 @@ from typing import Any, Callable, Optional
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
OCP_MX_BLOCK_SIZE, per_token_group_quant_mxfp4)
|
||||
OCP_MX_BLOCK_SIZE, dequant_mxfp4, quant_dequant_mxfp4)
|
||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||
PackedvLLMParameter)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
__all__ = ["QuarkW4A4MXFP4"]
|
||||
|
||||
|
||||
@@ -25,7 +27,29 @@ class QuarkW4A4MXFP4(QuarkScheme):
|
||||
self.qscheme = "per_group"
|
||||
self.weight_quant_spec = weight_quant_spec
|
||||
self.input_quant_spec = input_quant_spec
|
||||
self.emulate = not current_platform.supports_mx()
|
||||
|
||||
self.static_input_scales = not input_quant_spec.get("is_dynamic")
|
||||
|
||||
if self.static_input_scales:
|
||||
raise NotImplementedError(
|
||||
"QuarkW4A4MXFP4 with static input scales is currently not "
|
||||
"implemented. Please open an issue.")
|
||||
|
||||
if not current_platform.supports_mx():
|
||||
self.emulate = True
|
||||
logger.warning_once(
|
||||
"The current platform does not support native MXFP4 "
|
||||
"computation. Simulated weight dequantization and activation "
|
||||
"QDQ (quantize and dequantize) will be used, with the linear "
|
||||
"layers computed in high precision.")
|
||||
else:
|
||||
self.emulate = True
|
||||
logger.warning_once(
|
||||
"The current platform supports native MXFP4 "
|
||||
"computation, but kernels are not yet integrated in vLLM. "
|
||||
"Simulated weight dequantization and activation "
|
||||
"QDQ (quantize and dequantize) will be used, with the linear "
|
||||
"layers computed in high precision.")
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
@@ -37,43 +61,6 @@ class QuarkW4A4MXFP4(QuarkScheme):
|
||||
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
|
||||
requires_grad=False)
|
||||
|
||||
if self.emulate:
|
||||
try:
|
||||
from quark.torch.export.nn.modules import realquantizer
|
||||
from quark.torch.quantization.config.config import (
|
||||
QuantizationSpec)
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"The package `amd-quark` is required to use AMD Quark "
|
||||
"MX-FP4 models. Please install it with `pip install "
|
||||
"amd-quark`.") from err
|
||||
|
||||
weight_quant_spec = QuantizationSpec.from_dict(
|
||||
self.weight_quant_spec)
|
||||
|
||||
weight_quantizer = realquantizer.get_real_quantizer(
|
||||
qspec=weight_quant_spec,
|
||||
quantizer=None,
|
||||
real_quantized=True,
|
||||
reorder=False,
|
||||
float_dtype=self.out_dtype,
|
||||
scale_shape=layer.weight_scale.shape,
|
||||
zero_point_shape=None,
|
||||
)
|
||||
weight_quantizer.scale.data = layer.weight_scale.data
|
||||
|
||||
if not envs.VLLM_QUARK_EMU_MEM_OPT:
|
||||
layer.weight = torch.nn.Parameter(
|
||||
weight_quantizer(layer.weight.data).to(self.out_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
else:
|
||||
self.weight_quantizer = weight_quantizer
|
||||
layer.weight_scale = None
|
||||
|
||||
# This call is necessary to release the scales memory.
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
@@ -116,11 +103,10 @@ class QuarkW4A4MXFP4(QuarkScheme):
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
if self.emulate:
|
||||
if envs.VLLM_QUARK_EMU_MEM_OPT:
|
||||
dq_w = self.weight_quantizer(layer.weight).to(self.out_dtype)
|
||||
else:
|
||||
dq_w = layer.weight
|
||||
qdq_x, _ = per_token_group_quant_mxfp4(x, OCP_MX_BLOCK_SIZE)
|
||||
return F.linear(qdq_x, dq_w, bias)
|
||||
dq_w = dequant_mxfp4(layer.weight, layer.weight_scale, x.dtype)
|
||||
|
||||
x = quant_dequant_mxfp4(x)
|
||||
|
||||
return F.linear(x, dq_w, bias)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -0,0 +1,158 @@
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
# fmt: off
|
||||
|
||||
|
||||
MXFP_BLOCK_SIZE = tl.constexpr(32)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _get_max_quant_val(dtype: tl.constexpr):
|
||||
if dtype == tl.uint8:
|
||||
return 6.0
|
||||
elif dtype == tl.float8e5:
|
||||
return 57344.0
|
||||
elif dtype == tl.float8e4nv:
|
||||
return 448.0
|
||||
else:
|
||||
tl.static_assert(False, f"Invalid {dtype=}")
|
||||
|
||||
@triton.jit
|
||||
def _compute_quant_and_scale(src_tensor, valid_src_mask, mx_tensor_dtype: tl.constexpr,
|
||||
DEQUANT_SCALE_ROUNDING_MODE: tl.constexpr = 0):
|
||||
is_fp8: tl.constexpr = mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5
|
||||
BLOCK_SIZE_OUT_DIM: tl.constexpr = src_tensor.shape[0]
|
||||
BLOCK_SIZE_QUANT_DIM: tl.constexpr = src_tensor.shape[1]
|
||||
BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = src_tensor.shape[1] // MXFP_BLOCK_SIZE
|
||||
|
||||
# Explicit cast to fp32 since most ops are not supported on bfloat16. We avoid needless conversions to and from bf16
|
||||
f32_tensor = src_tensor.to(tl.float32)
|
||||
abs_tensor = tl.abs(f32_tensor)
|
||||
abs_tensor = tl.where(valid_src_mask, abs_tensor, -1.0) # Don't consider padding tensors in scale computation
|
||||
abs_tensor = tl.reshape(abs_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
|
||||
max_val = tl.max(abs_tensor, axis=2, keep_dims=True)
|
||||
dequant_scale = max_val / _get_max_quant_val(mx_tensor_dtype)
|
||||
if DEQUANT_SCALE_ROUNDING_MODE == 0:
|
||||
# DequantScaleRoundingMode.ROUND_UP
|
||||
# compute 2 ** ceil(log2(dequant_scale))
|
||||
# Adding 0x007FFFFF adds exponent by 1 unless mantissa is all zeros
|
||||
# A corner case: exponent is 0xFF that will overflow but that's already
|
||||
# NaN so assume we don't care.
|
||||
dequant_scale_exponent = (dequant_scale.to(tl.uint32, bitcast=True) + 0x007FFFFF) & 0x7F800000
|
||||
else:
|
||||
# DequantScaleRoundingMode.ROUND_DOWN
|
||||
# compute 2 ** floor(log2(dequant_scale))
|
||||
assert DEQUANT_SCALE_ROUNDING_MODE == 1
|
||||
dequant_scale_exponent = dequant_scale.to(tl.uint32, bitcast=True) & 0x7F800000
|
||||
dequant_scale_rounded = dequant_scale_exponent.to(tl.float32, bitcast=True)
|
||||
quant_scale = tl.where(dequant_scale_rounded == 0, 0, 1.0 / dequant_scale_rounded)
|
||||
|
||||
f32_tensor = tl.reshape(f32_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
|
||||
quant_tensor = f32_tensor * quant_scale
|
||||
|
||||
# Reshape the tensors after scaling
|
||||
quant_tensor = quant_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM])
|
||||
# Set the invalid portions of the tensor to 0. This will ensure that any padding tensors are 0 in the mx format.
|
||||
quant_tensor = tl.where(valid_src_mask, quant_tensor, 0)
|
||||
dequant_scale_exponent = dequant_scale_exponent.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE])
|
||||
|
||||
# First, we simply extract the exponent part of the scales and store the result
|
||||
dequant_scale_exponent = (dequant_scale_exponent >> 23).to(tl.uint8)
|
||||
# Now we must convert the tensors to the mx format.
|
||||
if is_fp8:
|
||||
out_tensor = quant_tensor.to(mx_tensor_dtype)
|
||||
else:
|
||||
quant_tensor = quant_tensor.to(tl.uint32, bitcast=True)
|
||||
signs = quant_tensor & 0x80000000
|
||||
exponents = (quant_tensor >> 23) & 0xFF
|
||||
mantissas = (quant_tensor & 0x7FFFFF)
|
||||
|
||||
# 0.25 <= x < 0.75 maps to 0.5, a denormal number
|
||||
E8_BIAS = 127
|
||||
E2_BIAS = 1
|
||||
# Move implicit bit 1 at the beginning to mantissa for denormals
|
||||
adjusted_exponents = tl.core.sub(E8_BIAS, exponents + 1, sanitize_overflow=False)
|
||||
mantissas = tl.where(exponents < E8_BIAS, (0x400000 | (mantissas >> 1)) >> adjusted_exponents, mantissas)
|
||||
|
||||
# For normal numbers, we change the bias from 127 to 1, and for subnormals, we keep exponent as 0.
|
||||
exponents = tl.maximum(exponents, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS)
|
||||
|
||||
# Combine sign, exponent, and mantissa, while saturating
|
||||
# rounding nearest with tie breaking up by adding +1 to one bit right of the LSB, then shift right
|
||||
e2m1_tmp = tl.minimum((((exponents << 2) | (mantissas >> 21)) + 1) >> 1, 0x7)
|
||||
e2m1_value = ((signs >> 28) | e2m1_tmp).to(tl.uint8)
|
||||
|
||||
e2m1_value = tl.reshape(e2m1_value, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM // 2, 2])
|
||||
evens, odds = tl.split(e2m1_value)
|
||||
out_tensor = evens | (odds << 4)
|
||||
|
||||
return out_tensor, dequant_scale_exponent
|
||||
|
||||
@triton.jit
|
||||
def _downcast_to_mxfp(mx_tensor_ptr, stride_mxt_outer, stride_mxt_quant: tl.constexpr,
|
||||
mx_scale_ptr, stride_mx_scale_outer, stride_mx_scale_quant,
|
||||
src_ptr, stride_src_outer, stride_src_quant,
|
||||
outer_dim, quant_dim,
|
||||
BLOCK_SIZE_OUT_DIM: tl.constexpr, BLOCK_SIZE_QUANT_DIM: tl.constexpr,
|
||||
DEQUANT_SCALE_ROUNDING_MODE: tl.constexpr):
|
||||
|
||||
tl.static_assert(stride_mxt_quant == 1, f"Output stride, {stride_mxt_quant=} must be 1.")
|
||||
tl.static_assert(BLOCK_SIZE_QUANT_DIM % MXFP_BLOCK_SIZE == 0, f"{BLOCK_SIZE_QUANT_DIM=} must be a multiple of 32")
|
||||
|
||||
# uint8 signifies two fp4 e2m1 values packed into a single byte
|
||||
mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty
|
||||
tl.static_assert(mx_tensor_dtype == tl.uint8 or (mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5),
|
||||
f"Invalid {mx_tensor_dtype=}. Must be uint8 or float8.")
|
||||
|
||||
src_dtype: tl.constexpr = src_ptr.dtype.element_ty
|
||||
tl.static_assert(mx_scale_ptr.dtype.element_ty == tl.uint8, f"{mx_scale_ptr.dtype.element_ty=} must be uint8")
|
||||
tl.static_assert((src_dtype == tl.bfloat16) or (src_dtype == tl.float16) or (src_dtype == tl.float32), f"{src_dtype=} must be bfloat16 or float16 or float32")
|
||||
is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8
|
||||
|
||||
outer_block = tl.program_id(0).to(tl.int64)
|
||||
quant_block = tl.program_id(1).to(tl.int64)
|
||||
|
||||
K_DIVISOR: tl.constexpr = 2 if is_fp4 else 1
|
||||
BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // MXFP_BLOCK_SIZE
|
||||
BLOCK_SIZE_QUANT_MX_TENSOR: tl.constexpr = BLOCK_SIZE_QUANT_DIM // K_DIVISOR
|
||||
|
||||
start_src_quant = quant_block * BLOCK_SIZE_QUANT_DIM
|
||||
start_mx_scale_quant = quant_block * BLOCK_SIZE_QUANT_MX_SCALE
|
||||
start_mx_quant = quant_block * BLOCK_SIZE_QUANT_MX_TENSOR
|
||||
start_out = outer_block * BLOCK_SIZE_OUT_DIM
|
||||
|
||||
src_ptr += start_src_quant * stride_src_quant + start_out * stride_src_outer
|
||||
mx_scale_ptr += start_mx_scale_quant * stride_mx_scale_quant + start_out * stride_mx_scale_outer
|
||||
mx_tensor_ptr += start_mx_quant * stride_mxt_quant + start_out * stride_mxt_outer
|
||||
|
||||
offs_src_quant = tl.arange(0, BLOCK_SIZE_QUANT_DIM)[None, :].to(tl.int64)
|
||||
offs_mxt_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_TENSOR)[None, :].to(tl.int64)
|
||||
offs_scale_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_SCALE)[None, :].to(tl.int64)
|
||||
offs_outer = tl.arange(0, BLOCK_SIZE_OUT_DIM)[:, None].to(tl.int64)
|
||||
|
||||
mask_src_quant = start_src_quant + offs_src_quant < quant_dim
|
||||
mask_n = start_out + offs_outer < outer_dim
|
||||
full_mask_src = mask_src_quant & mask_n
|
||||
|
||||
mask_mxt_quant = start_mx_quant + offs_mxt_quant < tl.cdiv(quant_dim, K_DIVISOR)
|
||||
full_mask_mxt = mask_mxt_quant & mask_n
|
||||
|
||||
scale_mask_k = start_mx_scale_quant + offs_scale_quant < tl.cdiv(quant_dim, MXFP_BLOCK_SIZE)
|
||||
full_scale_mask = scale_mask_k & mask_n
|
||||
|
||||
src_tensor_offsets = offs_src_quant * stride_src_quant + offs_outer * stride_src_outer
|
||||
mx_scale_offsets = offs_scale_quant * stride_mx_scale_quant + offs_outer * stride_mx_scale_outer
|
||||
mx_tensor_offsets = offs_mxt_quant * stride_mxt_quant + offs_outer * stride_mxt_outer
|
||||
src_tensor = tl.load(src_ptr + src_tensor_offsets, mask=full_mask_src)
|
||||
|
||||
out_tensor, scale_tensor = _compute_quant_and_scale(src_tensor, full_mask_src, mx_tensor_dtype,
|
||||
DEQUANT_SCALE_ROUNDING_MODE)
|
||||
|
||||
tl.store(mx_scale_ptr + mx_scale_offsets, scale_tensor, mask=full_scale_mask)
|
||||
tl.store(mx_tensor_ptr + mx_tensor_offsets, out_tensor, mask=full_mask_mxt)
|
||||
|
||||
|
||||
@triton.jit(repr=lambda _: "_dequantize_mxfp8")
|
||||
def _dequantize_mxfp8_fn(input, mask, pid=None):
|
||||
return _compute_quant_and_scale(input, mask, tl.float8e4nv)
|
||||
@@ -0,0 +1,136 @@
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from ._downcast_to_mxfp import MXFP_BLOCK_SIZE
|
||||
|
||||
|
||||
# fmt: off
|
||||
@triton.jit
|
||||
def _upcast_from_mxfp(out_ptr, stride_o_outer, stride_o_quant: tl.constexpr, mx_scale_ptr, stride_scale_outer,
|
||||
stride_scale_quant, mx_tensor_ptr, stride_tensor_outer, stride_tensor_quant: tl.constexpr,
|
||||
outer_dim, quant_dim, BLOCK_SIZE_OUT_DIM: tl.constexpr, BLOCK_SIZE_QUANT_DIM: tl.constexpr):
|
||||
|
||||
tl.static_assert(stride_o_quant == 1, "the weight must be contiguous in the k dimension for mx")
|
||||
tl.static_assert(BLOCK_SIZE_QUANT_DIM % MXFP_BLOCK_SIZE == 0, "BLOCK_SIZE_K must be a multiple of 32")
|
||||
# uint8 signifies two fp4 e2m1 values packed into a single byte
|
||||
mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty
|
||||
dst_dtype: tl.constexpr = out_ptr.dtype.element_ty
|
||||
tl.static_assert(dst_dtype == tl.float16 or (dst_dtype == tl.bfloat16 or dst_dtype == tl.float32))
|
||||
tl.static_assert(
|
||||
mx_tensor_dtype == tl.uint8
|
||||
or ((mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5) or mx_tensor_dtype == dst_dtype),
|
||||
"mx_tensor_ptr must be uint8 or float8 or dst_dtype")
|
||||
tl.static_assert(mx_scale_ptr.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8")
|
||||
|
||||
# Determine if we are dealing with fp8 types.
|
||||
is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8
|
||||
is_fp8: tl.constexpr = mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5
|
||||
K_DIVISOR: tl.constexpr = 2 if is_fp4 else 1
|
||||
BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // MXFP_BLOCK_SIZE
|
||||
BLOCK_SIZE_QUANT_MX_TENSOR: tl.constexpr = BLOCK_SIZE_QUANT_DIM // K_DIVISOR
|
||||
|
||||
# Compute starting indices for the quantized (packed) dimension and the outer dimension.
|
||||
outer_block = tl.program_id(0).to(tl.int64)
|
||||
quant_block = tl.program_id(1).to(tl.int64)
|
||||
|
||||
start_mxt_quant = quant_block * BLOCK_SIZE_QUANT_MX_TENSOR
|
||||
start_out_quant = quant_block * BLOCK_SIZE_QUANT_DIM
|
||||
start_mx_scale_quant = quant_block * BLOCK_SIZE_QUANT_MX_SCALE
|
||||
start_out = outer_block * BLOCK_SIZE_OUT_DIM
|
||||
|
||||
mx_tensor_ptr += start_mxt_quant * stride_tensor_quant + start_out * stride_tensor_outer
|
||||
mx_scale_ptr += start_mx_scale_quant * stride_scale_quant + start_out * stride_scale_outer
|
||||
out_ptr += start_out * stride_o_outer + start_out_quant * stride_o_quant
|
||||
|
||||
# Compute offsets and masks.
|
||||
offs_src_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_TENSOR)[None, :].to(tl.int64)
|
||||
offs_out_quant = tl.arange(0, BLOCK_SIZE_QUANT_DIM)[None, :].to(tl.int64)
|
||||
offs_outer = tl.arange(0, BLOCK_SIZE_OUT_DIM)[:, None].to(tl.int64)
|
||||
offs_scale = tl.arange(0, BLOCK_SIZE_QUANT_MX_SCALE)[None, :].to(tl.int64)
|
||||
|
||||
mask_outer = start_out + offs_outer < outer_dim
|
||||
mask_out_quant = start_out_quant + offs_out_quant < quant_dim
|
||||
full_mask_out = mask_out_quant & mask_outer
|
||||
|
||||
mask_src_quant = start_mxt_quant + offs_src_quant < tl.cdiv(quant_dim, K_DIVISOR)
|
||||
full_mask_src = mask_src_quant & mask_outer
|
||||
|
||||
mask_scale = start_mx_scale_quant + offs_scale < tl.cdiv(quant_dim, MXFP_BLOCK_SIZE)
|
||||
full_scale_mask = mask_scale & mask_outer
|
||||
|
||||
tensor_offsets = offs_src_quant * stride_tensor_quant + offs_outer * stride_tensor_outer
|
||||
scale_offsets = offs_scale * stride_scale_quant + offs_outer * stride_scale_outer
|
||||
out_offsets = offs_out_quant * stride_o_quant + offs_outer * stride_o_outer
|
||||
|
||||
# Load the packed tensor and scale.
|
||||
tensor = tl.load(mx_tensor_ptr + tensor_offsets, mask=full_mask_src)
|
||||
scale = tl.load(mx_scale_ptr + scale_offsets, mask=full_scale_mask)
|
||||
|
||||
# Upcast the scale to the destination type.
|
||||
if dst_dtype == tl.bfloat16:
|
||||
# dst_scale = (scale.to(tl.uint16) << 7).to(dst_dtype, bitcast=True)
|
||||
dst_scale = (scale.to(tl.uint16) << 7).to(tl.uint16).to(tl.bfloat16, bitcast=True)
|
||||
else:
|
||||
dst_scale = (scale.to(tl.uint32) << 23).to(tl.float32, bitcast=True)
|
||||
if dst_dtype == tl.float16:
|
||||
dst_scale = dst_scale.to(tl.float16)
|
||||
|
||||
# Now upcast the tensor.
|
||||
intermediate_dtype: tl.constexpr = tl.bfloat16 if dst_dtype == tl.float32 else dst_dtype
|
||||
if is_fp8:
|
||||
dst_tensor = tensor.to(intermediate_dtype)
|
||||
if tensor.dtype == tl.float8e5:
|
||||
from_e_bits: tl.constexpr = 5
|
||||
from_m_bits: tl.constexpr = 2
|
||||
to_e_bits: tl.constexpr = 8 if intermediate_dtype == tl.bfloat16 else 5
|
||||
to_m_bits: tl.constexpr = 7 if intermediate_dtype == tl.bfloat16 else 10
|
||||
|
||||
# Preserve infs and nans. FIXME Fp8E5M2_to_Bf16 doesn't preserve them!
|
||||
non_finite_mask_src: tl.constexpr = ((1 << from_e_bits) - 1) << from_m_bits
|
||||
non_finite_mask_dst: tl.constexpr = ((1 << to_e_bits) - 1) << to_m_bits
|
||||
dst_tensor = tl.where(
|
||||
(tensor.to(tl.uint8, bitcast=True) & non_finite_mask_src) == non_finite_mask_src,
|
||||
(dst_tensor.to(tl.uint16, bitcast=True) | non_finite_mask_dst).to(intermediate_dtype, bitcast=True),
|
||||
dst_tensor,
|
||||
)
|
||||
else:
|
||||
assert is_fp4
|
||||
dst_bias: tl.constexpr = 127 if intermediate_dtype == tl.bfloat16 else 15
|
||||
dst_0p5: tl.constexpr = 16128 if intermediate_dtype == tl.bfloat16 else 0x3800
|
||||
dst_m_bits: tl.constexpr = 7 if intermediate_dtype == tl.bfloat16 else 10
|
||||
# e2m1
|
||||
em0 = tensor & 0x07
|
||||
em1 = tensor & 0x70
|
||||
x0 = (em0.to(tl.uint16) << (dst_m_bits - 1)) | ((tensor & 0x08).to(tl.uint16) << 12)
|
||||
x1 = (em1.to(tl.uint16) << (dst_m_bits - 5)) | ((tensor & 0x80).to(tl.uint16) << 8)
|
||||
# Three cases:
|
||||
# 1) x is normal and non-zero: Correct bias
|
||||
x0 = tl.where((em0 & 0x06) != 0, x0 + ((dst_bias - 1) << dst_m_bits), x0)
|
||||
x1 = tl.where((em1 & 0x60) != 0, x1 + ((dst_bias - 1) << dst_m_bits), x1)
|
||||
# 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
|
||||
x0 = tl.where(em0 == 0x01, dst_0p5 | (x0 & 0x8000), x0)
|
||||
x1 = tl.where(em1 == 0x10, dst_0p5 | (x1 & 0x8000), x1)
|
||||
# 3) x is zero, do nothing
|
||||
if intermediate_dtype == tl.bfloat16:
|
||||
dst_tensor = tl.interleave(x0, x1).to(tl.uint16).to(tl.bfloat16, bitcast=True)
|
||||
else:
|
||||
dst_tensor = tl.interleave(x0, x1).to(tl.float16, bitcast=True)
|
||||
# dst_tensor = dst_tensor.to(dst_dtype)
|
||||
if dst_dtype == tl.bfloat16:
|
||||
dst_tensor = dst_tensor.to(tl.bfloat16)
|
||||
elif dst_dtype == tl.float16:
|
||||
dst_tensor = dst_tensor.to(tl.float16)
|
||||
else:
|
||||
dst_tensor = dst_tensor.to(tl.float32)
|
||||
|
||||
|
||||
# Reshape for proper broadcasting: the scale was stored with a 32‐sized “inner” grouping.
|
||||
dst_tensor = dst_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
|
||||
dst_scale = dst_scale.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1])
|
||||
scale = scale.reshape(dst_scale.shape)
|
||||
|
||||
out_tensor = dst_tensor * dst_scale
|
||||
# Correct any NaNs encoded via the scale.
|
||||
out_tensor = tl.where(scale == 0xFF, float("nan"), out_tensor)
|
||||
out_tensor = out_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM])
|
||||
tl.store(out_ptr + out_offsets, out_tensor, mask=full_mask_out)
|
||||
@@ -0,0 +1,303 @@
|
||||
# isort: off
|
||||
# fmt: off
|
||||
from enum import Enum
|
||||
import triton
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from ._upcast_from_mxfp import _upcast_from_mxfp
|
||||
from ._downcast_to_mxfp import _downcast_to_mxfp, _dequantize_mxfp8_fn, MXFP_BLOCK_SIZE
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Dequantization / Quantization Utilities
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DequantScaleRoundingMode(Enum):
|
||||
ROUND_UP = 0
|
||||
ROUND_DOWN = 1
|
||||
|
||||
|
||||
def downcast_to_mxfp(src_tensor: torch.Tensor, out_quant_type: torch.dtype, axis: int,
|
||||
DEQUANT_SCALE_ROUNDING_MODE: DequantScaleRoundingMode = DequantScaleRoundingMode.ROUND_UP):
|
||||
"""
|
||||
Convert the src weights to mx format. The src weight is quantized along the axis dimension.
|
||||
|
||||
If weight_quant_type is torch.uint8, we output mxfp4 where two e2m1 values are packed into a single byte.
|
||||
Note that this means the k_dim of the tensor will be half of the logical k_dim.
|
||||
|
||||
If weight_quant_type is torch.float8_e4m3fn or torch.float8_e5m2, we output mxfp8 with the float8s are stored
|
||||
in their respective formats.
|
||||
"""
|
||||
ndim = src_tensor.ndim
|
||||
assert -ndim <= axis < ndim, f"Invalid axis {axis=}"
|
||||
axis = axis if axis >= 0 else axis + ndim
|
||||
# downcast
|
||||
src_tensor = src_tensor.transpose(axis, src_tensor.ndim - 1)
|
||||
is_fp4 = out_quant_type == torch.uint8
|
||||
is_fp8 = out_quant_type in (torch.float8_e4m3fn, torch.float8_e5m2)
|
||||
assert is_fp4 or is_fp8
|
||||
divisor = 2 if is_fp4 else 1
|
||||
L = src_tensor.shape[-1]
|
||||
if is_fp4:
|
||||
assert L % 2 == 0, f"axis dim must be divisible by 2 for e2m1. Got {L}"
|
||||
out_shape = src_tensor.shape[:-1] + (L // divisor, )
|
||||
out_scale_shape = src_tensor.shape[:-1] + (triton.cdiv(L, MXFP_BLOCK_SIZE), )
|
||||
|
||||
out_quant_tensor = src_tensor.new_empty(out_shape, dtype=out_quant_type)
|
||||
out_scale = src_tensor.new_empty(out_scale_shape, dtype=torch.uint8)
|
||||
|
||||
if src_tensor.numel() > 0:
|
||||
kernel_src_tensor = src_tensor.reshape(-1, src_tensor.shape[-1])
|
||||
kernel_quant_tensor = out_quant_tensor.view(-1, out_quant_tensor.shape[-1])
|
||||
kernel_scale = out_scale.view(-1, out_scale.shape[-1])
|
||||
|
||||
BLOCK_OUT_DIM = 128
|
||||
BLOCK_QUANT_DIM = MXFP_BLOCK_SIZE.value
|
||||
grid_out = triton.cdiv(kernel_src_tensor.shape[0], BLOCK_OUT_DIM)
|
||||
grid_quant = triton.cdiv(kernel_src_tensor.shape[1], BLOCK_QUANT_DIM)
|
||||
|
||||
_downcast_to_mxfp[(grid_out, grid_quant)](kernel_quant_tensor, *kernel_quant_tensor.stride(), kernel_scale,
|
||||
*kernel_scale.stride(), kernel_src_tensor, *kernel_src_tensor.stride(),
|
||||
*kernel_src_tensor.shape, BLOCK_OUT_DIM, BLOCK_QUANT_DIM,
|
||||
DEQUANT_SCALE_ROUNDING_MODE.value, num_warps=8)
|
||||
|
||||
out_quant_tensor = out_quant_tensor.transpose(axis, src_tensor.ndim - 1)
|
||||
out_scale = out_scale.transpose(axis, src_tensor.ndim - 1)
|
||||
return out_quant_tensor, out_scale
|
||||
|
||||
|
||||
def upcast_from_mxfp(tensor: torch.Tensor, scale: torch.Tensor, dtype: torch.dtype, axis: int):
|
||||
"""
|
||||
Upcasts an mxfp (packed) weight tensor back to float16 or bfloat16.
|
||||
|
||||
The function assumes that the tensors were quantized along the given axis.
|
||||
It permutes the tensor so that the quantized axis is last, reshapes to 2D,
|
||||
launches the Triton upcast kernel, and then unpermutes back to the original order.
|
||||
"""
|
||||
ndim = tensor.ndim
|
||||
assert -ndim <= axis < ndim, f"Invalid axis {axis=}"
|
||||
axis = axis if axis >= 0 else axis + ndim
|
||||
assert tensor.ndim == scale.ndim, (f"Weight and scale must have the same number of dimensions. "
|
||||
f"Got {tensor.ndim=} and {scale.ndim=}")
|
||||
# dtype checks
|
||||
assert tensor.dtype in {torch.uint8, torch.float8_e5m2, torch.float8_e4m3fn}, \
|
||||
f"Invalid tensor dtype {tensor.dtype=}"
|
||||
assert scale.dtype == torch.uint8, f"Invalid scale dtype {scale.dtype=}"
|
||||
assert dtype in (torch.float16, torch.bfloat16, torch.float32), f"Invalid output dtype {dtype=}"
|
||||
# upcast
|
||||
logical_quant_dim = tensor.shape[axis] * (2 if tensor.dtype == torch.uint8 else 1)
|
||||
tensor = tensor.transpose(axis, tensor.ndim - 1).contiguous()
|
||||
scale = scale.transpose(axis, scale.ndim - 1).contiguous()
|
||||
out = torch.empty((*tensor.shape[:-1], logical_quant_dim), dtype=dtype, device=tensor.device)
|
||||
reshaped_out = out.view(-1, out.shape[-1])
|
||||
reshaped_tensor = tensor.view(-1, tensor.shape[-1])
|
||||
reshaped_scale = scale.view(-1, scale.shape[-1])
|
||||
BLOCK_OUT_DIM = 128
|
||||
BLOCK_QUANT_DIM = MXFP_BLOCK_SIZE.value
|
||||
blocks_out_dim = triton.cdiv(reshaped_out.shape[0], BLOCK_OUT_DIM)
|
||||
blocks_quant_dim = triton.cdiv(reshaped_out.shape[1], BLOCK_QUANT_DIM)
|
||||
_upcast_from_mxfp[(blocks_out_dim, blocks_quant_dim)](reshaped_out, *reshaped_out.stride(), reshaped_scale,
|
||||
*reshaped_scale.stride(), reshaped_tensor,
|
||||
*reshaped_tensor.stride(), *reshaped_out.shape, BLOCK_OUT_DIM,
|
||||
BLOCK_QUANT_DIM, num_warps=8)
|
||||
out = out.transpose(axis, scale.ndim - 1).contiguous()
|
||||
return out
|
||||
|
||||
|
||||
# ------------
|
||||
|
||||
|
||||
def right_shift_unsigned(x, shift):
|
||||
# CUDA torch does not support bit ops on uint32, so we need to mask to get unsigned right shift
|
||||
return (x >> shift) & ((1 << (32 - shift)) - 1)
|
||||
|
||||
|
||||
def get_max_quant_val(dtype: torch.dtype):
|
||||
d = {torch.uint8: 6.0, torch.float8_e5m2: 57344.0, torch.float8_e4m3fn: 448.0}
|
||||
assert dtype in d
|
||||
return d[dtype]
|
||||
|
||||
|
||||
def downcast_to_mxfp_torch(src_tensor: torch.Tensor, out_quant_type: torch.dtype, axis: int,
|
||||
DEQUANT_SCALE_ROUNDING_MODE: DequantScaleRoundingMode = DequantScaleRoundingMode.ROUND_UP):
|
||||
"""
|
||||
Converts the src tensor to the output format specified by out_quant_type.
|
||||
axis: The axis along which the tensors are contiguous and quantization is applied.
|
||||
DEQUANT_SCALE_ROUNDING_MODE: 0 for ROUND_UP, 1 for ROUND_DOWN.
|
||||
|
||||
Returns:
|
||||
out_quant_tensor: Quantized tensor in mx format.
|
||||
• For mxfp8, the output has the same shape as src_tensor.
|
||||
• For mxfp4, the size along the axis is halved, and the tensor is returned as a torch.uint8.
|
||||
scale: Scale tensor (stored as uint8) computed per group of 32 elements along the axis.
|
||||
Its shape is the same as src_tensor except that the axis is replaced by ceil(L/32),
|
||||
where L is the original length along that axis.
|
||||
"""
|
||||
# This should probably be packed into its own tiny class
|
||||
ndim = src_tensor.ndim
|
||||
assert -ndim <= axis < ndim, f"Invalid axis {axis=}"
|
||||
assert src_tensor.dtype in {torch.float32, torch.bfloat16,
|
||||
torch.float16}, f"Invalid input tensor dtype {src_tensor.dtype}"
|
||||
|
||||
axis = axis if axis >= 0 else axis + ndim
|
||||
is_fp4 = out_quant_type == torch.uint8
|
||||
is_fp8 = "float8" in str(out_quant_type)
|
||||
assert is_fp4 or is_fp8, f"Invalid input tensor dtype {out_quant_type}"
|
||||
|
||||
device = src_tensor.device
|
||||
|
||||
# For mxfp4 conversion, we assume the contiguous axis length is even.
|
||||
if is_fp4:
|
||||
axis_shape = src_tensor.size(axis)
|
||||
assert axis_shape % 2 == 0, "For mxfp4 conversion the contiguous axis length must be even."
|
||||
|
||||
# Permute the tensor so that the contiguous axis becomes the last dimension.
|
||||
src = src_tensor.transpose(axis, src_tensor.ndim - 1).to(torch.float32)
|
||||
axis_shape = src.shape[-1]
|
||||
|
||||
# Pad the axis to be divisible by 32, in case it is not.
|
||||
next_multiple = triton.cdiv(axis_shape, MXFP_BLOCK_SIZE) * MXFP_BLOCK_SIZE
|
||||
pad_amount = next_multiple - axis_shape
|
||||
padded_src = F.pad(src, (0, pad_amount))
|
||||
valid_mask = F.pad(torch.ones_like(src, dtype=torch.bool), (0, pad_amount))
|
||||
padded_axis_shape = padded_src.size(-1) # now divisible by 32
|
||||
|
||||
# --- Compute per-group maximums for scale ---
|
||||
# Set padded entries to -1 so they don’t affect the max.
|
||||
abs_f = torch.abs(padded_src)
|
||||
abs_f = torch.where(valid_mask, abs_f, torch.tensor(-1.0, device=device, dtype=padded_src.dtype))
|
||||
# Reshape the last dimension into groups of 32.
|
||||
new_shape = padded_src.shape[:-1] + (padded_axis_shape // MXFP_BLOCK_SIZE, MXFP_BLOCK_SIZE)
|
||||
abs_groups = abs_f.view(*new_shape)
|
||||
# Compute maximum along the group dimension (of size 32).
|
||||
max_val, _ = abs_groups.max(dim=-1, keepdim=True)
|
||||
|
||||
# Choose a max quantization value depending on type.
|
||||
max_quant_val = get_max_quant_val(out_quant_type)
|
||||
dequant_scale = max_val / max_quant_val # shape: (..., padded_axis_shape//32, 1)
|
||||
|
||||
# Convert to int to round the FP32 scale, prior to quantization!
|
||||
ds_int = dequant_scale.view(torch.int32)
|
||||
if DEQUANT_SCALE_ROUNDING_MODE == DequantScaleRoundingMode.ROUND_UP:
|
||||
ds_int_rounded = (ds_int + 0x007FFFFF) & 0x7F800000
|
||||
else:
|
||||
ds_int_rounded = ds_int & 0x7F800000
|
||||
# Reinterpret back as float32.
|
||||
dequant_scale_rounded = ds_int_rounded.view(torch.float32)
|
||||
|
||||
# Compute the quantization scale.
|
||||
quant_scale = torch.where(dequant_scale_rounded == 0, torch.tensor(0.0, device=device), 1.0 / dequant_scale_rounded)
|
||||
|
||||
# Quantize the tensor
|
||||
orig_padded_shape = padded_src.shape
|
||||
padded_src_groups = padded_src.view(*new_shape)
|
||||
quant_tensor = padded_src_groups * quant_scale
|
||||
# Reshape back to the original shape and trim padding
|
||||
quant_tensor = quant_tensor.view(orig_padded_shape)
|
||||
quant_tensor = quant_tensor[..., :axis_shape]
|
||||
|
||||
# Finally, convert the quantized tensor to the target format
|
||||
if is_fp8:
|
||||
# Conversion must use satfinite PTX, so clamp before the conversion in torch to emulate this behavior
|
||||
quant_tensor = torch.clamp(quant_tensor, -max_quant_val, max_quant_val)
|
||||
out_weight = quant_tensor.to(out_quant_type)
|
||||
else:
|
||||
assert is_fp4, f"Invalid output quantization type {out_quant_type}"
|
||||
# For mxfp4, perform bit-level manipulation and pack two 4-bit values per uint8.
|
||||
# First, reinterpret the quantized tensor bits.
|
||||
q_int = quant_tensor.contiguous().view(torch.int32)
|
||||
# Extract sign, exponent, and mantissa.
|
||||
signs = q_int & 0x80000000
|
||||
exponents = right_shift_unsigned(q_int, 23) & 0xFF
|
||||
mantissas = q_int & 0x7FFFFF
|
||||
|
||||
E8_BIAS = 127
|
||||
E2_BIAS = 1
|
||||
# Adjust mantissas for subnormals.
|
||||
mantissas = torch.where(exponents < E8_BIAS, (0x400000 | right_shift_unsigned(mantissas, 1)) >>
|
||||
(E8_BIAS - exponents - 1), mantissas)
|
||||
exponents = torch.maximum(exponents, torch.tensor(E8_BIAS - E2_BIAS, device=device)) - (E8_BIAS - E2_BIAS)
|
||||
e2m1_tmp = right_shift_unsigned(((exponents << 2) | right_shift_unsigned(mantissas, 21)) + 1, 1)
|
||||
e2m1_tmp = torch.minimum(e2m1_tmp, torch.tensor(0x7, device=device))
|
||||
e2m1_value = (right_shift_unsigned(signs, 28) | e2m1_tmp).to(torch.uint8) # shape: (..., even_axis_shape)
|
||||
|
||||
# Pack pairs of 4-bit values along the last dimension.
|
||||
e2m1_value = e2m1_value.view(*e2m1_value.shape[:-1], axis_shape // 2, 2)
|
||||
evens = e2m1_value[..., 0]
|
||||
odds = e2m1_value[..., 1]
|
||||
out_weight = evens | (odds << 4) # shape: (..., axis_shape//2)
|
||||
|
||||
# --- Process and output the scale ---
|
||||
dq_scale = (ds_int_rounded.view(*dequant_scale.shape) >> 23).to(torch.uint8) # shape: (..., axis_shape//32, 1)
|
||||
dq_scale = dq_scale.squeeze(-1)
|
||||
out_weight = out_weight.transpose(axis, src_tensor.ndim - 1)
|
||||
dq_scale = dq_scale.transpose(axis, src_tensor.ndim - 1)
|
||||
return out_weight, dq_scale
|
||||
|
||||
|
||||
def cvt_e2m1_to_fp32(input_tensor):
|
||||
assert input_tensor.dtype == torch.uint8
|
||||
|
||||
input_tensor = input_tensor.to(torch.int32)
|
||||
evens = input_tensor & 0xF
|
||||
odds = (input_tensor >> 4) & 0xF
|
||||
|
||||
vals = [0.0, 0.5, 1, 1.5, 2, 3, 4, 6]
|
||||
outputs = torch.tensor(vals, dtype=torch.float32, device=input_tensor.device)
|
||||
outputs = torch.cat([outputs, -outputs])
|
||||
|
||||
even_floats = outputs[evens]
|
||||
odd_floats = outputs[odds]
|
||||
output_tensor = torch.stack([even_floats, odd_floats], dim=-1)
|
||||
output_tensor = output_tensor.view(*input_tensor.shape[:-1], -1)
|
||||
return output_tensor
|
||||
|
||||
|
||||
def upcast_from_mxfp_torch(tensor: torch.Tensor, scale: torch.Tensor, target_dtype: torch.dtype, axis: int):
|
||||
"""
|
||||
Converts the mxfp4/mxfp8 tensor to the target format specified by target_dtype.
|
||||
axis: The axis along which dequantization is applied.
|
||||
|
||||
Returns:
|
||||
out_weight: Tensor in the target format.
|
||||
"""
|
||||
|
||||
ndim = tensor.ndim
|
||||
assert -ndim <= axis < ndim, f"Invalid axis {axis=}"
|
||||
is_fp8 = tensor.dtype == torch.float8_e4m3fn or tensor.dtype == torch.float8_e5m2
|
||||
assert is_fp8 or tensor.dtype == torch.uint8, f"Invalid input quantization type {tensor.dtype}"
|
||||
|
||||
# Permute the tensor and scale so that the quantization axis becomes the last dimension
|
||||
axis = axis if axis >= 0 else axis + ndim
|
||||
scale = scale.transpose(axis, scale.ndim - 1)
|
||||
tensor = tensor.transpose(axis, tensor.ndim - 1)
|
||||
|
||||
dq_scale = (scale.to(torch.int32) << 23).view(torch.float32) # Shift to the exponent and bitcast to fp32
|
||||
if tensor.dtype == torch.uint8:
|
||||
fp32_tensor = cvt_e2m1_to_fp32(tensor)
|
||||
else:
|
||||
fp32_tensor = tensor.to(torch.float32)
|
||||
|
||||
logical_quant_dim = tensor.shape[-1] * (2 if tensor.dtype == torch.uint8 else 1)
|
||||
axis_shape = fp32_tensor.size(-1)
|
||||
padded_axis_shape = triton.cdiv(logical_quant_dim, MXFP_BLOCK_SIZE) * MXFP_BLOCK_SIZE
|
||||
pad_size = padded_axis_shape - axis_shape
|
||||
padded_tensor = F.pad(fp32_tensor, (0, pad_size))
|
||||
|
||||
new_axis_shape = padded_tensor.shape[-1]
|
||||
new_shape = padded_tensor.shape[:-1] + (new_axis_shape // MXFP_BLOCK_SIZE, MXFP_BLOCK_SIZE)
|
||||
padded_tensor = padded_tensor.view(*new_shape)
|
||||
dq_scale_padded = dq_scale.unsqueeze(-1) # shape: [..., ceil(axis_shape/32), 1]
|
||||
out_padded = padded_tensor * dq_scale_padded
|
||||
|
||||
# Flatten back and remove the padded tail
|
||||
out_padded = out_padded.view(*fp32_tensor.shape[:-1], new_axis_shape)
|
||||
out_tensor = out_padded[..., :axis_shape]
|
||||
|
||||
out_tensor = out_tensor.to(target_dtype).contiguous()
|
||||
out_tensor = out_tensor.transpose(axis, tensor.ndim - 1)
|
||||
|
||||
return out_tensor
|
||||
|
||||
|
||||
dequantize_mxfp8_fn = _dequantize_mxfp8_fn
|
||||
@@ -261,6 +261,13 @@ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
|
||||
return s
|
||||
|
||||
|
||||
def marlin_permute_bias(s: torch.Tensor) -> torch.Tensor:
|
||||
origin_shape = s.shape
|
||||
_, scale_perm_single = get_scale_perms()
|
||||
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
|
||||
return s.reshape(*origin_shape).contiguous()
|
||||
|
||||
|
||||
def marlin_moe_permute_scales(
|
||||
s: torch.Tensor,
|
||||
size_k: int,
|
||||
@@ -410,6 +417,7 @@ def apply_gptq_marlin_linear(
|
||||
output = ops.gptq_marlin_gemm(reshaped_x,
|
||||
None,
|
||||
weight,
|
||||
bias,
|
||||
weight_scale,
|
||||
None,
|
||||
weight_zp,
|
||||
@@ -425,9 +433,6 @@ def apply_gptq_marlin_linear(
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=False)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
|
||||
|
||||
@@ -456,6 +461,7 @@ def apply_awq_marlin_linear(
|
||||
output = ops.gptq_marlin_gemm(reshaped_x,
|
||||
None,
|
||||
weight,
|
||||
bias,
|
||||
weight_scale,
|
||||
None,
|
||||
weight_zp,
|
||||
@@ -470,7 +476,4 @@ def apply_awq_marlin_linear(
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=False)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
|
||||
@@ -8,8 +8,8 @@ import torch
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales,
|
||||
should_use_atomic_add_reduce)
|
||||
USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_bias,
|
||||
marlin_permute_scales, should_use_atomic_add_reduce)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
@@ -22,7 +22,7 @@ def is_fp4_marlin_supported():
|
||||
return current_platform.has_device_capability(80)
|
||||
|
||||
|
||||
def fp4_marlin_process_scales(marlin_scales):
|
||||
def nvfp4_marlin_process_scales(marlin_scales):
|
||||
if not (marlin_scales >= 0).all():
|
||||
logger.warning_once(
|
||||
"NVFP4 Marlin assumes the scales to be >=0, but has encountered "
|
||||
@@ -56,7 +56,20 @@ def fp4_marlin_process_scales(marlin_scales):
|
||||
return marlin_scales
|
||||
|
||||
|
||||
def fp4_marlin_process_global_scale(global_scale):
|
||||
def mxfp4_marlin_process_scales(marlin_scales):
|
||||
# 8 is the number of scale number using by one thread
|
||||
marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8)
|
||||
marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape(
|
||||
marlin_scales.size(0) * 2, -1)
|
||||
|
||||
# fit the layout of fp8 dequantization
|
||||
marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view(
|
||||
marlin_scales.size(0), -1)
|
||||
marlin_scales = marlin_scales.to(torch.float8_e8m0fnu)
|
||||
return marlin_scales
|
||||
|
||||
|
||||
def nvfp4_marlin_process_global_scale(global_scale):
|
||||
assert global_scale.dtype in [torch.half, torch.bfloat16]
|
||||
fp4_exponent = 2
|
||||
if global_scale.dtype == torch.half:
|
||||
@@ -73,7 +86,7 @@ def apply_fp4_marlin_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
weight_scale_2: torch.Tensor,
|
||||
weight_scale_2: Optional[torch.Tensor],
|
||||
workspace: torch.Tensor,
|
||||
size_n: int,
|
||||
size_k: int,
|
||||
@@ -94,6 +107,7 @@ def apply_fp4_marlin_linear(
|
||||
output = ops.gptq_marlin_gemm(a=reshaped_x,
|
||||
c=None,
|
||||
b_q_weight=weight,
|
||||
b_bias=bias,
|
||||
b_scales=weight_scale,
|
||||
global_scale=weight_scale_2,
|
||||
b_zeros=None,
|
||||
@@ -107,9 +121,6 @@ def apply_fp4_marlin_linear(
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=use_fp32_reduce)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
|
||||
|
||||
@@ -120,6 +131,9 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
|
||||
"be used leveraging the Marlin kernel. This may degrade "
|
||||
"performance for compute-heavy workloads.")
|
||||
|
||||
is_nvfp4 = hasattr(layer, "weight_scale_2")
|
||||
group_size = 16 if is_nvfp4 else 32
|
||||
|
||||
part_size_n = layer.output_size_per_partition
|
||||
part_size_k = layer.input_size_per_partition
|
||||
param_dtype = layer.params_dtype
|
||||
@@ -145,18 +159,35 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
|
||||
|
||||
# WEIGHT SCALES
|
||||
# Permute scales
|
||||
weight_scale = layer.weight_scale.T.to(param_dtype)
|
||||
weight_scale = layer.weight_scale.T.contiguous()
|
||||
|
||||
if not is_nvfp4:
|
||||
weight_scale = weight_scale.view(torch.float8_e8m0fnu)
|
||||
|
||||
weight_scale = weight_scale.to(param_dtype)
|
||||
weight_scale = marlin_permute_scales(s=weight_scale,
|
||||
size_k=part_size_k,
|
||||
size_n=part_size_n,
|
||||
group_size=16)
|
||||
weight_scale = fp4_marlin_process_scales(weight_scale)
|
||||
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
|
||||
group_size=group_size)
|
||||
|
||||
weight_scale_2 = layer.weight_scale_2.to(param_dtype)
|
||||
weight_scale_2 = fp4_marlin_process_global_scale(weight_scale_2)
|
||||
layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2,
|
||||
requires_grad=False)
|
||||
if is_nvfp4:
|
||||
weight_scale = nvfp4_marlin_process_scales(weight_scale)
|
||||
layer.weight_scale = torch.nn.Parameter(weight_scale,
|
||||
requires_grad=False)
|
||||
|
||||
weight_scale_2 = layer.weight_scale_2.to(param_dtype)
|
||||
weight_scale_2 = nvfp4_marlin_process_global_scale(weight_scale_2)
|
||||
layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2,
|
||||
requires_grad=False)
|
||||
else:
|
||||
weight_scale = mxfp4_marlin_process_scales(weight_scale)
|
||||
layer.weight_scale = torch.nn.Parameter(weight_scale,
|
||||
requires_grad=False)
|
||||
|
||||
if hasattr(layer, "bias") and layer.bias is not None:
|
||||
assert layer.bias.shape == (part_size_n, )
|
||||
bias = marlin_permute_bias(layer.bias)
|
||||
layer.bias = torch.nn.Parameter(bias, requires_grad=False)
|
||||
|
||||
return
|
||||
|
||||
@@ -168,6 +199,9 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
|
||||
"be used leveraging the Marlin kernel. This may degrade "
|
||||
"performance for compute-heavy workloads.")
|
||||
|
||||
is_nvfp4 = hasattr(layer, "w13_weight_scale_2")
|
||||
group_size = 16 if is_nvfp4 else 32
|
||||
|
||||
e = layer.num_experts
|
||||
k = layer.hidden_size
|
||||
n = layer.intermediate_size_per_partition
|
||||
@@ -208,8 +242,13 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
|
||||
# WEIGHT SCALES
|
||||
# Permute scales
|
||||
for name in ["w13", "w2"]:
|
||||
scales = getattr(layer, name + "_weight_scale").to(param_dtype)
|
||||
global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype)
|
||||
scales = getattr(layer, name + "_weight_scale")
|
||||
if not is_nvfp4:
|
||||
scales = scales.view(torch.float8_e8m0fnu)
|
||||
scales = scales.to(param_dtype)
|
||||
if is_nvfp4:
|
||||
global_scale = getattr(layer,
|
||||
name + "_weight_scale_2").to(param_dtype)
|
||||
|
||||
tensor_list = []
|
||||
if "w13" in name:
|
||||
@@ -218,23 +257,47 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
|
||||
size_n, size_k = k, n
|
||||
|
||||
for i in range(e):
|
||||
marlin_scales = marlin_permute_scales(s=scales[i].T,
|
||||
scale = scales[i].T
|
||||
|
||||
marlin_scales = marlin_permute_scales(s=scale,
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
group_size=16)
|
||||
marlin_scales = fp4_marlin_process_scales(marlin_scales)
|
||||
group_size=group_size)
|
||||
if is_nvfp4:
|
||||
marlin_scales = nvfp4_marlin_process_scales(marlin_scales)
|
||||
else:
|
||||
marlin_scales = mxfp4_marlin_process_scales(marlin_scales)
|
||||
tensor_list.append(marlin_scales)
|
||||
|
||||
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
||||
scales = torch.nn.Parameter(scales, requires_grad=False)
|
||||
setattr(layer, name + "_weight_scale", scales)
|
||||
|
||||
global_scale = fp4_marlin_process_global_scale(global_scale)
|
||||
global_scale = torch.nn.Parameter(global_scale, requires_grad=False)
|
||||
setattr(layer, name + "_weight_scale_2", global_scale)
|
||||
if is_nvfp4:
|
||||
global_scale = nvfp4_marlin_process_global_scale(global_scale)
|
||||
global_scale = torch.nn.Parameter(global_scale,
|
||||
requires_grad=False)
|
||||
setattr(layer, name + "_weight_scale_2", global_scale)
|
||||
|
||||
# BIAS
|
||||
# Permute bias
|
||||
for name in ["w13_bias", "w2_bias"]:
|
||||
if not hasattr(layer, name):
|
||||
continue
|
||||
bias = getattr(layer, name).to(param_dtype)
|
||||
|
||||
tensor_list = []
|
||||
for i in range(e):
|
||||
expert_bias = bias[i]
|
||||
|
||||
tensor_list.append(marlin_permute_bias(expert_bias))
|
||||
|
||||
bias = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
||||
bias = torch.nn.Parameter(bias, requires_grad=False)
|
||||
setattr(layer, name, bias)
|
||||
|
||||
|
||||
def rand_marlin_weight_fp4_like(weight, group_size):
|
||||
def rand_marlin_weight_nvfp4_like(weight, group_size):
|
||||
assert group_size > 0
|
||||
size_n, size_k = weight.shape
|
||||
device = weight.device
|
||||
@@ -276,8 +339,58 @@ def rand_marlin_weight_fp4_like(weight, group_size):
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
group_size=group_size)
|
||||
marlin_scales = fp4_marlin_process_scales(marlin_scales)
|
||||
marlin_scales = nvfp4_marlin_process_scales(marlin_scales)
|
||||
|
||||
global_scale = fp4_marlin_process_global_scale(global_scale)
|
||||
global_scale = nvfp4_marlin_process_global_scale(global_scale)
|
||||
|
||||
return weight_ref.T, marlin_qweight, marlin_scales, global_scale
|
||||
|
||||
|
||||
def rand_marlin_weight_mxfp4_like(weight, group_size):
|
||||
assert group_size > 0
|
||||
size_n, size_k = weight.shape
|
||||
device = weight.device
|
||||
|
||||
scales = torch.randint(100,
|
||||
125, (size_n, size_k // group_size),
|
||||
dtype=torch.uint8,
|
||||
device=weight.device)
|
||||
scales = scales.view(torch.float8_e8m0fnu)
|
||||
|
||||
fp4_weight = torch.randint(0,
|
||||
256, (size_n, size_k // 2),
|
||||
dtype=torch.uint8,
|
||||
device=weight.device)
|
||||
fp4_weight_part_1 = ((fp4_weight & 0b10000000) |
|
||||
((fp4_weight & 0b01110000) >> 2))
|
||||
fp4_weight_part_1 = fp4_weight_part_1.view(torch.float8_e4m3fn)
|
||||
fp4_weight_part_1 = fp4_weight_part_1.to(weight.dtype) * (2**6)
|
||||
|
||||
fp4_weight2 = fp4_weight << 4
|
||||
fp4_weight_part_2 = ((fp4_weight2 & 0b10000000) |
|
||||
((fp4_weight2 & 0b01110000) >> 2))
|
||||
fp4_weight_part_2 = fp4_weight_part_2.view(torch.float8_e4m3fn)
|
||||
fp4_weight_part_2 = fp4_weight_part_2.to(weight.dtype) * (2**6)
|
||||
|
||||
weight_ref = torch.cat(
|
||||
[fp4_weight_part_2.unsqueeze(2),
|
||||
fp4_weight_part_1.unsqueeze(2)], 2).view(size_n, size_k)
|
||||
weight_ref = weight_ref * \
|
||||
scales.repeat_interleave(group_size, 1).to(weight.dtype)
|
||||
|
||||
marlin_qweight = ops.gptq_marlin_repack(
|
||||
b_q_weight=fp4_weight.view(torch.int32).T.contiguous(),
|
||||
perm=torch.empty(0, dtype=torch.int, device=device),
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
num_bits=4,
|
||||
)
|
||||
|
||||
marlin_scales = marlin_permute_scales(s=scales.T.to(weight.dtype),
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
group_size=group_size)
|
||||
|
||||
marlin_scales = mxfp4_marlin_process_scales(marlin_scales)
|
||||
|
||||
return weight_ref.T, marlin_qweight, marlin_scales.to(torch.float8_e8m0fnu)
|
||||
|
||||
@@ -1,45 +1,133 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
OCP_MX_BLOCK_SIZE = 32
|
||||
|
||||
|
||||
def per_token_group_quant_mxfp4(x: torch.Tensor,
|
||||
block_k: int,
|
||||
scale_calculation_mode: str = "even"
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
def _swizzle_mxfp4(quant_tensor, scale, num_warps):
|
||||
""" weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel
|
||||
"""
|
||||
import triton_kernels.matmul_ogs_details.opt_flags as opt_flags
|
||||
from triton_kernels.numerics import InFlexData
|
||||
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
|
||||
from triton_kernels.tensor_details import layout
|
||||
from triton_kernels.tensor_details.layout import StridedLayout
|
||||
if (current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(90)
|
||||
and not is_torch_equal_or_newer("2.8.1")):
|
||||
logger.warning_once(
|
||||
"Mxfp4 on hopper is running on torch < 2.8.1, "
|
||||
"this cause swizling to be disabled, which may "
|
||||
"cause performance degradation. Please upgrade to torch nightly")
|
||||
value_layout, value_layout_opts = StridedLayout, dict()
|
||||
scale_layout, scale_layout_opts = StridedLayout, dict()
|
||||
else:
|
||||
value_layout, value_layout_opts = \
|
||||
layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
|
||||
scale_layout, scale_layout_opts = (
|
||||
layout.make_default_matmul_mxfp4_w_scale_layout(
|
||||
mx_axis=1, num_warps=num_warps))
|
||||
if current_platform.is_cuda() and \
|
||||
current_platform.is_device_capability(100):
|
||||
constraints = {
|
||||
"is_persistent": True,
|
||||
"epilogue_subtile": 1,
|
||||
}
|
||||
opt_flags.update_opt_flags_constraints(constraints)
|
||||
# transpose the tensor so that the quantization axis is on dim1
|
||||
quant_tensor = quant_tensor.transpose(-2, -1)
|
||||
scale = scale.transpose(-2, -1)
|
||||
quant_tensor = convert_layout(wrap_torch_tensor(quant_tensor, dtype=FP4),
|
||||
value_layout, **value_layout_opts)
|
||||
scale = convert_layout(wrap_torch_tensor(scale), scale_layout,
|
||||
**scale_layout_opts)
|
||||
return quant_tensor, InFlexData(), scale
|
||||
|
||||
|
||||
def _can_support_mxfp4(use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
scoring_func: str = "softmax",
|
||||
activation: str = "swiglu_oai",
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None):
|
||||
return not (use_grouped_topk or topk_group or num_expert_group
|
||||
or expert_map or custom_routing_function
|
||||
or e_score_correction_bias or apply_router_weight_on_input
|
||||
or scoring_func != "softmax" or activation != "swiglu_oai"
|
||||
or expert_load_view or logical_to_physical_map
|
||||
or logical_replica_count)
|
||||
|
||||
|
||||
def _dequant_mxfp4(x: torch.Tensor, scale: torch.Tensor,
|
||||
float_dtype: torch.dtype) -> torch.Tensor:
|
||||
try:
|
||||
from quark.torch.kernel.hw_emulation.hw_emulation_interface import (
|
||||
fake_quantize_fp4_fp6_per_group_with_scale)
|
||||
from quark.torch.quantization.utils import (even_round,
|
||||
reshape_to_blocks)
|
||||
from quark.torch.kernel import mx
|
||||
except ImportError as err:
|
||||
raise ImportError("The package `amd-quark` is required to use "
|
||||
"MX-FP4 models. Please install it with `pip install "
|
||||
"amd-quark`.") from err
|
||||
|
||||
axis = -1
|
||||
block_x = reshape_to_blocks(x, block_k, axis)
|
||||
amax, _ = torch.max(torch.abs(block_x), dim=-1, keepdim=True)
|
||||
amax = amax.squeeze(-1)
|
||||
return mx.dq_mxfp4(x, scale, float_dtype)
|
||||
|
||||
# TODO: there are other rounding strategies supported in quark and in the
|
||||
# config.json that we do not check for here!
|
||||
if scale_calculation_mode != "even":
|
||||
raise NotImplementedError(
|
||||
f"Scale calculation mode {scale_calculation_mode} is not yet "
|
||||
"supported in MX-FP4 quantization")
|
||||
scale = even_round(amax, "fp4")
|
||||
|
||||
# Apply dequantize(quantize(x)).
|
||||
x = fake_quantize_fp4_fp6_per_group_with_scale(
|
||||
x,
|
||||
scale.to(x.device),
|
||||
axis=axis,
|
||||
group_size=block_k,
|
||||
quant_dtype="fp4",
|
||||
def _dequant_mxfp4_fake(x: torch.Tensor, scale: torch.Tensor,
|
||||
float_dtype: torch.dtype) -> torch.Tensor:
|
||||
return torch.empty((*x.shape[:-1], x.shape[-1] * 2),
|
||||
dtype=float_dtype,
|
||||
device=x.device)
|
||||
|
||||
|
||||
def _quant_dequant_mxfp4(x: torch.Tensor,
|
||||
scale_calculation_mode: str = "even") -> torch.Tensor:
|
||||
try:
|
||||
from quark.torch.kernel import mx
|
||||
except ImportError as err:
|
||||
raise ImportError("The package `amd-quark` is required to use "
|
||||
"MX-FP4 models. Please install it with `pip install "
|
||||
"amd-quark`.") from err
|
||||
|
||||
return mx.qdq_mxfp4(x, scale_calculation_mode)
|
||||
|
||||
|
||||
def _quant_dequant_mxfp4_fake(x: torch.Tensor,
|
||||
scale_calculation_mode: str = "even"
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="dequant_mxfp4",
|
||||
op_func=_dequant_mxfp4,
|
||||
mutates_args=[],
|
||||
fake_impl=_dequant_mxfp4_fake,
|
||||
)
|
||||
dequant_mxfp4 = torch.ops.vllm.dequant_mxfp4
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
|
||||
return x, scale
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="quant_dequant_mxfp4",
|
||||
op_func=_quant_dequant_mxfp4,
|
||||
mutates_args=[],
|
||||
fake_impl=_quant_dequant_mxfp4_fake,
|
||||
)
|
||||
quant_dequant_mxfp4 = torch.ops.vllm.quant_dequant_mxfp4
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
|
||||
@@ -3,22 +3,41 @@
|
||||
"""This file is used for /tests and /benchmarks"""
|
||||
from collections.abc import Mapping
|
||||
from types import MappingProxyType
|
||||
from typing import Optional
|
||||
from typing import ClassVar, NamedTuple, Optional
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
from vllm._custom_ops import cutlass_scaled_mm_supports_fp4
|
||||
from vllm.model_executor.layers.quantization.qqq import (
|
||||
MARLIN_QQQ_SUPPORTED_NUM_BITS)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
|
||||
# Use proxy as NamedTuple direct subclasses cannot have static members
|
||||
class _GroupShape(NamedTuple):
|
||||
row: int
|
||||
col: int
|
||||
|
||||
|
||||
class GroupShape(_GroupShape):
|
||||
"""
|
||||
This class describes the quantization group shape.
|
||||
It includes static members for common shapes (per-tensor, per-token).
|
||||
"""
|
||||
|
||||
# Aliases for common quantization group shapes
|
||||
PER_TENSOR: ClassVar['GroupShape']
|
||||
PER_TOKEN: ClassVar['GroupShape']
|
||||
|
||||
|
||||
GroupShape.PER_TENSOR = GroupShape(-1, -1)
|
||||
GroupShape.PER_TOKEN = GroupShape(1, -1)
|
||||
|
||||
|
||||
# Normalize the group_shape to the full extent for any dims that are -1
|
||||
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: tuple[int,
|
||||
int]):
|
||||
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape):
|
||||
# -1 means full extent
|
||||
return (group_shape[0] if group_shape[0] > 0 else x.shape[-2],
|
||||
group_shape[1] if group_shape[1] > 0 else x.shape[-1])
|
||||
@@ -58,7 +77,7 @@ def group_broadcast(t, shape):
|
||||
# (i.e. per-token-per-group)
|
||||
def scaled_quantize(
|
||||
x: torch.Tensor,
|
||||
group_shape: tuple[int, int],
|
||||
group_shape: GroupShape,
|
||||
quant_dtype: torch.dtype,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
group_shape = _normalize_quant_group_shape(x, group_shape)
|
||||
@@ -99,7 +118,7 @@ def scaled_quantize(
|
||||
def scaled_dequantize(
|
||||
x_q: torch.Tensor,
|
||||
x_s: torch.Tensor,
|
||||
group_shape: Optional[tuple[int, int]] = None,
|
||||
group_shape: Optional[GroupShape] = None,
|
||||
out_dtype: torch.dtype = torch.float32,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if group_shape is not None:
|
||||
@@ -332,6 +351,10 @@ def quantize_weights(w: torch.Tensor,
|
||||
)
|
||||
|
||||
|
||||
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
|
||||
|
||||
def gptq_quantize_weights(w: torch.Tensor,
|
||||
quant_type: ScalarType,
|
||||
group_size: int,
|
||||
@@ -571,3 +594,56 @@ def awq_pack(
|
||||
q_w = q_w.reshape((-1, size_n)).contiguous()
|
||||
|
||||
return pack_cols(q_w, num_bits, size_k, size_n)
|
||||
|
||||
|
||||
def swizzle_blockscale(scale: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Pad and block-interleave the FP4 block-scales so that they match the data
|
||||
layout expected by the CUTLASS / FlashInfer kernels.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
scale: torch.Tensor
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
The swizzled tensor with the same logical shape as *scale*.
|
||||
"""
|
||||
assert scale.dtype == torch.float8_e4m3fn, (
|
||||
"swizzle_blockscale expects the input tensor to be in "
|
||||
"torch.float8_e4m3fn format.")
|
||||
|
||||
scale_ndim = scale.ndim
|
||||
if scale_ndim == 2:
|
||||
scale = scale.unsqueeze(0) # (1, M, K)
|
||||
assert scale.ndim == 3, "Expected a 2-D or 3-D tensor for block scales."
|
||||
|
||||
B, M, K = scale.shape
|
||||
|
||||
def _round_up(x: int, m: int) -> int:
|
||||
return (x + m - 1) // m * m
|
||||
|
||||
M_padded = _round_up(M, 128)
|
||||
K_padded = _round_up(K, 4)
|
||||
|
||||
padded = torch.zeros((B, M_padded, K_padded),
|
||||
dtype=scale.dtype,
|
||||
device=scale.device)
|
||||
padded[:B, :M, :K] = scale
|
||||
|
||||
# Reshape / permute to the layout required by the kernel.
|
||||
padded = padded.reshape(B, M_padded // 128, 4, 32, K_padded // 4, 4)
|
||||
swizzled = padded.permute(0, 1, 4, 3, 2, 5).contiguous().cuda()
|
||||
|
||||
if scale_ndim == 2:
|
||||
return swizzled.reshape(M, K)
|
||||
return swizzled.reshape(B, M, K)
|
||||
|
||||
|
||||
def cutlass_fp4_supported() -> bool:
|
||||
if not current_platform.is_cuda():
|
||||
return False
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
capability = -1 if capability_tuple is None else capability_tuple.to_int()
|
||||
return cutlass_scaled_mm_supports_fp4(capability)
|
||||
|
||||
Reference in New Issue
Block a user