Clean up allocators (#9134)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -75,12 +75,12 @@ from sglang.srt.managers.schedule_batch import (
|
||||
global_server_args_dict,
|
||||
)
|
||||
from sglang.srt.mem_cache.allocator import (
|
||||
AscendPagedTokenToKVPoolAllocator,
|
||||
BaseTokenToKVPoolAllocator,
|
||||
PagedTokenToKVPoolAllocator,
|
||||
SWATokenToKVPoolAllocator,
|
||||
TokenToKVPoolAllocator,
|
||||
)
|
||||
from sglang.srt.mem_cache.allocator_ascend import AscendPagedTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
AscendMLAPagedTokenToKVPool,
|
||||
AscendTokenToKVPool,
|
||||
@@ -176,10 +176,6 @@ class ModelRunner:
|
||||
self.mem_fraction_static = mem_fraction_static
|
||||
self.device = server_args.device
|
||||
self.gpu_id = gpu_id
|
||||
|
||||
# Apply the rank zero filter to logger
|
||||
if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
|
||||
logger.addFilter(RankZeroFilter(tp_rank == 0))
|
||||
self.tp_rank = tp_rank
|
||||
self.tp_size = tp_size
|
||||
self.moe_ep_rank = moe_ep_rank
|
||||
@@ -205,15 +201,17 @@ class ModelRunner:
|
||||
self.is_hybrid = model_config.is_hybrid
|
||||
self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
|
||||
self.attention_chunk_size = model_config.attention_chunk_size
|
||||
|
||||
self.forward_pass_id = 0
|
||||
|
||||
# Apply the rank zero filter to logger
|
||||
if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
|
||||
logger.addFilter(RankZeroFilter(tp_rank == 0))
|
||||
if server_args.show_time_cost:
|
||||
enable_show_time_cost()
|
||||
|
||||
# Model-specific adjustment
|
||||
self.model_specific_adjustment()
|
||||
|
||||
if server_args.show_time_cost:
|
||||
enable_show_time_cost()
|
||||
|
||||
# Global vars
|
||||
global_server_args_dict.update(
|
||||
{k: getattr(server_args, k) for k in GLOBAL_SERVER_ARGS_KEYS}
|
||||
@@ -221,8 +219,6 @@ class ModelRunner:
|
||||
# TODO it is indeed not a "server args"
|
||||
"use_mla_backend": self.use_mla_backend,
|
||||
"speculative_algorithm": self.spec_algorithm,
|
||||
}
|
||||
| {
|
||||
"moe_a2a_backend": MoeA2ABackend(server_args.moe_a2a_backend),
|
||||
"deepep_mode": DeepEPMode(server_args.deepep_mode),
|
||||
}
|
||||
@@ -242,13 +238,15 @@ class ModelRunner:
|
||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
||||
deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
|
||||
|
||||
# If it is a draft model, tp_group can be different
|
||||
# Initialize the model runner
|
||||
self.initialize(min_per_gpu_memory)
|
||||
|
||||
# temporary cached values
|
||||
# Temporary cached values
|
||||
self.support_pp = (
|
||||
"pp_proxy_tensors" in inspect.signature(self.model.forward).parameters
|
||||
)
|
||||
|
||||
# For weight updates
|
||||
self._model_update_group = {}
|
||||
|
||||
def initialize(self, min_per_gpu_memory: float):
|
||||
@@ -277,6 +275,7 @@ class ModelRunner:
|
||||
)
|
||||
)
|
||||
|
||||
# Expert parallelism
|
||||
self.eplb_manager = (
|
||||
EPLBManager(self)
|
||||
if self.server_args.enable_eplb and (not self.is_draft_worker)
|
||||
@@ -1160,6 +1159,7 @@ class ModelRunner:
|
||||
max_num_reqs: Optional[int] = None,
|
||||
max_total_tokens: Optional[int] = None,
|
||||
):
|
||||
# Determine the kv cache dtype
|
||||
if self.server_args.kv_cache_dtype == "auto":
|
||||
self.kv_cache_dtype = self.dtype
|
||||
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
|
||||
@@ -1178,6 +1178,8 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
|
||||
if SGLANG_CI_SMALL_KV_SIZE:
|
||||
self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
|
||||
|
||||
if max_num_reqs is None:
|
||||
max_num_reqs = min(
|
||||
@@ -1190,9 +1192,6 @@ class ModelRunner:
|
||||
4096,
|
||||
)
|
||||
|
||||
if SGLANG_CI_SMALL_KV_SIZE:
|
||||
self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
|
||||
|
||||
if not self.spec_algorithm.is_none():
|
||||
if self.is_draft_worker:
|
||||
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
|
||||
@@ -1239,6 +1238,7 @@ class ModelRunner:
|
||||
"Not enough memory. Please try to increase --mem-fraction-static."
|
||||
)
|
||||
|
||||
# Initialize req_to_token_pool
|
||||
if self.req_to_token_pool is None:
|
||||
if self.server_args.disaggregation_mode == "decode":
|
||||
from sglang.srt.disaggregation.decode import DecodeReqToTokenPool
|
||||
@@ -1264,6 +1264,7 @@ class ModelRunner:
|
||||
# Draft worker shares req_to_token_pool with the target worker.
|
||||
assert self.is_draft_worker
|
||||
|
||||
# Initialize token_to_kv_pool
|
||||
if self.server_args.attention_backend == "ascend":
|
||||
if self.use_mla_backend:
|
||||
self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
|
||||
@@ -1349,28 +1350,44 @@ class ModelRunner:
|
||||
end_layer=self.end_layer,
|
||||
)
|
||||
|
||||
# Initialize token_to_kv_pool_allocator
|
||||
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
|
||||
max_num_extend_tokens = (
|
||||
self.server_args.chunked_prefill_size
|
||||
if self.server_args.chunked_prefill_size > 0
|
||||
else self.server_args.max_prefill_tokens
|
||||
)
|
||||
if self.token_to_kv_pool_allocator is None:
|
||||
if self.page_size == 1:
|
||||
if self.is_hybrid:
|
||||
self.token_to_kv_pool_allocator = SWATokenToKVPoolAllocator(
|
||||
self.full_max_total_num_tokens,
|
||||
self.swa_max_total_num_tokens,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
kvcache=self.token_to_kv_pool,
|
||||
need_sort=need_sort,
|
||||
)
|
||||
else:
|
||||
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
|
||||
self.max_total_num_tokens,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
kvcache=self.token_to_kv_pool,
|
||||
need_sort=need_sort,
|
||||
)
|
||||
if self.server_args.attention_backend == "ascend":
|
||||
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
|
||||
self.max_total_num_tokens,
|
||||
page_size=self.page_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
kvcache=self.token_to_kv_pool,
|
||||
need_sort=need_sort,
|
||||
)
|
||||
else:
|
||||
if not _is_npu:
|
||||
if self.page_size == 1:
|
||||
if self.is_hybrid:
|
||||
self.token_to_kv_pool_allocator = SWATokenToKVPoolAllocator(
|
||||
self.full_max_total_num_tokens,
|
||||
self.swa_max_total_num_tokens,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
kvcache=self.token_to_kv_pool,
|
||||
need_sort=need_sort,
|
||||
)
|
||||
else:
|
||||
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
|
||||
self.max_total_num_tokens,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
kvcache=self.token_to_kv_pool,
|
||||
need_sort=need_sort,
|
||||
)
|
||||
else:
|
||||
assert not self.is_hybrid
|
||||
self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
|
||||
self.max_total_num_tokens,
|
||||
page_size=self.page_size,
|
||||
@@ -1378,15 +1395,7 @@ class ModelRunner:
|
||||
device=self.device,
|
||||
kvcache=self.token_to_kv_pool,
|
||||
need_sort=need_sort,
|
||||
)
|
||||
else:
|
||||
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
|
||||
self.max_total_num_tokens,
|
||||
page_size=self.page_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
kvcache=self.token_to_kv_pool,
|
||||
need_sort=need_sort,
|
||||
max_num_extend_tokens=max_num_extend_tokens,
|
||||
)
|
||||
else:
|
||||
assert self.is_draft_worker
|
||||
@@ -1554,15 +1563,13 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
return TRTLLMHAAttnBackend(self)
|
||||
|
||||
elif backend_str == "intel_amx":
|
||||
from sglang.srt.layers.attention.intel_amx_backend import (
|
||||
IntelAMXAttnBackend,
|
||||
)
|
||||
|
||||
logger.info(f"Intel AMX attention backend is enabled.")
|
||||
return IntelAMXAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "dual_chunk_flash_attn":
|
||||
elif backend_str == "dual_chunk_flash_attn":
|
||||
from sglang.srt.layers.attention.dual_chunk_flashattention_backend import (
|
||||
DualChunkFlashAttentionBackend,
|
||||
)
|
||||
@@ -1606,6 +1613,7 @@ class ModelRunner:
|
||||
f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
|
||||
)
|
||||
self.cuda_graph_runner = CudaGraphRunner(self)
|
||||
|
||||
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
||||
self.cuda_graph_mem_usage = before_mem - after_mem
|
||||
logger.info(
|
||||
|
||||
Reference in New Issue
Block a user