From 287d07a669d3fd0b0650959b0e35c8e886513824 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 20 Jan 2025 20:25:13 -0800 Subject: [PATCH] Misc fixes for eagle (flush_cache, CPU overhead) (#3014) --- python/sglang/bench_offline_throughput.py | 28 +++--- python/sglang/bench_serving.py | 93 ++++++++++--------- python/sglang/srt/managers/scheduler.py | 11 ++- .../srt/model_executor/forward_batch_info.py | 8 +- python/sglang/srt/server.py | 4 +- python/sglang/srt/speculative/eagle_utils.py | 43 ++++++--- python/sglang/srt/speculative/eagle_worker.py | 24 +++-- python/sglang/srt/utils.py | 7 ++ python/sglang/test/test_programs.py | 3 +- python/sglang/test/test_utils.py | 7 +- test/lang/test_srt_backend.py | 1 + 11 files changed, 133 insertions(+), 96 deletions(-) diff --git a/python/sglang/bench_offline_throughput.py b/python/sglang/bench_offline_throughput.py index b0a715e61..9d56ff07c 100644 --- a/python/sglang/bench_offline_throughput.py +++ b/python/sglang/bench_offline_throughput.py @@ -49,12 +49,13 @@ class BenchArgs: gsp_system_prompt_len: int = 2048 gsp_question_len: int = 128 gsp_output_len: int = 256 + seed: int = 1 disable_ignore_eos: bool = False extra_request_body: Optional[str] = None - seed: int = 1 + apply_chat_template: bool = False + profile: bool = False skip_warmup: bool = False do_not_exit: bool = False - profile: bool = False @staticmethod def add_cli_args(parser: argparse.ArgumentParser): @@ -141,20 +142,31 @@ class BenchArgs: default=BenchArgs.gsp_output_len, help="Target length in tokens for outputs in generated-shared-prefix dataset", ) + parser.add_argument("--seed", type=int, default=1, help="The random seed.") parser.add_argument( "--disable-ignore-eos", - type=bool, - default=BenchArgs.disable_ignore_eos, + action="store_true", help="Disable ignore EOS token", ) parser.add_argument( "--extra-request-body", metavar='{"key1": "value1", "key2": "value2"}', type=str, + default=BenchArgs.extra_request_body, help="Append given JSON object to the request payload. You can use this to specify" "additional generate params like sampling params.", ) - parser.add_argument("--seed", type=int, default=1, help="The random seed.") + parser.add_argument( + "--apply-chat-template", + action="store_true", + help="Apply chat template", + ) + parser.add_argument( + "--profile", + action="store_true", + help="Use Torch Profiler. The endpoint must be launched with " + "SGLANG_TORCH_PROFILER_DIR to enable profiler.", + ) parser.add_argument( "--skip-warmup", action="store_true", @@ -165,12 +177,6 @@ class BenchArgs: action="store_true", help="Do not exit the program. This is useful for nsys profile with --duration and --delay.", ) - parser.add_argument( - "--profile", - action="store_true", - help="Use Torch Profiler. The endpoint must be launched with " - "SGLANG_TORCH_PROFILER_DIR to enable profiler.", - ) @classmethod def from_cli_args(cls, args: argparse.Namespace): diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 991b4ddcf..10ce965be 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -453,6 +453,7 @@ def get_dataset(args, tokenizer): tokenizer=tokenizer, fixed_output_len=args.sharegpt_output_len, context_len=args.sharegpt_context_len, + apply_chat_template=args.apply_chat_template, ) elif args.dataset_name == "random": input_requests = sample_random_requests( @@ -517,6 +518,7 @@ class BenchmarkMetrics: median_e2e_latency_ms: float std_e2e_latency_ms: float p99_e2e_latency_ms: float + concurrency: float SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json" @@ -562,6 +564,7 @@ def sample_sharegpt_requests( tokenizer: PreTrainedTokenizerBase, fixed_output_len: Optional[int] = None, context_len: Optional[int] = None, + apply_chat_template=False, ) -> List[Tuple[str, int, int]]: if fixed_output_len is not None and fixed_output_len < 4: raise ValueError("output_len too small") @@ -592,6 +595,15 @@ def sample_sharegpt_requests( # Tokenize the prompts and completions. prompt = dataset[i][0] + + if apply_chat_template: + prompt = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + add_generation_prompt=True, + tokenize=False, + ) + prompt = prompt.replace(tokenizer.bos_token, "") + prompt_token_ids = tokenizer.encode(prompt) completion = dataset[i][1] completion_token_ids = tokenizer.encode(completion) @@ -600,7 +612,7 @@ def sample_sharegpt_requests( len(completion_token_ids) if fixed_output_len is None else fixed_output_len ) - if prompt_len < 1 or output_len < 1: + if prompt_len < 2 or output_len < 2: # Prune too short sequences. continue @@ -880,6 +892,7 @@ def calculate_metrics( median_e2e_latency_ms=np.median(e2e_latencies) * 1000, std_e2e_latency_ms=np.std(e2e_latencies) * 1000, p99_e2e_latency_ms=np.percentile(e2e_latencies, 99) * 1000, + concurrency=np.sum(e2e_latencies) / dur_s, ) return metrics, output_lens @@ -1031,6 +1044,7 @@ async def benchmark( "Total token throughput (tok/s):", metrics.total_throughput ) ) + print("{:<40} {:<10.2f}".format("Concurrency:", metrics.concurrency)) print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-")) print( "{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms) @@ -1062,13 +1076,24 @@ async def benchmark( and metrics.output_throughput is not None ): result = { + # Arguments "backend": args.backend, "dataset_name": args.dataset_name, "request_rate": request_rate, "max_concurrency": max_concurrency, + "sharegpt_output_len": args.sharegpt_output_len, + "random_input_len": args.random_input_len, + "random_output_len": args.random_output_len, + "random_range_ratio": args.random_range_ratio, + # Results + "duration": benchmark_duration, + "completed": metrics.completed, "total_input_tokens": metrics.total_input, "total_output_tokens": metrics.total_output, "total_output_tokens_retokenized": metrics.total_output_retokenized, + "request_throughput": metrics.request_throughput, + "input_throughput": metrics.input_throughput, + "output_throughput": metrics.output_throughput, "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms, "median_e2e_latency_ms": metrics.median_e2e_latency_ms, "std_e2e_latency_ms": metrics.std_e2e_latency_ms, @@ -1085,14 +1110,7 @@ async def benchmark( "median_itl_ms": metrics.median_itl_ms, "std_itl_ms": metrics.std_itl_ms, "p99_itl_ms": metrics.p99_itl_ms, - "input_throughput": metrics.input_throughput, - "output_throughput": metrics.output_throughput, - "sharegpt_output_len": args.sharegpt_output_len, - "random_input_len": args.random_input_len, - "random_output_len": args.random_output_len, - "random_range_ratio": args.random_range_ratio, - "duration": benchmark_duration, - "completed": metrics.completed, + "concurrency": metrics.concurrency, } else: print(f"Error running benchmark for request rate: {request_rate}") @@ -1112,36 +1130,16 @@ async def benchmark( with open(output_file_name, "a") as file: file.write(json.dumps(result) + "\n") - result = { - "duration": benchmark_duration, - "completed": metrics.completed, - "total_input_tokens": metrics.total_input, - "total_output_tokens": metrics.total_output, - "total_output_tokens_retokenized": metrics.total_output_retokenized, - "request_throughput": metrics.request_throughput, - "input_throughput": metrics.input_throughput, - "output_throughput": metrics.output_throughput, - "mean_ttft_ms": metrics.mean_ttft_ms, - "median_ttft_ms": metrics.median_ttft_ms, - "std_ttft_ms": metrics.std_ttft_ms, - "p99_ttft_ms": metrics.p99_ttft_ms, - "mean_tpot_ms": metrics.mean_tpot_ms, - "median_tpot_ms": metrics.median_tpot_ms, - "std_tpot_ms": metrics.std_tpot_ms, - "p99_tpot_ms": metrics.p99_tpot_ms, - "mean_itl_ms": metrics.mean_itl_ms, - "median_itl_ms": metrics.median_itl_ms, - "std_itl_ms": metrics.std_itl_ms, - "p99_itl_ms": metrics.p99_itl_ms, - "input_lens": [output.prompt_len for output in outputs], - "output_lens": output_lens, - "ttfts": [output.ttft for output in outputs], - "itls": [output.itl for output in outputs], - "generated_texts": [output.generated_text for output in outputs], - "errors": [output.error for output in outputs], - "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms, - "median_e2e_latency_ms": metrics.median_e2e_latency_ms, - } + result.update( + { + "input_lens": [output.prompt_len for output in outputs], + "output_lens": output_lens, + "ttfts": [output.ttft for output in outputs], + "itls": [output.itl for output in outputs], + "generated_texts": [output.generated_text for output in outputs], + "errors": [output.error for output in outputs], + } + ) return result @@ -1422,7 +1420,6 @@ if __name__ == "__main__": "actual request rate may be lower than specified with --request-rate, " "if the server is not processing requests fast enough to keep up.", ) - parser.add_argument("--seed", type=int, default=1, help="The random seed.") parser.add_argument( "--multi", action="store_true", @@ -1445,16 +1442,17 @@ if __name__ == "__main__": action="store_true", help="Disable streaming mode.", ) - parser.add_argument( - "--disable-ignore-eos", - action="store_true", - help="Disable ignoring EOS.", - ) parser.add_argument( "--return-logprob", action="store_true", help="Return logprob.", ) + parser.add_argument("--seed", type=int, default=1, help="The random seed.") + parser.add_argument( + "--disable-ignore-eos", + action="store_true", + help="Disable ignoring EOS.", + ) parser.add_argument( "--extra-request-body", metavar='{"key1": "value1", "key2": "value2"}', @@ -1462,6 +1460,11 @@ if __name__ == "__main__": help="Append given JSON object to the request payload. You can use this to specify" "additional generate params like sampling params.", ) + parser.add_argument( + "--apply-chat-template", + action="store_true", + help="Apply chat template", + ) parser.add_argument( "--profile", action="store_true", diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index fba8a67ec..85bd1c2a4 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1023,7 +1023,7 @@ class Scheduler: ) # Check for jump-forward - if not self.disable_jump_forward: + if not self.disable_jump_forward and batch.has_grammar: jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func) self.waiting_queue.extend(jump_forward_reqs) if batch.is_empty(): @@ -1564,6 +1564,15 @@ class Scheduler: self.grammar_backend.reset() self.req_to_token_pool.clear() self.token_to_kv_pool.clear() + + if not self.spec_algorithm.is_none(): + self.draft_worker.model_runner.req_to_token_pool.clear() + self.draft_worker.model_runner.token_to_kv_pool.clear() + + self.num_generated_tokens = 0 + self.forward_ct_decode = 0 + self.spec_num_total_accepted_tokens = 0 + self.spec_num_total_forward_ct = 0 torch.cuda.empty_cache() logger.info("Cache flushed successfully!") if_success = True diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 354408ab3..8ef5c57b8 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -282,6 +282,9 @@ class ForwardBatch: can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph, lora_paths=batch.lora_paths, sampling_info=batch.sampling_info, + req_to_token_pool=model_runner.req_to_token_pool, + token_to_kv_pool=model_runner.token_to_kv_pool, + attn_backend=model_runner.attn_backend, spec_algorithm=batch.spec_algorithm, spec_info=batch.spec_info, capture_hidden_mode=batch.capture_hidden_mode, @@ -336,11 +339,6 @@ class ForwardBatch: if model_runner.model_is_mrope: ret.compute_mrope_positions(model_runner, batch) - # Init attention information - ret.req_to_token_pool = model_runner.req_to_token_pool - ret.token_to_kv_pool = model_runner.token_to_kv_pool - ret.attn_backend = model_runner.attn_backend - # Init lora information if model_runner.server_args.lora_paths is not None: model_runner.lora_manager.prepare_lora_batch(ret) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 8b0c56186..869a984d0 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -12,7 +12,7 @@ # limitations under the License. # ============================================================================== -# Some shortcuts for backward compatbility. +# Some shortcuts for backward compatibility. # They will be removed in new versions. from sglang.srt.entrypoints.engine import Engine -from sglang.srt.entrypoints.http_server import launch_server +from sglang.srt.entrypoints.http_server import kill_process_tree, launch_server diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index ac16f6c53..049ba2275 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -180,7 +180,6 @@ def generate_draft_decode_kv_indices( class EAGLEDraftInput(SpecInfo): def __init__(self): self.prev_mode = ForwardMode.DECODE - self.sample_output = None self.scores: torch.Tensor = None self.score_list: List[torch.Tensor] = [] @@ -190,12 +189,16 @@ class EAGLEDraftInput(SpecInfo): self.cache_list: List[torch.Tenor] = [] self.iter = 0 + # shape: (b, hidden_size) self.hidden_states: torch.Tensor = None + # shape: (b,) self.verified_id: torch.Tensor = None + # shape: (b, vocab_size) + self.sample_output: torch.Tensor = None + self.positions: torch.Tensor = None self.accept_length: torch.Tensor = None - self.has_finished: bool = False - self.unfinished_index: List[int] = None + self.accept_length_cpu: List[int] = None def load_server_args(self, server_args: ServerArgs): self.topk: int = server_args.speculative_eagle_topk @@ -218,7 +221,7 @@ class EAGLEDraftInput(SpecInfo): :pre_len ] = req.prefix_indices - batch.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = ( + batch.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = ( out_cache_loc[pt : pt + req.extend_input_len] ) @@ -295,7 +298,9 @@ class EAGLEDraftInput(SpecInfo): self.cache_list.append(batch.out_cache_loc) self.positions = ( batch.seq_lens[:, None] - + torch.ones([1, self.topk], device="cuda", dtype=torch.long) * self.iter + + torch.full( + [1, self.topk], fill_value=self.iter, device="cuda", dtype=torch.long + ) ).flatten() bs = len(batch.seq_lens) @@ -312,24 +317,25 @@ class EAGLEDraftInput(SpecInfo): def prepare_extend_after_decode(self, batch: ScheduleBatch): batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel()) - batch.extend_lens = (self.accept_length + 1).tolist() + accept_length_cpu = batch.spec_info.accept_length_cpu + batch.extend_lens = [x + 1 for x in accept_length_cpu] + batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend + seq_lens_cpu = batch.seq_lens.tolist() pt = 0 - seq_lens = batch.seq_lens.tolist() - i = 0 - for req in batch.reqs: if req.finished(): continue # assert seq_len - pre_len == req.extend_input_len - input_len = self.accept_length[i] + 1 - seq_len = seq_lens[i] + input_len = batch.extend_lens[i] + seq_len = seq_lens_cpu[i] batch.req_to_token_pool.req_to_token[req.req_pool_idx][ seq_len - input_len : seq_len ] = batch.out_cache_loc[pt : pt + input_len] pt += input_len i += 1 + assert pt == batch.out_cache_loc.shape[0] self.positions = torch.empty_like(self.verified_id) new_verified_id = torch.empty_like(self.accept_length, dtype=torch.long) @@ -345,7 +351,7 @@ class EAGLEDraftInput(SpecInfo): triton.next_power_of_2(self.spec_steps + 1), ) - batch.seq_lens_sum = sum(batch.seq_lens) + batch.seq_lens_sum = sum(seq_lens_cpu) batch.input_ids = self.verified_id self.verified_id = new_verified_id @@ -573,6 +579,8 @@ class EagleVerifyInput(SpecInfo): finished_extend_len = {} # {rid:accept_length + 1} accept_index_cpu = accept_index.tolist() predict_cpu = predict.tolist() + has_finished = False + # iterate every accepted token and check if req has finished after append the token # should be checked BEFORE free kv cache slots for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)): @@ -586,7 +594,7 @@ class EagleVerifyInput(SpecInfo): finished_extend_len[req.rid] = j + 1 req.check_finished() if req.finished(): - draft_input.has_finished = True + has_finished = True # set all tokens after finished token to -1 and break accept_index[i, j + 1 :] = -1 break @@ -600,7 +608,6 @@ class EagleVerifyInput(SpecInfo): accept_index = accept_index[accept_index != -1] accept_length_cpu = accept_length.tolist() verified_id = predict[accept_index] - verified_id_cpu = verified_id.tolist() evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool) evict_mask[accept_index] = False @@ -622,7 +629,13 @@ class EagleVerifyInput(SpecInfo): draft_input.verified_id = predict[new_accept_index] draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index] draft_input.accept_length = accept_length[unfinished_index] - draft_input.unfinished_index = unfinished_index + draft_input.accept_length_cpu = [ + accept_length_cpu[i] for i in unfinished_index + ] + if has_finished: + draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index] + else: + draft_input.seq_lens_for_draft_extend = batch.seq_lens logits_output.next_token_logits = logits_output.next_token_logits[accept_index] return ( diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 2a6ec9604..06a4372fc 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -13,6 +13,7 @@ from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.eagle_utils import EAGLEDraftInput +from sglang.srt.utils import rank0_print class EAGLEWorker(TpModelWorker): @@ -50,18 +51,18 @@ class EAGLEWorker(TpModelWorker): def forward_draft_decode(self, batch: ScheduleBatch): batch.spec_info.prepare_for_decode(batch) + batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST model_worker_batch = batch.get_model_worker_batch() forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) - forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST logits_output = self.model_runner.forward(forward_batch) self.capture_for_decode(logits_output, forward_batch) def forward_draft_extend(self, batch: ScheduleBatch): self._set_mem_pool(batch, self.model_runner) batch.spec_info.prepare_for_extend(batch) + batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST model_worker_batch = batch.get_model_worker_batch() forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) - forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST logits_output = self.model_runner.forward(forward_batch) self.capture_for_decode(logits_output, forward_batch) self._set_mem_pool(batch, self.target_worker.model_runner) @@ -134,26 +135,23 @@ class EAGLEWorker(TpModelWorker): batch.req_to_token_pool = runner.req_to_token_pool def forward_draft_extend_after_decode(self, batch: ScheduleBatch): + seq_lens_backup = batch.seq_lens + self._set_mem_pool(batch, self.model_runner) batch.forward_mode = ForwardMode.DRAFT_EXTEND - if batch.spec_info.has_finished: - index = batch.spec_info.unfinished_index - seq_lens = batch.seq_lens - batch.seq_lens = batch.seq_lens[index] - batch.spec_info.prepare_extend_after_decode(batch) + batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST model_worker_batch = batch.get_model_worker_batch() forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) - forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST logits_output = self.model_runner.forward(forward_batch) - - batch.spec_info.hidden_states = logits_output.hidden_states self.capture_for_decode(logits_output, forward_batch) - batch.forward_mode = ForwardMode.DECODE - if batch.spec_info.has_finished: - batch.seq_lens = seq_lens self._set_mem_pool(batch, self.target_worker.model_runner) + # Restore backup. + # This is because `seq_lens` can be modified in `prepare_extend_after_decode` + batch.forward_mode = ForwardMode.DECODE + batch.seq_lens = seq_lens_backup + def capture_for_decode( self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch ): diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 4614114b4..23dcb43d2 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1442,3 +1442,10 @@ def is_valid_ipv6_address(address: str) -> bool: return True except ValueError: return False + + +def rank0_print(msg: str): + from sglang.srt.distributed import get_tensor_model_parallel_rank + + if get_tensor_model_parallel_rank() == 0: + print(msg, flush=True) diff --git a/python/sglang/test/test_programs.py b/python/sglang/test/test_programs.py index 361bbaed0..088cb0d0a 100644 --- a/python/sglang/test/test_programs.py +++ b/python/sglang/test/test_programs.py @@ -535,7 +535,8 @@ def test_hellaswag_select(): # Compute accuracy accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels)) - assert np.abs(accuracy_gen - accuracy) < 0.1 + print(f"{accuracy=}, {accuracy_gen=}") + assert np.abs(accuracy_gen - accuracy) < 0.05 assert np.abs(latency_gen - latency) < 1 return accuracy, latency diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index c1437074f..ad8ff6cbf 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -567,15 +567,16 @@ def run_bench_serving( random_range_ratio=0.0, request_rate=request_rate, multi=None, - seed=0, output_file=None, disable_tqdm=False, disable_stream=disable_stream, - disable_ignore_eos=False, return_logprob=False, - lora_name=None, + seed=0, + disable_ignore_eos=False, extra_request_body=None, + apply_chat_template=False, profile=None, + lora_name=None, ) try: diff --git a/test/lang/test_srt_backend.py b/test/lang/test_srt_backend.py index 0d7cc9105..a4b1b88a2 100644 --- a/test/lang/test_srt_backend.py +++ b/test/lang/test_srt_backend.py @@ -1,6 +1,7 @@ """ Usage: python3 -m unittest test_srt_backend.TestSRTBackend.test_gen_min_new_tokens +python3 -m unittest test_srt_backend.TestSRTBackend.test_hellaswag_select """ import unittest