Revert "Support nvidia/NVIDIA-Nemotron-Nano-9B-v2-FP8/NVFP4" (#12015)
This commit is contained in:
@@ -90,50 +90,7 @@ CUTEDSL_MOE_NVFP4_DISPATCH = get_bool_env_var(
|
|||||||
ACTIVATION_SCHEMES = ["static"]
|
ACTIVATION_SCHEMES = ["static"]
|
||||||
|
|
||||||
|
|
||||||
class ModelOptQuantConfig(QuantizationConfig):
|
class ModelOptFp8Config(QuantizationConfig):
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
kv_cache_quant_algo: Optional[str],
|
|
||||||
exclude_modules: Optional[List[str]],
|
|
||||||
packed_modules_mapping: Optional[Dict[str, List[str]]],
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.packed_modules_mapping = packed_modules_mapping
|
|
||||||
self.exclude_modules = exclude_modules or []
|
|
||||||
self.kv_cache_quant_algo = kv_cache_quant_algo
|
|
||||||
|
|
||||||
def _get_quant_method(
|
|
||||||
self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
prefix: str,
|
|
||||||
*,
|
|
||||||
Linear: type[LinearMethodBase],
|
|
||||||
Moe: type[FusedMoEMethodBase],
|
|
||||||
) -> Optional[QuantizeMethodBase]:
|
|
||||||
from sglang.srt.layers.linear import LinearBase
|
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
|
||||||
|
|
||||||
if isinstance(layer, LinearBase):
|
|
||||||
if is_layer_skipped(
|
|
||||||
prefix, self.exclude_modules, self.packed_modules_mapping
|
|
||||||
) or self.is_layer_excluded(prefix):
|
|
||||||
return UnquantizedLinearMethod()
|
|
||||||
return Linear(self)
|
|
||||||
elif self.kv_cache_quant_algo and isinstance(layer, RadixAttention):
|
|
||||||
return ModelOptFp8KVCacheMethod(self)
|
|
||||||
elif isinstance(layer, FusedMoE):
|
|
||||||
return Moe(self)
|
|
||||||
return None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_config_filenames(cls) -> List[str]:
|
|
||||||
return ["hf_quant_config.json"]
|
|
||||||
|
|
||||||
def get_scaled_act_names(self) -> List[str]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
class ModelOptFp8Config(ModelOptQuantConfig):
|
|
||||||
"""Configuration for ModelOpt FP8 quantization, including serialization and compatibility checks."""
|
"""Configuration for ModelOpt FP8 quantization, including serialization and compatibility checks."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -141,14 +98,14 @@ class ModelOptFp8Config(ModelOptQuantConfig):
|
|||||||
is_checkpoint_fp8_serialized: bool = False,
|
is_checkpoint_fp8_serialized: bool = False,
|
||||||
kv_cache_quant_method: Optional[str] = None,
|
kv_cache_quant_method: Optional[str] = None,
|
||||||
exclude_modules: Optional[List[str]] = None,
|
exclude_modules: Optional[List[str]] = None,
|
||||||
packed_modules_mapping: Optional[Dict[str, List[str]]] = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
is_checkpoint_fp8_serialized (bool): Indicates if the checkpoint uses serialized FP8 format.
|
is_checkpoint_fp8_serialized (bool): Indicates if the checkpoint uses serialized FP8 format.
|
||||||
"""
|
"""
|
||||||
super().__init__(kv_cache_quant_method, exclude_modules, packed_modules_mapping)
|
|
||||||
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
||||||
|
self.kv_cache_quant_method = kv_cache_quant_method
|
||||||
|
self.exclude_modules = exclude_modules
|
||||||
if is_checkpoint_fp8_serialized:
|
if is_checkpoint_fp8_serialized:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change."
|
"Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change."
|
||||||
@@ -171,6 +128,10 @@ class ModelOptFp8Config(ModelOptQuantConfig):
|
|||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
return 89 # Minimum hardware capability (e.g., Hopper GPUs).
|
return 89 # Minimum hardware capability (e.g., Hopper GPUs).
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config_filenames(cls) -> List[str]:
|
||||||
|
return ["hf_quant_config.json"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp8Config:
|
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp8Config:
|
||||||
# Handle two different config formats:
|
# Handle two different config formats:
|
||||||
@@ -225,27 +186,37 @@ class ModelOptFp8Config(ModelOptQuantConfig):
|
|||||||
is_checkpoint_fp8_serialized=True,
|
is_checkpoint_fp8_serialized=True,
|
||||||
kv_cache_quant_method=kv_cache_quant_method,
|
kv_cache_quant_method=kv_cache_quant_method,
|
||||||
exclude_modules=exclude_modules,
|
exclude_modules=exclude_modules,
|
||||||
packed_modules_mapping=config.get("packed_modules_mapping"),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_layer_excluded(self, prefix: str) -> bool:
|
def get_quant_method(
|
||||||
if len(self.exclude_modules) == 0:
|
self, layer: torch.nn.Module, prefix: str
|
||||||
return False
|
) -> Optional[QuantizeMethodBase]:
|
||||||
return any(
|
|
||||||
|
from sglang.srt.layers.linear import LinearBase
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
|
|
||||||
|
if self.exclude_modules and any(
|
||||||
module in prefix
|
module in prefix
|
||||||
or (
|
or (
|
||||||
prefix.startswith("language_model.")
|
prefix.startswith("language_model.")
|
||||||
and module in prefix.removeprefix("language_model.")
|
and module in prefix.removeprefix("language_model.")
|
||||||
)
|
)
|
||||||
for module in self.exclude_modules
|
for module in self.exclude_modules
|
||||||
)
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
def get_quant_method(
|
if isinstance(layer, LinearBase):
|
||||||
self, layer: torch.nn.Module, prefix: str
|
return ModelOptFp8LinearMethod(self)
|
||||||
) -> Optional[QuantizeMethodBase]:
|
if self.kv_cache_quant_method and isinstance(layer, RadixAttention):
|
||||||
return self._get_quant_method(
|
return ModelOptFp8KVCacheMethod(self)
|
||||||
layer, prefix, Linear=ModelOptFp8LinearMethod, Moe=ModelOptFp8MoEMethod
|
|
||||||
)
|
if isinstance(layer, FusedMoE):
|
||||||
|
return ModelOptFp8MoEMethod(self)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_scaled_act_names(self) -> List[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
class ModelOptFp8LinearMethod(LinearMethodBase):
|
class ModelOptFp8LinearMethod(LinearMethodBase):
|
||||||
@@ -541,7 +512,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
return self.runner.run(dispatch_output, quant_info)
|
return self.runner.run(dispatch_output, quant_info)
|
||||||
|
|
||||||
|
|
||||||
class ModelOptFp4Config(ModelOptQuantConfig):
|
class ModelOptFp4Config(QuantizationConfig):
|
||||||
"""Config class for FP4."""
|
"""Config class for FP4."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -550,9 +521,7 @@ class ModelOptFp4Config(ModelOptQuantConfig):
|
|||||||
kv_cache_quant_algo: str = None,
|
kv_cache_quant_algo: str = None,
|
||||||
group_size: int = None,
|
group_size: int = None,
|
||||||
exclude_modules: List[str] = None,
|
exclude_modules: List[str] = None,
|
||||||
packed_modules_mapping: Optional[Dict[str, List[str]]] = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(kv_cache_quant_algo, exclude_modules, packed_modules_mapping)
|
|
||||||
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
|
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
|
||||||
if is_checkpoint_nvfp4_serialized:
|
if is_checkpoint_nvfp4_serialized:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -560,6 +529,8 @@ class ModelOptFp4Config(ModelOptQuantConfig):
|
|||||||
"format is experimental and subject to change."
|
"format is experimental and subject to change."
|
||||||
)
|
)
|
||||||
self.group_size = group_size
|
self.group_size = group_size
|
||||||
|
self.kv_cache_quant_algo = kv_cache_quant_algo
|
||||||
|
self.exclude_modules = exclude_modules
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def override_quantization_method(cls, hf_quant_config, user_quant):
|
def override_quantization_method(cls, hf_quant_config, user_quant):
|
||||||
@@ -578,6 +549,10 @@ class ModelOptFp4Config(ModelOptQuantConfig):
|
|||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
return 100
|
return 100
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config_filenames(cls) -> List[str]:
|
||||||
|
return ["hf_quant_config.json"]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def common_group_size(cfg: dict) -> int:
|
def common_group_size(cfg: dict) -> int:
|
||||||
"""Return the unique group_size across the config; raise if missing/mismatched."""
|
"""Return the unique group_size across the config; raise if missing/mismatched."""
|
||||||
@@ -693,15 +668,14 @@ class ModelOptFp4Config(ModelOptQuantConfig):
|
|||||||
kv_cache_quant_algo,
|
kv_cache_quant_algo,
|
||||||
group_size,
|
group_size,
|
||||||
exclude_modules,
|
exclude_modules,
|
||||||
config.get("packed_modules_mapping"),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_layer_excluded(self, prefix: str):
|
def is_layer_excluded(self, prefix: str, exclude_modules: list):
|
||||||
import regex as re
|
import regex as re
|
||||||
|
|
||||||
fused_patterns = ["q_a_proj", "q_b_proj", "kv_a_proj_with_mqa", "kv_b_proj"]
|
fused_patterns = ["q_a_proj", "q_b_proj", "kv_a_proj_with_mqa", "kv_b_proj"]
|
||||||
prefix_split = prefix.split(".")
|
prefix_split = prefix.split(".")
|
||||||
for pattern in self.exclude_modules:
|
for pattern in exclude_modules:
|
||||||
regex_str = pattern.replace(".", r"\.").replace("*", r".*")
|
regex_str = pattern.replace(".", r"\.").replace("*", r".*")
|
||||||
pattern_split = pattern.split(".")
|
pattern_split = pattern.split(".")
|
||||||
if re.fullmatch(regex_str, prefix):
|
if re.fullmatch(regex_str, prefix):
|
||||||
@@ -717,17 +691,30 @@ class ModelOptFp4Config(ModelOptQuantConfig):
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_quant_method(self, layer: torch.nn.Module, prefix: str):
|
def get_quant_method(
|
||||||
|
self, layer: torch.nn.Module, prefix: str
|
||||||
|
) -> Optional[QuantizeMethodBase]:
|
||||||
|
from sglang.srt.layers.linear import LinearBase
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFP4MoE
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFP4MoE
|
||||||
|
|
||||||
Moe = (
|
if isinstance(layer, LinearBase):
|
||||||
FlashInferFP4MoE # FlashInferFP4MoE needs the same quantization method but with compatible attribute handling
|
if is_layer_skipped(prefix, self.exclude_modules) or self.is_layer_excluded(
|
||||||
if isinstance(layer, FlashInferFP4MoE)
|
prefix, self.exclude_modules
|
||||||
else ModelOptNvFp4FusedMoEMethod
|
):
|
||||||
)
|
return UnquantizedLinearMethod()
|
||||||
return self._get_quant_method(
|
return ModelOptFp4LinearMethod(self)
|
||||||
layer, prefix, Linear=ModelOptFp4LinearMethod, Moe=Moe
|
if self.kv_cache_quant_algo and isinstance(layer, RadixAttention):
|
||||||
)
|
return ModelOptFp8KVCacheMethod(self)
|
||||||
|
elif isinstance(layer, FlashInferFP4MoE):
|
||||||
|
# FlashInferFP4MoE needs the same quantization method but with compatible attribute handling
|
||||||
|
return ModelOptNvFp4FusedMoEMethod(self)
|
||||||
|
elif isinstance(layer, FusedMoE):
|
||||||
|
return ModelOptNvFp4FusedMoEMethod(self)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_scaled_act_names(self) -> List[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
class ModelOptFp4LinearMethod(LinearMethodBase):
|
class ModelOptFp4LinearMethod(LinearMethodBase):
|
||||||
|
|||||||
@@ -180,12 +180,11 @@ def _get_quantization_config(
|
|||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
load_config: LoadConfig,
|
load_config: LoadConfig,
|
||||||
packed_modules_mapping: Dict[str, List[str]],
|
packed_modules_mapping: Dict[str, List[str]],
|
||||||
remap_prefix: Dict[str, str] | None = None,
|
|
||||||
) -> Optional[QuantizationConfig]:
|
) -> Optional[QuantizationConfig]:
|
||||||
"""Get the quantization config."""
|
"""Get the quantization config."""
|
||||||
if model_config.quantization is not None:
|
if model_config.quantization is not None:
|
||||||
quant_config = get_quant_config(
|
quant_config = get_quant_config(
|
||||||
model_config, load_config, packed_modules_mapping, remap_prefix
|
model_config, load_config, packed_modules_mapping
|
||||||
)
|
)
|
||||||
# (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
|
# (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
|
||||||
if quant_config is None:
|
if quant_config is None:
|
||||||
@@ -221,7 +220,6 @@ def _initialize_model(
|
|||||||
"""Initialize a model with the given configurations."""
|
"""Initialize a model with the given configurations."""
|
||||||
model_class, _ = get_model_architecture(model_config)
|
model_class, _ = get_model_architecture(model_config)
|
||||||
packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {})
|
packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {})
|
||||||
remap_prefix = getattr(model_class, "remap_prefix", None)
|
|
||||||
if _is_npu:
|
if _is_npu:
|
||||||
packed_modules_mapping.update(
|
packed_modules_mapping.update(
|
||||||
{
|
{
|
||||||
@@ -245,7 +243,7 @@ def _initialize_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
quant_config = _get_quantization_config(
|
quant_config = _get_quantization_config(
|
||||||
model_config, load_config, packed_modules_mapping, remap_prefix
|
model_config, load_config, packed_modules_mapping
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build kwargs conditionally
|
# Build kwargs conditionally
|
||||||
|
|||||||
@@ -37,10 +37,7 @@ from sglang.srt.configs.model_config import ModelConfig
|
|||||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||||
from sglang.srt.layers.dp_attention import get_attention_tp_rank
|
from sglang.srt.layers.dp_attention import get_attention_tp_rank
|
||||||
from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
|
from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
|
||||||
from sglang.srt.layers.quantization.modelopt_quant import (
|
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp4Config
|
||||||
ModelOptFp4Config,
|
|
||||||
ModelOptFp8Config,
|
|
||||||
)
|
|
||||||
from sglang.srt.utils import find_local_repo_dir, log_info_on_rank0, print_warning_once
|
from sglang.srt.utils import find_local_repo_dir, log_info_on_rank0, print_warning_once
|
||||||
from sglang.utils import is_in_ci
|
from sglang.utils import is_in_ci
|
||||||
|
|
||||||
@@ -138,26 +135,11 @@ def convert_bin_to_safetensor_file(
|
|||||||
raise RuntimeError(f"The output tensors do not match for key {k}")
|
raise RuntimeError(f"The output tensors do not match for key {k}")
|
||||||
|
|
||||||
|
|
||||||
def replace_prefix(key: str, prefix_mapping: dict[str, str]) -> str:
|
|
||||||
for prefix, new_prefix in prefix_mapping.items():
|
|
||||||
if key.startswith(prefix):
|
|
||||||
key = key.replace(prefix, new_prefix, 1)
|
|
||||||
return key
|
|
||||||
|
|
||||||
|
|
||||||
def replace_substrings(key: str, substring_mapping: dict[str, str]) -> str:
|
|
||||||
for substr, new_substr in substring_mapping.items():
|
|
||||||
if substr in key:
|
|
||||||
key = key.replace(substr, new_substr)
|
|
||||||
return key
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(woosuk): Move this to other place.
|
# TODO(woosuk): Move this to other place.
|
||||||
def get_quant_config(
|
def get_quant_config(
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
load_config: LoadConfig,
|
load_config: LoadConfig,
|
||||||
packed_modules_mapping: Dict[str, List[str]],
|
packed_modules_mapping: Dict[str, List[str]],
|
||||||
remap_prefix: Dict[str, str] | None = None,
|
|
||||||
) -> QuantizationConfig:
|
) -> QuantizationConfig:
|
||||||
quant_cls = get_quantization_config(model_config.quantization)
|
quant_cls = get_quantization_config(model_config.quantization)
|
||||||
|
|
||||||
@@ -227,33 +209,38 @@ def get_quant_config(
|
|||||||
quant_config_file = quant_config_files[0]
|
quant_config_file = quant_config_files[0]
|
||||||
with open(quant_config_file) as f:
|
with open(quant_config_file) as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
if remap_prefix is not None:
|
|
||||||
exclude_modules = [
|
|
||||||
replace_prefix(key, remap_prefix)
|
|
||||||
for key in config["quantization"]["exclude_modules"]
|
|
||||||
]
|
|
||||||
config["quantization"]["exclude_modules"] = exclude_modules
|
|
||||||
config["packed_modules_mapping"] = packed_modules_mapping
|
|
||||||
|
|
||||||
if model_config.quantization == "bitsandbytes":
|
if model_config.quantization == "bitsandbytes":
|
||||||
config["adapter_name_or_path"] = model_name_or_path
|
config["adapter_name_or_path"] = model_name_or_path
|
||||||
elif model_config.quantization.startswith("modelopt") and (
|
elif model_config.quantization == "modelopt":
|
||||||
config["producer"]["name"].startswith("modelopt")
|
if config["producer"]["name"] == "modelopt":
|
||||||
):
|
|
||||||
quant_algo = config["quantization"]["quant_algo"]
|
|
||||||
if quant_algo is None:
|
|
||||||
# (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
|
# (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
|
||||||
if model_config.hf_config.architectures[0] != "LlamaForCausalLMEagle3":
|
if config["quantization"]["quant_algo"] is None:
|
||||||
raise ValueError(
|
if (
|
||||||
f"Invalid quant_config, quantization method: {model_config.quantization},"
|
model_config.hf_config.architectures[0]
|
||||||
f"hf architectures: {model_config.hf_config.architectures[0]}. "
|
!= "LlamaForCausalLMEagle3"
|
||||||
)
|
):
|
||||||
return None
|
raise ValueError(
|
||||||
elif quant_algo == "FP8" or model_config.quantization == "modelopt_fp8":
|
f"Invalid quant_config, quantization method: {model_config.quantization},"
|
||||||
return ModelOptFp8Config.from_config(config)
|
f"hf architectures: {model_config.hf_config.architectures[0]}. "
|
||||||
elif "FP4" in quant_algo:
|
)
|
||||||
return ModelOptFp4Config.from_config(config)
|
return None
|
||||||
return quant_cls.from_config(config)
|
if "FP4" in config["quantization"]["quant_algo"]:
|
||||||
|
return ModelOptFp4Config.from_config(config)
|
||||||
|
else:
|
||||||
|
return quant_cls.from_config(config)
|
||||||
|
elif model_config.quantization == "modelopt_fp8":
|
||||||
|
if config["producer"]["name"] == "modelopt_fp8":
|
||||||
|
return quant_cls.from_config(config)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported quantization config"
|
||||||
|
f" found for {model_config.quantization} in {f}."
|
||||||
|
)
|
||||||
|
elif model_config.quantization == "w8a8_int8":
|
||||||
|
config["packed_modules_mapping"] = packed_modules_mapping
|
||||||
|
|
||||||
|
return quant_cls.from_config(config)
|
||||||
|
|
||||||
|
|
||||||
def find_local_hf_snapshot_dir(
|
def find_local_hf_snapshot_dir(
|
||||||
|
|||||||
@@ -48,8 +48,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
|
|||||||
from sglang.srt.model_loader.weight_utils import (
|
from sglang.srt.model_loader.weight_utils import (
|
||||||
default_weight_loader,
|
default_weight_loader,
|
||||||
maybe_remap_kv_scale_name,
|
maybe_remap_kv_scale_name,
|
||||||
replace_prefix,
|
|
||||||
replace_substrings,
|
|
||||||
)
|
)
|
||||||
from sglang.srt.utils import add_prefix, make_layers_non_pp
|
from sglang.srt.utils import add_prefix, make_layers_non_pp
|
||||||
from sglang.utils import logger
|
from sglang.utils import logger
|
||||||
@@ -157,7 +155,6 @@ class NemotronHMambaDecoderLayer(nn.Module):
|
|||||||
rms_norm_eps=config.rms_norm_eps,
|
rms_norm_eps=config.rms_norm_eps,
|
||||||
activation=config.mamba_hidden_act,
|
activation=config.mamba_hidden_act,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.mixer",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
@@ -384,19 +381,16 @@ class NemotronHModel(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class NemotronHForCausalLM(nn.Module):
|
class NemotronHForCausalLM(nn.Module):
|
||||||
stacked_params_mapping = [
|
|
||||||
# (param_name, shard_name, shard_id)
|
|
||||||
("qkv_proj", "q_proj", "q"),
|
|
||||||
("qkv_proj", "k_proj", "k"),
|
|
||||||
("qkv_proj", "v_proj", "v"),
|
|
||||||
]
|
|
||||||
packed_modules_mapping = {
|
|
||||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
|
||||||
}
|
|
||||||
|
|
||||||
remap_prefix = {"backbone": "model"}
|
remap_prefix = {"backbone": "model"}
|
||||||
remap_substr = {"A_log": "A", "embeddings": "embed_tokens"}
|
remap_substr = {"A_log": "A", "embeddings": "embed_tokens"}
|
||||||
|
|
||||||
|
# LoRA specific attributes
|
||||||
|
embedding_modules = {
|
||||||
|
"embed_tokens": "input_embeddings",
|
||||||
|
"lm_head": "output_embeddings",
|
||||||
|
}
|
||||||
|
embedding_padding_modules = ["lm_head"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@@ -438,9 +432,7 @@ class NemotronHForCausalLM(nn.Module):
|
|||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
return NemotronHModel(
|
return NemotronHModel(config=config, quant_config=quant_config, prefix=prefix)
|
||||||
config=config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
return self.model.get_input_embeddings(input_ids)
|
return self.model.get_input_embeddings(input_ids)
|
||||||
@@ -468,10 +460,21 @@ class NemotronHForCausalLM(nn.Module):
|
|||||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None:
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
]
|
||||||
|
|
||||||
updated_weights = []
|
updated_weights = []
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
name = replace_prefix(name, self.remap_prefix)
|
for prefix, new_key in self.remap_prefix.items():
|
||||||
name = replace_substrings(name, self.remap_substr)
|
if name.startswith(prefix):
|
||||||
|
name = name.replace(prefix, new_key)
|
||||||
|
for substr, new_key in self.remap_substr.items():
|
||||||
|
if substr in name:
|
||||||
|
name = name.replace(substr, new_key)
|
||||||
updated_weights.append((name, loaded_weight))
|
updated_weights.append((name, loaded_weight))
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
|
|
||||||
@@ -481,7 +484,7 @@ class NemotronHForCausalLM(nn.Module):
|
|||||||
if name is None:
|
if name is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for param_name, weight_name, shard_id in self.stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
name = name.replace(weight_name, param_name)
|
name = name.replace(weight_name, param_name)
|
||||||
|
|||||||
@@ -373,7 +373,3 @@ def test_causal_conv1d_varlen(
|
|||||||
)
|
)
|
||||||
unpadded_out = out[:, : out_ref_tensor.shape[-1]]
|
unpadded_out = out[:, : out_ref_tensor.shape[-1]]
|
||||||
assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
|
assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
pytest.main([__file__])
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
# Adapted from https://github.com/vllm-project/vllm/blob/2c58742dff8613a3bd7496f2008ce927e18d38d1/tests/kernels/mamba/test_mamba_mixer2.py
|
# Adapted from https://github.com/vllm-project/vllm/blob/2c58742dff8613a3bd7496f2008ce927e18d38d1/tests/kernels/mamba/test_mamba_mixer2.py
|
||||||
|
|
||||||
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -137,7 +136,3 @@ def mixer2_gated_norm_tensor_parallel(
|
|||||||
atol=5e-3,
|
atol=5e-3,
|
||||||
rtol=1e-3,
|
rtol=1e-3,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
pytest.main([__file__])
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
# Adapted from https://github.com/vllm-project/vllm/blob/633f943e30a4444d890d26b81850f7217736f840/tests/kernels/mamba/test_mamba_ssm_ssd.py
|
# Adapted from https://github.com/vllm-project/vllm/blob/633f943e30a4444d890d26b81850f7217736f840/tests/kernels/mamba/test_mamba_ssm_ssd.py
|
||||||
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -290,7 +289,3 @@ def test_selective_state_update_with_heads_with_batch_indices(
|
|||||||
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
||||||
assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol, atol=atol)
|
assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol, atol=atol)
|
||||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
pytest.main([__file__])
|
|
||||||
|
|||||||
@@ -8,12 +8,13 @@ from einops import rearrange, repeat
|
|||||||
|
|
||||||
from sglang.srt.layers.attention.mamba.mamba2_metadata import Mamba2Metadata
|
from sglang.srt.layers.attention.mamba.mamba2_metadata import Mamba2Metadata
|
||||||
from sglang.srt.layers.attention.mamba.ops import mamba_chunk_scan_combined
|
from sglang.srt.layers.attention.mamba.ops import mamba_chunk_scan_combined
|
||||||
from sglang.utils import is_in_ci
|
|
||||||
|
|
||||||
# Added by the IBM Team, 2024
|
# Added by the IBM Team, 2024
|
||||||
|
|
||||||
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/modules/ssd_minimal.py
|
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/modules/ssd_minimal.py
|
||||||
|
|
||||||
|
# TODO: These take a long time to run - we should cut down on some of the parameterized matrix.
|
||||||
|
|
||||||
|
|
||||||
# this is the segsum implementation taken from above
|
# this is the segsum implementation taken from above
|
||||||
def segsum(x):
|
def segsum(x):
|
||||||
@@ -190,22 +191,10 @@ def generate_continuous_batched_examples(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
SINGLE_ITYPE = [torch.float32, torch.float16, torch.bfloat16]
|
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
||||||
SINGLE_NHEADS = [3, 4, 11, 16, 32]
|
@pytest.mark.parametrize("n_heads", [3, 4, 11, 16, 32])
|
||||||
SINGLE_DHEAD = [5, 8, 19, 32, 128]
|
@pytest.mark.parametrize("d_head", [5, 8, 19, 32, 128])
|
||||||
SINGLE_SEQ_LEN_CHUNK_SIZE = [(112, 16), (128, 32)]
|
@pytest.mark.parametrize("seq_len_chunk_size", [(112, 16), (128, 32)])
|
||||||
|
|
||||||
if is_in_ci():
|
|
||||||
SINGLE_ITYPE = [torch.float32, torch.bfloat16]
|
|
||||||
SINGLE_NHEADS = [3, 32]
|
|
||||||
SINGLE_DHEAD = [5, 128]
|
|
||||||
SINGLE_SEQ_LEN_CHUNK_SIZE = [(112, 16)]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("itype", SINGLE_ITYPE)
|
|
||||||
@pytest.mark.parametrize("n_heads", SINGLE_NHEADS)
|
|
||||||
@pytest.mark.parametrize("d_head", SINGLE_DHEAD)
|
|
||||||
@pytest.mark.parametrize("seq_len_chunk_size", SINGLE_SEQ_LEN_CHUNK_SIZE)
|
|
||||||
def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, itype):
|
def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, itype):
|
||||||
if not torch.cuda.is_available():
|
if not torch.cuda.is_available():
|
||||||
pytest.skip("CUDA device not available")
|
pytest.skip("CUDA device not available")
|
||||||
@@ -249,19 +238,9 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, it
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
BATCHED_ITYPE = [torch.float32, torch.float16]
|
@pytest.mark.parametrize("itype", [torch.float32, torch.float16])
|
||||||
BATCHED_NHEADS = [4, 8, 13]
|
@pytest.mark.parametrize("n_heads", [4, 8, 13])
|
||||||
BATCHED_DHEAD = [5, 16, 21, 32]
|
@pytest.mark.parametrize("d_head", [5, 16, 21, 32])
|
||||||
|
|
||||||
if is_in_ci():
|
|
||||||
BATCHED_ITYPE = [torch.float32]
|
|
||||||
BATCHED_NHEADS = [4, 13]
|
|
||||||
BATCHED_DHEAD = [5, 32]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("itype", BATCHED_ITYPE)
|
|
||||||
@pytest.mark.parametrize("n_heads", BATCHED_NHEADS)
|
|
||||||
@pytest.mark.parametrize("d_head", BATCHED_DHEAD)
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"seq_len_chunk_size_cases",
|
"seq_len_chunk_size_cases",
|
||||||
[
|
[
|
||||||
@@ -600,7 +579,3 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
|||||||
rtol=rtol,
|
rtol=rtol,
|
||||||
msg=lambda x: f"seq{i} state " + x,
|
msg=lambda x: f"seq{i} state " + x,
|
||||||
) # noqa: B023
|
) # noqa: B023
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
pytest.main([__file__])
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
from sglang.srt.utils import is_blackwell, kill_process_tree
|
from sglang.srt.utils import kill_process_tree
|
||||||
from sglang.test.few_shot_gsm8k import run_eval
|
from sglang.test.few_shot_gsm8k import run_eval
|
||||||
from sglang.test.test_utils import (
|
from sglang.test.test_utils import (
|
||||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
@@ -12,11 +12,9 @@ from sglang.test.test_utils import (
|
|||||||
|
|
||||||
|
|
||||||
class TestNvidiaNemotronNanoV2(CustomTestCase):
|
class TestNvidiaNemotronNanoV2(CustomTestCase):
|
||||||
model = "nvidia/NVIDIA-Nemotron-Nano-9B-v2"
|
|
||||||
accuracy = 0.87
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
|
cls.model = "nvidia/NVIDIA-Nemotron-Nano-9B-v2"
|
||||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
cls.process = popen_launch_server(
|
cls.process = popen_launch_server(
|
||||||
cls.model,
|
cls.model,
|
||||||
@@ -44,18 +42,7 @@ class TestNvidiaNemotronNanoV2(CustomTestCase):
|
|||||||
)
|
)
|
||||||
metrics = run_eval(args)
|
metrics = run_eval(args)
|
||||||
print(f"{metrics=}")
|
print(f"{metrics=}")
|
||||||
self.assertGreaterEqual(metrics["accuracy"], self.accuracy)
|
self.assertGreater(metrics["accuracy"], 0.87)
|
||||||
|
|
||||||
|
|
||||||
class TestNvidiaNemotronNanoV2FP8(TestNvidiaNemotronNanoV2):
|
|
||||||
accuracy = 0.87
|
|
||||||
model = "nvidia/NVIDIA-Nemotron-Nano-9B-v2-FP8"
|
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(not is_blackwell(), "NVFP4 only supported on blackwell")
|
|
||||||
class TestNvidiaNemotronNanoV2NVFP4(TestNvidiaNemotronNanoV2):
|
|
||||||
accuracy = 0.855
|
|
||||||
model = "nvidia/NVIDIA-Nemotron-Nano-9B-v2-NVFP4"
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -19,9 +19,6 @@ suites = {
|
|||||||
TestFile("hicache/test_hicache_eagle.py", 150),
|
TestFile("hicache/test_hicache_eagle.py", 150),
|
||||||
TestFile("hicache/test_hicache_mla.py", 127),
|
TestFile("hicache/test_hicache_mla.py", 127),
|
||||||
TestFile("hicache/test_hicache_storage.py", 127),
|
TestFile("hicache/test_hicache_storage.py", 127),
|
||||||
TestFile("layers/attention/mamba/test_causal_conv1d.py", 25),
|
|
||||||
TestFile("layers/attention/mamba/test_mamba_ssm.py", 50),
|
|
||||||
TestFile("layers/attention/mamba/test_mamba_ssm_ssd.py", 70),
|
|
||||||
TestFile("lora/test_lora.py", 200),
|
TestFile("lora/test_lora.py", 200),
|
||||||
TestFile("lora/test_lora_eviction.py", 200),
|
TestFile("lora/test_lora_eviction.py", 200),
|
||||||
TestFile("lora/test_lora_eviction_policy.py", 200),
|
TestFile("lora/test_lora_eviction_policy.py", 200),
|
||||||
@@ -37,7 +34,7 @@ suites = {
|
|||||||
TestFile("models/test_embedding_models.py", 73),
|
TestFile("models/test_embedding_models.py", 73),
|
||||||
TestFile("models/test_encoder_embedding_models.py", 460),
|
TestFile("models/test_encoder_embedding_models.py", 460),
|
||||||
TestFile("models/test_generation_models.py", 103),
|
TestFile("models/test_generation_models.py", 103),
|
||||||
TestFile("models/test_nvidia_nemotron_nano_v2.py", 300),
|
TestFile("models/test_nvidia_nemotron_nano_v2.py", 180),
|
||||||
TestFile("models/test_qwen_models.py", 82),
|
TestFile("models/test_qwen_models.py", 82),
|
||||||
TestFile("batch_invariant/test_batch_invariant_ops.py", 10),
|
TestFile("batch_invariant/test_batch_invariant_ops.py", 10),
|
||||||
TestFile("models/test_reward_models.py", 132),
|
TestFile("models/test_reward_models.py", 132),
|
||||||
@@ -146,7 +143,7 @@ suites = {
|
|||||||
TestFile("hicache/test_hicache_storage_3fs_backend.py", 200),
|
TestFile("hicache/test_hicache_storage_3fs_backend.py", 200),
|
||||||
TestFile("hicache/test_hicache_storage_file_backend.py", 200),
|
TestFile("hicache/test_hicache_storage_file_backend.py", 200),
|
||||||
TestFile("hicache/test_hicache_storage_mooncake_backend.py", 400),
|
TestFile("hicache/test_hicache_storage_mooncake_backend.py", 400),
|
||||||
TestFile("layers/attention/mamba/test_mamba2_mixer.py", 50),
|
TestFile("layers/attention/mamba/test_mamba2_mixer.py", 110),
|
||||||
TestFile("lora/test_lora_tp.py", 116),
|
TestFile("lora/test_lora_tp.py", 116),
|
||||||
TestFile("models/test_glm4_moe_models.py", 100),
|
TestFile("models/test_glm4_moe_models.py", 100),
|
||||||
TestFile("rl/test_update_weights_from_distributed.py", 103),
|
TestFile("rl/test_update_weights_from_distributed.py", 103),
|
||||||
|
|||||||
Reference in New Issue
Block a user