From f2887498f0558494db0f05006ec6ef2e148853b8 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 10 Aug 2025 17:32:28 -0700 Subject: [PATCH] Simplify memory pool (#9033) --- .../sglang/srt/model_executor/model_runner.py | 82 +++++++++---------- 1 file changed, 41 insertions(+), 41 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index ee83c2d9c..7681d5fe0 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1251,30 +1251,33 @@ class ModelRunner: # Draft worker shares req_to_token_pool with the target worker. assert self.is_draft_worker - if self.server_args.attention_backend == "ascend" and not self.use_mla_backend: - self.token_to_kv_pool = AscendTokenToKVPool( - self.max_total_num_tokens, - page_size=self.page_size, - dtype=self.kv_cache_dtype, - head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()), - head_dim=self.model_config.head_dim, - layer_num=self.model_config.num_hidden_layers, - device=self.device, - enable_memory_saver=self.server_args.enable_memory_saver, - ) - elif self.server_args.attention_backend == "ascend" and self.use_mla_backend: - self.token_to_kv_pool = AscendMLAPagedTokenToKVPool( - self.max_total_num_tokens, - page_size=self.page_size, - dtype=self.kv_cache_dtype, - kv_lora_rank=self.model_config.kv_lora_rank, - qk_rope_head_dim=self.model_config.qk_rope_head_dim, - layer_num=self.num_effective_layers, - device=self.device, - enable_memory_saver=self.server_args.enable_memory_saver, - start_layer=self.start_layer, - end_layer=self.end_layer, - ) + if self.server_args.attention_backend == "ascend": + if self.use_mla_backend: + self.token_to_kv_pool = AscendMLAPagedTokenToKVPool( + self.max_total_num_tokens, + page_size=self.page_size, + dtype=self.kv_cache_dtype, + kv_lora_rank=self.model_config.kv_lora_rank, + qk_rope_head_dim=self.model_config.qk_rope_head_dim, + layer_num=self.num_effective_layers, + device=self.device, + enable_memory_saver=self.server_args.enable_memory_saver, + start_layer=self.start_layer, + end_layer=self.end_layer, + ) + else: + self.token_to_kv_pool = AscendTokenToKVPool( + self.max_total_num_tokens, + page_size=self.page_size, + dtype=self.kv_cache_dtype, + head_num=self.model_config.get_num_kv_heads( + get_attention_tp_size() + ), + head_dim=self.model_config.head_dim, + layer_num=self.model_config.num_hidden_layers, + device=self.device, + enable_memory_saver=self.server_args.enable_memory_saver, + ) elif self.use_mla_backend: self.token_to_kv_pool = MLATokenToKVPool( self.max_total_num_tokens, @@ -1333,6 +1336,7 @@ class ModelRunner: end_layer=self.end_layer, ) + need_sort = self.server_args.disaggregation_mode in ("decode", "prefill") if self.token_to_kv_pool_allocator is None: if self.page_size == 1: if self.is_hybrid: @@ -1342,8 +1346,7 @@ class ModelRunner: dtype=self.kv_cache_dtype, device=self.device, kvcache=self.token_to_kv_pool, - need_sort=self.server_args.disaggregation_mode - in ("decode", "prefill"), + need_sort=need_sort, ) else: self.token_to_kv_pool_allocator = TokenToKVPoolAllocator( @@ -1351,29 +1354,26 @@ class ModelRunner: dtype=self.kv_cache_dtype, device=self.device, kvcache=self.token_to_kv_pool, - need_sort=self.server_args.disaggregation_mode - in ("decode", "prefill"), + need_sort=need_sort, ) else: - if _is_npu: - self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator( - self.max_total_num_tokens, - page_size=self.page_size, - dtype=self.kv_cache_dtype, - device=self.device, - kvcache=self.token_to_kv_pool, - need_sort=self.server_args.disaggregation_mode - in ("decode", "prefill"), - ) - else: + if not _is_npu: self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator( self.max_total_num_tokens, page_size=self.page_size, dtype=self.kv_cache_dtype, device=self.device, kvcache=self.token_to_kv_pool, - need_sort=self.server_args.disaggregation_mode - in ("decode", "prefill"), + need_sort=need_sort, + ) + else: + self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator( + self.max_total_num_tokens, + page_size=self.page_size, + dtype=self.kv_cache_dtype, + device=self.device, + kvcache=self.token_to_kv_pool, + need_sort=need_sort, ) else: assert self.is_draft_worker