[CI] balance unit tests (#1977)
This commit is contained in:
12
.github/workflows/pr-test.yml
vendored
12
.github/workflows/pr-test.yml
vendored
@@ -47,10 +47,10 @@ jobs:
|
|||||||
bash scripts/ci_install_dependency.sh
|
bash scripts/ci_install_dependency.sh
|
||||||
|
|
||||||
- name: Run test
|
- name: Run test
|
||||||
timeout-minutes: 20
|
timeout-minutes: 25
|
||||||
run: |
|
run: |
|
||||||
cd test/srt
|
cd test/srt
|
||||||
python3 run_suite.py --suite minimal --range-begin 0 --range-end 6
|
python3 run_suite.py --suite minimal --range-begin 0 --range-end 5
|
||||||
|
|
||||||
unit-test-backend-part-2:
|
unit-test-backend-part-2:
|
||||||
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
||||||
@@ -64,10 +64,10 @@ jobs:
|
|||||||
bash scripts/ci_install_dependency.sh
|
bash scripts/ci_install_dependency.sh
|
||||||
|
|
||||||
- name: Run test
|
- name: Run test
|
||||||
timeout-minutes: 20
|
timeout-minutes: 25
|
||||||
run: |
|
run: |
|
||||||
cd test/srt
|
cd test/srt
|
||||||
python3 run_suite.py --suite minimal --range-begin 6 --range-end 14
|
python3 run_suite.py --suite minimal --range-begin 5 --range-end 14
|
||||||
|
|
||||||
unit-test-backend-part-3:
|
unit-test-backend-part-3:
|
||||||
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
||||||
@@ -81,7 +81,7 @@ jobs:
|
|||||||
bash scripts/ci_install_dependency.sh
|
bash scripts/ci_install_dependency.sh
|
||||||
|
|
||||||
- name: Run test
|
- name: Run test
|
||||||
timeout-minutes: 20
|
timeout-minutes: 25
|
||||||
run: |
|
run: |
|
||||||
cd test/srt
|
cd test/srt
|
||||||
python3 run_suite.py --suite minimal --range-begin 14 --range-end 20
|
python3 run_suite.py --suite minimal --range-begin 14 --range-end 20
|
||||||
@@ -98,7 +98,7 @@ jobs:
|
|||||||
bash scripts/ci_install_dependency.sh
|
bash scripts/ci_install_dependency.sh
|
||||||
|
|
||||||
- name: Run test
|
- name: Run test
|
||||||
timeout-minutes: 20
|
timeout-minutes: 25
|
||||||
run: |
|
run: |
|
||||||
cd test/srt
|
cd test/srt
|
||||||
python3 run_suite.py --suite minimal --range-begin 20
|
python3 run_suite.py --suite minimal --range-begin 20
|
||||||
|
|||||||
@@ -114,9 +114,16 @@ async def health() -> Response:
|
|||||||
@app.get("/health_generate")
|
@app.get("/health_generate")
|
||||||
async def health_generate(request: Request) -> Response:
|
async def health_generate(request: Request) -> Response:
|
||||||
"""Check the health of the inference server by generating one token."""
|
"""Check the health of the inference server by generating one token."""
|
||||||
gri = GenerateReqInput(
|
|
||||||
text="s", sampling_params={"max_new_tokens": 1, "temperature": 0.7}
|
if tokenizer_manager.is_generation:
|
||||||
)
|
gri = GenerateReqInput(
|
||||||
|
input_ids=[0], sampling_params={"max_new_tokens": 1, "temperature": 0.7}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
gri = EmbeddingReqInput(
|
||||||
|
input_ids=[0], sampling_params={"max_new_tokens": 1, "temperature": 0.7}
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for _ in tokenizer_manager.generate_request(gri, request):
|
async for _ in tokenizer_manager.generate_request(gri, request):
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -442,7 +442,7 @@ def popen_launch_server(
|
|||||||
"Content-Type": "application/json; charset=utf-8",
|
"Content-Type": "application/json; charset=utf-8",
|
||||||
"Authorization": f"Bearer {api_key}",
|
"Authorization": f"Bearer {api_key}",
|
||||||
}
|
}
|
||||||
response = requests.get(f"{base_url}/v1/models", headers=headers)
|
response = requests.get(f"{base_url}/health_generate", headers=headers)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return process
|
return process
|
||||||
except requests.RequestException:
|
except requests.RequestException:
|
||||||
@@ -637,8 +637,8 @@ def calculate_rouge_l(output_strs_list1, output_strs_list2):
|
|||||||
return rouge_l_scores
|
return rouge_l_scores
|
||||||
|
|
||||||
|
|
||||||
STDOUT_FILENAME = "stdout.txt"
|
|
||||||
STDERR_FILENAME = "stderr.txt"
|
STDERR_FILENAME = "stderr.txt"
|
||||||
|
STDOUT_FILENAME = "stdout.txt"
|
||||||
|
|
||||||
|
|
||||||
def read_output(output_lines):
|
def read_output(output_lines):
|
||||||
|
|||||||
@@ -15,6 +15,8 @@ from sglang.test.test_utils import (
|
|||||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
DEFAULT_URL_FOR_TEST,
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
STDERR_FILENAME,
|
||||||
|
STDOUT_FILENAME,
|
||||||
popen_launch_server,
|
popen_launch_server,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -26,8 +28,8 @@ class TestLargeMaxNewTokens(unittest.TestCase):
|
|||||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
cls.api_key = "sk-123456"
|
cls.api_key = "sk-123456"
|
||||||
|
|
||||||
cls.stdout = open("stdout.txt", "w")
|
cls.stdout = open(STDOUT_FILENAME, "w")
|
||||||
cls.stderr = open("stderr.txt", "w")
|
cls.stderr = open(STDERR_FILENAME, "w")
|
||||||
|
|
||||||
cls.process = popen_launch_server(
|
cls.process = popen_launch_server(
|
||||||
cls.model,
|
cls.model,
|
||||||
@@ -53,8 +55,8 @@ class TestLargeMaxNewTokens(unittest.TestCase):
|
|||||||
kill_child_process(cls.process.pid, include_self=True)
|
kill_child_process(cls.process.pid, include_self=True)
|
||||||
cls.stdout.close()
|
cls.stdout.close()
|
||||||
cls.stderr.close()
|
cls.stderr.close()
|
||||||
os.remove("stdout.txt")
|
os.remove(STDOUT_FILENAME)
|
||||||
os.remove("stderr.txt")
|
os.remove(STDERR_FILENAME)
|
||||||
|
|
||||||
def run_chat_completion(self):
|
def run_chat_completion(self):
|
||||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
@@ -84,7 +86,7 @@ class TestLargeMaxNewTokens(unittest.TestCase):
|
|||||||
pt = 0
|
pt = 0
|
||||||
while pt >= 0:
|
while pt >= 0:
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
lines = open("stderr.txt").readlines()
|
lines = open(STDERR_FILENAME).readlines()
|
||||||
for line in lines[pt:]:
|
for line in lines[pt:]:
|
||||||
print(line, end="", flush=True)
|
print(line, end="", flush=True)
|
||||||
if f"#running-req: {num_requests}" in line:
|
if f"#running-req: {num_requests}" in line:
|
||||||
|
|||||||
Reference in New Issue
Block a user