[Fix] Fix major performance bug in certain cases (#1563)
Co-authored-by: hnyls2002 <hnyls2002@gmail.com>
This commit is contained in:
6
.github/workflows/pr-test.yml
vendored
6
.github/workflows/pr-test.yml
vendored
@@ -130,6 +130,12 @@ jobs:
|
|||||||
cd test/srt
|
cd test/srt
|
||||||
python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_default
|
python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_default
|
||||||
|
|
||||||
|
- name: Benchmark Offline Throughput (Non-streaming, small batch size)
|
||||||
|
timeout-minutes: 10
|
||||||
|
run: |
|
||||||
|
cd test/srt
|
||||||
|
python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
|
||||||
|
|
||||||
performance-test-1-gpu-part-2:
|
performance-test-1-gpu-part-2:
|
||||||
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
||||||
runs-on: 1-gpu-runner
|
runs-on: 1-gpu-runner
|
||||||
|
|||||||
@@ -845,6 +845,7 @@ def run_benchmark(args_: argparse.Namespace):
|
|||||||
tokenizer = get_tokenizer(tokenizer_id)
|
tokenizer = get_tokenizer(tokenizer_id)
|
||||||
|
|
||||||
if args.dataset_name == "sharegpt":
|
if args.dataset_name == "sharegpt":
|
||||||
|
assert args.random_input_len is None and args.random_output_len is None
|
||||||
input_requests = sample_sharegpt_requests(
|
input_requests = sample_sharegpt_requests(
|
||||||
dataset_path=args.dataset_path,
|
dataset_path=args.dataset_path,
|
||||||
num_requests=args.num_prompts,
|
num_requests=args.num_prompts,
|
||||||
@@ -852,6 +853,7 @@ def run_benchmark(args_: argparse.Namespace):
|
|||||||
fixed_output_len=args.sharegpt_output_len,
|
fixed_output_len=args.sharegpt_output_len,
|
||||||
)
|
)
|
||||||
elif args.dataset_name == "random":
|
elif args.dataset_name == "random":
|
||||||
|
assert args.random_input_len is not None and args.random_output_len is not None
|
||||||
input_requests = sample_random_requests(
|
input_requests = sample_random_requests(
|
||||||
input_len=args.random_input_len,
|
input_len=args.random_input_len,
|
||||||
output_len=args.random_output_len,
|
output_len=args.random_output_len,
|
||||||
@@ -964,13 +966,11 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--random-input-len",
|
"--random-input-len",
|
||||||
type=int,
|
type=int,
|
||||||
default=1024,
|
|
||||||
help="Number of input tokens per request, used only for random dataset.",
|
help="Number of input tokens per request, used only for random dataset.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--random-output-len",
|
"--random-output-len",
|
||||||
type=int,
|
type=int,
|
||||||
default=128,
|
|
||||||
help="Number of output tokens per request, used only for random dataset.",
|
help="Number of output tokens per request, used only for random dataset.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
@@ -222,7 +222,7 @@ class Scheduler:
|
|||||||
)
|
)
|
||||||
self.new_token_ratio = self.min_new_token_ratio
|
self.new_token_ratio = self.min_new_token_ratio
|
||||||
self.new_token_ratio_decay = global_config.new_token_ratio_decay
|
self.new_token_ratio_decay = global_config.new_token_ratio_decay
|
||||||
self.do_not_get_new_batch = False
|
self.batch_is_full = False
|
||||||
|
|
||||||
def event_loop(self):
|
def event_loop(self):
|
||||||
while True:
|
while True:
|
||||||
@@ -261,12 +261,10 @@ class Scheduler:
|
|||||||
for recv_req in recv_reqs:
|
for recv_req in recv_reqs:
|
||||||
if isinstance(recv_req, TokenizedGenerateReqInput):
|
if isinstance(recv_req, TokenizedGenerateReqInput):
|
||||||
self.handle_generate_request(recv_req)
|
self.handle_generate_request(recv_req)
|
||||||
self.do_not_get_new_batch = False
|
|
||||||
elif isinstance(
|
elif isinstance(
|
||||||
recv_req, (TokenizedEmbeddingReqInput, TokenizedRewardReqInput)
|
recv_req, (TokenizedEmbeddingReqInput, TokenizedRewardReqInput)
|
||||||
):
|
):
|
||||||
self.handle_embedding_request(recv_req)
|
self.handle_embedding_request(recv_req)
|
||||||
self.do_not_get_new_batch = False
|
|
||||||
elif isinstance(recv_req, FlushCacheReq):
|
elif isinstance(recv_req, FlushCacheReq):
|
||||||
self.flush_cache()
|
self.flush_cache()
|
||||||
elif isinstance(recv_req, AbortReq):
|
elif isinstance(recv_req, AbortReq):
|
||||||
@@ -279,11 +277,12 @@ class Scheduler:
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def forward_step(self):
|
def forward_step(self):
|
||||||
if self.do_not_get_new_batch and self.current_inflight_req is None:
|
if (
|
||||||
|
self.batch_is_full or len(self.waiting_queue) == 0
|
||||||
|
) and self.current_inflight_req is None:
|
||||||
new_batch = None
|
new_batch = None
|
||||||
else:
|
else:
|
||||||
new_batch = self.get_new_prefill_batch()
|
new_batch = self.get_new_prefill_batch()
|
||||||
self.do_not_get_new_batch = False
|
|
||||||
|
|
||||||
if new_batch is not None:
|
if new_batch is not None:
|
||||||
# Run a new prefill batch
|
# Run a new prefill batch
|
||||||
@@ -447,6 +446,7 @@ class Scheduler:
|
|||||||
len(self.running_batch.reqs) if self.running_batch is not None else 0
|
len(self.running_batch.reqs) if self.running_batch is not None else 0
|
||||||
)
|
)
|
||||||
if running_bs >= self.max_running_requests:
|
if running_bs >= self.max_running_requests:
|
||||||
|
self.batch_is_full = True
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Get priority queue
|
# Get priority queue
|
||||||
@@ -490,9 +490,11 @@ class Scheduler:
|
|||||||
)
|
)
|
||||||
> self.max_loras_per_batch
|
> self.max_loras_per_batch
|
||||||
):
|
):
|
||||||
|
self.batch_is_full = True
|
||||||
break
|
break
|
||||||
|
|
||||||
if adder.no_remaining_tokens():
|
if adder.no_remaining_tokens():
|
||||||
|
self.batch_is_full = True
|
||||||
break
|
break
|
||||||
req.init_next_round_input(None if prefix_computed else self.tree_cache)
|
req.init_next_round_input(None if prefix_computed else self.tree_cache)
|
||||||
res = adder.add_one_req(req)
|
res = adder.add_one_req(req)
|
||||||
@@ -500,6 +502,7 @@ class Scheduler:
|
|||||||
not res
|
not res
|
||||||
or running_bs + len(adder.can_run_list) >= self.max_running_requests
|
or running_bs + len(adder.can_run_list) >= self.max_running_requests
|
||||||
):
|
):
|
||||||
|
self.batch_is_full = True
|
||||||
break
|
break
|
||||||
|
|
||||||
can_run_list = adder.can_run_list
|
can_run_list = adder.can_run_list
|
||||||
@@ -810,9 +813,6 @@ class Scheduler:
|
|||||||
if req.top_logprobs_num > 0:
|
if req.top_logprobs_num > 0:
|
||||||
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
|
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
|
||||||
|
|
||||||
if not has_finished:
|
|
||||||
self.do_not_get_new_batch = True
|
|
||||||
|
|
||||||
self.handle_finished_requests(batch)
|
self.handle_finished_requests(batch)
|
||||||
|
|
||||||
def handle_finished_requests(self, batch: ScheduleBatch):
|
def handle_finished_requests(self, batch: ScheduleBatch):
|
||||||
@@ -833,6 +833,8 @@ class Scheduler:
|
|||||||
for i, req in enumerate(batch.reqs):
|
for i, req in enumerate(batch.reqs):
|
||||||
if not req.finished() and req is not self.current_inflight_req:
|
if not req.finished() and req is not self.current_inflight_req:
|
||||||
unfinished_indices.append(i)
|
unfinished_indices.append(i)
|
||||||
|
else:
|
||||||
|
self.batch_is_full = False
|
||||||
|
|
||||||
if req.finished() or (
|
if req.finished() or (
|
||||||
req.stream
|
req.stream
|
||||||
|
|||||||
@@ -514,7 +514,16 @@ def get_similarities(vec1, vec2):
|
|||||||
return F.cosine_similarity(torch.tensor(vec1), torch.tensor(vec2), dim=0)
|
return F.cosine_similarity(torch.tensor(vec1), torch.tensor(vec2), dim=0)
|
||||||
|
|
||||||
|
|
||||||
def run_bench_serving(model, num_prompts, request_rate, other_server_args):
|
def run_bench_serving(
|
||||||
|
model,
|
||||||
|
num_prompts,
|
||||||
|
request_rate,
|
||||||
|
other_server_args,
|
||||||
|
dataset_name="random",
|
||||||
|
random_input_len=4096,
|
||||||
|
random_output_len=2048,
|
||||||
|
disable_stream=False,
|
||||||
|
):
|
||||||
# Launch the server
|
# Launch the server
|
||||||
base_url = DEFAULT_URL_FOR_TEST
|
base_url = DEFAULT_URL_FOR_TEST
|
||||||
process = popen_launch_server(
|
process = popen_launch_server(
|
||||||
@@ -530,21 +539,21 @@ def run_bench_serving(model, num_prompts, request_rate, other_server_args):
|
|||||||
base_url=base_url,
|
base_url=base_url,
|
||||||
host=None,
|
host=None,
|
||||||
port=None,
|
port=None,
|
||||||
dataset_name="random",
|
dataset_name=dataset_name,
|
||||||
dataset_path="",
|
dataset_path="",
|
||||||
model=None,
|
model=None,
|
||||||
tokenizer=None,
|
tokenizer=None,
|
||||||
num_prompts=num_prompts,
|
num_prompts=num_prompts,
|
||||||
sharegpt_output_len=None,
|
sharegpt_output_len=None,
|
||||||
random_input_len=4096,
|
random_input_len=random_input_len,
|
||||||
random_output_len=2048,
|
random_output_len=random_output_len,
|
||||||
random_range_ratio=0.0,
|
random_range_ratio=0.0,
|
||||||
request_rate=request_rate,
|
request_rate=request_rate,
|
||||||
multi=None,
|
multi=None,
|
||||||
seed=0,
|
seed=0,
|
||||||
output_file=None,
|
output_file=None,
|
||||||
disable_tqdm=False,
|
disable_tqdm=False,
|
||||||
disable_stream=False,
|
disable_stream=disable_stream,
|
||||||
disable_ignore_eos=False,
|
disable_ignore_eos=False,
|
||||||
extra_request_body=None,
|
extra_request_body=None,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -20,7 +20,22 @@ class TestBenchServing(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if is_in_ci():
|
if is_in_ci():
|
||||||
assert res["output_throughput"] > 2600
|
assert res["output_throughput"] > 2830
|
||||||
|
|
||||||
|
def test_offline_throughput_non_stream_small_batch_size(self):
|
||||||
|
res = run_bench_serving(
|
||||||
|
model=DEFAULT_MODEL_NAME_FOR_TEST,
|
||||||
|
num_prompts=200,
|
||||||
|
request_rate=float("inf"),
|
||||||
|
dataset_name="sharegpt",
|
||||||
|
random_input_len=None,
|
||||||
|
random_output_len=None,
|
||||||
|
disable_stream=True,
|
||||||
|
other_server_args=["--max-running-requests", "10"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_in_ci():
|
||||||
|
assert res["output_throughput"] > 1000
|
||||||
|
|
||||||
def test_offline_throughput_without_radix_cache(self):
|
def test_offline_throughput_without_radix_cache(self):
|
||||||
res = run_bench_serving(
|
res = run_bench_serving(
|
||||||
@@ -31,7 +46,7 @@ class TestBenchServing(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if is_in_ci():
|
if is_in_ci():
|
||||||
assert res["output_throughput"] > 2800
|
assert res["output_throughput"] > 2880
|
||||||
|
|
||||||
def test_offline_throughput_without_chunked_prefill(self):
|
def test_offline_throughput_without_chunked_prefill(self):
|
||||||
res = run_bench_serving(
|
res = run_bench_serving(
|
||||||
@@ -58,7 +73,7 @@ class TestBenchServing(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if is_in_ci():
|
if is_in_ci():
|
||||||
assert res["output_throughput"] > 2600
|
assert res["output_throughput"] > 2930
|
||||||
|
|
||||||
def test_offline_throughput_default_fp8(self):
|
def test_offline_throughput_default_fp8(self):
|
||||||
res = run_bench_serving(
|
res = run_bench_serving(
|
||||||
|
|||||||
Reference in New Issue
Block a user