diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index febf94735..14da05af1 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -170,6 +170,7 @@ class ModelConfig: if is_draft_model and self.hf_config.architectures[0] == "Qwen3NextForCausalLM": self.hf_config.architectures[0] = "Qwen3NextForCausalLMMTP" + self.hf_config.num_nextn_predict_layers = 1 # Check model type self.is_generation = is_generation_model( diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index d465baeb4..13eac206f 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -185,10 +185,9 @@ class LogitsMetadata: ) else: dp_local_start_pos = cumtokens[dp_rank - 1] - dp_local_num_tokens = self.global_num_tokens_for_logprob_gpu[dp_rank] self.dp_local_start_pos = dp_local_start_pos - self.dp_local_num_tokens = dp_local_num_tokens + self.dp_local_num_tokens = self.global_num_tokens_for_logprob_gpu[dp_rank] hidden_size = get_dp_hidden_size() dtype = get_dp_dtype() diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 7de38aabd..5b0f8a714 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -52,6 +52,10 @@ if _is_npu: import torch_npu +def get_tensor_size_bytes(t: torch.Tensor): + return np.prod(t.shape) * t.dtype.itemsize + + class ReqToTokenPool: """A memory pool that maps a request to its token locations.""" @@ -158,16 +162,23 @@ class MambaPool: intermediate_ssm_state_cache, intermediate_conv_window_cache, ) + logger.info( + f"Mamba Cache is allocated. " + f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, " + f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB " + f"intermediate_ssm_state_cache size: {get_tensor_size_bytes(intermediate_ssm_state_cache) / GB:.2f}GB " + f"intermediate_conv_window_cache size: {get_tensor_size_bytes(intermediate_conv_window_cache) / GB:.2f}GB " + ) else: self.mamba_cache = (conv_state, temporal_state) + logger.info( + f"Mamba Cache is allocated. " + f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, " + f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB " + ) self.size = size self.free_slots = list(range(size)) self.mem_usage = self.get_mamba_size() / GB - logger.info( - f"Mamba Cache is allocated. " - f"conv_state size: {conv_state.numel() * conv_state.itemsize / GB:.2f}GB, " - f"ssm_state size: {temporal_state.numel() * temporal_state.itemsize / GB:.2f}GB " - ) def get_mamba_params_all_layers(self): return [self.mamba_cache[i] for i in range(len(self.mamba_cache))] @@ -176,10 +187,7 @@ class MambaPool: return [self.mamba_cache[i][layer_id] for i in range(len(self.mamba_cache))] def get_mamba_size(self): - return ( - np.prod(self.mamba_cache[0].shape) * self.mamba_cache[0].dtype.itemsize - + np.prod(self.mamba_cache[1].shape) * self.mamba_cache[1].dtype.itemsize - ) + return sum(get_tensor_size_bytes(t) for t in self.mamba_cache) def available_size(self): return len(self.free_slots) @@ -492,10 +500,10 @@ class MHATokenToKVPool(KVCache): assert hasattr(self, "v_buffer") k_size_bytes = 0 for k_cache in self.k_buffer: - k_size_bytes += np.prod(k_cache.shape) * k_cache.dtype.itemsize + k_size_bytes += get_tensor_size_bytes(k_cache) v_size_bytes = 0 for v_cache in self.v_buffer: - v_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize + v_size_bytes += get_tensor_size_bytes(v_cache) return k_size_bytes, v_size_bytes # for disagg @@ -1077,7 +1085,7 @@ class MLATokenToKVPool(KVCache): assert hasattr(self, "kv_buffer") kv_size_bytes = 0 for kv_cache in self.kv_buffer: - kv_size_bytes += np.prod(kv_cache.shape) * kv_cache.dtype.itemsize + kv_size_bytes += get_tensor_size_bytes(kv_cache) return kv_size_bytes # for disagg @@ -1240,9 +1248,9 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool): assert hasattr(self, "v_buffer") kv_size_bytes = 0 for k_cache in self.k_buffer: - kv_size_bytes += np.prod(k_cache.shape) * k_cache.dtype.itemsize + kv_size_bytes += get_tensor_size_bytes(k_cache) for v_cache in self.v_buffer: - kv_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize + kv_size_bytes += get_tensor_size_bytes(v_cache) return kv_size_bytes def get_kv_buffer(self, layer_id: int): diff --git a/python/sglang/srt/models/qwen3_next_mtp.py b/python/sglang/srt/models/qwen3_next_mtp.py index a9da0867d..b123efcf8 100644 --- a/python/sglang/srt/models/qwen3_next_mtp.py +++ b/python/sglang/srt/models/qwen3_next_mtp.py @@ -85,8 +85,11 @@ class Qwen3NextForCausalLMMTP(Qwen3NextForCausalLM): if input_embeds is None: input_embeds = self.model.embed_tokens(input_ids) - input_embeds = self.pre_fc_norm_embedding(input_embeds) - hidden_states = self.pre_fc_norm_hidden(forward_batch.spec_info.hidden_states) + hidden_states = forward_batch.spec_info.hidden_states + # Some idle batch has 0 batch size. GemmaRMSNorm.forward would fail due to bs=0. + if not forward_batch.forward_mode.is_idle(): + input_embeds = self.pre_fc_norm_embedding(input_embeds) + hidden_states = self.pre_fc_norm_hidden(hidden_states) hidden_states = self.fc(torch.cat((input_embeds, hidden_states), dim=-1)) hidden_states = self.model(