diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 949d63450..634fb121f 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -90,7 +90,50 @@ CUTEDSL_MOE_NVFP4_DISPATCH = get_bool_env_var( ACTIVATION_SCHEMES = ["static"] -class ModelOptFp8Config(QuantizationConfig): +class ModelOptQuantConfig(QuantizationConfig): + def __init__( + self, + kv_cache_quant_algo: Optional[str], + exclude_modules: Optional[List[str]], + packed_modules_mapping: Optional[Dict[str, List[str]]], + ): + super().__init__() + self.packed_modules_mapping = packed_modules_mapping + self.exclude_modules = exclude_modules or [] + self.kv_cache_quant_algo = kv_cache_quant_algo + + def _get_quant_method( + self, + layer: torch.nn.Module, + prefix: str, + *, + Linear: type[LinearMethodBase], + Moe: type[FusedMoEMethodBase], + ) -> Optional[QuantizeMethodBase]: + from sglang.srt.layers.linear import LinearBase + from sglang.srt.layers.moe.fused_moe_triton import FusedMoE + + if isinstance(layer, LinearBase): + if is_layer_skipped( + prefix, self.exclude_modules, self.packed_modules_mapping + ) or self.is_layer_excluded(prefix): + return UnquantizedLinearMethod() + return Linear(self) + elif self.kv_cache_quant_algo and isinstance(layer, RadixAttention): + return ModelOptFp8KVCacheMethod(self) + elif isinstance(layer, FusedMoE): + return Moe(self) + return None + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["hf_quant_config.json"] + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class ModelOptFp8Config(ModelOptQuantConfig): """Configuration for ModelOpt FP8 quantization, including serialization and compatibility checks.""" def __init__( @@ -98,14 +141,14 @@ class ModelOptFp8Config(QuantizationConfig): is_checkpoint_fp8_serialized: bool = False, kv_cache_quant_method: Optional[str] = None, exclude_modules: Optional[List[str]] = None, + packed_modules_mapping: Optional[Dict[str, List[str]]] = None, ) -> None: """ Args: is_checkpoint_fp8_serialized (bool): Indicates if the checkpoint uses serialized FP8 format. """ + super().__init__(kv_cache_quant_method, exclude_modules, packed_modules_mapping) self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized - self.kv_cache_quant_method = kv_cache_quant_method - self.exclude_modules = exclude_modules if is_checkpoint_fp8_serialized: logger.warning( "Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change." @@ -128,10 +171,6 @@ class ModelOptFp8Config(QuantizationConfig): def get_min_capability(cls) -> int: return 89 # Minimum hardware capability (e.g., Hopper GPUs). - @classmethod - def get_config_filenames(cls) -> List[str]: - return ["hf_quant_config.json"] - @classmethod def from_config(cls, config: Dict[str, Any]) -> ModelOptFp8Config: # Handle two different config formats: @@ -186,37 +225,27 @@ class ModelOptFp8Config(QuantizationConfig): is_checkpoint_fp8_serialized=True, kv_cache_quant_method=kv_cache_quant_method, exclude_modules=exclude_modules, + packed_modules_mapping=config.get("packed_modules_mapping"), ) - def get_quant_method( - self, layer: torch.nn.Module, prefix: str - ) -> Optional[QuantizeMethodBase]: - - from sglang.srt.layers.linear import LinearBase - from sglang.srt.layers.moe.fused_moe_triton import FusedMoE - - if self.exclude_modules and any( + def is_layer_excluded(self, prefix: str) -> bool: + if len(self.exclude_modules) == 0: + return False + return any( module in prefix or ( prefix.startswith("language_model.") and module in prefix.removeprefix("language_model.") ) for module in self.exclude_modules - ): - return None + ) - if isinstance(layer, LinearBase): - return ModelOptFp8LinearMethod(self) - if self.kv_cache_quant_method and isinstance(layer, RadixAttention): - return ModelOptFp8KVCacheMethod(self) - - if isinstance(layer, FusedMoE): - return ModelOptFp8MoEMethod(self) - - return None - - def get_scaled_act_names(self) -> List[str]: - return [] + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional[QuantizeMethodBase]: + return self._get_quant_method( + layer, prefix, Linear=ModelOptFp8LinearMethod, Moe=ModelOptFp8MoEMethod + ) class ModelOptFp8LinearMethod(LinearMethodBase): @@ -512,7 +541,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): return self.runner.run(dispatch_output, quant_info) -class ModelOptFp4Config(QuantizationConfig): +class ModelOptFp4Config(ModelOptQuantConfig): """Config class for FP4.""" def __init__( @@ -521,7 +550,9 @@ class ModelOptFp4Config(QuantizationConfig): kv_cache_quant_algo: str = None, group_size: int = None, exclude_modules: List[str] = None, + packed_modules_mapping: Optional[Dict[str, List[str]]] = None, ) -> None: + super().__init__(kv_cache_quant_algo, exclude_modules, packed_modules_mapping) self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized if is_checkpoint_nvfp4_serialized: logger.warning( @@ -529,8 +560,6 @@ class ModelOptFp4Config(QuantizationConfig): "format is experimental and subject to change." ) self.group_size = group_size - self.kv_cache_quant_algo = kv_cache_quant_algo - self.exclude_modules = exclude_modules @classmethod def override_quantization_method(cls, hf_quant_config, user_quant): @@ -549,10 +578,6 @@ class ModelOptFp4Config(QuantizationConfig): def get_min_capability(cls) -> int: return 100 - @classmethod - def get_config_filenames(cls) -> List[str]: - return ["hf_quant_config.json"] - @staticmethod def common_group_size(cfg: dict) -> int: """Return the unique group_size across the config; raise if missing/mismatched.""" @@ -668,14 +693,15 @@ class ModelOptFp4Config(QuantizationConfig): kv_cache_quant_algo, group_size, exclude_modules, + config.get("packed_modules_mapping"), ) - def is_layer_excluded(self, prefix: str, exclude_modules: list): + def is_layer_excluded(self, prefix: str): import regex as re fused_patterns = ["q_a_proj", "q_b_proj", "kv_a_proj_with_mqa", "kv_b_proj"] prefix_split = prefix.split(".") - for pattern in exclude_modules: + for pattern in self.exclude_modules: regex_str = pattern.replace(".", r"\.").replace("*", r".*") pattern_split = pattern.split(".") if re.fullmatch(regex_str, prefix): @@ -691,30 +717,17 @@ class ModelOptFp4Config(QuantizationConfig): return True return False - def get_quant_method( - self, layer: torch.nn.Module, prefix: str - ) -> Optional[QuantizeMethodBase]: - from sglang.srt.layers.linear import LinearBase - from sglang.srt.layers.moe.fused_moe_triton import FusedMoE + def get_quant_method(self, layer: torch.nn.Module, prefix: str): from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFP4MoE - if isinstance(layer, LinearBase): - if is_layer_skipped(prefix, self.exclude_modules) or self.is_layer_excluded( - prefix, self.exclude_modules - ): - return UnquantizedLinearMethod() - return ModelOptFp4LinearMethod(self) - if self.kv_cache_quant_algo and isinstance(layer, RadixAttention): - return ModelOptFp8KVCacheMethod(self) - elif isinstance(layer, FlashInferFP4MoE): - # FlashInferFP4MoE needs the same quantization method but with compatible attribute handling - return ModelOptNvFp4FusedMoEMethod(self) - elif isinstance(layer, FusedMoE): - return ModelOptNvFp4FusedMoEMethod(self) - return None - - def get_scaled_act_names(self) -> List[str]: - return [] + Moe = ( + FlashInferFP4MoE # FlashInferFP4MoE needs the same quantization method but with compatible attribute handling + if isinstance(layer, FlashInferFP4MoE) + else ModelOptNvFp4FusedMoEMethod + ) + return self._get_quant_method( + layer, prefix, Linear=ModelOptFp4LinearMethod, Moe=Moe + ) class ModelOptFp4LinearMethod(LinearMethodBase): diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index 6134f24ba..06ecb5041 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -180,11 +180,12 @@ def _get_quantization_config( model_config: ModelConfig, load_config: LoadConfig, packed_modules_mapping: Dict[str, List[str]], + remap_prefix: Dict[str, str] | None = None, ) -> Optional[QuantizationConfig]: """Get the quantization config.""" if model_config.quantization is not None: quant_config = get_quant_config( - model_config, load_config, packed_modules_mapping + model_config, load_config, packed_modules_mapping, remap_prefix ) # (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3 if quant_config is None: @@ -220,6 +221,7 @@ def _initialize_model( """Initialize a model with the given configurations.""" model_class, _ = get_model_architecture(model_config) packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {}) + remap_prefix = getattr(model_class, "remap_prefix", None) if _is_npu: packed_modules_mapping.update( { @@ -243,7 +245,7 @@ def _initialize_model( ) quant_config = _get_quantization_config( - model_config, load_config, packed_modules_mapping + model_config, load_config, packed_modules_mapping, remap_prefix ) # Build kwargs conditionally diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index d4585bbb3..7edd0bbe0 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -37,7 +37,10 @@ from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.layers.dp_attention import get_attention_tp_rank from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config -from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp4Config +from sglang.srt.layers.quantization.modelopt_quant import ( + ModelOptFp4Config, + ModelOptFp8Config, +) from sglang.srt.utils import find_local_repo_dir, log_info_on_rank0, print_warning_once from sglang.utils import is_in_ci @@ -135,11 +138,26 @@ def convert_bin_to_safetensor_file( raise RuntimeError(f"The output tensors do not match for key {k}") +def replace_prefix(key: str, prefix_mapping: dict[str, str]) -> str: + for prefix, new_prefix in prefix_mapping.items(): + if key.startswith(prefix): + key = key.replace(prefix, new_prefix, 1) + return key + + +def replace_substrings(key: str, substring_mapping: dict[str, str]) -> str: + for substr, new_substr in substring_mapping.items(): + if substr in key: + key = key.replace(substr, new_substr) + return key + + # TODO(woosuk): Move this to other place. def get_quant_config( model_config: ModelConfig, load_config: LoadConfig, packed_modules_mapping: Dict[str, List[str]], + remap_prefix: Dict[str, str] | None = None, ) -> QuantizationConfig: quant_cls = get_quantization_config(model_config.quantization) @@ -209,38 +227,33 @@ def get_quant_config( quant_config_file = quant_config_files[0] with open(quant_config_file) as f: config = json.load(f) + if remap_prefix is not None: + exclude_modules = [ + replace_prefix(key, remap_prefix) + for key in config["quantization"]["exclude_modules"] + ] + config["quantization"]["exclude_modules"] = exclude_modules + config["packed_modules_mapping"] = packed_modules_mapping if model_config.quantization == "bitsandbytes": config["adapter_name_or_path"] = model_name_or_path - elif model_config.quantization == "modelopt": - if config["producer"]["name"] == "modelopt": + elif model_config.quantization.startswith("modelopt") and ( + config["producer"]["name"].startswith("modelopt") + ): + quant_algo = config["quantization"]["quant_algo"] + if quant_algo is None: # (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3 - if config["quantization"]["quant_algo"] is None: - if ( - model_config.hf_config.architectures[0] - != "LlamaForCausalLMEagle3" - ): - raise ValueError( - f"Invalid quant_config, quantization method: {model_config.quantization}," - f"hf architectures: {model_config.hf_config.architectures[0]}. " - ) - return None - if "FP4" in config["quantization"]["quant_algo"]: - return ModelOptFp4Config.from_config(config) - else: - return quant_cls.from_config(config) - elif model_config.quantization == "modelopt_fp8": - if config["producer"]["name"] == "modelopt_fp8": - return quant_cls.from_config(config) - else: - raise ValueError( - f"Unsupported quantization config" - f" found for {model_config.quantization} in {f}." - ) - elif model_config.quantization == "w8a8_int8": - config["packed_modules_mapping"] = packed_modules_mapping - - return quant_cls.from_config(config) + if model_config.hf_config.architectures[0] != "LlamaForCausalLMEagle3": + raise ValueError( + f"Invalid quant_config, quantization method: {model_config.quantization}," + f"hf architectures: {model_config.hf_config.architectures[0]}. " + ) + return None + elif quant_algo == "FP8" or model_config.quantization == "modelopt_fp8": + return ModelOptFp8Config.from_config(config) + elif "FP4" in quant_algo: + return ModelOptFp4Config.from_config(config) + return quant_cls.from_config(config) def find_local_hf_snapshot_dir( diff --git a/python/sglang/srt/models/nemotron_h.py b/python/sglang/srt/models/nemotron_h.py index 9f0126c3f..eadff130f 100644 --- a/python/sglang/srt/models/nemotron_h.py +++ b/python/sglang/srt/models/nemotron_h.py @@ -48,6 +48,8 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe from sglang.srt.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, + replace_prefix, + replace_substrings, ) from sglang.srt.utils import add_prefix, make_layers_non_pp from sglang.utils import logger @@ -155,6 +157,7 @@ class NemotronHMambaDecoderLayer(nn.Module): rms_norm_eps=config.rms_norm_eps, activation=config.mamba_hidden_act, quant_config=quant_config, + prefix=f"{prefix}.mixer", ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -381,16 +384,19 @@ class NemotronHModel(nn.Module): class NemotronHForCausalLM(nn.Module): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + } + remap_prefix = {"backbone": "model"} remap_substr = {"A_log": "A", "embeddings": "embed_tokens"} - # LoRA specific attributes - embedding_modules = { - "embed_tokens": "input_embeddings", - "lm_head": "output_embeddings", - } - embedding_padding_modules = ["lm_head"] - def __init__( self, *, @@ -432,7 +438,9 @@ class NemotronHForCausalLM(nn.Module): quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): - return NemotronHModel(config=config, quant_config=quant_config, prefix=prefix) + return NemotronHModel( + config=config, quant_config=quant_config, prefix=add_prefix("model", prefix) + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -460,21 +468,10 @@ class NemotronHForCausalLM(nn.Module): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] - updated_weights = [] for name, loaded_weight in weights: - for prefix, new_key in self.remap_prefix.items(): - if name.startswith(prefix): - name = name.replace(prefix, new_key) - for substr, new_key in self.remap_substr.items(): - if substr in name: - name = name.replace(substr, new_key) + name = replace_prefix(name, self.remap_prefix) + name = replace_substrings(name, self.remap_substr) updated_weights.append((name, loaded_weight)) params_dict = dict(self.named_parameters()) @@ -484,7 +481,7 @@ class NemotronHForCausalLM(nn.Module): if name is None: continue - for param_name, weight_name, shard_id in stacked_params_mapping: + for param_name, weight_name, shard_id in self.stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) diff --git a/test/srt/layers/attention/mamba/test_causal_conv1d.py b/test/srt/layers/attention/mamba/test_causal_conv1d.py index c56b96b4f..dd1a9a25f 100644 --- a/test/srt/layers/attention/mamba/test_causal_conv1d.py +++ b/test/srt/layers/attention/mamba/test_causal_conv1d.py @@ -373,3 +373,7 @@ def test_causal_conv1d_varlen( ) unpadded_out = out[:, : out_ref_tensor.shape[-1]] assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/test/srt/layers/attention/mamba/test_mamba2_mixer.py b/test/srt/layers/attention/mamba/test_mamba2_mixer.py index aae477db5..2252db653 100644 --- a/test/srt/layers/attention/mamba/test_mamba2_mixer.py +++ b/test/srt/layers/attention/mamba/test_mamba2_mixer.py @@ -1,5 +1,6 @@ # Adapted from https://github.com/vllm-project/vllm/blob/2c58742dff8613a3bd7496f2008ce927e18d38d1/tests/kernels/mamba/test_mamba_mixer2.py + from unittest.mock import patch import pytest @@ -136,3 +137,7 @@ def mixer2_gated_norm_tensor_parallel( atol=5e-3, rtol=1e-3, ) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/test/srt/layers/attention/mamba/test_mamba_ssm.py b/test/srt/layers/attention/mamba/test_mamba_ssm.py index 3e983a00e..4a2c9a8e2 100644 --- a/test/srt/layers/attention/mamba/test_mamba_ssm.py +++ b/test/srt/layers/attention/mamba/test_mamba_ssm.py @@ -1,5 +1,6 @@ # Adapted from https://github.com/vllm-project/vllm/blob/633f943e30a4444d890d26b81850f7217736f840/tests/kernels/mamba/test_mamba_ssm_ssd.py + import pytest import torch import torch.nn.functional as F @@ -289,3 +290,7 @@ def test_selective_state_update_with_heads_with_batch_indices( print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol, atol=atol) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/test/srt/layers/attention/mamba/test_mamba_ssm_ssd.py b/test/srt/layers/attention/mamba/test_mamba_ssm_ssd.py index 493a179ee..10a7f3f80 100644 --- a/test/srt/layers/attention/mamba/test_mamba_ssm_ssd.py +++ b/test/srt/layers/attention/mamba/test_mamba_ssm_ssd.py @@ -8,13 +8,12 @@ from einops import rearrange, repeat from sglang.srt.layers.attention.mamba.mamba2_metadata import Mamba2Metadata from sglang.srt.layers.attention.mamba.ops import mamba_chunk_scan_combined +from sglang.utils import is_in_ci # Added by the IBM Team, 2024 # Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/modules/ssd_minimal.py -# TODO: These take a long time to run - we should cut down on some of the parameterized matrix. - # this is the segsum implementation taken from above def segsum(x): @@ -191,10 +190,22 @@ def generate_continuous_batched_examples( ) -@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("n_heads", [3, 4, 11, 16, 32]) -@pytest.mark.parametrize("d_head", [5, 8, 19, 32, 128]) -@pytest.mark.parametrize("seq_len_chunk_size", [(112, 16), (128, 32)]) +SINGLE_ITYPE = [torch.float32, torch.float16, torch.bfloat16] +SINGLE_NHEADS = [3, 4, 11, 16, 32] +SINGLE_DHEAD = [5, 8, 19, 32, 128] +SINGLE_SEQ_LEN_CHUNK_SIZE = [(112, 16), (128, 32)] + +if is_in_ci(): + SINGLE_ITYPE = [torch.float32, torch.bfloat16] + SINGLE_NHEADS = [3, 32] + SINGLE_DHEAD = [5, 128] + SINGLE_SEQ_LEN_CHUNK_SIZE = [(112, 16)] + + +@pytest.mark.parametrize("itype", SINGLE_ITYPE) +@pytest.mark.parametrize("n_heads", SINGLE_NHEADS) +@pytest.mark.parametrize("d_head", SINGLE_DHEAD) +@pytest.mark.parametrize("seq_len_chunk_size", SINGLE_SEQ_LEN_CHUNK_SIZE) def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, itype): if not torch.cuda.is_available(): pytest.skip("CUDA device not available") @@ -238,9 +249,19 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, it ) -@pytest.mark.parametrize("itype", [torch.float32, torch.float16]) -@pytest.mark.parametrize("n_heads", [4, 8, 13]) -@pytest.mark.parametrize("d_head", [5, 16, 21, 32]) +BATCHED_ITYPE = [torch.float32, torch.float16] +BATCHED_NHEADS = [4, 8, 13] +BATCHED_DHEAD = [5, 16, 21, 32] + +if is_in_ci(): + BATCHED_ITYPE = [torch.float32] + BATCHED_NHEADS = [4, 13] + BATCHED_DHEAD = [5, 32] + + +@pytest.mark.parametrize("itype", BATCHED_ITYPE) +@pytest.mark.parametrize("n_heads", BATCHED_NHEADS) +@pytest.mark.parametrize("d_head", BATCHED_DHEAD) @pytest.mark.parametrize( "seq_len_chunk_size_cases", [ @@ -579,3 +600,7 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens): rtol=rtol, msg=lambda x: f"seq{i} state " + x, ) # noqa: B023 + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/test/srt/models/test_nvidia_nemotron_nano_v2.py b/test/srt/models/test_nvidia_nemotron_nano_v2.py index 2fcb6fea0..4b414fbac 100644 --- a/test/srt/models/test_nvidia_nemotron_nano_v2.py +++ b/test/srt/models/test_nvidia_nemotron_nano_v2.py @@ -1,7 +1,7 @@ import unittest from types import SimpleNamespace -from sglang.srt.utils import kill_process_tree +from sglang.srt.utils import is_blackwell, kill_process_tree from sglang.test.few_shot_gsm8k import run_eval from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -12,9 +12,11 @@ from sglang.test.test_utils import ( class TestNvidiaNemotronNanoV2(CustomTestCase): + model = "nvidia/NVIDIA-Nemotron-Nano-9B-v2" + accuracy = 0.87 + @classmethod def setUpClass(cls): - cls.model = "nvidia/NVIDIA-Nemotron-Nano-9B-v2" cls.base_url = DEFAULT_URL_FOR_TEST cls.process = popen_launch_server( cls.model, @@ -42,7 +44,18 @@ class TestNvidiaNemotronNanoV2(CustomTestCase): ) metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["accuracy"], 0.87) + self.assertGreaterEqual(metrics["accuracy"], self.accuracy) + + +class TestNvidiaNemotronNanoV2FP8(TestNvidiaNemotronNanoV2): + accuracy = 0.87 + model = "nvidia/NVIDIA-Nemotron-Nano-9B-v2-FP8" + + +@unittest.skipIf(not is_blackwell(), "NVFP4 only supported on blackwell") +class TestNvidiaNemotronNanoV2NVFP4(TestNvidiaNemotronNanoV2): + accuracy = 0.855 + model = "nvidia/NVIDIA-Nemotron-Nano-9B-v2-NVFP4" if __name__ == "__main__": diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 845e22ee6..723073bf2 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -19,6 +19,9 @@ suites = { TestFile("hicache/test_hicache_eagle.py", 150), TestFile("hicache/test_hicache_mla.py", 127), TestFile("hicache/test_hicache_storage.py", 127), + TestFile("layers/attention/mamba/test_causal_conv1d.py", 25), + TestFile("layers/attention/mamba/test_mamba_ssm.py", 50), + TestFile("layers/attention/mamba/test_mamba_ssm_ssd.py", 70), TestFile("lora/test_lora.py", 200), TestFile("lora/test_lora_eviction.py", 200), TestFile("lora/test_lora_eviction_policy.py", 200), @@ -34,7 +37,7 @@ suites = { TestFile("models/test_embedding_models.py", 73), TestFile("models/test_encoder_embedding_models.py", 460), TestFile("models/test_generation_models.py", 103), - TestFile("models/test_nvidia_nemotron_nano_v2.py", 180), + TestFile("models/test_nvidia_nemotron_nano_v2.py", 300), TestFile("models/test_qwen_models.py", 82), TestFile("batch_invariant/test_batch_invariant_ops.py", 10), TestFile("models/test_reward_models.py", 132), @@ -143,7 +146,7 @@ suites = { TestFile("hicache/test_hicache_storage_3fs_backend.py", 200), TestFile("hicache/test_hicache_storage_file_backend.py", 200), TestFile("hicache/test_hicache_storage_mooncake_backend.py", 400), - TestFile("layers/attention/mamba/test_mamba2_mixer.py", 110), + TestFile("layers/attention/mamba/test_mamba2_mixer.py", 50), TestFile("lora/test_lora_tp.py", 116), TestFile("models/test_glm4_moe_models.py", 100), TestFile("rl/test_update_weights_from_distributed.py", 103),