Move output processing logic from scheduler.py into a separate file (#4354)

This commit is contained in:
Lianmin Zheng
2025-03-12 16:21:49 -07:00
committed by GitHub
parent 2c3656f276
commit e35a93fa8a
6 changed files with 634 additions and 609 deletions

View File

@@ -82,7 +82,6 @@ from sglang.srt.utils import (
logger = logging.getLogger(__name__)
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
@@ -119,6 +118,7 @@ class ModelRunner:
self.spec_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm
)
self.page_size = server_args.page_size
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
@@ -161,6 +161,11 @@ class ModelRunner:
# Get memory before model loading
min_per_gpu_memory = self.init_torch_distributed()
# If it is a draft model tp_group can be different.
self.initialize(min_per_gpu_memory)
def initialize(self, min_per_gpu_memory: float):
server_args = self.server_args
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=self.server_args.enable_memory_saver
)
@@ -300,15 +305,16 @@ class ModelRunner:
min_per_gpu_memory = get_available_gpu_memory(
self.device, self.gpu_id, distributed=self.tp_size > 1
)
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
self.tp_group = get_tp_group()
self.attention_tp_group = get_attention_tp_group()
# Check memory for tensor parallelism
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
if self.tp_size > 1:
if min_per_gpu_memory < local_gpu_memory * 0.9:
raise ValueError(
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
)
logger.info(
@@ -698,6 +704,12 @@ class ModelRunner:
)
self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
self.max_total_num_tokens = (
self.max_total_num_tokens
// self.server_args.page_size
* self.server_args.page_size
)
if self.max_total_num_tokens <= 0:
raise RuntimeError(
"Not enough memory. Please try to increase --mem-fraction-static."
@@ -783,7 +795,6 @@ class ModelRunner:
# Init streams
if self.server_args.speculative_algorithm == "EAGLE":
self.plan_stream_for_flashinfer = torch.cuda.Stream()
self.attn_backend = FlashInferAttnBackend(self)
elif self.server_args.attention_backend == "triton":
assert self.sliding_window_size is None, (