model: Support Hybrid Mamba2 NemotronHForCausalLM (nvidia/NVIDIA-Nemotron-Nano-9B-v2) (#10909)

Signed-off-by: Netanel Haber <nhaber@nvidia.com>
This commit is contained in:
Netanel Haber
2025-10-08 19:37:38 +03:00
committed by GitHub
parent c882b5ae75
commit d6837aea4d
35 changed files with 3280 additions and 854 deletions

View File

@@ -29,6 +29,7 @@ from typing import List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from sglang.srt.configs import FalconH1Config, NemotronHConfig, Qwen3NextConfig
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
from sglang.srt.configs.model_config import (
@@ -354,8 +355,9 @@ class ModelRunner:
if architectures and not any("Llama4" in arch for arch in architectures):
self.is_hybrid = self.model_config.is_hybrid = True
if self.is_hybrid_gdn:
logger.warning("Hybrid GDN model detected, disable radix cache")
if config := self.mambaish_config:
class_name = config.__class__.__name__
logger.warning(f"{class_name} model detected, disable radix cache")
self.server_args.disable_radix_cache = True
if self.server_args.max_mamba_cache_size is None:
if self.server_args.max_running_requests is not None:
@@ -364,6 +366,7 @@ class ModelRunner:
)
else:
self.server_args.max_mamba_cache_size = 512
if self.hybrid_gdn_config is not None:
self.server_args.max_mamba_cache_size = (
self.server_args.max_mamba_cache_size
// (
@@ -1267,8 +1270,8 @@ class ModelRunner:
"num_nextn_predict_layers",
self.num_effective_layers,
)
elif self.is_hybrid_gdn:
num_layers = len(self.model_config.hf_config.full_attention_layer_ids)
elif config := self.mambaish_config:
num_layers = len(config.full_attention_layer_ids)
else:
num_layers = self.num_effective_layers
if self.use_mla_backend:
@@ -1288,22 +1291,32 @@ class ModelRunner:
rest_memory = available_gpu_memory - total_gpu_memory * (
1 - self.mem_fraction_static
)
if self.is_hybrid_gdn:
if config := self.mambaish_config:
rest_memory -= (
self.server_args.max_mamba_cache_size
* self.model_config.hf_config.mamba_cache_per_req
* config.mamba2_cache_params.mamba_cache_per_req
/ (1 << 30)
)
max_num_token = int(rest_memory * (1 << 30) // cell_size)
return max_num_token
@property
def is_hybrid_gdn(self):
return self.model_config.hf_config.architectures[0] in [
"Qwen3NextForCausalLM",
"Qwen3NextForCausalLMMTP",
"FalconH1ForCausalLM",
]
def hybrid_gdn_config(self):
config = self.model_config.hf_config
if isinstance(config, Qwen3NextConfig):
return config
return None
@property
def mamba2_config(self):
config = self.model_config.hf_config
if isinstance(config, FalconH1Config | NemotronHConfig):
return config
return None
@property
def mambaish_config(self):
return self.mamba2_config or self.hybrid_gdn_config
def set_num_token_hybrid(self):
if (
@@ -1438,7 +1451,7 @@ class ModelRunner:
),
4096,
)
if self.is_hybrid_gdn:
if self.mambaish_config is not None:
max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size)
if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
@@ -1519,26 +1532,14 @@ class ModelRunner:
enable_memory_saver=self.server_args.enable_memory_saver,
pre_alloc_size=pre_alloc_size,
)
elif self.is_hybrid_gdn:
config = self.model_config.hf_config
(
conv_state_shape,
temporal_state_shape,
conv_dtype,
ssm_dtype,
mamba_layers,
) = config.hybrid_gdn_params
elif config := self.mambaish_config:
self.req_to_token_pool = HybridReqToTokenPool(
size=max_num_reqs,
max_context_len=self.model_config.context_len
+ extra_max_context_len,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
conv_state_shape=conv_state_shape,
temporal_state_shape=temporal_state_shape,
conv_dtype=conv_dtype,
ssm_dtype=ssm_dtype,
mamba_layers=mamba_layers,
cache_params=config.mamba2_cache_params,
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
)
else:
@@ -1640,7 +1641,7 @@ class ModelRunner:
enable_kvcache_transpose=False,
device=self.device,
)
elif self.is_hybrid_gdn:
elif config := self.mambaish_config:
self.token_to_kv_pool = HybridLinearKVPool(
page_size=self.page_size,
size=self.max_total_num_tokens,
@@ -1651,9 +1652,7 @@ class ModelRunner:
head_dim=self.model_config.head_dim,
# if draft worker, we only need 1 attention layer's kv pool
full_attention_layer_ids=(
[0]
if self.is_draft_worker
else self.model_config.hf_config.full_attention_layer_ids
[0] if self.is_draft_worker else config.full_attention_layer_ids
),
enable_kvcache_transpose=False,
device=self.device,
@@ -1681,7 +1680,8 @@ class ModelRunner:
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
if self.token_to_kv_pool_allocator is None:
if _is_npu and (
self.server_args.attention_backend == "ascend" or self.is_hybrid_gdn
self.server_args.attention_backend == "ascend"
or self.hybrid_gdn_config is not None
):
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
self.max_total_num_tokens,