Move output processing logic from scheduler.py into a separate file (#4354)
This commit is contained in:
@@ -82,7 +82,6 @@ from sglang.srt.utils import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
|
||||
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
|
||||
|
||||
@@ -119,6 +118,7 @@ class ModelRunner:
|
||||
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
||||
server_args.speculative_algorithm
|
||||
)
|
||||
self.page_size = server_args.page_size
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||
|
||||
@@ -161,6 +161,11 @@ class ModelRunner:
|
||||
# Get memory before model loading
|
||||
min_per_gpu_memory = self.init_torch_distributed()
|
||||
|
||||
# If it is a draft model tp_group can be different.
|
||||
self.initialize(min_per_gpu_memory)
|
||||
|
||||
def initialize(self, min_per_gpu_memory: float):
|
||||
server_args = self.server_args
|
||||
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=self.server_args.enable_memory_saver
|
||||
)
|
||||
@@ -300,15 +305,16 @@ class ModelRunner:
|
||||
min_per_gpu_memory = get_available_gpu_memory(
|
||||
self.device, self.gpu_id, distributed=self.tp_size > 1
|
||||
)
|
||||
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
||||
self.tp_group = get_tp_group()
|
||||
self.attention_tp_group = get_attention_tp_group()
|
||||
|
||||
# Check memory for tensor parallelism
|
||||
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
||||
if self.tp_size > 1:
|
||||
if min_per_gpu_memory < local_gpu_memory * 0.9:
|
||||
raise ValueError(
|
||||
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
|
||||
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
|
||||
f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -698,6 +704,12 @@ class ModelRunner:
|
||||
)
|
||||
self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
|
||||
|
||||
self.max_total_num_tokens = (
|
||||
self.max_total_num_tokens
|
||||
// self.server_args.page_size
|
||||
* self.server_args.page_size
|
||||
)
|
||||
|
||||
if self.max_total_num_tokens <= 0:
|
||||
raise RuntimeError(
|
||||
"Not enough memory. Please try to increase --mem-fraction-static."
|
||||
@@ -783,7 +795,6 @@ class ModelRunner:
|
||||
# Init streams
|
||||
if self.server_args.speculative_algorithm == "EAGLE":
|
||||
self.plan_stream_for_flashinfer = torch.cuda.Stream()
|
||||
|
||||
self.attn_backend = FlashInferAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "triton":
|
||||
assert self.sliding_window_size is None, (
|
||||
|
||||
Reference in New Issue
Block a user