diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 29b976a12..150fb9258 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -76,7 +76,7 @@ jobs: timeout-minutes: 20 run: | cd test/srt - python3 run_suite.py --suite minimal --range-begin 5 --range-end 15 + python3 run_suite.py --suite minimal --range-begin 5 --range-end 16 unit-test-backend-part-3: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' @@ -96,7 +96,7 @@ jobs: timeout-minutes: 20 run: | cd test/srt - python3 run_suite.py --suite minimal --range-begin 15 + python3 run_suite.py --suite minimal --range-begin 16 performance-test-1-gpu-part-1: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 0b2172b14..bb70366b0 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -421,6 +421,7 @@ class ScheduleBatch: extend_lens: List[int] = None extend_num_tokens: int = None running_bs: int = None + decoding_reqs: List[Req] = None # Stream has_stream: bool = False diff --git a/python/sglang/test/few_shot_gsm8k.py b/python/sglang/test/few_shot_gsm8k.py index 960f9c2e0..1f2af1629 100644 --- a/python/sglang/test/few_shot_gsm8k.py +++ b/python/sglang/test/few_shot_gsm8k.py @@ -76,7 +76,9 @@ def run_eval(args): def few_shot_gsm8k(s, question): s += few_shot_examples + question s += sgl.gen( - "answer", max_tokens=512, stop=["Question", "Assistant:", "<|separator|>"] + "answer", + max_tokens=args.max_new_tokens, + stop=["Question", "Assistant:", "<|separator|>"], ) ##################################### @@ -131,6 +133,7 @@ if __name__ == "__main__": parser.add_argument("--num-shots", type=int, default=5) parser.add_argument("--data-path", type=str, default="test.jsonl") parser.add_argument("--num-questions", type=int, default=200) + parser.add_argument("--max-new-tokens", type=int, default=512) parser.add_argument("--parallel", type=int, default=128) parser.add_argument("--host", type=str, default="http://127.0.0.1") parser.add_argument("--port", type=int, default=30000)