Split the scheduler into multiple mixin classes to reduce the file size (#8483)

This commit is contained in:
Lianmin Zheng
2025-07-29 12:46:50 -07:00
committed by GitHub
parent 5973675bc3
commit a4c3b121d8
12 changed files with 869 additions and 785 deletions

View File

@@ -170,16 +170,6 @@ class ReqState:
output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
is_cross_node = server_args.dist_init_addr
if is_cross_node:
# Fallback to default CPU transport for multi-node
return "default"
else:
return "cuda_ipc"
class TokenizerManager:
"""TokenizerManager is a process that tokenizes the text."""
@@ -199,16 +189,6 @@ class TokenizerManager:
else None
)
self.crash_dump_folder = server_args.crash_dump_folder
self.crash_dump_performed = False # Flag to ensure dump is only called once
# Init inter-process communication
context = zmq.asyncio.Context(2)
self.recv_from_detokenizer = get_zmq_socket(
context, zmq.PULL, port_args.tokenizer_ipc_name, True
)
self.send_to_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
)
# Read model args
self.model_path = server_args.model_path
@@ -218,8 +198,7 @@ class TokenizerManager:
self.is_image_gen = self.model_config.is_image_gen
self.context_len = self.model_config.context_len
self.image_token_id = self.model_config.image_token_id
self._updating = False
self._cond = asyncio.Condition()
self.max_req_input_len = None # Will be set later in engine.py
if self.model_config.is_multimodal:
import_processors()
@@ -258,39 +237,57 @@ class TokenizerManager:
revision=server_args.revision,
)
# Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
# The registry dynamically updates as adapters are loaded / unloaded during runtime. It
# serves as the source of truth for available adapters and maps user-friendly LoRA names
# to internally used unique LoRA IDs.
self.lora_registry = LoRARegistry(self.server_args.lora_paths or {})
# Init inter-process communication
context = zmq.asyncio.Context(2)
self.recv_from_detokenizer = get_zmq_socket(
context, zmq.PULL, port_args.tokenizer_ipc_name, True
)
self.send_to_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
)
# Store states
# Request states
self.no_create_loop = False
self.rid_to_state: Dict[str, ReqState] = {}
self.asyncio_tasks = set()
# Health check
self.health_check_failed = False
self.gracefully_exit = False
self.last_receive_tstamp = 0
# Dumping
self.dump_requests_folder = "" # By default do not dump
self.dump_requests_threshold = 1000
self.dump_request_list: List[Tuple] = []
self.crash_dump_request_list: deque[Tuple] = deque()
self.log_request_metadata = self.get_log_request_metadata()
self.session_futures = {} # session_id -> asyncio event
self.max_req_input_len = None
self.asyncio_tasks = set()
self.crash_dump_request_list: deque[Tuple] = deque()
self.crash_dump_performed = False # Flag to ensure dump is only called once
# Session
self.session_futures = {} # session_id -> asyncio event
# Weight updates
# The event to notify the weight sync is finished.
self.model_update_lock = RWLock()
self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
None
)
self._is_updating = False
self._is_updating_cond = asyncio.Condition()
# LoRA
# Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
# The registry dynamically updates as adapters are loaded / unloaded during runtime. It
# serves as the source of truth for available adapters and maps user-friendly LoRA names
# to internally used unique LoRA IDs.
self.lora_registry = LoRARegistry(self.server_args.lora_paths or {})
# Lock to serialize LoRA update operations.
# Please note that, unlike `model_update_lock`, this does not block inference, allowing
# LoRA updates and inference to overlap.
self.lora_update_lock = asyncio.Lock()
# For pd disaggregtion
# For PD disaggregtion
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
@@ -458,17 +455,11 @@ class TokenizerManager:
request: Optional[fastapi.Request] = None,
):
created_time = time.time()
async with self._cond:
await self._cond.wait_for(lambda: not self._updating)
self.auto_create_handle_loop()
obj.normalize_batch_and_arguments()
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
raise ValueError(
"This model does not appear to be an embedding model by default. "
"Please add `--is-embedding` when launching the server or try another model."
)
async with self._is_updating_cond:
await self._is_updating_cond.wait_for(lambda: not self._is_updating)
if self.log_requests:
max_length, skip_names, _ = self.log_request_metadata
@@ -567,6 +558,12 @@ class TokenizerManager:
f"model's context length ({self.context_len} tokens)."
)
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
raise ValueError(
"This model does not appear to be an embedding model by default. "
"Please add `--is-embedding` when launching the server or try another model."
)
# Check total tokens (input + max_new_tokens)
max_new_tokens = obj.sampling_params.get("max_new_tokens")
if (
@@ -959,14 +956,14 @@ class TokenizerManager:
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
async def pause_generation(self):
async with self._cond:
self._updating = True
async with self._is_updating_cond:
self._is_updating = True
self.abort_request(abort_all=True)
async def continue_generation(self):
async with self._cond:
self._updating = False
self._cond.notify_all()
async with self._is_updating_cond:
self._is_updating = False
self._is_updating_cond.notify_all()
async def update_weights_from_disk(
self,
@@ -1208,14 +1205,6 @@ class TokenizerManager:
# Many DP ranks
return [res.internal_state for res in responses]
async def get_load(self) -> dict:
# TODO(lsyin): fake load report server
if not self.current_load_lock.locked():
async with self.current_load_lock:
internal_state = await self.get_internal_state()
self.current_load = internal_state[0]["load"]
return {"load": self.current_load}
async def set_internal_state(
self, obj: SetInternalStateReq
) -> SetInternalStateReqOutput:
@@ -1224,6 +1213,14 @@ class TokenizerManager:
)
return [res.internal_state for res in responses]
async def get_load(self) -> dict:
# TODO(lsyin): fake load report server
if not self.current_load_lock.locked():
async with self.current_load_lock:
internal_state = await self.get_internal_state()
self.current_load = internal_state[0]["load"]
return {"load": self.current_load}
def get_log_request_metadata(self):
max_length = None
skip_names = None
@@ -1343,11 +1340,24 @@ class TokenizerManager:
"SIGTERM/SIGQUIT/Exception triggered, but crash dump already performed, skipping."
)
return
logger.error(f"Dumping requests before crash. {self.crash_dump_folder=}")
self.crash_dump_performed = True
if not self.crash_dump_folder:
return
logger.error(f"Dumping requests before crash. {self.crash_dump_folder=}")
self.crash_dump_performed = True
# Check if NFS directory is available
# expected_nfs_dir = "/" + self.crash_dump_folder.lstrip("/").split("/")[0]
# use_nfs_dir = os.path.isdir(expected_nfs_dir) and os.access(
# expected_nfs_dir, os.W_OK
# )
use_nfs_dir = False
if not use_nfs_dir:
logger.error(
f"Expected NFS directory is not available or writable. Uploading to GCS."
)
data_to_dump = []
if self.crash_dump_request_list:
data_to_dump.extend(self.crash_dump_request_list)
@@ -1357,7 +1367,12 @@ class TokenizerManager:
for rid, state in self.rid_to_state.items():
if not state.finished:
unfinished_requests.append(
(state.obj, {}, state.created_time, time.time())
(
state.obj,
state.out_list[-1] if state.out_list else {},
state.created_time,
time.time(),
)
)
if unfinished_requests:
data_to_dump.extend(unfinished_requests)
@@ -1365,10 +1380,11 @@ class TokenizerManager:
if not data_to_dump:
return
object_name = f'crash_dump_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.pkl'
filename = os.path.join(
self.crash_dump_folder,
os.getenv("HOSTNAME", None),
f"crash_dump_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.pkl",
object_name,
)
os.makedirs(os.path.dirname(filename), exist_ok=True)
@@ -1383,6 +1399,24 @@ class TokenizerManager:
f"Dumped {len(self.crash_dump_request_list)} finished and {len(unfinished_requests)} unfinished requests before crash to {filename}"
)
def _upload_file_to_gcs(bucket_name, source_file_path, object_name):
from google.cloud import storage
client = storage.Client()
bucket = client.bucket(bucket_name)
blob = bucket.blob(object_name)
blob.upload_from_filename(source_file_path, if_generation_match=0)
logger.error(
f"Successfully uploaded {source_file_path} to gs://{bucket_name}/{object_name}"
)
if not use_nfs_dir:
_upload_file_to_gcs(
"sglang_crash_dump",
filename,
os.getenv("HOSTNAME", None) + "/" + object_name,
)
async def sigterm_watchdog(self):
while not self.gracefully_exit:
await asyncio.sleep(5)
@@ -1426,7 +1460,7 @@ class TokenizerManager:
while True:
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
self._result_dispatcher(recv_obj)
self.last_receive_tstamp = time.perf_counter()
self.last_receive_tstamp = time.time()
def _handle_batch_output(
self,
@@ -1697,24 +1731,13 @@ class TokenizerManager:
self.dump_requests_folder,
datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl",
)
logger.info(f"Dump {len(self.dump_request_list)} requests to {filename}")
to_dump = self.dump_request_list
self._dump_data_to_file(
data_list=self.dump_request_list,
filename=filename,
log_message=f"Dump {len(self.dump_request_list)} requests to {filename}",
)
self.dump_request_list = []
to_dump_with_server_args = {
"server_args": self.server_args,
"requests": to_dump,
}
def background_task():
os.makedirs(self.dump_requests_folder, exist_ok=True)
with open(filename, "wb") as f:
pickle.dump(to_dump_with_server_args, f)
# Schedule the task to run in the background without awaiting it
asyncio.create_task(asyncio.to_thread(background_task))
def record_request_for_crash_dump(self, state: ReqState, out_dict: dict):
current_time = time.time()
self.crash_dump_request_list.append(
@@ -1727,6 +1750,22 @@ class TokenizerManager:
):
self.crash_dump_request_list.popleft()
def _dump_data_to_file(
self, data_list: List[Tuple], filename: str, log_message: str
):
logger.info(log_message)
to_dump_with_server_args = {
"server_args": self.server_args,
"requests": data_list.copy(),
}
def background_task():
os.makedirs(os.path.dirname(filename), exist_ok=True)
with open(filename, "wb") as f:
pickle.dump(to_dump_with_server_args, f)
asyncio.create_task(asyncio.to_thread(background_task))
def _handle_abort_req(self, recv_obj):
state = self.rid_to_state[recv_obj.rid]
state.finished = True
@@ -1862,6 +1901,16 @@ class TokenizerManager:
return scores
def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
is_cross_node = server_args.dist_init_addr
if is_cross_node:
# Fallback to default CPU transport for multi-node
return "default"
else:
return "cuda_ipc"
async def print_exception_wrapper(func):
"""
Sometimes an asyncio function does not print exception.