Support OCP MXFP4 quantization on AMD GPUs (#8255)
Co-authored-by: wunhuang <wunhuang@amd.com> Co-authored-by: Hubert Lu <Hubert.Lu@amd.com>
This commit is contained in:
@@ -401,6 +401,8 @@ class ModelConfig:
|
|||||||
"fbgemm_fp8",
|
"fbgemm_fp8",
|
||||||
"w8a8_fp8",
|
"w8a8_fp8",
|
||||||
"petit_nvfp4",
|
"petit_nvfp4",
|
||||||
|
"quark",
|
||||||
|
"mxfp4",
|
||||||
]
|
]
|
||||||
optimized_quantization_methods = [
|
optimized_quantization_methods = [
|
||||||
"fp8",
|
"fp8",
|
||||||
|
|||||||
@@ -47,6 +47,12 @@ from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
|
|||||||
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
|
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
|
||||||
CompressedTensorsConfig,
|
CompressedTensorsConfig,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.utils import mxfp_supported
|
||||||
|
|
||||||
|
is_mxfp_supported = mxfp_supported()
|
||||||
|
if is_mxfp_supported:
|
||||||
|
from sglang.srt.layers.quantization.fp4 import MxFp4Config
|
||||||
|
|
||||||
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
||||||
from sglang.srt.layers.quantization.gptq import (
|
from sglang.srt.layers.quantization.gptq import (
|
||||||
GPTQConfig,
|
GPTQConfig,
|
||||||
@@ -84,7 +90,13 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|||||||
"w4afp8": W4AFp8Config,
|
"w4afp8": W4AFp8Config,
|
||||||
"petit_nvfp4": PetitNvFp4Config,
|
"petit_nvfp4": PetitNvFp4Config,
|
||||||
}
|
}
|
||||||
|
if is_mxfp_supported:
|
||||||
|
BASE_QUANTIZATION_METHODS.update(
|
||||||
|
{
|
||||||
|
"quark": MxFp4Config,
|
||||||
|
"mxfp4": MxFp4Config,
|
||||||
|
}
|
||||||
|
)
|
||||||
# VLLM-dependent quantization methods
|
# VLLM-dependent quantization methods
|
||||||
VLLM_QUANTIZATION_METHODS = {
|
VLLM_QUANTIZATION_METHODS = {
|
||||||
"aqlm": AQLMConfig,
|
"aqlm": AQLMConfig,
|
||||||
|
|||||||
822
python/sglang/srt/layers/quantization/fp4.py
Normal file
822
python/sglang/srt/layers/quantization/fp4.py
Normal file
@@ -0,0 +1,822 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import fnmatch
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union, cast
|
||||||
|
|
||||||
|
import aiter
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from aiter import ActivationType, QuantType, dtypes
|
||||||
|
from aiter.fused_moe import fused_moe
|
||||||
|
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
|
||||||
|
from aiter.ops.gemm_op_a4w4 import gemm_a4w4
|
||||||
|
from aiter.ops.quant import get_torch_quant
|
||||||
|
from aiter.ops.shuffle import shuffle_weight
|
||||||
|
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
|
||||||
|
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
||||||
|
from aiter.utility.fp4_utils import e8m0_shuffle
|
||||||
|
from torch.nn import Module
|
||||||
|
|
||||||
|
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||||
|
from sglang.srt.layers.parameter import ModelWeightParameter
|
||||||
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
|
FusedMoEMethodBase,
|
||||||
|
LinearMethodBase,
|
||||||
|
QuantizationConfig,
|
||||||
|
QuantizeMethodBase,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
|
from sglang.srt.layers.quantization.quark.schemes import QuarkScheme, QuarkW4A4MXFP4
|
||||||
|
from sglang.srt.layers.quantization.quark.utils import deep_compare, should_ignore_layer
|
||||||
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.utils import (
|
||||||
|
get_bool_env_var,
|
||||||
|
get_device_capability,
|
||||||
|
log_info_on_rank0,
|
||||||
|
mxfp_supported,
|
||||||
|
set_weight_attrs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.layers.moe.topk import TopKOutput
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
use_dynamic_mxfp4_linear = get_bool_env_var("SGLANG_USE_DYNAMIC_MXFP4_linear")
|
||||||
|
|
||||||
|
OCP_MX_BLOCK_SIZE = 32
|
||||||
|
|
||||||
|
|
||||||
|
class MxFp4Config(QuantizationConfig):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
is_checkpoint_fp4_serialized: bool = False,
|
||||||
|
quant_config: dict[str, Any] = None,
|
||||||
|
kv_cache_group: Optional[list[str]] = None,
|
||||||
|
kv_cache_config: Optional[dict[str, Any]] = None,
|
||||||
|
pack_method: str = "reorder",
|
||||||
|
ignored_layers: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if kv_cache_group is None:
|
||||||
|
kv_cache_group = []
|
||||||
|
|
||||||
|
self.is_checkpoint_fp4_serialized = is_checkpoint_fp4_serialized
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.kv_cache_group = kv_cache_group
|
||||||
|
self.kv_cache_config = kv_cache_config
|
||||||
|
self.pack_method = pack_method
|
||||||
|
|
||||||
|
self.packed_modules_mapping = (
|
||||||
|
self.quant_config["packed_modules_mapping"]
|
||||||
|
if is_checkpoint_fp4_serialized
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
self.ignored_layers = ignored_layers or []
|
||||||
|
|
||||||
|
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||||
|
return [torch.float16, torch.bfloat16]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
return 70
|
||||||
|
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "fp4"
|
||||||
|
|
||||||
|
def get_quant_method(
|
||||||
|
self, layer: torch.nn.Module, prefix: str
|
||||||
|
) -> Optional["QuantizeMethodBase"]:
|
||||||
|
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
|
|
||||||
|
# Check if the layer is skipped for quantization.
|
||||||
|
if len(self.ignored_layers) > 0 and should_ignore_layer(
|
||||||
|
prefix,
|
||||||
|
ignore=self.ignored_layers,
|
||||||
|
fused_mapping=self.packed_modules_mapping,
|
||||||
|
):
|
||||||
|
return UnquantizedLinearMethod()
|
||||||
|
|
||||||
|
if isinstance(layer, LinearBase):
|
||||||
|
if self.is_checkpoint_fp4_serialized:
|
||||||
|
scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
||||||
|
layer.scheme = scheme
|
||||||
|
return MxFp4LinearMethod(self)
|
||||||
|
|
||||||
|
elif use_dynamic_mxfp4_linear:
|
||||||
|
return MxFp4LinearMethod(self)
|
||||||
|
else:
|
||||||
|
return UnquantizedLinearMethod()
|
||||||
|
|
||||||
|
if isinstance(layer, RadixAttention):
|
||||||
|
return MxFp4KVCacheMethod(self)
|
||||||
|
|
||||||
|
if isinstance(layer, FusedMoE):
|
||||||
|
return MxFp4MoEMethod.get_moe_method(self, module=layer, layer_name=prefix)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: dict[str, Any]) -> "MxFp4Config":
|
||||||
|
if not mxfp_supported():
|
||||||
|
platform = torch.cuda.get_device_properties(0).gcnArchName
|
||||||
|
raise ValueError(
|
||||||
|
f"Current platform {platform} not support mxfp4 computation"
|
||||||
|
)
|
||||||
|
quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||||
|
is_checkpoint_fp4_serialized = (
|
||||||
|
True if quant_method else False
|
||||||
|
) # "quark" in quant_method
|
||||||
|
|
||||||
|
kv_cache_group = []
|
||||||
|
pack_method = None
|
||||||
|
|
||||||
|
if is_checkpoint_fp4_serialized:
|
||||||
|
export_config = config.get("export")
|
||||||
|
if export_config is None:
|
||||||
|
raise ValueError(
|
||||||
|
"The export key should be included in "
|
||||||
|
"the configurations of Quark quantized model"
|
||||||
|
)
|
||||||
|
|
||||||
|
kv_cache_group = cast(list[str], export_config.get("kv_cache_group"))
|
||||||
|
pack_method = cast(str, export_config.get("pack_method"))
|
||||||
|
|
||||||
|
# In the export model of quark, the quantization configuration
|
||||||
|
# of kv_cache is stored in layer_quant_config. First, it is
|
||||||
|
# judged whether kv_cache_group exists, and then it is judged
|
||||||
|
# whether layer_quant_config has a quantization configuration
|
||||||
|
# that matches kv_cache.
|
||||||
|
if len(kv_cache_group) == 0:
|
||||||
|
kv_cache_config = None
|
||||||
|
else:
|
||||||
|
kv_cache_set = set(kv_cache_group)
|
||||||
|
layer_quant_config = cast(dict[str, Any], config.get("layer_quant_config"))
|
||||||
|
layer_quant_names = list(layer_quant_config.keys())
|
||||||
|
layer_quant_set = set(layer_quant_names)
|
||||||
|
|
||||||
|
if not kv_cache_set.issubset(layer_quant_set):
|
||||||
|
raise ValueError(
|
||||||
|
"The Quark quantized model has the "
|
||||||
|
"kv_cache_group parameter setting, "
|
||||||
|
"but no kv_cache quantization settings "
|
||||||
|
"were found in the quantization "
|
||||||
|
"configuration."
|
||||||
|
)
|
||||||
|
|
||||||
|
q_configs = [
|
||||||
|
cast(dict[str, Any], layer_quant_config.get(name))
|
||||||
|
for name in kv_cache_group
|
||||||
|
]
|
||||||
|
if not all(deep_compare(q_config, q_configs[0]) for q_config in q_configs):
|
||||||
|
raise ValueError(
|
||||||
|
"The quantization method used for kv_cache should "
|
||||||
|
"be the same, but the quantization method for the "
|
||||||
|
"kv_cache layer in the config is different."
|
||||||
|
)
|
||||||
|
kv_cache_config = q_configs[0].get("output_tensors")
|
||||||
|
if kv_cache_config is None:
|
||||||
|
raise ValueError("The kv_cache quantization configuration is empty.")
|
||||||
|
|
||||||
|
# Since we have already set kv_cache quantization configurations,
|
||||||
|
# we will remove the quantization configuration for the
|
||||||
|
# output_tensors corresponding to the kv_cache layer.
|
||||||
|
for q_config in q_configs:
|
||||||
|
q_config["output_tensors"] = None
|
||||||
|
|
||||||
|
# In case q_proj output is also quantized, remove the configuration
|
||||||
|
# to keep qkv consistency.
|
||||||
|
q_proj_q_config = cast(dict[str, Any], layer_quant_config.get("*q_proj"))
|
||||||
|
if q_proj_q_config is not None:
|
||||||
|
q_proj_q_config["output_tensors"] = None
|
||||||
|
|
||||||
|
ignored_layers = cls.get_from_keys_or(config, ["exclude"], None)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
is_checkpoint_fp4_serialized=is_checkpoint_fp4_serialized,
|
||||||
|
quant_config=config,
|
||||||
|
kv_cache_group=kv_cache_group,
|
||||||
|
kv_cache_config=kv_cache_config,
|
||||||
|
pack_method=pack_method,
|
||||||
|
ignored_layers=ignored_layers,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config_filenames(cls) -> list[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _check_scheme_supported(self, min_capability: int, error: bool = True) -> bool:
|
||||||
|
capability_tuple = get_device_capability()
|
||||||
|
|
||||||
|
if capability_tuple is not None:
|
||||||
|
assert 0 <= capability_tuple[1] < 10
|
||||||
|
capability = capability_tuple[0] * 10 + capability_tuple[1]
|
||||||
|
|
||||||
|
supported = capability >= min_capability
|
||||||
|
if error and not supported:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Quantization scheme is not supported for ",
|
||||||
|
f"the current GPU. Min capability: {min_capability}. ",
|
||||||
|
f"Current capability: {capability}.",
|
||||||
|
)
|
||||||
|
return supported
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _is_mx_fp4(
|
||||||
|
self,
|
||||||
|
weight_quant: Optional[dict[str, Any]],
|
||||||
|
input_quant: Optional[dict[str, Any]],
|
||||||
|
) -> bool:
|
||||||
|
# Confirm weights and input quantized.
|
||||||
|
if weight_quant is None or input_quant is None:
|
||||||
|
logger.debug(
|
||||||
|
"Quark model is not in MX-FP4 format: "
|
||||||
|
"weight_quant or input_quant not set"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Input and weight dtype needs to be fp4.
|
||||||
|
if weight_quant.get("dtype") != "fp4" or input_quant.get("dtype") != "fp4":
|
||||||
|
logger.debug("Quark model is not in MX-FP4 format: dtype not fp4")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Input and weight qscheme needs to be per group.
|
||||||
|
if (
|
||||||
|
weight_quant.get("qscheme") != "per_group"
|
||||||
|
or input_quant.get("qscheme") != "per_group"
|
||||||
|
):
|
||||||
|
logger.debug("Quark model is not in MX-FP4 format: not per_group")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Input and weight group size needs to be 32.
|
||||||
|
if weight_quant.get("group_size") != 32 or input_quant.get("group_size") != 32:
|
||||||
|
logger.debug("Quark model is not in MX-FP4 format: not group_size=32")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Weights need to use static quantization.
|
||||||
|
if weight_quant.get("is_dynamic") is True:
|
||||||
|
logger.debug("Quark model is not in MX-FP4 format: not weight static")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Activations need to use dynamic quantization.
|
||||||
|
if input_quant.get("is_dynamic") is False:
|
||||||
|
logger.debug("Quark model is not in MX-FP4 format: not activation dynamic")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Activations and weight scales need to be in e8m0 format.
|
||||||
|
if (
|
||||||
|
weight_quant.get("scale_format") != "e8m0"
|
||||||
|
or input_quant.get("scale_format") != "e8m0"
|
||||||
|
):
|
||||||
|
logger.debug("Quark model is not in MX-FP4 format: not scale_format e8m0")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _find_matched_config(
|
||||||
|
self, layer_name: str, module: torch.nn.Module
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
|
||||||
|
proj_name = layer_name.split(".")[-1]
|
||||||
|
if proj_name in self.packed_modules_mapping:
|
||||||
|
shard_proj_names = self.packed_modules_mapping[proj_name]
|
||||||
|
|
||||||
|
# Convert fused_name --> [shard_names]
|
||||||
|
shard_names = [
|
||||||
|
layer_name.replace(proj_name, shard_proj_name)
|
||||||
|
for shard_proj_name in shard_proj_names
|
||||||
|
]
|
||||||
|
shard_configs = [
|
||||||
|
self._find_matched_config(shard_name, module)
|
||||||
|
for shard_name in shard_names
|
||||||
|
]
|
||||||
|
if not all(
|
||||||
|
deep_compare(q_config, shard_configs[0]) for q_config in shard_configs
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Found a different quantization configuration for "
|
||||||
|
f"{shard_proj_names=} in {layer_name=}. vLLM "
|
||||||
|
"requires all to use the same scheme."
|
||||||
|
)
|
||||||
|
return shard_configs[0]
|
||||||
|
else:
|
||||||
|
layer_quant_config = cast(
|
||||||
|
dict[str, Any], self.quant_config.get("layer_quant_config")
|
||||||
|
)
|
||||||
|
for name_pattern in layer_quant_config:
|
||||||
|
if fnmatch.fnmatch(layer_name, name_pattern):
|
||||||
|
return layer_quant_config[name_pattern]
|
||||||
|
|
||||||
|
layer_type = cast(str, type(module))
|
||||||
|
layer_type_quant_config = cast(
|
||||||
|
dict[str, Any], self.quant_config.get("layer_type_quant_config")
|
||||||
|
)
|
||||||
|
if layer_type in layer_type_quant_config:
|
||||||
|
return layer_type_quant_config[layer_type]
|
||||||
|
|
||||||
|
global_quant_config = cast(
|
||||||
|
dict[str, Any], self.quant_config.get("global_quant_config")
|
||||||
|
)
|
||||||
|
return global_quant_config
|
||||||
|
|
||||||
|
def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme":
|
||||||
|
if config.get("output_tensors") or config.get("bias"):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Currently, Quark models with output_tensors "
|
||||||
|
"and bias quantized are not supported"
|
||||||
|
)
|
||||||
|
weight_config = cast(dict[str, Any], config.get("weight"))
|
||||||
|
input_config = cast(dict[str, Any], config.get("input_tensors"))
|
||||||
|
|
||||||
|
if self._is_mx_fp4(weight_config, input_config):
|
||||||
|
return QuarkW4A4MXFP4(weight_config, input_config)
|
||||||
|
|
||||||
|
raise NotImplementedError(
|
||||||
|
"No quark compatible scheme was found. "
|
||||||
|
f"{weight_config=}, "
|
||||||
|
f"{input_config=}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_scheme(self, layer: torch.nn.Module, layer_name: str) -> "QuarkScheme":
|
||||||
|
|
||||||
|
layer_quant_config = self._find_matched_config(layer_name, layer)
|
||||||
|
|
||||||
|
# Find the quant_scheme
|
||||||
|
scheme = self._get_scheme_from_config(layer_quant_config)
|
||||||
|
|
||||||
|
# Raise error if device does not support the scheme
|
||||||
|
# (e.g. fp8 needs ada lovelace)
|
||||||
|
self._check_scheme_supported(scheme.get_min_capability())
|
||||||
|
|
||||||
|
return scheme
|
||||||
|
|
||||||
|
def get_scaled_act_names(self) -> List[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class MxFp4LinearMethod(LinearMethodBase):
|
||||||
|
|
||||||
|
def __init__(self, quantization_config: MxFp4Config):
|
||||||
|
self.quantization_config = quantization_config
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
return
|
||||||
|
# if self.quantization_config.is_checkpoint_fp4_serialized:
|
||||||
|
# layer.scheme.process_weights_after_loading(layer)
|
||||||
|
# else:
|
||||||
|
# #w, w_scales = dynamic_mxfp4_quant(layer.weight.data)
|
||||||
|
# ##log_info_on_rank0(logger, f"w.shape: {w.shape}")
|
||||||
|
|
||||||
|
# #wshuffle = w#shuffle_weight(w, layout=(16, 16))
|
||||||
|
# #w_scales_shuffle = w_scales#e8m0_shuffle(w_scales).view(dtypes.fp8_e8m0)
|
||||||
|
|
||||||
|
# quant_func = aiter.get_triton_quant(aiter.QuantType.per_1x32)
|
||||||
|
|
||||||
|
# w, w_scales_shuffle = quant_func(layer.weight.data, shuffle=True)
|
||||||
|
|
||||||
|
# wshuffle = shuffle_weight(w, layout=(16, 16))
|
||||||
|
|
||||||
|
# layer.weight = torch.nn.Parameter(wshuffle,
|
||||||
|
# requires_grad=False)
|
||||||
|
# layer.weight_scale = torch.nn.Parameter(w_scales_shuffle,
|
||||||
|
# requires_grad=False)
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
input_size_per_partition: int,
|
||||||
|
output_partition_sizes: list[int],
|
||||||
|
input_size: int,
|
||||||
|
output_size: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Use the CompressedTensorsScheme associated with each layer to create
|
||||||
|
the necessary parameters for the layer. See LinearMethodBase for param
|
||||||
|
details
|
||||||
|
"""
|
||||||
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||||
|
|
||||||
|
if self.quantization_config.is_checkpoint_fp4_serialized:
|
||||||
|
layer.scheme.create_weights(
|
||||||
|
layer=layer,
|
||||||
|
input_size=input_size,
|
||||||
|
input_size_per_partition=input_size_per_partition,
|
||||||
|
output_partition_sizes=output_partition_sizes,
|
||||||
|
output_size=output_size,
|
||||||
|
params_dtype=params_dtype,
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output_size_per_partition = sum(output_partition_sizes)
|
||||||
|
layer.logical_widths = output_partition_sizes
|
||||||
|
layer.input_size_per_partition = input_size_per_partition
|
||||||
|
layer.output_size_per_partition = output_size_per_partition
|
||||||
|
layer.orig_dtype = params_dtype
|
||||||
|
|
||||||
|
weight_dtype = params_dtype
|
||||||
|
|
||||||
|
weight = ModelWeightParameter(
|
||||||
|
data=torch.empty(
|
||||||
|
output_size_per_partition,
|
||||||
|
input_size_per_partition,
|
||||||
|
dtype=weight_dtype,
|
||||||
|
),
|
||||||
|
input_dim=1,
|
||||||
|
output_dim=0,
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
|
||||||
|
layer.register_parameter("weight", weight)
|
||||||
|
layer.register_parameter("weight_scale", None)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Use the output of create_weights and the CompressedTensorsScheme
|
||||||
|
associated with the layer to apply the forward pass with the
|
||||||
|
layer input. See LinearMethodBase for param details
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self.quantization_config.is_checkpoint_fp4_serialized:
|
||||||
|
scheme = layer.scheme
|
||||||
|
if scheme is None:
|
||||||
|
raise ValueError("A scheme must be defined for each layer")
|
||||||
|
return scheme.apply_weights(layer, x, bias=bias)
|
||||||
|
else:
|
||||||
|
out_dtype = x.dtype
|
||||||
|
|
||||||
|
# ck or asm implement
|
||||||
|
# M = x.shape[0]
|
||||||
|
# N = layer.weight.shape[0]
|
||||||
|
|
||||||
|
# quant_func = aiter.get_triton_quant(aiter.QuantType.per_1x32)
|
||||||
|
|
||||||
|
# x, x_scales_shuffle = quant_func(x, shuffle=True)
|
||||||
|
|
||||||
|
# y = torch.zeros((M + 255) // 256 * 256, N, device=x.device, dtype=out_dtype)
|
||||||
|
|
||||||
|
# out = gemm_a4w4(x, layer.weight.data, x_scales_shuffle, layer.weight_scale.data, y, bias=bias)
|
||||||
|
|
||||||
|
# return out[:M]
|
||||||
|
|
||||||
|
# triton implement
|
||||||
|
x_q, x_s = dynamic_mxfp4_quant(x)
|
||||||
|
y = torch.empty(
|
||||||
|
x_q.shape[0], layer.weight.shape[0], device=x_q.device, dtype=out_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
out = gemm_afp4wfp4(
|
||||||
|
x_q, layer.weight, x_s, layer.weight_scale, out_dtype, y
|
||||||
|
)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class MxFp4MoEMethod:
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
if not hasattr(cls, "_initialized"):
|
||||||
|
original_init = cls.__init__
|
||||||
|
new_cls = type(
|
||||||
|
cls.__name__,
|
||||||
|
(FusedMoEMethodBase,),
|
||||||
|
{
|
||||||
|
"__init__": original_init,
|
||||||
|
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
obj = super(new_cls, new_cls).__new__(new_cls)
|
||||||
|
obj.__init__(*args, **kwargs)
|
||||||
|
return obj
|
||||||
|
return super().__new__(cls)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_moe_method(
|
||||||
|
quant_config: "MxFp4Config", # type: ignore # noqa E501 # noqa F821
|
||||||
|
module: torch.nn.Module,
|
||||||
|
layer_name: str,
|
||||||
|
) -> "MxFp4MoEMethod":
|
||||||
|
|
||||||
|
if quant_config.is_checkpoint_fp4_serialized:
|
||||||
|
layer_quant_config = quant_config._find_matched_config(layer_name, module)
|
||||||
|
|
||||||
|
if layer_quant_config.get("output_tensors") or layer_quant_config.get(
|
||||||
|
"bias"
|
||||||
|
):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Currently, Quark models with "
|
||||||
|
"output_tensors and bias "
|
||||||
|
"quantized are not supported"
|
||||||
|
)
|
||||||
|
weight_config = layer_quant_config.get("weight")
|
||||||
|
input_config = layer_quant_config.get("input_tensors")
|
||||||
|
|
||||||
|
if quant_config._is_mx_fp4(weight_config, input_config):
|
||||||
|
return W4A4MXFp4MoEStaticMethod(weight_config, input_config)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Unsupported FusedMoe scheme")
|
||||||
|
else:
|
||||||
|
return W4A4MXFp4MoEDynamicMethod(quant_config)
|
||||||
|
|
||||||
|
|
||||||
|
class W4A4MXFp4MoEDynamicMethod(MxFp4MoEMethod):
|
||||||
|
def __init__(self, quant_config):
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
num_experts: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size_per_partition: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||||
|
|
||||||
|
w13_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts,
|
||||||
|
2 * intermediate_size_per_partition,
|
||||||
|
hidden_size,
|
||||||
|
dtype=params_dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
w2_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts,
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size_per_partition,
|
||||||
|
dtype=params_dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
layer.register_parameter("w13_weight", w13_weight)
|
||||||
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
layer.register_parameter("w2_weight", w2_weight)
|
||||||
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
# Allocate 2 scales for w1 and w3 respectively.
|
||||||
|
# They will be combined to a single scale after weight loading.
|
||||||
|
w13_weight_scale = torch.nn.Parameter(
|
||||||
|
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
|
||||||
|
)
|
||||||
|
w2_weight_scale = torch.nn.Parameter(
|
||||||
|
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||||
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||||
|
|
||||||
|
# Add the quantization method used (per tensor/grouped/channel)
|
||||||
|
# to ensure the weight scales are loaded in properly
|
||||||
|
extra_weight_attrs.update(
|
||||||
|
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
||||||
|
)
|
||||||
|
|
||||||
|
layer.w13_input_scale = None
|
||||||
|
layer.w2_input_scale = None
|
||||||
|
|
||||||
|
def mxfp4_quantize(self, w):
|
||||||
|
w_shape = w.shape
|
||||||
|
w_need_reshape = True if w.dim() != 2 else False
|
||||||
|
|
||||||
|
if w_need_reshape:
|
||||||
|
w_last_dim_size = w_shape[-1]
|
||||||
|
w = w.view(-1, w_last_dim_size)
|
||||||
|
|
||||||
|
# log_info_on_rank0(logger, f"[Pre-quant] w.shape: {w.shape}")
|
||||||
|
w, mx_scales = dynamic_mxfp4_quant(w)
|
||||||
|
# log_info_on_rank0(logger, f"[Post-quant] w.shape: {w.shape} mx_scales.shape: {mx_scales.shape}")
|
||||||
|
|
||||||
|
if w_need_reshape:
|
||||||
|
w_new_shape = w_shape[:-1] + (w.shape[-1],)
|
||||||
|
w = w.view(w_new_shape)
|
||||||
|
|
||||||
|
# log_info_on_rank0(logger, f"[re-shape] w.shape: {w.shape} mx_scales.shape: {mx_scales.shape}")
|
||||||
|
|
||||||
|
mx_scales = e8m0_shuffle(mx_scales)
|
||||||
|
|
||||||
|
return w, mx_scales
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
|
w13, w13_mx_scales = self.mxfp4_quantize(layer.w13_weight.data)
|
||||||
|
w2, w2_mx_scales = self.mxfp4_quantize(layer.w2_weight.data)
|
||||||
|
|
||||||
|
layer.w13_weight = torch.nn.Parameter(w13, requires_grad=False)
|
||||||
|
layer.w13_weight_scale = torch.nn.Parameter(w13_mx_scales, requires_grad=False)
|
||||||
|
|
||||||
|
layer.w2_weight = torch.nn.Parameter(w2, requires_grad=False)
|
||||||
|
layer.w2_weight_scale = torch.nn.Parameter(w2_mx_scales, requires_grad=False)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
topk_output: TopKOutput,
|
||||||
|
*,
|
||||||
|
activation: str = "silu",
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
inplace: bool = True,
|
||||||
|
no_combine: bool = False,
|
||||||
|
routed_scaling_factor: Optional[float] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
topk_weights, topk_ids, _ = topk_output
|
||||||
|
|
||||||
|
return fused_moe(
|
||||||
|
x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
quant_type=QuantType.per_1x32,
|
||||||
|
w1_scale=layer.w13_weight_scale,
|
||||||
|
w2_scale=layer.w2_weight_scale,
|
||||||
|
activation=(
|
||||||
|
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
|
||||||
|
),
|
||||||
|
doweight_stage1=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class W4A4MXFp4MoEStaticMethod(MxFp4MoEMethod):
|
||||||
|
|
||||||
|
def __init__(self, weight_config: dict[str, Any], input_config: dict[str, Any]):
|
||||||
|
self.weight_quant = weight_config
|
||||||
|
self.input_quant = input_config
|
||||||
|
|
||||||
|
weight_qscheme = self.weight_quant.get("qscheme")
|
||||||
|
input_qscheme = self.input_quant.get("qscheme")
|
||||||
|
if not (weight_qscheme == "per_group" and input_qscheme == "per_group"):
|
||||||
|
raise ValueError(
|
||||||
|
"For MX(FP4) Fused MoE layers, only per-group scales "
|
||||||
|
"for weights and activations are supported. Found "
|
||||||
|
f"{weight_qscheme=}, {input_qscheme=}"
|
||||||
|
) # noqa E501
|
||||||
|
|
||||||
|
self.static_input_scales = not self.input_quant.get("is_dynamic")
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
num_experts: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size_per_partition: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||||
|
|
||||||
|
# Add the quantization method used (per tensor/grouped/channel)
|
||||||
|
# to ensure the weight scales are loaded in properly
|
||||||
|
extra_weight_attrs.update(
|
||||||
|
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
|
||||||
|
)
|
||||||
|
|
||||||
|
params_dtype = torch.uint8
|
||||||
|
|
||||||
|
# WEIGHTS
|
||||||
|
w13_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts,
|
||||||
|
2 * intermediate_size_per_partition,
|
||||||
|
hidden_size // 2,
|
||||||
|
dtype=params_dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight", w13_weight)
|
||||||
|
|
||||||
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
w2_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts,
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size_per_partition // 2,
|
||||||
|
dtype=params_dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_weight", w2_weight)
|
||||||
|
|
||||||
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
# WEIGHT_SCALES
|
||||||
|
w13_weight_scale = torch.nn.Parameter(
|
||||||
|
torch.ones(
|
||||||
|
num_experts,
|
||||||
|
2 * intermediate_size_per_partition,
|
||||||
|
hidden_size // OCP_MX_BLOCK_SIZE,
|
||||||
|
dtype=params_dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
w2_weight_scale = torch.nn.Parameter(
|
||||||
|
torch.ones(
|
||||||
|
num_experts,
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size_per_partition // OCP_MX_BLOCK_SIZE,
|
||||||
|
dtype=params_dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||||
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||||
|
|
||||||
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||||
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
float_dtype = torch.get_default_dtype()
|
||||||
|
|
||||||
|
# Pre-shuffle weight scales
|
||||||
|
s0, s1, _ = layer.w13_weight_scale.shape
|
||||||
|
w13_weight_scale = layer.w13_weight_scale.view(s0 * s1, -1)
|
||||||
|
w13_weight_scale = e8m0_shuffle(w13_weight_scale)
|
||||||
|
layer.w13_weight_scale.data = w13_weight_scale.view(s0, s1, -1)
|
||||||
|
|
||||||
|
s0, s1, _ = layer.w2_weight_scale.shape
|
||||||
|
w2_weight_scale = layer.w2_weight_scale.view(s0 * s1, -1)
|
||||||
|
w2_weight_scale = e8m0_shuffle(w2_weight_scale)
|
||||||
|
layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
topk_output: TopKOutput,
|
||||||
|
*,
|
||||||
|
activation: str = "silu",
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
inplace: bool = True,
|
||||||
|
no_combine: bool = False,
|
||||||
|
routed_scaling_factor: Optional[float] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
topk_weights, topk_ids, _ = topk_output
|
||||||
|
|
||||||
|
return fused_moe(
|
||||||
|
x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
quant_type=QuantType.per_1x32,
|
||||||
|
w1_scale=layer.w13_weight_scale,
|
||||||
|
w2_scale=layer.w2_weight_scale,
|
||||||
|
activation=(
|
||||||
|
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
|
||||||
|
),
|
||||||
|
doweight_stage1=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MxFp4KVCacheMethod(BaseKVCacheMethod):
|
||||||
|
"""
|
||||||
|
Supports loading kv-cache scaling factors from quark checkpoints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, quant_config: MxFp4Config):
|
||||||
|
self.validate_kv_cache_config(quant_config.kv_cache_config)
|
||||||
|
super().__init__(quant_config)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def validate_kv_cache_config(kv_cache_config: Optional[dict[str, Any]]):
|
||||||
|
"""
|
||||||
|
Validator for the kv cache configuration. Useful for controlling the
|
||||||
|
kv cache quantization schemes, that are being supported in vLLM
|
||||||
|
:param kv_cache_config: the quark kv cache scheme
|
||||||
|
"""
|
||||||
|
if kv_cache_config is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
dtype = kv_cache_config.get("dtype")
|
||||||
|
if dtype != "fp8_e4m3":
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Currently supported kv cache quantization is "
|
||||||
|
f"dtype=fp8_e4m3, however received {dtype}"
|
||||||
|
)
|
||||||
|
|
||||||
|
qscheme = kv_cache_config.get("qscheme")
|
||||||
|
if qscheme != "per_tensor":
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Only support per-tensor scaling factor "
|
||||||
|
"for quark KV cache. "
|
||||||
|
f"Expected qscheme: per_tensor, found qscheme: {qscheme}"
|
||||||
|
)
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from .quark_scheme import QuarkScheme
|
||||||
|
from .quark_w4a4_mxfp4 import QuarkW4A4MXFP4
|
||||||
|
|
||||||
|
__all__ = ["QuarkScheme", "QuarkW4A4MXFP4"]
|
||||||
@@ -0,0 +1,55 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
__all__ = ["QuarkScheme"]
|
||||||
|
|
||||||
|
|
||||||
|
class QuarkScheme(ABC):
|
||||||
|
"""
|
||||||
|
Abstract class used to describe the weight creation and forward pass
|
||||||
|
of different quantization schemes supported by Quark.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
"""
|
||||||
|
Get minimum device capability.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create_weights(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Weight creation for the particular scheme. Inputs to this function
|
||||||
|
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def apply_weights(
|
||||||
|
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Run the forward pass for the particular scheme. This is where
|
||||||
|
scheme-specific dequant/quant steps/kernels should be applied.
|
||||||
|
|
||||||
|
:param layer: torch.nn.Module with the registered weights and
|
||||||
|
other parameters relevant to the particular scheme.
|
||||||
|
:param x: input to the layer
|
||||||
|
:param bias: bias parameter
|
||||||
|
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Called after weight loading is complete for any cleanup that
|
||||||
|
needs to occur.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
@@ -0,0 +1,118 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
import aiter
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from aiter.ops.gemm_op_a4w4 import gemm_a4w4
|
||||||
|
from aiter.ops.shuffle import shuffle_weight
|
||||||
|
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
|
||||||
|
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
||||||
|
from aiter.utility import dtypes
|
||||||
|
from aiter.utility.fp4_utils import e8m0_shuffle
|
||||||
|
|
||||||
|
from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter
|
||||||
|
from sglang.srt.layers.quantization.quark.schemes import QuarkScheme
|
||||||
|
from sglang.srt.utils import get_bool_env_var
|
||||||
|
|
||||||
|
__all__ = ["QuarkW4A4MXFP4"]
|
||||||
|
|
||||||
|
OCP_MX_BLOCK_SIZE = 32
|
||||||
|
|
||||||
|
|
||||||
|
class QuarkW4A4MXFP4(QuarkScheme):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, weight_quant_spec: dict[str, Any], input_quant_spec: dict[str, Any]
|
||||||
|
):
|
||||||
|
self.out_dtype = torch.get_default_dtype()
|
||||||
|
self.qscheme = "per_group"
|
||||||
|
self.weight_quant_spec = weight_quant_spec
|
||||||
|
self.input_quant_spec = input_quant_spec
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
return 70
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# for aiter implement
|
||||||
|
# wshuffle = shuffle_weight(layer.weight.data, layout=(16, 16))
|
||||||
|
# w_scales_shuffle = e8m0_shuffle(layer.weight_scale.data).view(dtypes.fp8_e8m0)
|
||||||
|
|
||||||
|
# layer.weight = torch.nn.Parameter(wshuffle,
|
||||||
|
# requires_grad=False)
|
||||||
|
# layer.weight_scale = torch.nn.Parameter(w_scales_shuffle,
|
||||||
|
# requires_grad=False)
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
output_partition_sizes: list[int],
|
||||||
|
input_size_per_partition: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
weight_loader: Callable,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
output_size_per_partition = sum(output_partition_sizes)
|
||||||
|
layer.logical_widths = output_partition_sizes
|
||||||
|
|
||||||
|
# WEIGHT
|
||||||
|
weight = PackedvLLMParameter(
|
||||||
|
data=torch.empty(
|
||||||
|
output_size_per_partition,
|
||||||
|
input_size_per_partition // 2,
|
||||||
|
dtype=torch.uint8,
|
||||||
|
),
|
||||||
|
input_dim=1,
|
||||||
|
output_dim=0,
|
||||||
|
packed_dim=1,
|
||||||
|
packed_factor=2,
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
layer.register_parameter("weight", weight)
|
||||||
|
|
||||||
|
# WEIGHT SCALE
|
||||||
|
weight_scale = GroupQuantScaleParameter(
|
||||||
|
data=torch.empty(
|
||||||
|
output_size_per_partition,
|
||||||
|
input_size_per_partition // OCP_MX_BLOCK_SIZE,
|
||||||
|
dtype=torch.uint8,
|
||||||
|
),
|
||||||
|
input_dim=1,
|
||||||
|
output_dim=0,
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
layer.register_parameter("weight_scale", weight_scale)
|
||||||
|
|
||||||
|
def apply_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
out_dtype = x.dtype
|
||||||
|
# M = x.shape[0]
|
||||||
|
# N = layer.weight.shape[0]
|
||||||
|
|
||||||
|
# quant_func = aiter.get_triton_quant(aiter.QuantType.per_1x32)
|
||||||
|
# x, x_scales_shuffle = quant_func(x, shuffle=True)
|
||||||
|
|
||||||
|
# y = torch.zeros((M + 255) // 256 * 256, N, device=x.device, dtype=self.out_dtype)
|
||||||
|
|
||||||
|
# out = gemm_a4w4(x, layer.weight.data, x_scales_shuffle, layer.weight_scale.data, y, bias=bias)
|
||||||
|
|
||||||
|
# return out[:M]
|
||||||
|
|
||||||
|
# triton implement
|
||||||
|
x_q, x_s = dynamic_mxfp4_quant(x)
|
||||||
|
y = torch.empty(
|
||||||
|
x_q.shape[0], layer.weight.shape[0], device=x_q.device, dtype=out_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
out = gemm_afp4wfp4(x_q, layer.weight, x_s, layer.weight_scale, out_dtype, y)
|
||||||
|
|
||||||
|
return out
|
||||||
107
python/sglang/srt/layers/quantization/quark/utils.py
Normal file
107
python/sglang/srt/layers/quantization/quark/utils.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import re
|
||||||
|
from collections.abc import Iterable, Mapping
|
||||||
|
from types import MappingProxyType
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
|
||||||
|
def deep_compare(dict1: Any, dict2: Any) -> bool:
|
||||||
|
if type(dict1) is not type(dict2):
|
||||||
|
return False
|
||||||
|
if isinstance(dict1, dict):
|
||||||
|
if dict1.keys() != dict2.keys():
|
||||||
|
return False
|
||||||
|
return all(deep_compare(dict1[k], dict2[k]) for k in dict1)
|
||||||
|
elif isinstance(dict1, list):
|
||||||
|
return set(dict1) == set(dict2)
|
||||||
|
else:
|
||||||
|
return dict1 == dict2
|
||||||
|
|
||||||
|
|
||||||
|
def should_ignore_layer(
|
||||||
|
layer_name: Optional[str],
|
||||||
|
ignore: Iterable[str],
|
||||||
|
fused_mapping: Mapping[str, list[str]] = MappingProxyType({}),
|
||||||
|
) -> bool:
|
||||||
|
if layer_name is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# layer_name = model.layers.0.self_attn.qkv_proj
|
||||||
|
# proj_name = qkv_proj
|
||||||
|
proj_name = layer_name.split(".")[-1]
|
||||||
|
|
||||||
|
# Fused layers like gate_up_proj or qkv_proj will not be fused
|
||||||
|
# in the safetensors checkpoint. So, we convert the name
|
||||||
|
# from the fused version to unfused + check to make sure that
|
||||||
|
# each shard of the fused layer has the same scheme.
|
||||||
|
if proj_name in fused_mapping:
|
||||||
|
shard_proj_names = fused_mapping[proj_name]
|
||||||
|
|
||||||
|
# Convert fused_name --> [shard_names]
|
||||||
|
shard_names = [
|
||||||
|
layer_name.replace(proj_name, shard_proj_name)
|
||||||
|
for shard_proj_name in shard_proj_names
|
||||||
|
]
|
||||||
|
|
||||||
|
# Layer should be ignored if shards are ignored.
|
||||||
|
should_ignore_layer = None
|
||||||
|
for shard_name in shard_names:
|
||||||
|
should_ignore_shard = check_equal_or_regex_match(
|
||||||
|
layer_name=shard_name, targets=ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
# If shard_idx=0, set layer ignore to match shard.
|
||||||
|
if should_ignore_layer is None:
|
||||||
|
should_ignore_layer = should_ignore_shard
|
||||||
|
|
||||||
|
# If shard_idx=1+ confirm scheme matches prior shards.
|
||||||
|
elif should_ignore_shard != should_ignore_layer:
|
||||||
|
raise ValueError(
|
||||||
|
f"Found a different quantization schemes for "
|
||||||
|
f"{shard_proj_names} in {layer_name}. vLLM "
|
||||||
|
"requires all to use the same scheme."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Unfused layers like down_proj and o_proj will match
|
||||||
|
# the safetensors checkpoint already.
|
||||||
|
else:
|
||||||
|
should_ignore_layer = check_equal_or_regex_match(
|
||||||
|
layer_name=layer_name, targets=ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
assert should_ignore_layer is not None
|
||||||
|
|
||||||
|
return should_ignore_layer
|
||||||
|
|
||||||
|
|
||||||
|
def check_equal_or_regex_match(layer_name: str, targets: Iterable[str]) -> bool:
|
||||||
|
"""
|
||||||
|
Checks whether a layer_name is exactly equal or a regex match for
|
||||||
|
if target starts with 're:' to any target in list.
|
||||||
|
"""
|
||||||
|
for target in targets:
|
||||||
|
if _is_equal_or_regex_match(layer_name, target):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _is_equal_or_regex_match(
|
||||||
|
value: str, target: str, check_contains: bool = False
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Checks whether a value is exactly equal or a regex match for target
|
||||||
|
if target starts with 're:'. If check_contains is set to True,
|
||||||
|
additionally checks if the target string is contained within the value.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if target.startswith("re:"):
|
||||||
|
pattern = target[3:]
|
||||||
|
if re.match(pattern, value):
|
||||||
|
return True
|
||||||
|
elif check_contains:
|
||||||
|
if target.lower() in value.lower():
|
||||||
|
return True
|
||||||
|
elif target == value:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
@@ -843,6 +843,16 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
|
|||||||
return None
|
return None
|
||||||
return remapped_name
|
return remapped_name
|
||||||
|
|
||||||
|
quark_scale_names = {
|
||||||
|
".q_proj.output_scale": ".attn.q_scale",
|
||||||
|
".k_proj.output_scale": ".attn.k_scale",
|
||||||
|
".v_proj.output_scale": ".attn.v_scale",
|
||||||
|
"self_attn.prob_output_scale": ".attn.prob_scale",
|
||||||
|
}
|
||||||
|
for quark_scale_name, sglang_scale_name in quark_scale_names.items():
|
||||||
|
if name.endswith(quark_scale_name):
|
||||||
|
return name.replace(quark_scale_name, sglang_scale_name)
|
||||||
|
|
||||||
# If there were no matches, return the untouched param name
|
# If there were no matches, return the untouched param name
|
||||||
return name
|
return name
|
||||||
|
|
||||||
|
|||||||
@@ -2061,6 +2061,8 @@ class DeepseekV2Model(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class DeepseekV2ForCausalLM(nn.Module):
|
class DeepseekV2ForCausalLM(nn.Module):
|
||||||
|
# for quark model load
|
||||||
|
packed_modules_mapping = {}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -2069,6 +2071,18 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
# for quark model load
|
||||||
|
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
|
||||||
|
self.fuse_qkv_a_proj = (
|
||||||
|
hasattr(config, "q_lora_rank") and config.q_lora_rank is not None
|
||||||
|
)
|
||||||
|
if self.fuse_qkv_a_proj:
|
||||||
|
self.packed_modules_mapping["fused_qkv_a_proj_with_mqa"] = [
|
||||||
|
"q_a_proj",
|
||||||
|
"kv_a_proj_with_mqa",
|
||||||
|
]
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|||||||
@@ -813,6 +813,7 @@ class ServerArgs:
|
|||||||
"moe_wna16",
|
"moe_wna16",
|
||||||
"qoq",
|
"qoq",
|
||||||
"w4afp8",
|
"w4afp8",
|
||||||
|
"mxfp4",
|
||||||
],
|
],
|
||||||
help="The quantization method.",
|
help="The quantization method.",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2832,6 +2832,17 @@ def parse_module_path(module_path, function_name, create_dummy):
|
|||||||
return final_module, None
|
return final_module, None
|
||||||
|
|
||||||
|
|
||||||
|
def mxfp_supported():
|
||||||
|
"""
|
||||||
|
Returns whether the current platform supports MX types.
|
||||||
|
"""
|
||||||
|
if torch.version.hip:
|
||||||
|
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
|
||||||
|
return any(gfx in gcn_arch for gfx in ["gfx95"])
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
# LoRA-related constants and utilities
|
# LoRA-related constants and utilities
|
||||||
SUPPORTED_LORA_TARGET_MODULES = [
|
SUPPORTED_LORA_TARGET_MODULES = [
|
||||||
"q_proj",
|
"q_proj",
|
||||||
|
|||||||
Reference in New Issue
Block a user