Split the scheduler into multiple mixin classes to reduce the file size (#8483)
This commit is contained in:
@@ -170,16 +170,6 @@ class ReqState:
|
||||
output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
|
||||
|
||||
|
||||
def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
|
||||
is_cross_node = server_args.dist_init_addr
|
||||
|
||||
if is_cross_node:
|
||||
# Fallback to default CPU transport for multi-node
|
||||
return "default"
|
||||
else:
|
||||
return "cuda_ipc"
|
||||
|
||||
|
||||
class TokenizerManager:
|
||||
"""TokenizerManager is a process that tokenizes the text."""
|
||||
|
||||
@@ -199,16 +189,6 @@ class TokenizerManager:
|
||||
else None
|
||||
)
|
||||
self.crash_dump_folder = server_args.crash_dump_folder
|
||||
self.crash_dump_performed = False # Flag to ensure dump is only called once
|
||||
|
||||
# Init inter-process communication
|
||||
context = zmq.asyncio.Context(2)
|
||||
self.recv_from_detokenizer = get_zmq_socket(
|
||||
context, zmq.PULL, port_args.tokenizer_ipc_name, True
|
||||
)
|
||||
self.send_to_scheduler = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
|
||||
)
|
||||
|
||||
# Read model args
|
||||
self.model_path = server_args.model_path
|
||||
@@ -218,8 +198,7 @@ class TokenizerManager:
|
||||
self.is_image_gen = self.model_config.is_image_gen
|
||||
self.context_len = self.model_config.context_len
|
||||
self.image_token_id = self.model_config.image_token_id
|
||||
self._updating = False
|
||||
self._cond = asyncio.Condition()
|
||||
self.max_req_input_len = None # Will be set later in engine.py
|
||||
|
||||
if self.model_config.is_multimodal:
|
||||
import_processors()
|
||||
@@ -258,39 +237,57 @@ class TokenizerManager:
|
||||
revision=server_args.revision,
|
||||
)
|
||||
|
||||
# Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
|
||||
# The registry dynamically updates as adapters are loaded / unloaded during runtime. It
|
||||
# serves as the source of truth for available adapters and maps user-friendly LoRA names
|
||||
# to internally used unique LoRA IDs.
|
||||
self.lora_registry = LoRARegistry(self.server_args.lora_paths or {})
|
||||
# Init inter-process communication
|
||||
context = zmq.asyncio.Context(2)
|
||||
self.recv_from_detokenizer = get_zmq_socket(
|
||||
context, zmq.PULL, port_args.tokenizer_ipc_name, True
|
||||
)
|
||||
self.send_to_scheduler = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
|
||||
)
|
||||
|
||||
# Store states
|
||||
# Request states
|
||||
self.no_create_loop = False
|
||||
self.rid_to_state: Dict[str, ReqState] = {}
|
||||
self.asyncio_tasks = set()
|
||||
|
||||
# Health check
|
||||
self.health_check_failed = False
|
||||
self.gracefully_exit = False
|
||||
self.last_receive_tstamp = 0
|
||||
|
||||
# Dumping
|
||||
self.dump_requests_folder = "" # By default do not dump
|
||||
self.dump_requests_threshold = 1000
|
||||
self.dump_request_list: List[Tuple] = []
|
||||
self.crash_dump_request_list: deque[Tuple] = deque()
|
||||
self.log_request_metadata = self.get_log_request_metadata()
|
||||
self.session_futures = {} # session_id -> asyncio event
|
||||
self.max_req_input_len = None
|
||||
self.asyncio_tasks = set()
|
||||
self.crash_dump_request_list: deque[Tuple] = deque()
|
||||
self.crash_dump_performed = False # Flag to ensure dump is only called once
|
||||
|
||||
# Session
|
||||
self.session_futures = {} # session_id -> asyncio event
|
||||
|
||||
# Weight updates
|
||||
# The event to notify the weight sync is finished.
|
||||
self.model_update_lock = RWLock()
|
||||
self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
|
||||
None
|
||||
)
|
||||
self._is_updating = False
|
||||
self._is_updating_cond = asyncio.Condition()
|
||||
|
||||
# LoRA
|
||||
# Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
|
||||
# The registry dynamically updates as adapters are loaded / unloaded during runtime. It
|
||||
# serves as the source of truth for available adapters and maps user-friendly LoRA names
|
||||
# to internally used unique LoRA IDs.
|
||||
self.lora_registry = LoRARegistry(self.server_args.lora_paths or {})
|
||||
# Lock to serialize LoRA update operations.
|
||||
# Please note that, unlike `model_update_lock`, this does not block inference, allowing
|
||||
# LoRA updates and inference to overlap.
|
||||
self.lora_update_lock = asyncio.Lock()
|
||||
|
||||
# For pd disaggregtion
|
||||
# For PD disaggregtion
|
||||
self.disaggregation_mode = DisaggregationMode(
|
||||
self.server_args.disaggregation_mode
|
||||
)
|
||||
@@ -458,17 +455,11 @@ class TokenizerManager:
|
||||
request: Optional[fastapi.Request] = None,
|
||||
):
|
||||
created_time = time.time()
|
||||
async with self._cond:
|
||||
await self._cond.wait_for(lambda: not self._updating)
|
||||
|
||||
self.auto_create_handle_loop()
|
||||
obj.normalize_batch_and_arguments()
|
||||
|
||||
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
|
||||
raise ValueError(
|
||||
"This model does not appear to be an embedding model by default. "
|
||||
"Please add `--is-embedding` when launching the server or try another model."
|
||||
)
|
||||
async with self._is_updating_cond:
|
||||
await self._is_updating_cond.wait_for(lambda: not self._is_updating)
|
||||
|
||||
if self.log_requests:
|
||||
max_length, skip_names, _ = self.log_request_metadata
|
||||
@@ -567,6 +558,12 @@ class TokenizerManager:
|
||||
f"model's context length ({self.context_len} tokens)."
|
||||
)
|
||||
|
||||
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
|
||||
raise ValueError(
|
||||
"This model does not appear to be an embedding model by default. "
|
||||
"Please add `--is-embedding` when launching the server or try another model."
|
||||
)
|
||||
|
||||
# Check total tokens (input + max_new_tokens)
|
||||
max_new_tokens = obj.sampling_params.get("max_new_tokens")
|
||||
if (
|
||||
@@ -959,14 +956,14 @@ class TokenizerManager:
|
||||
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
|
||||
|
||||
async def pause_generation(self):
|
||||
async with self._cond:
|
||||
self._updating = True
|
||||
async with self._is_updating_cond:
|
||||
self._is_updating = True
|
||||
self.abort_request(abort_all=True)
|
||||
|
||||
async def continue_generation(self):
|
||||
async with self._cond:
|
||||
self._updating = False
|
||||
self._cond.notify_all()
|
||||
async with self._is_updating_cond:
|
||||
self._is_updating = False
|
||||
self._is_updating_cond.notify_all()
|
||||
|
||||
async def update_weights_from_disk(
|
||||
self,
|
||||
@@ -1208,14 +1205,6 @@ class TokenizerManager:
|
||||
# Many DP ranks
|
||||
return [res.internal_state for res in responses]
|
||||
|
||||
async def get_load(self) -> dict:
|
||||
# TODO(lsyin): fake load report server
|
||||
if not self.current_load_lock.locked():
|
||||
async with self.current_load_lock:
|
||||
internal_state = await self.get_internal_state()
|
||||
self.current_load = internal_state[0]["load"]
|
||||
return {"load": self.current_load}
|
||||
|
||||
async def set_internal_state(
|
||||
self, obj: SetInternalStateReq
|
||||
) -> SetInternalStateReqOutput:
|
||||
@@ -1224,6 +1213,14 @@ class TokenizerManager:
|
||||
)
|
||||
return [res.internal_state for res in responses]
|
||||
|
||||
async def get_load(self) -> dict:
|
||||
# TODO(lsyin): fake load report server
|
||||
if not self.current_load_lock.locked():
|
||||
async with self.current_load_lock:
|
||||
internal_state = await self.get_internal_state()
|
||||
self.current_load = internal_state[0]["load"]
|
||||
return {"load": self.current_load}
|
||||
|
||||
def get_log_request_metadata(self):
|
||||
max_length = None
|
||||
skip_names = None
|
||||
@@ -1343,11 +1340,24 @@ class TokenizerManager:
|
||||
"SIGTERM/SIGQUIT/Exception triggered, but crash dump already performed, skipping."
|
||||
)
|
||||
return
|
||||
logger.error(f"Dumping requests before crash. {self.crash_dump_folder=}")
|
||||
self.crash_dump_performed = True
|
||||
|
||||
if not self.crash_dump_folder:
|
||||
return
|
||||
|
||||
logger.error(f"Dumping requests before crash. {self.crash_dump_folder=}")
|
||||
self.crash_dump_performed = True
|
||||
|
||||
# Check if NFS directory is available
|
||||
# expected_nfs_dir = "/" + self.crash_dump_folder.lstrip("/").split("/")[0]
|
||||
# use_nfs_dir = os.path.isdir(expected_nfs_dir) and os.access(
|
||||
# expected_nfs_dir, os.W_OK
|
||||
# )
|
||||
use_nfs_dir = False
|
||||
if not use_nfs_dir:
|
||||
logger.error(
|
||||
f"Expected NFS directory is not available or writable. Uploading to GCS."
|
||||
)
|
||||
|
||||
data_to_dump = []
|
||||
if self.crash_dump_request_list:
|
||||
data_to_dump.extend(self.crash_dump_request_list)
|
||||
@@ -1357,7 +1367,12 @@ class TokenizerManager:
|
||||
for rid, state in self.rid_to_state.items():
|
||||
if not state.finished:
|
||||
unfinished_requests.append(
|
||||
(state.obj, {}, state.created_time, time.time())
|
||||
(
|
||||
state.obj,
|
||||
state.out_list[-1] if state.out_list else {},
|
||||
state.created_time,
|
||||
time.time(),
|
||||
)
|
||||
)
|
||||
if unfinished_requests:
|
||||
data_to_dump.extend(unfinished_requests)
|
||||
@@ -1365,10 +1380,11 @@ class TokenizerManager:
|
||||
if not data_to_dump:
|
||||
return
|
||||
|
||||
object_name = f'crash_dump_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.pkl'
|
||||
filename = os.path.join(
|
||||
self.crash_dump_folder,
|
||||
os.getenv("HOSTNAME", None),
|
||||
f"crash_dump_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.pkl",
|
||||
object_name,
|
||||
)
|
||||
|
||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||
@@ -1383,6 +1399,24 @@ class TokenizerManager:
|
||||
f"Dumped {len(self.crash_dump_request_list)} finished and {len(unfinished_requests)} unfinished requests before crash to {filename}"
|
||||
)
|
||||
|
||||
def _upload_file_to_gcs(bucket_name, source_file_path, object_name):
|
||||
from google.cloud import storage
|
||||
|
||||
client = storage.Client()
|
||||
bucket = client.bucket(bucket_name)
|
||||
blob = bucket.blob(object_name)
|
||||
blob.upload_from_filename(source_file_path, if_generation_match=0)
|
||||
logger.error(
|
||||
f"Successfully uploaded {source_file_path} to gs://{bucket_name}/{object_name}"
|
||||
)
|
||||
|
||||
if not use_nfs_dir:
|
||||
_upload_file_to_gcs(
|
||||
"sglang_crash_dump",
|
||||
filename,
|
||||
os.getenv("HOSTNAME", None) + "/" + object_name,
|
||||
)
|
||||
|
||||
async def sigterm_watchdog(self):
|
||||
while not self.gracefully_exit:
|
||||
await asyncio.sleep(5)
|
||||
@@ -1426,7 +1460,7 @@ class TokenizerManager:
|
||||
while True:
|
||||
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
||||
self._result_dispatcher(recv_obj)
|
||||
self.last_receive_tstamp = time.perf_counter()
|
||||
self.last_receive_tstamp = time.time()
|
||||
|
||||
def _handle_batch_output(
|
||||
self,
|
||||
@@ -1697,24 +1731,13 @@ class TokenizerManager:
|
||||
self.dump_requests_folder,
|
||||
datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl",
|
||||
)
|
||||
logger.info(f"Dump {len(self.dump_request_list)} requests to {filename}")
|
||||
|
||||
to_dump = self.dump_request_list
|
||||
self._dump_data_to_file(
|
||||
data_list=self.dump_request_list,
|
||||
filename=filename,
|
||||
log_message=f"Dump {len(self.dump_request_list)} requests to {filename}",
|
||||
)
|
||||
self.dump_request_list = []
|
||||
|
||||
to_dump_with_server_args = {
|
||||
"server_args": self.server_args,
|
||||
"requests": to_dump,
|
||||
}
|
||||
|
||||
def background_task():
|
||||
os.makedirs(self.dump_requests_folder, exist_ok=True)
|
||||
with open(filename, "wb") as f:
|
||||
pickle.dump(to_dump_with_server_args, f)
|
||||
|
||||
# Schedule the task to run in the background without awaiting it
|
||||
asyncio.create_task(asyncio.to_thread(background_task))
|
||||
|
||||
def record_request_for_crash_dump(self, state: ReqState, out_dict: dict):
|
||||
current_time = time.time()
|
||||
self.crash_dump_request_list.append(
|
||||
@@ -1727,6 +1750,22 @@ class TokenizerManager:
|
||||
):
|
||||
self.crash_dump_request_list.popleft()
|
||||
|
||||
def _dump_data_to_file(
|
||||
self, data_list: List[Tuple], filename: str, log_message: str
|
||||
):
|
||||
logger.info(log_message)
|
||||
to_dump_with_server_args = {
|
||||
"server_args": self.server_args,
|
||||
"requests": data_list.copy(),
|
||||
}
|
||||
|
||||
def background_task():
|
||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||
with open(filename, "wb") as f:
|
||||
pickle.dump(to_dump_with_server_args, f)
|
||||
|
||||
asyncio.create_task(asyncio.to_thread(background_task))
|
||||
|
||||
def _handle_abort_req(self, recv_obj):
|
||||
state = self.rid_to_state[recv_obj.rid]
|
||||
state.finished = True
|
||||
@@ -1862,6 +1901,16 @@ class TokenizerManager:
|
||||
return scores
|
||||
|
||||
|
||||
def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
|
||||
is_cross_node = server_args.dist_init_addr
|
||||
|
||||
if is_cross_node:
|
||||
# Fallback to default CPU transport for multi-node
|
||||
return "default"
|
||||
else:
|
||||
return "cuda_ipc"
|
||||
|
||||
|
||||
async def print_exception_wrapper(func):
|
||||
"""
|
||||
Sometimes an asyncio function does not print exception.
|
||||
|
||||
Reference in New Issue
Block a user