model: Support Hybrid Mamba2 NemotronHForCausalLM (nvidia/NVIDIA-Nemotron-Nano-9B-v2) (#10909)
Signed-off-by: Netanel Haber <nhaber@nvidia.com>
This commit is contained in:
@@ -29,6 +29,7 @@ from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.srt.configs import FalconH1Config, NemotronHConfig, Qwen3NextConfig
|
||||
from sglang.srt.configs.device_config import DeviceConfig
|
||||
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
|
||||
from sglang.srt.configs.model_config import (
|
||||
@@ -354,8 +355,9 @@ class ModelRunner:
|
||||
if architectures and not any("Llama4" in arch for arch in architectures):
|
||||
self.is_hybrid = self.model_config.is_hybrid = True
|
||||
|
||||
if self.is_hybrid_gdn:
|
||||
logger.warning("Hybrid GDN model detected, disable radix cache")
|
||||
if config := self.mambaish_config:
|
||||
class_name = config.__class__.__name__
|
||||
logger.warning(f"{class_name} model detected, disable radix cache")
|
||||
self.server_args.disable_radix_cache = True
|
||||
if self.server_args.max_mamba_cache_size is None:
|
||||
if self.server_args.max_running_requests is not None:
|
||||
@@ -364,6 +366,7 @@ class ModelRunner:
|
||||
)
|
||||
else:
|
||||
self.server_args.max_mamba_cache_size = 512
|
||||
if self.hybrid_gdn_config is not None:
|
||||
self.server_args.max_mamba_cache_size = (
|
||||
self.server_args.max_mamba_cache_size
|
||||
// (
|
||||
@@ -1267,8 +1270,8 @@ class ModelRunner:
|
||||
"num_nextn_predict_layers",
|
||||
self.num_effective_layers,
|
||||
)
|
||||
elif self.is_hybrid_gdn:
|
||||
num_layers = len(self.model_config.hf_config.full_attention_layer_ids)
|
||||
elif config := self.mambaish_config:
|
||||
num_layers = len(config.full_attention_layer_ids)
|
||||
else:
|
||||
num_layers = self.num_effective_layers
|
||||
if self.use_mla_backend:
|
||||
@@ -1288,22 +1291,32 @@ class ModelRunner:
|
||||
rest_memory = available_gpu_memory - total_gpu_memory * (
|
||||
1 - self.mem_fraction_static
|
||||
)
|
||||
if self.is_hybrid_gdn:
|
||||
if config := self.mambaish_config:
|
||||
rest_memory -= (
|
||||
self.server_args.max_mamba_cache_size
|
||||
* self.model_config.hf_config.mamba_cache_per_req
|
||||
* config.mamba2_cache_params.mamba_cache_per_req
|
||||
/ (1 << 30)
|
||||
)
|
||||
max_num_token = int(rest_memory * (1 << 30) // cell_size)
|
||||
return max_num_token
|
||||
|
||||
@property
|
||||
def is_hybrid_gdn(self):
|
||||
return self.model_config.hf_config.architectures[0] in [
|
||||
"Qwen3NextForCausalLM",
|
||||
"Qwen3NextForCausalLMMTP",
|
||||
"FalconH1ForCausalLM",
|
||||
]
|
||||
def hybrid_gdn_config(self):
|
||||
config = self.model_config.hf_config
|
||||
if isinstance(config, Qwen3NextConfig):
|
||||
return config
|
||||
return None
|
||||
|
||||
@property
|
||||
def mamba2_config(self):
|
||||
config = self.model_config.hf_config
|
||||
if isinstance(config, FalconH1Config | NemotronHConfig):
|
||||
return config
|
||||
return None
|
||||
|
||||
@property
|
||||
def mambaish_config(self):
|
||||
return self.mamba2_config or self.hybrid_gdn_config
|
||||
|
||||
def set_num_token_hybrid(self):
|
||||
if (
|
||||
@@ -1438,7 +1451,7 @@ class ModelRunner:
|
||||
),
|
||||
4096,
|
||||
)
|
||||
if self.is_hybrid_gdn:
|
||||
if self.mambaish_config is not None:
|
||||
max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size)
|
||||
|
||||
if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
|
||||
@@ -1519,26 +1532,14 @@ class ModelRunner:
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
pre_alloc_size=pre_alloc_size,
|
||||
)
|
||||
elif self.is_hybrid_gdn:
|
||||
config = self.model_config.hf_config
|
||||
(
|
||||
conv_state_shape,
|
||||
temporal_state_shape,
|
||||
conv_dtype,
|
||||
ssm_dtype,
|
||||
mamba_layers,
|
||||
) = config.hybrid_gdn_params
|
||||
elif config := self.mambaish_config:
|
||||
self.req_to_token_pool = HybridReqToTokenPool(
|
||||
size=max_num_reqs,
|
||||
max_context_len=self.model_config.context_len
|
||||
+ extra_max_context_len,
|
||||
device=self.device,
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
conv_state_shape=conv_state_shape,
|
||||
temporal_state_shape=temporal_state_shape,
|
||||
conv_dtype=conv_dtype,
|
||||
ssm_dtype=ssm_dtype,
|
||||
mamba_layers=mamba_layers,
|
||||
cache_params=config.mamba2_cache_params,
|
||||
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
|
||||
)
|
||||
else:
|
||||
@@ -1640,7 +1641,7 @@ class ModelRunner:
|
||||
enable_kvcache_transpose=False,
|
||||
device=self.device,
|
||||
)
|
||||
elif self.is_hybrid_gdn:
|
||||
elif config := self.mambaish_config:
|
||||
self.token_to_kv_pool = HybridLinearKVPool(
|
||||
page_size=self.page_size,
|
||||
size=self.max_total_num_tokens,
|
||||
@@ -1651,9 +1652,7 @@ class ModelRunner:
|
||||
head_dim=self.model_config.head_dim,
|
||||
# if draft worker, we only need 1 attention layer's kv pool
|
||||
full_attention_layer_ids=(
|
||||
[0]
|
||||
if self.is_draft_worker
|
||||
else self.model_config.hf_config.full_attention_layer_ids
|
||||
[0] if self.is_draft_worker else config.full_attention_layer_ids
|
||||
),
|
||||
enable_kvcache_transpose=False,
|
||||
device=self.device,
|
||||
@@ -1681,7 +1680,8 @@ class ModelRunner:
|
||||
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
|
||||
if self.token_to_kv_pool_allocator is None:
|
||||
if _is_npu and (
|
||||
self.server_args.attention_backend == "ascend" or self.is_hybrid_gdn
|
||||
self.server_args.attention_backend == "ascend"
|
||||
or self.hybrid_gdn_config is not None
|
||||
):
|
||||
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
|
||||
self.max_total_num_tokens,
|
||||
|
||||
Reference in New Issue
Block a user