[gpt-oss] Add gpt-oss mxfp4 support
This commit is contained in:
@@ -261,6 +261,13 @@ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
|
||||
return s
|
||||
|
||||
|
||||
def marlin_permute_bias(s: torch.Tensor) -> torch.Tensor:
|
||||
origin_shape = s.shape
|
||||
_, scale_perm_single = get_scale_perms()
|
||||
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
|
||||
return s.reshape(*origin_shape).contiguous()
|
||||
|
||||
|
||||
def marlin_moe_permute_scales(
|
||||
s: torch.Tensor,
|
||||
size_k: int,
|
||||
@@ -410,6 +417,7 @@ def apply_gptq_marlin_linear(
|
||||
output = ops.gptq_marlin_gemm(reshaped_x,
|
||||
None,
|
||||
weight,
|
||||
bias,
|
||||
weight_scale,
|
||||
None,
|
||||
weight_zp,
|
||||
@@ -425,9 +433,6 @@ def apply_gptq_marlin_linear(
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=False)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
|
||||
|
||||
@@ -456,6 +461,7 @@ def apply_awq_marlin_linear(
|
||||
output = ops.gptq_marlin_gemm(reshaped_x,
|
||||
None,
|
||||
weight,
|
||||
bias,
|
||||
weight_scale,
|
||||
None,
|
||||
weight_zp,
|
||||
@@ -470,7 +476,4 @@ def apply_awq_marlin_linear(
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=False)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
|
||||
@@ -8,8 +8,8 @@ import torch
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales,
|
||||
should_use_atomic_add_reduce)
|
||||
USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_bias,
|
||||
marlin_permute_scales, should_use_atomic_add_reduce)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
@@ -22,7 +22,7 @@ def is_fp4_marlin_supported():
|
||||
return current_platform.has_device_capability(80)
|
||||
|
||||
|
||||
def fp4_marlin_process_scales(marlin_scales):
|
||||
def nvfp4_marlin_process_scales(marlin_scales):
|
||||
if not (marlin_scales >= 0).all():
|
||||
logger.warning_once(
|
||||
"NVFP4 Marlin assumes the scales to be >=0, but has encountered "
|
||||
@@ -56,7 +56,20 @@ def fp4_marlin_process_scales(marlin_scales):
|
||||
return marlin_scales
|
||||
|
||||
|
||||
def fp4_marlin_process_global_scale(global_scale):
|
||||
def mxfp4_marlin_process_scales(marlin_scales):
|
||||
# 8 is the number of scale number using by one thread
|
||||
marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8)
|
||||
marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape(
|
||||
marlin_scales.size(0) * 2, -1)
|
||||
|
||||
# fit the layout of fp8 dequantization
|
||||
marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view(
|
||||
marlin_scales.size(0), -1)
|
||||
marlin_scales = marlin_scales.to(torch.float8_e8m0fnu)
|
||||
return marlin_scales
|
||||
|
||||
|
||||
def nvfp4_marlin_process_global_scale(global_scale):
|
||||
assert global_scale.dtype in [torch.half, torch.bfloat16]
|
||||
fp4_exponent = 2
|
||||
if global_scale.dtype == torch.half:
|
||||
@@ -73,7 +86,7 @@ def apply_fp4_marlin_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
weight_scale_2: torch.Tensor,
|
||||
weight_scale_2: Optional[torch.Tensor],
|
||||
workspace: torch.Tensor,
|
||||
size_n: int,
|
||||
size_k: int,
|
||||
@@ -94,6 +107,7 @@ def apply_fp4_marlin_linear(
|
||||
output = ops.gptq_marlin_gemm(a=reshaped_x,
|
||||
c=None,
|
||||
b_q_weight=weight,
|
||||
b_bias=bias,
|
||||
b_scales=weight_scale,
|
||||
global_scale=weight_scale_2,
|
||||
b_zeros=None,
|
||||
@@ -107,9 +121,6 @@ def apply_fp4_marlin_linear(
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=use_fp32_reduce)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
|
||||
|
||||
@@ -120,6 +131,9 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
|
||||
"be used leveraging the Marlin kernel. This may degrade "
|
||||
"performance for compute-heavy workloads.")
|
||||
|
||||
is_nvfp4 = hasattr(layer, "weight_scale_2")
|
||||
group_size = 16 if is_nvfp4 else 32
|
||||
|
||||
part_size_n = layer.output_size_per_partition
|
||||
part_size_k = layer.input_size_per_partition
|
||||
param_dtype = layer.params_dtype
|
||||
@@ -145,18 +159,35 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
|
||||
|
||||
# WEIGHT SCALES
|
||||
# Permute scales
|
||||
weight_scale = layer.weight_scale.T.to(param_dtype)
|
||||
weight_scale = layer.weight_scale.T.contiguous()
|
||||
|
||||
if not is_nvfp4:
|
||||
weight_scale = weight_scale.view(torch.float8_e8m0fnu)
|
||||
|
||||
weight_scale = weight_scale.to(param_dtype)
|
||||
weight_scale = marlin_permute_scales(s=weight_scale,
|
||||
size_k=part_size_k,
|
||||
size_n=part_size_n,
|
||||
group_size=16)
|
||||
weight_scale = fp4_marlin_process_scales(weight_scale)
|
||||
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
|
||||
group_size=group_size)
|
||||
|
||||
weight_scale_2 = layer.weight_scale_2.to(param_dtype)
|
||||
weight_scale_2 = fp4_marlin_process_global_scale(weight_scale_2)
|
||||
layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2,
|
||||
requires_grad=False)
|
||||
if is_nvfp4:
|
||||
weight_scale = nvfp4_marlin_process_scales(weight_scale)
|
||||
layer.weight_scale = torch.nn.Parameter(weight_scale,
|
||||
requires_grad=False)
|
||||
|
||||
weight_scale_2 = layer.weight_scale_2.to(param_dtype)
|
||||
weight_scale_2 = nvfp4_marlin_process_global_scale(weight_scale_2)
|
||||
layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2,
|
||||
requires_grad=False)
|
||||
else:
|
||||
weight_scale = mxfp4_marlin_process_scales(weight_scale)
|
||||
layer.weight_scale = torch.nn.Parameter(weight_scale,
|
||||
requires_grad=False)
|
||||
|
||||
if hasattr(layer, "bias") and layer.bias is not None:
|
||||
assert layer.bias.shape == (part_size_n, )
|
||||
bias = marlin_permute_bias(layer.bias)
|
||||
layer.bias = torch.nn.Parameter(bias, requires_grad=False)
|
||||
|
||||
return
|
||||
|
||||
@@ -168,6 +199,9 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
|
||||
"be used leveraging the Marlin kernel. This may degrade "
|
||||
"performance for compute-heavy workloads.")
|
||||
|
||||
is_nvfp4 = hasattr(layer, "w13_weight_scale_2")
|
||||
group_size = 16 if is_nvfp4 else 32
|
||||
|
||||
e = layer.num_experts
|
||||
k = layer.hidden_size
|
||||
n = layer.intermediate_size_per_partition
|
||||
@@ -208,8 +242,13 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
|
||||
# WEIGHT SCALES
|
||||
# Permute scales
|
||||
for name in ["w13", "w2"]:
|
||||
scales = getattr(layer, name + "_weight_scale").to(param_dtype)
|
||||
global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype)
|
||||
scales = getattr(layer, name + "_weight_scale")
|
||||
if not is_nvfp4:
|
||||
scales = scales.view(torch.float8_e8m0fnu)
|
||||
scales = scales.to(param_dtype)
|
||||
if is_nvfp4:
|
||||
global_scale = getattr(layer,
|
||||
name + "_weight_scale_2").to(param_dtype)
|
||||
|
||||
tensor_list = []
|
||||
if "w13" in name:
|
||||
@@ -218,23 +257,47 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
|
||||
size_n, size_k = k, n
|
||||
|
||||
for i in range(e):
|
||||
marlin_scales = marlin_permute_scales(s=scales[i].T,
|
||||
scale = scales[i].T
|
||||
|
||||
marlin_scales = marlin_permute_scales(s=scale,
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
group_size=16)
|
||||
marlin_scales = fp4_marlin_process_scales(marlin_scales)
|
||||
group_size=group_size)
|
||||
if is_nvfp4:
|
||||
marlin_scales = nvfp4_marlin_process_scales(marlin_scales)
|
||||
else:
|
||||
marlin_scales = mxfp4_marlin_process_scales(marlin_scales)
|
||||
tensor_list.append(marlin_scales)
|
||||
|
||||
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
||||
scales = torch.nn.Parameter(scales, requires_grad=False)
|
||||
setattr(layer, name + "_weight_scale", scales)
|
||||
|
||||
global_scale = fp4_marlin_process_global_scale(global_scale)
|
||||
global_scale = torch.nn.Parameter(global_scale, requires_grad=False)
|
||||
setattr(layer, name + "_weight_scale_2", global_scale)
|
||||
if is_nvfp4:
|
||||
global_scale = nvfp4_marlin_process_global_scale(global_scale)
|
||||
global_scale = torch.nn.Parameter(global_scale,
|
||||
requires_grad=False)
|
||||
setattr(layer, name + "_weight_scale_2", global_scale)
|
||||
|
||||
# BIAS
|
||||
# Permute bias
|
||||
for name in ["w13_bias", "w2_bias"]:
|
||||
if not hasattr(layer, name):
|
||||
continue
|
||||
bias = getattr(layer, name).to(param_dtype)
|
||||
|
||||
tensor_list = []
|
||||
for i in range(e):
|
||||
expert_bias = bias[i]
|
||||
|
||||
tensor_list.append(marlin_permute_bias(expert_bias))
|
||||
|
||||
bias = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
||||
bias = torch.nn.Parameter(bias, requires_grad=False)
|
||||
setattr(layer, name, bias)
|
||||
|
||||
|
||||
def rand_marlin_weight_fp4_like(weight, group_size):
|
||||
def rand_marlin_weight_nvfp4_like(weight, group_size):
|
||||
assert group_size > 0
|
||||
size_n, size_k = weight.shape
|
||||
device = weight.device
|
||||
@@ -276,8 +339,58 @@ def rand_marlin_weight_fp4_like(weight, group_size):
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
group_size=group_size)
|
||||
marlin_scales = fp4_marlin_process_scales(marlin_scales)
|
||||
marlin_scales = nvfp4_marlin_process_scales(marlin_scales)
|
||||
|
||||
global_scale = fp4_marlin_process_global_scale(global_scale)
|
||||
global_scale = nvfp4_marlin_process_global_scale(global_scale)
|
||||
|
||||
return weight_ref.T, marlin_qweight, marlin_scales, global_scale
|
||||
|
||||
|
||||
def rand_marlin_weight_mxfp4_like(weight, group_size):
|
||||
assert group_size > 0
|
||||
size_n, size_k = weight.shape
|
||||
device = weight.device
|
||||
|
||||
scales = torch.randint(100,
|
||||
125, (size_n, size_k // group_size),
|
||||
dtype=torch.uint8,
|
||||
device=weight.device)
|
||||
scales = scales.view(torch.float8_e8m0fnu)
|
||||
|
||||
fp4_weight = torch.randint(0,
|
||||
256, (size_n, size_k // 2),
|
||||
dtype=torch.uint8,
|
||||
device=weight.device)
|
||||
fp4_weight_part_1 = ((fp4_weight & 0b10000000) |
|
||||
((fp4_weight & 0b01110000) >> 2))
|
||||
fp4_weight_part_1 = fp4_weight_part_1.view(torch.float8_e4m3fn)
|
||||
fp4_weight_part_1 = fp4_weight_part_1.to(weight.dtype) * (2**6)
|
||||
|
||||
fp4_weight2 = fp4_weight << 4
|
||||
fp4_weight_part_2 = ((fp4_weight2 & 0b10000000) |
|
||||
((fp4_weight2 & 0b01110000) >> 2))
|
||||
fp4_weight_part_2 = fp4_weight_part_2.view(torch.float8_e4m3fn)
|
||||
fp4_weight_part_2 = fp4_weight_part_2.to(weight.dtype) * (2**6)
|
||||
|
||||
weight_ref = torch.cat(
|
||||
[fp4_weight_part_2.unsqueeze(2),
|
||||
fp4_weight_part_1.unsqueeze(2)], 2).view(size_n, size_k)
|
||||
weight_ref = weight_ref * \
|
||||
scales.repeat_interleave(group_size, 1).to(weight.dtype)
|
||||
|
||||
marlin_qweight = ops.gptq_marlin_repack(
|
||||
b_q_weight=fp4_weight.view(torch.int32).T.contiguous(),
|
||||
perm=torch.empty(0, dtype=torch.int, device=device),
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
num_bits=4,
|
||||
)
|
||||
|
||||
marlin_scales = marlin_permute_scales(s=scales.T.to(weight.dtype),
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
group_size=group_size)
|
||||
|
||||
marlin_scales = mxfp4_marlin_process_scales(marlin_scales)
|
||||
|
||||
return weight_ref.T, marlin_qweight, marlin_scales.to(torch.float8_e8m0fnu)
|
||||
|
||||
@@ -1,45 +1,133 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
OCP_MX_BLOCK_SIZE = 32
|
||||
|
||||
|
||||
def per_token_group_quant_mxfp4(x: torch.Tensor,
|
||||
block_k: int,
|
||||
scale_calculation_mode: str = "even"
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
def _swizzle_mxfp4(quant_tensor, scale, num_warps):
|
||||
""" weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel
|
||||
"""
|
||||
import triton_kernels.matmul_ogs_details.opt_flags as opt_flags
|
||||
from triton_kernels.numerics import InFlexData
|
||||
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
|
||||
from triton_kernels.tensor_details import layout
|
||||
from triton_kernels.tensor_details.layout import StridedLayout
|
||||
if (current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(90)
|
||||
and not is_torch_equal_or_newer("2.8.1")):
|
||||
logger.warning_once(
|
||||
"Mxfp4 on hopper is running on torch < 2.8.1, "
|
||||
"this cause swizling to be disabled, which may "
|
||||
"cause performance degradation. Please upgrade to torch nightly")
|
||||
value_layout, value_layout_opts = StridedLayout, dict()
|
||||
scale_layout, scale_layout_opts = StridedLayout, dict()
|
||||
else:
|
||||
value_layout, value_layout_opts = \
|
||||
layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
|
||||
scale_layout, scale_layout_opts = (
|
||||
layout.make_default_matmul_mxfp4_w_scale_layout(
|
||||
mx_axis=1, num_warps=num_warps))
|
||||
if current_platform.is_cuda() and \
|
||||
current_platform.is_device_capability(100):
|
||||
constraints = {
|
||||
"is_persistent": True,
|
||||
"epilogue_subtile": 1,
|
||||
}
|
||||
opt_flags.update_opt_flags_constraints(constraints)
|
||||
# transpose the tensor so that the quantization axis is on dim1
|
||||
quant_tensor = quant_tensor.transpose(-2, -1)
|
||||
scale = scale.transpose(-2, -1)
|
||||
quant_tensor = convert_layout(wrap_torch_tensor(quant_tensor, dtype=FP4),
|
||||
value_layout, **value_layout_opts)
|
||||
scale = convert_layout(wrap_torch_tensor(scale), scale_layout,
|
||||
**scale_layout_opts)
|
||||
return quant_tensor, InFlexData(), scale
|
||||
|
||||
|
||||
def _can_support_mxfp4(use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
scoring_func: str = "softmax",
|
||||
activation: str = "swiglu_oai",
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None):
|
||||
return not (use_grouped_topk or topk_group or num_expert_group
|
||||
or expert_map or custom_routing_function
|
||||
or e_score_correction_bias or apply_router_weight_on_input
|
||||
or scoring_func != "softmax" or activation != "swiglu_oai"
|
||||
or expert_load_view or logical_to_physical_map
|
||||
or logical_replica_count)
|
||||
|
||||
|
||||
def _dequant_mxfp4(x: torch.Tensor, scale: torch.Tensor,
|
||||
float_dtype: torch.dtype) -> torch.Tensor:
|
||||
try:
|
||||
from quark.torch.kernel.hw_emulation.hw_emulation_interface import (
|
||||
fake_quantize_fp4_fp6_per_group_with_scale)
|
||||
from quark.torch.quantization.utils import (even_round,
|
||||
reshape_to_blocks)
|
||||
from quark.torch.kernel import mx
|
||||
except ImportError as err:
|
||||
raise ImportError("The package `amd-quark` is required to use "
|
||||
"MX-FP4 models. Please install it with `pip install "
|
||||
"amd-quark`.") from err
|
||||
|
||||
axis = -1
|
||||
block_x = reshape_to_blocks(x, block_k, axis)
|
||||
amax, _ = torch.max(torch.abs(block_x), dim=-1, keepdim=True)
|
||||
amax = amax.squeeze(-1)
|
||||
return mx.dq_mxfp4(x, scale, float_dtype)
|
||||
|
||||
# TODO: there are other rounding strategies supported in quark and in the
|
||||
# config.json that we do not check for here!
|
||||
if scale_calculation_mode != "even":
|
||||
raise NotImplementedError(
|
||||
f"Scale calculation mode {scale_calculation_mode} is not yet "
|
||||
"supported in MX-FP4 quantization")
|
||||
scale = even_round(amax, "fp4")
|
||||
|
||||
# Apply dequantize(quantize(x)).
|
||||
x = fake_quantize_fp4_fp6_per_group_with_scale(
|
||||
x,
|
||||
scale.to(x.device),
|
||||
axis=axis,
|
||||
group_size=block_k,
|
||||
quant_dtype="fp4",
|
||||
def _dequant_mxfp4_fake(x: torch.Tensor, scale: torch.Tensor,
|
||||
float_dtype: torch.dtype) -> torch.Tensor:
|
||||
return torch.empty((*x.shape[:-1], x.shape[-1] * 2),
|
||||
dtype=float_dtype,
|
||||
device=x.device)
|
||||
|
||||
|
||||
def _quant_dequant_mxfp4(x: torch.Tensor,
|
||||
scale_calculation_mode: str = "even") -> torch.Tensor:
|
||||
try:
|
||||
from quark.torch.kernel import mx
|
||||
except ImportError as err:
|
||||
raise ImportError("The package `amd-quark` is required to use "
|
||||
"MX-FP4 models. Please install it with `pip install "
|
||||
"amd-quark`.") from err
|
||||
|
||||
return mx.qdq_mxfp4(x, scale_calculation_mode)
|
||||
|
||||
|
||||
def _quant_dequant_mxfp4_fake(x: torch.Tensor,
|
||||
scale_calculation_mode: str = "even"
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="dequant_mxfp4",
|
||||
op_func=_dequant_mxfp4,
|
||||
mutates_args=[],
|
||||
fake_impl=_dequant_mxfp4_fake,
|
||||
)
|
||||
dequant_mxfp4 = torch.ops.vllm.dequant_mxfp4
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
|
||||
return x, scale
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="quant_dequant_mxfp4",
|
||||
op_func=_quant_dequant_mxfp4,
|
||||
mutates_args=[],
|
||||
fake_impl=_quant_dequant_mxfp4_fake,
|
||||
)
|
||||
quant_dequant_mxfp4 = torch.ops.vllm.quant_dequant_mxfp4
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
|
||||
@@ -3,22 +3,41 @@
|
||||
"""This file is used for /tests and /benchmarks"""
|
||||
from collections.abc import Mapping
|
||||
from types import MappingProxyType
|
||||
from typing import Optional
|
||||
from typing import ClassVar, NamedTuple, Optional
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
from vllm._custom_ops import cutlass_scaled_mm_supports_fp4
|
||||
from vllm.model_executor.layers.quantization.qqq import (
|
||||
MARLIN_QQQ_SUPPORTED_NUM_BITS)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
|
||||
# Use proxy as NamedTuple direct subclasses cannot have static members
|
||||
class _GroupShape(NamedTuple):
|
||||
row: int
|
||||
col: int
|
||||
|
||||
|
||||
class GroupShape(_GroupShape):
|
||||
"""
|
||||
This class describes the quantization group shape.
|
||||
It includes static members for common shapes (per-tensor, per-token).
|
||||
"""
|
||||
|
||||
# Aliases for common quantization group shapes
|
||||
PER_TENSOR: ClassVar['GroupShape']
|
||||
PER_TOKEN: ClassVar['GroupShape']
|
||||
|
||||
|
||||
GroupShape.PER_TENSOR = GroupShape(-1, -1)
|
||||
GroupShape.PER_TOKEN = GroupShape(1, -1)
|
||||
|
||||
|
||||
# Normalize the group_shape to the full extent for any dims that are -1
|
||||
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: tuple[int,
|
||||
int]):
|
||||
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape):
|
||||
# -1 means full extent
|
||||
return (group_shape[0] if group_shape[0] > 0 else x.shape[-2],
|
||||
group_shape[1] if group_shape[1] > 0 else x.shape[-1])
|
||||
@@ -58,7 +77,7 @@ def group_broadcast(t, shape):
|
||||
# (i.e. per-token-per-group)
|
||||
def scaled_quantize(
|
||||
x: torch.Tensor,
|
||||
group_shape: tuple[int, int],
|
||||
group_shape: GroupShape,
|
||||
quant_dtype: torch.dtype,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
group_shape = _normalize_quant_group_shape(x, group_shape)
|
||||
@@ -99,7 +118,7 @@ def scaled_quantize(
|
||||
def scaled_dequantize(
|
||||
x_q: torch.Tensor,
|
||||
x_s: torch.Tensor,
|
||||
group_shape: Optional[tuple[int, int]] = None,
|
||||
group_shape: Optional[GroupShape] = None,
|
||||
out_dtype: torch.dtype = torch.float32,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if group_shape is not None:
|
||||
@@ -332,6 +351,10 @@ def quantize_weights(w: torch.Tensor,
|
||||
)
|
||||
|
||||
|
||||
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
|
||||
|
||||
def gptq_quantize_weights(w: torch.Tensor,
|
||||
quant_type: ScalarType,
|
||||
group_size: int,
|
||||
@@ -571,3 +594,56 @@ def awq_pack(
|
||||
q_w = q_w.reshape((-1, size_n)).contiguous()
|
||||
|
||||
return pack_cols(q_w, num_bits, size_k, size_n)
|
||||
|
||||
|
||||
def swizzle_blockscale(scale: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Pad and block-interleave the FP4 block-scales so that they match the data
|
||||
layout expected by the CUTLASS / FlashInfer kernels.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
scale: torch.Tensor
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
The swizzled tensor with the same logical shape as *scale*.
|
||||
"""
|
||||
assert scale.dtype == torch.float8_e4m3fn, (
|
||||
"swizzle_blockscale expects the input tensor to be in "
|
||||
"torch.float8_e4m3fn format.")
|
||||
|
||||
scale_ndim = scale.ndim
|
||||
if scale_ndim == 2:
|
||||
scale = scale.unsqueeze(0) # (1, M, K)
|
||||
assert scale.ndim == 3, "Expected a 2-D or 3-D tensor for block scales."
|
||||
|
||||
B, M, K = scale.shape
|
||||
|
||||
def _round_up(x: int, m: int) -> int:
|
||||
return (x + m - 1) // m * m
|
||||
|
||||
M_padded = _round_up(M, 128)
|
||||
K_padded = _round_up(K, 4)
|
||||
|
||||
padded = torch.zeros((B, M_padded, K_padded),
|
||||
dtype=scale.dtype,
|
||||
device=scale.device)
|
||||
padded[:B, :M, :K] = scale
|
||||
|
||||
# Reshape / permute to the layout required by the kernel.
|
||||
padded = padded.reshape(B, M_padded // 128, 4, 32, K_padded // 4, 4)
|
||||
swizzled = padded.permute(0, 1, 4, 3, 2, 5).contiguous().cuda()
|
||||
|
||||
if scale_ndim == 2:
|
||||
return swizzled.reshape(M, K)
|
||||
return swizzled.reshape(B, M, K)
|
||||
|
||||
|
||||
def cutlass_fp4_supported() -> bool:
|
||||
if not current_platform.is_cuda():
|
||||
return False
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
capability = -1 if capability_tuple is None else capability_tuple.to_int()
|
||||
return cutlass_scaled_mm_supports_fp4(capability)
|
||||
|
||||
Reference in New Issue
Block a user