Support OCP MXFP4 quantization on AMD GPUs (#8255)
Co-authored-by: wunhuang <wunhuang@amd.com> Co-authored-by: Hubert Lu <Hubert.Lu@amd.com>
This commit is contained in:
@@ -0,0 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .quark_scheme import QuarkScheme
|
||||
from .quark_w4a4_mxfp4 import QuarkW4A4MXFP4
|
||||
|
||||
__all__ = ["QuarkScheme", "QuarkW4A4MXFP4"]
|
||||
@@ -0,0 +1,55 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = ["QuarkScheme"]
|
||||
|
||||
|
||||
class QuarkScheme(ABC):
|
||||
"""
|
||||
Abstract class used to describe the weight creation and forward pass
|
||||
of different quantization schemes supported by Quark.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
"""
|
||||
Get minimum device capability.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(self, *args, **kwargs):
|
||||
"""
|
||||
Weight creation for the particular scheme. Inputs to this function
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply_weights(
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
|
||||
):
|
||||
"""
|
||||
Run the forward pass for the particular scheme. This is where
|
||||
scheme-specific dequant/quant steps/kernels should be applied.
|
||||
|
||||
:param layer: torch.nn.Module with the registered weights and
|
||||
other parameters relevant to the particular scheme.
|
||||
:param x: input to the layer
|
||||
:param bias: bias parameter
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
"""
|
||||
Called after weight loading is complete for any cleanup that
|
||||
needs to occur.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,118 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import aiter
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from aiter.ops.gemm_op_a4w4 import gemm_a4w4
|
||||
from aiter.ops.shuffle import shuffle_weight
|
||||
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
|
||||
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
||||
from aiter.utility import dtypes
|
||||
from aiter.utility.fp4_utils import e8m0_shuffle
|
||||
|
||||
from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter
|
||||
from sglang.srt.layers.quantization.quark.schemes import QuarkScheme
|
||||
from sglang.srt.utils import get_bool_env_var
|
||||
|
||||
__all__ = ["QuarkW4A4MXFP4"]
|
||||
|
||||
OCP_MX_BLOCK_SIZE = 32
|
||||
|
||||
|
||||
class QuarkW4A4MXFP4(QuarkScheme):
|
||||
|
||||
def __init__(
|
||||
self, weight_quant_spec: dict[str, Any], input_quant_spec: dict[str, Any]
|
||||
):
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.qscheme = "per_group"
|
||||
self.weight_quant_spec = weight_quant_spec
|
||||
self.input_quant_spec = input_quant_spec
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
return
|
||||
|
||||
# for aiter implement
|
||||
# wshuffle = shuffle_weight(layer.weight.data, layout=(16, 16))
|
||||
# w_scales_shuffle = e8m0_shuffle(layer.weight_scale.data).view(dtypes.fp8_e8m0)
|
||||
|
||||
# layer.weight = torch.nn.Parameter(wshuffle,
|
||||
# requires_grad=False)
|
||||
# layer.weight_scale = torch.nn.Parameter(w_scales_shuffle,
|
||||
# requires_grad=False)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
# WEIGHT
|
||||
weight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // 2,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
packed_dim=1,
|
||||
packed_factor=2,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
weight_scale = GroupQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // OCP_MX_BLOCK_SIZE,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
out_dtype = x.dtype
|
||||
# M = x.shape[0]
|
||||
# N = layer.weight.shape[0]
|
||||
|
||||
# quant_func = aiter.get_triton_quant(aiter.QuantType.per_1x32)
|
||||
# x, x_scales_shuffle = quant_func(x, shuffle=True)
|
||||
|
||||
# y = torch.zeros((M + 255) // 256 * 256, N, device=x.device, dtype=self.out_dtype)
|
||||
|
||||
# out = gemm_a4w4(x, layer.weight.data, x_scales_shuffle, layer.weight_scale.data, y, bias=bias)
|
||||
|
||||
# return out[:M]
|
||||
|
||||
# triton implement
|
||||
x_q, x_s = dynamic_mxfp4_quant(x)
|
||||
y = torch.empty(
|
||||
x_q.shape[0], layer.weight.shape[0], device=x_q.device, dtype=out_dtype
|
||||
)
|
||||
|
||||
out = gemm_afp4wfp4(x_q, layer.weight, x_s, layer.weight_scale, out_dtype, y)
|
||||
|
||||
return out
|
||||
107
python/sglang/srt/layers/quantization/quark/utils.py
Normal file
107
python/sglang/srt/layers/quantization/quark/utils.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import re
|
||||
from collections.abc import Iterable, Mapping
|
||||
from types import MappingProxyType
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
def deep_compare(dict1: Any, dict2: Any) -> bool:
|
||||
if type(dict1) is not type(dict2):
|
||||
return False
|
||||
if isinstance(dict1, dict):
|
||||
if dict1.keys() != dict2.keys():
|
||||
return False
|
||||
return all(deep_compare(dict1[k], dict2[k]) for k in dict1)
|
||||
elif isinstance(dict1, list):
|
||||
return set(dict1) == set(dict2)
|
||||
else:
|
||||
return dict1 == dict2
|
||||
|
||||
|
||||
def should_ignore_layer(
|
||||
layer_name: Optional[str],
|
||||
ignore: Iterable[str],
|
||||
fused_mapping: Mapping[str, list[str]] = MappingProxyType({}),
|
||||
) -> bool:
|
||||
if layer_name is None:
|
||||
return False
|
||||
|
||||
# layer_name = model.layers.0.self_attn.qkv_proj
|
||||
# proj_name = qkv_proj
|
||||
proj_name = layer_name.split(".")[-1]
|
||||
|
||||
# Fused layers like gate_up_proj or qkv_proj will not be fused
|
||||
# in the safetensors checkpoint. So, we convert the name
|
||||
# from the fused version to unfused + check to make sure that
|
||||
# each shard of the fused layer has the same scheme.
|
||||
if proj_name in fused_mapping:
|
||||
shard_proj_names = fused_mapping[proj_name]
|
||||
|
||||
# Convert fused_name --> [shard_names]
|
||||
shard_names = [
|
||||
layer_name.replace(proj_name, shard_proj_name)
|
||||
for shard_proj_name in shard_proj_names
|
||||
]
|
||||
|
||||
# Layer should be ignored if shards are ignored.
|
||||
should_ignore_layer = None
|
||||
for shard_name in shard_names:
|
||||
should_ignore_shard = check_equal_or_regex_match(
|
||||
layer_name=shard_name, targets=ignore
|
||||
)
|
||||
|
||||
# If shard_idx=0, set layer ignore to match shard.
|
||||
if should_ignore_layer is None:
|
||||
should_ignore_layer = should_ignore_shard
|
||||
|
||||
# If shard_idx=1+ confirm scheme matches prior shards.
|
||||
elif should_ignore_shard != should_ignore_layer:
|
||||
raise ValueError(
|
||||
f"Found a different quantization schemes for "
|
||||
f"{shard_proj_names} in {layer_name}. vLLM "
|
||||
"requires all to use the same scheme."
|
||||
)
|
||||
|
||||
# Unfused layers like down_proj and o_proj will match
|
||||
# the safetensors checkpoint already.
|
||||
else:
|
||||
should_ignore_layer = check_equal_or_regex_match(
|
||||
layer_name=layer_name, targets=ignore
|
||||
)
|
||||
|
||||
assert should_ignore_layer is not None
|
||||
|
||||
return should_ignore_layer
|
||||
|
||||
|
||||
def check_equal_or_regex_match(layer_name: str, targets: Iterable[str]) -> bool:
|
||||
"""
|
||||
Checks whether a layer_name is exactly equal or a regex match for
|
||||
if target starts with 're:' to any target in list.
|
||||
"""
|
||||
for target in targets:
|
||||
if _is_equal_or_regex_match(layer_name, target):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_equal_or_regex_match(
|
||||
value: str, target: str, check_contains: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
Checks whether a value is exactly equal or a regex match for target
|
||||
if target starts with 're:'. If check_contains is set to True,
|
||||
additionally checks if the target string is contained within the value.
|
||||
"""
|
||||
|
||||
if target.startswith("re:"):
|
||||
pattern = target[3:]
|
||||
if re.match(pattern, value):
|
||||
return True
|
||||
elif check_contains:
|
||||
if target.lower() in value.lower():
|
||||
return True
|
||||
elif target == value:
|
||||
return True
|
||||
return False
|
||||
Reference in New Issue
Block a user