Clean up allocators (#9134)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Lianmin Zheng
2025-08-13 13:56:04 -07:00
committed by GitHub
parent 2f20f43026
commit 9e426466af
16 changed files with 288 additions and 295 deletions

View File

@@ -75,12 +75,12 @@ from sglang.srt.managers.schedule_batch import (
global_server_args_dict,
)
from sglang.srt.mem_cache.allocator import (
AscendPagedTokenToKVPoolAllocator,
BaseTokenToKVPoolAllocator,
PagedTokenToKVPoolAllocator,
SWATokenToKVPoolAllocator,
TokenToKVPoolAllocator,
)
from sglang.srt.mem_cache.allocator_ascend import AscendPagedTokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import (
AscendMLAPagedTokenToKVPool,
AscendTokenToKVPool,
@@ -176,10 +176,6 @@ class ModelRunner:
self.mem_fraction_static = mem_fraction_static
self.device = server_args.device
self.gpu_id = gpu_id
# Apply the rank zero filter to logger
if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
logger.addFilter(RankZeroFilter(tp_rank == 0))
self.tp_rank = tp_rank
self.tp_size = tp_size
self.moe_ep_rank = moe_ep_rank
@@ -205,15 +201,17 @@ class ModelRunner:
self.is_hybrid = model_config.is_hybrid
self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
self.attention_chunk_size = model_config.attention_chunk_size
self.forward_pass_id = 0
# Apply the rank zero filter to logger
if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
logger.addFilter(RankZeroFilter(tp_rank == 0))
if server_args.show_time_cost:
enable_show_time_cost()
# Model-specific adjustment
self.model_specific_adjustment()
if server_args.show_time_cost:
enable_show_time_cost()
# Global vars
global_server_args_dict.update(
{k: getattr(server_args, k) for k in GLOBAL_SERVER_ARGS_KEYS}
@@ -221,8 +219,6 @@ class ModelRunner:
# TODO it is indeed not a "server args"
"use_mla_backend": self.use_mla_backend,
"speculative_algorithm": self.spec_algorithm,
}
| {
"moe_a2a_backend": MoeA2ABackend(server_args.moe_a2a_backend),
"deepep_mode": DeepEPMode(server_args.deepep_mode),
}
@@ -242,13 +238,15 @@ class ModelRunner:
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
# If it is a draft model, tp_group can be different
# Initialize the model runner
self.initialize(min_per_gpu_memory)
# temporary cached values
# Temporary cached values
self.support_pp = (
"pp_proxy_tensors" in inspect.signature(self.model.forward).parameters
)
# For weight updates
self._model_update_group = {}
def initialize(self, min_per_gpu_memory: float):
@@ -277,6 +275,7 @@ class ModelRunner:
)
)
# Expert parallelism
self.eplb_manager = (
EPLBManager(self)
if self.server_args.enable_eplb and (not self.is_draft_worker)
@@ -1160,6 +1159,7 @@ class ModelRunner:
max_num_reqs: Optional[int] = None,
max_total_tokens: Optional[int] = None,
):
# Determine the kv cache dtype
if self.server_args.kv_cache_dtype == "auto":
self.kv_cache_dtype = self.dtype
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
@@ -1178,6 +1178,8 @@ class ModelRunner:
)
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
if SGLANG_CI_SMALL_KV_SIZE:
self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
if max_num_reqs is None:
max_num_reqs = min(
@@ -1190,9 +1192,6 @@ class ModelRunner:
4096,
)
if SGLANG_CI_SMALL_KV_SIZE:
self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
if not self.spec_algorithm.is_none():
if self.is_draft_worker:
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
@@ -1239,6 +1238,7 @@ class ModelRunner:
"Not enough memory. Please try to increase --mem-fraction-static."
)
# Initialize req_to_token_pool
if self.req_to_token_pool is None:
if self.server_args.disaggregation_mode == "decode":
from sglang.srt.disaggregation.decode import DecodeReqToTokenPool
@@ -1264,6 +1264,7 @@ class ModelRunner:
# Draft worker shares req_to_token_pool with the target worker.
assert self.is_draft_worker
# Initialize token_to_kv_pool
if self.server_args.attention_backend == "ascend":
if self.use_mla_backend:
self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
@@ -1349,28 +1350,44 @@ class ModelRunner:
end_layer=self.end_layer,
)
# Initialize token_to_kv_pool_allocator
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
max_num_extend_tokens = (
self.server_args.chunked_prefill_size
if self.server_args.chunked_prefill_size > 0
else self.server_args.max_prefill_tokens
)
if self.token_to_kv_pool_allocator is None:
if self.page_size == 1:
if self.is_hybrid:
self.token_to_kv_pool_allocator = SWATokenToKVPoolAllocator(
self.full_max_total_num_tokens,
self.swa_max_total_num_tokens,
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
need_sort=need_sort,
)
else:
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
self.max_total_num_tokens,
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
need_sort=need_sort,
)
if self.server_args.attention_backend == "ascend":
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
need_sort=need_sort,
)
else:
if not _is_npu:
if self.page_size == 1:
if self.is_hybrid:
self.token_to_kv_pool_allocator = SWATokenToKVPoolAllocator(
self.full_max_total_num_tokens,
self.swa_max_total_num_tokens,
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
need_sort=need_sort,
)
else:
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
self.max_total_num_tokens,
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
need_sort=need_sort,
)
else:
assert not self.is_hybrid
self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
self.max_total_num_tokens,
page_size=self.page_size,
@@ -1378,15 +1395,7 @@ class ModelRunner:
device=self.device,
kvcache=self.token_to_kv_pool,
need_sort=need_sort,
)
else:
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
need_sort=need_sort,
max_num_extend_tokens=max_num_extend_tokens,
)
else:
assert self.is_draft_worker
@@ -1554,15 +1563,13 @@ class ModelRunner:
)
return TRTLLMHAAttnBackend(self)
elif backend_str == "intel_amx":
from sglang.srt.layers.attention.intel_amx_backend import (
IntelAMXAttnBackend,
)
logger.info(f"Intel AMX attention backend is enabled.")
return IntelAMXAttnBackend(self)
elif self.server_args.attention_backend == "dual_chunk_flash_attn":
elif backend_str == "dual_chunk_flash_attn":
from sglang.srt.layers.attention.dual_chunk_flashattention_backend import (
DualChunkFlashAttentionBackend,
)
@@ -1606,6 +1613,7 @@ class ModelRunner:
f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
)
self.cuda_graph_runner = CudaGraphRunner(self)
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
self.cuda_graph_mem_usage = before_mem - after_mem
logger.info(