Support nvidia/NVIDIA-Nemotron-Nano-9B-v2-FP8/NVFP4 (#11866)

This commit is contained in:
Netanel Haber
2025-10-23 12:29:02 +03:00
committed by GitHub
parent 36a4cad7b0
commit d6fee73d1f
10 changed files with 207 additions and 127 deletions

View File

@@ -90,7 +90,50 @@ CUTEDSL_MOE_NVFP4_DISPATCH = get_bool_env_var(
ACTIVATION_SCHEMES = ["static"]
class ModelOptFp8Config(QuantizationConfig):
class ModelOptQuantConfig(QuantizationConfig):
def __init__(
self,
kv_cache_quant_algo: Optional[str],
exclude_modules: Optional[List[str]],
packed_modules_mapping: Optional[Dict[str, List[str]]],
):
super().__init__()
self.packed_modules_mapping = packed_modules_mapping
self.exclude_modules = exclude_modules or []
self.kv_cache_quant_algo = kv_cache_quant_algo
def _get_quant_method(
self,
layer: torch.nn.Module,
prefix: str,
*,
Linear: type[LinearMethodBase],
Moe: type[FusedMoEMethodBase],
) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if isinstance(layer, LinearBase):
if is_layer_skipped(
prefix, self.exclude_modules, self.packed_modules_mapping
) or self.is_layer_excluded(prefix):
return UnquantizedLinearMethod()
return Linear(self)
elif self.kv_cache_quant_algo and isinstance(layer, RadixAttention):
return ModelOptFp8KVCacheMethod(self)
elif isinstance(layer, FusedMoE):
return Moe(self)
return None
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["hf_quant_config.json"]
def get_scaled_act_names(self) -> List[str]:
return []
class ModelOptFp8Config(ModelOptQuantConfig):
"""Configuration for ModelOpt FP8 quantization, including serialization and compatibility checks."""
def __init__(
@@ -98,14 +141,14 @@ class ModelOptFp8Config(QuantizationConfig):
is_checkpoint_fp8_serialized: bool = False,
kv_cache_quant_method: Optional[str] = None,
exclude_modules: Optional[List[str]] = None,
packed_modules_mapping: Optional[Dict[str, List[str]]] = None,
) -> None:
"""
Args:
is_checkpoint_fp8_serialized (bool): Indicates if the checkpoint uses serialized FP8 format.
"""
super().__init__(kv_cache_quant_method, exclude_modules, packed_modules_mapping)
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
self.kv_cache_quant_method = kv_cache_quant_method
self.exclude_modules = exclude_modules
if is_checkpoint_fp8_serialized:
logger.warning(
"Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change."
@@ -128,10 +171,6 @@ class ModelOptFp8Config(QuantizationConfig):
def get_min_capability(cls) -> int:
return 89 # Minimum hardware capability (e.g., Hopper GPUs).
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["hf_quant_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp8Config:
# Handle two different config formats:
@@ -186,37 +225,27 @@ class ModelOptFp8Config(QuantizationConfig):
is_checkpoint_fp8_serialized=True,
kv_cache_quant_method=kv_cache_quant_method,
exclude_modules=exclude_modules,
packed_modules_mapping=config.get("packed_modules_mapping"),
)
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if self.exclude_modules and any(
def is_layer_excluded(self, prefix: str) -> bool:
if len(self.exclude_modules) == 0:
return False
return any(
module in prefix
or (
prefix.startswith("language_model.")
and module in prefix.removeprefix("language_model.")
)
for module in self.exclude_modules
):
return None
)
if isinstance(layer, LinearBase):
return ModelOptFp8LinearMethod(self)
if self.kv_cache_quant_method and isinstance(layer, RadixAttention):
return ModelOptFp8KVCacheMethod(self)
if isinstance(layer, FusedMoE):
return ModelOptFp8MoEMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional[QuantizeMethodBase]:
return self._get_quant_method(
layer, prefix, Linear=ModelOptFp8LinearMethod, Moe=ModelOptFp8MoEMethod
)
class ModelOptFp8LinearMethod(LinearMethodBase):
@@ -512,7 +541,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
return self.runner.run(dispatch_output, quant_info)
class ModelOptFp4Config(QuantizationConfig):
class ModelOptFp4Config(ModelOptQuantConfig):
"""Config class for FP4."""
def __init__(
@@ -521,7 +550,9 @@ class ModelOptFp4Config(QuantizationConfig):
kv_cache_quant_algo: str = None,
group_size: int = None,
exclude_modules: List[str] = None,
packed_modules_mapping: Optional[Dict[str, List[str]]] = None,
) -> None:
super().__init__(kv_cache_quant_algo, exclude_modules, packed_modules_mapping)
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
if is_checkpoint_nvfp4_serialized:
logger.warning(
@@ -529,8 +560,6 @@ class ModelOptFp4Config(QuantizationConfig):
"format is experimental and subject to change."
)
self.group_size = group_size
self.kv_cache_quant_algo = kv_cache_quant_algo
self.exclude_modules = exclude_modules
@classmethod
def override_quantization_method(cls, hf_quant_config, user_quant):
@@ -549,10 +578,6 @@ class ModelOptFp4Config(QuantizationConfig):
def get_min_capability(cls) -> int:
return 100
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["hf_quant_config.json"]
@staticmethod
def common_group_size(cfg: dict) -> int:
"""Return the unique group_size across the config; raise if missing/mismatched."""
@@ -668,14 +693,15 @@ class ModelOptFp4Config(QuantizationConfig):
kv_cache_quant_algo,
group_size,
exclude_modules,
config.get("packed_modules_mapping"),
)
def is_layer_excluded(self, prefix: str, exclude_modules: list):
def is_layer_excluded(self, prefix: str):
import regex as re
fused_patterns = ["q_a_proj", "q_b_proj", "kv_a_proj_with_mqa", "kv_b_proj"]
prefix_split = prefix.split(".")
for pattern in exclude_modules:
for pattern in self.exclude_modules:
regex_str = pattern.replace(".", r"\.").replace("*", r".*")
pattern_split = pattern.split(".")
if re.fullmatch(regex_str, prefix):
@@ -691,30 +717,17 @@ class ModelOptFp4Config(QuantizationConfig):
return True
return False
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
def get_quant_method(self, layer: torch.nn.Module, prefix: str):
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFP4MoE
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.exclude_modules) or self.is_layer_excluded(
prefix, self.exclude_modules
):
return UnquantizedLinearMethod()
return ModelOptFp4LinearMethod(self)
if self.kv_cache_quant_algo and isinstance(layer, RadixAttention):
return ModelOptFp8KVCacheMethod(self)
elif isinstance(layer, FlashInferFP4MoE):
# FlashInferFP4MoE needs the same quantization method but with compatible attribute handling
return ModelOptNvFp4FusedMoEMethod(self)
elif isinstance(layer, FusedMoE):
return ModelOptNvFp4FusedMoEMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
Moe = (
FlashInferFP4MoE # FlashInferFP4MoE needs the same quantization method but with compatible attribute handling
if isinstance(layer, FlashInferFP4MoE)
else ModelOptNvFp4FusedMoEMethod
)
return self._get_quant_method(
layer, prefix, Linear=ModelOptFp4LinearMethod, Moe=Moe
)
class ModelOptFp4LinearMethod(LinearMethodBase):

View File

@@ -180,11 +180,12 @@ def _get_quantization_config(
model_config: ModelConfig,
load_config: LoadConfig,
packed_modules_mapping: Dict[str, List[str]],
remap_prefix: Dict[str, str] | None = None,
) -> Optional[QuantizationConfig]:
"""Get the quantization config."""
if model_config.quantization is not None:
quant_config = get_quant_config(
model_config, load_config, packed_modules_mapping
model_config, load_config, packed_modules_mapping, remap_prefix
)
# (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
if quant_config is None:
@@ -220,6 +221,7 @@ def _initialize_model(
"""Initialize a model with the given configurations."""
model_class, _ = get_model_architecture(model_config)
packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {})
remap_prefix = getattr(model_class, "remap_prefix", None)
if _is_npu:
packed_modules_mapping.update(
{
@@ -243,7 +245,7 @@ def _initialize_model(
)
quant_config = _get_quantization_config(
model_config, load_config, packed_modules_mapping
model_config, load_config, packed_modules_mapping, remap_prefix
)
# Build kwargs conditionally

View File

@@ -37,7 +37,10 @@ from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.layers.dp_attention import get_attention_tp_rank
from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp4Config
from sglang.srt.layers.quantization.modelopt_quant import (
ModelOptFp4Config,
ModelOptFp8Config,
)
from sglang.srt.utils import find_local_repo_dir, log_info_on_rank0, print_warning_once
from sglang.utils import is_in_ci
@@ -135,11 +138,26 @@ def convert_bin_to_safetensor_file(
raise RuntimeError(f"The output tensors do not match for key {k}")
def replace_prefix(key: str, prefix_mapping: dict[str, str]) -> str:
for prefix, new_prefix in prefix_mapping.items():
if key.startswith(prefix):
key = key.replace(prefix, new_prefix, 1)
return key
def replace_substrings(key: str, substring_mapping: dict[str, str]) -> str:
for substr, new_substr in substring_mapping.items():
if substr in key:
key = key.replace(substr, new_substr)
return key
# TODO(woosuk): Move this to other place.
def get_quant_config(
model_config: ModelConfig,
load_config: LoadConfig,
packed_modules_mapping: Dict[str, List[str]],
remap_prefix: Dict[str, str] | None = None,
) -> QuantizationConfig:
quant_cls = get_quantization_config(model_config.quantization)
@@ -209,38 +227,33 @@ def get_quant_config(
quant_config_file = quant_config_files[0]
with open(quant_config_file) as f:
config = json.load(f)
if remap_prefix is not None:
exclude_modules = [
replace_prefix(key, remap_prefix)
for key in config["quantization"]["exclude_modules"]
]
config["quantization"]["exclude_modules"] = exclude_modules
config["packed_modules_mapping"] = packed_modules_mapping
if model_config.quantization == "bitsandbytes":
config["adapter_name_or_path"] = model_name_or_path
elif model_config.quantization == "modelopt":
if config["producer"]["name"] == "modelopt":
elif model_config.quantization.startswith("modelopt") and (
config["producer"]["name"].startswith("modelopt")
):
quant_algo = config["quantization"]["quant_algo"]
if quant_algo is None:
# (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
if config["quantization"]["quant_algo"] is None:
if (
model_config.hf_config.architectures[0]
!= "LlamaForCausalLMEagle3"
):
raise ValueError(
f"Invalid quant_config, quantization method: {model_config.quantization},"
f"hf architectures: {model_config.hf_config.architectures[0]}. "
)
return None
if "FP4" in config["quantization"]["quant_algo"]:
return ModelOptFp4Config.from_config(config)
else:
return quant_cls.from_config(config)
elif model_config.quantization == "modelopt_fp8":
if config["producer"]["name"] == "modelopt_fp8":
return quant_cls.from_config(config)
else:
raise ValueError(
f"Unsupported quantization config"
f" found for {model_config.quantization} in {f}."
)
elif model_config.quantization == "w8a8_int8":
config["packed_modules_mapping"] = packed_modules_mapping
return quant_cls.from_config(config)
if model_config.hf_config.architectures[0] != "LlamaForCausalLMEagle3":
raise ValueError(
f"Invalid quant_config, quantization method: {model_config.quantization},"
f"hf architectures: {model_config.hf_config.architectures[0]}. "
)
return None
elif quant_algo == "FP8" or model_config.quantization == "modelopt_fp8":
return ModelOptFp8Config.from_config(config)
elif "FP4" in quant_algo:
return ModelOptFp4Config.from_config(config)
return quant_cls.from_config(config)
def find_local_hf_snapshot_dir(

View File

@@ -48,6 +48,8 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
replace_prefix,
replace_substrings,
)
from sglang.srt.utils import add_prefix, make_layers_non_pp
from sglang.utils import logger
@@ -155,6 +157,7 @@ class NemotronHMambaDecoderLayer(nn.Module):
rms_norm_eps=config.rms_norm_eps,
activation=config.mamba_hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mixer",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -381,16 +384,19 @@ class NemotronHModel(nn.Module):
class NemotronHForCausalLM(nn.Module):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
}
remap_prefix = {"backbone": "model"}
remap_substr = {"A_log": "A", "embeddings": "embed_tokens"}
# LoRA specific attributes
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(
self,
*,
@@ -432,7 +438,9 @@ class NemotronHForCausalLM(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
return NemotronHModel(config=config, quant_config=quant_config, prefix=prefix)
return NemotronHModel(
config=config, quant_config=quant_config, prefix=add_prefix("model", prefix)
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
@@ -460,21 +468,10 @@ class NemotronHForCausalLM(nn.Module):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
updated_weights = []
for name, loaded_weight in weights:
for prefix, new_key in self.remap_prefix.items():
if name.startswith(prefix):
name = name.replace(prefix, new_key)
for substr, new_key in self.remap_substr.items():
if substr in name:
name = name.replace(substr, new_key)
name = replace_prefix(name, self.remap_prefix)
name = replace_substrings(name, self.remap_substr)
updated_weights.append((name, loaded_weight))
params_dict = dict(self.named_parameters())
@@ -484,7 +481,7 @@ class NemotronHForCausalLM(nn.Module):
if name is None:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
for param_name, weight_name, shard_id in self.stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)