This commit is contained in:
@@ -8,7 +8,6 @@ import torch.nn.functional as F
|
||||
from aiter.ops.gemm_op_a4w4 import gemm_a4w4
|
||||
from aiter.ops.shuffle import shuffle_weight
|
||||
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
|
||||
from aiter.ops.triton.gemm_afp4wfp4_pre_quant_atomic import gemm_afp4wfp4_pre_quant
|
||||
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
||||
from aiter.utility import dtypes
|
||||
from aiter.utility.fp4_utils import e8m0_shuffle
|
||||
@@ -39,6 +38,15 @@ class QuarkW4A4MXFP4(QuarkScheme):
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
return
|
||||
|
||||
# for aiter implement
|
||||
# wshuffle = shuffle_weight(layer.weight.data, layout=(16, 16))
|
||||
# w_scales_shuffle = e8m0_shuffle(layer.weight_scale.data).view(dtypes.fp8_e8m0)
|
||||
|
||||
# layer.weight = torch.nn.Parameter(wshuffle,
|
||||
# requires_grad=False)
|
||||
# layer.weight_scale = torch.nn.Parameter(w_scales_shuffle,
|
||||
# requires_grad=False)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -85,53 +93,26 @@ class QuarkW4A4MXFP4(QuarkScheme):
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# This path does not have support for bias currently
|
||||
assert bias is None, "bias is not supported"
|
||||
|
||||
three_d = False
|
||||
x_s = None
|
||||
y = None
|
||||
if isinstance(x, tuple):
|
||||
assert len(x) in [
|
||||
2,
|
||||
3,
|
||||
], "For tuple input, only (x, x_s) or (x, x_s, y) formats are accepted"
|
||||
if len(x) == 2:
|
||||
x, x_s = x
|
||||
elif len(x) == 3:
|
||||
x, x_s, y = x
|
||||
out_dtype = x.dtype
|
||||
# M = x.shape[0]
|
||||
# N = layer.weight.shape[0]
|
||||
|
||||
use_fused_quant_gemm = (
|
||||
x_s is None and y is not None and layer.weight.shape[0] == y.shape[1]
|
||||
# quant_func = aiter.get_triton_quant(aiter.QuantType.per_1x32)
|
||||
# x, x_scales_shuffle = quant_func(x, shuffle=True)
|
||||
|
||||
# y = torch.zeros((M + 255) // 256 * 256, N, device=x.device, dtype=self.out_dtype)
|
||||
|
||||
# out = gemm_a4w4(x, layer.weight.data, x_scales_shuffle, layer.weight_scale.data, y, bias=bias)
|
||||
|
||||
# return out[:M]
|
||||
|
||||
# triton implement
|
||||
x_q, x_s = dynamic_mxfp4_quant(x)
|
||||
y = torch.empty(
|
||||
x_q.shape[0], layer.weight.shape[0], device=x_q.device, dtype=out_dtype
|
||||
)
|
||||
|
||||
if x.dim() == 3:
|
||||
three_d = True
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output_shape = [*x.shape[:-1], layer.weight.shape[0]]
|
||||
out = gemm_afp4wfp4(x_q, layer.weight, x_s, layer.weight_scale, out_dtype, y)
|
||||
|
||||
# use_fused_quant_gemm = true, x_q is a bf16/fp16 num
|
||||
# x_s is not None = true, x_q is uint8 num
|
||||
if use_fused_quant_gemm or x_s is not None:
|
||||
x_q = x
|
||||
else:
|
||||
x_q, x_s = dynamic_mxfp4_quant(x)
|
||||
|
||||
if y is None:
|
||||
y = torch.empty(
|
||||
x_q.shape[0],
|
||||
layer.weight.shape[0],
|
||||
device=x_q.device,
|
||||
dtype=self.out_dtype,
|
||||
)
|
||||
|
||||
if use_fused_quant_gemm:
|
||||
gemm_afp4wfp4_pre_quant(x_q, layer.weight, layer.weight_scale, y.dtype, y)
|
||||
y = y.to(x.dtype)
|
||||
else:
|
||||
gemm_afp4wfp4(x_q, layer.weight, x_s, layer.weight_scale, self.out_dtype, y)
|
||||
|
||||
if three_d:
|
||||
return y.view(*output_shape)
|
||||
|
||||
return y
|
||||
return out
|
||||
|
||||
@@ -5,10 +5,6 @@ from collections.abc import Iterable, Mapping
|
||||
from types import MappingProxyType
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
||||
from torch import nn
|
||||
|
||||
|
||||
def deep_compare(dict1: Any, dict2: Any) -> bool:
|
||||
if type(dict1) is not type(dict2):
|
||||
@@ -109,96 +105,3 @@ def _is_equal_or_regex_match(
|
||||
elif target == value:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# utility for tensor dims > 2 cases
|
||||
def b_dynamic_mxfp4_quant(x):
|
||||
h, b, d = x.shape
|
||||
x, x_scales = dynamic_mxfp4_quant(x.reshape(-1, d))
|
||||
return x.view(h, b, d // 2), x_scales.view(h, b, d // 32)
|
||||
|
||||
|
||||
def mxfp4_to_f32(x, is_threed):
|
||||
# 2 because we pack fp4 in uint8.
|
||||
x = x.repeat_interleave(2, dim=-1)
|
||||
if is_threed:
|
||||
x[..., ::2] = x[..., ::2] & 0xF
|
||||
x[..., 1::2] = x[..., 1::2] >> 4
|
||||
else:
|
||||
x[:, ::2] = x[:, ::2] & 0xF
|
||||
x[:, 1::2] = x[:, 1::2] >> 4
|
||||
|
||||
mxfp4_list = [
|
||||
0.0,
|
||||
0.5,
|
||||
1.0,
|
||||
1.5,
|
||||
2.0,
|
||||
3.0,
|
||||
4.0,
|
||||
6.0,
|
||||
-0.0,
|
||||
-0.5,
|
||||
-1.0,
|
||||
-1.5,
|
||||
-2.0,
|
||||
-3.0,
|
||||
-4.0,
|
||||
-6.0,
|
||||
]
|
||||
mxfp4_in_f32 = torch.tensor(mxfp4_list, dtype=torch.float32, device="cuda")
|
||||
return mxfp4_in_f32[x.long()]
|
||||
|
||||
|
||||
def e8m0_to_f32(x):
|
||||
# Convert the input tensor `x` (assumed to be in e8m0 format) to float32.
|
||||
# e8m0 is a custom 8-bit floating point format with 8 bits for exponent, 0 for mantissa.
|
||||
# This means the value is essentially 2^(exponent - 127), similar to how IEEE-754 stores floats.
|
||||
|
||||
# Convert x to float32 for computation, and compute the power of 2 by subtracting the bias (127).
|
||||
x_f32 = 2 ** ((x.to(torch.float32)) - 127)
|
||||
|
||||
# If the exponent value was 255 (i.e., 2^(128)), this is a special case usually used to represent NaN or Inf.
|
||||
# Since this custom format has no mantissa, treat 2^128 as NaN.
|
||||
x_f32[x_f32 == 128] = float("nan")
|
||||
return x_f32
|
||||
|
||||
|
||||
def quark_post_load_weights(self_attn: nn.Module, w: torch.Tensor, quant_format: str):
|
||||
if "mxfp4" in quant_format:
|
||||
# when dtype is bf16, the processing flow is to dynamic quantize bf16 tensor to uint8 tensor
|
||||
# do w_kc (bf16) first to get the w_kc(uint8) w_s_kc(uint8)
|
||||
# and w_vc repeating the same procedure of w_kc to get w_vc(uint8) w_s_vc(uint8)
|
||||
if w.dtype == torch.bfloat16:
|
||||
w_kc, w_vc = w.unflatten(
|
||||
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
||||
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
||||
w_kc, w_s_kc = b_dynamic_mxfp4_quant(w_kc.transpose(-2, -1))
|
||||
w_kc = w_kc.transpose(-2, -1)
|
||||
w_s_kc = w_s_kc.transpose(-2, -1)
|
||||
w_vc, w_s_vc = b_dynamic_mxfp4_quant(w_vc)
|
||||
w_s_kc = w_s_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
||||
w_s_vc = w_s_vc.contiguous().transpose(1, 2)
|
||||
elif w.dtype == torch.uint8: # static quant for mxfp4
|
||||
# when dtype is uint8, it means the w has been quantized to mxfp4 format
|
||||
# but we must separate it to w_kc and w_vc.
|
||||
# The quantized tensor size is only half of original tensor size
|
||||
# and the scaling factor is 1/32, the transpose behavior will be not correct
|
||||
# need to upcast it to fp32 to separate w to w_kc and w_vc
|
||||
# to ensure the following transpose behavior is correct
|
||||
# and then do mxfp4 quant again
|
||||
w = mxfp4_to_f32(w, True).to(torch.bfloat16)
|
||||
w_scales = self_attn.kv_b_proj.weight_scale.repeat_interleave(32, dim=-1)
|
||||
w_scales = e8m0_to_f32(w_scales).to(torch.bfloat16)
|
||||
w = w * w_scales
|
||||
w_kc, w_vc = w.unflatten(
|
||||
0, (-1, (self_attn.qk_nope_head_dim + self_attn.v_head_dim))
|
||||
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
||||
w_kc, w_s_kc = b_dynamic_mxfp4_quant(w_kc.transpose(-2, -1))
|
||||
w_kc = w_kc.transpose(-2, -1)
|
||||
w_s_kc = w_s_kc.transpose(-2, -1)
|
||||
w_vc, w_s_vc = b_dynamic_mxfp4_quant(w_vc)
|
||||
w_s_kc = w_s_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
||||
w_s_vc = w_s_vc.contiguous().transpose(1, 2)
|
||||
|
||||
return w_kc, w_s_kc, w_vc, w_s_vc
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
from aiter.ops.triton.batched_gemm_afp4wfp4_pre_quant import (
|
||||
batched_gemm_afp4wfp4_pre_quant,
|
||||
)
|
||||
from aiter.ops.triton.fused_mxfp4_quant import (
|
||||
fused_flatten_mxfp4_quant,
|
||||
fused_rms_mxfp4_quant,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"fused_rms_mxfp4_quant",
|
||||
"fused_flatten_mxfp4_quant",
|
||||
"batched_gemm_afp4wfp4_pre_quant",
|
||||
]
|
||||
Reference in New Issue
Block a user