Improve the structure of CI (#911)

This commit is contained in:
Ying Sheng
2024-08-03 23:09:21 -07:00
committed by GitHub
parent 539856455d
commit 995af5a54b
29 changed files with 451 additions and 237 deletions

View File

@@ -12,6 +12,7 @@ import urllib.request
from concurrent.futures import ThreadPoolExecutor
from io import BytesIO
from json import dumps
from typing import Union
import numpy as np
import requests
@@ -25,7 +26,7 @@ def get_exception_traceback():
return err_str
def is_same_type(values):
def is_same_type(values: list):
"""Return whether the elements in values are of the same type."""
if len(values) <= 1:
return True
@@ -45,7 +46,7 @@ def read_jsonl(filename: str):
return rets
def dump_state_text(filename, states, mode="w"):
def dump_state_text(filename: str, states: list, mode: str = "w"):
"""Dump program state in a text file."""
from sglang.lang.interpreter import ProgramState
@@ -105,7 +106,7 @@ def http_request(
return HttpResponse(e)
def encode_image_base64(image_path):
def encode_image_base64(image_path: Union[str, bytes]):
"""Encode an image in base64."""
if isinstance(image_path, str):
with open(image_path, "rb") as image_file:
@@ -144,7 +145,7 @@ def encode_frame(frame):
return frame_bytes
def encode_video_base64(video_path, num_frames=16):
def encode_video_base64(video_path: str, num_frames: int = 16):
import cv2 # pip install opencv-python-headless
cap = cv2.VideoCapture(video_path)
@@ -190,7 +191,7 @@ def encode_video_base64(video_path, num_frames=16):
return video_base64
def _is_chinese_char(cp):
def _is_chinese_char(cp: int):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
@@ -215,7 +216,7 @@ def _is_chinese_char(cp):
return False
def find_printable_text(text):
def find_printable_text(text: str):
"""Returns the longest printable substring of text that contains only entire words."""
# Borrowed from https://github.com/huggingface/transformers/blob/061580c82c2db1de9139528243e105953793f7a2/src/transformers/generation/streamers.py#L99
@@ -234,26 +235,7 @@ def find_printable_text(text):
return text[: text.rfind(" ") + 1]
def run_with_timeout(func, args=(), kwargs=None, timeout=None):
"""Run a function with timeout."""
ret_value = []
def _target_func():
ret_value.append(func(*args, **(kwargs or {})))
t = threading.Thread(target=_target_func)
t.start()
t.join(timeout=timeout)
if t.is_alive():
raise TimeoutError()
if not ret_value:
raise RuntimeError()
return ret_value[0]
def graceful_registry(sub_module_name):
def graceful_registry(sub_module_name: str):
def graceful_shutdown(signum, frame):
logger.info(
f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown..."
@@ -265,7 +247,9 @@ def graceful_registry(sub_module_name):
class LazyImport:
def __init__(self, module_name, class_name):
"""Lazy import to make `import sglang` run faster."""
def __init__(self, module_name: str, class_name: str):
self.module_name = module_name
self.class_name = class_name
self._module = None
@@ -276,7 +260,7 @@ class LazyImport:
self._module = getattr(module, self.class_name)
return self._module
def __getattr__(self, name):
def __getattr__(self, name: str):
module = self._load()
return getattr(module, name)