Support mxfp4 for GPT-OSS (#8843)
Co-authored-by: Co-author fzyzcjy <ch271828n@outlook.com> Co-authored-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Co-authored-by: zhuofan1123 <zhuofanl@nvidia.com> Co-authored-by: liz-badada <jinyanc@nvidia.com> Co-authored-by: xutizhou <xutingz@nvidia.com> Co-authored-by: linhu-nv <linhu@nvidia.com>
This commit is contained in:
@@ -389,7 +389,7 @@ class FusedMoE(torch.nn.Module):
|
||||
# Narrow parameter and load.
|
||||
if is_bias:
|
||||
# this expert_data is a bias, not weight,
|
||||
# for w2_bias in TP, it does not need to be sharded
|
||||
# for w2_weight_bias in TP, it does not need to be sharded
|
||||
shard_size = expert_data.shape[-1]
|
||||
else:
|
||||
# this parameter is a weight matrix
|
||||
@@ -410,10 +410,6 @@ class FusedMoE(torch.nn.Module):
|
||||
if not is_bias and not self.use_presharded_weights:
|
||||
if self.use_triton_kernels:
|
||||
loaded_weight = loaded_weight.transpose(-2, -1)
|
||||
if shard_size * tp_rank + shard_size > loaded_weight.shape[shard_dim]:
|
||||
raise ValueError(
|
||||
f"Shard size {shard_size} at rank {tp_rank} exceeds loaded_weight dimension {loaded_weight.shape[shard_dim]}"
|
||||
)
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
shard_dim, shard_size * tp_rank, shard_size
|
||||
)
|
||||
@@ -461,9 +457,25 @@ class FusedMoE(torch.nn.Module):
|
||||
loaded_weight: torch.Tensor,
|
||||
weight_name: str,
|
||||
shard_id: str,
|
||||
expert_id: int,
|
||||
expert_id: Optional[int],
|
||||
) -> None:
|
||||
|
||||
# if expert_id is None, then
|
||||
# all the experts are loaded at the same time
|
||||
if (
|
||||
not expert_id
|
||||
and self.quant_config is not None
|
||||
and self.quant_config.get_name() == "mxfp4"
|
||||
):
|
||||
if "bias" in weight_name:
|
||||
dim1 = loaded_weight.shape[1]
|
||||
param.data[:, :dim1].copy_(loaded_weight)
|
||||
else:
|
||||
dim1 = loaded_weight.shape[1]
|
||||
dim2 = loaded_weight.shape[2]
|
||||
param.data[:, :dim1, :dim2].copy_(loaded_weight)
|
||||
return
|
||||
|
||||
global_expert_location_metadata = get_global_expert_location_metadata()
|
||||
if global_expert_location_metadata is None:
|
||||
self._weight_loader_impl(
|
||||
@@ -502,6 +514,7 @@ class FusedMoE(torch.nn.Module):
|
||||
shard_id: str,
|
||||
expert_id: int,
|
||||
) -> None:
|
||||
|
||||
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
|
||||
if expert_id == -1:
|
||||
return
|
||||
@@ -705,6 +718,18 @@ class FusedMoE(torch.nn.Module):
|
||||
) -> None:
|
||||
tp_rank = self.moe_tp_rank
|
||||
|
||||
if self.quant_config is not None and self.quant_config.get_name() == "mxfp4":
|
||||
if "bias" in weight_name:
|
||||
dim1 = loaded_weight.shape[1]
|
||||
param.data[:, :dim1].copy_(loaded_weight)
|
||||
elif "scale" in weight_name:
|
||||
param.data.copy_(loaded_weight)
|
||||
else:
|
||||
dim1 = loaded_weight.shape[1]
|
||||
dim2 = loaded_weight.shape[2]
|
||||
param.data[:, :dim1, :dim2].copy_(loaded_weight)
|
||||
return
|
||||
|
||||
# compressed-tensors checkpoints with packed weights are stored flipped
|
||||
# TODO: check self.quant_method.quant_config.quant_format
|
||||
# against known CompressionFormat enum values that have this quality
|
||||
@@ -854,6 +879,33 @@ class FusedMoE(torch.nn.Module):
|
||||
("experts.w2_weight_bias", f"experts.{ckpt_down_proj_bias_name}", "w2"),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def make_expert_params_mapping_fused_mxfp4(
|
||||
cls,
|
||||
ckpt_gate_up_proj_name: str,
|
||||
ckpt_down_proj_name: str,
|
||||
ckpt_gate_up_proj_bias_name: str,
|
||||
ckpt_down_proj_bias_name: str,
|
||||
ckpt_gate_up_proj_scale_name: str,
|
||||
ckpt_down_proj_scale_name: str,
|
||||
):
|
||||
return [
|
||||
("experts.w13_weight", f"experts.{ckpt_gate_up_proj_name}", "w13"),
|
||||
(
|
||||
"experts.w13_weight_bias",
|
||||
f"experts.{ckpt_gate_up_proj_bias_name}",
|
||||
"w13",
|
||||
),
|
||||
("experts.w2_weight", f"experts.{ckpt_down_proj_name}", "w2"),
|
||||
("experts.w2_weight_bias", f"experts.{ckpt_down_proj_bias_name}", "w2"),
|
||||
(
|
||||
"experts.w13_weight_scale",
|
||||
f"experts.{ckpt_gate_up_proj_scale_name}",
|
||||
"w13",
|
||||
),
|
||||
("experts.w2_weight_scale", f"experts.{ckpt_down_proj_scale_name}", "w2"),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def make_expert_input_scale_params_mapping(
|
||||
cls,
|
||||
|
||||
@@ -186,8 +186,10 @@ def triton_kernel_fused_experts(
|
||||
def triton_kernel_moe_with_bias_forward(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_pcg,
|
||||
b1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_pcg,
|
||||
b2: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
inplace: bool = False,
|
||||
@@ -209,13 +211,15 @@ def triton_kernel_moe_with_bias_forward(
|
||||
|
||||
return triton_kernel_fused_experts_with_bias(
|
||||
hidden_states,
|
||||
w1,
|
||||
b1,
|
||||
w2,
|
||||
b2,
|
||||
routing_data,
|
||||
gather_idx,
|
||||
scatter_idx,
|
||||
w1=w1,
|
||||
w1_pcg=w1_pcg,
|
||||
b1=b1,
|
||||
w2=w2,
|
||||
w2_pcg=w2_pcg,
|
||||
b2=b2,
|
||||
routing_data=routing_data,
|
||||
gather_indx=gather_idx,
|
||||
scatter_indx=scatter_idx,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
@@ -235,8 +239,10 @@ def triton_kernel_moe_with_bias_forward(
|
||||
def triton_kernel_fused_experts_with_bias(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_pcg,
|
||||
b1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_pcg,
|
||||
b2: torch.Tensor,
|
||||
routing_data: RoutingData,
|
||||
gather_indx: GatherIndx,
|
||||
@@ -267,8 +273,10 @@ def triton_kernel_fused_experts_with_bias(
|
||||
|
||||
# type check
|
||||
assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16"
|
||||
assert w1.dtype == torch.bfloat16, "w1 must be bfloat16"
|
||||
assert w2.dtype == torch.bfloat16, "w2 must be bfloat16"
|
||||
for w in (w1, w2):
|
||||
# TODO assert bf16 or mxfp4
|
||||
# assert (w.dtype == torch.bfloat16) or check-is-mxfp4, f"w must be bfloat16 or mxfp4 {w1.dtype=}"
|
||||
pass
|
||||
|
||||
# Shape check
|
||||
assert hidden_states.ndim == 2, "hidden_states must be 2D"
|
||||
@@ -287,13 +295,15 @@ def triton_kernel_fused_experts_with_bias(
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
|
||||
device = "cuda"
|
||||
optg = dict()
|
||||
w1, w1_flex = quantize(w1, "bf16", device, **optg)
|
||||
w1_pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex))
|
||||
# TODO maybe completely remove this branch
|
||||
if w1.dtype == torch.bfloat16:
|
||||
device = "cuda"
|
||||
optg = dict()
|
||||
w1, w1_flex = quantize(w1, "bf16", device, **optg)
|
||||
w1_pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex))
|
||||
|
||||
w2, w2_flex = quantize(w2, "bf16", device, **optg)
|
||||
w2_pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex))
|
||||
w2, w2_flex = quantize(w2, "bf16", device, **optg)
|
||||
w2_pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex))
|
||||
|
||||
act = FusedActivation(
|
||||
FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")),
|
||||
|
||||
@@ -47,7 +47,7 @@ from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
|
||||
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
|
||||
CompressedTensorsConfig,
|
||||
)
|
||||
from sglang.srt.utils import mxfp_supported
|
||||
from sglang.srt.utils import is_cuda, is_hip, mxfp_supported
|
||||
|
||||
is_mxfp_supported = mxfp_supported()
|
||||
if is_mxfp_supported:
|
||||
@@ -66,6 +66,7 @@ from sglang.srt.layers.quantization.modelopt_quant import (
|
||||
ModelOptFp8Config,
|
||||
)
|
||||
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
|
||||
from sglang.srt.layers.quantization.mxfp4 import Mxfp4Config
|
||||
from sglang.srt.layers.quantization.petit import PetitNvFp4Config
|
||||
from sglang.srt.layers.quantization.qoq import QoQConfig
|
||||
from sglang.srt.layers.quantization.utils import get_linear_quant_method
|
||||
@@ -90,7 +91,16 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
||||
"w4afp8": W4AFp8Config,
|
||||
"petit_nvfp4": PetitNvFp4Config,
|
||||
}
|
||||
if is_mxfp_supported:
|
||||
|
||||
|
||||
if is_cuda():
|
||||
BASE_QUANTIZATION_METHODS.update(
|
||||
{
|
||||
"quark": Mxfp4Config,
|
||||
"mxfp4": Mxfp4Config,
|
||||
}
|
||||
)
|
||||
elif is_mxfp_supported and is_hip():
|
||||
BASE_QUANTIZATION_METHODS.update(
|
||||
{
|
||||
"quark": MxFp4Config,
|
||||
|
||||
@@ -50,315 +50,50 @@ use_dynamic_mxfp4_linear = get_bool_env_var("SGLANG_USE_DYNAMIC_MXFP4_linear")
|
||||
OCP_MX_BLOCK_SIZE = 32
|
||||
|
||||
|
||||
class MxFp4Config(QuantizationConfig):
|
||||
class Mxfp4Config(QuantizationConfig):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
is_checkpoint_fp4_serialized: bool = False,
|
||||
quant_config: dict[str, Any] = None,
|
||||
kv_cache_group: Optional[list[str]] = None,
|
||||
kv_cache_config: Optional[dict[str, Any]] = None,
|
||||
pack_method: str = "reorder",
|
||||
ignored_layers: Optional[List[str]] = None,
|
||||
):
|
||||
def __init__(self, ignored_layers: Optional[list[str]] = None):
|
||||
super().__init__()
|
||||
if kv_cache_group is None:
|
||||
kv_cache_group = []
|
||||
self.ignored_layers = ignored_layers
|
||||
|
||||
self.is_checkpoint_fp4_serialized = is_checkpoint_fp4_serialized
|
||||
self.quant_config = quant_config
|
||||
self.kv_cache_group = kv_cache_group
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.pack_method = pack_method
|
||||
|
||||
self.packed_modules_mapping = (
|
||||
self.quant_config["packed_modules_mapping"]
|
||||
if is_checkpoint_fp4_serialized
|
||||
else None
|
||||
)
|
||||
|
||||
self.ignored_layers = ignored_layers or []
|
||||
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
return cls()
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
def get_name(self) -> str:
|
||||
return "fp4"
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
|
||||
# Check if the layer is skipped for quantization.
|
||||
if len(self.ignored_layers) > 0 and should_ignore_layer(
|
||||
prefix,
|
||||
ignore=self.ignored_layers,
|
||||
fused_mapping=self.packed_modules_mapping,
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if self.is_checkpoint_fp4_serialized:
|
||||
scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
||||
layer.scheme = scheme
|
||||
return MxFp4LinearMethod(self)
|
||||
|
||||
elif use_dynamic_mxfp4_linear:
|
||||
return MxFp4LinearMethod(self)
|
||||
else:
|
||||
return UnquantizedLinearMethod()
|
||||
|
||||
if isinstance(layer, RadixAttention):
|
||||
return MxFp4KVCacheMethod(self)
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
return MxFp4MoEMethod.get_moe_method(self, module=layer, layer_name=prefix)
|
||||
|
||||
return None
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "MxFp4Config":
|
||||
if not mxfp_supported():
|
||||
platform = torch.cuda.get_device_properties(0).gcnArchName
|
||||
raise ValueError(
|
||||
f"Current platform {platform} not support mxfp4 computation"
|
||||
)
|
||||
quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||
is_checkpoint_fp4_serialized = (
|
||||
True if quant_method else False
|
||||
) # "quark" in quant_method
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "mxfp4"
|
||||
|
||||
kv_cache_group = []
|
||||
pack_method = None
|
||||
|
||||
if is_checkpoint_fp4_serialized:
|
||||
export_config = config.get("export")
|
||||
if export_config is None:
|
||||
raise ValueError(
|
||||
"The export key should be included in "
|
||||
"the configurations of Quark quantized model"
|
||||
)
|
||||
|
||||
kv_cache_group = cast(list[str], export_config.get("kv_cache_group"))
|
||||
pack_method = cast(str, export_config.get("pack_method"))
|
||||
|
||||
# In the export model of quark, the quantization configuration
|
||||
# of kv_cache is stored in layer_quant_config. First, it is
|
||||
# judged whether kv_cache_group exists, and then it is judged
|
||||
# whether layer_quant_config has a quantization configuration
|
||||
# that matches kv_cache.
|
||||
if len(kv_cache_group) == 0:
|
||||
kv_cache_config = None
|
||||
else:
|
||||
kv_cache_set = set(kv_cache_group)
|
||||
layer_quant_config = cast(dict[str, Any], config.get("layer_quant_config"))
|
||||
layer_quant_names = list(layer_quant_config.keys())
|
||||
layer_quant_set = set(layer_quant_names)
|
||||
|
||||
if not kv_cache_set.issubset(layer_quant_set):
|
||||
raise ValueError(
|
||||
"The Quark quantized model has the "
|
||||
"kv_cache_group parameter setting, "
|
||||
"but no kv_cache quantization settings "
|
||||
"were found in the quantization "
|
||||
"configuration."
|
||||
)
|
||||
|
||||
q_configs = [
|
||||
cast(dict[str, Any], layer_quant_config.get(name))
|
||||
for name in kv_cache_group
|
||||
]
|
||||
if not all(deep_compare(q_config, q_configs[0]) for q_config in q_configs):
|
||||
raise ValueError(
|
||||
"The quantization method used for kv_cache should "
|
||||
"be the same, but the quantization method for the "
|
||||
"kv_cache layer in the config is different."
|
||||
)
|
||||
kv_cache_config = q_configs[0].get("output_tensors")
|
||||
if kv_cache_config is None:
|
||||
raise ValueError("The kv_cache quantization configuration is empty.")
|
||||
|
||||
# Since we have already set kv_cache quantization configurations,
|
||||
# we will remove the quantization configuration for the
|
||||
# output_tensors corresponding to the kv_cache layer.
|
||||
for q_config in q_configs:
|
||||
q_config["output_tensors"] = None
|
||||
|
||||
# In case q_proj output is also quantized, remove the configuration
|
||||
# to keep qkv consistency.
|
||||
q_proj_q_config = cast(dict[str, Any], layer_quant_config.get("*q_proj"))
|
||||
if q_proj_q_config is not None:
|
||||
q_proj_q_config["output_tensors"] = None
|
||||
|
||||
ignored_layers = cls.get_from_keys_or(config, ["exclude"], None)
|
||||
|
||||
return cls(
|
||||
is_checkpoint_fp4_serialized=is_checkpoint_fp4_serialized,
|
||||
quant_config=config,
|
||||
kv_cache_group=kv_cache_group,
|
||||
kv_cache_config=kv_cache_config,
|
||||
pack_method=pack_method,
|
||||
ignored_layers=ignored_layers,
|
||||
)
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return []
|
||||
|
||||
def _check_scheme_supported(self, min_capability: int, error: bool = True) -> bool:
|
||||
capability_tuple = get_device_capability()
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
|
||||
if capability_tuple is not None:
|
||||
assert 0 <= capability_tuple[1] < 10
|
||||
capability = capability_tuple[0] * 10 + capability_tuple[1]
|
||||
|
||||
supported = capability >= min_capability
|
||||
if error and not supported:
|
||||
raise RuntimeError(
|
||||
"Quantization scheme is not supported for ",
|
||||
f"the current GPU. Min capability: {min_capability}. ",
|
||||
f"Current capability: {capability}.",
|
||||
)
|
||||
return supported
|
||||
else:
|
||||
return False
|
||||
|
||||
def _is_mx_fp4(
|
||||
self,
|
||||
weight_quant: Optional[dict[str, Any]],
|
||||
input_quant: Optional[dict[str, Any]],
|
||||
) -> bool:
|
||||
# Confirm weights and input quantized.
|
||||
if weight_quant is None or input_quant is None:
|
||||
logger.debug(
|
||||
"Quark model is not in MX-FP4 format: "
|
||||
"weight_quant or input_quant not set"
|
||||
)
|
||||
return False
|
||||
|
||||
# Input and weight dtype needs to be fp4.
|
||||
if weight_quant.get("dtype") != "fp4" or input_quant.get("dtype") != "fp4":
|
||||
logger.debug("Quark model is not in MX-FP4 format: dtype not fp4")
|
||||
return False
|
||||
|
||||
# Input and weight qscheme needs to be per group.
|
||||
if (
|
||||
weight_quant.get("qscheme") != "per_group"
|
||||
or input_quant.get("qscheme") != "per_group"
|
||||
):
|
||||
logger.debug("Quark model is not in MX-FP4 format: not per_group")
|
||||
return False
|
||||
|
||||
# Input and weight group size needs to be 32.
|
||||
if weight_quant.get("group_size") != 32 or input_quant.get("group_size") != 32:
|
||||
logger.debug("Quark model is not in MX-FP4 format: not group_size=32")
|
||||
return False
|
||||
|
||||
# Weights need to use static quantization.
|
||||
if weight_quant.get("is_dynamic") is True:
|
||||
logger.debug("Quark model is not in MX-FP4 format: not weight static")
|
||||
return False
|
||||
|
||||
# Activations need to use dynamic quantization.
|
||||
if input_quant.get("is_dynamic") is False:
|
||||
logger.debug("Quark model is not in MX-FP4 format: not activation dynamic")
|
||||
return False
|
||||
|
||||
# Activations and weight scales need to be in e8m0 format.
|
||||
if (
|
||||
weight_quant.get("scale_format") != "e8m0"
|
||||
or input_quant.get("scale_format") != "e8m0"
|
||||
):
|
||||
logger.debug("Quark model is not in MX-FP4 format: not scale_format e8m0")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _find_matched_config(
|
||||
self, layer_name: str, module: torch.nn.Module
|
||||
) -> dict[str, Any]:
|
||||
|
||||
proj_name = layer_name.split(".")[-1]
|
||||
if proj_name in self.packed_modules_mapping:
|
||||
shard_proj_names = self.packed_modules_mapping[proj_name]
|
||||
|
||||
# Convert fused_name --> [shard_names]
|
||||
shard_names = [
|
||||
layer_name.replace(proj_name, shard_proj_name)
|
||||
for shard_proj_name in shard_proj_names
|
||||
]
|
||||
shard_configs = [
|
||||
self._find_matched_config(shard_name, module)
|
||||
for shard_name in shard_names
|
||||
]
|
||||
if not all(
|
||||
deep_compare(q_config, shard_configs[0]) for q_config in shard_configs
|
||||
if isinstance(layer, LinearBase):
|
||||
if self.ignored_layers and is_layer_skipped(
|
||||
prefix=prefix,
|
||||
ignored_layers=self.ignored_layers,
|
||||
fused_mapping=self.packed_modules_mapping,
|
||||
):
|
||||
raise ValueError(
|
||||
f"Found a different quantization configuration for "
|
||||
f"{shard_proj_names=} in {layer_name=}. vLLM "
|
||||
"requires all to use the same scheme."
|
||||
)
|
||||
return shard_configs[0]
|
||||
else:
|
||||
layer_quant_config = cast(
|
||||
dict[str, Any], self.quant_config.get("layer_quant_config")
|
||||
)
|
||||
for name_pattern in layer_quant_config:
|
||||
if fnmatch.fnmatch(layer_name, name_pattern):
|
||||
return layer_quant_config[name_pattern]
|
||||
|
||||
layer_type = cast(str, type(module))
|
||||
layer_type_quant_config = cast(
|
||||
dict[str, Any], self.quant_config.get("layer_type_quant_config")
|
||||
)
|
||||
if layer_type in layer_type_quant_config:
|
||||
return layer_type_quant_config[layer_type]
|
||||
|
||||
global_quant_config = cast(
|
||||
dict[str, Any], self.quant_config.get("global_quant_config")
|
||||
)
|
||||
return global_quant_config
|
||||
|
||||
def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme":
|
||||
if config.get("output_tensors") or config.get("bias"):
|
||||
raise NotImplementedError(
|
||||
"Currently, Quark models with output_tensors "
|
||||
"and bias quantized are not supported"
|
||||
)
|
||||
weight_config = cast(dict[str, Any], config.get("weight"))
|
||||
input_config = cast(dict[str, Any], config.get("input_tensors"))
|
||||
|
||||
if self._is_mx_fp4(weight_config, input_config):
|
||||
return QuarkW4A4MXFP4(weight_config, input_config)
|
||||
|
||||
raise NotImplementedError(
|
||||
"No quark compatible scheme was found. "
|
||||
f"{weight_config=}, "
|
||||
f"{input_config=}"
|
||||
)
|
||||
|
||||
def get_scheme(self, layer: torch.nn.Module, layer_name: str) -> "QuarkScheme":
|
||||
|
||||
layer_quant_config = self._find_matched_config(layer_name, layer)
|
||||
|
||||
# Find the quant_scheme
|
||||
scheme = self._get_scheme_from_config(layer_quant_config)
|
||||
|
||||
# Raise error if device does not support the scheme
|
||||
# (e.g. fp8 needs ada lovelace)
|
||||
self._check_scheme_supported(scheme.get_min_capability())
|
||||
|
||||
return scheme
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
return UnquantizedLinearMethod()
|
||||
raise NotImplementedError("Mxfp4 linear layer is not implemented")
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return Mxfp4MoEMethod(layer.moe_config)
|
||||
elif isinstance(layer, Attention):
|
||||
raise NotImplementedError("Mxfp4 attention layer is not implemented")
|
||||
return None
|
||||
|
||||
|
||||
class MxFp4LinearMethod(LinearMethodBase):
|
||||
|
||||
443
python/sglang/srt/layers/quantization/mxfp4.py
Normal file
443
python/sglang/srt/layers/quantization/mxfp4.py
Normal file
@@ -0,0 +1,443 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
# from vllm.model_executor.layers.fused_moe import (
|
||||
# FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
|
||||
# FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize)
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
FusedMoEMethodBase,
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
||||
from sglang.srt.utils import (
|
||||
direct_register_custom_op,
|
||||
is_cuda,
|
||||
is_flashinfer_available,
|
||||
is_hip,
|
||||
next_power_of_2,
|
||||
round_up,
|
||||
set_weight_attrs,
|
||||
)
|
||||
|
||||
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
||||
|
||||
if is_flashinfer_available():
|
||||
# from flashinfer.fused_moe import cutlass_fused_moe
|
||||
from flashinfer import (
|
||||
mxfp8_quantize,
|
||||
shuffle_matrix_a,
|
||||
shuffle_matrix_sf_a,
|
||||
trtllm_fp4_block_scale_moe,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
|
||||
OCP_MX_BLOCK_SIZE = 32
|
||||
|
||||
|
||||
def _swizzle_mxfp4(quant_tensor, scale, num_warps):
|
||||
"""weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel"""
|
||||
import triton_kernels.matmul_ogs_details.opt_flags as opt_flags
|
||||
from triton_kernels.numerics import InFlexData
|
||||
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
|
||||
from triton_kernels.tensor_details import layout
|
||||
|
||||
value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(
|
||||
mx_axis=1
|
||||
)
|
||||
scale_layout, scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout(
|
||||
mx_axis=1, num_warps=num_warps
|
||||
)
|
||||
if is_cuda() and torch.cuda.get_device_capability()[0] == 10:
|
||||
constraints = {
|
||||
"is_persistent": True,
|
||||
"epilogue_subtile": 1,
|
||||
}
|
||||
opt_flags.update_opt_flags_constraints(constraints)
|
||||
# transpose the tensor so that the quantization axis is on dim1
|
||||
quant_tensor = quant_tensor.transpose(-2, -1)
|
||||
scale = scale.transpose(-2, -1)
|
||||
quant_tensor = convert_layout(
|
||||
wrap_torch_tensor(quant_tensor, dtype=FP4), value_layout, **value_layout_opts
|
||||
)
|
||||
scale = convert_layout(wrap_torch_tensor(scale), scale_layout, **scale_layout_opts)
|
||||
return quant_tensor, InFlexData(), scale
|
||||
|
||||
|
||||
def _dequant_mxfp4(
|
||||
x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
try:
|
||||
from quark.torch.kernel import mx
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"The package `amd-quark` is required to use "
|
||||
"MX-FP4 models. Please install it with `pip install "
|
||||
"amd-quark`."
|
||||
) from err
|
||||
|
||||
return mx.dq_mxfp4(x, scale, float_dtype)
|
||||
|
||||
|
||||
def _dequant_mxfp4_fake(
|
||||
x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(
|
||||
(*x.shape[:-1], x.shape[-1] * 2), dtype=float_dtype, device=x.device
|
||||
)
|
||||
|
||||
|
||||
def _quant_dequant_mxfp4(
|
||||
x: torch.Tensor, scale_calculation_mode: str = "even"
|
||||
) -> torch.Tensor:
|
||||
try:
|
||||
from quark.torch.kernel import mx
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"The package `amd-quark` is required to use "
|
||||
"MX-FP4 models. Please install it with `pip install "
|
||||
"amd-quark`."
|
||||
) from err
|
||||
|
||||
return mx.qdq_mxfp4(x, scale_calculation_mode)
|
||||
|
||||
|
||||
def _quant_dequant_mxfp4_fake(
|
||||
x: torch.Tensor, scale_calculation_mode: str = "even"
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="dequant_mxfp4",
|
||||
op_func=_dequant_mxfp4,
|
||||
mutates_args=[],
|
||||
fake_impl=_dequant_mxfp4_fake,
|
||||
)
|
||||
dequant_mxfp4 = torch.ops.sglang.dequant_mxfp4
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="quant_dequant_mxfp4",
|
||||
op_func=_quant_dequant_mxfp4,
|
||||
mutates_args=[],
|
||||
fake_impl=_quant_dequant_mxfp4_fake,
|
||||
)
|
||||
quant_dequant_mxfp4 = torch.ops.sglang.quant_dequant_mxfp4
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
|
||||
|
||||
class Mxfp4Config(QuantizationConfig):
|
||||
|
||||
def __init__(self, ignored_layers: Optional[list[str]] = None):
|
||||
super().__init__()
|
||||
self.ignored_layers = ignored_layers
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
return cls()
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "mxfp4"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.float16]
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return []
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if self.ignored_layers and is_layer_skipped(
|
||||
prefix=prefix,
|
||||
ignored_layers=self.ignored_layers,
|
||||
fused_mapping=self.packed_modules_mapping,
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return Mxfp4MoEMethod(use_triton_kernels=True, with_bias=True)
|
||||
else:
|
||||
raise NotImplementedError("Mxfp4 attention layer is not implemented")
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def __init__(self, use_triton_kernels: bool = True, with_bias: bool = True):
|
||||
super().__init__()
|
||||
self.topk_indices_dtype = None
|
||||
self.use_triton_kernels = use_triton_kernels
|
||||
self.with_bias = with_bias
|
||||
self.triton_kernel_moe_forward = None
|
||||
self.triton_kernel_moe_with_bias_forward = None
|
||||
if torch.cuda.is_available() and has_triton_kernels:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
||||
triton_kernel_moe_forward as _tk_forward,
|
||||
)
|
||||
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
||||
triton_kernel_moe_with_bias_forward as _tk_with_bias_forward,
|
||||
)
|
||||
|
||||
self.triton_kernel_moe_forward = _tk_forward
|
||||
self.triton_kernel_moe_with_bias_forward = _tk_with_bias_forward
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
# print(f"hi {self=} create_weights {layer=}")
|
||||
self.num_experts = num_experts
|
||||
weight_dtype = torch.uint8
|
||||
scale_dtype = torch.uint8
|
||||
|
||||
intermediate_size *= 2
|
||||
mxfp4_block = 32
|
||||
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_size = hidden_size
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts, 2 * intermediate_size, hidden_size // 2, dtype=weight_dtype
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size,
|
||||
hidden_size // mxfp4_block,
|
||||
dtype=scale_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
|
||||
w13_weight_bias = torch.nn.Parameter(
|
||||
torch.zeros(num_experts, 2 * intermediate_size, dtype=torch.bfloat16),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_bias", w13_weight_bias)
|
||||
set_weight_attrs(w13_weight_bias, extra_weight_attrs)
|
||||
|
||||
# down_proj (row parallel)
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts, hidden_size, intermediate_size // 2, dtype=weight_dtype
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size // mxfp4_block,
|
||||
dtype=scale_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
w2_weight_bias = torch.nn.Parameter(
|
||||
torch.zeros(num_experts, hidden_size, dtype=torch.bfloat16),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight_bias", w2_weight_bias)
|
||||
set_weight_attrs(w2_weight_bias, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
|
||||
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
||||
|
||||
w13_weight_bias = layer.w13_weight_bias.to(torch.float32)
|
||||
w2_weight_bias = layer.w2_weight_bias.to(torch.float32)
|
||||
|
||||
layer.w13_weight_bias = Parameter(w13_weight_bias, requires_grad=False)
|
||||
layer.w2_weight_bias = Parameter(w2_weight_bias, requires_grad=False)
|
||||
|
||||
num_warps = 8
|
||||
|
||||
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
|
||||
layer.w13_weight, layer.w13_weight_scale, num_warps
|
||||
)
|
||||
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
|
||||
layer.w2_weight, layer.w2_weight_scale, num_warps
|
||||
)
|
||||
|
||||
self.w13_precision_config = PrecisionConfig(
|
||||
weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex)
|
||||
)
|
||||
self.w2_precision_config = PrecisionConfig(
|
||||
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
|
||||
)
|
||||
|
||||
self.w13_weight_triton_tensor = w13_weight
|
||||
self.w2_weight_triton_tensor = w2_weight
|
||||
|
||||
# need to delete the original weights to save memory on single GPU
|
||||
del layer.w13_weight
|
||||
del layer.w2_weight
|
||||
layer.w13_weight = None
|
||||
layer.w2_weight = None
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int):
|
||||
# Number of tokens in the input tensor.
|
||||
num_tokens = x.shape[0]
|
||||
# Factor to account for the imbalance of the experts.
|
||||
# factor equals to the
|
||||
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
|
||||
# - 1.0 means perfect expert distribution.
|
||||
# - > 1.0 means some experts have more
|
||||
# tokens than the perfect distribution.
|
||||
# - < 1.0 does not make sense.
|
||||
imbalance_factor = 1.3
|
||||
# Calculate the number of tokens per expert
|
||||
# assuming perfect distribution.
|
||||
num_tokens_per_expert = (num_tokens * top_k) // self.num_experts
|
||||
# Apply the imbalance factor.
|
||||
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
|
||||
# And pad the number to the next power of 2.
|
||||
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
|
||||
# Cap to 8-64 tokens per CTA tile
|
||||
# as it's the range supported by the kernel.
|
||||
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
||||
|
||||
return tile_tokens_dim
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
activation_alpha: Optional[float] = None,
|
||||
swiglu_limit: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
# avoid import error when triton_kernel is not installed
|
||||
# from vllm.model_executor.layers.fused_moe.triton_kernels_moe import (
|
||||
# triton_kernel_moe_forward)
|
||||
|
||||
"""
|
||||
if (envs.VLLM_USE_FLASHINFER_MXFP4_MOE
|
||||
or envs.VLLM_USE_FLASHINFER_MXFP4_BF16_MOE):
|
||||
assert not self.moe.use_ep, (
|
||||
"EP is not supported for flashinfer mxfp4 moe backend yet.")
|
||||
if envs.VLLM_USE_FLASHINFER_MXFP4_BF16_MOE:
|
||||
assert x.dtype == torch.bfloat16
|
||||
x_quant = x
|
||||
x_scale = None
|
||||
else:
|
||||
x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
|
||||
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
|
||||
trtllm_gen_output = trtllm_fp4_block_scale_moe(
|
||||
router_logits.to(torch.bfloat16),
|
||||
None, # routing_bias
|
||||
x_quant,
|
||||
x_scale,
|
||||
layer.w13_weight, # uint8 (e2m1 x 2)
|
||||
layer.w13_weight_scale, # uint8 (e4m3 x 2)
|
||||
layer.w13_weight_bias, # fp32 per expert per channel
|
||||
layer.gemm1_alpha, # fp32 per expert
|
||||
layer.gemm1_beta, # fp32 per expert
|
||||
layer.gemm1_clamp_limit, # fp32 per expert
|
||||
layer.w2_weight, # uint8 (e2m1 x 2)
|
||||
layer.w2_weight_scale, # ue8m0
|
||||
layer.w2_weight_bias, # fp32 per expert per channel
|
||||
None, # output1_scale_scalar
|
||||
None, # output1_scale_gate_scalar
|
||||
None, # output2_scale_scalar
|
||||
self.num_experts,
|
||||
top_k,
|
||||
None, # n_group
|
||||
None, # topk_group
|
||||
self.intermediate_size, # padded to multiple of 256
|
||||
0, # local_expert_offset
|
||||
self.num_experts, # local num experts
|
||||
None,
|
||||
self._get_tile_tokens_dim(x, top_k),
|
||||
1, # routing_method_type, renormalize
|
||||
True, # do finalize
|
||||
)[0]
|
||||
return trtllm_gen_output
|
||||
"""
|
||||
|
||||
if self.use_triton_kernels:
|
||||
if self.with_bias:
|
||||
# TODO why we do not put weights on layer?
|
||||
assert layer.w13_weight is None
|
||||
assert layer.w2_weight is None
|
||||
return self.triton_kernel_moe_with_bias_forward(
|
||||
hidden_states=x,
|
||||
w1=self.w13_weight_triton_tensor,
|
||||
w1_pcg=self.w13_precision_config,
|
||||
w2=self.w2_weight_triton_tensor,
|
||||
w2_pcg=self.w2_precision_config,
|
||||
b1=layer.w13_weight_bias,
|
||||
b2=layer.w2_weight_bias,
|
||||
topk_output=topk_output,
|
||||
activation=activation,
|
||||
activation_alpha=activation_alpha,
|
||||
swiglu_limit=swiglu_limit,
|
||||
)
|
||||
else:
|
||||
return self.triton_kernel_moe_forward(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_output=topk_output,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
@@ -272,6 +272,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
activation=activation,
|
||||
activation_alpha=activation_alpha,
|
||||
swiglu_limit=swiglu_limit,
|
||||
w1_pcg=None,
|
||||
w2_pcg=None,
|
||||
)
|
||||
else:
|
||||
return self.triton_kernel_moe_forward(
|
||||
|
||||
@@ -25,6 +25,8 @@ from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
get_moe_expert_parallel_rank,
|
||||
get_moe_expert_parallel_world_size,
|
||||
get_moe_tensor_parallel_rank,
|
||||
get_pp_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
@@ -108,11 +110,15 @@ class GptOssSparseMoeBlock(nn.Module):
|
||||
experts_type = get_moe_impl_class()
|
||||
extra_kwargs = {}
|
||||
if experts_type.__name__ == "FusedMoE":
|
||||
quant_config_name = (
|
||||
quant_config.get_name() if quant_config is not None else None
|
||||
)
|
||||
extra_kwargs = {
|
||||
"enable_flashinfer_cutlass_moe": global_server_args_dict[
|
||||
"enable_flashinfer_cutlass_moe"
|
||||
],
|
||||
"use_weight_loader_fused": True, # for moe gate_up_proj and down_proj and their bias loading
|
||||
# for moe gate_up_proj and down_proj and their bias loading
|
||||
"use_weight_loader_fused": quant_config_name != "mxfp4",
|
||||
}
|
||||
self.experts = experts_type(
|
||||
num_experts=config.num_local_experts
|
||||
@@ -350,7 +356,6 @@ class GptOssDecoderLayer(nn.Module):
|
||||
head_dim=head_dim,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
attention_bias=attention_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
sliding_window_size=self.sliding_window_size,
|
||||
layer_type=config.layer_types[layer_id],
|
||||
@@ -538,7 +543,7 @@ class GptOssForCausalLM(nn.Module):
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
# quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||
)
|
||||
@@ -652,11 +657,188 @@ class GptOssForCausalLM(nn.Module):
|
||||
|
||||
return weight_mapping
|
||||
|
||||
# TODO beautify code
|
||||
def load_weights(
|
||||
self,
|
||||
weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
is_nextn: bool = False,
|
||||
weight_name_mapping: dict = None,
|
||||
):
|
||||
quant_config_name = (
|
||||
self.quant_config.get_name() if self.quant_config is not None else None
|
||||
)
|
||||
if quant_config_name != "mxfp4":
|
||||
self._load_normal_weights(
|
||||
weights, is_nextn=is_nextn, weight_name_mapping=weight_name_mapping
|
||||
)
|
||||
else:
|
||||
self._load_weights_mxfp4(
|
||||
weights, is_nextn=is_nextn, weight_name_mapping=weight_name_mapping
|
||||
)
|
||||
|
||||
def _load_weights_mxfp4(self, weights, is_nextn, weight_name_mapping):
|
||||
mxfp4_weights = []
|
||||
normal_weights = []
|
||||
|
||||
for name, weight in weights:
|
||||
if (
|
||||
".experts" in name
|
||||
and self.quant_config is not None
|
||||
and self.quant_config.get_name() == "mxfp4"
|
||||
):
|
||||
mxfp4_weights.append((name, weight))
|
||||
else:
|
||||
normal_weights.append((name, weight))
|
||||
|
||||
mxfp4_loaded_params = self._load_mxfp4_experts_weights(mxfp4_weights)
|
||||
self._load_normal_weights(
|
||||
normal_weights,
|
||||
is_nextn=is_nextn,
|
||||
weight_name_mapping=weight_name_mapping,
|
||||
other_loaded_param_names=mxfp4_loaded_params,
|
||||
)
|
||||
|
||||
def _load_mxfp4_experts_weights(self, weights):
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
mxfp4_block = 32
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
intermediate_size = self.config.intermediate_size
|
||||
intermediate_size_block = intermediate_size // mxfp4_block
|
||||
per_rank_intermediate_size_block = intermediate_size_block // tp_size
|
||||
per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block
|
||||
|
||||
# Calculate common slicing bounds for current rank
|
||||
tp_rank_start = tp_rank * per_rank_intermediate_size
|
||||
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size)
|
||||
|
||||
# Attention heads per rank
|
||||
heads_per_rank = self.config.num_attention_heads // tp_size
|
||||
head_start = tp_rank * heads_per_rank
|
||||
|
||||
num_experts = self.config.num_local_experts
|
||||
|
||||
for name, weight in weights:
|
||||
weight = weight.cuda()
|
||||
|
||||
if "gate_up_proj_blocks" in name:
|
||||
# Handle MLP gate and up projection weights
|
||||
new_name = name.replace("gate_up_proj_blocks", "w13_weight")
|
||||
|
||||
# flat weight from (E, 2 * N, block_size, entry_per_block)
|
||||
# to (E, 2 * N, -1), shouldn't trigger copy for contiguous
|
||||
weight = weight.view(
|
||||
num_experts, 2 * intermediate_size, -1
|
||||
).contiguous()
|
||||
|
||||
narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end, ...]
|
||||
|
||||
param = params_dict[new_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(
|
||||
param,
|
||||
narrow_weight,
|
||||
weight_name=new_name,
|
||||
shard_id=None,
|
||||
expert_id=None,
|
||||
)
|
||||
loaded_params.add(new_name)
|
||||
|
||||
elif "down_proj_blocks" in name:
|
||||
# Handle MLP down projection weights
|
||||
new_name = name.replace("down_proj_blocks", "w2_weight")
|
||||
# same flatten here, but since 2 mx4 value are packed in 1
|
||||
# uint8, divide by 2
|
||||
weight = weight.view(
|
||||
num_experts, -1, intermediate_size // 2
|
||||
).contiguous()
|
||||
narrow_weight = weight[..., tp_rank_start // 2 : tp_rank_end // 2]
|
||||
|
||||
param = params_dict[new_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(
|
||||
param,
|
||||
narrow_weight,
|
||||
weight_name=new_name,
|
||||
shard_id=None,
|
||||
expert_id=None,
|
||||
)
|
||||
loaded_params.add(new_name)
|
||||
|
||||
elif "gate_up_proj_scales" in name:
|
||||
# Handle MLP gate and up projection weights scale
|
||||
new_name = name.replace("gate_up_proj_scales", "w13_weight_scale")
|
||||
narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end, ...]
|
||||
|
||||
param = params_dict[new_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(
|
||||
param,
|
||||
narrow_weight,
|
||||
weight_name=new_name,
|
||||
shard_id=None,
|
||||
expert_id=None,
|
||||
)
|
||||
loaded_params.add(new_name)
|
||||
|
||||
elif "down_proj_scales" in name:
|
||||
# Handle MLP down projection weights
|
||||
new_name = name.replace("down_proj_scales", "w2_weight_scale")
|
||||
narrow_weight = weight[
|
||||
..., tp_rank_start // mxfp4_block : tp_rank_end // mxfp4_block
|
||||
]
|
||||
|
||||
param = params_dict[new_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(
|
||||
param,
|
||||
narrow_weight,
|
||||
weight_name=new_name,
|
||||
shard_id=None,
|
||||
expert_id=None,
|
||||
)
|
||||
loaded_params.add(new_name)
|
||||
elif "gate_up_proj_bias" in name:
|
||||
# Handle MLP gate and up projection biases
|
||||
new_name = name.replace("gate_up_proj_bias", "w13_weight_bias")
|
||||
|
||||
narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end]
|
||||
|
||||
param = params_dict[new_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(
|
||||
param,
|
||||
narrow_weight,
|
||||
weight_name=new_name,
|
||||
shard_id=None,
|
||||
expert_id=None,
|
||||
)
|
||||
loaded_params.add(new_name)
|
||||
|
||||
elif "down_proj_bias" in name:
|
||||
if get_moe_tensor_parallel_rank() != 0:
|
||||
weight = torch.zeros_like(weight)
|
||||
|
||||
# Handle MLP down projection bias
|
||||
new_name = name.replace("down_proj_bias", "w2_weight_bias")
|
||||
param = params_dict[new_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(
|
||||
param, weight, weight_name=new_name, shard_id=None, expert_id=None
|
||||
)
|
||||
loaded_params.add(new_name)
|
||||
|
||||
return loaded_params
|
||||
|
||||
def _load_normal_weights(
|
||||
self,
|
||||
weights,
|
||||
is_nextn: bool,
|
||||
weight_name_mapping: dict,
|
||||
other_loaded_param_names=[],
|
||||
):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
if is_nextn:
|
||||
@@ -725,15 +907,33 @@ class GptOssForCausalLM(nn.Module):
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
|
||||
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping_fused(
|
||||
ckpt_gate_up_proj_name="gate_up_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_gate_up_proj_bias_name="gate_up_proj_bias",
|
||||
ckpt_down_proj_bias_name="down_proj_bias",
|
||||
)
|
||||
if self.quant_config is not None and (self.quant_config.get_name() == "mxfp4"):
|
||||
expert_params_mapping = (
|
||||
get_moe_impl_class().make_expert_params_mapping_fused_mxfp4(
|
||||
ckpt_gate_up_proj_name="gate_up_proj_blocks",
|
||||
ckpt_down_proj_name="down_proj_blocks",
|
||||
ckpt_gate_up_proj_bias_name="gate_up_proj_bias",
|
||||
ckpt_down_proj_bias_name="down_proj_bias",
|
||||
ckpt_gate_up_proj_scale_name="gate_up_proj_scales",
|
||||
ckpt_down_proj_scale_name="down_proj_scales",
|
||||
)
|
||||
)
|
||||
else:
|
||||
expert_params_mapping = (
|
||||
get_moe_impl_class().make_expert_params_mapping_fused(
|
||||
ckpt_gate_up_proj_name="gate_up_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_gate_up_proj_bias_name="gate_up_proj_bias",
|
||||
ckpt_down_proj_bias_name="down_proj_bias",
|
||||
)
|
||||
)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
params_checker = {k: False for k, v in params_dict.items()}
|
||||
|
||||
for other_loaded_param_name in other_loaded_param_names:
|
||||
params_checker[other_loaded_param_name] = True
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
loaded_weight = _WeightCreator.maybe_materialize(loaded_weight)
|
||||
|
||||
|
||||
@@ -464,6 +464,16 @@ class ServerArgs:
|
||||
self.enable_triton_kernel_moe = True
|
||||
self.disable_hybrid_swa_memory = True
|
||||
|
||||
quantization_config = getattr(
|
||||
self.get_hf_config(), "quantization_config", None
|
||||
)
|
||||
if (
|
||||
quantization_config is not None
|
||||
and quantization_config.get("quant_method") == "mxfp4"
|
||||
):
|
||||
# use bf16 for mxfp4 triton kernels
|
||||
self.dtype = "bfloat16"
|
||||
|
||||
# Set page size
|
||||
if self.page_size is None:
|
||||
self.page_size = 1
|
||||
|
||||
@@ -2124,6 +2124,10 @@ def next_power_of_2(n: int):
|
||||
return 1 << (n - 1).bit_length() if n > 0 else 1
|
||||
|
||||
|
||||
def round_up(x: int, y: int) -> int:
|
||||
return ((x - 1) // y + 1) * y
|
||||
|
||||
|
||||
setattr(triton, "next_power_of_2", next_power_of_2)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user