Support nvidia/NVIDIA-Nemotron-Nano-9B-v2-FP8/NVFP4 (#11866)
This commit is contained in:
@@ -90,7 +90,50 @@ CUTEDSL_MOE_NVFP4_DISPATCH = get_bool_env_var(
|
||||
ACTIVATION_SCHEMES = ["static"]
|
||||
|
||||
|
||||
class ModelOptFp8Config(QuantizationConfig):
|
||||
class ModelOptQuantConfig(QuantizationConfig):
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_quant_algo: Optional[str],
|
||||
exclude_modules: Optional[List[str]],
|
||||
packed_modules_mapping: Optional[Dict[str, List[str]]],
|
||||
):
|
||||
super().__init__()
|
||||
self.packed_modules_mapping = packed_modules_mapping
|
||||
self.exclude_modules = exclude_modules or []
|
||||
self.kv_cache_quant_algo = kv_cache_quant_algo
|
||||
|
||||
def _get_quant_method(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
prefix: str,
|
||||
*,
|
||||
Linear: type[LinearMethodBase],
|
||||
Moe: type[FusedMoEMethodBase],
|
||||
) -> Optional[QuantizeMethodBase]:
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped(
|
||||
prefix, self.exclude_modules, self.packed_modules_mapping
|
||||
) or self.is_layer_excluded(prefix):
|
||||
return UnquantizedLinearMethod()
|
||||
return Linear(self)
|
||||
elif self.kv_cache_quant_algo and isinstance(layer, RadixAttention):
|
||||
return ModelOptFp8KVCacheMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return Moe(self)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return ["hf_quant_config.json"]
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
class ModelOptFp8Config(ModelOptQuantConfig):
|
||||
"""Configuration for ModelOpt FP8 quantization, including serialization and compatibility checks."""
|
||||
|
||||
def __init__(
|
||||
@@ -98,14 +141,14 @@ class ModelOptFp8Config(QuantizationConfig):
|
||||
is_checkpoint_fp8_serialized: bool = False,
|
||||
kv_cache_quant_method: Optional[str] = None,
|
||||
exclude_modules: Optional[List[str]] = None,
|
||||
packed_modules_mapping: Optional[Dict[str, List[str]]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
is_checkpoint_fp8_serialized (bool): Indicates if the checkpoint uses serialized FP8 format.
|
||||
"""
|
||||
super().__init__(kv_cache_quant_method, exclude_modules, packed_modules_mapping)
|
||||
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
||||
self.kv_cache_quant_method = kv_cache_quant_method
|
||||
self.exclude_modules = exclude_modules
|
||||
if is_checkpoint_fp8_serialized:
|
||||
logger.warning(
|
||||
"Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change."
|
||||
@@ -128,10 +171,6 @@ class ModelOptFp8Config(QuantizationConfig):
|
||||
def get_min_capability(cls) -> int:
|
||||
return 89 # Minimum hardware capability (e.g., Hopper GPUs).
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return ["hf_quant_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp8Config:
|
||||
# Handle two different config formats:
|
||||
@@ -186,37 +225,27 @@ class ModelOptFp8Config(QuantizationConfig):
|
||||
is_checkpoint_fp8_serialized=True,
|
||||
kv_cache_quant_method=kv_cache_quant_method,
|
||||
exclude_modules=exclude_modules,
|
||||
packed_modules_mapping=config.get("packed_modules_mapping"),
|
||||
)
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional[QuantizeMethodBase]:
|
||||
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
|
||||
if self.exclude_modules and any(
|
||||
def is_layer_excluded(self, prefix: str) -> bool:
|
||||
if len(self.exclude_modules) == 0:
|
||||
return False
|
||||
return any(
|
||||
module in prefix
|
||||
or (
|
||||
prefix.startswith("language_model.")
|
||||
and module in prefix.removeprefix("language_model.")
|
||||
)
|
||||
for module in self.exclude_modules
|
||||
):
|
||||
return None
|
||||
)
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
return ModelOptFp8LinearMethod(self)
|
||||
if self.kv_cache_quant_method and isinstance(layer, RadixAttention):
|
||||
return ModelOptFp8KVCacheMethod(self)
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
return ModelOptFp8MoEMethod(self)
|
||||
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional[QuantizeMethodBase]:
|
||||
return self._get_quant_method(
|
||||
layer, prefix, Linear=ModelOptFp8LinearMethod, Moe=ModelOptFp8MoEMethod
|
||||
)
|
||||
|
||||
|
||||
class ModelOptFp8LinearMethod(LinearMethodBase):
|
||||
@@ -512,7 +541,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
return self.runner.run(dispatch_output, quant_info)
|
||||
|
||||
|
||||
class ModelOptFp4Config(QuantizationConfig):
|
||||
class ModelOptFp4Config(ModelOptQuantConfig):
|
||||
"""Config class for FP4."""
|
||||
|
||||
def __init__(
|
||||
@@ -521,7 +550,9 @@ class ModelOptFp4Config(QuantizationConfig):
|
||||
kv_cache_quant_algo: str = None,
|
||||
group_size: int = None,
|
||||
exclude_modules: List[str] = None,
|
||||
packed_modules_mapping: Optional[Dict[str, List[str]]] = None,
|
||||
) -> None:
|
||||
super().__init__(kv_cache_quant_algo, exclude_modules, packed_modules_mapping)
|
||||
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
|
||||
if is_checkpoint_nvfp4_serialized:
|
||||
logger.warning(
|
||||
@@ -529,8 +560,6 @@ class ModelOptFp4Config(QuantizationConfig):
|
||||
"format is experimental and subject to change."
|
||||
)
|
||||
self.group_size = group_size
|
||||
self.kv_cache_quant_algo = kv_cache_quant_algo
|
||||
self.exclude_modules = exclude_modules
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_config, user_quant):
|
||||
@@ -549,10 +578,6 @@ class ModelOptFp4Config(QuantizationConfig):
|
||||
def get_min_capability(cls) -> int:
|
||||
return 100
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return ["hf_quant_config.json"]
|
||||
|
||||
@staticmethod
|
||||
def common_group_size(cfg: dict) -> int:
|
||||
"""Return the unique group_size across the config; raise if missing/mismatched."""
|
||||
@@ -668,14 +693,15 @@ class ModelOptFp4Config(QuantizationConfig):
|
||||
kv_cache_quant_algo,
|
||||
group_size,
|
||||
exclude_modules,
|
||||
config.get("packed_modules_mapping"),
|
||||
)
|
||||
|
||||
def is_layer_excluded(self, prefix: str, exclude_modules: list):
|
||||
def is_layer_excluded(self, prefix: str):
|
||||
import regex as re
|
||||
|
||||
fused_patterns = ["q_a_proj", "q_b_proj", "kv_a_proj_with_mqa", "kv_b_proj"]
|
||||
prefix_split = prefix.split(".")
|
||||
for pattern in exclude_modules:
|
||||
for pattern in self.exclude_modules:
|
||||
regex_str = pattern.replace(".", r"\.").replace("*", r".*")
|
||||
pattern_split = pattern.split(".")
|
||||
if re.fullmatch(regex_str, prefix):
|
||||
@@ -691,30 +717,17 @@ class ModelOptFp4Config(QuantizationConfig):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional[QuantizeMethodBase]:
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
def get_quant_method(self, layer: torch.nn.Module, prefix: str):
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFP4MoE
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped(prefix, self.exclude_modules) or self.is_layer_excluded(
|
||||
prefix, self.exclude_modules
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
return ModelOptFp4LinearMethod(self)
|
||||
if self.kv_cache_quant_algo and isinstance(layer, RadixAttention):
|
||||
return ModelOptFp8KVCacheMethod(self)
|
||||
elif isinstance(layer, FlashInferFP4MoE):
|
||||
# FlashInferFP4MoE needs the same quantization method but with compatible attribute handling
|
||||
return ModelOptNvFp4FusedMoEMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return ModelOptNvFp4FusedMoEMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
Moe = (
|
||||
FlashInferFP4MoE # FlashInferFP4MoE needs the same quantization method but with compatible attribute handling
|
||||
if isinstance(layer, FlashInferFP4MoE)
|
||||
else ModelOptNvFp4FusedMoEMethod
|
||||
)
|
||||
return self._get_quant_method(
|
||||
layer, prefix, Linear=ModelOptFp4LinearMethod, Moe=Moe
|
||||
)
|
||||
|
||||
|
||||
class ModelOptFp4LinearMethod(LinearMethodBase):
|
||||
|
||||
@@ -180,11 +180,12 @@ def _get_quantization_config(
|
||||
model_config: ModelConfig,
|
||||
load_config: LoadConfig,
|
||||
packed_modules_mapping: Dict[str, List[str]],
|
||||
remap_prefix: Dict[str, str] | None = None,
|
||||
) -> Optional[QuantizationConfig]:
|
||||
"""Get the quantization config."""
|
||||
if model_config.quantization is not None:
|
||||
quant_config = get_quant_config(
|
||||
model_config, load_config, packed_modules_mapping
|
||||
model_config, load_config, packed_modules_mapping, remap_prefix
|
||||
)
|
||||
# (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
|
||||
if quant_config is None:
|
||||
@@ -220,6 +221,7 @@ def _initialize_model(
|
||||
"""Initialize a model with the given configurations."""
|
||||
model_class, _ = get_model_architecture(model_config)
|
||||
packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {})
|
||||
remap_prefix = getattr(model_class, "remap_prefix", None)
|
||||
if _is_npu:
|
||||
packed_modules_mapping.update(
|
||||
{
|
||||
@@ -243,7 +245,7 @@ def _initialize_model(
|
||||
)
|
||||
|
||||
quant_config = _get_quantization_config(
|
||||
model_config, load_config, packed_modules_mapping
|
||||
model_config, load_config, packed_modules_mapping, remap_prefix
|
||||
)
|
||||
|
||||
# Build kwargs conditionally
|
||||
|
||||
@@ -37,7 +37,10 @@ from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_rank
|
||||
from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
|
||||
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp4Config
|
||||
from sglang.srt.layers.quantization.modelopt_quant import (
|
||||
ModelOptFp4Config,
|
||||
ModelOptFp8Config,
|
||||
)
|
||||
from sglang.srt.utils import find_local_repo_dir, log_info_on_rank0, print_warning_once
|
||||
from sglang.utils import is_in_ci
|
||||
|
||||
@@ -135,11 +138,26 @@ def convert_bin_to_safetensor_file(
|
||||
raise RuntimeError(f"The output tensors do not match for key {k}")
|
||||
|
||||
|
||||
def replace_prefix(key: str, prefix_mapping: dict[str, str]) -> str:
|
||||
for prefix, new_prefix in prefix_mapping.items():
|
||||
if key.startswith(prefix):
|
||||
key = key.replace(prefix, new_prefix, 1)
|
||||
return key
|
||||
|
||||
|
||||
def replace_substrings(key: str, substring_mapping: dict[str, str]) -> str:
|
||||
for substr, new_substr in substring_mapping.items():
|
||||
if substr in key:
|
||||
key = key.replace(substr, new_substr)
|
||||
return key
|
||||
|
||||
|
||||
# TODO(woosuk): Move this to other place.
|
||||
def get_quant_config(
|
||||
model_config: ModelConfig,
|
||||
load_config: LoadConfig,
|
||||
packed_modules_mapping: Dict[str, List[str]],
|
||||
remap_prefix: Dict[str, str] | None = None,
|
||||
) -> QuantizationConfig:
|
||||
quant_cls = get_quantization_config(model_config.quantization)
|
||||
|
||||
@@ -209,38 +227,33 @@ def get_quant_config(
|
||||
quant_config_file = quant_config_files[0]
|
||||
with open(quant_config_file) as f:
|
||||
config = json.load(f)
|
||||
if remap_prefix is not None:
|
||||
exclude_modules = [
|
||||
replace_prefix(key, remap_prefix)
|
||||
for key in config["quantization"]["exclude_modules"]
|
||||
]
|
||||
config["quantization"]["exclude_modules"] = exclude_modules
|
||||
config["packed_modules_mapping"] = packed_modules_mapping
|
||||
|
||||
if model_config.quantization == "bitsandbytes":
|
||||
config["adapter_name_or_path"] = model_name_or_path
|
||||
elif model_config.quantization == "modelopt":
|
||||
if config["producer"]["name"] == "modelopt":
|
||||
elif model_config.quantization.startswith("modelopt") and (
|
||||
config["producer"]["name"].startswith("modelopt")
|
||||
):
|
||||
quant_algo = config["quantization"]["quant_algo"]
|
||||
if quant_algo is None:
|
||||
# (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
|
||||
if config["quantization"]["quant_algo"] is None:
|
||||
if (
|
||||
model_config.hf_config.architectures[0]
|
||||
!= "LlamaForCausalLMEagle3"
|
||||
):
|
||||
raise ValueError(
|
||||
f"Invalid quant_config, quantization method: {model_config.quantization},"
|
||||
f"hf architectures: {model_config.hf_config.architectures[0]}. "
|
||||
)
|
||||
return None
|
||||
if "FP4" in config["quantization"]["quant_algo"]:
|
||||
return ModelOptFp4Config.from_config(config)
|
||||
else:
|
||||
return quant_cls.from_config(config)
|
||||
elif model_config.quantization == "modelopt_fp8":
|
||||
if config["producer"]["name"] == "modelopt_fp8":
|
||||
return quant_cls.from_config(config)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported quantization config"
|
||||
f" found for {model_config.quantization} in {f}."
|
||||
)
|
||||
elif model_config.quantization == "w8a8_int8":
|
||||
config["packed_modules_mapping"] = packed_modules_mapping
|
||||
|
||||
return quant_cls.from_config(config)
|
||||
if model_config.hf_config.architectures[0] != "LlamaForCausalLMEagle3":
|
||||
raise ValueError(
|
||||
f"Invalid quant_config, quantization method: {model_config.quantization},"
|
||||
f"hf architectures: {model_config.hf_config.architectures[0]}. "
|
||||
)
|
||||
return None
|
||||
elif quant_algo == "FP8" or model_config.quantization == "modelopt_fp8":
|
||||
return ModelOptFp8Config.from_config(config)
|
||||
elif "FP4" in quant_algo:
|
||||
return ModelOptFp4Config.from_config(config)
|
||||
return quant_cls.from_config(config)
|
||||
|
||||
|
||||
def find_local_hf_snapshot_dir(
|
||||
|
||||
@@ -48,6 +48,8 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
|
||||
from sglang.srt.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
maybe_remap_kv_scale_name,
|
||||
replace_prefix,
|
||||
replace_substrings,
|
||||
)
|
||||
from sglang.srt.utils import add_prefix, make_layers_non_pp
|
||||
from sglang.utils import logger
|
||||
@@ -155,6 +157,7 @@ class NemotronHMambaDecoderLayer(nn.Module):
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
activation=config.mamba_hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mixer",
|
||||
)
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@@ -381,16 +384,19 @@ class NemotronHModel(nn.Module):
|
||||
|
||||
|
||||
class NemotronHForCausalLM(nn.Module):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
}
|
||||
|
||||
remap_prefix = {"backbone": "model"}
|
||||
remap_substr = {"A_log": "A", "embeddings": "embed_tokens"}
|
||||
|
||||
# LoRA specific attributes
|
||||
embedding_modules = {
|
||||
"embed_tokens": "input_embeddings",
|
||||
"lm_head": "output_embeddings",
|
||||
}
|
||||
embedding_padding_modules = ["lm_head"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -432,7 +438,9 @@ class NemotronHForCausalLM(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
return NemotronHModel(config=config, quant_config=quant_config, prefix=prefix)
|
||||
return NemotronHModel(
|
||||
config=config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
@@ -460,21 +468,10 @@ class NemotronHForCausalLM(nn.Module):
|
||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
|
||||
updated_weights = []
|
||||
for name, loaded_weight in weights:
|
||||
for prefix, new_key in self.remap_prefix.items():
|
||||
if name.startswith(prefix):
|
||||
name = name.replace(prefix, new_key)
|
||||
for substr, new_key in self.remap_substr.items():
|
||||
if substr in name:
|
||||
name = name.replace(substr, new_key)
|
||||
name = replace_prefix(name, self.remap_prefix)
|
||||
name = replace_substrings(name, self.remap_substr)
|
||||
updated_weights.append((name, loaded_weight))
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
@@ -484,7 +481,7 @@ class NemotronHForCausalLM(nn.Module):
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
for param_name, weight_name, shard_id in self.stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
Reference in New Issue
Block a user