Simplify our docs with complicated functions into utils (#1807)

Co-authored-by: Chayenne <zhaochenyang@ucla.edu>
This commit is contained in:
Chayenne
2024-10-26 10:44:11 -07:00
committed by GitHub
parent 9084a86445
commit ced362f7c6
5 changed files with 159 additions and 103 deletions

View File

@@ -1,12 +1,15 @@
"""Common utilities."""
import base64
import gc
import importlib
import json
import logging
import os
import signal
import subprocess
import sys
import time
import traceback
import urllib.request
from concurrent.futures import ThreadPoolExecutor
@@ -16,6 +19,7 @@ from typing import Optional, Union
import numpy as np
import requests
import torch
from tqdm import tqdm
logger = logging.getLogger(__name__)
@@ -294,3 +298,80 @@ def download_and_cache_file(url: str, filename: Optional[str] = None):
bar.update(len(chunk))
return filename
def execute_shell_command(command: str) -> subprocess.Popen:
"""
Execute a shell command and return the process handle
Args:
command: Shell command as a string (can include \ line continuations)
Returns:
subprocess.Popen: Process handle
"""
# Replace \ newline with space and split
command = command.replace("\\\n", " ").replace("\\", " ")
parts = command.split()
return subprocess.Popen(
parts,
text=True,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
def wait_for_server(base_url: str, timeout: int = None) -> None:
"""Wait for the server to be ready by polling the /v1/models endpoint.
Args:
base_url: The base URL of the server
timeout: Maximum time to wait in seconds. None means wait forever.
"""
start_time = time.time()
while True:
try:
response = requests.get(
f"{base_url}/v1/models",
headers={"Authorization": "Bearer None"},
)
if response.status_code == 200:
break
if timeout and time.time() - start_time > timeout:
raise TimeoutError("Server did not become ready within timeout period")
except requests.exceptions.RequestException:
time.sleep(1)
def terminate_process(process):
"""Safely terminate a process and clean up GPU memory.
Args:
process: subprocess.Popen object to terminate
"""
try:
process.terminate()
try:
process.wait(timeout=5)
except subprocess.TimeoutExpired:
if os.name != "nt":
try:
pgid = os.getpgid(process.pid)
os.killpg(pgid, signal.SIGTERM)
time.sleep(1)
if process.poll() is None:
os.killpg(pgid, signal.SIGKILL)
except ProcessLookupError:
pass
else:
process.kill()
process.wait()
except Exception as e:
print(f"Warning: {e}")
finally:
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
time.sleep(2)