Revert "Optimized deepseek-v3/r1 model performance on mxfp4 run (#9671)" (#9959)

This commit is contained in:
Yineng Zhang
2025-09-03 00:50:04 -07:00
committed by GitHub
parent 2c7ca33abb
commit 1b2ff4fb7f
7 changed files with 59 additions and 455 deletions

View File

@@ -8,7 +8,6 @@ import torch.nn.functional as F
from aiter.ops.gemm_op_a4w4 import gemm_a4w4
from aiter.ops.shuffle import shuffle_weight
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
from aiter.ops.triton.gemm_afp4wfp4_pre_quant_atomic import gemm_afp4wfp4_pre_quant
from aiter.ops.triton.quant import dynamic_mxfp4_quant
from aiter.utility import dtypes
from aiter.utility.fp4_utils import e8m0_shuffle
@@ -39,6 +38,15 @@ class QuarkW4A4MXFP4(QuarkScheme):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
return
# for aiter implement
# wshuffle = shuffle_weight(layer.weight.data, layout=(16, 16))
# w_scales_shuffle = e8m0_shuffle(layer.weight_scale.data).view(dtypes.fp8_e8m0)
# layer.weight = torch.nn.Parameter(wshuffle,
# requires_grad=False)
# layer.weight_scale = torch.nn.Parameter(w_scales_shuffle,
# requires_grad=False)
def create_weights(
self,
layer: torch.nn.Module,
@@ -85,53 +93,26 @@ class QuarkW4A4MXFP4(QuarkScheme):
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# This path does not have support for bias currently
assert bias is None, "bias is not supported"
three_d = False
x_s = None
y = None
if isinstance(x, tuple):
assert len(x) in [
2,
3,
], "For tuple input, only (x, x_s) or (x, x_s, y) formats are accepted"
if len(x) == 2:
x, x_s = x
elif len(x) == 3:
x, x_s, y = x
out_dtype = x.dtype
# M = x.shape[0]
# N = layer.weight.shape[0]
use_fused_quant_gemm = (
x_s is None and y is not None and layer.weight.shape[0] == y.shape[1]
# quant_func = aiter.get_triton_quant(aiter.QuantType.per_1x32)
# x, x_scales_shuffle = quant_func(x, shuffle=True)
# y = torch.zeros((M + 255) // 256 * 256, N, device=x.device, dtype=self.out_dtype)
# out = gemm_a4w4(x, layer.weight.data, x_scales_shuffle, layer.weight_scale.data, y, bias=bias)
# return out[:M]
# triton implement
x_q, x_s = dynamic_mxfp4_quant(x)
y = torch.empty(
x_q.shape[0], layer.weight.shape[0], device=x_q.device, dtype=out_dtype
)
if x.dim() == 3:
three_d = True
x = x.view(-1, x.shape[-1])
output_shape = [*x.shape[:-1], layer.weight.shape[0]]
out = gemm_afp4wfp4(x_q, layer.weight, x_s, layer.weight_scale, out_dtype, y)
# use_fused_quant_gemm = true, x_q is a bf16/fp16 num
# x_s is not None = true, x_q is uint8 num
if use_fused_quant_gemm or x_s is not None:
x_q = x
else:
x_q, x_s = dynamic_mxfp4_quant(x)
if y is None:
y = torch.empty(
x_q.shape[0],
layer.weight.shape[0],
device=x_q.device,
dtype=self.out_dtype,
)
if use_fused_quant_gemm:
gemm_afp4wfp4_pre_quant(x_q, layer.weight, layer.weight_scale, y.dtype, y)
y = y.to(x.dtype)
else:
gemm_afp4wfp4(x_q, layer.weight, x_s, layer.weight_scale, self.out_dtype, y)
if three_d:
return y.view(*output_shape)
return y
return out

View File

@@ -5,10 +5,6 @@ from collections.abc import Iterable, Mapping
from types import MappingProxyType
from typing import Any, Optional
import torch
from aiter.ops.triton.quant import dynamic_mxfp4_quant
from torch import nn
def deep_compare(dict1: Any, dict2: Any) -> bool:
if type(dict1) is not type(dict2):
@@ -109,96 +105,3 @@ def _is_equal_or_regex_match(
elif target == value:
return True
return False
# utility for tensor dims > 2 cases
def b_dynamic_mxfp4_quant(x):
h, b, d = x.shape
x, x_scales = dynamic_mxfp4_quant(x.reshape(-1, d))
return x.view(h, b, d // 2), x_scales.view(h, b, d // 32)
def mxfp4_to_f32(x, is_threed):
# 2 because we pack fp4 in uint8.
x = x.repeat_interleave(2, dim=-1)
if is_threed:
x[..., ::2] = x[..., ::2] & 0xF
x[..., 1::2] = x[..., 1::2] >> 4
else:
x[:, ::2] = x[:, ::2] & 0xF
x[:, 1::2] = x[:, 1::2] >> 4
mxfp4_list = [
0.0,
0.5,
1.0,
1.5,
2.0,
3.0,
4.0,
6.0,
-0.0,
-0.5,
-1.0,
-1.5,
-2.0,
-3.0,
-4.0,
-6.0,
]
mxfp4_in_f32 = torch.tensor(mxfp4_list, dtype=torch.float32, device="cuda")
return mxfp4_in_f32[x.long()]
def e8m0_to_f32(x):
# Convert the input tensor `x` (assumed to be in e8m0 format) to float32.
# e8m0 is a custom 8-bit floating point format with 8 bits for exponent, 0 for mantissa.
# This means the value is essentially 2^(exponent - 127), similar to how IEEE-754 stores floats.
# Convert x to float32 for computation, and compute the power of 2 by subtracting the bias (127).
x_f32 = 2 ** ((x.to(torch.float32)) - 127)
# If the exponent value was 255 (i.e., 2^(128)), this is a special case usually used to represent NaN or Inf.
# Since this custom format has no mantissa, treat 2^128 as NaN.
x_f32[x_f32 == 128] = float("nan")
return x_f32
def quark_post_load_weights(self_attn: nn.Module, w: torch.Tensor, quant_format: str):
if "mxfp4" in quant_format:
# when dtype is bf16, the processing flow is to dynamic quantize bf16 tensor to uint8 tensor
# do w_kc (bf16) first to get the w_kc(uint8) w_s_kc(uint8)
# and w_vc repeating the same procedure of w_kc to get w_vc(uint8) w_s_vc(uint8)
if w.dtype == torch.bfloat16:
w_kc, w_vc = w.unflatten(
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
w_kc, w_s_kc = b_dynamic_mxfp4_quant(w_kc.transpose(-2, -1))
w_kc = w_kc.transpose(-2, -1)
w_s_kc = w_s_kc.transpose(-2, -1)
w_vc, w_s_vc = b_dynamic_mxfp4_quant(w_vc)
w_s_kc = w_s_kc.transpose(1, 2).contiguous().transpose(1, 2)
w_s_vc = w_s_vc.contiguous().transpose(1, 2)
elif w.dtype == torch.uint8: # static quant for mxfp4
# when dtype is uint8, it means the w has been quantized to mxfp4 format
# but we must separate it to w_kc and w_vc.
# The quantized tensor size is only half of original tensor size
# and the scaling factor is 1/32, the transpose behavior will be not correct
# need to upcast it to fp32 to separate w to w_kc and w_vc
# to ensure the following transpose behavior is correct
# and then do mxfp4 quant again
w = mxfp4_to_f32(w, True).to(torch.bfloat16)
w_scales = self_attn.kv_b_proj.weight_scale.repeat_interleave(32, dim=-1)
w_scales = e8m0_to_f32(w_scales).to(torch.bfloat16)
w = w * w_scales
w_kc, w_vc = w.unflatten(
0, (-1, (self_attn.qk_nope_head_dim + self_attn.v_head_dim))
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
w_kc, w_s_kc = b_dynamic_mxfp4_quant(w_kc.transpose(-2, -1))
w_kc = w_kc.transpose(-2, -1)
w_s_kc = w_s_kc.transpose(-2, -1)
w_vc, w_s_vc = b_dynamic_mxfp4_quant(w_vc)
w_s_kc = w_s_kc.transpose(1, 2).contiguous().transpose(1, 2)
w_s_vc = w_s_vc.contiguous().transpose(1, 2)
return w_kc, w_s_kc, w_vc, w_s_vc

View File

@@ -1,13 +0,0 @@
from aiter.ops.triton.batched_gemm_afp4wfp4_pre_quant import (
batched_gemm_afp4wfp4_pre_quant,
)
from aiter.ops.triton.fused_mxfp4_quant import (
fused_flatten_mxfp4_quant,
fused_rms_mxfp4_quant,
)
__all__ = [
"fused_rms_mxfp4_quant",
"fused_flatten_mxfp4_quant",
"batched_gemm_afp4wfp4_pre_quant",
]