From ac2387279ea98f28558e7a75972847b0ef9e8515 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 3 Mar 2025 00:12:04 -0800 Subject: [PATCH] Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988) Co-authored-by: SangBin Cho Co-authored-by: dhou-xai Co-authored-by: Hanming Lu --- .../benchmark_torch_compile_fused_moe.py | 11 +- ...nchmark_vllm_vs_sglang_fused_moe_triton.py | 9 + .../tuning_fused_moe_triton.py | 9 + docs/backend/native_api.ipynb | 5 +- docs/backend/server_arguments.md | 1 - python/pyproject.toml | 5 +- python/sglang/README.md | 4 +- python/sglang/bench_offline_throughput.py | 19 + python/sglang/bench_serving.py | 179 ++-- python/sglang/global_config.py | 4 +- .../sglang/lang/backend/runtime_endpoint.py | 7 +- python/sglang/srt/configs/load_config.py | 5 +- python/sglang/srt/configs/model_config.py | 27 +- .../srt/constrained/xgrammar_backend.py | 24 +- python/sglang/srt/entrypoints/engine.py | 29 +- python/sglang/srt/entrypoints/http_server.py | 162 +++- python/sglang/srt/entrypoints/verl_engine.py | 2 + .../sglang/srt/layers/attention/__init__.py | 19 +- .../layers/attention/flashinfer_backend.py | 16 +- .../srt/layers/attention/triton_backend.py | 8 +- python/sglang/srt/layers/dp_attention.py | 145 +++- python/sglang/srt/layers/layernorm.py | 2 +- python/sglang/srt/layers/linear.py | 3 +- python/sglang/srt/layers/logits_processor.py | 316 ++++++- .../sglang/srt/layers/moe/ep_moe/kernels.py | 67 ++ python/sglang/srt/layers/moe/ep_moe/layer.py | 12 + .../sglang/srt/layers/moe/fused_moe_native.py | 2 + .../layers/moe/fused_moe_triton/fused_moe.py | 41 +- .../srt/layers/moe/fused_moe_triton/layer.py | 20 +- python/sglang/srt/layers/quantization/fp8.py | 6 +- python/sglang/srt/layers/rotary_embedding.py | 1 - python/sglang/srt/layers/sampler.py | 107 ++- .../srt/layers/vocab_parallel_embedding.py | 2 +- .../sglang/srt/managers/configure_logging.py | 3 +- .../srt/managers/data_parallel_controller.py | 3 + .../srt/managers/detokenizer_manager.py | 42 +- python/sglang/srt/managers/io_struct.py | 104 ++- python/sglang/srt/managers/schedule_batch.py | 255 ++++-- python/sglang/srt/managers/schedule_policy.py | 25 +- python/sglang/srt/managers/scheduler.py | 784 ++++++++++++++---- .../sglang/srt/managers/session_controller.py | 8 +- .../sglang/srt/managers/tokenizer_manager.py | 301 +++++-- python/sglang/srt/managers/tp_worker.py | 7 +- .../srt/managers/tp_worker_overlap_thread.py | 5 +- python/sglang/srt/mem_cache/chunk_cache.py | 17 +- python/sglang/srt/metrics/collector.py | 240 +++++- .../srt/model_executor/cuda_graph_runner.py | 22 +- .../srt/model_executor/forward_batch_info.py | 26 +- .../sglang/srt/model_executor/model_runner.py | 110 ++- .../sglang/srt/model_loader/weight_utils.py | 37 +- .../srt/sampling/penaltylib/__init__.py | 10 +- .../sampling/penaltylib/frequency_penalty.py | 66 ++ .../{penalizers => }/min_new_tokens.py | 38 +- .../srt/sampling/penaltylib/orchestrator.py | 231 +----- .../penalizers/frequency_penalty.py | 75 -- .../penaltylib/penalizers/presence_penalty.py | 74 -- .../penalizers/repetition_penalty.py | 85 -- .../sampling/penaltylib/presence_penalty.py | 66 ++ .../srt/sampling/sampling_batch_info.py | 237 ++---- python/sglang/srt/server_args.py | 102 ++- python/sglang/srt/utils.py | 41 +- python/sglang/srt/warmup.py | 47 ++ python/sglang/test/runners.py | 196 ++++- python/sglang/test/send_one.py | 88 ++ python/sglang/test/test_utils.py | 136 ++- scripts/killall_sglang.sh | 17 +- scripts/playground/bench_speculative.py | 257 ++++++ sgl-kernel/src/sgl-kernel/__init__.py | 6 +- test/lang/test_srt_backend.py | 4 +- test/srt/run_suite.py | 10 +- .../penalizers/test_frequency_penalty.py | 97 --- .../penalizers/test_min_new_tokens.py | 152 ---- .../penalizers/test_presence_penalty.py | 93 --- .../penalizers/test_repetition_penalty.py | 87 -- .../test_srt_endpoint_with_penalizers.py | 114 --- test/srt/test_bench_serving.py | 12 +- test/srt/test_eval_accuracy_large.py | 15 + test/srt/test_health_check.py | 27 + test/srt/test_hidden_states.py | 2 +- test/srt/test_metrics.py | 2 + test/srt/test_mla.py | 2 +- test/srt/test_penalty.py | 91 ++ test/srt/test_session_control.py | 78 +- test/srt/test_skip_tokenizer_init.py | 129 ++- test/srt/test_srt_endpoint.py | 184 ++-- test/srt/test_verl_engine.py | 2 +- 86 files changed, 4116 insertions(+), 2015 deletions(-) create mode 100644 python/sglang/srt/sampling/penaltylib/frequency_penalty.py rename python/sglang/srt/sampling/penaltylib/{penalizers => }/min_new_tokens.py (70%) delete mode 100644 python/sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py delete mode 100644 python/sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py delete mode 100644 python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py create mode 100644 python/sglang/srt/sampling/penaltylib/presence_penalty.py create mode 100644 python/sglang/srt/warmup.py create mode 100644 python/sglang/test/send_one.py create mode 100644 scripts/playground/bench_speculative.py delete mode 100644 test/srt/sampling/penaltylib/penalizers/test_frequency_penalty.py delete mode 100644 test/srt/sampling/penaltylib/penalizers/test_min_new_tokens.py delete mode 100644 test/srt/sampling/penaltylib/penalizers/test_presence_penalty.py delete mode 100644 test/srt/sampling/penaltylib/penalizers/test_repetition_penalty.py delete mode 100644 test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py create mode 100644 test/srt/test_health_check.py create mode 100644 test/srt/test_penalty.py diff --git a/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py b/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py index ce5a3399a..6f8fcc01b 100644 --- a/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py +++ b/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py @@ -30,11 +30,20 @@ def get_model_config(model_name: str, tp_size: int): topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size shard_intermediate_size = 2 * intermediate_size // tp_size + elif config.architectures[0] in [ + "Grok1ForCausalLM", + "Grok1ImgGen", + "Grok1AForCausalLM", + ]: + E = config.num_local_experts + topk = config.num_experts_per_tok + intermediate_size = config.moe_intermediate_size + shard_intermediate_size = 2 * intermediate_size // tp_size elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]: E = config.n_routed_experts topk = config.num_experts_per_tok intermediate_size = config.intermediate_size - shard_intermediate_size = 2 * intermediate_size // args.tp_size + shard_intermediate_size = 2 * intermediate_size // tp_size else: # Default: Mixtral E = config.num_local_experts diff --git a/benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py b/benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py index 6a4605eb5..0760116fb 100644 --- a/benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py +++ b/benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py @@ -35,6 +35,15 @@ def get_model_config(model_name: str, tp_size: int): topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size shard_intermediate_size = 2 * intermediate_size // tp_size + elif config.architectures[0] in [ + "Grok1ForCausalLM", + "Grok1ImgGen", + "Grok1AForCausalLM", + ]: + E = config.num_local_experts + topk = config.num_experts_per_tok + intermediate_size = config.moe_intermediate_size + shard_intermediate_size = 2 * intermediate_size // tp_size else: # Default: Mixtral E = config.num_local_experts diff --git a/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py b/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py index a14cf5ee9..b1fd94c09 100644 --- a/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py +++ b/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py @@ -397,6 +397,15 @@ def main(args: argparse.Namespace): topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size shard_intermediate_size = 2 * intermediate_size // args.tp_size + elif config.architectures[0] in [ + "Grok1ForCausalLM", + "Grok1ImgGen", + "Grok1AForCausalLM", + ]: + E = config.num_local_experts + topk = config.num_experts_per_tok + intermediate_size = config.moe_intermediate_size + shard_intermediate_size = 2 * intermediate_size // args.tp_size else: # Default: Mixtral E = config.num_local_experts diff --git a/docs/backend/native_api.ipynb b/docs/backend/native_api.ipynb index cf76871a7..b75eb8205 100644 --- a/docs/backend/native_api.ipynb +++ b/docs/backend/native_api.ipynb @@ -210,8 +210,7 @@ "response = requests.post(url, json=data)\n", "print_highlight(response.text)\n", "assert response.json()[\"success\"] is True\n", - "assert response.json()[\"message\"] == \"Succeeded to update model weights.\"\n", - "assert response.json().keys() == {\"success\", \"message\"}" + "assert response.json()[\"message\"] == \"Succeeded to update model weights.\"" ] }, { @@ -411,7 +410,7 @@ " },\n", ")\n", "output = response.json()\n", - "output_tokens = output[\"token_ids\"]\n", + "output_tokens = output[\"output_ids\"]\n", "\n", "output_text = tokenizer.decode(output_tokens, skip_special_tokens=False)\n", "print_highlight(f\"Tokenized Output: {output_tokens}\")\n", diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index 9a28167a9..0c4d7840e 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -96,7 +96,6 @@ Please consult the documentation below to learn more about the parameters you ma * `schedule_policy`: The scheduling policy to control the processing order of waiting prefill requests in a single engine. * `schedule_conservativeness`: Can be used to decrease/increase the conservativeness of the server when taking new requests. Highly conservative behavior leads to starvation, but low conservativeness leads to slowed-down performance. * `cpu_offload_gb`: Reserve this amount of RAM in GB for offloading of model parameters to the CPU. -* `prefill_only_one_req`: When this flag is turned on, the engine prefills only one request at a time. ## Other runtime options diff --git a/python/pyproject.toml b/python/pyproject.toml index 800febc42..cc0c5de3a 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -96,7 +96,10 @@ dev_cpu = ["sglang[all_cpu]", "sglang[test]"] "Bug Tracker" = "https://github.com/sgl-project/sglang/issues" [tool.setuptools.package-data] -"sglang" = ["srt/layers/moe/fused_moe_triton/configs/*.json", "srt/layers/quantization/configs/*.json"] +"sglang" = [ + "srt/layers/moe/fused_moe_triton/configs/*.json", + "srt/layers/quantization/configs/*.json", +] [tool.setuptools.packages.find] exclude = [ diff --git a/python/sglang/README.md b/python/sglang/README.md index e5fe16005..9eab54601 100644 --- a/python/sglang/README.md +++ b/python/sglang/README.md @@ -8,8 +8,10 @@ - `bench_one_batch.py`: Benchmark the latency of running a single static batch without a server. - `bench_one_batch_server.py`: Benchmark the latency of running a single batch with a server. - `bench_serving.py`: Benchmark online serving with dynamic requests. -- `check_env.py`: Check the environment variables. +- `check_env.py`: Check the environment variables and dependencies. - `global_config.py`: The global configs and constants. - `launch_server.py`: The entry point for launching the local server. - `llama3_eval.py`: Evaluation of Llama 3 using the Meta Llama dataset. +- `profiler.py`: Profile a running server. - `utils.py`: Common utilities. +- `version.py`: Version info. diff --git a/python/sglang/bench_offline_throughput.py b/python/sglang/bench_offline_throughput.py index 9d56ff07c..69bbc3e4d 100644 --- a/python/sglang/bench_offline_throughput.py +++ b/python/sglang/bench_offline_throughput.py @@ -56,6 +56,7 @@ class BenchArgs: profile: bool = False skip_warmup: bool = False do_not_exit: bool = False + prompt_suffix: str = "" @staticmethod def add_cli_args(parser: argparse.ArgumentParser): @@ -177,6 +178,12 @@ class BenchArgs: action="store_true", help="Do not exit the program. This is useful for nsys profile with --duration and --delay.", ) + parser.add_argument( + "--prompt-suffix", + type=str, + default="", + help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.", + ) @classmethod def from_cli_args(cls, args: argparse.Namespace): @@ -216,6 +223,10 @@ def throughput_test_once( ] if profile: + assert ( + "SGLANG_TORCH_PROFILER_DIR" in os.environ + ), "Please set SGLANG_TORCH_PROFILER_DIR." + os.makedirs(os.environ["SGLANG_TORCH_PROFILER_DIR"], exist_ok=True) backend.start_profile() st = time.perf_counter() @@ -229,6 +240,8 @@ def throughput_test_once( if backend_name == "runtime": gen_out = json.loads(gen_out) + server_info = backend.get_server_info() + measurement_results["total_latency"] = latency measurement_results["total_output_tokens"] = sum( o["meta_info"]["completion_tokens"] for o in gen_out @@ -246,6 +259,7 @@ def throughput_test_once( measurement_results["total_input_tokens"] + measurement_results["total_output_tokens"] ) / latency + measurement_results["last_gen_throughput"] = server_info["last_gen_throughput"] return measurement_results @@ -361,6 +375,11 @@ def throughput_test( print( "{:<40} {:<10}".format("Total generated tokens:", result["total_output_tokens"]) ) + print( + "{:<40} {:<10.2f}".format( + "Last generation throughput (tok/s):", result["last_gen_throughput"] + ) + ) print( "{:<40} {:<10.2f}".format( "Request throughput (req/s):", result["request_throughput"] diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 41e1a6109..814ec40de 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -8,7 +8,6 @@ Usage: python3 -m sglang.bench_serving --backend sglang --num-prompt 10 python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3000 --random-input 1024 --random-output 1024 --random-range-ratio 0.5 -python3 -m sglang.bench_serving --backend sglang --dataset-name random --request-rate-range 1,2,4,8,16,32 --random-input 4096 --random-output 1024 --random-range-ratio 0.125 --multi """ import argparse @@ -71,6 +70,10 @@ def remove_prefix(text: str, prefix: str) -> str: return text[len(prefix) :] if text.startswith(prefix) else text +def remove_suffix(text: str, suffix: str) -> str: + return text[: -len(suffix)] if text.endswith(suffix) else text + + def get_auth_headers() -> Dict[str, str]: api_key = os.environ.get("OPENAI_API_KEY") if api_key: @@ -79,7 +82,7 @@ def get_auth_headers() -> Dict[str, str]: return {} -# trt llm not support ignore_eos +# trt llm does not support ignore_eos # https://github.com/triton-inference-server/tensorrtllm_backend/issues/505 async def async_request_trt_llm( request_func_input: RequestFuncInput, @@ -179,6 +182,7 @@ async def async_request_openai_completions( output.prompt_len = request_func_input.prompt_len generated_text = "" + output_len = request_func_input.output_len ttft = 0.0 st = time.perf_counter() most_recent_timestamp = st @@ -215,11 +219,14 @@ async def async_request_openai_completions( most_recent_timestamp = timestamp generated_text += data["choices"][0]["text"] + output_len = data.get("usage", {}).get( + "completion_tokens", output_len + ) output.generated_text = generated_text output.success = True output.latency = latency - output.output_len = request_func_input.output_len + output.output_len = output_len else: output.error = response.reason or "" output.success = False @@ -339,9 +346,11 @@ async def async_request_sglang_generate( output.prompt_len = request_func_input.prompt_len generated_text = "" + output_len = request_func_input.output_len ttft = 0.0 st = time.perf_counter() most_recent_timestamp = st + last_output_len = 0 try: async with session.post( url=api_url, json=payload, headers=headers @@ -365,6 +374,9 @@ async def async_request_sglang_generate( # want to check a token was generated if data["text"]: timestamp = time.perf_counter() + generated_text = data["text"] + output_len = data["meta_info"]["completion_tokens"] + # First token if ttft == 0.0: ttft = time.perf_counter() - st @@ -372,7 +384,13 @@ async def async_request_sglang_generate( # Decoding phase else: - output.itl.append(timestamp - most_recent_timestamp) + num_new_tokens = output_len - last_output_len + if num_new_tokens == 0: + continue + adjust_itl = ( + timestamp - most_recent_timestamp + ) / num_new_tokens + output.itl.extend([adjust_itl] * num_new_tokens) most_recent_timestamp = timestamp generated_text = data["text"] @@ -380,7 +398,7 @@ async def async_request_sglang_generate( output.generated_text = generated_text output.success = True output.latency = latency - output.output_len = request_func_input.output_len + output.output_len = output_len else: output.error = response.reason or "" output.success = False @@ -388,6 +406,7 @@ async def async_request_sglang_generate( output.success = False exc_info = sys.exc_info() output.error = "".join(traceback.format_exception(*exc_info)) + print(f"{output.error=}") if pbar: pbar.update(1) @@ -461,6 +480,7 @@ def get_dataset(args, tokenizer): tokenizer=tokenizer, fixed_output_len=args.sharegpt_output_len, context_len=args.sharegpt_context_len, + prompt_suffix=args.prompt_suffix, apply_chat_template=args.apply_chat_template, ) elif args.dataset_name == "random": @@ -521,7 +541,9 @@ class BenchmarkMetrics: mean_itl_ms: float median_itl_ms: float std_itl_ms: float + p95_itl_ms: float p99_itl_ms: float + max_itl_ms: float mean_e2e_latency_ms: float median_e2e_latency_ms: float std_e2e_latency_ms: float @@ -572,6 +594,7 @@ def sample_sharegpt_requests( tokenizer: PreTrainedTokenizerBase, fixed_output_len: Optional[int] = None, context_len: Optional[int] = None, + prompt_suffix: Optional[str] = "", apply_chat_template=False, ) -> List[Tuple[str, int, int]]: if fixed_output_len is not None and fixed_output_len < 4: @@ -584,11 +607,19 @@ def sample_sharegpt_requests( # Load the dataset. with open(dataset_path) as f: dataset = json.load(f) + # Filter out the conversations with less than 2 turns. - dataset = [data for data in dataset if len(data["conversations"]) >= 2] + dataset = [ + data + for data in dataset + if len(data.get("conversations", data.get("conversation", []))) >= 2 + ] # Only keep the first two turns of each conversation. dataset = [ - (data["conversations"][0]["value"], data["conversations"][1]["value"]) + ( + data.get("conversations", data.get("conversation", []))[0]["value"], + data.get("conversations", data.get("conversation", []))[1]["value"], + ) for data in dataset ] @@ -603,6 +634,8 @@ def sample_sharegpt_requests( # Tokenize the prompts and completions. prompt = dataset[i][0] + if prompt_suffix: + prompt = prompt if apply_chat_template: prompt = tokenizer.apply_chat_template( @@ -666,10 +699,17 @@ def sample_random_requests( with open(dataset_path) as f: dataset = json.load(f) # Filter out the conversations with less than 2 turns. - dataset = [data for data in dataset if len(data["conversations"]) >= 2] + dataset = [ + data + for data in dataset + if len(data.get("conversations", data.get("conversation", []))) >= 2 + ] # Only keep the first two turns of each conversation. dataset = [ - (data["conversations"][0]["value"], data["conversations"][1]["value"]) + ( + data.get("conversations", data.get("conversation", []))[0]["value"], + data.get("conversations", data.get("conversation", []))[1]["value"], + ) for data in dataset ] # Shuffle the dataset. @@ -895,7 +935,9 @@ def calculate_metrics( mean_itl_ms=np.mean(itls or 0) * 1000, median_itl_ms=np.median(itls or 0) * 1000, std_itl_ms=np.std(itls or 0) * 1000, + p95_itl_ms=np.percentile(itls or 0, 95) * 1000, p99_itl_ms=np.percentile(itls or 0, 99) * 1000, + max_itl_ms=np.max(itls or 0) * 1000, mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000, median_e2e_latency_ms=np.median(e2e_latencies) * 1000, std_e2e_latency_ms=np.std(e2e_latencies) * 1000, @@ -919,6 +961,7 @@ async def benchmark( lora_name: str, extra_request_body: Dict[str, Any], profile: bool, + pd_seperated: bool = False, ): if backend in ASYNC_REQUEST_FUNCS: request_func = ASYNC_REQUEST_FUNCS[backend] @@ -1004,6 +1047,17 @@ async def benchmark( if pbar is not None: pbar.close() + if "sglang" in backend: + server_info = requests.get(base_url + "/get_server_info") + if pd_seperated: + accept_length = server_info.json()["decode"][0].get( + "avg_spec_accept_length", None + ) + else: + accept_length = server_info.json().get("avg_spec_accept_length", None) + else: + accept_length = None + # Compute metrics and print results benchmark_duration = time.perf_counter() - benchmark_start_time metrics, output_lens = calculate_metrics( @@ -1053,6 +1107,8 @@ async def benchmark( ) ) print("{:<40} {:<10.2f}".format("Concurrency:", metrics.concurrency)) + if accept_length: + print("{:<40} {:<10.2f}".format("Accept length:", accept_length)) print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-")) print( "{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms) @@ -1066,16 +1122,12 @@ async def benchmark( print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms)) print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms)) print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms)) - print( - "{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-") - ) - print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms)) - print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms)) - print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms)) - print("{s:{c}^{n}}".format(s="Inter-token Latency", n=50, c="-")) + print("{s:{c}^{n}}".format(s="Inter-Token Latency", n=50, c="-")) print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms)) print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms)) + print("{:<40} {:<10.2f}".format("P95 ITL (ms):", metrics.p95_itl_ms)) print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms)) + print("{:<40} {:<10.2f}".format("Max ITL (ms):", metrics.max_itl_ms)) print("=" * 50) if ( @@ -1117,8 +1169,10 @@ async def benchmark( "mean_itl_ms": metrics.mean_itl_ms, "median_itl_ms": metrics.median_itl_ms, "std_itl_ms": metrics.std_itl_ms, + "p95_itl_ms": metrics.p95_itl_ms, "p99_itl_ms": metrics.p99_itl_ms, "concurrency": metrics.concurrency, + "accept_length": accept_length, } else: print(f"Error running benchmark for request rate: {request_rate}") @@ -1151,14 +1205,6 @@ async def benchmark( return result -def parse_request_rate_range(request_rate_range): - if len(request_rate_range.split(",")) == 3: - start, stop, step = map(int, request_rate_range.split(",")) - return list(range(start, stop, step)) - else: - return list(map(int, request_rate_range.split(","))) - - def check_chat_template(model_path): try: tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) @@ -1168,6 +1214,12 @@ def check_chat_template(model_path): return False +def set_global_args(args_: argparse.Namespace): + """Set the global args.""" + global args + args = args_ + + def run_benchmark(args_: argparse.Namespace): global args args = args_ @@ -1176,6 +1228,8 @@ def run_benchmark(args_: argparse.Namespace): if not hasattr(args, "max_concurrency"): args.max_concurrency = None + print(f"benchmark_args={args}") + # Set global environments set_ulimit() random.seed(args.seed) @@ -1272,49 +1326,26 @@ def run_benchmark(args_: argparse.Namespace): backend = args.backend model_id = args.model tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model - tokenizer = get_tokenizer(tokenizer_id) - input_requests = get_dataset(args, tokenizer) - if not args.multi: - return asyncio.run( - benchmark( - backend=backend, - api_url=api_url, - base_url=base_url, - model_id=model_id, - tokenizer=tokenizer, - input_requests=input_requests, - request_rate=args.request_rate, - max_concurrency=args.max_concurrency, - disable_tqdm=args.disable_tqdm, - lora_name=args.lora_name, - extra_request_body=extra_request_body, - profile=args.profile, - ) + return asyncio.run( + benchmark( + backend=backend, + api_url=api_url, + base_url=base_url, + model_id=model_id, + tokenizer=tokenizer, + input_requests=input_requests, + request_rate=args.request_rate, + max_concurrency=args.max_concurrency, + disable_tqdm=args.disable_tqdm, + lora_name=args.lora_name, + extra_request_body=extra_request_body, + profile=args.profile, + pd_seperated=args.pd_seperated, ) - else: - # Benchmark multiple rps. TODO: use a fixed duration to compute num_prompts - request_rates = parse_request_rate_range(args.request_rate_range) - - for rate in request_rates: - asyncio.run( - benchmark( - backend=backend, - api_url=api_url, - base_url=base_url, - model_id=model_id, - tokenizer=tokenizer, - input_requests=input_requests, - request_rate=rate, - max_concurrency=args.max_concurrency, - disable_tqdm=args.disable_tqdm, - lora_name=args.lora_name, - extra_request_body=extra_request_body, - profile=args.profile, - ) - ) + ) def set_ulimit(target_soft_limit=65535): @@ -1428,17 +1459,6 @@ if __name__ == "__main__": "actual request rate may be lower than specified with --request-rate, " "if the server is not processing requests fast enough to keep up.", ) - parser.add_argument( - "--multi", - action="store_true", - help="Use request rate range rather than single value.", - ) - parser.add_argument( - "--request-rate-range", - type=str, - default="2,34,2", - help="Range of request rates in the format start,stop,step. Default is 2,34,2. It also supports a list of request rates, requiring the parameters to not equal three.", - ) parser.add_argument("--output-file", type=str, help="Output JSONL file name.") parser.add_argument( "--disable-tqdm", @@ -1485,6 +1505,17 @@ if __name__ == "__main__": default=None, help="The name of LoRA adapter", ) + parser.add_argument( + "--prompt-suffix", + type=str, + default="", + help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.", + ) + parser.add_argument( + "--pd-seperated", + action="store_true", + help="Benchmark PD disaggregation server", + ) group = parser.add_argument_group("generated-shared-prefix dataset arguments") group.add_argument( diff --git a/python/sglang/global_config.py b/python/sglang/global_config.py index ac034ec0a..bb4da0fec 100644 --- a/python/sglang/global_config.py +++ b/python/sglang/global_config.py @@ -34,11 +34,9 @@ class GlobalConfig: self.skip_special_tokens_in_output = True self.spaces_between_special_tokens_in_out = True - # Interpreter optimization configs + # Language frontend interpreter optimization configs self.enable_precache_with_tracing = True self.enable_parallel_encoding = True - self.enable_flashinfer_mla = False - global_config = GlobalConfig() diff --git a/python/sglang/lang/backend/runtime_endpoint.py b/python/sglang/lang/backend/runtime_endpoint.py index 1cd3d5246..3a2bf79b0 100644 --- a/python/sglang/lang/backend/runtime_endpoint.py +++ b/python/sglang/lang/backend/runtime_endpoint.py @@ -329,7 +329,12 @@ class RuntimeEndpoint(BaseBackend): def compute_normalized_prompt_logprobs(input_logprobs): values = [x[0] for x in input_logprobs if x[0]] - return sum(values) / len(values) + try: + return sum(values) / len(values) + except TypeError: + print(f"{input_logprobs=}", flush=True) + print(f"{input_logprobs[0]=}", flush=True) + exit(-1) class Runtime: diff --git a/python/sglang/srt/configs/load_config.py b/python/sglang/srt/configs/load_config.py index 6cb35ab47..c8521910e 100644 --- a/python/sglang/srt/configs/load_config.py +++ b/python/sglang/srt/configs/load_config.py @@ -21,6 +21,7 @@ class LoadFormat(str, enum.Enum): BITSANDBYTES = "bitsandbytes" MISTRAL = "mistral" LAYERED = "layered" + JAX = "jax" @dataclass @@ -42,13 +43,15 @@ class LoadConfig: ignore_patterns: The list of patterns to ignore when loading the model. Default to "original/**/*" to avoid repeated loading of llama's checkpoints. - + decryption_key_file: If set, decrypts the output files with a password read + from this file (after PBKDF2). """ load_format: Union[str, LoadFormat] = LoadFormat.AUTO download_dir: Optional[str] = None model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict) ignore_patterns: Optional[Union[List[str], str]] = None + decryption_key_file: Optional[str] = None def __post_init__(self): model_loader_extra_config = self.model_loader_extra_config or {} diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 86085284d..8880288f1 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -44,6 +44,7 @@ class ModelConfig: is_embedding: Optional[bool] = None, dtype: str = "auto", quantization: Optional[str] = None, + override_config_file: Optional[str] = None, ) -> None: self.model_path = model_path self.revision = revision @@ -51,11 +52,16 @@ class ModelConfig: # Parse args self.model_override_args = json.loads(model_override_args) + kwargs = {} + if override_config_file and override_config_file.strip(): + kwargs["_configuration_file"] = override_config_file.strip() + self.hf_config = get_config( model_path, trust_remote_code=trust_remote_code, revision=revision, model_override_args=self.model_override_args, + **kwargs, ) self.hf_text_config = get_hf_text_config(self.hf_config) @@ -64,6 +70,9 @@ class ModelConfig: self.hf_config.architectures, is_embedding ) self.is_multimodal = is_multimodal_model(self.hf_config.architectures) + self.is_multimodal_gen = is_multimodal_gen_model(self.hf_config.architectures) + self.is_image_gen = is_image_gen_model(self.hf_config.architectures) + self.is_audio_model = is_audio_model(self.hf_config.architectures) self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) @@ -71,7 +80,9 @@ class ModelConfig: derived_context_len = get_context_length(self.hf_text_config) if context_length is not None: if context_length > derived_context_len: - if get_bool_env_var("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"): + if get_bool_env_var( + "SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", default="False" + ): logger.warning( f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). " f"This may lead to incorrect model outputs or CUDA errors." @@ -416,6 +427,8 @@ def is_multimodal_model(model_architectures: List[str]): or "LlavaQwenForCausalLM" in model_architectures or "LlavaMistralForCausalLM" in model_architectures or "LlavaVidForCausalLM" in model_architectures + or "Grok1VForCausalLM" in model_architectures + or "Grok1AForCausalLM" in model_architectures or "MllamaForConditionalGeneration" in model_architectures or "Qwen2VLForConditionalGeneration" in model_architectures or "Qwen2_5_VLForConditionalGeneration" in model_architectures @@ -426,6 +439,18 @@ def is_multimodal_model(model_architectures: List[str]): return False +def is_multimodal_gen_model(model_architectures: List[str]): + return False + + +def is_image_gen_model(model_architectures: List[str]): + return False + + +def is_audio_model(model_architectures: List[str]): + return False + + def is_encoder_decoder_model(model_architectures: List[str]): return "MllamaForConditionalGeneration" in model_architectures diff --git a/python/sglang/srt/constrained/xgrammar_backend.py b/python/sglang/srt/constrained/xgrammar_backend.py index d00ba1428..b7f9a15e9 100644 --- a/python/sglang/srt/constrained/xgrammar_backend.py +++ b/python/sglang/srt/constrained/xgrammar_backend.py @@ -15,7 +15,7 @@ import json import logging -from typing import List, Tuple +from typing import List, Optional, Tuple, Union import torch from xgrammar import ( @@ -42,11 +42,16 @@ MAX_ROLLBACK_TOKENS = 200 class XGrammarGrammar(BaseGrammarObject): def __init__( - self, matcher: GrammarMatcher, vocab_size: int, ctx: CompiledGrammar + self, + matcher: GrammarMatcher, + vocab_size: int, + ctx: CompiledGrammar, + override_stop_tokens: Optional[Union[List[int], int]], ) -> None: self.matcher = matcher self.vocab_size = vocab_size self.ctx = ctx + self.override_stop_tokens = override_stop_tokens self.finished = False def accept_token(self, token: int): @@ -96,8 +101,14 @@ class XGrammarGrammar(BaseGrammarObject): apply_token_bitmask_inplace(logits, vocab_mask) def copy(self): - matcher = GrammarMatcher(self.ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS) - return XGrammarGrammar(matcher, self.vocab_size, self.ctx) + matcher = GrammarMatcher( + self.ctx, + max_rollback_tokens=MAX_ROLLBACK_TOKENS, + override_stop_tokens=self.override_stop_tokens, + ) + return XGrammarGrammar( + matcher, self.vocab_size, self.ctx, self.override_stop_tokens + ) class XGrammarGrammarBackend(BaseGrammarBackend): @@ -111,8 +122,11 @@ class XGrammarGrammarBackend(BaseGrammarBackend): tokenizer_info = TokenizerInfo.from_huggingface( tokenizer, vocab_size=vocab_size ) + override_stop_tokens = None + self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info) self.vocab_size = vocab_size + self.override_stop_tokens = override_stop_tokens def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar: @@ -161,7 +175,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend): raise ValueError(f"Invalid key_type: {key_type}") matcher = GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS) - return XGrammarGrammar(matcher, self.vocab_size, ctx) + return XGrammarGrammar(matcher, self.vocab_size, ctx, self.override_stop_tokens) def reset(self): if self.grammar_compiler: diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 68bdf2cba..65c1f1b85 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -121,6 +121,7 @@ class Engine: return_logprob: Optional[Union[List[bool], bool]] = False, logprob_start_len: Optional[Union[List[int], int]] = None, top_logprobs_num: Optional[Union[List[int], int]] = None, + token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None, lora_path: Optional[List[Optional[str]]] = None, custom_logit_processor: Optional[Union[List[str], str]] = None, return_hidden_states: bool = False, @@ -142,6 +143,7 @@ class Engine: return_logprob=return_logprob, logprob_start_len=logprob_start_len, top_logprobs_num=top_logprobs_num, + token_ids_logprob=token_ids_logprob, lora_path=lora_path, modalities=modalities_list, custom_logit_processor=custom_logit_processor, @@ -179,6 +181,7 @@ class Engine: return_logprob: Optional[Union[List[bool], bool]] = False, logprob_start_len: Optional[Union[List[int], int]] = None, top_logprobs_num: Optional[Union[List[int], int]] = None, + token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None, lora_path: Optional[List[Optional[str]]] = None, custom_logit_processor: Optional[Union[List[str], str]] = None, stream: bool = False, @@ -195,6 +198,7 @@ class Engine: return_logprob=return_logprob, logprob_start_len=logprob_start_len, top_logprobs_num=top_logprobs_num, + token_ids_logprob=token_ids_logprob, lora_path=lora_path, stream=stream, custom_logit_processor=custom_logit_processor, @@ -226,15 +230,22 @@ class Engine: kill_process_tree(os.getpid(), include_parent=False) def start_profile(self): - self.tokenizer_manager.start_profile() + loop = asyncio.get_event_loop() + loop.run_until_complete(self.tokenizer_manager.start_profile()) def stop_profile(self): self.tokenizer_manager.stop_profile() def get_server_info(self): + loop = asyncio.get_event_loop() + internal_states = loop.run_until_complete( + self.tokenizer_manager.get_internal_state() + ) + return { - **dataclasses.asdict(self.tokenizer_manager.server_args), # server args + **dataclasses.asdict(self.tokenizer_manager.server_args), **self.scheduler_info, + **internal_states, "version": __version__, } @@ -323,6 +334,7 @@ def _set_envs_and_config(server_args: ServerArgs): os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls)) os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4" + os.environ["CUDA_MODULE_LOADING"] = "AUTO" # Set prometheus env vars if server_args.enable_metrics: @@ -346,12 +358,23 @@ def _set_envs_and_config(server_args: ServerArgs): "at https://docs.flashinfer.ai/installation.html.", ) + def sigchld_handler(signum, frame): + pid, exitcode = os.waitpid(0, os.WNOHANG) + if exitcode != 0: + logger.warning( + "Child process unexpectedly failed with an exit code %d. pid=%d", + exitcode, + pid, + ) + + signal.signal(signal.SIGCHLD, sigchld_handler) + # Register the signal handler. # The child processes will send SIGQUIT to this process when any error happens # This process then clean up the whole process tree def sigquit_handler(signum, frame): logger.error( - "Received sigquit from a child proces. It usually means the child failed." + "Received sigquit from a child process. It usually means the child failed." ) kill_process_tree(os.getpid()) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 4873306c6..f29a81cb4 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -25,11 +25,14 @@ import os import threading import time from http import HTTPStatus -from typing import AsyncIterator, Dict, Optional +from typing import AsyncIterator, Callable, Dict, Optional # Fix a bug of Python threading setattr(threading, "_register_atexit", lambda *args, **kwargs: None) +from contextlib import asynccontextmanager + +import numpy as np import orjson import requests import uvicorn @@ -49,8 +52,10 @@ from sglang.srt.managers.io_struct import ( InitWeightsUpdateGroupReqInput, OpenSessionReqInput, ParseFunctionCallReq, + ProfileReqInput, ReleaseMemoryOccupationReqInput, ResumeMemoryOccupationReqInput, + SetInternalStateReq, UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, VertexGenerateReqInput, @@ -78,22 +83,13 @@ from sglang.srt.utils import ( kill_process_tree, set_uvicorn_logging_configs, ) +from sglang.srt.warmup import execute_warmups from sglang.utils import get_exception_traceback from sglang.version import __version__ logger = logging.getLogger(__name__) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) -# Fast API -app = FastAPI() -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - # Store global states @dataclasses.dataclass @@ -110,6 +106,34 @@ def set_global_state(global_state: _GlobalState): _global_state = global_state +@asynccontextmanager +async def lifespan(fast_api_app: FastAPI): + server_args: ServerArgs = fast_api_app.server_args + if server_args.warmups is not None: + await execute_warmups( + server_args.warmups.split(","), _global_state.tokenizer_manager + ) + logger.info("Warmup ended") + + warmup_thread = getattr(fast_api_app, "warmup_thread", None) + if warmup_thread is not None: + warmup_thread.start() + yield + + +# Fast API +app = FastAPI(lifespan=lifespan) +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20)) + + ##### Native API endpoints ##### @@ -123,24 +147,48 @@ async def health() -> Response: async def health_generate(request: Request) -> Response: """Check the health of the inference server by generating one token.""" - sampling_params = {"max_new_tokens": 1, "temperature": 0.7} + sampling_params = {"max_new_tokens": 1, "temperature": 0.0} + rid = f"HEALTH_CHECK_{time.time()}" - if _global_state.tokenizer_manager.is_generation: + if _global_state.tokenizer_manager.is_image_gen: + raise NotImplementedError() + elif _global_state.tokenizer_manager.is_generation: gri = GenerateReqInput( - input_ids=[0], sampling_params=sampling_params, log_metrics=False + rid=rid, + input_ids=[0], + sampling_params=sampling_params, + log_metrics=False, ) else: gri = EmbeddingReqInput( - input_ids=[0], sampling_params=sampling_params, log_metrics=False + rid=rid, input_ids=[0], sampling_params=sampling_params, log_metrics=False ) - try: + async def gen(): async for _ in _global_state.tokenizer_manager.generate_request(gri, request): break - return Response(status_code=200) - except Exception as e: - logger.exception(e) - return Response(status_code=503) + + tic = time.time() + task = asyncio.create_task(gen()) + while time.time() < tic + HEALTH_CHECK_TIMEOUT: + await asyncio.sleep(1) + if _global_state.tokenizer_manager.last_receive_tstamp > tic: + task.cancel() + _global_state.tokenizer_manager.rid_to_state.pop(rid, None) + return Response(status_code=200) + + task.cancel() + tic_time = time.strftime("%H:%M:%S", time.localtime(tic)) + last_receive_time = time.strftime( + "%H:%M:%S", time.localtime(_global_state.tokenizer_manager.last_receive_tstamp) + ) + logger.error( + f"Health check failed. Server couldn't get a response from detokenizer for last " + f"{HEALTH_CHECK_TIMEOUT} seconds. tic start time: {tic_time}. " + f"last_heartbeat time: {last_receive_time}" + ) + _global_state.tokenizer_manager.rid_to_state.pop(rid, None) + return Response(status_code=503) @app.get("/get_model_info") @@ -156,13 +204,21 @@ async def get_model_info(): @app.get("/get_server_info") async def get_server_info(): + internal_states = await _global_state.tokenizer_manager.get_internal_state() return { **dataclasses.asdict(_global_state.tokenizer_manager.server_args), **_global_state.scheduler_info, + **internal_states, "version": __version__, } +@app.api_route("/set_internal_state", methods=["POST", "PUT"]) +async def set_internal_state(obj: SetInternalStateReq, request: Request): + res = await _global_state.tokenizer_manager.set_internal_state(obj) + return res + + # fastapi implicitly converts json in the request to obj (dataclass) @app.api_route("/generate", methods=["POST", "PUT"]) async def generate_request(obj: GenerateReqInput, request: Request): @@ -179,6 +235,7 @@ async def generate_request(obj: GenerateReqInput, request: Request): ) + b"\n\n" except ValueError as e: out = {"error": {"message": str(e)}} + logger.error(f"Error: {e}") yield b"data: " + orjson.dumps( out, option=orjson.OPT_NON_STR_KEYS ) + b"\n\n" @@ -236,9 +293,14 @@ async def flush_cache(): @app.api_route("/start_profile", methods=["GET", "POST"]) -async def start_profile_async(): +async def start_profile_async(obj: Optional[ProfileReqInput] = None): """Start profiling.""" - _global_state.tokenizer_manager.start_profile() + if obj is None: + obj = ProfileReqInput() + + await _global_state.tokenizer_manager.start_profile( + obj.output_dir, obj.num_steps, obj.activities + ) return Response( content="Start profiling.\n", status_code=200, @@ -257,11 +319,15 @@ async def stop_profile_async(): @app.post("/update_weights_from_disk") async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request): - """Update the weights from disk in-place without re-launching the server.""" - success, message = await _global_state.tokenizer_manager.update_weights_from_disk( - obj, request + """Update the weights from disk inplace without re-launching the server.""" + success, message, num_paused_requests = ( + await _global_state.tokenizer_manager.update_weights_from_disk(obj, request) ) - content = {"success": success, "message": message} + content = { + "success": success, + "message": message, + "num_paused_requests": num_paused_requests, + } if success: return ORJSONResponse( content, @@ -323,7 +389,7 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request): async def release_memory_occupation( obj: ReleaseMemoryOccupationReqInput, request: Request ): - """Release GPU occupation temporarily""" + """Release GPU memory occupation temporarily.""" try: await _global_state.tokenizer_manager.release_memory_occupation(obj, request) except Exception as e: @@ -334,7 +400,7 @@ async def release_memory_occupation( async def resume_memory_occupation( obj: ResumeMemoryOccupationReqInput, request: Request ): - """Resume GPU occupation""" + """Resume GPU memory occupation.""" try: await _global_state.tokenizer_manager.resume_memory_occupation(obj, request) except Exception as e: @@ -357,7 +423,7 @@ async def open_session(obj: OpenSessionReqInput, request: Request): @app.api_route("/close_session", methods=["GET", "POST"]) async def close_session(obj: CloseSessionReqInput, request: Request): - """Close the session""" + """Close the session.""" try: await _global_state.tokenizer_manager.close_session(obj, request) return Response(status_code=200) @@ -367,7 +433,7 @@ async def close_session(obj: CloseSessionReqInput, request: Request): @app.api_route("/configure_logging", methods=["GET", "POST"]) async def configure_logging(obj: ConfigureLoggingReq, request: Request): - """Close the session""" + """Configure the request logging options.""" _global_state.tokenizer_manager.configure_logging(obj) return Response(status_code=200) @@ -511,6 +577,7 @@ def _create_error_response(e): def launch_server( server_args: ServerArgs, pipe_finish_writer: Optional[multiprocessing.connection.Connection] = None, + launch_callback: Optional[Callable[[], None]] = None, ): """ Launch SRT (SGLang Runtime) Server. @@ -544,21 +611,23 @@ def launch_server( add_prometheus_middleware(app) enable_func_timer() - # Send a warmup request - t = threading.Thread( + # Send a warmup request - we will create the thread launch it + # in the lifespan after all other warmups have fired. + warmup_thread = threading.Thread( target=_wait_and_warmup, args=( server_args, pipe_finish_writer, _global_state.tokenizer_manager.image_token_id, + launch_callback, ), ) - t.start() + app.warmup_thread = warmup_thread try: # Update logging configs set_uvicorn_logging_configs() - + app.server_args = server_args # Listen for HTTP requests uvicorn.run( app, @@ -569,10 +638,15 @@ def launch_server( loop="uvloop", ) finally: - t.join() + warmup_thread.join() -def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text): +def _wait_and_warmup( + server_args: ServerArgs, + pipe_finish_writer: Optional[multiprocessing.connection.Connection], + image_token_text: str, + launch_callback: Optional[Callable[[], None]] = None, +): headers = {} url = server_args.url() if server_args.api_key: @@ -614,8 +688,16 @@ def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text): else: json_data["text"] = "The capital city of France is" + # Debug dumping + if server_args.debug_tensor_dump_input_file: + json_data.pop("text", None) + json_data["input_ids"] = np.load( + server_args.debug_tensor_dump_input_file + ).tolist() + json_data["sampling_params"]["max_new_tokens"] = 0 + try: - for _ in range(server_args.dp_size): + for i in range(server_args.dp_size): res = requests.post( url + request_name, json=json_data, @@ -640,3 +722,9 @@ def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text): if server_args.delete_ckpt_after_loading: delete_directory(server_args.model_path) + + if server_args.debug_tensor_dump_input_file: + kill_process_tree(os.getpid()) + + if launch_callback is not None: + launch_callback() diff --git a/python/sglang/srt/entrypoints/verl_engine.py b/python/sglang/srt/entrypoints/verl_engine.py index 3892242c9..76927a745 100644 --- a/python/sglang/srt/entrypoints/verl_engine.py +++ b/python/sglang/srt/entrypoints/verl_engine.py @@ -60,6 +60,7 @@ class VerlEngine: return_logprob: Optional[Union[List[bool], bool]] = False, logprob_start_len: Optional[Union[List[int], int]] = None, top_logprobs_num: Optional[Union[List[int], int]] = None, + token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None, lora_path: Optional[List[Optional[str]]] = None, custom_logit_processor: Optional[Union[List[str], str]] = None, ) -> Dict: @@ -76,6 +77,7 @@ class VerlEngine: return_logprob=return_logprob, logprob_start_len=logprob_start_len, top_logprobs_num=top_logprobs_num, + token_ids_logprob=token_ids_logprob, lora_path=lora_path, custom_logit_processor=custom_logit_processor, ) diff --git a/python/sglang/srt/layers/attention/__init__.py b/python/sglang/srt/layers/attention/__init__.py index 745598643..e0e688ce5 100644 --- a/python/sglang/srt/layers/attention/__init__.py +++ b/python/sglang/srt/layers/attention/__init__.py @@ -1,14 +1,14 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Union import torch if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode - from sglang.srt.speculative.spec_info import SpecInfo + from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput class AttentionBackend(ABC): @@ -31,7 +31,7 @@ class AttentionBackend(ABC): seq_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[SpecInfo], + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], ): """Init the metadata for a forward pass for capturing a cuda graph.""" raise NotImplementedError() @@ -44,7 +44,7 @@ class AttentionBackend(ABC): seq_lens_sum: int, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[SpecInfo], + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], ): """Init the metadata for a forward pass for replying a cuda graph.""" raise NotImplementedError() @@ -64,7 +64,14 @@ class AttentionBackend(ABC): ): """Run forward on an attention layer.""" if forward_batch.forward_mode.is_decode(): - return self.forward_decode(q, k, v, layer, forward_batch, save_kv_cache) + return self.forward_decode( + q, + k, + v, + layer, + forward_batch, + save_kv_cache=save_kv_cache, + ) else: return self.forward_extend( q, @@ -72,7 +79,7 @@ class AttentionBackend(ABC): v, layer, forward_batch, - save_kv_cache, + save_kv_cache=save_kv_cache, ) def forward_decode( diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index a3b890219..39bba1125 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -68,6 +68,7 @@ class FlashInferAttnBackend(AttentionBackend): model_runner: ModelRunner, skip_prefill: bool = False, kv_indptr_buf: Optional[torch.Tensor] = None, + kv_last_page_len_buf: Optional[torch.Tensor] = None, ): super().__init__() @@ -125,9 +126,14 @@ class FlashInferAttnBackend(AttentionBackend): assert self.num_wrappers == 1 self.kv_indptr = [kv_indptr_buf] - self.kv_last_page_len = torch.ones( - (max_bs,), dtype=torch.int32, device=model_runner.device - ) + if kv_last_page_len_buf is None: + self.kv_last_page_len = torch.ones( + (max_bs,), dtype=torch.int32, device=model_runner.device + ) + else: + assert self.num_wrappers == 1 + self.kv_last_page_len = kv_last_page_len_buf + self.qo_indptr = [ torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device) for _ in range(self.num_wrappers) @@ -922,6 +928,9 @@ class FlashInferMultiStepDraftBackend: dtype=torch.int32, device=model_runner.device, ) + self.kv_last_page_len = torch.ones( + (max_bs,), dtype=torch.int32, device=model_runner.device + ) self.attn_backends = [] for i in range(self.speculative_num_steps): self.attn_backends.append( @@ -929,6 +938,7 @@ class FlashInferMultiStepDraftBackend: model_runner, skip_prefill=True, kv_indptr_buf=self.kv_indptr[i], + kv_last_page_len_buf=self.kv_last_page_len, ) ) self.max_context_len = self.attn_backends[0].max_context_len diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 7bb6615ed..ab28b84ae 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Union import torch import triton @@ -15,7 +15,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner - from sglang.srt.speculative.spec_info import SpecInfo + from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput class TritonAttnBackend(AttentionBackend): @@ -232,7 +232,7 @@ class TritonAttnBackend(AttentionBackend): seq_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[SpecInfo], + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], ): assert encoder_lens is None, "Not supported" @@ -310,7 +310,7 @@ class TritonAttnBackend(AttentionBackend): seq_lens_sum: int, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[SpecInfo], + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], ): # NOTE: encoder_lens expected to be zeros or None if forward_mode.is_decode_or_idle(): diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index 36b87ca0b..f8b756f52 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -1,6 +1,21 @@ -import torch +from __future__ import annotations -from sglang.srt.distributed import GroupCoordinator, get_tp_group +import functools +from typing import TYPE_CHECKING, Union + +import torch +import triton +import triton.language as tl + +from sglang.srt.distributed import ( + GroupCoordinator, + get_tensor_model_parallel_world_size, + get_tp_group, + tensor_model_parallel_all_reduce, +) + +if TYPE_CHECKING: + from sglang.srt.model_executor.forward_batch_info import ForwardBatch _ATTN_TP_GROUP = None _ATTN_TP_RANK = None @@ -69,3 +84,129 @@ def get_attention_dp_rank(): def get_attention_dp_size(): assert _DP_SIZE is not None, "dp attention not initialized!" return _DP_SIZE + + +def get_dp_local_info(forward_batch: ForwardBatch): + dp_rank = get_attention_dp_rank() + + if forward_batch.dp_local_start_pos is None: + cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0) + if dp_rank == 0: + local_start_pos = torch.zeros_like(cumtokens[0]) + else: + local_start_pos = cumtokens[dp_rank - 1] + local_num_tokens = forward_batch.global_num_tokens_gpu[dp_rank] + + forward_batch.dp_local_start_pos = local_start_pos + forward_batch.dp_local_num_tokens = local_num_tokens + + return forward_batch.dp_local_start_pos, forward_batch.dp_local_num_tokens + + +@triton.jit +def memcpy_triton_kernel( + dst_ptr, + src_ptr, + offset_ptr, + sz_ptr, + offset_src, + chunk_size, # multiplied for offset and sz + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0).to(tl.int64) + offset = tl.load(offset_ptr).to(tl.int64) * chunk_size + sz = tl.load(sz_ptr).to(tl.int64) * chunk_size + + start_index = pid * BLOCK_SIZE + offs = tl.arange(0, BLOCK_SIZE) + mask = start_index + offs < sz + + if offset_src: + data = tl.load(src_ptr + offset + start_index + offs, mask=mask) + tl.store(dst_ptr + start_index + offs, data, mask=mask) + else: + data = tl.load(src_ptr + start_index + offs, mask=mask) + tl.store(dst_ptr + offset + start_index + offs, data, mask=mask) + + +def prod(x): + return functools.reduce(lambda a, b: a * b, x, 1) + + +def memcpy_triton(dst, src, dim, offset, sz, offset_src): + max_size = min(src.numel(), dst.numel()) + assert dim == 0, "dim != 0 unsupported" + assert src.shape[1:] == dst.shape[1:], "src and dst must have same shape" + chunk_size = prod(src.shape[1:]) + BLOCK_SIZE = 8192 + grid = (triton.cdiv(max_size, BLOCK_SIZE),) + + memcpy_triton_kernel[grid](dst, src, offset, sz, offset_src, chunk_size, BLOCK_SIZE) + + +def dp_gather( + global_tokens: torch.Tensor, + local_tokens: torch.Tensor, + forward_batch: ForwardBatch, + layer_id: Union[str, int], +): + local_start_pos, local_num_tokens = get_dp_local_info(forward_batch) + + global_tokens.fill_(0) + assert local_tokens.is_contiguous() + assert global_tokens.is_contiguous() + if local_tokens.shape[0] > 0 and ( + layer_id != "embedding" or get_attention_tp_rank() == 0 + ): + assert ( + global_tokens.storage().data_ptr() != local_tokens.storage().data_ptr() + ), "aliasing between global_tokens and local_tokens not allowed" + memcpy_triton( + global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False + ) + + # Input IDs are in int 32. We should use inplace_all_reduce for local case becaues of custom all reduce. + NUM_GPUS_PER_NODE = 8 + if ( + not local_tokens.dtype.is_floating_point + and get_tensor_model_parallel_world_size() <= NUM_GPUS_PER_NODE + ): + torch.ops.sglang.inplace_all_reduce( + global_tokens, group_name=get_tp_group().unique_name + ) + else: + global_tokens = tensor_model_parallel_all_reduce(global_tokens) + + +def dp_scatter( + local_tokens: torch.Tensor, # output + global_tokens: torch.Tensor, # input + forward_batch: ForwardBatch, +): + # local_num_tokens is not necessarily the same as local_tokens.shape[0], + # since local_tokens may be padded for cuda graph + local_start_pos, local_num_tokens = get_dp_local_info(forward_batch) + local_tokens.fill_(0) + assert local_tokens.is_contiguous() + assert global_tokens.is_contiguous() + if local_tokens.shape[0] > 0: + assert ( + local_tokens.untyped_storage().data_ptr() + != global_tokens.untyped_storage().data_ptr() + ), "aliasing between local_tokens and global_tokens not allowed" + memcpy_triton( + local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True + ) + + +def get_do_logits_dp_scatter(forward_batch: ForwardBatch): + def do_logits_dp_scatter(logits: torch.Tensor): + local_logits = torch.empty( + (forward_batch.input_ids.shape[0], *logits.shape[1:]), + dtype=logits.dtype, + device=logits.device, + ) + dp_scatter(local_logits, logits, forward_batch) + return local_logits + + return do_logits_dp_scatter diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index e3b23a2a9..289d75b36 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -69,7 +69,7 @@ class RMSNorm(CustomOp): variance = x.pow(2).mean(dim=-1, keepdim=True) x = x * torch.rsqrt(variance + self.variance_epsilon) - x = x.to(orig_dtype) * self.weight + x = (x * self.weight).to(orig_dtype) if residual is None: return x else: diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index f8f2d9f6d..919bcced3 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -426,13 +426,14 @@ class ColumnParallelLinear(LinearBase): from sglang.srt.layers.parameter import _ColumnvLLMParameter if isinstance(param, _ColumnvLLMParameter): - # FIXME: why would we need this special case? param.load_column_parallel_weight( loaded_weight, tp_rank=self.tp_rank, use_presharded_weights=self.use_presharded_weights, ) else: + # FIXME: This branch is needed to load deepseek v3 awq. + # However, we should fix this and avoid the branching here. param.load_column_parallel_weight(loaded_weight) def forward(self, input_): diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 08ee5a350..ec47912ef 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -26,12 +26,19 @@ from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, ) +from sglang.srt.layers.dp_attention import ( + dp_gather, + dp_scatter, + get_attention_dp_rank, + get_attention_dp_size, +) from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ( CaptureHiddenMode, ForwardBatch, ForwardMode, ) +from sglang.srt.utils import dump_to_file logger = logging.getLogger(__name__) @@ -51,6 +58,9 @@ class LogitsProcessorOutput: # The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k] next_token_top_logprobs_val: Optional[List] = None next_token_top_logprobs_idx: Optional[List] = None + # The logprobs and ids of the requested token ids in output positions. shape: [#seq, n] (n is the number of requested token ids) + next_token_token_ids_logprobs_val: Optional[List] = None + next_token_token_ids_logprobs_idx: Optional[List] = None ## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor # The logprobs of input tokens. shape: [#token] @@ -58,6 +68,9 @@ class LogitsProcessorOutput: # The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k] input_top_logprobs_val: List = None input_top_logprobs_idx: List = None + # The logprobs and ids of the requested token ids in input positions. shape: [#seq, n] (n is the number of requested token ids) + input_token_ids_logprobs_val: Optional[List] = None + input_token_ids_logprobs_idx: Optional[List] = None @dataclasses.dataclass @@ -67,43 +80,107 @@ class LogitsMetadata: extend_return_logprob: bool = False extend_return_top_logprob: bool = False + extend_token_ids_logprob: bool = False extend_seq_lens: Optional[torch.Tensor] = None extend_seq_lens_cpu: Optional[List[int]] = None extend_logprob_start_lens_cpu: Optional[List[int]] = None extend_logprob_pruned_lens_cpu: Optional[List[int]] = None top_logprobs_nums: Optional[List[int]] = None + extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None + token_ids_logprobs: Optional[List[List[int]]] = None + + # logits and logprobs post processing + temp_scaled_logprobs: bool = False + temperature: torch.Tensor = None + top_p_normalized_logprobs: bool = False + top_p: torch.Tensor = None + + # DP attention metadata. Not needed when DP attention is not used. + # Number of tokens in the request. + global_num_tokens_gpu: Optional[torch.Tensor] = None + # The start position of local hidden states. + dp_local_start_pos: Optional[torch.Tensor] = None + dp_local_num_tokens: Optional[torch.Tensor] = None + gathered_buffer: Optional[torch.Tensor] = None + # Buffer to gather logits from all ranks. + forward_batch_gathered_buffer: Optional[torch.Tensor] = None + # Number of tokens to sample per DP rank + global_num_tokens_for_logprob_cpu: Optional[torch.Tensor] = None + global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None + + # for padding + padded_static_len: int = -1 @classmethod def from_forward_batch(cls, forward_batch: ForwardBatch): - if forward_batch.forward_mode.is_extend() and forward_batch.return_logprob: - extend_return_logprob = True + if ( + forward_batch.forward_mode.is_extend() + and forward_batch.return_logprob + and not forward_batch.forward_mode.is_target_verify() + ): extend_return_top_logprob = any( x > 0 for x in forward_batch.top_logprobs_nums ) - extend_logprob_pruned_lens_cpu = [ - extend_len - start_len - for extend_len, start_len in zip( - forward_batch.extend_seq_lens_cpu, - forward_batch.extend_logprob_start_lens_cpu, - ) - ] + extend_token_ids_logprob = any( + x is not None for x in forward_batch.token_ids_logprobs + ) + extend_return_logprob = False + extend_logprob_pruned_lens_cpu = [] + for extend_len, start_len in zip( + forward_batch.extend_seq_lens_cpu, + forward_batch.extend_logprob_start_lens_cpu, + ): + if extend_len - start_len > 0: + extend_return_logprob = True + extend_logprob_pruned_lens_cpu.append(extend_len - start_len) else: extend_return_logprob = extend_return_top_logprob = ( - extend_logprob_pruned_lens_cpu - ) = False + extend_token_ids_logprob + ) = extend_logprob_pruned_lens_cpu = False return cls( forward_mode=forward_batch.forward_mode, capture_hidden_mode=forward_batch.capture_hidden_mode, extend_return_logprob=extend_return_logprob, extend_return_top_logprob=extend_return_top_logprob, + extend_token_ids_logprob=extend_token_ids_logprob, extend_seq_lens=forward_batch.extend_seq_lens, extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu, extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu, extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu, top_logprobs_nums=forward_batch.top_logprobs_nums, + token_ids_logprobs=forward_batch.token_ids_logprobs, + extend_input_logprob_token_ids_gpu=forward_batch.extend_input_logprob_token_ids_gpu, + padded_static_len=forward_batch.padded_static_len, ) + def compute_dp_attention_metadata(self, hidden_states: torch.Tensor): + if self.global_num_tokens_for_logprob_cpu is None: + # we are capturing cuda graph + return + + cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0) + dp_rank = get_attention_dp_rank() + if dp_rank == 0: + dp_local_start_pos = torch.zeros_like( + self.global_num_tokens_for_logprob_gpu[0] + ) + else: + dp_local_start_pos = cumtokens[dp_rank - 1] + dp_local_num_tokens = self.global_num_tokens_for_logprob_gpu[dp_rank] + gathered_buffer = torch.zeros( + ( + sum(self.global_num_tokens_for_logprob_cpu), + hidden_states.shape[1], + ), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + + self.dp_local_start_pos = dp_local_start_pos + self.dp_local_num_tokens = dp_local_num_tokens + self.gathered_buffer = gathered_buffer + class LogitsProcessor(nn.Module): def __init__( @@ -115,6 +192,9 @@ class LogitsProcessor(nn.Module): self.do_tensor_parallel_all_gather = ( not skip_all_gather and get_tensor_model_parallel_world_size() > 1 ) + self.do_tensor_parallel_all_gather_dp_attn = ( + self.do_tensor_parallel_all_gather and get_attention_dp_size() != 1 + ) self.final_logit_softcapping = getattr( self.config, "final_logit_softcapping", None ) @@ -124,6 +204,12 @@ class LogitsProcessor(nn.Module): ): self.final_logit_softcapping = None + from sglang.srt.managers.schedule_batch import global_server_args_dict + + self.debug_tensor_dump_output_folder = global_server_args_dict[ + "debug_tensor_dump_output_folder" + ] + def forward( self, input_ids, @@ -141,30 +227,74 @@ class LogitsProcessor(nn.Module): ): pruned_states = hidden_states sample_indices = None + input_logprob_indices = None elif ( logits_metadata.forward_mode.is_extend() and not logits_metadata.extend_return_logprob ): # Prefill without input logprobs. - last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1 + if logits_metadata.padded_static_len < 0: + last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1 + else: + # If padding_static length is 5 and extended_seq_lens is [2, 3], + # then our batch looks like [t00, t01, p, p, p, t10, t11, t12, p, p] + # and this retrieves t01 and t12, which are the valid last tokens + idx = torch.arange( + len(logits_metadata.extend_seq_lens), + device=logits_metadata.extend_seq_lens.device, + ) + last_index = ( + idx * logits_metadata.padded_static_len + + logits_metadata.extend_seq_lens + - 1 + ) pruned_states = hidden_states[last_index] sample_indices = None + input_logprob_indices = None else: - # Slice the requested tokens to compute logprob + # Input logprobs are required. + # Find 3 different indices. + # 1. pruned_states: hidden states that we want logprobs from. + # 2. sample_indices: Indices that have sampled tokens. + # 3. input_logprob_indices: Indices that have input logprob tokens. sample_index_pt = -1 sample_indices = [] - pt, pruned_states, pruned_input_ids = 0, [], [] - for start_len, extend_len in zip( + input_logprob_indices_pt = 0 + input_logprob_indices = [] + pt, pruned_states = 0, [] + for extend_logprob_start_len, extend_len in zip( logits_metadata.extend_logprob_start_lens_cpu, logits_metadata.extend_seq_lens_cpu, ): + # It can happen in chunked prefill. We still need to sample 1 token, + # But we don't want to include it in input logprob. + if extend_len == extend_logprob_start_len: + start_len = extend_logprob_start_len - 1 + else: + start_len = extend_logprob_start_len + + # We always need at least 1 token to sample because that's required + # by a caller. + assert extend_len > start_len pruned_states.append(hidden_states[pt + start_len : pt + extend_len]) + pt += extend_len sample_index_pt += extend_len - start_len sample_indices.append(sample_index_pt) - pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len]) - pt += extend_len + input_logprob_indices.extend( + [ + input_logprob_indices_pt + i + for i in range(extend_len - extend_logprob_start_len) + ] + ) + input_logprob_indices_pt += extend_len - start_len pruned_states = torch.cat(pruned_states) + sample_indices = torch.tensor( + sample_indices, device=pruned_states.device, dtype=torch.int64 + ) + input_logprob_indices = torch.tensor( + input_logprob_indices, device=pruned_states.device, dtype=torch.int64 + ) # Compute logits for both input and sampled tokens. logits = self._get_logits(pruned_states, lm_head, logits_metadata) @@ -172,28 +302,51 @@ class LogitsProcessor(nn.Module): logits[sample_indices] if sample_indices is not None else logits ) - if ( - not logits_metadata.extend_return_logprob - or logits_metadata.capture_hidden_mode.need_capture() - ): + if self.debug_tensor_dump_output_folder: + assert ( + not self.do_tensor_parallel_all_gather or get_attention_dp_size() == 1 + ), "dp attention + sharded lm_head doesn't support full logits" + full_logits = self._get_logits(hidden_states, lm_head, logits_metadata) + dump_to_file(self.debug_tensor_dump_output_folder, "logits", full_logits) + + hidden_states_to_store: Optional[torch.Tensor] = None + if logits_metadata.capture_hidden_mode.need_capture(): + if logits_metadata.capture_hidden_mode.is_full(): + hidden_states_to_store = hidden_states + elif logits_metadata.capture_hidden_mode.is_last(): + # Get the last token hidden states. If sample_indices is None, + # pruned states only contain the last tokens already. + hidden_states_to_store = ( + pruned_states[sample_indices] if sample_indices else pruned_states + ) + else: + assert False, "Should never reach" + + if not logits_metadata.extend_return_logprob: # Decode mode or extend mode without return_logprob. return LogitsProcessorOutput( next_token_logits=sampled_logits, - hidden_states=( - hidden_states - if logits_metadata.capture_hidden_mode.is_full() - else ( - pruned_states - if logits_metadata.capture_hidden_mode.is_last() - else None - ) - ), + hidden_states=hidden_states_to_store, ) else: - input_logprobs = logits + input_logprobs = logits[input_logprob_indices] del hidden_states, logits # Normalize the logprob w/o temperature, top-p + pruned_lens = torch.tensor( + logits_metadata.extend_logprob_pruned_lens_cpu, + device=input_logprobs.device, + ) + if logits_metadata.temp_scaled_logprobs: + logits_metadata.temperature = torch.repeat_interleave( + logits_metadata.temperature.view(-1), + pruned_lens, + ).view(-1, 1) + if logits_metadata.top_p_normalized_logprobs: + logits_metadata.top_p = torch.repeat_interleave( + logits_metadata.top_p, + pruned_lens, + ) input_logprobs = self.compute_temp_top_p_normalized_logprobs( input_logprobs, logits_metadata ) @@ -207,14 +360,18 @@ class LogitsProcessor(nn.Module): else: input_top_logprobs_val = input_top_logprobs_idx = None + # Get the logprob of given token id + if logits_metadata.extend_token_ids_logprob: + ( + input_token_ids_logprobs_val, + input_token_ids_logprobs_idx, + ) = self.get_token_ids_logprobs(input_logprobs, logits_metadata) + else: + input_token_ids_logprobs_val = input_token_ids_logprobs_idx = None + input_token_logprobs = input_logprobs[ torch.arange(input_logprobs.shape[0], device=input_logprobs.device), - torch.cat( - [ - torch.cat(pruned_input_ids)[1:], - torch.tensor([0], device=input_logprobs.device), - ] - ), + logits_metadata.extend_input_logprob_token_ids_gpu, ] return LogitsProcessorOutput( @@ -222,6 +379,9 @@ class LogitsProcessor(nn.Module): input_token_logprobs=input_token_logprobs, input_top_logprobs_val=input_top_logprobs_val, input_top_logprobs_idx=input_top_logprobs_idx, + hidden_states=hidden_states_to_store, + input_token_ids_logprobs_val=input_token_ids_logprobs_val, + input_token_ids_logprobs_idx=input_token_ids_logprobs_idx, ) def _get_logits( @@ -231,10 +391,24 @@ class LogitsProcessor(nn.Module): logits_metadata: LogitsMetadata, embedding_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - """Get logits from hidden_states.""" + """Get logits from hidden_states. + + If sampled_logits_only is True, it means hidden_states only contain the + last position (e.g., extend without input logprobs). The caller should + guarantee the given hidden_states follow this constraint. + """ + if self.do_tensor_parallel_all_gather_dp_attn: + logits_metadata.compute_dp_attention_metadata(hidden_states) + hidden_states, local_hidden_states = ( + logits_metadata.gathered_buffer, + hidden_states.clone(), + ) + dp_gather(hidden_states, local_hidden_states, logits_metadata, "embedding") if hasattr(lm_head, "weight"): - logits = torch.matmul(hidden_states, lm_head.weight.T) + logits = torch.matmul( + hidden_states.to(lm_head.weight.dtype), lm_head.weight.T + ) else: # GGUF models logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias) @@ -245,6 +419,17 @@ class LogitsProcessor(nn.Module): if self.do_tensor_parallel_all_gather: logits = tensor_model_parallel_all_gather(logits) + if self.do_tensor_parallel_all_gather_dp_attn: + logits, global_logits = ( + torch.empty( + (local_hidden_states.shape[0], logits.shape[1]), + device=logits.device, + dtype=logits.dtype, + ), + logits, + ) + dp_scatter(logits, global_logits, logits_metadata) + logits = logits[:, : self.config.vocab_size].float() if self.final_logit_softcapping: @@ -272,21 +457,66 @@ class LogitsProcessor(nn.Module): continue input_top_logprobs_val.append( - [values[pt + j][:k] for j in range(pruned_len - 1)] + [values[pt + j][:k] for j in range(pruned_len)] ) input_top_logprobs_idx.append( - [indices[pt + j][:k] for j in range(pruned_len - 1)] + [indices[pt + j][:k] for j in range(pruned_len)] ) pt += pruned_len return input_top_logprobs_val, input_top_logprobs_idx + @staticmethod + def get_token_ids_logprobs( + all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata + ): + input_token_ids_logprobs_val, input_token_ids_logprobs_idx = [], [] + pt = 0 + for token_ids, pruned_len in zip( + logits_metadata.token_ids_logprobs, + logits_metadata.extend_logprob_pruned_lens_cpu, + ): + if pruned_len <= 0: + input_token_ids_logprobs_val.append([]) + input_token_ids_logprobs_idx.append([]) + continue + + input_token_ids_logprobs_val.append( + [all_logprobs[pt + j, token_ids].tolist() for j in range(pruned_len)] + ) + input_token_ids_logprobs_idx.append([token_ids for _ in range(pruned_len)]) + pt += pruned_len + + return input_token_ids_logprobs_val, input_token_ids_logprobs_idx + @staticmethod def compute_temp_top_p_normalized_logprobs( last_logits: torch.Tensor, logits_metadata: LogitsMetadata ) -> torch.Tensor: - # TODO: Implement the temp and top-p normalization - return torch.nn.functional.log_softmax(last_logits, dim=-1) + """ + compute logprobs for the output token from the given logits. + + Returns: + torch.Tensor: logprobs from logits + """ + # Scale logits if temperature scaling is enabled + if logits_metadata.temp_scaled_logprobs: + last_logits = last_logits / logits_metadata.temperature + + # Normalize logprobs if top_p normalization is enabled + # NOTE: only normalize logprobs when top_p is set and not equal to 1.0 + if ( + logits_metadata.top_p_normalized_logprobs + and (logits_metadata.top_p != 1.0).any() + ): + from sglang.srt.layers.sampler import top_p_normalize_probs_torch + + probs = torch.softmax(last_logits, dim=-1) + del last_logits + probs = top_p_normalize_probs_torch(probs, logits_metadata.top_p) + return torch.log(probs) + else: + return torch.nn.functional.log_softmax(last_logits, dim=-1) @triton.jit diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index ae7d13ea5..85f791889 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -144,6 +144,73 @@ def silu_and_mul_triton_kernel( tl.store(down_input_ptr + offset, silu_mul_output, mask=mask) +@triton.jit +def tanh(x): + return 2 * tl.sigmoid(2 * x) - 1 + + +@triton.jit +def gelu_and_mul_triton_kernel( + gateup_output, + down_input, + hidden_size, + reorder_topk_ids, + scales, + start_expert_id, + end_expert_id, + BLOCK_SIZE: tl.constexpr, +): + InDtype = gateup_output.dtype.element_ty + OutDtype = down_input.dtype.element_ty + + half_hidden_size = hidden_size // 2 + + pid = tl.program_id(0) + expert_id = tl.load(reorder_topk_ids + pid) + if expert_id >= start_expert_id and expert_id <= end_expert_id: + gateup_output_ptr = gateup_output + pid * hidden_size + gate_output_ptr = gateup_output_ptr + up_output_ptr = gateup_output_ptr + half_hidden_size + down_input_ptr = down_input + pid * half_hidden_size + + if scales is not None: + scale = tl.load(scales + expert_id - start_expert_id) + scale = (1 / scale).to(InDtype) + else: + scale = 1 + + for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE): + offset = start_offset + tl.arange(0, BLOCK_SIZE) + mask = offset < half_hidden_size + + gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32) + up_output = tl.load(up_output_ptr + offset, mask=mask) + + # gelu & mul & quantize + # https://pytorch.org/docs/stable/generated/torch.nn.GELU.html + # sqrt(2/pi) + kAlpha = 0.7978845608028654 + gate_output = ( + 0.5 + * gate_output + * ( + 1 + + tanh( + kAlpha + * ( + gate_output + + 0.044715 * gate_output * gate_output * gate_output + ) + ) + ) + ) + gate_output = gate_output.to(InDtype) + + gelu_mul_output = gate_output * up_output * scale + gelu_mul_output = gelu_mul_output.to(OutDtype) + tl.store(down_input_ptr + offset, gelu_mul_output, mask=mask) + + @triton.jit def post_reorder_triton_kernel( down_output_ptr, diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 7468c0b91..1c1537810 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -11,6 +11,7 @@ from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, ) from sglang.srt.layers.moe.ep_moe.kernels import ( + gelu_and_mul_triton_kernel, grouped_gemm_triton, post_reorder_triton_kernel, pre_reorder_triton_kernel, @@ -296,6 +297,17 @@ class EPMoE(torch.nn.Module): self.end_expert_id, BLOCK_SIZE=512, ) + elif self.activation == "gelu": + gelu_and_mul_triton_kernel[(gateup_output.shape[0],)]( + gateup_output, + down_input, + gateup_output.shape[1], + reorder_topk_ids, + self.w2_input_scale, + self.start_expert_id, + self.end_expert_id, + BLOCK_SIZE=512, + ) else: raise ValueError(f"Unsupported activation: {self.activation=}") diff --git a/python/sglang/srt/layers/moe/fused_moe_native.py b/python/sglang/srt/layers/moe/fused_moe_native.py index 042c0a52c..97299baa2 100644 --- a/python/sglang/srt/layers/moe/fused_moe_native.py +++ b/python/sglang/srt/layers/moe/fused_moe_native.py @@ -24,6 +24,8 @@ def fused_moe_forward_native( custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, activation: str = "silu", + inplace: bool = True, + no_combine: bool = False, ) -> torch.Tensor: topk_weights, topk_ids = select_experts( hidden_states=x, diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 805a43e45..92f46f009 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -23,7 +23,7 @@ from sglang.srt.utils import ( is_hip, ) -is_hip_flag = is_hip() +is_hip_ = is_hip() logger = logging.getLogger(__name__) @@ -487,6 +487,7 @@ def invoke_fused_moe_kernel( use_int8_w8a8: bool, use_int8_w8a16: bool, block_shape: Optional[List[int]] = None, + no_combine: bool = False, ) -> None: assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 @@ -646,7 +647,7 @@ def get_default_config( "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 8, - "num_stages": 2 if is_hip_flag else 4, + "num_stages": 2 if is_hip_ else 4, } if M <= E: config = { @@ -655,7 +656,7 @@ def get_default_config( "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 2 if is_hip_flag else 4, + "num_stages": 2 if is_hip_ else 4, } else: # Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1] @@ -665,7 +666,7 @@ def get_default_config( "BLOCK_SIZE_K": block_shape[1], "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 2 if is_hip_flag else 3, + "num_stages": 2 if is_hip_ else 3, } else: config = { @@ -814,6 +815,7 @@ def outplace_fused_experts( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, + no_combine: bool = False, ) -> torch.Tensor: return fused_experts_impl( hidden_states, @@ -831,6 +833,7 @@ def outplace_fused_experts( a1_scale, a2_scale, block_shape, + no_combine=no_combine, ) @@ -849,6 +852,7 @@ def outplace_fused_experts_fake( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, + no_combine: bool = False, ) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -877,8 +881,10 @@ def fused_experts( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, + no_combine: bool = False, ): if inplace: + assert not no_combine, "no combine + inplace makes no sense" torch.ops.sglang.inplace_fused_experts( hidden_states, w1, @@ -912,6 +918,7 @@ def fused_experts( a1_scale, a2_scale, block_shape, + no_combine=no_combine, ) @@ -931,6 +938,7 @@ def fused_experts_impl( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, + no_combine: bool = False, ): padded_size = padding_size if not use_fp8_w8a8 or not use_int8_w8a8 or block_shape is not None: @@ -987,7 +995,14 @@ def fused_experts_impl( compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 - if inplace: + if no_combine: + assert not inplace + out_hidden_states = torch.empty( + (num_tokens, topk_ids.shape[1], w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + elif inplace: out_hidden_states = hidden_states else: out_hidden_states = torch.empty_like(hidden_states) @@ -1057,7 +1072,11 @@ def fused_experts_impl( invoke_fused_moe_kernel( intermediate_cache2, w2, - intermediate_cache3, + ( + intermediate_cache3 + if not no_combine and topk_ids.shape[1] != 1 + else out_hidden_states[begin_chunk_idx:end_chunk_idx] + ), a2_scale, w2_scale, curr_topk_weights, @@ -1075,16 +1094,16 @@ def fused_experts_impl( block_shape=block_shape, ) - if is_hip_flag: + if no_combine: + pass + elif is_hip_: ops.moe_sum( intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states[begin_chunk_idx:end_chunk_idx], ) else: if topk_ids.shape[1] == 1: - out_hidden_states[begin_chunk_idx:end_chunk_idx].copy_( - intermediate_cache3[:, 0] - ) + pass # we write directly into out_hidden_states elif topk_ids.shape[1] == 2: torch.add( intermediate_cache3[:, 0], @@ -1122,6 +1141,7 @@ def fused_moe( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, + no_combine: bool = False, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -1191,4 +1211,5 @@ def fused_moe( a1_scale=a1_scale, a2_scale=a2_scale, block_shape=block_shape, + no_combine=no_combine, ) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 4a944fb85..cf9a706b8 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -125,6 +125,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, activation: str = "silu", + inplace: bool = True, + no_combine: bool = False, ) -> torch.Tensor: return self.forward( x=x, @@ -138,6 +140,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): custom_routing_function=custom_routing_function, correction_bias=correction_bias, activation=activation, + inplace=inplace, + no_combine=no_combine, ) def forward_cuda( @@ -153,6 +157,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, activation: str = "silu", + inplace: bool = True, + no_combine: bool = False, ) -> torch.Tensor: topk_weights, topk_ids = select_experts( hidden_states=x, @@ -171,6 +177,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): from aiter.fused_moe import fused_experts_ck assert activation == "silu", f"{activation=} is not supported." + assert not no_combine, "unsupported" return fused_experts_ck( hidden_states=x, @@ -186,8 +193,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, - inplace=True, + inplace=inplace and not no_combine, activation=activation, + no_combine=no_combine, ) def forward_cpu( @@ -202,6 +210,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, + inplace: bool = True, ) -> torch.Tensor: return moe_forward_native( layer, @@ -241,6 +250,7 @@ class FusedMoE(torch.nn.Module): reduce_results: Whether to all all_reduce on the output of the layer renomalize: Whether to renormalize the logits in the fused_moe kernel quant_config: Quantization configure. + inplace: suggestion to compute inplace (modify input activation). """ def __init__( @@ -262,6 +272,8 @@ class FusedMoE(torch.nn.Module): correction_bias: Optional[torch.Tensor] = None, activation: str = "silu", use_presharded_weights: bool = False, + inplace: bool = True, + no_combine: bool = False, ): super().__init__() @@ -285,6 +297,9 @@ class FusedMoE(torch.nn.Module): self.custom_routing_function = custom_routing_function self.correction_bias = correction_bias self.activation = activation + self.use_presharded_weights = use_presharded_weights + self.inplace = inplace + self.no_combine = no_combine if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = ( @@ -304,7 +319,6 @@ class FusedMoE(torch.nn.Module): params_dtype=params_dtype, weight_loader=self.weight_loader, ) - self.use_presharded_weights = use_presharded_weights def _load_per_tensor_weight_scale( self, @@ -598,6 +612,8 @@ class FusedMoE(torch.nn.Module): custom_routing_function=self.custom_routing_function, correction_bias=self.correction_bias, activation=self.activation, + inplace=self.inplace, + no_combine=self.no_combine, ) if self.reduce_results and self.tp_size > 1: diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 9a038f384..2707026e8 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -771,6 +771,8 @@ class Fp8MoEMethod: custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, activation: str = "silu", + inplace: bool = True, + no_combine: bool = False, ) -> torch.Tensor: from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.topk import select_experts @@ -793,6 +795,7 @@ class Fp8MoEMethod: from aiter.fused_moe import fused_experts_ck assert activation == "silu", f"{activation=} is not supported." + assert not no_combine, f"{no_combine=} is not supported." return fused_experts_ck( x, @@ -823,7 +826,7 @@ class Fp8MoEMethod: layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, - inplace=True, + inplace=inplace and not no_combine, activation=activation, use_fp8_w8a8=True, w1_scale=( @@ -839,6 +842,7 @@ class Fp8MoEMethod: a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, block_shape=self.quant_config.weight_block_size, + no_combine=no_combine, ) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index ef8a96c98..c31c2e0b5 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -707,7 +707,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): cos = freqs.cos() * self.mscale sin = freqs.sin() * self.mscale cache = torch.cat((cos, sin), dim=-1) - print("Cache shape", cache.shape) return cache def forward( diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 720e25984..f471626e1 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -1,5 +1,5 @@ import logging -from typing import List +from typing import List, Optional import torch import torch.distributed as dist @@ -41,7 +41,21 @@ class Sampler(nn.Module): sampling_info: SamplingBatchInfo, return_logprob: bool, top_logprobs_nums: List[int], + token_ids_logprobs: List[List[int]], + batch_next_token_ids: Optional[torch.Tensor] = None, ): + """Run a sampler & compute logprobs and update logits_output accordingly. + + Args: + logits_output: The logits from the model forward + sampling_info: Metadata for sampling + return_logprob: If set, store the output logprob information to + logits_output + top_logprobs_nums: Number of top lobprobs per sequence in a batch + batch_next_token_ids: next token IDs. If set, skip sampling and only + compute output logprobs It is used for speculative decoding which + performs sampling in draft workers. + """ logits = logits_output.next_token_logits # Apply the custom logit processors if registered in the sampling info. @@ -58,13 +72,15 @@ class Sampler(nn.Module): if sampling_info.is_all_greedy: # Use torch.argmax if all requests use greedy sampling - batch_next_token_ids = torch.argmax(logits, -1) + if batch_next_token_ids is None: + batch_next_token_ids = torch.argmax(logits, -1) if return_logprob: logprobs = torch.nn.functional.log_softmax(logits, dim=-1) else: # Post process logits logits.div_(sampling_info.temperatures) - probs = torch.softmax(logits, dim=-1) + logits[:] = torch.softmax(logits, dim=-1) + probs = logits del logits if global_server_args_dict["sampling_backend"] == "flashinfer": @@ -78,38 +94,43 @@ class Sampler(nn.Module): top_p_normalize_probs_torch(probs, sampling_info.top_ps) ).clamp(min=torch.finfo(probs.dtype).min) - max_top_k_round, batch_size = 32, probs.shape[0] - uniform_samples = torch.rand( - (max_top_k_round, batch_size), device=probs.device - ) - if sampling_info.need_min_p_sampling: - probs = top_k_renorm_prob(probs, sampling_info.top_ks) - probs = top_p_renorm_prob(probs, sampling_info.top_ps) - batch_next_token_ids = min_p_sampling_from_probs( - probs, uniform_samples, sampling_info.min_ps - ) - else: - batch_next_token_ids, success = top_k_top_p_sampling_from_probs( - probs, - uniform_samples, - sampling_info.top_ks, - sampling_info.top_ps, - filter_apply_order="joint", + if batch_next_token_ids is None: + max_top_k_round, batch_size = 32, probs.shape[0] + uniform_samples = torch.rand( + (max_top_k_round, batch_size), device=probs.device ) + if sampling_info.need_min_p_sampling: + probs = top_k_renorm_prob(probs, sampling_info.top_ks) + probs = top_p_renorm_prob(probs, sampling_info.top_ps) + batch_next_token_ids = min_p_sampling_from_probs( + probs, uniform_samples, sampling_info.min_ps + ) + else: + batch_next_token_ids, success = top_k_top_p_sampling_from_probs( + probs, + uniform_samples, + sampling_info.top_ks, + sampling_info.top_ps, + filter_apply_order="joint", + ) - if self.use_nan_detection and not torch.all(success): - logger.warning("Detected errors during sampling!") - batch_next_token_ids = torch.zeros_like(batch_next_token_ids) + if self.use_nan_detection and not torch.all(success): + logger.warning("Detected errors during sampling!") + batch_next_token_ids = torch.zeros_like( + batch_next_token_ids + ) elif global_server_args_dict["sampling_backend"] == "pytorch": - # A slower fallback implementation with torch native operations. - batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch( - probs, - sampling_info.top_ks, - sampling_info.top_ps, - sampling_info.min_ps, - sampling_info.need_min_p_sampling, - ) + if batch_next_token_ids is None: + # A slower fallback implementation with torch native operations. + batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch( + probs, + sampling_info.top_ks, + sampling_info.top_ps, + sampling_info.min_ps, + sampling_info.need_min_p_sampling, + ) + if return_logprob: # clamp to avoid -inf logprobs = torch.log( @@ -128,6 +149,12 @@ class Sampler(nn.Module): logits_output.next_token_top_logprobs_idx, ) = get_top_logprobs(logprobs, top_logprobs_nums) + if any(x is not None for x in token_ids_logprobs): + ( + logits_output.next_token_token_ids_logprobs_val, + logits_output.next_token_token_ids_logprobs_idx, + ) = get_token_ids_logprobs(logprobs, token_ids_logprobs) + logits_output.next_token_logprobs = logprobs[ torch.arange(len(batch_next_token_ids), device=sampling_info.device), batch_next_token_ids, @@ -223,6 +250,10 @@ def top_p_normalize_probs_torch( def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]): + assert len(top_logprobs_nums) == logprobs.shape[0], ( + len(top_logprobs_nums), + logprobs.shape[0], + ) max_k = max(top_logprobs_nums) ret = logprobs.topk(max_k, dim=1) values = ret.values.tolist() @@ -234,3 +265,17 @@ def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]): output_top_logprobs_val.append(values[i][:k]) output_top_logprobs_idx.append(indices[i][:k]) return output_top_logprobs_val, output_top_logprobs_idx + + +def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List[int]]): + output_token_ids_logprobs_val = [] + output_token_ids_logprobs_idx = [] + for i, token_ids in enumerate(token_ids_logprobs): + if token_ids is not None: + output_token_ids_logprobs_val.append(logprobs[i, token_ids].tolist()) + output_token_ids_logprobs_idx.append(token_ids) + else: + output_token_ids_logprobs_val.append([]) + output_token_ids_logprobs_idx.append([]) + + return output_token_ids_logprobs_val, output_token_ids_logprobs_idx diff --git a/python/sglang/srt/layers/vocab_parallel_embedding.py b/python/sglang/srt/layers/vocab_parallel_embedding.py index ed9d67ef9..22229b643 100644 --- a/python/sglang/srt/layers/vocab_parallel_embedding.py +++ b/python/sglang/srt/layers/vocab_parallel_embedding.py @@ -457,7 +457,7 @@ class VocabParallelEmbedding(torch.nn.Module): assert loaded_weight.shape[output_dim] == ( self.org_vocab_size // (self.tp_size if self.use_presharded_weights else 1) - ) + ), f"{self.org_vocab_size=} {self.use_presharded_weights=} {loaded_weight.shape[output_dim]=}" # Copy the data. if not self.use_presharded_weights: diff --git a/python/sglang/srt/managers/configure_logging.py b/python/sglang/srt/managers/configure_logging.py index 187af4d9c..d331990ff 100644 --- a/python/sglang/srt/managers/configure_logging.py +++ b/python/sglang/srt/managers/configure_logging.py @@ -28,6 +28,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--url", type=str, default="http://localhost:30000") parser.add_argument("--log-requests", action="store_true") + parser.add_argument("--log-requests-level", type=int, default=2) parser.add_argument( "--dump-requests-folder", type=str, default="/tmp/sglang_request_dump" ) @@ -38,7 +39,7 @@ if __name__ == "__main__": args.url + "/configure_logging", json={ "log_requests": args.log_requests, - "log_requests_level": 1, # Log full requests + "log_requests_level": args.log_requests_level, # Log full requests "dump_requests_folder": args.dump_requests_folder, "dump_requests_threshold": args.dump_requests_threshold, }, diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index f1d669fc8..8a4019f83 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -198,6 +198,8 @@ class DataParallelController: self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"] self.max_req_input_len = scheduler_info[0]["max_req_input_len"] + print(f"{scheduler_info=}") + def round_robin_scheduler(self, req): self.workers[self.round_robin_counter].send_pyobj(req) self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers) @@ -220,6 +222,7 @@ class DataParallelController: TokenizedEmbeddingReqInput, ), ): + logger.info("dispatching") self.dispatching(recv_req) else: # Send other control messages to first worker of tp group diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index aa5c6dba8..17bc6e3b3 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -14,6 +14,7 @@ """DetokenizerManager is a process that detokenizes the token ids.""" import dataclasses +import json import logging import os import signal @@ -27,11 +28,16 @@ import zmq from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.managers.io_struct import ( BatchEmbeddingOut, + BatchMultimodalDecodeReq, BatchStrOut, BatchTokenIDOut, ) from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.utils import configure_logger, get_zmq_socket +from sglang.srt.utils import ( + configure_logger, + get_zmq_socket, + kill_itself_when_parent_died, +) from sglang.utils import ( TypeBasedDispatcher, find_printable_text, @@ -86,14 +92,23 @@ class DetokenizerManager: ) self.decode_status = LimitedCapacityDict(capacity=DETOKENIZER_MAX_STATES) + self.is_dummy = server_args.load_format == "dummy" self._request_dispatcher = TypeBasedDispatcher( [ (BatchEmbeddingOut, self.handle_batch_embedding_out), (BatchTokenIDOut, self.handle_batch_token_id_out), + (BatchMultimodalDecodeReq, self.handle_multimodal_decode_req), ] ) + def event_loop(self): + """The event loop that handles requests""" + while True: + recv_obj = self.recv_from_scheduler.recv_pyobj() + output = self._request_dispatcher(recv_obj) + self.send_to_tokenizer.send_pyobj(output) + def trim_matched_stop( self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool ): @@ -117,14 +132,6 @@ class DetokenizerManager: return output[:-1] return output - def event_loop(self): - """The event loop that handles requests""" - - while True: - recv_obj = self.recv_from_scheduler.recv_pyobj() - output = self._request_dispatcher(recv_obj) - self.send_to_tokenizer.send_pyobj(output) - def handle_batch_embedding_out(self, recv_obj: BatchEmbeddingOut): # If it is embedding model, no detokenization is needed. return recv_obj @@ -173,7 +180,6 @@ class DetokenizerManager: # Incremental decoding output_strs = [] - finished_reqs = [] for i in range(bs): try: s = self.decode_status[recv_obj.rids[i]] @@ -196,8 +202,6 @@ class DetokenizerManager: new_text = "" else: new_text = find_printable_text(new_text) - else: - finished_reqs.append(recv_obj.rids[i]) output_strs.append( self.trim_matched_stop( @@ -207,7 +211,7 @@ class DetokenizerManager: ) ) - out = BatchStrOut( + return BatchStrOut( rids=recv_obj.rids, finished_reasons=recv_obj.finished_reasons, output_strs=output_strs, @@ -223,14 +227,15 @@ class DetokenizerManager: input_top_logprobs_idx=recv_obj.input_top_logprobs_idx, output_top_logprobs_val=recv_obj.output_top_logprobs_val, output_top_logprobs_idx=recv_obj.output_top_logprobs_idx, + input_token_ids_logprobs_val=recv_obj.input_token_ids_logprobs_val, + input_token_ids_logprobs_idx=recv_obj.input_token_ids_logprobs_idx, + output_token_ids_logprobs_val=recv_obj.output_token_ids_logprobs_val, + output_token_ids_logprobs_idx=recv_obj.output_token_ids_logprobs_idx, output_hidden_states=recv_obj.output_hidden_states, ) - # remove decodestatus for completed requests - for rid in finished_reqs: - self.decode_status.pop(rid) - - return out + def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq): + raise NotImplementedError() class LimitedCapacityDict(OrderedDict): @@ -250,6 +255,7 @@ def run_detokenizer_process( server_args: ServerArgs, port_args: PortArgs, ): + kill_itself_when_parent_died() setproctitle.setproctitle("sglang::detokenizer") configure_logger(server_args) parent_process = psutil.Process().parent() diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index e105ba943..9c4034c24 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -16,10 +16,11 @@ The definition of objects transfered between different processes (TokenizerManager, DetokenizerManager, Controller). """ +import copy import uuid from dataclasses import dataclass, field from enum import Enum -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union from sglang.srt.managers.schedule_batch import BaseFinishReason from sglang.srt.sampling.sampling_params import SamplingParams @@ -55,6 +56,8 @@ class GenerateReqInput: logprob_start_len: Optional[Union[List[int], int]] = None # If return logprobs, the number of top logprobs to return at each position. top_logprobs_num: Optional[Union[List[int], int]] = None + # If return logprobs, the token ids to return logprob for. + token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None # Whether to detokenize tokens in text in the returned logprobs. return_text_in_logprobs: bool = False # Whether to stream output. @@ -146,6 +149,8 @@ class GenerateReqInput: self.logprob_start_len = -1 if self.top_logprobs_num is None: self.top_logprobs_num = 0 + if not self.token_ids_logprob: # covers both None and [] + self.token_ids_logprob = None else: if self.parallel_sample_num == 1: num = self.batch_size @@ -191,6 +196,17 @@ class GenerateReqInput: else: assert self.parallel_sample_num == 1 + if not self.token_ids_logprob: # covers both None and [] + self.token_ids_logprob = [None] * num + elif not isinstance(self.token_ids_logprob, list): + self.token_ids_logprob = [[self.token_ids_logprob] for _ in range(num)] + elif not isinstance(self.token_ids_logprob[0], list): + self.token_ids_logprob = [ + copy.deepcopy(self.token_ids_logprob) for _ in range(num) + ] + else: + assert self.parallel_sample_num == 1 + if self.custom_logit_processor is None: self.custom_logit_processor = [None] * num elif not isinstance(self.custom_logit_processor, list): @@ -198,6 +214,12 @@ class GenerateReqInput: else: assert self.parallel_sample_num == 1 + # Other checks + if self.session_params is not None: + assert isinstance(self.session_params, dict) or isinstance( + self.session_params[0], dict + ) + def regenerate_rid(self): self.rid = uuid.uuid4().hex return self.rid @@ -212,6 +234,7 @@ class GenerateReqInput: return_logprob=self.return_logprob[i], logprob_start_len=self.logprob_start_len[i], top_logprobs_num=self.top_logprobs_num[i], + token_ids_logprob=self.token_ids_logprob[i], return_text_in_logprobs=self.return_text_in_logprobs, stream=self.stream, log_metrics=self.log_metrics, @@ -244,6 +267,8 @@ class TokenizedGenerateReqInput: logprob_start_len: int # If return logprobs, the number of top logprobs to return at each position. top_logprobs_num: int + # If return logprobs, the token id to return logprob for + token_ids_logprob: List[int] # Whether to stream output stream: bool @@ -378,10 +403,21 @@ class BatchTokenIDOut: input_top_logprobs_idx: List[List] output_top_logprobs_val: List[List] output_top_logprobs_idx: List[List] + input_token_ids_logprobs_val: List[List] + input_token_ids_logprobs_idx: List[List] + output_token_ids_logprobs_val: List[List] + output_token_ids_logprobs_idx: List[List] + # Hidden states output_hidden_states: List[List[float]] +@dataclass +class BatchMultimodalDecodeReq: + # The request id + rids: List[str] + + @dataclass class BatchStrOut: # The request id @@ -406,10 +442,21 @@ class BatchStrOut: input_top_logprobs_idx: List[List] output_top_logprobs_val: List[List] output_top_logprobs_idx: List[List] + input_token_ids_logprobs_val: List[List] + input_token_ids_logprobs_idx: List[List] + output_token_ids_logprobs_val: List[List] + output_token_ids_logprobs_idx: List[List] + # Hidden states output_hidden_states: List[List[float]] +@dataclass +class BatchMultimodalOut: + # The request id + rids: List[str] + + @dataclass class BatchEmbeddingOut: # The request id @@ -439,6 +486,8 @@ class UpdateWeightFromDiskReqInput: class UpdateWeightFromDiskReqOutput: success: bool message: str + # Number of paused requests during weight sync. + num_paused_requests: Optional[int] = 0 @dataclass @@ -526,11 +575,57 @@ class AbortReq: rid: str -class ProfileReq(Enum): +@dataclass +class GetInternalStateReq: + pass + + +@dataclass +class GetInternalStateReqOutput: + internal_state: Dict[Any, Any] + + +@dataclass +class SetInternalStateReq: + server_args: Dict[str, Any] + + +@dataclass +class SetInternalStateReqOutput: + updated: bool + server_args: Dict[str, Any] + + +@dataclass +class ProfileReqInput: + # The output directory + output_dir: Optional[str] = None + # If set, it profile as many as this number of steps. + # If it is set, profiling is automatically stopped after this step, and + # the caller doesn't need to run stop_profile. + num_steps: Optional[int] = None + activities: Optional[List[str]] = None + + +class ProfileReqType(Enum): START_PROFILE = 1 STOP_PROFILE = 2 +@dataclass +class ProfileReq: + type: ProfileReqType + output_dir: Optional[str] = None + num_steps: Optional[int] = None + activities: Optional[List[str]] = None + + +@dataclass +class ProfileReqOutput: + success: bool + message: str + + @dataclass class ConfigureLoggingReq: log_requests: Optional[bool] = None @@ -556,6 +651,11 @@ class OpenSessionReqOutput: success: bool +@dataclass +class HealthCheckOutput: + pass + + @dataclass class Function: description: Optional[str] = None diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 3d34cefb1..f1edcd461 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -29,6 +29,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch It contains low-level tensor data. Most of the data consists of GPU tensors. """ +import copy import dataclasses import logging from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union @@ -50,7 +51,10 @@ from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import ServerArgs if TYPE_CHECKING: - from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm + from sglang.srt.server_args import ServerArgs + from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput + from sglang.srt.speculative.spec_info import SpeculativeAlgorithm + INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 @@ -65,6 +69,8 @@ global_server_args_dict = { "enable_dp_attention": ServerArgs.enable_dp_attention, "enable_ep_moe": ServerArgs.enable_ep_moe, "device": ServerArgs.device, + "speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single, + "speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc, "enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla, "disable_radix_cache": ServerArgs.disable_radix_cache, "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged, @@ -230,6 +236,7 @@ class Req: sampling_params: SamplingParams, return_logprob: bool = False, top_logprobs_num: int = 0, + token_ids_logprob: List[int] = None, stream: bool = False, origin_input_ids_unpadded: Optional[Tuple[int]] = None, lora_path: Optional[str] = None, @@ -256,17 +263,24 @@ class Req: self.input_embeds = input_embeds # Sampling info + if isinstance(sampling_params.custom_params, dict): + sampling_params = copy.copy(sampling_params) + sampling_params.custom_params = sampling_params.custom_params | { + "__req__": self + } self.sampling_params = sampling_params self.custom_logit_processor = custom_logit_processor self.return_hidden_states = return_hidden_states # Memory pool info - self.req_pool_idx = None + self.req_pool_idx: Optional[int] = None # Check finish self.tokenizer = None self.finished_reason = None + # If we want to abort the request in the middle of the event loop, set this to true + # Note: We should never set finished_reason in the middle, the req will get filtered and never respond self.to_abort = False self.stream = stream self.eos_token_ids = eos_token_ids @@ -289,38 +303,56 @@ class Req: self.image_inputs: Optional[ImageInputs] = None # Prefix info + # The indices to kv cache for the shared prefix. self.prefix_indices = [] - # Tokens to run prefill. input_tokens - shared_prefix_tokens. - # Updated if chunked. + # Number of tokens to run prefill. self.extend_input_len = 0 + # The relative logprob_start_len in an extend batch + self.extend_logprob_start_len = 0 self.last_node = None - # Chunked prefill - self.is_being_chunked = 0 + # Whether or not if it is chunked. It increments whenever + # it is chunked, and decrement whenever chunked request is + # processed. + self.is_chunked = 0 # For retraction self.is_retracted = False # Logprobs (arguments) self.return_logprob = return_logprob + # Start index to compute logprob from. self.logprob_start_len = 0 self.top_logprobs_num = top_logprobs_num + self.token_ids_logprob = token_ids_logprob # Logprobs (return values) self.input_token_logprobs_val: Optional[List[float]] = None self.input_token_logprobs_idx: Optional[List[int]] = None self.input_top_logprobs_val: Optional[List[float]] = None self.input_top_logprobs_idx: Optional[List[int]] = None + self.input_token_ids_logprobs_val: Optional[List[float]] = None + self.input_token_ids_logprobs_idx: Optional[List[int]] = None + # Temporary holder to store input_token_logprobs. + self.input_token_logprobs: Optional[List[Tuple[int]]] = None + self.temp_input_top_logprobs_val: Optional[List[torch.Tensor]] = None + self.temp_input_top_logprobs_idx: Optional[List[int]] = None + self.temp_input_token_ids_logprobs_val: Optional[List[float]] = None + self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None if return_logprob: self.output_token_logprobs_val = [] self.output_token_logprobs_idx = [] self.output_top_logprobs_val = [] self.output_top_logprobs_idx = [] + self.output_token_ids_logprobs_val = [] + self.output_token_ids_logprobs_idx = [] else: self.output_token_logprobs_val = self.output_token_logprobs_idx = ( self.output_top_logprobs_val - ) = self.output_top_logprobs_idx = None + ) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = ( + self.output_token_ids_logprobs_idx + ) = None self.hidden_states = [] # Logprobs (internal values) @@ -345,6 +377,13 @@ class Req: self.spec_verify_ct = 0 self.lora_path = lora_path + # This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop + self.to_abort_message: str = "Unknown error" + + @property + def seqlen(self): + return len(self.origin_input_ids) + len(self.output_ids) + def extend_image_inputs(self, image_inputs): if self.image_inputs is None: self.image_inputs = image_inputs @@ -422,7 +461,9 @@ class Req: return if self.to_abort: - self.finished_reason = FINISH_ABORT() + self.finished_reason = FINISH_ABORT( + message=self.to_abort_message, + ) return if len(self.output_ids) >= self.sampling_params.max_new_tokens: @@ -517,6 +558,8 @@ class Req: self.output_token_logprobs_idx = self.output_token_logprobs_idx[:k] self.output_top_logprobs_val = self.output_top_logprobs_val[:k] self.output_top_logprobs_idx = self.output_top_logprobs_idx[:k] + self.output_token_ids_logprobs_val = self.output_token_ids_logprobs_val[:k] + self.output_token_ids_logprobs_idx = self.output_token_ids_logprobs_idx[:k] self.logprob_start_len = prompt_tokens + k self.last_update_decode_tokens = len(self.output_ids) - k @@ -527,16 +570,19 @@ class Req: self.last_node = None self.extend_input_len = 0 self.is_retracted = True + self.input_token_logprobs = None + self.temp_input_top_logprobs_val = None + self.temp_input_top_logprobs_idx = None + self.extend_logprob_start_len = 0 + self.is_chunked = 0 + self.req_pool_idx = None - # For incremental logprobs - # TODO: Fix the `logprob_start_len` self.last_update_decode_tokens = 0 - self.logprob_start_len = 10**9 def __repr__(self): return ( - f"rid(n={self.rid}, " - f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}" + f"Req(rid={self.rid}, " + f"input_ids={self.origin_input_ids}, output_ids={self.output_ids})" ) @@ -576,11 +622,13 @@ class ScheduleBatch: # For DP attention global_num_tokens: Optional[List[int]] = None + global_num_tokens_for_logprob: Optional[List[int]] = None can_run_dp_cuda_graph: bool = False # For processing logprobs return_logprob: bool = False top_logprobs_nums: Optional[List[int]] = None + token_ids_logprobs: Optional[List[List[int]]] = None # For extend and mixed chunekd prefill prefix_lens: List[int] = None @@ -588,6 +636,8 @@ class ScheduleBatch: extend_num_tokens: int = None decoding_reqs: List[Req] = None extend_logprob_start_lens: List[int] = None + # It comes empty list if logprob is not required. + extend_input_logprob_token_ids: Optional[torch.Tensor] = None # For encoder-decoder encoder_cached: Optional[List[bool]] = None @@ -606,7 +656,7 @@ class ScheduleBatch: # Speculative decoding spec_algorithm: SpeculativeAlgorithm = None - spec_info: Optional[SpecInfo] = None + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None # Enable custom logit processor enable_custom_logit_processor: bool = False @@ -653,8 +703,10 @@ class ScheduleBatch: req_pool_indices = self.req_to_token_pool.alloc(num_reqs) if req_pool_indices is None: raise RuntimeError( - "Out of memory. " - "Please set a smaller number for `--max-running-requests`." + "alloc_req_slots runs out of memory. " + "Please set a smaller number for `--max-running-requests`. " + f"{self.req_to_token_pool.available_size()=}, " + f"{num_reqs=}, " ) return req_pool_indices @@ -765,6 +817,7 @@ class ScheduleBatch: out_cache_loc = self.alloc_token_slots(extend_num_tokens) input_embeds = [] + extend_input_logprob_token_ids = [] pt = 0 for i, req in enumerate(reqs): @@ -783,22 +836,64 @@ class ScheduleBatch: # If req.input_embeds is already a list, append its content directly input_embeds.extend(req.input_embeds) # Use extend to avoid nesting - if req.return_logprob: - # Compute the relative logprob_start_len in an extend batch - if req.logprob_start_len >= pre_len: - extend_logprob_start_len = min( - req.logprob_start_len - pre_len, req.extend_input_len - 1 - ) - else: - raise RuntimeError( - f"This should never happen. {req.logprob_start_len=}, {pre_len=}" - ) - req.extend_logprob_start_len = extend_logprob_start_len - req.cached_tokens += pre_len - req.already_computed req.already_computed = seq_len req.is_retracted = False pre_lens.append(pre_len) + # Compute the relative logprob_start_len in an extend batch + if req.logprob_start_len >= pre_len: + req.extend_logprob_start_len = min( + req.logprob_start_len - pre_len, + req.extend_input_len, + req.seqlen - 1, + ) + else: + req.extend_logprob_start_len = 0 + + if self.return_logprob: + # Find input logprob token ids. + # First, find a global index within origin_input_ids and slide it by 1 + # to compute input logprobs. It is because you need the next token + # to compute input logprobs. E.g., (chunk size 2) + # + # input_logprobs = [1, 2, 3, 4] + # fill_ids = [1, 2] + # extend_input_logprob_token_id = [2, 3] + # + # Note that it can also overflow. In this case, we pad it with 0. + # input_logprobs = [1, 2, 3, 4] + # fill_ids = [3, 4] + # extend_input_logprob_token_id = [4, 0] + global_start_idx, global_end_idx = ( + len(req.prefix_indices), + len(req.fill_ids), + ) + # Apply logprob_start_len + if global_start_idx < req.logprob_start_len: + global_start_idx = req.logprob_start_len + + logprob_token_ids = req.origin_input_ids[ + global_start_idx + 1 : global_end_idx + 1 + ] + extend_input_logprob_token_ids.extend(logprob_token_ids) + + # We will need req.extend_input_len - req.extend_logprob_start_len number of + # tokens, and logprob_token_ids is for input logprob, so pad the rest of them by 0. + extend_input_logprob_token_ids.extend( + [0] + * ( + req.extend_input_len + - req.extend_logprob_start_len + - len(logprob_token_ids) + ) + ) + + if self.return_logprob: + extend_input_logprob_token_ids = torch.tensor( + extend_input_logprob_token_ids + ) + else: + extend_input_logprob_token_ids = None # Set fields self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to( @@ -821,10 +916,12 @@ class ScheduleBatch: self.seq_lens_sum = sum(seq_lens) if self.return_logprob: self.top_logprobs_nums = [r.top_logprobs_num for r in reqs] + self.token_ids_logprobs = [r.token_ids_logprob for r in reqs] self.extend_num_tokens = extend_num_tokens self.prefix_lens = [len(r.prefix_indices) for r in reqs] self.extend_lens = [r.extend_input_len for r in reqs] self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs] + self.extend_input_logprob_token_ids = extend_input_logprob_token_ids # Write to req_to_token_pool pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to( @@ -860,7 +957,6 @@ class ScheduleBatch: self.sampling_info = SamplingBatchInfo.from_schedule_batch( self, self.model_config.vocab_size, - enable_overlap_schedule=self.enable_overlap, ) def mix_with_running(self, running_batch: "ScheduleBatch"): @@ -905,25 +1001,43 @@ class ScheduleBatch: return False - def retract_decode(self): + def retract_decode(self, server_args: ServerArgs): """Retract the decoding requests when there is not enough memory.""" sorted_indices = [i for i in range(len(self.reqs))] # TODO(lsyin): improve retraction policy for radix cache - sorted_indices.sort( - key=lambda i: ( - len(self.reqs[i].output_ids), - -len(self.reqs[i].origin_input_ids), - ), - reverse=True, - ) + # For spec decoding, filter_batch API can only filter + # requests from the back, so we can only retract from the back. + # TODO(sang): Clean up finish path and support better retract + # policy. + if not server_args.speculative_algorithm: + sorted_indices.sort( + key=lambda i: ( + len(self.reqs[i].output_ids), + -len(self.reqs[i].origin_input_ids), + ), + reverse=True, + ) + + def get_required_tokens(num_reqs: int): + headroom_for_spec_decode = 0 + if server_args.speculative_algorithm: + headroom_for_spec_decode += ( + num_reqs + * server_args.speculative_eagle_topk + * server_args.speculative_num_steps + + num_reqs * server_args.speculative_num_draft_tokens + ) + return ( + num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode + ) retracted_reqs = [] seq_lens_cpu = self.seq_lens.cpu().numpy() first_iter = True while ( self.token_to_kv_pool.available_size() - < len(sorted_indices) * global_config.retract_decode_steps + < get_required_tokens(len(sorted_indices)) or first_iter ): if len(sorted_indices) == 1: @@ -1048,17 +1162,40 @@ class ScheduleBatch: self.sampling_info = SamplingBatchInfo.from_schedule_batch( self, self.model_config.vocab_size, - enable_overlap_schedule=self.enable_overlap, ) def prepare_for_decode(self): self.forward_mode = ForwardMode.DECODE if self.spec_algorithm.is_eagle(): + # if spec decoding is used, the decode batch is prepared inside + # `forward_batch_speculative_generation` after running draft models. return + if self.sampling_info.penalizer_orchestrator.is_required: + if self.enable_overlap: + # TODO: this can be slow, optimize this. + delayed_output_ids = torch.tensor( + [ + ( + req.output_ids[-1] + if len(req.output_ids) + else req.origin_input_ids[-1] + ) + for req in self.reqs + ], + dtype=torch.int64, + device=self.device, + ) + self.sampling_info.penalizer_orchestrator.cumulate_output_tokens( + delayed_output_ids + ) + else: + self.sampling_info.penalizer_orchestrator.cumulate_output_tokens( + self.output_ids.to(torch.int64) + ) + self.input_ids = self.output_ids self.output_ids = None - self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(self.input_ids) # Alloc mem bs = len(self.reqs) @@ -1086,14 +1223,15 @@ class ScheduleBatch: def filter_batch( self, - being_chunked_req: Optional[Req] = None, + chunked_req_to_exclude: Optional[Req] = None, keep_indices: Optional[List[int]] = None, ): if keep_indices is None: keep_indices = [ i for i in range(len(self.reqs)) - if not self.reqs[i].finished() and self.reqs[i] is not being_chunked_req + if not self.reqs[i].finished() + and self.reqs[i] is not chunked_req_to_exclude ] if keep_indices is None or len(keep_indices) == 0: @@ -1105,31 +1243,34 @@ class ScheduleBatch: # No need to filter return + keep_indices_device = torch.tensor(keep_indices, dtype=torch.int64).to( + self.device, non_blocking=True + ) + if self.model_config.is_encoder_decoder: - self.encoder_lens = self.encoder_lens[keep_indices] + self.encoder_lens = self.encoder_lens[keep_indices_device] self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices] self.reqs = [self.reqs[i] for i in keep_indices] - new_indices = torch.tensor(keep_indices, dtype=torch.int64).to( - self.device, non_blocking=True - ) - self.req_pool_indices = self.req_pool_indices[new_indices] - self.seq_lens = self.seq_lens[new_indices] + self.req_pool_indices = self.req_pool_indices[keep_indices_device] + self.seq_lens = self.seq_lens[keep_indices_device] self.out_cache_loc = None self.seq_lens_sum = self.seq_lens.sum().item() - self.output_ids = self.output_ids[new_indices] + self.output_ids = self.output_ids[keep_indices_device] self.return_logprob = any(req.return_logprob for req in self.reqs) if self.return_logprob: self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices] + self.token_ids_logprobs = [self.token_ids_logprobs[i] for i in keep_indices] else: self.top_logprobs_nums = None + self.token_ids_logprobs = None self.has_stream = any(req.stream for req in self.reqs) self.has_grammar = any(req.grammar for req in self.reqs) - self.sampling_info.filter_batch(keep_indices, new_indices) + self.sampling_info.filter_batch(keep_indices, keep_indices_device) if self.spec_info: - self.spec_info.filter_batch(new_indices) + self.spec_info.filter_batch(keep_indices_device) def merge_batch(self, other: "ScheduleBatch"): # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because @@ -1152,10 +1293,13 @@ class ScheduleBatch: self.output_ids = torch.concat([self.output_ids, other.output_ids]) if self.return_logprob and other.return_logprob: self.top_logprobs_nums.extend(other.top_logprobs_nums) + self.token_ids_logprobs.extend(other.token_ids_logprobs) elif self.return_logprob: self.top_logprobs_nums.extend([0] * len(other.reqs)) + self.token_ids_logprobs.extend([None] * len(other.reqs)) elif other.return_logprob: self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums + self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs self.reqs.extend(other.reqs) self.return_logprob |= other.return_logprob @@ -1192,7 +1336,9 @@ class ScheduleBatch: seq_lens_sum=self.seq_lens_sum, return_logprob=self.return_logprob, top_logprobs_nums=self.top_logprobs_nums, + token_ids_logprobs=self.token_ids_logprobs, global_num_tokens=self.global_num_tokens, + global_num_tokens_for_logprob=self.global_num_tokens_for_logprob, can_run_dp_cuda_graph=self.can_run_dp_cuda_graph, extend_num_tokens=self.extend_num_tokens, extend_seq_lens=extend_seq_lens, @@ -1219,6 +1365,7 @@ class ScheduleBatch: else CaptureHiddenMode.NULL ) ), + extend_input_logprob_token_ids=self.extend_input_logprob_token_ids, ) def copy(self): @@ -1262,9 +1409,11 @@ class ModelWorkerBatch: # For logprob return_logprob: bool top_logprobs_nums: Optional[List[int]] + token_ids_logprobs: Optional[List[List[int]]] # For DP attention global_num_tokens: Optional[List[int]] + global_num_tokens_for_logprob: Optional[List[int]] can_run_dp_cuda_graph: bool # For extend @@ -1272,6 +1421,7 @@ class ModelWorkerBatch: extend_seq_lens: Optional[List[int]] extend_prefix_lens: Optional[List[int]] extend_logprob_start_lens: Optional[List[int]] + extend_input_logprob_token_ids: Optional[torch.Tensor] # For multimodal image_inputs: Optional[List[ImageInputs]] @@ -1293,7 +1443,8 @@ class ModelWorkerBatch: # Speculative decoding spec_algorithm: SpeculativeAlgorithm = None - spec_info: Optional[SpecInfo] = None + spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None + # If set, the output of the batch contains the hidden states of the run. capture_hidden_mode: CaptureHiddenMode = None diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index a3a099b83..916692446 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -272,7 +272,7 @@ class PrefillAdder: self.req_states = None self.can_run_list = [] - self.new_being_chunked_req = None + self.new_chunked_req = None self.log_hit_tokens = 0 self.log_input_tokens = 0 @@ -327,7 +327,7 @@ class PrefillAdder: self.log_hit_tokens += prefix_len self.log_input_tokens += extend_input_len - def add_being_chunked_req(self, req: Req): + def add_chunked_req(self, req: Req): truncated = req.extend_input_len > self.rem_chunk_tokens req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens) req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len] @@ -354,7 +354,7 @@ class PrefillAdder: finally: self.tree_cache.dec_lock_ref(last_node) - def add_one_req_ignore_eos(self, req: Req): + def add_one_req_ignore_eos(self, req: Req, has_chunked_req: bool): def add_req_state(r, insert_sort=False): new_token_ratio = ( 1.0 if r.sampling_params.ignore_eos else self.new_token_ratio @@ -403,6 +403,7 @@ class PrefillAdder: self.rem_chunk_tokens is None or req.extend_input_len <= self.rem_chunk_tokens ): + # Non-chunked prefill self.can_run_list.append(req) self._prefill_one_req( 0, @@ -418,14 +419,14 @@ class PrefillAdder: req.extend_input_len = trunc_len req.fill_ids = req.fill_ids[:trunc_len] self.can_run_list.append(req) - self.new_being_chunked_req = req + self.new_chunked_req = req self._prefill_one_req(0, trunc_len, 0) return self.budget_state() - def add_one_req(self, req: Req): + def add_one_req(self, req: Req, has_chunked_req: bool): if req.sampling_params.ignore_eos and self.tree_cache.disable: - return self.add_one_req_ignore_eos(req) + return self.add_one_req_ignore_eos(req, has_chunked_req) total_tokens = req.extend_input_len + min( req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION @@ -443,14 +444,7 @@ class PrefillAdder: if total_tokens > self.rem_total_tokens: return AddReqResult.NO_TOKEN - if ( - self.rem_chunk_tokens is None - or input_tokens <= self.rem_chunk_tokens - or ( - req.return_logprob - and req.logprob_start_len != len(req.origin_input_ids) - 1 - ) - ): + if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens: # Non-chunked prefill self.can_run_list.append(req) self.tree_cache.inc_lock_ref(req.last_node) @@ -470,8 +464,9 @@ class PrefillAdder: req.extend_input_len = trunc_len req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len] + self.can_run_list.append(req) - self.new_being_chunked_req = req + self.new_chunked_req = req self.tree_cache.inc_lock_ref(req.last_node) self._prefill_one_req(prefix_len, trunc_len, 0) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 1de73137f..a02bcf785 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -17,10 +17,11 @@ import faulthandler import logging import os import signal +import sys import threading import time import warnings -from collections import deque +from collections import defaultdict, deque from concurrent import futures from dataclasses import dataclass from http import HTTPStatus @@ -41,20 +42,28 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.io_struct import ( AbortReq, BatchEmbeddingOut, + BatchMultimodalDecodeReq, BatchTokenIDOut, CloseSessionReqInput, FlushCacheReq, + GetInternalStateReq, + GetInternalStateReqOutput, GetWeightsByNameReqInput, GetWeightsByNameReqOutput, + HealthCheckOutput, InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqOutput, OpenSessionReqInput, OpenSessionReqOutput, ProfileReq, + ProfileReqOutput, + ProfileReqType, ReleaseMemoryOccupationReqInput, ReleaseMemoryOccupationReqOutput, ResumeMemoryOccupationReqInput, ResumeMemoryOccupationReqOutput, + SetInternalStateReq, + SetInternalStateReqOutput, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, UpdateWeightFromDiskReqInput, @@ -95,6 +104,8 @@ from sglang.srt.utils import ( crash_on_warnings, get_bool_env_var, get_zmq_socket, + kill_itself_when_parent_died, + pyspy_dump_schedulers, set_gpu_proc_affinity, set_random_seed, suppress_other_loggers, @@ -104,13 +115,16 @@ from sglang.utils import TypeBasedDispatcher, get_exception_traceback logger = logging.getLogger(__name__) # Test retract decode for debugging purposes -test_retract = get_bool_env_var("SGLANG_TEST_RETRACT") +TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT") +RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME") @dataclass class GenerationBatchResult: logits_output: LogitsProcessorOutput next_token_ids: List[int] + extend_input_len_per_req: List[int] + extend_logprob_start_len_per_req: List[int] bid: int @@ -142,15 +156,23 @@ class Scheduler: self.enable_overlap = not server_args.disable_overlap_schedule self.skip_tokenizer_init = server_args.skip_tokenizer_init self.enable_metrics = server_args.enable_metrics + self.stream_interval = server_args.stream_interval self.spec_algorithm = SpeculativeAlgorithm.from_string( server_args.speculative_algorithm ) + self.gpu_id = gpu_id + self.enable_hierarchical_cache = server_args.enable_hierarchical_cache self.decode_mem_cache_buf_multiplier = ( - self.server_args.speculative_num_draft_tokens + ( + self.server_args.speculative_num_draft_tokens + + ( + self.server_args.speculative_eagle_topk + * self.server_args.speculative_num_steps + ) + ) if not self.spec_algorithm.is_none() else 1 ) - self.enable_hierarchical_cache = server_args.enable_hierarchical_cache # Distributed rank info self.dp_size = server_args.dp_size @@ -246,7 +268,7 @@ class Scheduler: nccl_port=port_args.nccl_port, ) - # Launch a worker for speculative decoding if needed + # Launch a draft worker for speculative decoding if self.spec_algorithm.is_eagle(): from sglang.srt.speculative.eagle_worker import EAGLEWorker @@ -258,8 +280,10 @@ class Scheduler: target_worker=self.tp_worker, dp_rank=dp_rank, ) + self.prefill_only_one_req = True else: self.draft_worker = None + self.prefill_only_one_req = False # Get token and memory info from the model worker ( @@ -280,6 +304,7 @@ class Scheduler: self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func() global_server_args_dict.update(worker_global_server_args_dict) set_random_seed(self.random_seed) + # Print debug info logger.info( f"max_total_num_tokens={self.max_total_num_tokens}, " @@ -301,19 +326,18 @@ class Scheduler: token_to_kv_pool=self.token_to_kv_pool, ) else: - self.tree_cache = ( - HiRadixCache( + if self.enable_hierarchical_cache: + self.tree_cache = HiRadixCache( req_to_token_pool=self.req_to_token_pool, token_to_kv_pool=self.token_to_kv_pool, ) - if self.enable_hierarchical_cache - else RadixCache( + else: + self.tree_cache = RadixCache( req_to_token_pool=self.req_to_token_pool, token_to_kv_pool=self.token_to_kv_pool, disable=server_args.disable_radix_cache, ) - ) - self.tree_cache_metrics = {"total": 0, "hit": 0} + self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache) # Init running status @@ -330,12 +354,23 @@ class Scheduler: self.num_generated_tokens = 0 self.spec_num_total_accepted_tokens = 0 self.spec_num_total_forward_ct = 0 + self.cum_spec_accept_length = 0 + self.cum_spec_accept_count = 0 self.last_decode_stats_tic = time.time() + self.return_health_check_ct = 0 self.stream_interval = server_args.stream_interval self.current_stream = torch.get_device_module(self.device).current_stream() if self.device == "cpu": self.current_stream.synchronize = lambda: None # No-op for CPU + # For metrics only. + # The largest prefill length of a single request + self._largest_prefill_len: int = 0 + # The largest context length (prefill + generation) of a single request + self._largest_prefill_decode_len: int = 0 + self.last_gen_throughput: float = 0.0 + self.step_time_dict = defaultdict(list) # Dict[batch size -> step time] + # Session info self.sessions: Dict[str, Session] = {} @@ -343,7 +378,7 @@ class Scheduler: self.chunked_prefill_size = server_args.chunked_prefill_size if self.chunked_prefill_size <= 0: # -1 means disable self.chunked_prefill_size = None - self.being_chunked_req = None + self.chunked_req = None self.is_mixed_chunk = ( self.chunked_prefill_size is not None and server_args.enable_mixed_chunk ) @@ -377,7 +412,7 @@ class Scheduler: ) / global_config.default_new_token_ratio_decay_steps self.new_token_ratio = self.init_new_token_ratio - # Tells whether the current running batch is full so that we can skip + # Tell whether the current running batch is full so that we can skip # the check of whether to prefill new requests. # This is an optimization to reduce the overhead of the prefill check. self.batch_is_full = False @@ -388,26 +423,16 @@ class Scheduler: t.start() self.parent_process = psutil.Process().parent() + # Init memory saver self.memory_saver_adapter = TorchMemorySaverAdapter.create( enable=server_args.enable_memory_saver ) # Init profiler - if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "": - self.profiler = None - else: - self.torch_profiler_trace_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR") - logger.info( - "Profiling enabled. Traces will be saved to: %s", - self.torch_profiler_trace_dir, - ) - self.profiler = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - with_stack=True, - ) + self.torch_profiler = None + self.torch_profiler_output_dir: Optional[str] = None + self.torch_profiler_activities: Optional[List[str]] = None + self.profiler_target_forward_ct: Optional[int] = None # Init metrics stats self.stats = SchedulerStats() @@ -431,6 +456,8 @@ class Scheduler: (TokenizedEmbeddingReqInput, self.handle_embedding_request), (FlushCacheReq, self.flush_cache_wrapped), (AbortReq, self.abort_request), + (OpenSessionReqInput, self.open_session), + (CloseSessionReqInput, self.close_session), (UpdateWeightFromDiskReqInput, self.update_weights_from_disk), (InitWeightsUpdateGroupReqInput, self.init_weights_update_group), ( @@ -439,22 +466,16 @@ class Scheduler: ), (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor), (GetWeightsByNameReqInput, self.get_weights_by_name), + (ReleaseMemoryOccupationReqInput, self.release_memory_occupation), + (ResumeMemoryOccupationReqInput, self.resume_memory_occupation), (ProfileReq, self.profile), - (OpenSessionReqInput, self.open_session), - (CloseSessionReqInput, self.close_session), - ( - ReleaseMemoryOccupationReqInput, - lambda _: self.release_memory_occupation(), - ), - ( - ResumeMemoryOccupationReqInput, - lambda _: self.resume_memory_occupation(), - ), + (GetInternalStateReq, self.get_internal_state), + (SetInternalStateReq, self.set_internal_state), ] ) def watchdog_thread(self): - """A watch dog thread that will try to kill the server itself if one batch takes too long.""" + """A watch dog thread that will try to kill the server itself if one forward batch takes too long.""" self.watchdog_last_forward_ct = 0 self.watchdog_last_time = time.time() @@ -469,7 +490,18 @@ class Scheduler: self.watchdog_last_forward_ct = self.forward_ct self.watchdog_last_time = current time.sleep(self.watchdog_timeout // 2) - # Wait sometimes so that the parent process can print the error. + + # Print batch size and memory pool info to check whether there are de-sync issues. + logger.error( + f"{self.cur_batch.batch_size()=}, " + f"{self.cur_batch.reqs=}, " + f"{self.token_to_kv_pool.available_size()=}, " + f"{self.tree_cache.evictable_size()=}, " + ) + # Wait for some time so that the parent process can print the error. + pyspy_dump_schedulers() + print(file=sys.stderr, flush=True) + print(file=sys.stdout, flush=True) time.sleep(5) self.parent_process.send_signal(signal.SIGQUIT) @@ -586,6 +618,13 @@ class Scheduler: def process_input_requests(self, recv_reqs: List): for recv_req in recv_reqs: + # If it is a health check generation request and there are running requests, ignore it. + if is_health_check_generate_req(recv_req) and ( + self.chunked_req is not None or self.running_batch is not None + ): + self.return_health_check_ct += 1 + continue + output = self._request_dispatcher(recv_req) if output is not None: self.send_to_tokenizer.send_pyobj(output) @@ -600,7 +639,6 @@ class Scheduler: or recv_req.session_params.id is None or recv_req.session_params.id not in self.sessions ): - if recv_req.input_embeds is not None: # Generate fake input_ids based on the length of input_embeds seq_length = len(recv_req.input_embeds) @@ -627,6 +665,7 @@ class Scheduler: recv_req.sampling_params, return_logprob=recv_req.return_logprob, top_logprobs_num=recv_req.top_logprobs_num, + token_ids_logprob=recv_req.token_ids_logprob, stream=recv_req.stream, lora_path=recv_req.lora_path, input_embeds=recv_req.input_embeds, @@ -643,14 +682,14 @@ class Scheduler: req.finished_reason = FINISH_ABORT( f"Invalid request: session id {recv_req.session_params.id} does not exist" ) - self.waiting_queue.append(req) + self._add_request_to_queue(req) return else: # Create a new request from a previous session session = self.sessions[recv_req.session_params.id] req = session.create_req(recv_req, self.tokenizer) if isinstance(req.finished_reason, FINISH_ABORT): - self.waiting_queue.append(req) + self._add_request_to_queue(req) return # Handle multimodal inputs @@ -674,7 +713,7 @@ class Scheduler: req.finished_reason = FINISH_ABORT( error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError" ) - self.waiting_queue.append(req) + self._add_request_to_queue(req) return # Validate prompts length @@ -686,16 +725,26 @@ class Scheduler: if error_msg: req.origin_input_ids = [0] req.sampling_params.max_new_tokens = 0 - self.waiting_queue.append(req) + self._add_request_to_queue(req) return # Copy more attributes - if recv_req.logprob_start_len == -1: + if recv_req.logprob_start_len == -1 or not recv_req.return_logprob: # By default, only return the logprobs for output tokens req.logprob_start_len = len(req.origin_input_ids) - 1 else: req.logprob_start_len = recv_req.logprob_start_len + if req.logprob_start_len >= len(req.origin_input_ids): + req.finished_reason = FINISH_ABORT( + f"logprob_start_len, ({req.logprob_start_len}) is higher than the number of input tokens ({len(req.origin_input_ids)}). Request with a lower logprob_start_len.", + HTTPStatus.BAD_REQUEST, + "BadRequestError", + ) + req.logprob_start_len = len(req.origin_input_ids) - 1 + self._add_request_to_queue(req) + return + req.sampling_params.max_new_tokens = min( ( req.sampling_params.max_new_tokens @@ -731,7 +780,13 @@ class Scheduler: if add_to_grammar_queue: self.grammar_queue.append(req) else: - self.waiting_queue.append(req) + self._add_request_to_queue(req) + + def _add_request_to_queue(self, req: Req): + self.waiting_queue.append(req) + + def _extend_requests_to_queue(self, reqs: List[Req]): + self.waiting_queue.extend(reqs) def handle_embedding_request( self, @@ -752,61 +807,62 @@ class Scheduler: self.server_args.allow_auto_truncate, ) if error_msg: - self.waiting_queue.append(req) + self._add_request_to_queue(req) return # Copy more attributes req.logprob_start_len = len(req.origin_input_ids) - 1 - self.waiting_queue.append(req) + self._add_request_to_queue(req) def log_prefill_stats( self, adder: PrefillAdder, can_run_list: List[Req], - running_bs: ScheduleBatch, - has_being_chunked: bool, + running_bs: int, ): - self.tree_cache_metrics["total"] += ( - adder.log_input_tokens + adder.log_hit_tokens - ) / 10**9 - self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9 - tree_cache_hit_rate = ( - self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"] - ) - num_used = self.max_total_num_tokens - ( self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() ) + self._largest_prefill_len = max( + self._largest_prefill_len, adder.log_input_tokens + ) - logger.info( + f = ( f"Prefill batch. " f"#new-seq: {len(can_run_list)}, " f"#new-token: {adder.log_input_tokens}, " f"#cached-token: {adder.log_hit_tokens}, " - f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, " f"token usage: {num_used / self.max_total_num_tokens:.2f}, " f"#running-req: {running_bs}, " - f"#queue-req: {len(self.waiting_queue) + has_being_chunked}" + f"#queue-req: {len(self.waiting_queue)}, " ) + logger.info(f) if self.enable_metrics: + cache_hit_rate = adder.log_hit_tokens / ( + adder.log_input_tokens + adder.log_hit_tokens + ) self.stats.num_running_reqs = running_bs self.stats.num_used_tokens = num_used self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2) - self.stats.num_queue_reqs = len(self.waiting_queue) + has_being_chunked - self.stats.cache_hit_rate = tree_cache_hit_rate + self.stats.num_queue_reqs = len(self.waiting_queue) + self.stats.cache_hit_rate = cache_hit_rate self.metrics_collector.log_stats(self.stats) def log_decode_stats(self): + gap_latency = time.time() - self.last_decode_stats_tic + self.last_decode_stats_tic = time.time() + self.last_gen_throughput = self.num_generated_tokens / gap_latency + self.num_generated_tokens = 0 + num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0 num_used = self.max_total_num_tokens - ( self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() ) - gen_throughput = self.num_generated_tokens / ( - time.time() - self.last_decode_stats_tic - ) - self.num_generated_tokens = 0 - self.last_decode_stats_tic = time.time() - num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0 + + if RECORD_STEP_TIME: + self.step_time_dict[num_running_reqs].append( + gap_latency / self.server_args.decode_log_interval + ) if self.spec_algorithm.is_none(): msg = ( @@ -814,14 +870,17 @@ class Scheduler: f"#running-req: {num_running_reqs}, " f"#token: {num_used}, " f"token usage: {num_used / self.max_total_num_tokens:.2f}, " - f"gen throughput (token/s): {gen_throughput:.2f}, " - f"#queue-req: {len(self.waiting_queue)}" + f"gen throughput (token/s): {self.last_gen_throughput:.2f}, " + f"largest-len: {self._largest_prefill_decode_len}, " + f"#queue-req: {len(self.waiting_queue)}, " ) spec_accept_length = 0 else: spec_accept_length = ( self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct ) + self.cum_spec_accept_length += self.spec_num_total_accepted_tokens + self.cum_spec_accept_count += self.spec_num_total_forward_ct self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0 msg = ( f"Decode batch. " @@ -829,8 +888,9 @@ class Scheduler: f"#token: {num_used}, " f"token usage: {num_used / self.max_total_num_tokens:.2f}, " f"accept len: {spec_accept_length:.2f}, " - f"gen throughput (token/s): {gen_throughput:.2f}, " - f"#queue-req: {len(self.waiting_queue)}" + f"gen throughput (token/s): {self.last_gen_throughput:.2f}, " + f"largest-len: {self._largest_prefill_decode_len}, " + f"#queue-req: {len(self.waiting_queue)}, " ) logger.info(msg) @@ -838,7 +898,8 @@ class Scheduler: self.stats.num_running_reqs = num_running_reqs self.stats.num_used_tokens = num_used self.stats.token_usage = num_used / self.max_total_num_tokens - self.stats.gen_throughput = gen_throughput + self.stats.cache_hit_rate = 0.0 + self.stats.gen_throughput = self.last_gen_throughput self.stats.num_queue_reqs = len(self.waiting_queue) self.stats.spec_accept_length = spec_accept_length self.metrics_collector.log_stats(self.stats) @@ -872,21 +933,42 @@ class Scheduler: if crash_on_warnings(): raise ValueError(msg) + if ( + self.enable_metrics + and self.attn_tp_rank == 0 + and time.time() > self.metrics_collector.last_log_time + 30 + ): + # During idle time, also collect metrics every 30 seconds. + num_used = self.max_total_num_tokens - ( + self.token_to_kv_pool.available_size() + + self.tree_cache.evictable_size() + ) + num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0 + self.stats.num_running_reqs = num_running_reqs + self.stats.num_used_tokens = num_used + self.stats.token_usage = num_used / self.max_total_num_tokens + self.stats.gen_throughput = 0 + self.stats.num_queue_reqs = len(self.waiting_queue) + self.metrics_collector.log_stats(self.stats) + def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: # Merge the prefill batch into the running batch if self.last_batch and self.last_batch.forward_mode.is_extend(): - if self.being_chunked_req: - # Move the chunked request out of the batch - self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req) - self.tree_cache.cache_unfinished_req(self.being_chunked_req) - # being chunked request keeps its rid but will get a new req_pool_idx - self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx) + if self.chunked_req: + # Move the chunked request out of the batch so that we can merge + # only finished requests to running_batch. + self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req) + self.tree_cache.cache_unfinished_req(self.chunked_req) + # chunked request keeps its rid but will get a new req_pool_idx + self.req_to_token_pool.free(self.chunked_req.req_pool_idx) self.batch_is_full = False + self.last_batch.filter_batch() if not self.last_batch.is_empty(): if self.running_batch is None: self.running_batch = self.last_batch else: + # merge running_batch with prefill batch self.running_batch.merge_batch(self.last_batch) new_batch = self.get_new_batch_prefill() @@ -915,7 +997,7 @@ class Scheduler: # Handle the cases where prefill is not allowed if ( self.batch_is_full or len(self.waiting_queue) == 0 - ) and self.being_chunked_req is None: + ) and self.chunked_req is None: return None running_bs = len(self.running_batch.reqs) if self.running_batch else 0 @@ -937,10 +1019,10 @@ class Scheduler: running_bs if self.is_mixed_chunk else 0, ) - has_being_chunked = self.being_chunked_req is not None - if has_being_chunked: - self.being_chunked_req.init_next_round_input() - self.being_chunked_req = adder.add_being_chunked_req(self.being_chunked_req) + is_chunked = self.chunked_req is not None + if is_chunked: + self.chunked_req.init_next_round_input() + self.chunked_req = adder.add_chunked_req(self.chunked_req) if self.lora_paths: lora_set = ( @@ -994,7 +1076,7 @@ class Scheduler: self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid]) del self.staging_reqs[req.rid] - res = adder.add_one_req(req) + res = adder.add_one_req(req, self.chunked_req) if res != AddReqResult.CONTINUE: if res == AddReqResult.NO_TOKEN: if self.enable_hierarchical_cache: @@ -1006,27 +1088,27 @@ class Scheduler: else: self.batch_is_full = True break - if self.server_args.prefill_only_one_req: + if self.prefill_only_one_req: break # Update waiting queue - can_run_list = adder.can_run_list + can_run_list: List[Req] = adder.can_run_list if len(can_run_list) == 0: return None self.waiting_queue = [ x for x in self.waiting_queue if x not in set(can_run_list) ] - if adder.new_being_chunked_req is not None: - assert self.being_chunked_req is None - self.being_chunked_req = adder.new_being_chunked_req + if adder.new_chunked_req is not None: + assert self.chunked_req is None + self.chunked_req = adder.new_chunked_req - if self.being_chunked_req: - self.being_chunked_req.is_being_chunked += 1 + if self.chunked_req: + self.chunked_req.is_chunked += 1 # Print stats if self.attn_tp_rank == 0: - self.log_prefill_stats(adder, can_run_list, running_bs, has_being_chunked) + self.log_prefill_stats(adder, can_run_list, running_bs) # Create a new batch new_batch = ScheduleBatch.init_new( @@ -1062,8 +1144,6 @@ class Scheduler: def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]: """Update the current running decoding batch.""" - global test_retract - initial_bs = batch.batch_size() batch.filter_batch() @@ -1073,11 +1153,11 @@ class Scheduler: # Check if decode out of memory if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or ( - test_retract and batch.batch_size() > 10 + TEST_RETRACT and batch.batch_size() > 10 ): old_ratio = self.new_token_ratio - retracted_reqs, new_token_ratio = batch.retract_decode() + retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args) self.new_token_ratio = new_token_ratio if self.draft_worker: self.draft_worker.finish_request(retracted_reqs) @@ -1087,7 +1167,7 @@ class Scheduler: f"#retracted_reqs: {len(retracted_reqs)}, " f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}" ) - self.waiting_queue.extend(retracted_reqs) + self._extend_requests_to_queue(retracted_reqs) else: self.new_token_ratio = max( self.new_token_ratio - self.new_token_ratio_decay, @@ -1097,7 +1177,7 @@ class Scheduler: # Check for jump-forward if not self.disable_jump_forward and batch.has_grammar: jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func) - self.waiting_queue.extend(jump_forward_reqs) + self._extend_requests_to_queue(jump_forward_reqs) if batch.is_empty(): self.batch_is_full = False return None @@ -1115,6 +1195,13 @@ class Scheduler: """Run a batch.""" self.forward_ct += 1 + # Check profiler + if ( + self.profiler_target_forward_ct + and self.profiler_target_forward_ct <= self.forward_ct + ): + self.stop_profile() + if self.is_generation: if self.spec_algorithm.is_none(): model_worker_batch = batch.get_model_worker_batch() @@ -1135,9 +1222,23 @@ class Scheduler: self.num_generated_tokens += num_accepted_tokens batch.output_ids = next_token_ids + # These 2 values are needed for processing the output, but the values can be + # modified by overlap schedule. So we have to copy them here so that + # we can use the correct values in output processing. + if batch.return_logprob: + extend_input_len_per_req = [req.extend_input_len for req in batch.reqs] + extend_logprob_start_len_per_req = [ + req.extend_logprob_start_len for req in batch.reqs + ] + else: + extend_input_len_per_req = None + extend_logprob_start_len_per_req = None + ret = GenerationBatchResult( logits_output=logits_output, next_token_ids=next_token_ids, + extend_input_len_per_req=extend_input_len_per_req, + extend_logprob_start_len_per_req=extend_logprob_start_len_per_req, bid=model_worker_batch.bid, ) else: # embedding or reward model @@ -1171,6 +1272,13 @@ class Scheduler: self.current_stream.synchronize() batch.next_batch_sampling_info.sampling_info_done.set() + if self.return_health_check_ct: + # Return some signal for the health check. + # This is used to prevent the health check signal being blocked by long context prefill. + # However, one minor issue is that this code path does not check the status of detokenizer manager. + self.return_health_check_ct -= 1 + self.send_to_tokenizer.send_pyobj(HealthCheckOutput()) + def process_batch_result_prefill( self, batch: ScheduleBatch, @@ -1182,10 +1290,14 @@ class Scheduler: ( logits_output, next_token_ids, + extend_input_len_per_req, + extend_logprob_start_len_per_req, bid, ) = ( result.logits_output, result.next_token_ids, + result.extend_input_len_per_req, + result.extend_logprob_start_len_per_req, result.bid, ) @@ -1195,12 +1307,14 @@ class Scheduler: # Move next_token_ids and logprobs to cpu next_token_ids = next_token_ids.tolist() if batch.return_logprob: - logits_output.next_token_logprobs = ( - logits_output.next_token_logprobs.tolist() - ) - logits_output.input_token_logprobs = ( - logits_output.input_token_logprobs.tolist() - ) + if logits_output.next_token_logprobs is not None: + logits_output.next_token_logprobs = ( + logits_output.next_token_logprobs.tolist() + ) + if logits_output.input_token_logprobs is not None: + logits_output.input_token_logprobs = tuple( + logits_output.input_token_logprobs.tolist() + ) hidden_state_offset = 0 @@ -1216,19 +1330,33 @@ class Scheduler: self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1]) continue - if req.is_being_chunked <= 0: + if req.is_chunked <= 0: + # req output_ids are set here req.output_ids.append(next_token_id) req.check_finished() if req.finished(): self.tree_cache.cache_finished_req(req) elif not batch.decoding_reqs or req not in batch.decoding_reqs: + # This updates radix so others can match self.tree_cache.cache_unfinished_req(req) if req.return_logprob: - logprob_pt += self.add_logprob_return_values( - i, req, logprob_pt, next_token_ids, logits_output + assert extend_logprob_start_len_per_req is not None + assert extend_input_len_per_req is not None + extend_logprob_start_len = extend_logprob_start_len_per_req[i] + extend_input_len = extend_input_len_per_req[i] + num_input_logprobs = extend_input_len - extend_logprob_start_len + self.add_logprob_return_values( + i, + req, + logprob_pt, + next_token_ids, + num_input_logprobs, + logits_output, ) + logprob_pt += num_input_logprobs + if ( req.return_hidden_states and logits_output.hidden_states is not None @@ -1249,12 +1377,31 @@ class Scheduler: req.grammar.finished = req.finished() else: # being chunked reqs' prefill is not finished - req.is_being_chunked -= 1 + req.is_chunked -= 1 # There is only at most one request being currently chunked. # Because this request does not finish prefill, # we don't want to stream the request currently being chunked. skip_stream_req = req + # Incrementally update input logprobs. + if req.return_logprob: + extend_logprob_start_len = extend_logprob_start_len_per_req[i] + extend_input_len = extend_input_len_per_req[i] + if extend_logprob_start_len < extend_input_len: + # Update input logprobs. + num_input_logprobs = ( + extend_input_len - extend_logprob_start_len + ) + self.add_input_logprob_return_values( + i, + req, + logits_output, + logprob_pt, + num_input_logprobs, + last_prefill_chunk=False, + ) + logprob_pt += num_input_logprobs + if batch.next_batch_sampling_info: batch.next_batch_sampling_info.update_regex_vocab_mask() self.current_stream.synchronize() @@ -1270,7 +1417,7 @@ class Scheduler: continue req.embedding = embeddings[i] - if req.is_being_chunked <= 0: + if req.is_chunked <= 0: # Dummy output token for embedding models req.output_ids.append(0) req.check_finished() @@ -1281,7 +1428,7 @@ class Scheduler: self.tree_cache.cache_unfinished_req(req) else: # being chunked reqs' prefill is not finished - req.is_being_chunked -= 1 + req.is_chunked -= 1 self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req) @@ -1322,11 +1469,11 @@ class Scheduler: req.output_ids.append(next_token_id) req.check_finished() - if req.finished(): self.tree_cache.cache_finished_req(req) - if req.return_logprob: + if req.return_logprob and batch.spec_algorithm.is_none(): + # speculative worker handles logprob in speculative decoding req.output_token_logprobs_val.append(next_token_logprobs[i]) req.output_token_logprobs_idx.append(next_token_id) if req.top_logprobs_num > 0: @@ -1336,11 +1483,18 @@ class Scheduler: req.output_top_logprobs_idx.append( logits_output.next_token_top_logprobs_idx[i] ) + if req.token_ids_logprob is not None: + req.output_token_ids_logprobs_val.append( + logits_output.next_token_token_ids_logprobs_val[i] + ) + req.output_token_ids_logprobs_idx.append( + logits_output.next_token_token_ids_logprobs_idx[i] + ) if req.return_hidden_states and logits_output.hidden_states is not None: req.hidden_states.append(logits_output.hidden_states[i].cpu().clone()) - if req.grammar is not None: + if req.grammar is not None and batch.spec_algorithm.is_none(): req.grammar.accept_token(next_token_id) req.grammar.finished = req.finished() @@ -1360,48 +1514,156 @@ class Scheduler: ): self.log_decode_stats() - def add_logprob_return_values( + def add_input_logprob_return_values( self, i: int, req: Req, - pt: int, - next_token_ids: List[int], output: LogitsProcessorOutput, + logprob_pt: int, + num_input_logprobs: int, + last_prefill_chunk: bool, # If True, it means prefill is finished. ): - """Attach logprobs to the return values.""" - req.output_token_logprobs_val.append(output.next_token_logprobs[i]) - req.output_token_logprobs_idx.append(next_token_ids[i]) + """Incrementally add input logprobs to `req`. - # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored. - num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len + Args: + i: The request index in a batch. + req: The request. Input logprobs inside req are modified as a + consequence of the API + fill_ids: The prefill ids processed. + output: Logit processor output that's used to compute input logprobs + last_prefill_chunk: True if it is the last prefill (when chunked). + Some of input logprob operation should only happen at the last + prefill (e.g., computing input token logprobs). + """ + assert output.input_token_logprobs is not None + # It is for jump decoding that will be deprecated. + assert req.last_update_decode_tokens == 0 + if req.input_token_logprobs is None: + req.input_token_logprobs = [] + if req.temp_input_top_logprobs_val is None: + req.temp_input_top_logprobs_val = [] + if req.temp_input_top_logprobs_idx is None: + req.temp_input_top_logprobs_idx = [] + if req.temp_input_token_ids_logprobs_val is None: + req.temp_input_token_ids_logprobs_val = [] + if req.temp_input_token_ids_logprobs_idx is None: + req.temp_input_token_ids_logprobs_idx = [] - if req.input_token_logprobs_val is None: - input_token_logprobs_val = output.input_token_logprobs[ - pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens - ] + if req.input_token_logprobs_val is not None: + # The input logprob has been already computed. It only happens + # upon retract. + if req.top_logprobs_num > 0: + assert req.input_token_logprobs_val is not None + return - input_token_logprobs_idx = req.fill_ids[ - len(req.fill_ids) - - num_input_logprobs - + 1 : len(req.fill_ids) - - req.last_update_decode_tokens - ] + # Important for the performance. + assert isinstance(output.input_token_logprobs, tuple) + input_token_logprobs: Tuple[int] = output.input_token_logprobs + input_token_logprobs = input_token_logprobs[ + logprob_pt : logprob_pt + num_input_logprobs + ] + req.input_token_logprobs.extend(input_token_logprobs) + + if req.top_logprobs_num > 0: + req.temp_input_top_logprobs_val.append(output.input_top_logprobs_val[i]) + req.temp_input_top_logprobs_idx.append(output.input_top_logprobs_idx[i]) + + if req.token_ids_logprob is not None: + req.temp_input_token_ids_logprobs_val.append( + output.input_token_ids_logprobs_val[i] + ) + req.temp_input_token_ids_logprobs_idx.append( + output.input_token_ids_logprobs_idx[i] + ) + + if last_prefill_chunk: + input_token_logprobs = req.input_token_logprobs + req.input_token_logprobs = None + assert req.input_token_logprobs_val is None + assert req.input_token_logprobs_idx is None + assert req.input_top_logprobs_val is None + assert req.input_top_logprobs_idx is None + + # Compute input_token_logprobs_val + # Always pad the first one with None. + req.input_token_logprobs_val = [None] + req.input_token_logprobs_val.extend(input_token_logprobs) + # The last input logprob is for sampling, so just pop it out. + req.input_token_logprobs_val.pop() + + # Compute input_token_logprobs_idx + input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :] # Clip the padded hash values from image tokens. # Otherwise, it will lead to detokenization errors. input_token_logprobs_idx = [ x if x < self.model_config.vocab_size - 1 else 0 for x in input_token_logprobs_idx ] - - if ( - req.logprob_start_len == 0 - ): # The first token does not have logprob, pad it. - input_token_logprobs_val = [None] + input_token_logprobs_val - input_token_logprobs_idx = [req.fill_ids[0]] + input_token_logprobs_idx - - req.input_token_logprobs_val = input_token_logprobs_val req.input_token_logprobs_idx = input_token_logprobs_idx + if req.top_logprobs_num > 0: + req.input_top_logprobs_val = [None] + req.input_top_logprobs_idx = [None] + + for val, idx in zip( + req.temp_input_top_logprobs_val, + req.temp_input_top_logprobs_idx, + strict=True, + ): + req.input_top_logprobs_val.extend(val) + req.input_top_logprobs_idx.extend(idx) + + # Last token is a sample token. + req.input_top_logprobs_val.pop() + req.input_top_logprobs_idx.pop() + req.temp_input_top_logprobs_idx = None + req.temp_input_top_logprobs_val = None + + if req.token_ids_logprob is not None: + req.input_token_ids_logprobs_val = [None] + req.input_token_ids_logprobs_idx = [None] + + for val, idx in zip( + req.temp_input_token_ids_logprobs_val, + req.temp_input_token_ids_logprobs_idx, + strict=True, + ): + req.input_token_ids_logprobs_val.extend(val) + req.input_token_ids_logprobs_idx.extend(idx) + + # Last token is a sample token. + req.input_token_ids_logprobs_val.pop() + req.input_token_ids_logprobs_idx.pop() + req.temp_input_token_ids_logprobs_idx = None + req.temp_input_token_ids_logprobs_val = None + + if req.return_logprob: + relevant_tokens_len = len(req.origin_input_ids) - req.logprob_start_len + assert len(req.input_token_logprobs_val) == relevant_tokens_len + assert len(req.input_token_logprobs_idx) == relevant_tokens_len + if req.top_logprobs_num > 0: + assert len(req.input_top_logprobs_val) == relevant_tokens_len + assert len(req.input_top_logprobs_idx) == relevant_tokens_len + if req.token_ids_logprob is not None: + assert len(req.input_token_ids_logprobs_val) == relevant_tokens_len + assert len(req.input_token_ids_logprobs_idx) == relevant_tokens_len + + def add_logprob_return_values( + self, + i: int, + req: Req, + pt: int, + next_token_ids: List[int], + num_input_logprobs: int, + output: LogitsProcessorOutput, + ): + """Attach logprobs to the return values.""" + req.output_token_logprobs_val.append(output.next_token_logprobs[i]) + req.output_token_logprobs_idx.append(next_token_ids[i]) + + self.add_input_logprob_return_values( + i, req, output, pt, num_input_logprobs, last_prefill_chunk=True + ) if req.last_update_decode_tokens != 0: # Some decode tokens are re-computed in an extend batch req.output_token_logprobs_val.extend( @@ -1422,13 +1684,6 @@ class Scheduler: ) if req.top_logprobs_num > 0: - if req.input_top_logprobs_val is None: - req.input_top_logprobs_val = output.input_top_logprobs_val[i] - req.input_top_logprobs_idx = output.input_top_logprobs_idx[i] - if req.logprob_start_len == 0: - req.input_top_logprobs_val = [None] + req.input_top_logprobs_val - req.input_top_logprobs_idx = [None] + req.input_top_logprobs_idx - if req.last_update_decode_tokens != 0: req.output_top_logprobs_val.extend( output.input_top_logprobs_val[i][-req.last_update_decode_tokens :] @@ -1440,6 +1695,26 @@ class Scheduler: req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i]) req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i]) + if req.token_ids_logprob is not None: + if req.last_update_decode_tokens != 0: + req.output_token_ids_logprobs_val.extend( + output.input_token_ids_logprobs_val[i][ + -req.last_update_decode_tokens : + ] + ) + req.output_token_ids_logprobs_idx.extend( + output.input_token_ids_logprobs_idx[i][ + -req.last_update_decode_tokens : + ] + ) + + req.output_token_ids_logprobs_val.append( + output.next_token_token_ids_logprobs_val[i] + ) + req.output_token_ids_logprobs_idx.append( + output.next_token_token_ids_logprobs_idx[i] + ) + return num_input_logprobs def stream_output( @@ -1474,24 +1749,41 @@ class Scheduler: input_top_logprobs_idx = [] output_top_logprobs_val = [] output_top_logprobs_idx = [] + input_token_ids_logprobs_val = [] + input_token_ids_logprobs_idx = [] + output_token_ids_logprobs_val = [] + output_token_ids_logprobs_idx = [] else: input_token_logprobs_val = input_token_logprobs_idx = ( output_token_logprobs_val ) = output_token_logprobs_idx = input_top_logprobs_val = ( input_top_logprobs_idx - ) = output_top_logprobs_val = output_top_logprobs_idx = None + ) = output_top_logprobs_val = output_top_logprobs_idx = ( + input_token_ids_logprobs_val + ) = input_token_ids_logprobs_idx = output_token_ids_logprobs_val = ( + output_token_ids_logprobs_idx + ) = None for req in reqs: if req is skip_req: continue - # TODO(lianmin): revisit this for overlap + retract + stream + # Multimodal partial stream chunks break the detokenizer, so drop aborted requests here. + if self.model_config.is_multimodal_gen and req.to_abort: + continue + if ( req.finished() # If stream, follow the given stream_interval or (req.stream and len(req.output_ids) % self.stream_interval == 0) # If not stream, we still want to output some tokens to get the benefit of incremental decoding. - or (not req.stream and len(req.output_ids) % 50 == 0) + # TODO(lianmin): this is wrong for speculative decoding because len(req.output_ids) does not + # always increase one-by-one. + or ( + not req.stream + and len(req.output_ids) % 50 == 0 + and not self.model_config.is_multimodal_gen + ) ): if self.draft_worker and req.finished(): self.draft_worker.finish_request(req) @@ -1529,6 +1821,18 @@ class Scheduler: input_top_logprobs_idx.append(req.input_top_logprobs_idx) output_top_logprobs_val.append(req.output_top_logprobs_val) output_top_logprobs_idx.append(req.output_top_logprobs_idx) + input_token_ids_logprobs_val.append( + req.input_token_ids_logprobs_val + ) + input_token_ids_logprobs_idx.append( + req.input_token_ids_logprobs_idx + ) + output_token_ids_logprobs_val.append( + req.output_token_ids_logprobs_val + ) + output_token_ids_logprobs_idx.append( + req.output_token_ids_logprobs_idx + ) if req.return_hidden_states: if output_hidden_states is None: @@ -1537,6 +1841,9 @@ class Scheduler: # Send to detokenizer if rids: + if self.model_config.is_multimodal_gen: + raise NotImplementedError() + self.send_to_detokenizer.send_pyobj( BatchTokenIDOut( rids, @@ -1561,6 +1868,10 @@ class Scheduler: input_top_logprobs_idx, output_top_logprobs_val, output_top_logprobs_idx, + input_token_ids_logprobs_val, + input_token_ids_logprobs_idx, + output_token_ids_logprobs_val, + output_token_ids_logprobs_idx, output_hidden_states, ) ) @@ -1668,7 +1979,7 @@ class Scheduler: ].grammar.result() num_ready_reqs = num_ready_reqs_max - self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs]) + self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs]) self.grammar_queue = self.grammar_queue[num_ready_reqs:] def flush_cache_wrapped(self, recv_req: FlushCacheReq): @@ -1679,6 +1990,8 @@ class Scheduler: if len(self.waiting_queue) == 0 and ( self.running_batch is None or len(self.running_batch.reqs) == 0 ): + self.cur_batch = None + self.last_batch = None self.tree_cache.reset() self.tree_cache_metrics = {"total": 0, "hit": 0} if self.grammar_backend: @@ -1694,6 +2007,8 @@ class Scheduler: self.forward_ct_decode = 0 self.spec_num_total_accepted_tokens = 0 self.spec_num_total_forward_ct = 0 + self.cum_spec_accept_length = 0 + self.cum_spec_accept_count = 0 torch.cuda.empty_cache() logger.info("Cache flushed successfully!") if_success = True @@ -1706,6 +2021,49 @@ class Scheduler: if_success = False return if_success + def get_internal_state(self, recv_req: GetInternalStateReq): + ret = dict(global_server_args_dict) + ret["last_gen_throughput"] = self.last_gen_throughput + if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0: + ret["avg_spec_accept_length"] = ( + self.cum_spec_accept_length / self.cum_spec_accept_count + ) + + if RECORD_STEP_TIME: + ret["step_time_dict"] = self.step_time_dict + return GetInternalStateReqOutput( + internal_state=ret, + ) + + def set_internal_state(self, recv_req: SetInternalStateReq): + server_args_dict = recv_req.server_args + args_allow_update = set( + [ + "speculative_accept_threshold_single", + "speculative_accept_threshold_acc", + ] + ) + if_success = True + for k, v in server_args_dict.items(): + if k not in args_allow_update: + logging.warning(f"Updating {k} is not supported.") + if_success = False + break + if if_success: + if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0: + avg_spec_accept_length = ( + self.cum_spec_accept_length / self.cum_spec_accept_count + ) + logger.info(f"{avg_spec_accept_length=}") + self.cum_spec_accept_length = self.cum_spec_accept_count = 0 + for k, v in server_args_dict.items(): + global_server_args_dict[k] = v + logger.info(f"Global server args updated! " f"{global_server_args_dict=}") + return SetInternalStateReqOutput( + updated=True, + server_args=global_server_args_dict, + ) + def abort_request(self, recv_req: AbortReq): # Delete requests in the waiting queue to_del = None @@ -1735,7 +2093,7 @@ class Scheduler: assert flash_cache_success, "Cache flush failed after updating weights" else: logger.error(message) - return UpdateWeightFromDiskReqOutput(success, message) + return UpdateWeightFromDiskReqOutput(success, message, 0) def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput): """Initialize the online model parameter update group.""" @@ -1771,7 +2129,7 @@ class Scheduler: parameter = self.tp_worker.get_weights_by_name(recv_req) return GetWeightsByNameReqOutput(parameter) - def release_memory_occupation(self): + def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput): self.stashed_model_static_state = _export_static_state( self.tp_worker.worker.model_runner.model ) @@ -1779,7 +2137,7 @@ class Scheduler: self.flush_cache() return ReleaseMemoryOccupationReqOutput() - def resume_memory_occupation(self): + def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput): self.memory_saver_adapter.resume() _import_static_state( self.tp_worker.worker.model_runner.model, self.stashed_model_static_state @@ -1788,24 +2146,96 @@ class Scheduler: return ResumeMemoryOccupationReqOutput() def profile(self, recv_req: ProfileReq): - if recv_req == ProfileReq.START_PROFILE: - self.start_profile() + if recv_req.type == ProfileReqType.START_PROFILE: + return self.start_profile( + recv_req.output_dir, recv_req.num_steps, recv_req.activities + ) else: - self.stop_profile() + return self.stop_profile() - def start_profile(self) -> None: - if self.profiler is None: - raise RuntimeError("Profiler is not enabled.") - self.profiler.start() + def start_profile( + self, + output_dir: Optional[str], + num_steps: Optional[int], + activities: Optional[List[str]], + ) -> None: + if self.torch_profiler_activities: + return ProfileReqOutput( + success=False, + message="Profiling is already in progress. Call /stop_profile first.", + ) + + if output_dir is None: + output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp") + if activities is None: + activities = ["CPU", "GPU"] + + self.torch_profiler_output_dir = output_dir + self.torch_profiler_activities = activities + logger.info( + "Profiling starts. Traces will be saved to: %s", + self.torch_profiler_output_dir, + ) + + activity_map = { + "CPU": torch.profiler.ProfilerActivity.CPU, + "GPU": torch.profiler.ProfilerActivity.CUDA, + } + torchprof_activities = [ + activity_map[a] for a in activities if a in activity_map + ] + + if torchprof_activities: + self.torch_profiler = torch.profiler.profile( + activities=torchprof_activities, + with_stack=True, + ) + self.torch_profiler.start() + + if "MEM" in activities: + torch.cuda.memory._record_memory_history(max_entries=100000) + + if num_steps: + self.profiler_target_forward_ct = self.forward_ct + num_steps + # The caller will be notified when reaching profiler_target_forward_ct + else: + self.profiler_target_forward_ct = None + return ProfileReqOutput(success=True, message="Succeeded") def stop_profile(self) -> None: - if self.profiler is None: - raise RuntimeError("Profiler is not enabled.") - self.profiler.stop() - self.profiler.export_chrome_trace( - self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz" + if self.torch_profiler_activities is None: + return + + logger.info("Stop profiling...") + if self.torch_profiler is not None: + self.torch_profiler.stop() + self.torch_profiler.export_chrome_trace( + os.path.join( + self.torch_profiler_output_dir, + str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz", + ) + ) + + if "MEM" in self.torch_profiler_activities: + memory_profile_path = os.path.join( + self.torch_profiler_trace_dir, + str(time.time()) + f"-TP-{self.tp_rank}-memory" + ".pickle", + ) + torch.cuda.memory._dump_snapshot(memory_profile_path) + torch.cuda.memory._record_memory_history(enabled=None) + + logger.info( + "Profiling done. Traces are saved to: %s", + self.torch_profiler_output_dir, ) - logger.info("Profiler is done") + self.torch_profiler = None + self.torch_profiler_output_dir = None + self.torch_profiler_activities = None + + if self.profiler_target_forward_ct: + self.send_to_tokenizer.send_pyobj( + ProfileReqOutput(success=True, message="Succeeded.") + ) def open_session(self, recv_req: OpenSessionReqInput): # handle error @@ -1814,7 +2244,7 @@ class Scheduler: logger.warning(f"session id {session_id} already exist, cannot open.") return OpenSessionReqOutput(session_id, False) elif session_id is None: - logger.warning(f"session id is None, cannot open.") + logger.warning("session id is None, cannot open.") return OpenSessionReqOutput(session_id, False) else: self.sessions[session_id] = Session( @@ -1831,6 +2261,10 @@ class Scheduler: del self.sessions[session_id] +def is_health_check_generate_req(recv_req): + return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK") + + def _export_static_state(model): return dict( buffers=[ @@ -1853,8 +2287,11 @@ def run_scheduler_process( dp_rank: Optional[int], pipe_writer, ): - setproctitle.setproctitle("sglang::scheduler") + # Config the process + # kill_itself_when_parent_died() # This is disabled because it does not work for `--dp 2` + setproctitle.setproctitle(f"sglang::scheduler_{dp_rank}") faulthandler.enable() + parent_process = psutil.Process().parent() # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var if dp_rank is None and "SGLANG_DP_RANK" in os.environ: @@ -1862,9 +2299,10 @@ def run_scheduler_process( # Configure the logger if dp_rank is None: - configure_logger(server_args, prefix=f" TP{tp_rank}") + prefix = f" TP{tp_rank}" else: - configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}") + prefix = f" DP{dp_rank} TP{tp_rank}" + configure_logger(server_args, prefix=prefix) suppress_other_loggers() # Set cpu affinity to this gpu process diff --git a/python/sglang/srt/managers/session_controller.py b/python/sglang/srt/managers/session_controller.py index 4f4af6367..9aa6e4c59 100644 --- a/python/sglang/srt/managers/session_controller.py +++ b/python/sglang/srt/managers/session_controller.py @@ -35,12 +35,12 @@ class SessionReqNode: for req_node in self.childs: req_node.clear(req_dict) - if self.req.finished_reason == None: + if self.req.finished_reason is None: self.req.to_abort = True del req_dict[self.req.rid] def abort(self): - if self.req.finished_reason == None: + if self.req.finished_reason is None: self.req.to_abort = True def __str__(self): @@ -132,6 +132,10 @@ class Session: lora_path=req.lora_path, session_id=self.session_id, custom_logit_processor=req.custom_logit_processor, + stream=req.stream, + return_logprob=req.return_logprob, + top_logprobs_num=req.top_logprobs_num, + token_ids_logprob=req.token_ids_logprob, ) if last_req is not None: new_req.image_inputs = last_req.image_inputs diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 40348edc0..87e0ff847 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -16,6 +16,7 @@ import asyncio import copy import dataclasses +import json import logging import os import pickle @@ -24,9 +25,21 @@ import sys import threading import time import uuid +from collections import deque from datetime import datetime from http import HTTPStatus -from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union +from typing import ( + Any, + Awaitable, + Deque, + Dict, + Generic, + List, + Optional, + Tuple, + TypeVar, + Union, +) import fastapi import uvloop @@ -44,6 +57,7 @@ from sglang.srt.managers.image_processor import ( from sglang.srt.managers.io_struct import ( AbortReq, BatchEmbeddingOut, + BatchMultimodalOut, BatchStrOut, BatchTokenIDOut, CloseSessionReqInput, @@ -51,18 +65,25 @@ from sglang.srt.managers.io_struct import ( EmbeddingReqInput, FlushCacheReq, GenerateReqInput, + GetInternalStateReq, + GetInternalStateReqOutput, GetWeightsByNameReqInput, GetWeightsByNameReqOutput, + HealthCheckOutput, InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqOutput, OpenSessionReqInput, OpenSessionReqOutput, ProfileReq, + ProfileReqOutput, + ProfileReqType, ReleaseMemoryOccupationReqInput, ReleaseMemoryOccupationReqOutput, ResumeMemoryOccupationReqInput, ResumeMemoryOccupationReqOutput, SessionParams, + SetInternalStateReq, + SetInternalStateReqOutput, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, UpdateWeightFromDiskReqInput, @@ -98,7 +119,10 @@ class ReqState: # For metrics created_time: float - first_token_time: Optional[float] = None + finished_time: float = 0.0 + first_token_time: float = 0.0 + last_time: float = 0.0 + last_completion_tokens: int = 1 # For streaming output last_output_offset: int = 0 @@ -113,11 +137,10 @@ class TokenizerManager: port_args: PortArgs, ): # Parse args - self.server_args = server_args self.enable_metrics = server_args.enable_metrics self.log_requests = server_args.log_requests - self.log_requests_level = 0 + self.log_requests_level = server_args.log_requests_level # Init inter-process communication context = zmq.asyncio.Context(2) @@ -143,6 +166,7 @@ class TokenizerManager: ) self.is_generation = self.model_config.is_generation + self.is_image_gen = self.model_config.is_image_gen self.context_len = self.model_config.context_len self.image_token_id = self.model_config.image_token_id @@ -178,9 +202,12 @@ class TokenizerManager: # Store states self.no_create_loop = False self.rid_to_state: Dict[str, ReqState] = {} + self.gracefully_exit = False + self.last_receive_tstamp = 0 self.dump_requests_folder = "" # By default do not dump self.dump_requests_threshold = 1000 self.dump_request_list: List[Tuple] = [] + self.log_request_metadata = self.get_log_request_metadata() # The event to notify the weight sync is finished. self.model_update_lock = RWLock() @@ -192,8 +219,19 @@ class TokenizerManager: # For session info self.session_futures = {} # session_id -> asyncio event - # Others - self.gracefully_exit = False + # Set after scheduler is initialized + self.max_req_input_len = None + + # Metrics + if self.enable_metrics: + self.metrics_collector = TokenizerMetricsCollector( + labels={ + "model_name": self.server_args.served_model_name, + # TODO: Add lora name/path in the future, + }, + ) + + # Communicators self.init_weights_update_group_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) @@ -212,22 +250,26 @@ class TokenizerManager: self.resume_memory_occupation_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) - # Set after scheduler is initialized - self.max_req_input_len = None - - # Metrics - if self.enable_metrics: - self.metrics_collector = TokenizerMetricsCollector( - labels={ - "model_name": self.server_args.served_model_name, - # TODO: Add lora name/path in the future, - }, - ) + self.start_profile_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + self.health_check_communitcator = _Communicator(self.send_to_scheduler, 1) + self.get_internal_state_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + self.set_internal_state_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) self._result_dispatcher = TypeBasedDispatcher( [ ( - (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut), + ( + BatchStrOut, + BatchEmbeddingOut, + BatchTokenIDOut, + BatchMultimodalOut, + ), self._handle_batch_output, ), (OpenSessionReqOutput, self._handle_open_session_req_output), @@ -259,6 +301,19 @@ class TokenizerManager: ResumeMemoryOccupationReqOutput, self.resume_memory_occupation_communicator.handle_recv, ), + ( + ProfileReqOutput, + self.start_profile_communicator.handle_recv, + ), + ( + GetInternalStateReqOutput, + self.get_internal_state_communicator.handle_recv, + ), + ( + SetInternalStateReqOutput, + self.set_internal_state_communicator.handle_recv, + ), + (HealthCheckOutput, lambda x: None), ] ) @@ -280,9 +335,9 @@ class TokenizerManager: obj.normalize_batch_and_arguments() if self.log_requests: - max_length = 2048 if self.log_requests_level == 0 else 1 << 30 + max_length, skip_names, _ = self.log_request_metadata logger.info( - f"Receive: obj={dataclass_to_string_truncated(obj, max_length)}" + f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}" ) async with self.model_update_lock.reader_lock: @@ -336,6 +391,7 @@ class TokenizerManager: return_logprob = obj.return_logprob logprob_start_len = obj.logprob_start_len top_logprobs_num = obj.top_logprobs_num + token_ids_logprob = obj.token_ids_logprob session_params = ( SessionParams(**obj.session_params) if obj.session_params else None ) @@ -378,6 +434,7 @@ class TokenizerManager: return_logprob, logprob_start_len, top_logprobs_num, + token_ids_logprob, obj.stream, lora_path=obj.lora_path, input_embeds=input_embeds, @@ -401,8 +458,7 @@ class TokenizerManager: tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput], created_time: Optional[float] = None, ): - event = asyncio.Event() - state = ReqState([], False, event, obj, created_time=created_time) + state = ReqState([], False, asyncio.Event(), obj, created_time=created_time) self.rid_to_state[obj.rid] = state self.send_to_scheduler.send_pyobj(tokenized_obj) @@ -420,7 +476,10 @@ class TokenizerManager: except asyncio.TimeoutError: if request is not None and await request.is_disconnected(): self.abort_request(obj.rid) - raise ValueError(f"Abort request {obj.rid}") + raise ValueError( + "Request is disconnected from the client side. " + f"Abort request {obj.rid}" + ) continue out = state.out_list[-1] @@ -428,8 +487,11 @@ class TokenizerManager: state.out_list = [] if state.finished: if self.log_requests: - max_length = 2048 if self.log_requests_level == 0 else 1 << 30 - msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length)}, out={dataclass_to_string_truncated(out, max_length)}" + max_length, skip_names, out_skip_names = self.log_request_metadata + if self.model_config.is_multimodal_gen: + msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}" + else: + msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}" logger.info(msg) del self.rid_to_state[obj.rid] @@ -452,7 +514,10 @@ class TokenizerManager: else: if request is not None and await request.is_disconnected(): self.abort_request(obj.rid) - raise ValueError(f"Abort request {obj.rid}") + raise ValueError( + "Request is disconnected from the client side. " + f"Abort request {obj.rid}" + ) async def _handle_batch_request( self, @@ -543,12 +608,25 @@ class TokenizerManager: req = AbortReq(rid) self.send_to_scheduler.send_pyobj(req) - def start_profile(self): - req = ProfileReq.START_PROFILE - self.send_to_scheduler.send_pyobj(req) + async def start_profile( + self, + output_dir: Optional[str] = None, + num_steps: Optional[int] = None, + activities: Optional[List[str]] = None, + ): + req = ProfileReq( + type=ProfileReqType.START_PROFILE, + output_dir=output_dir, + num_steps=num_steps, + activities=activities, + ) + result = (await self.start_profile_communicator(req))[0] + if not result.success: + raise RuntimeError(result.message) + return result def stop_profile(self): - req = ProfileReq.STOP_PROFILE + req = ProfileReq(type=ProfileReqType.STOP_PROFILE) self.send_to_scheduler.send_pyobj(req) async def update_weights_from_disk( @@ -581,7 +659,7 @@ class TokenizerManager: self.server_args.model_path = obj.model_path self.server_args.load_format = obj.load_format self.model_path = obj.model_path - return result.success, result.message + return result.success, result.message, result.num_paused_requests else: # self.server_args.dp_size > 1 self.model_update_tmp = [] result = await self.model_update_result @@ -593,7 +671,8 @@ class TokenizerManager: self.model_path = obj.model_path all_message = [r.message for r in result] all_message = " | ".join(all_message) - return all_success, all_message + all_paused_requests = [r.num_paused_requests for r in result] + return all_success, all_message, all_paused_requests async def init_weights_update_group( self, @@ -688,6 +767,54 @@ class TokenizerManager: ): await self.send_to_scheduler.send_pyobj(obj) + async def get_internal_state(self) -> Dict[Any, Any]: + req = GetInternalStateReq() + res: List[GetInternalStateReqOutput] = ( + await self.get_internal_state_communicator(req) + ) + return res[0].internal_state + + async def set_internal_state( + self, obj: SetInternalStateReq + ) -> SetInternalStateReqOutput: + res: List[SetInternalStateReqOutput] = ( + await self.set_internal_state_communicator(obj) + ) + return res[0] + + def get_log_request_metadata(self): + max_length = None + skip_names = None + out_skip_names = None + if self.log_requests: + if self.log_requests_level == 0: + max_length = 1 << 30 + skip_names = set( + [ + "text", + "input_ids", + "input_embeds", + "image_data", + "audio_data", + "lora_path", + ] + ) + out_skip_names = set( + [ + "text", + "output_ids", + ] + ) + elif self.log_requests_level == 1: + max_length = 2048 + elif self.log_requests_level == 2: + max_length = 1 << 30 + else: + raise ValueError( + f"Invalid --log-requests-level: {self.log_requests_level=}" + ) + return max_length, skip_names, out_skip_names + def configure_logging(self, obj: ConfigureLoggingReq): if obj.log_requests is not None: self.log_requests = obj.log_requests @@ -698,6 +825,7 @@ class TokenizerManager: if obj.dump_requests_threshold is not None: self.dump_requests_threshold = obj.dump_requests_threshold logging.info(f"Config logging: {obj=}") + self.log_request_metadata = self.get_log_request_metadata() def create_abort_task(self, obj: GenerateReqInput): # Abort the request if the client is disconnected. @@ -762,15 +890,20 @@ class TokenizerManager: while True: recv_obj = await self.recv_from_detokenizer.recv_pyobj() self._result_dispatcher(recv_obj) + self.last_receive_tstamp = time.time() def _handle_batch_output( - self, recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut] + self, + recv_obj: Union[ + BatchStrOut, BatchEmbeddingOut, BatchMultimodalOut, BatchTokenIDOut + ], ): for i, rid in enumerate(recv_obj.rids): state = self.rid_to_state.get(rid, None) if state is None: continue + # Build meta_info and return value meta_info = { "id": rid, "finish_reason": recv_obj.finished_reasons[i], @@ -781,14 +914,12 @@ class TokenizerManager: self.convert_logprob_style( meta_info, state.obj.top_logprobs_num, + state.obj.token_ids_logprob, state.obj.return_text_in_logprobs, recv_obj, i, ) - if self.server_args.speculative_algorithm: - meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i] - if not isinstance(recv_obj, BatchEmbeddingOut): meta_info.update( { @@ -806,10 +937,20 @@ class TokenizerManager: "meta_info": meta_info, } elif isinstance(recv_obj, BatchTokenIDOut): + if self.server_args.stream_output and state.obj.stream: + output_token_ids = recv_obj.output_ids[i][ + state.last_output_offset : + ] + state.last_output_offset = len(recv_obj.output_ids[i]) + else: + output_token_ids = recv_obj.output_ids[i] + out_dict = { - "token_ids": recv_obj.output_ids[i], + "output_ids": output_token_ids, "meta_info": meta_info, } + elif isinstance(recv_obj, BatchMultimodalOut): + raise NotImplementedError() else: assert isinstance(recv_obj, BatchEmbeddingOut) out_dict = { @@ -817,10 +958,17 @@ class TokenizerManager: "meta_info": meta_info, } - state.out_list.append(out_dict) state.finished = recv_obj.finished_reasons[i] is not None + if state.finished: + if self.server_args.speculative_algorithm: + meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i] + state.finished_time = time.time() + meta_info["e2e_latency"] = state.finished_time - state.created_time + + state.out_list.append(out_dict) state.event.set() + # Log metrics and dump if self.enable_metrics and state.obj.log_metrics: self.collect_metrics(state, recv_obj, i) if self.dump_requests_folder and state.finished and state.obj.log_metrics: @@ -830,6 +978,7 @@ class TokenizerManager: self, meta_info: dict, top_logprobs_num: int, + token_ids_logprob: List[int], return_text_in_logprobs: bool, recv_obj: BatchStrOut, recv_obj_index: int, @@ -857,6 +1006,20 @@ class TokenizerManager: return_text_in_logprobs, ) + if token_ids_logprob is not None: + meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens( + recv_obj.input_token_ids_logprobs_val[recv_obj_index], + recv_obj.input_token_ids_logprobs_idx[recv_obj_index], + return_text_in_logprobs, + ) + meta_info["output_token_ids_logprobs"] = ( + self.detokenize_top_logprobs_tokens( + recv_obj.output_token_ids_logprobs_val[recv_obj_index], + recv_obj.output_token_ids_logprobs_idx[recv_obj_index], + return_text_in_logprobs, + ) + ) + def detokenize_logprob_tokens( self, token_logprobs_val: List[float], @@ -900,34 +1063,30 @@ class TokenizerManager: else 0 ) - if state.first_token_time is None: - state.first_token_time = time.time() + if state.first_token_time == 0.0: + state.first_token_time = state.last_time = time.time() + state.last_completion_tokens = completion_tokens self.metrics_collector.observe_time_to_first_token( state.first_token_time - state.created_time ) else: - if completion_tokens >= 2: - # Compute time_per_output_token for the streaming case - self.metrics_collector.observe_time_per_output_token( - (time.time() - state.first_token_time) / (completion_tokens - 1) + num_new_tokens = completion_tokens - state.last_completion_tokens + if num_new_tokens: + new_time = time.time() + interval = new_time - state.last_time + self.metrics_collector.observe_inter_token_latency( + interval, + num_new_tokens, ) + state.last_time = new_time + state.last_completion_tokens = completion_tokens if state.finished: self.metrics_collector.observe_one_finished_request( - recv_obj.prompt_tokens[i], completion_tokens + recv_obj.prompt_tokens[i], + completion_tokens, + state.finished_time - state.created_time, ) - self.metrics_collector.observe_e2e_request_latency( - time.time() - state.created_time - ) - # Compute time_per_output_token for the non-streaming case - if ( - hasattr(state.obj, "stream") - and not state.obj.stream - and completion_tokens >= 1 - ): - self.metrics_collector.observe_time_per_output_token( - (time.time() - state.created_time) / completion_tokens - ) def dump_requests(self, state: ReqState, out_dict: dict): self.dump_request_list.append( @@ -996,22 +1155,38 @@ T = TypeVar("T") class _Communicator(Generic[T]): + """Note: The communicator now only run up to 1 in-flight request at any time.""" + def __init__(self, sender, fan_out: int): self._sender = sender self._fan_out = fan_out - self._result_future: Optional[asyncio.Future] = None + self._result_event: Optional[asyncio.Event] = None self._result_values: Optional[List[T]] = None + self._ready_queue: Deque[asyncio.Future] = deque() async def __call__(self, obj): - self._sender.send_pyobj(obj) - self._result_future = asyncio.Future() + ready_event = asyncio.Event() + if self._result_event is not None or len(self._ready_queue) > 0: + self._ready_queue.append(ready_event) + await ready_event.wait() + assert self._result_event is None + assert self._result_values is None + + if obj: + self._sender.send_pyobj(obj) + + self._result_event = asyncio.Event() self._result_values = [] - await self._result_future + await self._result_event.wait() result_values = self._result_values - self._result_future = self._result_values = None + self._result_event = self._result_values = None + + if len(self._ready_queue) > 0: + self._ready_queue.popleft().set() + return result_values def handle_recv(self, recv_obj: T): self._result_values.append(recv_obj) if len(self._result_values) == self._fan_out: - self._result_future.set_result(None) + self._result_event.set() diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 8b93ee5aa..ddb8a7c2e 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -15,10 +15,13 @@ import logging import threading -from typing import Optional +from typing import Optional, Tuple + +import torch from sglang.srt.configs.model_config import ModelConfig from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer +from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.io_struct import ( GetWeightsByNameReqInput, InitWeightsUpdateGroupReqInput, @@ -159,7 +162,7 @@ class TpModelWorker: model_worker_batch: ModelWorkerBatch, launch_done: Optional[threading.Event] = None, skip_sample: bool = False, - ): + ) -> Tuple[LogitsProcessorOutput, Optional[torch.Tensor]]: forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) logits_output = self.model_runner.forward(forward_batch) if launch_done: diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 26d8d5748..74a2be5a2 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -175,7 +175,7 @@ class TpModelWorkerClient: logits_output.next_token_logprobs.tolist() ) if logits_output.input_token_logprobs is not None: - logits_output.input_token_logprobs = ( + logits_output.input_token_logprobs = tuple( logits_output.input_token_logprobs.tolist() ) next_token_ids = next_token_ids.tolist() @@ -188,8 +188,7 @@ class TpModelWorkerClient: model_worker_batch.sampling_info = self.cur_sampling_info = dataclasses.replace( sampling_info, sampling_info_done=threading.Event(), - scaling_penalties=sampling_info.scaling_penalties, - linear_penalties=sampling_info.linear_penalties, + penalizer_orchestrator=None, ) # A cuda stream sync here to avoid the cuda illegal memory access error. diff --git a/python/sglang/srt/mem_cache/chunk_cache.py b/python/sglang/srt/mem_cache/chunk_cache.py index b50199ca2..8f58e146c 100644 --- a/python/sglang/srt/mem_cache/chunk_cache.py +++ b/python/sglang/srt/mem_cache/chunk_cache.py @@ -2,7 +2,9 @@ from __future__ import annotations """Cache for chunked prefill, used when RadixCache is disabled.""" -from typing import TYPE_CHECKING, Callable, List, Optional, Tuple +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple + +import torch from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool @@ -12,7 +14,7 @@ if TYPE_CHECKING: class ChunkCacheEntry: - def __init__(self, rid, value): + def __init__(self, rid: str, value: torch.Tensor): self.rid = rid self.value = value @@ -24,6 +26,7 @@ class ChunkCache(BasePrefixCache): self.disable = True self.req_to_token_pool = req_to_token_pool self.token_to_kv_pool = token_to_kv_pool + self.entries: Dict[str, ChunkCacheEntry] = {} self.reset() @@ -53,11 +56,8 @@ class ChunkCache(BasePrefixCache): if req.rid in self.entries: del self.entries[req.rid] - def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None): - if token_ids is None: - token_id_len = len(req.fill_ids) - else: - token_id_len = len(token_ids) + def cache_unfinished_req(self, req: Req): + token_id_len = len(req.fill_ids) kv_indices = self.req_to_token_pool.req_to_token[ req.req_pool_idx, :token_id_len @@ -86,5 +86,8 @@ class ChunkCache(BasePrefixCache): def evictable_size(self): return 0 + def pretty_print(self): + return "" + def protected_size(self): return 0 diff --git a/python/sglang/srt/metrics/collector.py b/python/sglang/srt/metrics/collector.py index 26eb2fc27..0a4a14973 100644 --- a/python/sglang/srt/metrics/collector.py +++ b/python/sglang/srt/metrics/collector.py @@ -13,6 +13,7 @@ # ============================================================================== """Utilities for Prometheus Metrics Collection.""" +import time from dataclasses import dataclass from typing import Dict, Union @@ -35,19 +36,20 @@ class SchedulerMetricsCollector: from prometheus_client import Gauge self.labels = labels + self.last_log_time = time.time() self.num_running_reqs = Gauge( name="sglang:num_running_reqs", documentation="The number of running requests.", labelnames=labels.keys(), - multiprocess_mode="sum", + multiprocess_mode="mostrecent", ) self.num_used_tokens = Gauge( name="sglang:num_used_tokens", documentation="The number of used tokens.", labelnames=labels.keys(), - multiprocess_mode="sum", + multiprocess_mode="mostrecent", ) self.token_usage = Gauge( @@ -61,14 +63,14 @@ class SchedulerMetricsCollector: name="sglang:gen_throughput", documentation="The generation throughput (token/s).", labelnames=labels.keys(), - multiprocess_mode="sum", + multiprocess_mode="mostrecent", ) self.num_queue_reqs = Gauge( name="sglang:num_queue_reqs", documentation="The number of requests in the waiting queue.", labelnames=labels.keys(), - multiprocess_mode="sum", + multiprocess_mode="mostrecent", ) self.cache_hit_rate = Gauge( @@ -97,6 +99,7 @@ class SchedulerMetricsCollector: self._log_gauge(self.num_queue_reqs, stats.num_queue_reqs) self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate) self._log_gauge(self.spec_accept_length, stats.spec_accept_length) + self.last_log_time = time.time() class TokenizerMetricsCollector: @@ -130,12 +133,15 @@ class TokenizerMetricsCollector: labelnames=labels.keys(), buckets=[ 0.1, - 0.25, + 0.3, 0.5, - 0.75, + 0.7, + 0.9, 1, 2, - 5, + 4, + 6, + 8, 10, 20, 40, @@ -151,24 +157,56 @@ class TokenizerMetricsCollector: documentation="Histogram of time per output token in seconds.", labelnames=labels.keys(), buckets=[ + 0.002, 0.005, - 0.01, + 0.010, + 0.020, + 0.030, + 0.040, + 0.050, + 0.060, + 0.070, + 0.080, + 0.090, + 0.100, + 0.150, + 0.200, + 0.300, + 0.400, + 0.600, + 0.800, + 1.000, + 2.000, + ], + ) + + self.histogram_inter_token_latency_seconds = Histogram( + name="sglang:inter_token_latency_seconds", + documentation="Histogram of inter-token latency in seconds.", + labelnames=labels.keys(), + buckets=[ + 0.002, + 0.004, + 0.006, + 0.008, + 0.010, 0.015, - 0.02, + 0.020, 0.025, - 0.03, - 0.04, - 0.05, + 0.030, + 0.035, + 0.040, + 0.050, 0.075, - 0.1, - 0.15, - 0.2, - 0.3, - 0.4, - 0.5, - 0.75, - 1.0, - 2.5, + 0.100, + 0.150, + 0.200, + 0.300, + 0.400, + 0.500, + 0.750, + 1.000, + 2.000, ], ) @@ -178,8 +216,9 @@ class TokenizerMetricsCollector: labelnames=labels.keys(), buckets=[ 0.1, - 0.25, - 0.5, + 0.2, + 0.4, + 0.8, 1, 2, 5, @@ -188,28 +227,161 @@ class TokenizerMetricsCollector: 40, 60, 80, + 100, + 150, + 200, + 250, + 300, + 350, + 500, + 1000, + ], + ) + + self.histogram_prefill_prealloc_duration = Histogram( + name="sglang:prefill_prealloc_duration_seconds", + documentation="Histogram of prefill prealloc duration in seconds.", + labelnames=labels.keys(), + buckets=[ + 0.1, + 0.3, + 0.5, + 0.7, + 0.9, + 1, + 2, + 4, + 6, + 8, + 10, + 20, + 40, + 60, + 80, 120, 160, ], ) + self.histogram_prefill_queue_duration = Histogram( + name="sglang:prefill_queue_duration_seconds", + documentation="Histogram of prefill queue duration in seconds.", + labelnames=labels.keys(), + buckets=[ + 0.1, + 0.3, + 0.5, + 0.7, + 0.9, + 2, + 4, + 8, + 16, + 64, + ], + ) + + self.histogram_prefill_forward_duration = Histogram( + name="sglang:prefill_forward_duration_seconds", + documentation="Histogram of prefill forward duration in seconds.", + labelnames=labels.keys(), + buckets=[ + 0.1, + 0.3, + 0.5, + 0.7, + 0.9, + 2, + 4, + 8, + 16, + 64, + ], + ) + + self.histogram_prefill_transfer_duration = Histogram( + name="sglang:prefill_transfer_duration_seconds", + documentation="Histogram of prefill transfer duration in seconds.", + labelnames=labels.keys(), + buckets=[ + 0.050, + 0.100, + 0.150, + 0.200, + 0.300, + 0.400, + 0.500, + 1.000, + 2.000, + ], + ) + + self.histogram_decode_prealloc_duration = Histogram( + name="sglang:decode_prealloc_duration_seconds", + documentation="Histogram of decode prealloc duration in seconds.", + labelnames=labels.keys(), + buckets=[ + 0.1, + 0.3, + 0.5, + 0.7, + 0.9, + 2, + 4, + 8, + 16, + 64, + ], + ) + + self.histogram_decode_queue_duration = Histogram( + name="sglang:decode_queue_duration_seconds", + documentation="Histogram of decode queue duration in seconds.", + labelnames=labels.keys(), + buckets=[ + 0.1, + 0.3, + 0.5, + 0.7, + 0.9, + 2, + 4, + 8, + 16, + 64, + ], + ) + def _log_histogram(self, histogram, data: Union[int, float]) -> None: histogram.labels(**self.labels).observe(data) - def _log_counter(self, counter, data: Union[int, float]) -> None: - # Convenience function for logging to counter. - counter.labels(**self.labels).inc(data) - - def observe_one_finished_request(self, prompt_tokens: int, generation_tokens: int): + def observe_one_finished_request( + self, + prompt_tokens: int, + generation_tokens: int, + e2e_latency: float, + ): self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens) self.generation_tokens_total.labels(**self.labels).inc(generation_tokens) self.num_requests_total.labels(**self.labels).inc(1) + self._log_histogram(self.histogram_e2e_request_latency, e2e_latency) + if generation_tokens >= 1: + self.histogram_time_per_output_token.labels(**self.labels).observe( + e2e_latency / generation_tokens + ) - def observe_time_to_first_token(self, value: Union[float, int]): - self._log_histogram(self.histogram_time_to_first_token, value) + def observe_time_to_first_token(self, value: float): + self.histogram_time_to_first_token.labels(**self.labels).observe(value) - def observe_time_per_output_token(self, value: Union[float, int]): - self._log_histogram(self.histogram_time_per_output_token, value) + def observe_inter_token_latency(self, internval: float, num_new_tokens: int): + adjusted_interval = internval / num_new_tokens - def observe_e2e_request_latency(self, value: Union[float, int]): - self._log_histogram(self.histogram_e2e_request_latency, value) + # A faster version of the Histogram::observe which observes multiple values at the same time. + # reference: https://github.com/prometheus/client_python/blob/v0.21.1/prometheus_client/metrics.py#L639 + his = self.histogram_inter_token_latency_seconds.labels(**self.labels) + his._sum.inc(internval) + + for i, bound in enumerate(his._upper_bounds): + if adjusted_interval <= bound: + his._buckets[i].inc(num_new_tokens) + break diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index dd1b0da94..385f67e57 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -109,11 +109,15 @@ def set_torch_compile_config(): def get_batch_sizes_to_capture(model_runner: ModelRunner): server_args = model_runner.server_args capture_bs = server_args.cuda_graph_bs + if capture_bs is None: - if server_args.disable_cuda_graph_padding: - capture_bs = list(range(1, 33)) + [64, 128] + if server_args.speculative_algorithm is None: + if server_args.disable_cuda_graph_padding: + capture_bs = list(range(1, 33)) + [64, 96, 128, 160] + else: + capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)] else: - capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)] + capture_bs = list(range(1, 33)) if is_hip_: capture_bs += [i * 8 for i in range(21, 33)] @@ -130,6 +134,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): ) ) ) + capture_bs = [ bs for bs in capture_bs @@ -385,9 +390,6 @@ class CudaGraphRunner: run_once() - torch.cuda.synchronize() - self.model_runner.tp_group.barrier() - torch.cuda.synchronize() self.model_runner.tp_group.barrier() @@ -401,12 +403,11 @@ class CudaGraphRunner: global_graph_memory_pool = graph.pool() return graph, out - def replay(self, forward_batch: ForwardBatch): - assert forward_batch.out_cache_loc is not None + def recapture_if_needed(self, forward_batch: ForwardBatch): + # If the capture_hidden_mode changes, we need to recapture the graph hidden_mode_from_spec_info = getattr( forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL ) - # If the capture_hidden_mode changes, we need to recapture the graph if ( forward_batch.capture_hidden_mode == CaptureHiddenMode.FULL and self.capture_hidden_mode != CaptureHiddenMode.FULL @@ -420,6 +421,9 @@ class CudaGraphRunner: self.capture_hidden_mode = hidden_mode_from_spec_info self.capture() + def replay(self, forward_batch: ForwardBatch): + self.recapture_if_needed(forward_batch) + raw_bs = forward_batch.batch_size raw_num_token = raw_bs * self.num_tokens_per_bs diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index cdd03bec4..79f445da0 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -31,7 +31,7 @@ from __future__ import annotations from dataclasses import dataclass from enum import IntEnum, auto -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, List, Optional, Union import torch import triton @@ -46,7 +46,8 @@ if TYPE_CHECKING: from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo - from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm + from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput + from sglang.srt.speculative.spec_info import SpeculativeAlgorithm class ForwardMode(IntEnum): @@ -112,7 +113,9 @@ class ForwardMode(IntEnum): class CaptureHiddenMode(IntEnum): NULL = auto() + # Capture hidden states of all tokens. FULL = auto() + # Capture a hidden state of the last token. LAST = auto() def need_capture(self): @@ -148,6 +151,7 @@ class ForwardBatch: # For logprob return_logprob: bool = False top_logprobs_nums: Optional[List[int]] = None + token_ids_logprobs: Optional[List[List[int]]] = None # Position information positions: torch.Tensor = None @@ -160,6 +164,7 @@ class ForwardBatch: extend_prefix_lens_cpu: Optional[List[int]] = None extend_seq_lens_cpu: Optional[List[int]] = None extend_logprob_start_lens_cpu: Optional[List[int]] = None + extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None # For multimodal image_inputs: Optional[List[ImageInputs]] = None @@ -190,10 +195,13 @@ class ForwardBatch: can_run_dp_cuda_graph: bool = False # Speculative decoding - spec_info: SpecInfo = None + spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None spec_algorithm: SpeculativeAlgorithm = None capture_hidden_mode: CaptureHiddenMode = None + # For padding + padded_static_len: int = -1 # -1 if not padded + # For Qwen2-VL mrope_positions: torch.Tensor = None @@ -203,8 +211,13 @@ class ForwardBatch: batch: ModelWorkerBatch, model_runner: ModelRunner, ): - device = model_runner.device + extend_input_logprob_token_ids_gpu = None + if batch.extend_input_logprob_token_ids is not None: + extend_input_logprob_token_ids_gpu = ( + batch.extend_input_logprob_token_ids.to(device, non_blocking=True) + ) + ret = cls( forward_mode=batch.forward_mode, batch_size=len(batch.seq_lens), @@ -220,6 +233,7 @@ class ForwardBatch: seq_lens_sum=batch.seq_lens_sum, return_logprob=batch.return_logprob, top_logprobs_nums=batch.top_logprobs_nums, + token_ids_logprobs=batch.token_ids_logprobs, global_num_tokens=batch.global_num_tokens, can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph, lora_paths=batch.lora_paths, @@ -231,6 +245,7 @@ class ForwardBatch: spec_info=batch.spec_info, capture_hidden_mode=batch.capture_hidden_mode, input_embeds=batch.input_embeds, + extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu, ) if ret.global_num_tokens is not None: @@ -341,6 +356,7 @@ class ForwardBatch: ) batch.image_inputs[i].mrope_position_delta = mrope_position_delta mrope_positions_list[i] = mrope_positions + self.mrope_positions = torch.concat( [torch.tensor(pos, device=device) for pos in mrope_positions_list], axis=1, @@ -379,7 +395,7 @@ def compute_position_kernel( extend_seq_lens, ): BLOCK_SIZE: tl.constexpr = 512 - pid = tl.program_id(0) + pid = tl.program_id(0).to(tl.int64) prefix_len = tl.load(extend_prefix_lens + pid) seq_len = tl.load(extend_seq_lens + pid) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 4b486ec8b..581dcbd88 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -13,9 +13,12 @@ # ============================================================================== """ModelRunner runs the forward passes of the models.""" +import collections +import datetime import gc import json import logging +import os import time from dataclasses import dataclass from typing import List, Optional, Tuple, Union @@ -58,6 +61,7 @@ from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader import get_model from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter @@ -73,10 +77,15 @@ from sglang.srt.utils import ( set_cpu_offload_max_bytes, set_cuda_arch, ) +from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) +SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None) +UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300 + + class ModelRunner: """ModelRunner runs the forward passes of the models.""" @@ -180,9 +189,13 @@ class ModelRunner: "enable_dp_attention": server_args.enable_dp_attention, "enable_ep_moe": server_args.enable_ep_moe, "device": server_args.device, + "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single, + "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc, "enable_flashinfer_mla": server_args.enable_flashinfer_mla, "disable_radix_cache": server_args.disable_radix_cache, "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged, + "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder, + "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject, } ) @@ -199,6 +212,18 @@ class ModelRunner: self.sampler = Sampler() self.load_model() + # Handle the case where some of models don't finish loading. + try: + dist.monitored_barrier( + group=get_tp_group().cpu_group, + timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S), + wait_all_ranks=True, + ) + except RuntimeError: + raise ValueError( + f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node." + ) from None + # Apply torchao quantization torchao_applied = getattr(self.model, "torchao_applied", False) # In layered loading, torchao may have been applied @@ -625,6 +650,9 @@ class ModelRunner: 4096, ) + if SGLANG_CI_SMALL_KV_SIZE: + self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE) + if not self.spec_algorithm.is_none(): if self.is_draft_worker: self.max_total_num_tokens = self.server_args.draft_runner_cache_size @@ -655,6 +683,7 @@ class ModelRunner: device=self.device, enable_memory_saver=self.server_args.enable_memory_saver, ) + if ( self.model_config.attention_arch == AttentionArch.MLA and not self.server_args.disable_mla @@ -758,9 +787,13 @@ class ModelRunner: return tic = time.time() - logger.info("Capture cuda graph begin. This can take up to several minutes.") + logger.info( + f"Capture cuda graph begin. This can take up to several minutes. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" + ) self.cuda_graph_runner = CudaGraphRunner(self) - logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s") + logger.info( + f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" + ) def apply_torch_tp(self): logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.") @@ -820,11 +853,10 @@ class ModelRunner: else: raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}") - def sample( - self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch - ) -> torch.Tensor: + def _preprocess_logits( + self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo + ): # Apply logit bias - sampling_info = forward_batch.sampling_info if sampling_info.sampling_info_done: # Overlap mode: the function update_regex_vocab_mask was executed # in process_batch_result of the last batch. @@ -833,15 +865,77 @@ class ModelRunner: else: # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass. sampling_info.update_regex_vocab_mask() - sampling_info.update_penalties() sampling_info.apply_logits_bias(logits_output.next_token_logits) + def update_output_logprobs( + self, + logits_output: LogitsProcessorOutput, + sampling_info: SamplingBatchInfo, + top_logprobs_nums: List[int], + token_ids_logprobs: List[int], + next_token_ids: torch.Tensor, + *, + num_tokens_per_req: List[int], + ): + """Update the logits_output's output logprob based on next_token_ids + + Args: + logits_output: The logits output from the model forward + sampling_info: Sampling info for logprob calculation + top_logprobs_nums: Number of logprobs per request. + next_token_ids: Next token ids. + num_tokens_per_req: The number of tokens per request. + + Returns: + A list of next_token_ids + """ + self._preprocess_logits(logits_output, sampling_info) + # We should repeat top_logprobs_nums to match num_tokens_per_req. + top_logprobs_nums_repeat_interleaved = [] + token_ids_logprobs_repeat_interleaved = [] + for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req): + top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens) + for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req): + token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens) + self.sampler( + logits_output, + sampling_info, + True, + top_logprobs_nums_repeat_interleaved, + token_ids_logprobs_repeat_interleaved, + batch_next_token_ids=next_token_ids, + ) + + def sample( + self, + logits_output: LogitsProcessorOutput, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + """Sample and compute logprobs and update logits_output. + + Args: + logits_output: The logits output from the model forward + forward_batch: The forward batch that generates logits_output + + Returns: + A list of next_token_ids + """ + # For duplex models with multiple output streams. + if isinstance(logits_output, tuple): + return torch.stack( + [self.sample(values, forward_batch) for values in logits_output], + axis=-1, + ) + + self._preprocess_logits(logits_output, forward_batch.sampling_info) + # Sample the next tokens next_token_ids = self.sampler( logits_output, - sampling_info, + forward_batch.sampling_info, forward_batch.return_logprob, forward_batch.top_logprobs_nums, + forward_batch.token_ids_logprobs, ) return next_token_ids diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index 822e28844..1106c6cb7 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -25,10 +25,10 @@ import filelock import gguf import huggingface_hub.constants import numpy as np +import safetensors.torch import torch from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator -from safetensors.torch import load_file, safe_open, save_file from tqdm.auto import tqdm from sglang.srt.configs.load_config import LoadConfig @@ -62,7 +62,6 @@ enable_hf_transfer() class DisabledTqdm(tqdm): - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs, disable=True) @@ -121,7 +120,7 @@ def convert_bin_to_safetensor_file( ) # check if the tensors are the same - reloaded = load_file(sf_filename) + reloaded = safetensors.torch.load_file(sf_filename) for k in loaded: pt_tensor = loaded[k] sf_tensor = reloaded[k] @@ -133,7 +132,6 @@ def convert_bin_to_safetensor_file( def get_quant_config( model_config: ModelConfig, load_config: LoadConfig ) -> QuantizationConfig: - quant_cls = get_quantization_config(model_config.quantization) # GGUF doesn't have config file @@ -402,15 +400,34 @@ def np_cache_weights_iterator( yield name, torch.from_numpy(param) +def decrypt(fn, key): + raise NotImplementedError() + + +def safetensors_encrypted_weights_iterator( + hf_weights_files: List[str], + is_all_weights_sharded: bool = False, + decryption_key: Optional[str] = None, +): + raise NotImplementedError() + + def safetensors_weights_iterator( hf_weights_files: List[str], is_all_weights_sharded: bool = False, + decryption_key: Optional[str] = None, ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model safetensor files. If is_all_weights_sharded is True, it uses more optimize read by reading an entire file instead of reading each tensor one by one. """ + if decryption_key: + yield from safetensors_encrypted_weights_iterator( + hf_weights_files, is_all_weights_sharded, decryption_key + ) + return + enable_tqdm = ( not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 ) @@ -420,15 +437,9 @@ def safetensors_weights_iterator( disable=not enable_tqdm, bar_format=_BAR_FORMAT, ): - if not is_all_weights_sharded: - with safe_open(st_file, framework="pt") as f: - for name in f.keys(): # noqa: SIM118 - param = f.get_tensor(name) - yield name, param - else: - result = load_file(st_file, device="cpu") - for name, param in result.items(): - yield name, param + result = safetensors.torch.load_file(st_file, device="cpu") + for name, param in result.items(): + yield name, param def pt_weights_iterator( diff --git a/python/sglang/srt/sampling/penaltylib/__init__.py b/python/sglang/srt/sampling/penaltylib/__init__.py index 43fff0fca..26a780517 100644 --- a/python/sglang/srt/sampling/penaltylib/__init__.py +++ b/python/sglang/srt/sampling/penaltylib/__init__.py @@ -1,13 +1,11 @@ -from .orchestrator import BatchedPenalizerOrchestrator -from .penalizers.frequency_penalty import BatchedFrequencyPenalizer -from .penalizers.min_new_tokens import BatchedMinNewTokensPenalizer -from .penalizers.presence_penalty import BatchedPresencePenalizer -from .penalizers.repetition_penalty import BatchedRepetitionPenalizer +from sglang.srt.sampling.penaltylib.frequency_penalty import BatchedFrequencyPenalizer +from sglang.srt.sampling.penaltylib.min_new_tokens import BatchedMinNewTokensPenalizer +from sglang.srt.sampling.penaltylib.orchestrator import BatchedPenalizerOrchestrator +from sglang.srt.sampling.penaltylib.presence_penalty import BatchedPresencePenalizer __all__ = [ "BatchedFrequencyPenalizer", "BatchedMinNewTokensPenalizer", "BatchedPresencePenalizer", - "BatchedRepetitionPenalizer", "BatchedPenalizerOrchestrator", ] diff --git a/python/sglang/srt/sampling/penaltylib/frequency_penalty.py b/python/sglang/srt/sampling/penaltylib/frequency_penalty.py new file mode 100644 index 000000000..691534627 --- /dev/null +++ b/python/sglang/srt/sampling/penaltylib/frequency_penalty.py @@ -0,0 +1,66 @@ +import torch + +from sglang.srt.sampling.penaltylib.orchestrator import ( + BatchedPenalizerOrchestrator, + _BatchedPenalizer, +) + + +class BatchedFrequencyPenalizer(_BatchedPenalizer): + """ + Frequency penalizer penalizes tokens based on their frequency in the output. + """ + + def __init__(self, orchestrator: BatchedPenalizerOrchestrator): + self.orchestrator = orchestrator + self._is_prepared = False + + def _is_required(self) -> bool: + return any( + req.sampling_params.frequency_penalty != 0.0 + for req in self.orchestrator.reqs() + ) + + def _prepare(self): + self.cumulated_frequency_penalties = torch.zeros( + (len(self.orchestrator.reqs()), self.orchestrator.vocab_size), + dtype=torch.float32, + device=self.orchestrator.device, + ) + + self.frequency_penalties = ( + torch.tensor( + data=[ + req.sampling_params.frequency_penalty + for req in self.orchestrator.reqs() + ], + dtype=torch.float32, + device=self.orchestrator.device, + ) + ).unsqueeze_(1) + + def _cumulate_output_tokens(self, output_ids: torch.Tensor): + self.cumulated_frequency_penalties.scatter_add_( + dim=1, + index=output_ids.unsqueeze(1), + src=self.frequency_penalties, + ) + + def _apply(self, logits: torch.Tensor) -> torch.Tensor: + logits.sub_(self.cumulated_frequency_penalties) + + def _filter(self, keep_indices: torch.Tensor): + self.frequency_penalties = self.frequency_penalties[keep_indices] + self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[ + keep_indices + ] + + def _merge(self, their: "BatchedFrequencyPenalizer"): + print(f"{self.frequency_penalties.shape=}, {their.frequency_penalties.shape=}") + self.frequency_penalties = torch.cat( + [self.frequency_penalties, their.frequency_penalties], dim=0 + ) + self.cumulated_frequency_penalties = torch.cat( + [self.cumulated_frequency_penalties, their.cumulated_frequency_penalties], + dim=0, + ) diff --git a/python/sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py b/python/sglang/srt/sampling/penaltylib/min_new_tokens.py similarity index 70% rename from python/sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py rename to python/sglang/srt/sampling/penaltylib/min_new_tokens.py index 0e27c7e5a..da06265d9 100644 --- a/python/sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +++ b/python/sglang/srt/sampling/penaltylib/min_new_tokens.py @@ -1,8 +1,9 @@ -from typing import List - import torch -from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs +from sglang.srt.sampling.penaltylib.orchestrator import ( + BatchedPenalizerOrchestrator, + _BatchedPenalizer, +) class BatchedMinNewTokensPenalizer(_BatchedPenalizer): @@ -10,9 +11,9 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer): Min new tokens penalizer penalizes tokens based on the length of the output. """ - min_new_tokens: torch.Tensor = None - stop_token_penalties: torch.Tensor = None - len_output_tokens: torch.Tensor = None + def __init__(self, orchestrator: BatchedPenalizerOrchestrator): + self.orchestrator = orchestrator + self._is_prepared = False def _is_required(self) -> bool: return any( @@ -47,7 +48,7 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer): padding_value=self.orchestrator.vocab_size, ) self.stop_token_penalties = torch.zeros( - size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1), + size=(len(self.orchestrator.reqs()), self.orchestrator.vocab_size + 1), dtype=torch.float32, device=self.orchestrator.device, ).scatter_add_( @@ -64,31 +65,22 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer): ] self.len_output_tokens = torch.zeros( - size=(self.orchestrator.batch_size(), 1), + size=(len(self.orchestrator.reqs()), 1), dtype=torch.int32, device=self.orchestrator.device, ) - def _teardown(self): - self.min_new_tokens = None - self.stop_token_penalties = None - self.len_output_tokens = None - - def _cumulate_input_tokens(self, input_ids: _TokenIDs): - pass - - def _cumulate_output_tokens(self, output_ids: _TokenIDs): + def _cumulate_output_tokens(self, output_ids: torch.Tensor): self.len_output_tokens += 1 - def _apply(self, logits: torch.Tensor) -> torch.Tensor: + def _apply(self, logits: torch.Tensor): mask = (self.len_output_tokens < self.min_new_tokens).expand_as(logits) logits[mask] += self.stop_token_penalties[mask] - return logits - def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor): - self.min_new_tokens = self.min_new_tokens[indices_tensor_to_keep] - self.stop_token_penalties = self.stop_token_penalties[indices_tensor_to_keep] - self.len_output_tokens = self.len_output_tokens[indices_tensor_to_keep] + def _filter(self, keep_indices: torch.Tensor): + self.min_new_tokens = self.min_new_tokens[keep_indices] + self.stop_token_penalties = self.stop_token_penalties[keep_indices] + self.len_output_tokens = self.len_output_tokens[keep_indices] def _merge(self, their: "BatchedMinNewTokensPenalizer"): self.min_new_tokens = torch.cat( diff --git a/python/sglang/srt/sampling/penaltylib/orchestrator.py b/python/sglang/srt/sampling/penaltylib/orchestrator.py index 9c393d180..a75d5e9bb 100644 --- a/python/sglang/srt/sampling/penaltylib/orchestrator.py +++ b/python/sglang/srt/sampling/penaltylib/orchestrator.py @@ -1,35 +1,25 @@ +from __future__ import annotations + import abc -import dataclasses -from typing import List, Set, Type, Union +from typing import TYPE_CHECKING, Set, Type import torch - -@dataclasses.dataclass -class _ReqLike: - origin_input_ids: List[int] - - -@dataclasses.dataclass -class _BatchLike: - reqs: List[_ReqLike] - - def batch_size(self): - return len(self.reqs) +if TYPE_CHECKING: + from sglang.srt.managers.schedule_batch import ScheduleBatch class BatchedPenalizerOrchestrator: def __init__( self, vocab_size: int, - batch: _BatchLike, - device: str, - Penalizers: Set[Type["_BatchedPenalizer"]], + batch: ScheduleBatch, + penalizers: Set[Type["_BatchedPenalizer"]], ): self.vocab_size = vocab_size self.batch = batch - self.device = device - self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers} + self.device = batch.device + self.penalizers = {Penalizer: Penalizer(self) for Penalizer in penalizers} is_required = False for penalizer in self.penalizers.values(): @@ -37,31 +27,9 @@ class BatchedPenalizerOrchestrator: is_required |= pen_is_required self.is_required = is_required - input_ids = [ - torch.tensor(req.origin_input_ids, dtype=torch.int64, device=self.device) - for req in self.reqs() - ] - if self.is_required: - self.cumulate_input_tokens(input_ids=input_ids) - def reqs(self): return self.batch.reqs - def batch_size(self): - return self.batch.batch_size() - - def cumulate_input_tokens(self, input_ids: List[torch.Tensor]): - """ - Feed the input tokens to the penalizers. - - Args: - input_ids (List[torch.Tensor]): The input tokens. - """ - token_ids = _TokenIDs(orchestrator=self, token_ids=input_ids) - - for penalizer in self.penalizers.values(): - penalizer.cumulate_input_tokens(input_ids=token_ids) - def cumulate_output_tokens(self, output_ids: torch.Tensor): """ Feed the output tokens to the penalizers. @@ -69,13 +37,8 @@ class BatchedPenalizerOrchestrator: Args: output_ids (torch.Tensor): The output tokens. """ - if not self.is_required: - return - - token_ids = _TokenIDs(orchestrator=self, token_ids=output_ids) - for penalizer in self.penalizers.values(): - penalizer.cumulate_output_tokens(output_ids=token_ids) + penalizer.cumulate_output_tokens(output_ids=output_ids) def apply(self, logits: torch.Tensor) -> torch.Tensor: """ @@ -88,48 +51,33 @@ class BatchedPenalizerOrchestrator: Returns: torch.Tensor: The logits after applying the penalizers. """ - if not self.is_required: - return - for penalizer in self.penalizers.values(): - logits = penalizer.apply(logits) + penalizer.apply(logits) - return logits - - def filter( - self, - indices_to_keep: List[int], - indices_tensor_to_keep: torch.Tensor = None, - ): + def filter(self, keep_indices: torch.Tensor): """ Filter the penalizers based on the indices to keep in the batch. Args: - indices_to_keep (List[int]): List of indices to keep in the batch. - indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor. + keep_indices (torch.Tensor): Tensor of indices to keep in the batch. """ if not self.is_required: return - empty_indices = len(indices_to_keep) == 0 + if len(keep_indices) == 0: + self.is_required = False + for penalizer in self.penalizers.values(): + penalizer.teardown() + return is_required = False for penalizer in self.penalizers.values(): tmp_is_required = penalizer.is_required() - is_required = is_required or tmp_is_required - if not tmp_is_required or empty_indices: - penalizer.teardown() + is_required |= tmp_is_required + if tmp_is_required: + penalizer.filter(keep_indices=keep_indices) else: - # create tensor index only when it's needed - if indices_tensor_to_keep is None: - indices_tensor_to_keep = torch.tensor( - indices_to_keep, dtype=torch.int32, device=self.device - ) - - penalizer.filter( - indices_to_keep=indices_to_keep, - indices_tensor_to_keep=indices_tensor_to_keep, - ) + penalizer.teardown() self.is_required = is_required def merge(self, their: "BatchedPenalizerOrchestrator"): @@ -146,75 +94,9 @@ class BatchedPenalizerOrchestrator: if not self.is_required and not their.is_required: return - self.is_required |= their.is_required - for Penalizer, their_penalizer in their.penalizers.items(): - if Penalizer not in self.penalizers: - raise ValueError(f"Penalizer {Penalizer} not found in self.penalizers") - - self.penalizers[Penalizer].merge(their_penalizer) - - -class _TokenIDs: - """ - A class that wraps token IDs to provide additional utility functions to penalizers. - - Attributes: - orchestrator (BatchedPenalizerOrchestrator): The orchestrator that this token IDs belong to. - token_ids (Union[torch.Tensor, List[torch.Tensor]]): The token IDs. - cached_counts (torch.Tensor): The cached occurrence count tensor. - """ - - def __init__( - self, - orchestrator: BatchedPenalizerOrchestrator, - token_ids: Union[torch.Tensor, List[torch.Tensor]], - ): - self.orchestrator = orchestrator - self.token_ids = token_ids - self.cached_counts = None - - def occurrence_count(self) -> torch.Tensor: - """ - Returns a tensor of shape (batch_size, vocab_size) where each element is the number of times the corresponding token appears in the batch. - - Returns: - torch.Tensor: The occurrence count tensor. - """ - if self.cached_counts is not None: - return self.cached_counts - - token_ids = self.token_ids - - if isinstance(token_ids, list): - # TODO: optimize this part - padded_token_ids = torch.nn.utils.rnn.pad_sequence( - sequences=token_ids, - batch_first=True, - padding_value=self.orchestrator.vocab_size, - ) - self.cached_counts = torch.zeros( - size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1), - dtype=torch.int64, - device=self.orchestrator.device, - ).scatter_add_( - dim=1, - index=padded_token_ids, - src=torch.ones_like(padded_token_ids), - )[ - :, : self.orchestrator.vocab_size - ] - else: - # TODO: optimize this part. We do not need to create this big tensor every time. - # We can directly apply the results on the logits. - self.cached_counts = torch.zeros( - size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size), - device=self.orchestrator.device, - ) - self.cached_counts[ - torch.arange(len(token_ids), device=self.orchestrator.device), token_ids - ] = 1 - - return self.cached_counts + self.is_required = True + for penalizer, their_penalizer in their.penalizers.items(): + self.penalizers[penalizer].merge(their_penalizer) class _BatchedPenalizer(abc.ABC): @@ -222,10 +104,6 @@ class _BatchedPenalizer(abc.ABC): An abstract class for a batched penalizer. """ - def __init__(self, orchestrator: BatchedPenalizerOrchestrator): - self.orchestrator = orchestrator - self._is_prepared = False - def is_prepared(self) -> bool: return self._is_prepared @@ -233,51 +111,40 @@ class _BatchedPenalizer(abc.ABC): return self._is_required() def prepare(self): - if not self.is_prepared(): + if not self._is_prepared: self._prepare() self._is_prepared = True def prepare_if_required(self): - if self.is_required(): + if self._is_required(): self.prepare() return True else: return False def teardown(self): - if self.is_prepared(): - self._teardown() - self._is_prepared = False + self._is_prepared = False - def cumulate_input_tokens(self, input_ids: _TokenIDs): - if not self.is_prepared(): - return - - self._cumulate_input_tokens(input_ids=input_ids) - - def cumulate_output_tokens(self, output_ids: _TokenIDs): - if not self.is_prepared(): + def cumulate_output_tokens(self, output_ids: torch.Tensor): + if not self._is_prepared: return self._cumulate_output_tokens(output_ids=output_ids) def apply(self, logits: torch.Tensor) -> torch.Tensor: - if not self.is_prepared(): - return logits - - return self._apply(logits=logits) - - def filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor): - if not self.is_prepared(): + if not self._is_prepared: return - self._filter( - indices_to_keep=indices_to_keep, - indices_tensor_to_keep=indices_tensor_to_keep, - ) + self._apply(logits=logits) + + def filter(self, keep_indices: torch.Tensor): + if not self._is_prepared: + return + + self._filter(keep_indices=keep_indices) def merge(self, their: "_BatchedPenalizer"): - if not self.is_prepared() and not their.is_prepared(): + if not self._is_prepared and not their._is_prepared: return self.prepare() @@ -300,23 +167,7 @@ class _BatchedPenalizer(abc.ABC): pass @abc.abstractmethod - def _teardown(self): - """ - Tear down the penalizer. - Usually, this is where the penalizer frees its tensors. - """ - pass - - @abc.abstractmethod - def _cumulate_input_tokens(self, input_ids: _TokenIDs): - """ - Cumulate the input tokens. - Orchestrator will call this function to feed the input tokens to the penalizer. - """ - pass - - @abc.abstractmethod - def _cumulate_output_tokens(self, output_ids: _TokenIDs): + def _cumulate_output_tokens(self, output_ids: torch.Tensor): """ Cumulate the output tokens. Orchestrator will call this function to feed the output tokens to the penalizer. @@ -332,7 +183,7 @@ class _BatchedPenalizer(abc.ABC): pass @abc.abstractmethod - def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor): + def _filter(self, keep_indices: torch.Tensor): """ Filter the penalizer (tensors or underlying data) based on the indices to keep in the batch. """ diff --git a/python/sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py b/python/sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py deleted file mode 100644 index 34fa5abbf..000000000 --- a/python/sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +++ /dev/null @@ -1,75 +0,0 @@ -from typing import List - -import torch - -from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs - - -class BatchedFrequencyPenalizer(_BatchedPenalizer): - """ - Frequency penalizer penalizes tokens based on their frequency in the output. - """ - - frequency_penalties: torch.Tensor = None - cumulated_frequency_penalties: torch.Tensor = None - - def _is_required(self) -> bool: - return any( - req.sampling_params.frequency_penalty != 0.0 - for req in self.orchestrator.reqs() - ) - - def _prepare(self): - self.cumulated_frequency_penalties = ( - torch.tensor( - data=[0.0 for _ in self.orchestrator.reqs()], - dtype=torch.float32, - device=self.orchestrator.device, - ) - .unsqueeze_(1) - .repeat(1, self.orchestrator.vocab_size) - ) - - self.frequency_penalties = ( - torch.tensor( - data=[ - req.sampling_params.frequency_penalty - for req in self.orchestrator.reqs() - ], - dtype=torch.float32, - device=self.orchestrator.device, - ) - .unsqueeze_(1) - .expand_as(self.cumulated_frequency_penalties) - ) - - def _teardown(self): - self.frequency_penalties = None - self.cumulated_frequency_penalties = None - - def _cumulate_input_tokens(self, input_ids: _TokenIDs): - pass - - def _cumulate_output_tokens(self, output_ids: _TokenIDs): - self.cumulated_frequency_penalties += ( - self.frequency_penalties * output_ids.occurrence_count() - ) - - def _apply(self, logits: torch.Tensor) -> torch.Tensor: - logits -= self.cumulated_frequency_penalties - return logits - - def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor): - self.frequency_penalties = self.frequency_penalties[indices_tensor_to_keep] - self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[ - indices_tensor_to_keep - ] - - def _merge(self, their: "BatchedFrequencyPenalizer"): - self.frequency_penalties = torch.cat( - [self.frequency_penalties, their.frequency_penalties], dim=0 - ) - self.cumulated_frequency_penalties = torch.cat( - [self.cumulated_frequency_penalties, their.cumulated_frequency_penalties], - dim=0, - ) diff --git a/python/sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py b/python/sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py deleted file mode 100644 index f86aa4a2d..000000000 --- a/python/sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +++ /dev/null @@ -1,74 +0,0 @@ -from typing import List - -import torch - -from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs - - -class BatchedPresencePenalizer(_BatchedPenalizer): - """ - Presence penalizer penalizes tokens based on their presence in the output. - """ - - presence_penalties: torch.Tensor = None - cumulated_presence_penalties: torch.Tensor = None - - def _is_required(self) -> bool: - return any( - req.sampling_params.presence_penalty != 0.0 - for req in self.orchestrator.reqs() - ) - - def _prepare(self): - self.cumulated_presence_penalties = ( - torch.tensor( - data=[0.0 for _ in self.orchestrator.reqs()], - dtype=torch.float32, - device=self.orchestrator.device, - ) - .unsqueeze_(1) - .repeat(1, self.orchestrator.vocab_size) - ) - - self.presence_penalties = ( - torch.tensor( - data=[ - req.sampling_params.presence_penalty - for req in self.orchestrator.reqs() - ], - dtype=torch.float32, - device=self.orchestrator.device, - ) - .unsqueeze_(1) - .expand_as(self.cumulated_presence_penalties) - ) - - def _teardown(self): - self.presence_penalties = None - self.cumulated_presence_penalties = None - - def _cumulate_input_tokens(self, input_ids: _TokenIDs): - pass - - def _cumulate_output_tokens(self, output_ids: _TokenIDs): - mask = output_ids.occurrence_count() > 0 - self.cumulated_presence_penalties[mask] = self.presence_penalties[mask] - - def _apply(self, logits: torch.Tensor) -> torch.Tensor: - logits -= self.cumulated_presence_penalties - return logits - - def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor): - self.presence_penalties = self.presence_penalties[indices_tensor_to_keep] - self.cumulated_presence_penalties = self.cumulated_presence_penalties[ - indices_tensor_to_keep - ] - - def _merge(self, their: "BatchedPresencePenalizer"): - self.presence_penalties = torch.cat( - [self.presence_penalties, their.presence_penalties], dim=0 - ) - self.cumulated_presence_penalties = torch.cat( - [self.cumulated_presence_penalties, their.cumulated_presence_penalties], - dim=0, - ) diff --git a/python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py b/python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py deleted file mode 100644 index fe687c569..000000000 --- a/python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +++ /dev/null @@ -1,85 +0,0 @@ -from typing import List - -import torch - -from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs -from sglang.srt.utils import get_compiler_backend - - -@torch.compile(dynamic=True, backend=get_compiler_backend()) -def apply_scaling_penalties(logits, scaling_penalties): - logits[:] = torch.where( - logits > 0, - logits / scaling_penalties, - logits * scaling_penalties, - ) - - -class BatchedRepetitionPenalizer(_BatchedPenalizer): - """ - Repetition penalizer penalizes tokens based on their repetition in the input and output. - """ - - repetition_penalties: torch.Tensor = None - cumulated_repetition_penalties: torch.Tensor = None - - def _is_required(self) -> bool: - return any( - req.sampling_params.repetition_penalty != 1.0 - for req in self.orchestrator.reqs() - ) - - def _prepare(self): - self.cumulated_repetition_penalties = ( - torch.tensor( - data=[1.0 for _ in self.orchestrator.reqs()], - dtype=torch.float32, - device=self.orchestrator.device, - ) - .unsqueeze_(1) - .repeat(1, self.orchestrator.vocab_size) - ) - - self.repetition_penalties = ( - torch.tensor( - data=[ - req.sampling_params.repetition_penalty - for req in self.orchestrator.reqs() - ], - dtype=torch.float32, - device=self.orchestrator.device, - ) - .unsqueeze_(1) - .expand_as(self.cumulated_repetition_penalties) - ) - - def _teardown(self): - self.repetition_penalties = None - self.cumulated_repetition_penalties = None - - def _cumulate_input_tokens(self, input_ids: _TokenIDs): - mask = input_ids.occurrence_count() > 0 - self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask] - - def _cumulate_output_tokens(self, output_ids: _TokenIDs): - mask = output_ids.occurrence_count() > 0 - self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask] - - def _apply(self, logits: torch.Tensor) -> torch.Tensor: - apply_scaling_penalties(logits, self.cumulated_repetition_penalties) - return logits - - def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor): - self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep] - self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[ - indices_tensor_to_keep - ] - - def _merge(self, their: "BatchedRepetitionPenalizer"): - self.repetition_penalties = torch.cat( - [self.repetition_penalties, their.repetition_penalties], dim=0 - ) - self.cumulated_repetition_penalties = torch.cat( - [self.cumulated_repetition_penalties, their.cumulated_repetition_penalties], - dim=0, - ) diff --git a/python/sglang/srt/sampling/penaltylib/presence_penalty.py b/python/sglang/srt/sampling/penaltylib/presence_penalty.py new file mode 100644 index 000000000..91266b352 --- /dev/null +++ b/python/sglang/srt/sampling/penaltylib/presence_penalty.py @@ -0,0 +1,66 @@ +import torch + +from sglang.srt.sampling.penaltylib.orchestrator import ( + BatchedPenalizerOrchestrator, + _BatchedPenalizer, +) + + +class BatchedPresencePenalizer(_BatchedPenalizer): + """ + Presence penalizer penalizes tokens based on their presence in the output. + """ + + def __init__(self, orchestrator: BatchedPenalizerOrchestrator): + self.orchestrator = orchestrator + self._is_prepared = False + + def _is_required(self) -> bool: + return any( + req.sampling_params.presence_penalty != 0.0 + for req in self.orchestrator.reqs() + ) + + def _prepare(self): + self.cumulated_presence_penalties = torch.zeros( + (len(self.orchestrator.reqs()), self.orchestrator.vocab_size), + dtype=torch.float32, + device=self.orchestrator.device, + ) + + self.presence_penalties = ( + torch.tensor( + data=[ + req.sampling_params.presence_penalty + for req in self.orchestrator.reqs() + ], + dtype=torch.float32, + device=self.orchestrator.device, + ) + ).unsqueeze_(1) + + def _cumulate_output_tokens(self, output_ids: torch.Tensor): + self.cumulated_presence_penalties.scatter_( + dim=1, + index=output_ids.unsqueeze(1), + src=self.presence_penalties, + ) + + def _apply(self, logits: torch.Tensor) -> torch.Tensor: + logits.sub_(self.cumulated_presence_penalties) + + def _filter(self, keep_indices: torch.Tensor): + self.presence_penalties = self.presence_penalties[keep_indices] + self.cumulated_presence_penalties = self.cumulated_presence_penalties[ + keep_indices + ] + + def _merge(self, their: "BatchedPresencePenalizer"): + print(f"{self.presence_penalties.shape=}, {their.presence_penalties.shape=}") + self.presence_penalties = torch.cat( + [self.presence_penalties, their.presence_penalties], dim=0 + ) + self.cumulated_presence_penalties = torch.cat( + [self.cumulated_presence_penalties, their.cumulated_presence_penalties], + dim=0, + ) diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 393e713e9..5942b8270 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -9,9 +9,6 @@ import torch import sglang.srt.sampling.penaltylib as penaltylib from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor -from sglang.srt.sampling.penaltylib.penalizers.repetition_penalty import ( - apply_scaling_penalties, -) logger = logging.getLogger(__name__) @@ -22,49 +19,45 @@ if TYPE_CHECKING: @dataclasses.dataclass class SamplingBatchInfo: - # Batched sampling params + # Basic batched sampling params temperatures: torch.Tensor top_ps: torch.Tensor top_ks: torch.Tensor min_ps: torch.Tensor - # All requests use greedy sampling + # Whether all requests use greedy sampling is_all_greedy: bool - # Dispatch in CUDA graph + # Whether any request needs min_p sampling need_min_p_sampling: bool - # Whether any request has custom logit processor - has_custom_logit_processor: bool - - # Bias Tensors + # Masking tensors for grammar-guided structured outputs vocab_size: int grammars: Optional[List] = None - sampling_info_done: Optional[threading.Event] = None - logit_bias: torch.Tensor = None vocab_mask: Optional[torch.Tensor] = None - apply_mask: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None + apply_mask_func: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None + + # An event used for overlap schedule + sampling_info_done: Optional[threading.Event] = None # Penalizer penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None - linear_penalties: Optional[torch.Tensor] = None - scaling_penalties: Optional[torch.Tensor] = None + linear_penalty: torch.Tensor = None - # Device - device: str = "cuda" - - # Custom Parameters + # Whether any request has custom logit processor + has_custom_logit_processor: bool = False + # Custom parameters custom_params: Optional[List[Optional[Dict[str, Any]]]] = None - - # Custom Logit Processor + # Custom logit processor custom_logit_processor: Optional[ Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]] ] = None + # Device + device: str = "cuda" + @classmethod - def from_schedule_batch( - cls, batch: ScheduleBatch, vocab_size: int, enable_overlap_schedule: bool - ): + def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): reqs = batch.reqs device = batch.device temperatures = ( @@ -118,106 +111,60 @@ class SamplingBatchInfo: merged_custom_logit_processor = None custom_params = None + # Each penalizers will do nothing if they evaluate themselves as not required by looking at + # the sampling_params of the requests (See {_is_required()} of each penalizers). So this + # should not add hefty computation overhead other than simple checks. + # + # While we can choose not to even create the class instances if they are not required, this + # could add additional complexity to the {ScheduleBatch} class, especially we need to + # handle {filter_batch()} and {merge_batch()} cases as well. + penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator( + vocab_size=vocab_size, + batch=batch, + penalizers={ + penaltylib.BatchedFrequencyPenalizer, + penaltylib.BatchedMinNewTokensPenalizer, + penaltylib.BatchedPresencePenalizer, + }, + ) + ret = cls( temperatures=temperatures, top_ps=top_ps, top_ks=top_ks, min_ps=min_ps, - need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs), is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs), - has_custom_logit_processor=has_custom_logit_processor, + need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs), vocab_size=vocab_size, - device=device, + penalizer_orchestrator=penalizer_orchestrator, + has_custom_logit_processor=has_custom_logit_processor, custom_params=custom_params, custom_logit_processor=merged_custom_logit_processor, + device=device, ) - # TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge. - - if enable_overlap_schedule: - # TODO (lianmin): Some penalizers such as frequency and presence depend on model outputs, - # so it is kind of tricky to make it work with overlap scheduler. - # It requires correcly updating the penalty logits before the sampling and syncing the events. - # We will support them later. - penalizers = { - penaltylib.BatchedMinNewTokensPenalizer, - } - if ( - any(req.sampling_params.frequency_penalty != 0.0 for req in reqs) - or any(req.sampling_params.presence_penalty != 0.0 for req in reqs) - or any(req.sampling_params.repetition_penalty != 1.0 for req in reqs) - ): - logger.warning( - "frequency_penalty, presence_penalty, and repetition_penalty are not supported " - "when using the default overlap scheduler. They will be ignored. " - "Please add `--disable-overlap` when launching the server if you need these features. " - "The speed will be slower in that case." - ) - else: - penalizers = { - penaltylib.BatchedFrequencyPenalizer, - penaltylib.BatchedMinNewTokensPenalizer, - penaltylib.BatchedPresencePenalizer, - penaltylib.BatchedRepetitionPenalizer, - } - - # Each penalizers will do nothing if they evaluate themselves as not required by looking at - # the sampling_params of the requests (See {_is_required()} of each penalizers). So this - # should not add hefty computation overhead other than simple checks. - # - # While we choose not to even create the class instances if they are not required, this - # could add additional complexity to the {ScheduleBatch} class, especially we need to - # handle {filter_batch()} and {merge_batch()} cases as well. - ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator( - vocab_size=vocab_size, - batch=batch, - device=batch.device, - Penalizers=penalizers, - ) - - # Handle logit bias but only allocate when needed - ret.logit_bias = None - return ret def __len__(self): return len(self.temperatures) - def update_penalties(self): - self.scaling_penalties = None - self.linear_penalties = None - - for penalizer in self.penalizer_orchestrator.penalizers.values(): - if not penalizer.is_prepared(): - continue - - if isinstance(penalizer, penaltylib.BatchedRepetitionPenalizer): - self.scaling_penalties = penalizer.cumulated_repetition_penalties - else: - if self.linear_penalties is None: - bs = self.penalizer_orchestrator.batch.batch_size() - self.linear_penalties = torch.zeros( - (bs, self.vocab_size), - dtype=torch.float32, - device=self.device, - ) - self.linear_penalties = penalizer.apply(self.linear_penalties) - def update_regex_vocab_mask(self): if not self.grammars: self.vocab_mask = None - self.apply_mask = None + self.apply_mask_func = None return - # find a grammar from the list + # Find a grammar from the list first_grammar = next(grammar for grammar in self.grammars if grammar) - # maybe we can reuse the existing mask? + # TODO(lianmin): Maybe we can reuse the existing mask? self.vocab_mask = first_grammar.allocate_vocab_mask( vocab_size=self.vocab_size, batch_size=len(self.temperatures), device=self.device, ) - self.apply_mask = first_grammar.apply_vocab_mask # force to use static method + self.apply_mask_func = ( + first_grammar.apply_vocab_mask + ) # force to use static method # Apply the mask for i, grammar in enumerate(self.grammars): @@ -227,35 +174,56 @@ class SamplingBatchInfo: # Move the mask to the device if needed self.vocab_mask = first_grammar.move_vocab_mask(self.vocab_mask, self.device) - def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor): - self.penalizer_orchestrator.filter(unfinished_indices, new_indices) + def update_penalties(self): + if self.penalizer_orchestrator.is_required: + self.linear_penalty = torch.zeros( + (len(self.temperatures), self.vocab_size), + dtype=torch.float32, + device=self.temperatures.device, + ) + self.penalizer_orchestrator.apply(self.linear_penalty) + else: + self.linear_penalty = None + + def apply_logits_bias(self, logits: torch.Tensor): + if self.linear_penalty is not None: + # Used in the overlap mode + logits.add_(self.linear_penalty) + + if self.penalizer_orchestrator and self.penalizer_orchestrator.is_required: + # Used in the non-overlap mode + self.penalizer_orchestrator.apply(logits) + + if self.vocab_mask is not None: + self.apply_mask_func(logits=logits, vocab_mask=self.vocab_mask) + + def filter_batch(self, keep_indices: List[int], keep_indices_device: torch.Tensor): + self.penalizer_orchestrator.filter(keep_indices_device) + if self.has_custom_logit_processor: - self._filter_batch_custom_logit_processor(unfinished_indices, new_indices) + self._filter_batch_custom_logit_processor(keep_indices, keep_indices_device) for item in [ "temperatures", "top_ps", "top_ks", "min_ps", - "logit_bias", ]: value = getattr(self, item, None) - if value is not None: # logit_bias can be None - setattr(self, item, value[new_indices]) + setattr(self, item, value[keep_indices_device]) def _filter_batch_custom_logit_processor( - self, unfinished_indices: List[int], new_indices: torch.Tensor + self, keep_indices: List[int], keep_indices_device: torch.Tensor ): """Filter the custom logit processor and custom params""" - self.custom_logit_processor = { - k: (p, mask[new_indices]) + k: (p, mask[keep_indices_device]) for k, (p, mask) in self.custom_logit_processor.items() - if any( - mask[new_indices] + if torch.any( + mask[keep_indices_device] ) # ignore the custom logit processor whose mask is all False } - self.custom_params = [self.custom_params[i] for i in unfinished_indices] + self.custom_params = [self.custom_params[i] for i in keep_indices] # If the custom logit processor is an empty dict, set the flag to False, # and set the custom logit processor and custom params to None. @@ -264,31 +232,6 @@ class SamplingBatchInfo: self.custom_params = None self.has_custom_logit_processor = False - @staticmethod - def merge_bias_tensor( - lhs: torch.Tensor, - rhs: torch.Tensor, - bs1: int, - bs2: int, - device: str, - default: int = 0, - ): - # bias tensor can be None - if lhs is not None or rhs is not None: - shape, dtype = None, None - if lhs is not None: - shape, dtype = lhs.shape[1:], lhs.dtype - else: - shape, dtype = rhs.shape[1:], rhs.dtype - with torch.dtype(dtype): - if lhs is None: - lhs = torch.empty((bs1, *shape), device=device).fill_(default) - if rhs is None: - rhs = torch.empty((bs2, *shape), device=device).fill_(default) - return torch.cat([lhs, rhs]) - - return None - @staticmethod def merge_custom_logit_processor( lhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]], @@ -332,11 +275,6 @@ class SamplingBatchInfo: def merge_batch(self, other: "SamplingBatchInfo"): self.penalizer_orchestrator.merge(other.penalizer_orchestrator) - # Merge the logit bias tensor - self.logit_bias = SamplingBatchInfo.merge_bias_tensor( - self.logit_bias, other.logit_bias, len(self), len(other), self.device - ) - # Merge the custom logit processors and custom params lists if self.has_custom_logit_processor or other.has_custom_logit_processor: # Merge the custom logit processors @@ -370,22 +308,5 @@ class SamplingBatchInfo: other_val = getattr(other, item, None) setattr(self, item, torch.concat([self_val, other_val])) - self.is_all_greedy = self.is_all_greedy and other.is_all_greedy - self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling - - def apply_logits_bias(self, logits: torch.Tensor): - # Apply logit_bias - if self.logit_bias is not None: - logits.add_(self.logit_bias) - - # min-token, presence, frequency - if self.linear_penalties is not None: - logits.add_(self.linear_penalties) - - # repetition - if self.scaling_penalties is not None: - apply_scaling_penalties(logits, self.scaling_penalties) - - # Apply regex vocab_mask - if self.vocab_mask is not None: - self.apply_mask(logits=logits, vocab_mask=self.vocab_mask) + self.is_all_greedy |= other.is_all_greedy + self.need_min_p_sampling |= other.need_min_p_sampling diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 296400fa6..0658f4ebf 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -15,15 +15,21 @@ import argparse import dataclasses +import json import logging +import os import random +import subprocess import tempfile +import uuid +from pathlib import Path from typing import List, Optional import torch from sglang.srt.hf_transformers_utils import check_gguf_file from sglang.srt.utils import ( + create_checksum, get_amdgpu_memory_capacity, get_hpu_memory_capacity, get_nvgpu_memory_capacity, @@ -43,12 +49,13 @@ class ServerArgs: model_path: str tokenizer_path: Optional[str] = None tokenizer_mode: str = "auto" + skip_tokenizer_init: bool = False load_format: str = "auto" - trust_remote_code: bool = True + trust_remote_code: bool = False dtype: str = "auto" kv_cache_dtype: str = "auto" - quantization_param_path: nullable_str = None quantization: Optional[str] = None + quantization_param_path: nullable_str = None context_length: Optional[int] = None device: str = "cuda" served_model_name: Optional[str] = None @@ -67,7 +74,7 @@ class ServerArgs: max_total_tokens: Optional[int] = None chunked_prefill_size: Optional[int] = None max_prefill_tokens: int = 16384 - schedule_policy: str = "lpm" + schedule_policy: str = "fcfs" schedule_conservativeness: float = 1.0 cpu_offload_gb: int = 0 prefill_only_one_req: bool = False @@ -88,6 +95,7 @@ class ServerArgs: log_level: str = "info" log_level_http: Optional[str] = None log_requests: bool = False + log_requests_level: int = 0 show_time_cost: bool = False enable_metrics: bool = False decode_log_interval: int = 40 @@ -123,11 +131,13 @@ class ServerArgs: grammar_backend: Optional[str] = "outlines" # Speculative decoding - speculative_draft_model_path: Optional[str] = None speculative_algorithm: Optional[str] = None + speculative_draft_model_path: Optional[str] = None speculative_num_steps: int = 5 - speculative_eagle_topk: int = 8 - speculative_num_draft_tokens: int = 64 + speculative_eagle_topk: int = 4 + speculative_num_draft_tokens: int = 8 + speculative_accept_threshold_single: float = 1.0 + speculative_accept_threshold_acc: float = 1.0 speculative_token_map: Optional[str] = None # Double Sparsity @@ -169,6 +179,12 @@ class ServerArgs: enable_hierarchical_cache: bool = False enable_flashinfer_mla: bool = False flashinfer_mla_disable_ragged: bool = False + warmups: Optional[str] = None + + # Debug tensor dumps + debug_tensor_dump_output_folder: Optional[str] = None + debug_tensor_dump_input_file: Optional[str] = None + debug_tensor_dump_inject: bool = False def __post_init__(self): # Set missing default values @@ -266,10 +282,10 @@ class ServerArgs: self.speculative_algorithm == "EAGLE" or self.speculative_algorithm == "NEXTN" ): + self.disable_overlap_schedule = True self.prefill_only_one_req = True self.disable_cuda_graph_padding = True self.disable_radix_cache = True - self.disable_overlap_schedule = True self.chunked_prefill_size = -1 logger.info( f"The radix cache, chunked prefill, and overlap scheduler are disabled because of using {self.speculative_algorithm} speculative decoding." @@ -377,15 +393,6 @@ class ServerArgs: choices=["auto", "fp8_e5m2", "fp8_e4m3"], help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.', ) - parser.add_argument( - "--quantization-param-path", - type=nullable_str, - default=None, - help="Path to the JSON file containing the KV cache " - "scaling factors. This should generally be supplied, when " - "KV cache dtype is FP8. Otherwise, KV cache scaling factors " - "default to 1.0, which may cause accuracy issues. ", - ) parser.add_argument( "--quantization", type=str, @@ -404,6 +411,15 @@ class ServerArgs: ], help="The quantization method.", ) + parser.add_argument( + "--quantization-param-path", + type=nullable_str, + default=None, + help="Path to the JSON file containing the KV cache " + "scaling factors. This should generally be supplied, when " + "KV cache dtype is FP8. Otherwise, KV cache scaling factors " + "default to 1.0, which may cause accuracy issues. ", + ) parser.add_argument( "--context-length", type=int, @@ -578,7 +594,14 @@ class ServerArgs: parser.add_argument( "--log-requests", action="store_true", - help="Log the inputs and outputs of all requests.", + help="Log metadata, inputs, outputs of all requests. The verbosity is decided by --log-requests-level", + ) + parser.add_argument( + "--log-requests-level", + type=int, + default=0, + help="0: Log metadata. 1. Log metadata and partial input/output. 2. Log every input/output.", + choices=[0, 1, 2], ) parser.add_argument( "--show-time-cost", @@ -742,16 +765,28 @@ class ServerArgs: parser.add_argument( "--speculative-eagle-topk", type=int, - help="The number of token sampled from draft model in eagle2 each step.", + help="The number of tokens sampled from the draft model in eagle2 each step.", choices=[1, 2, 4, 8], default=ServerArgs.speculative_eagle_topk, ) parser.add_argument( "--speculative-num-draft-tokens", type=int, - help="The number of token sampled from draft model in Speculative Decoding.", + help="The number of tokens sampled from the draft model in Speculative Decoding.", default=ServerArgs.speculative_num_draft_tokens, ) + parser.add_argument( + "--speculative-accept-threshold-single", + type=float, + help="Accept a draft token if its probability in the target model is greater than this threshold.", + default=ServerArgs.speculative_accept_threshold_single, + ) + parser.add_argument( + "--speculative-accept-threshold-acc", + type=float, + help="The accept probability of a draft token is raised from its target probability p to min(1, p / threshold_acc).", + default=ServerArgs.speculative_accept_threshold_acc, + ) parser.add_argument( "--speculative-token-map", type=str, @@ -949,6 +984,35 @@ class ServerArgs: help="Enable hierarchical cache", ) + # Server warmups + parser.add_argument( + "--warmups", + type=str, + required=False, + help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 " + "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests", + ) + + # Debug tensor dumps + parser.add_argument( + "--debug-tensor-dump-output-folder", + type=str, + default=ServerArgs.debug_tensor_dump_output_folder, + help="The output folder for dumping tensors.", + ) + parser.add_argument( + "--debug-tensor-dump-input-file", + type=str, + default=ServerArgs.debug_tensor_dump_input_file, + help="The input filename for dumping tensors", + ) + parser.add_argument( + "--debug-tensor-dump-inject", + type=str, + default=ServerArgs.debug_tensor_dump_inject, + help="Inject the outputs from jax as the input of every layer.", + ) + @classmethod def from_cli_args(cls, args: argparse.Namespace): args.tp_size = args.tensor_parallel_size diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 88c8a4b61..38ff7409d 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -32,13 +32,15 @@ import socket import subprocess import sys import tempfile +import threading import time import warnings from functools import lru_cache from importlib.metadata import PackageNotFoundError, version from io import BytesIO +from multiprocessing import Pool from multiprocessing.reduction import ForkingPickler -from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, Union import numpy as np import psutil @@ -480,6 +482,10 @@ def assert_pkg_version(pkg: str, min_version: str, message: str): def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = None): """Kill the process and all its child processes.""" + # Remove sigchld handler to avoid spammy logs. + if threading.current_thread() is threading.main_thread(): + signal.signal(signal.SIGCHLD, signal.SIG_DFL) + if parent_pid is None: parent_pid = os.getpid() include_parent = False @@ -499,17 +505,14 @@ def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = N pass if include_parent: - if parent_pid == os.getpid(): - sys.exit(0) - else: - try: - itself.kill() + try: + itself.kill() - # Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes), - # so we send an additional signal to kill them. - itself.send_signal(signal.SIGQUIT) - except psutil.NoSuchProcess: - pass + # Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes), + # so we send an additional signal to kill them. + itself.send_signal(signal.SIGQUIT) + except psutil.NoSuchProcess: + pass def monkey_patch_p2p_access_check(): @@ -1215,7 +1218,11 @@ def cuda_device_count_stateless() -> int: return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES", None)) -def dataclass_to_string_truncated(data, max_length=2048): +def dataclass_to_string_truncated( + data, max_length=2048, skip_names: Optional[Set[str]] = None +): + if skip_names is None: + skip_names = set() if isinstance(data, str): if len(data) > max_length: half_length = max_length // 2 @@ -1234,6 +1241,7 @@ def dataclass_to_string_truncated(data, max_length=2048): + ", ".join( f"'{k}': {dataclass_to_string_truncated(v, max_length)}" for k, v in data.items() + if k not in skip_names ) + "}" ) @@ -1244,6 +1252,7 @@ def dataclass_to_string_truncated(data, max_length=2048): + ", ".join( f"{f.name}={dataclass_to_string_truncated(getattr(data, f.name), max_length)}" for f in fields + if f.name not in skip_names ) + ")" ) @@ -1322,9 +1331,9 @@ def pyspy_dump_schedulers(): result = subprocess.run( cmd, shell=True, capture_output=True, text=True, check=True ) - logger.info(f"Profile for PID {pid}:\n{result.stdout}") + logger.error(f"Pyspy dump for PID {pid}:\n{result.stdout}") except subprocess.CalledProcessError as e: - logger.info(f"Failed to profile PID {pid}. Error: {e.stderr}") + logger.error(f"Pyspy failed to dump PID {pid}. Error: {e.stderr}") def kill_itself_when_parent_died(): @@ -1448,6 +1457,10 @@ def launch_dummy_health_check_server(host, port): ) +def create_checksum(directory: str): + raise NotImplementedError() + + def set_cuda_arch(): if is_flashinfer_available(): capability = torch.cuda.get_device_capability() diff --git a/python/sglang/srt/warmup.py b/python/sglang/srt/warmup.py new file mode 100644 index 000000000..fc6d2202b --- /dev/null +++ b/python/sglang/srt/warmup.py @@ -0,0 +1,47 @@ +import logging +from typing import List + +import numpy as np +import tqdm + +from sglang.srt.managers.io_struct import GenerateReqInput +from sglang.srt.managers.tokenizer_manager import TokenizerManager + +logger = logging.getLogger(__file__) + +_warmup_registry = {} + + +def warmup(name: str) -> callable: + def decorator(fn: callable): + _warmup_registry[name] = fn + return fn + + return decorator + + +async def execute_warmups(warmup_names: List[str], tokenizer_manager: TokenizerManager): + for warmup_name in warmup_names: + if warmup_name not in _warmup_registry: + logger.warning(f"Could not find custom warmup {warmup_name}") + continue + logger.info(f"Running warmup {warmup_name}") + await _warmup_registry[warmup_name](tokenizer_manager) + + +@warmup("voice_chat") +async def voice_chat(tokenizer_manager: TokenizerManager): + # this warms up the fused_moe triton kernels and caches them + # if we don't do this we break real time inference for voice chat + for i in tqdm.trange(1, 512): + size = i * 4 + generate_req_input = GenerateReqInput( + input_ids=(np.random.randint(2**16, size=[size])).tolist(), + sampling_params={ + "max_new_tokens": 30, + "temperature": 0.8, + "stop_token_ids": [1], + "min_p": 0.0, + }, + ) + await tokenizer_manager.generate_request(generate_req_input, None).__anext__() diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index 4fd74d148..faccb16e5 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -15,7 +15,7 @@ import multiprocessing as mp import os from dataclasses import dataclass -from typing import List, Union +from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -56,6 +56,13 @@ def get_top_logprobs(logits, k): return logprobs +def get_token_ids_logprobs(logits, token_ids): + logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) + del logits + logprobs = logprobs[..., token_ids] + return logprobs + + def _get_sentence_transformer_embedding_model(model_path, torch_dtype): from sentence_transformers import SentenceTransformer from sentence_transformers.util import is_sentence_transformer_model @@ -84,8 +91,13 @@ class ModelOutput: output_ids: List[int] = None top_input_logprobs: List[torch.Tensor] = None top_output_logprobs: List[torch.Tensor] = None + top_output_logprob_idx: List[List[int]] = None embed_logits: List[torch.Tensor] = None scores: List[float] = None + input_token_logprobs_lst: List[List[Tuple[float, int, None]]] = None + output_token_logprobs_lst: List[List[Tuple[float, int, None]]] = None + token_ids_input_logprobs: List[torch.Tensor] = None + token_ids_output_logprobs: List[torch.Tensor] = None class HFRunner: @@ -157,7 +169,7 @@ class HFRunner: # Run forward while True: - prompts, max_new_tokens, lora_paths = in_queue.get() + prompts, max_new_tokens, lora_paths, token_ids_logprob = in_queue.get() if lora_paths is not None: assert len(prompts) == len(lora_paths) @@ -165,16 +177,16 @@ class HFRunner: if self.model_type == "generation": out_queue.put( self.forward_generation_raw( + base_model=self.base_model, prompts=prompts, max_new_tokens=max_new_tokens, - base_model=self.base_model, tokenizer=self.tokenizer, lora_paths=lora_paths, torch_dtype=torch_dtype, output_str_only=self.output_str_only, + token_ids_logprob=token_ids_logprob, ) ) - elif self.model_type == "embedding": assert not self.output_str_only logits = self.model.encode(prompts).tolist() @@ -199,10 +211,11 @@ class HFRunner: def forward( self, prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS, - max_new_tokens=8, - lora_paths=None, + max_new_tokens: int = 8, + lora_paths: Optional[List[str]] = None, + token_ids_logprob: Optional[int] = None, ): - self.in_queue.put((prompts, max_new_tokens, lora_paths)) + self.in_queue.put((prompts, max_new_tokens, lora_paths, token_ids_logprob)) return self.out_queue.get() def terminate(self): @@ -218,17 +231,24 @@ class HFRunner: @staticmethod def forward_generation_raw( - prompts: Union[List[str], List[torch.Tensor]], - max_new_tokens, base_model, + prompts: Union[List[str], List[torch.Tensor]], + max_new_tokens: int, tokenizer, - lora_paths, torch_dtype: torch.dtype, - output_str_only: bool, + lora_paths: Optional[List[str]] = None, + output_str_only: bool = False, + token_ids_logprob: Optional[int] = None, ) -> ModelOutput: output_strs = [] top_input_logprobs = [] top_output_logprobs = [] + if token_ids_logprob is not None: + token_ids_input_logprobs = [] + token_ids_output_logprobs = [] + else: + token_ids_input_logprobs = token_ids_output_logprobs = None + for i, p in enumerate(prompts): if isinstance(p, str): input_ids = tokenizer.encode(p, return_tensors="pt").cuda() @@ -275,18 +295,33 @@ class HFRunner: for logits in outputs.scores ] ) + if token_ids_logprob is not None: + token_ids_output_logprobs.append( + [ + get_token_ids_logprobs( + logits[0], token_ids_logprob + ).tolist() + for logits in outputs.scores + ] + ) del outputs input_logits = model.forward(input_ids).logits[0] top_input_logprobs.append( get_top_logprobs(input_logits, NUM_TOP_LOGPROBS).tolist() ) + if token_ids_logprob is not None: + token_ids_input_logprobs.append( + get_token_ids_logprobs(input_logits, token_ids_logprob).tolist() + ) del input_logits return ModelOutput( output_strs=output_strs, top_input_logprobs=top_input_logprobs, top_output_logprobs=top_output_logprobs, + token_ids_input_logprobs=token_ids_input_logprobs, + token_ids_output_logprobs=token_ids_output_logprobs, ) @@ -303,11 +338,31 @@ class SRTRunner: lora_backend: str = "triton", disable_cuda_graph: bool = False, disable_radix_cache: bool = False, + chunked_prefill_size: Optional[int] = None, + dp_size: int = 1, + tokenizer_path: Optional[str] = None, + enable_ep_moe: bool = False, mem_fraction_static: float = 0.65, trust_remote_code: bool = False, + speculative_draft_model_path: Optional[str] = None, + speculative_algorithm: Optional[str] = None, + speculative_num_steps: Optional[int] = None, + speculative_eagle_topk: Optional[int] = None, + speculative_num_draft_tokens: Optional[int] = None, + disable_overlap_schedule: bool = False, ): self.model_type = model_type self.is_generation = model_type == "generation" + enable_dp_attention = dp_size > 1 + + spec_kwargs = {} + if speculative_draft_model_path: + spec_kwargs["speculative_draft_model_path"] = speculative_draft_model_path + spec_kwargs["speculative_algorithm"] = speculative_algorithm + spec_kwargs["speculative_num_steps"] = speculative_num_steps + spec_kwargs["speculative_eagle_topk"] = speculative_eagle_topk + spec_kwargs["speculative_num_draft_tokens"] = speculative_num_draft_tokens + self.engine = Engine( model_path=model_path, tp_size=tp_size, @@ -321,21 +376,41 @@ class SRTRunner: lora_backend=lora_backend, disable_cuda_graph=disable_cuda_graph, disable_radix_cache=disable_radix_cache, + chunked_prefill_size=chunked_prefill_size, + enable_dp_attention=enable_dp_attention, + dp_size=dp_size, + tokenizer_path=tokenizer_path, + enable_ep_moe=enable_ep_moe, + disable_overlap_schedule=disable_overlap_schedule, + cuda_graph_max_bs=4, + **spec_kwargs, ) - self.tokenizer = get_tokenizer(model_path, trust_remote_code=trust_remote_code) + + if tokenizer_path is None: + self.tokenizer = get_tokenizer( + model_path, trust_remote_code=trust_remote_code + ) + else: + self.tokenizer = None def forward( self, prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS, - max_new_tokens=8, - lora_paths=None, + max_new_tokens: int = 8, + lora_paths: Optional[List[str]] = None, + logprob_start_len: int = 0, + top_k: Optional[int] = None, + token_ids_logprob: Optional[List[int]] = None, ): if self.is_generation: return self.forward_generation_raw( + engine=self.engine, prompts=prompts, max_new_tokens=max_new_tokens, lora_paths=lora_paths, - engine=self.engine, + logprob_start_len=logprob_start_len, + top_k=top_k, + token_ids_logprob=token_ids_logprob, ) else: response = self.engine.encode(prompts) @@ -358,10 +433,10 @@ class SRTRunner: """ if self.is_generation: return self.batch_forward_generation_raw( + engine=self.engine, prompts=prompts, max_new_tokens=max_new_tokens, lora_paths=lora_paths, - engine=self.engine, ) else: response = self.engine.encode(prompts) @@ -381,24 +456,43 @@ class SRTRunner: @staticmethod def forward_generation_raw( + engine: Engine, prompts: Union[List[str], List[torch.Tensor]], - max_new_tokens, - lora_paths, - engine, + max_new_tokens: int = 8, + lora_paths: Optional[List[str]] = None, + logprob_start_len: int = 0, + top_k: Optional[int] = None, + token_ids_logprob: Optional[List[int]] = None, ): # the return value contains logprobs from prefill output_strs = [] + output_ids = [] + # Input logprobs. Note that the last item in input logprob is equivalent to + # the first item in the output logprob. top_input_logprobs = [] + input_token_logprobs_lst = [] top_output_logprobs = [] + output_token_logprobs_lst = [] + top_output_logprob_idx = [] + if token_ids_logprob is not None: + token_ids_input_logprobs = [] + token_ids_output_logprobs = [] + else: + token_ids_input_logprobs = token_ids_output_logprobs = None + sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0} + if top_k: + sampling_params["top_k"] = top_k + for i, prompt in enumerate(prompts): response = engine.generate( prompt, lora_path=lora_paths[i] if lora_paths else None, sampling_params=sampling_params, return_logprob=True, - logprob_start_len=0, + logprob_start_len=logprob_start_len, top_logprobs_num=NUM_TOP_LOGPROBS, + token_ids_logprob=token_ids_logprob, ) text = response["text"] @@ -408,12 +502,36 @@ class SRTRunner: "Received an empty text response. Please verify your input or model configuration." ) output_strs.append(text) + # output_ids.append(response["output_ids"]) + + input_token_logprobs = response["meta_info"]["input_token_logprobs"] + output_token_logprobs = response["meta_info"]["output_token_logprobs"] + # print(i, input_token_logprobs) + # print(i, output_token_logprobs) + logprobs = response["meta_info"]["input_top_logprobs"] + if token_ids_logprob is not None: + input_token_ids_logprobs = response["meta_info"][ + "input_token_ids_logprobs" + ][1:] + else: + input_token_ids_logprobs = None + + num_prompt_tokens = response["meta_info"]["prompt_tokens"] + assert len(input_token_logprobs) == num_prompt_tokens - logprob_start_len + assert len(logprobs) == num_prompt_tokens - logprob_start_len + + # The first token logprob has no meaning in sglang. + input_token_logprobs = input_token_logprobs[1:] + logprobs = logprobs[1:] + assert len(input_token_logprobs) == len(logprobs) + + input_token_logprobs_lst.append( + input_token_logprobs + [output_token_logprobs[0]] + ) + output_token_logprobs_lst.append(output_token_logprobs) top_input_logprobs.append( - [ - [tup[0] for tup in x[:NUM_TOP_LOGPROBS]] - for x in response["meta_info"]["input_top_logprobs"][1:] - ] + [[tup[0] for tup in x[:NUM_TOP_LOGPROBS]] for x in logprobs] + [ [ tup[0] @@ -429,11 +547,41 @@ class SRTRunner: for x in response["meta_info"]["output_top_logprobs"] ] ) + top_output_logprob_idx.append( + [ + [tup[1] for tup in x[:NUM_TOP_LOGPROBS]] + for x in response["meta_info"]["output_top_logprobs"] + ] + ) + if token_ids_logprob is not None: + token_ids_input_logprobs.append( + [[tup[0] for tup in x] for x in input_token_ids_logprobs] + + [ + [ + tup[0] + for tup in response["meta_info"][ + "output_token_ids_logprobs" + ][0] + ] + ] + ) + token_ids_output_logprobs.append( + [ + [tup[0] for tup in x] + for x in response["meta_info"]["output_token_ids_logprobs"] + ] + ) return ModelOutput( output_strs=output_strs, + output_ids=output_ids, top_input_logprobs=top_input_logprobs, top_output_logprobs=top_output_logprobs, + input_token_logprobs_lst=input_token_logprobs_lst, + output_token_logprobs_lst=output_token_logprobs_lst, + top_output_logprob_idx=top_output_logprob_idx, + token_ids_input_logprobs=token_ids_input_logprobs, + token_ids_output_logprobs=token_ids_output_logprobs, ) @staticmethod diff --git a/python/sglang/test/send_one.py b/python/sglang/test/send_one.py new file mode 100644 index 000000000..376b44588 --- /dev/null +++ b/python/sglang/test/send_one.py @@ -0,0 +1,88 @@ +""" +Run one test prompt. + +Usage: +python3 -m sglang.test.send_one +""" + +import argparse +import json + +import requests + + +def send_one_prompt(args): + if args.image: + args.prompt = ( + "Human: Describe this image in a very short sentence.\n\nAssistant:" + ) + image_data = "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png" + else: + image_data = None + + response = requests.post( + "http://localhost:30000/generate", + json={ + "text": args.prompt, + "image_data": image_data, + "sampling_params": { + "temperature": args.temperature, + "max_new_tokens": args.max_new_tokens, + "frequency_penalty": args.frequency_penalty, + "presence_penalty": args.presence_penalty, + }, + "return_logprob": args.return_logprob, + "stream": args.stream, + }, + stream=args.stream, + ) + + if args.stream: + for chunk in response.iter_lines(decode_unicode=False): + chunk = chunk.decode("utf-8") + if chunk and chunk.startswith("data:"): + if chunk == "data: [DONE]": + break + ret = json.loads(chunk[5:].strip("\n")) + else: + ret = response.json() + + latency = ret["meta_info"]["e2e_latency"] + + if "spec_verify_ct" in ret["meta_info"]: + acc_length = ( + ret["meta_info"]["completion_tokens"] / ret["meta_info"]["spec_verify_ct"] + ) + else: + acc_length = 1.0 + + speed = ret["meta_info"]["completion_tokens"] / latency + + print(ret["text"]) + print() + print(f"{acc_length=:.2f}") + print(f"{speed=:.2f} token/s") + + return acc_length, speed + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--temperature", type=float, default=0.0) + parser.add_argument("--max-new-tokens", type=int, default=512) + parser.add_argument("--frequency-penalty", type=float, default=0.0) + parser.add_argument("--presence-penalty", type=float, default=0.0) + parser.add_argument("--return-logprob", action="store_true") + parser.add_argument( + "--prompt", + type=str, + default="Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:", + ) + parser.add_argument( + "--image", + action="store_true", + ) + parser.add_argument("--stream", action="store_true") + args = parser.parse_args() + + send_one_prompt(args) diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 3dc1ae347..ab472cc7a 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -8,10 +8,11 @@ import random import subprocess import threading import time +import unittest from concurrent.futures import ThreadPoolExecutor from functools import partial from types import SimpleNamespace -from typing import Callable, List, Optional +from typing import Callable, List, Optional, Tuple import numpy as np import requests @@ -408,26 +409,49 @@ def popen_launch_server( other_args: list[str] = (), env: Optional[dict] = None, return_stdout_stderr: Optional[tuple] = None, + pd_seperated: bool = False, ): _, host, port = base_url.split(":") host = host[2:] + if pd_seperated: + command = "sglang.launch_pd_server" + else: + command = "sglang.launch_server" + command = [ "python3", "-m", - "sglang.launch_server", + command, "--model-path", model, - "--host", - host, - "--port", - port, - *other_args, + *[str(x) for x in other_args], ] + if pd_seperated: + command.extend( + [ + "--lb-host", + host, + "--lb-port", + port, + ] + ) + else: + command.extend( + [ + "--host", + host, + "--port", + port, + ] + ) + if api_key: command += ["--api-key", api_key] + print(f"command={' '.join(command)}") + if return_stdout_stderr: process = subprocess.Popen( command, @@ -456,6 +480,8 @@ def popen_launch_server( except requests.RequestException: pass time.sleep(10) + + kill_process_tree(process.pid) raise TimeoutError("Server failed to start within the timeout period.") @@ -488,9 +514,11 @@ def run_unittest_files(files: List[str], timeout_per_file: float): success = True for filename in files: - global process + process = None def run_one_file(filename): + nonlocal process + filename = os.path.join(os.getcwd(), filename) print(f"\n\nRun:\npython3 {filename}\n\n", flush=True) process = subprocess.Popen( @@ -534,11 +562,14 @@ def get_benchmark_args( dataset_path="", tokenizer="", num_prompts=500, + sharegpt_output_len=None, random_input_len=4096, random_output_len=2048, + sharegpt_context_len=None, request_rate=float("inf"), disable_stream=False, disable_ignore_eos=False, + pd_seperated: bool = False, ): return SimpleNamespace( backend="sglang", @@ -550,8 +581,8 @@ def get_benchmark_args( model=None, tokenizer=tokenizer, num_prompts=num_prompts, - sharegpt_output_len=None, - sharegpt_context_len=None, + sharegpt_output_len=sharegpt_output_len, + sharegpt_context_len=sharegpt_context_len, random_input_len=random_input_len, random_output_len=random_output_len, random_range_ratio=0.0, @@ -567,6 +598,8 @@ def get_benchmark_args( apply_chat_template=False, profile=None, lora_name=None, + prompt_suffix="", + pd_seperated=pd_seperated, ) @@ -580,6 +613,7 @@ def run_bench_serving( tokenizer=None, random_input_len=4096, random_output_len=2048, + sharegpt_context_len=None, disable_stream=False, disable_ignore_eos=False, need_warmup=False, @@ -602,6 +636,7 @@ def run_bench_serving( num_prompts=num_prompts, random_input_len=random_input_len, random_output_len=random_output_len, + sharegpt_context_len=sharegpt_context_len, request_rate=request_rate, disable_stream=disable_stream, disable_ignore_eos=disable_ignore_eos, @@ -626,6 +661,7 @@ def run_bench_serving_multi( other_server_args, benchmark_args, need_warmup=False, + pd_seperated=False, ): # Launch the server process = popen_launch_server( @@ -633,6 +669,7 @@ def run_bench_serving_multi( base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=other_server_args, + pd_seperated=pd_seperated, ) # run benchmark for all @@ -665,7 +702,7 @@ def run_bench_one_batch(model, other_args): "128", "--output", "8", - *other_args, + *[str(x) for x in other_args], ] process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) @@ -816,7 +853,7 @@ def run_command_and_capture_output(command, env: Optional[dict] = None): stdout = open(STDOUT_FILENAME, "w") stderr = open(STDERR_FILENAME, "w") process = subprocess.Popen( - command, stdout=stdout, stderr=stderr, env=env, text=True + command, stdout=stdout, stderr=stdout, env=env, text=True ) # Launch a thread to stream the output @@ -914,3 +951,78 @@ def run_mulit_request_test( def write_github_step_summary(content): with open(os.environ["GITHUB_STEP_SUMMARY"], "a") as f: f.write(content) + + +def run_logprob_check(self: unittest.TestCase, arg: Tuple): + ( + input_len, + output_len, + temperature, + logprob_start_len, + return_logprob, + top_logprobs_num, + ) = arg + input_ids = list(range(input_len)) + + response = requests.post( + self.base_url + "/generate", + json={ + "input_ids": input_ids, + "sampling_params": { + "temperature": temperature, + "max_new_tokens": output_len, + "ignore_eos": True, + }, + "return_logprob": return_logprob, + "logprob_start_len": logprob_start_len, + "top_logprobs_num": top_logprobs_num, + }, + ) + response_json = response.json() + + res = response_json + self.assertEqual(res["meta_info"]["prompt_tokens"], input_len) + self.assertEqual(res["meta_info"]["completion_tokens"], output_len) + + # Test the number of tokens are correct + if return_logprob: + self.assertEqual( + len(res["meta_info"]["input_token_logprobs"]) + logprob_start_len, + res["meta_info"]["prompt_tokens"], + ) + self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), output_len) + + if top_logprobs_num: + self.assertEqual( + len(res["meta_info"]["input_top_logprobs"]) + logprob_start_len, + res["meta_info"]["prompt_tokens"], + ) + self.assertEqual(len(res["meta_info"]["output_top_logprobs"]), output_len) + + for i in range(output_len): + self.assertEqual( + len(res["meta_info"]["output_top_logprobs"][i]), + top_logprobs_num, + ) + + # Test the top-1 tokens are the same as output tokens if temperature == 0 + if temperature == 0: + rank = 0 + while rank < len(res["meta_info"]["output_top_logprobs"][i]): + try: + self.assertListEqual( + res["meta_info"]["output_token_logprobs"][i], + res["meta_info"]["output_top_logprobs"][i][rank], + ) + break + except AssertionError: + # There's a tie. Allow the second item in this case. + if ( + res["meta_info"]["output_top_logprobs"][i][rank][0] + == res["meta_info"]["output_top_logprobs"][i][rank + 1][ + 0 + ] + ): + rank += 1 + else: + raise diff --git a/scripts/killall_sglang.sh b/scripts/killall_sglang.sh index 163a60f18..fdafbbb98 100755 --- a/scripts/killall_sglang.sh +++ b/scripts/killall_sglang.sh @@ -1,14 +1,5 @@ #!/bin/bash -# Check if sudo is available -if command -v sudo >/dev/null 2>&1; then - sudo apt-get update - sudo apt-get install -y lsof -else - apt-get update - apt-get install -y lsof -fi - # Show current GPU status nvidia-smi @@ -20,6 +11,14 @@ kill -9 $(ps aux | grep 'sglang.data_parallel' | grep -v 'grep' | awk '{print $2 # Clean all GPU processes if any argument is provided if [ $# -gt 0 ]; then + # Check if sudo is available + if command -v sudo >/dev/null 2>&1; then + sudo apt-get update + sudo apt-get install -y lsof + else + apt-get update + apt-get install -y lsof + fi kill -9 $(nvidia-smi | sed -n '/Processes:/,$p' | grep " [0-9]" | awk '{print $5}') 2>/dev/null lsof /dev/nvidia* | awk '{print $2}' | xargs kill -9 2>/dev/null fi diff --git a/scripts/playground/bench_speculative.py b/scripts/playground/bench_speculative.py new file mode 100644 index 000000000..81f0a03f2 --- /dev/null +++ b/scripts/playground/bench_speculative.py @@ -0,0 +1,257 @@ +""" +Usage: +# single GPU +python3 bench_speculative.py --model-path meta-llama/Llama-2-7b-chat-hf --speculative-draft-model-path lmzheng/sglang-EAGLE-llama2-chat-7B +""" + +import argparse +import asyncio +import json +import os +import time +from types import SimpleNamespace + +import numpy as np +import requests + +from sglang.bench_serving import benchmark, set_global_args +from sglang.srt.server_args import ServerArgs +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + kill_process_tree, + popen_launch_server, +) + + +def node0_print(msg): + if server_args.node_rank == 0: + print(msg) + + +prompts = [ + "Human: Give me a fully functional FastAPI server. Show the full, long python code without stop.\n\nAssistant:", + "Human: Imagine you are an experienced Ethereum developer tasked with creating a smart contract for a blockchain messenger. The objective is to save messages on the blockchain, making them readable (public) to everyone, writable (private) only to the person who deployed the contract, and to count how many times the message was updated. Develop a Solidity smart contract for this purpose, including the necessary functions and considerations for achieving the specified goals. Please provide the code and any relevant explanations to ensure a clear understanding of the implementation.\n\nAssistant:", + "Human: Write a travel blog post to Hawaii.\n\nAssistant:", + "Human: I want you to act as an English translator, spelling corrector and improver. I will speak to you in any language and you will detect the language, translate it and answer in the corrected and improved version of my text, in English. I want you to replace my simplified A0-level words and sentences with more beautiful and elegant, upper level English words and sentences. Keep the meaning same, but make them more literary. My first sentence is 'istanbulu cok seviyom burada olmak cok guzel'. Answer in more than 5000 words.\n\nAssistant:", + "Human: I want you to act as a storyteller. You will come up with entertaining stories that are engaging, imaginative and captivating for the audience. It can be fairy tales, educational stories or any other type of stories which has the potential to capture people's attention and imagination. Depending on the target audience, you may choose specific themes or topics for your storytelling session e.g., if it’s children then you can talk about animals; If it’s adults then history-based tales might engage them better etc. Answer in more than 5000 words. My first request is 'I need an interesting story on perseverance.'\n\nAssistant:", + "Human: Solve x^2 = -1. Think step-by-step. Give me a long detailed explanation. \n\nAssistant:", + "Human: Tell me about the president of the USA in wikipedia style.\n\nAssistant:", + "Human: Hello? Who are you? Write code, math, and poem to explanin yourself.\n\nAssistant:", +] + + +class FakeTokenizer: + def encode(self, text: str, add_special_tokens: bool = False): + return [] + + +def send_one_batch(base_url, num_prompts, batch_size): + padded_prompts = (prompts * ((num_prompts + len(prompts) - 1) // len(prompts)))[ + :num_prompts + ] + + # format: (prompt, input_len, output len). We set input_len as a dummy value 0. + input_requests = [(p, 0, 512) for p in padded_prompts] + + # We need to set some dummy values in order to call `benchmark` below. + args = SimpleNamespace( + disable_ignore_eos=False, + disable_stream=False, + return_logprob=False, + backend="sglang", + dataset_name="custom", + num_prompts=None, + sharegpt_output_len=None, + random_input_len=None, + random_output_len=None, + random_range_ratio=None, + output_file=None, + ) + set_global_args(args) + tokenizer = FakeTokenizer() + + # Run benchmark + results = asyncio.run( + benchmark( + backend="sglang", + api_url=f"{base_url}/generate", + base_url=base_url, + model_id="default", + tokenizer=tokenizer, + input_requests=input_requests, + request_rate=float("inf"), + max_concurrency=batch_size, + disable_tqdm=False, + lora_name=None, + extra_request_body={}, + profile=None, + ) + ) + + assert results["completed"] == len(input_requests) + acc_length = results["accept_length"] or 1.0 + avg_output_token = results["total_output_tokens"] / results["completed"] + + server_info = requests.get(base_url + "/get_server_info").json() + # We use 20% percentile instead of median on purpose + step_time = np.percentile(server_info["step_time_dict"][str(batch_size)], 20) + speed = 1 / step_time * acc_length + + return ( + round(acc_length, 3), + round(step_time, 5), + round(speed, 3), + avg_output_token, + ) + + +def main(args, server_args): + base_url = "http://127.0.0.1:20000" + + configs = [] + for batch_size in args.batch_size: + for steps in args.steps: + for topk in args.topk: + for num_draft_tokens in args.num_draft_tokens: + if steps * topk + 1 < num_draft_tokens: + continue + + if (steps == 0 or topk == 0 or num_draft_tokens == 0) and ( + steps + topk + num_draft_tokens != 0 + ): + # steps == 0 and topk == 0 and num_draft_tokens == 0 is a special case for non-speculative decoding. + continue + + configs.append((batch_size, steps, topk, num_draft_tokens)) + + for i in range(args.start, args.end or len(configs)): + batch_size, steps, topk, num_draft_tokens = configs[i] + + node0_print( + f"Start {i=}: {batch_size=}, {steps=}, {topk=}, {num_draft_tokens=}" + ) + + # Create an LLM. + if steps == 0: + other_args = [] + else: + other_args = [ + "--speculative-algorithm", + "EAGLE", + "--speculative-num-steps", + steps, + "--speculative-eagle-topk", + topk, + "--speculative-num-draft-tokens", + num_draft_tokens, + ] + if server_args.speculative_draft_model_path is not None: + other_args.extend( + [ + "--speculative-draft-model-path", + server_args.speculative_draft_model_path, + ] + ) + + other_args.extend( + [ + "--cuda-graph-max-bs", + batch_size, + "--mem-fraction-static", + server_args.mem_fraction_static, + "--tp-size", + server_args.tp_size, + "--max-running-requests", + batch_size, + ] + ) + + if server_args.quantization: + other_args.extend( + [ + "--quantization", + server_args.quantization, + ] + ) + + process = popen_launch_server( + args.model_path, + base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + env={ + "SGLANG_RECORD_STEP_TIME": "1", + **os.environ, + }, + ) + + try: + # Warmup + send_one_batch(base_url, batch_size, batch_size) + + # Benchmark + acc_length, step_time, speed, completion_tokens = send_one_batch( + base_url, max(args.num_prompts, batch_size), batch_size + ) + finally: + kill_process_tree(process.pid) + + node0_print( + f"Finish {i=}: {batch_size=}, {steps=}, {topk=}, {num_draft_tokens=}, {speed=:.2f} token/s, step_time={step_time * 1000:.2f} ms" + ) + + record = { + "batch_size": batch_size, + "steps": steps, + "topk": topk, + "num_draft_tokens": num_draft_tokens, + "acc_length": acc_length, + "step_time": step_time, + "speed": speed, + "completion_tokens": completion_tokens, + } + + with open(args.output, "a") as fout: + fout.write(json.dumps(record) + "\n") + + # Wait for the server to shutdown + time.sleep(5) + + +# The __main__ condition is necessary here because we use "spawn" to create subprocesses +# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine +if __name__ == "__main__": + parser = argparse.ArgumentParser() + ServerArgs.add_cli_args(parser) + parser.add_argument( + "--batch-size", + type=int, + nargs="+", + default=(1, 2, 4, 8, 16), + ) + parser.add_argument( + "--steps", + type=int, + nargs="+", + default=(0, 1, 3, 5, 7), # use (0, 1, 2, 3, 4) for large batch size + ) + parser.add_argument( + "--topk", + type=int, + nargs="+", + default=(0, 1, 2, 4, 8), + ) + parser.add_argument( + "--num_draft_tokens", + type=int, + nargs="+", + default=(0, 2, 4, 8, 16, 32), # use (0, 2, 4, 8) for large batch size + ) + parser.add_argument("--num-prompts", type=int, default=16) + parser.add_argument("--start", type=int, default=0) + parser.add_argument("--end", type=int) + parser.add_argument("--output", type=str, default="output.jsonl") + args = parser.parse_args() + server_args: ServerArgs = ServerArgs.from_cli_args(args) + + main(args, server_args) diff --git a/sgl-kernel/src/sgl-kernel/__init__.py b/sgl-kernel/src/sgl-kernel/__init__.py index 8af022434..ad554e60c 100644 --- a/sgl-kernel/src/sgl-kernel/__init__.py +++ b/sgl-kernel/src/sgl-kernel/__init__.py @@ -111,6 +111,8 @@ else: "cublas_grouped_gemm", "custom_dispose", "custom_reduce", + "build_tree_kernel_efficient", + "build_tree_kernel", "fp8_blockwise_scaled_mm", "fp8_scaled_mm", "fused_add_rmsnorm", @@ -127,12 +129,10 @@ else: "register_graph_buffers", "rmsnorm", "sampling_scaling_penalties", + "sgl_per_token_group_quant_fp8", "silu_and_mul", "top_k_renorm_prob", "top_k_top_p_sampling_from_probs", "top_p_renorm_prob", "tree_speculative_sampling_target_only", - "build_tree_kernel_efficient", - "build_tree_kernel", - "sgl_per_token_group_quant_fp8", ] diff --git a/test/lang/test_srt_backend.py b/test/lang/test_srt_backend.py index 086faeabb..92e828c26 100644 --- a/test/lang/test_srt_backend.py +++ b/test/lang/test_srt_backend.py @@ -30,7 +30,9 @@ class TestSRTBackend(unittest.TestCase): @classmethod def setUpClass(cls): - cls.backend = sgl.Runtime(model_path=DEFAULT_MODEL_NAME_FOR_TEST) + cls.backend = sgl.Runtime( + model_path=DEFAULT_MODEL_NAME_FOR_TEST, cuda_graph_max_bs=4 + ) sgl.set_default_backend(cls.backend) @classmethod diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 8aa8e0fd1..326b96e33 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -12,7 +12,6 @@ suites = { "models/test_generation_models.py", "models/test_qwen_models.py", "models/test_reward_models.py", - "sampling/penaltylib", "test_abort.py", "test_chunked_prefill.py", "test_custom_allreduce.py", @@ -31,6 +30,7 @@ suites = { "test_no_chunked_prefill.py", "test_no_overlap_scheduler.py", "test_openai_server.py", + "test_penalty.py", "test_pytorch_sampling_backend.py", "test_radix_attention.py", "test_regex_constrained.py", @@ -38,7 +38,8 @@ suites = { "test_request_length_validation.py", "test_retract_decode.py", "test_server_args.py", - "test_session_control.py", + # Disabled temporarily + # "test_session_control.py", "test_skip_tokenizer_init.py", "test_srt_engine.py", "test_srt_endpoint.py", @@ -64,9 +65,6 @@ suites = { # Disable temporarily # "test_nightly_math_eval.py", ], - "sampling/penaltylib": glob.glob( - "sampling/penaltylib/**/test_*.py", recursive=True - ), } # Expand suite @@ -83,7 +81,7 @@ if __name__ == "__main__": arg_parser.add_argument( "--timeout-per-file", type=int, - default=2000, + default=1800, help="The time limit for running one file in seconds.", ) arg_parser.add_argument( diff --git a/test/srt/sampling/penaltylib/penalizers/test_frequency_penalty.py b/test/srt/sampling/penaltylib/penalizers/test_frequency_penalty.py deleted file mode 100644 index e8a8fe033..000000000 --- a/test/srt/sampling/penaltylib/penalizers/test_frequency_penalty.py +++ /dev/null @@ -1,97 +0,0 @@ -import unittest -from typing import List - -import torch - -from sglang.srt.sampling.penaltylib.penalizers.frequency_penalty import ( - BatchedFrequencyPenalizer, -) -from sglang.test.srt.sampling.penaltylib.utils import ( - BaseBatchedPenalizerTest, - MockSamplingParams, - Step, - StepType, - Subject, -) - - -class BaseBatchedFrequencyPenalizerTest(BaseBatchedPenalizerTest): - Penalizer = BatchedFrequencyPenalizer - frequency_penalty: float - - def setUp(self): - if self.__class__ == BaseBatchedFrequencyPenalizerTest: - self.skipTest("Base class for frequency_penalty tests") - - super().setUp() - - def _create_subject(self, frequency_penalty: float) -> Subject: - return Subject( - sampling_params=MockSamplingParams( - frequency_penalty=frequency_penalty, - ), - steps=[ - Step( - type=StepType.INPUT, - token_ids=[0, 1, 2], - expected_tensors={ - "frequency_penalties": self.tensor( - [[frequency_penalty] * self.vocab_size], dtype=torch.float32 - ), - "cumulated_frequency_penalties": self.tensor( - [[0.0] * self.vocab_size], dtype=torch.float32 - ), - }, - expected_logits=self.tensor( - [[1] * self.vocab_size], dtype=torch.float32 - ), - ), - Step( - type=StepType.OUTPUT, - token_ids=[ - 1, - 2, - 2, - ], # This is the output ids of one request in three steps. - expected_tensors={ - "frequency_penalties": self.tensor( - [[frequency_penalty] * self.vocab_size], dtype=torch.float32 - ), - "cumulated_frequency_penalties": self.tensor( - [ - [ - frequency_penalty * i if i in {1, 2} else 0.0 - for i in range(self.vocab_size) - ], - ], - dtype=torch.float32, - ), - }, - expected_logits=self.tensor( - [ - [ - 1.0 - frequency_penalty * i if i in {1, 2} else 1.0 - for i in range(self.vocab_size) - ], - ], - dtype=torch.float32, - ), - ), - ], - ) - - def create_test_subjects(self) -> List[Subject]: - self.enabled = self._create_subject(frequency_penalty=self.frequency_penalty) - self.disabled = self._create_subject(frequency_penalty=0.0) - - -class TestBatchedFrequencyPenalizerPositiveValue(BaseBatchedFrequencyPenalizerTest): - frequency_penalty = 0.12 - - -class TestBatchedFrequencyPenalizerNegativeValue(BaseBatchedFrequencyPenalizerTest): - frequency_penalty = -0.12 - - -if __name__ == "__main__": - unittest.main() diff --git a/test/srt/sampling/penaltylib/penalizers/test_min_new_tokens.py b/test/srt/sampling/penaltylib/penalizers/test_min_new_tokens.py deleted file mode 100644 index 298dd2cc1..000000000 --- a/test/srt/sampling/penaltylib/penalizers/test_min_new_tokens.py +++ /dev/null @@ -1,152 +0,0 @@ -import unittest -from typing import List - -import torch - -from sglang.srt.sampling.penaltylib.penalizers.min_new_tokens import ( - BatchedMinNewTokensPenalizer, -) -from sglang.test.srt.sampling.penaltylib.utils import ( - BaseBatchedPenalizerTest, - MockSamplingParams, - Step, - StepType, - Subject, -) - -MIN_NEW_TOKENS = 2 -EOS_TOKEN_ID = 4 -STOP_TOKEN_ID = 3 - -ALL_STOP_TOKEN_IDS = {STOP_TOKEN_ID, EOS_TOKEN_ID} - - -class TestBatchedMinNewTokensPenalizer(BaseBatchedPenalizerTest): - Penalizer = BatchedMinNewTokensPenalizer - - def _create_subject(self, min_new_tokens: int) -> Subject: - return Subject( - eos_token_id=EOS_TOKEN_ID, - sampling_params=MockSamplingParams( - min_new_tokens=min_new_tokens, - stop_token_ids={STOP_TOKEN_ID}, - ), - steps=[ - Step( - type=StepType.INPUT, - token_ids=[0, 1, 2], - expected_tensors={ - "min_new_tokens": self.tensor( - [[min_new_tokens]], dtype=torch.int32 - ), - "stop_token_penalties": self.tensor( - [ - [ - float("-inf") if i in ALL_STOP_TOKEN_IDS else 0 - for i in range(self.vocab_size) - ] - ], - dtype=torch.float32, - ), - "len_output_tokens": self.tensor([[0]], dtype=torch.int32), - }, - expected_logits=( - self.tensor( - [ - [ - float("-inf") if i in ALL_STOP_TOKEN_IDS else 1 - for i in range(self.vocab_size) - ] - ], - dtype=torch.float32, - ) - if min_new_tokens > 0 - else torch.ones( - (1, self.vocab_size), - dtype=torch.float32, - device=self.device, - ) - ), - ), - Step( - type=StepType.OUTPUT, - token_ids=[0], - expected_tensors={ - "min_new_tokens": self.tensor( - [[min_new_tokens]], dtype=torch.int32 - ), - "stop_token_penalties": self.tensor( - [ - [ - float("-inf") if i in ALL_STOP_TOKEN_IDS else 0 - for i in range(self.vocab_size) - ] - ], - dtype=torch.float32, - ), - "len_output_tokens": self.tensor([[1]], dtype=torch.int32), - }, - expected_logits=( - self.tensor( - [ - [ - float("-inf") if i in ALL_STOP_TOKEN_IDS else 1 - for i in range(self.vocab_size) - ] - ], - dtype=torch.float32, - ) - if min_new_tokens > 1 - else torch.ones( - (1, self.vocab_size), - dtype=torch.float32, - device=self.device, - ) - ), - ), - Step( - type=StepType.OUTPUT, - token_ids=[0], - expected_tensors={ - "min_new_tokens": self.tensor( - [[min_new_tokens]], dtype=torch.int32 - ), - "stop_token_penalties": self.tensor( - [ - [ - float("-inf") if i in ALL_STOP_TOKEN_IDS else 0 - for i in range(self.vocab_size) - ] - ], - dtype=torch.float32, - ), - "len_output_tokens": self.tensor([[2]], dtype=torch.int32), - }, - expected_logits=( - self.tensor( - [ - [ - float("-inf") if i in ALL_STOP_TOKEN_IDS else 1 - for i in range(self.vocab_size) - ] - ], - dtype=torch.float32, - ) - if min_new_tokens > 2 - else torch.ones( - (1, self.vocab_size), - dtype=torch.float32, - device=self.device, - ) - ), - ), - ], - ) - - def create_test_subjects(self) -> List[Subject]: - self.enabled = self._create_subject(min_new_tokens=MIN_NEW_TOKENS) - self.disabled = self._create_subject(min_new_tokens=0.0) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/srt/sampling/penaltylib/penalizers/test_presence_penalty.py b/test/srt/sampling/penaltylib/penalizers/test_presence_penalty.py deleted file mode 100644 index b249283ac..000000000 --- a/test/srt/sampling/penaltylib/penalizers/test_presence_penalty.py +++ /dev/null @@ -1,93 +0,0 @@ -import unittest -from typing import List - -import torch - -from sglang.srt.sampling.penaltylib.penalizers.presence_penalty import ( - BatchedPresencePenalizer, -) -from sglang.test.srt.sampling.penaltylib.utils import ( - BaseBatchedPenalizerTest, - MockSamplingParams, - Step, - StepType, - Subject, -) - - -class BaseBatchedPresencePenalizerTest(BaseBatchedPenalizerTest): - Penalizer = BatchedPresencePenalizer - presence_penalty: float - - def setUp(self): - if self.__class__ == BaseBatchedPresencePenalizerTest: - self.skipTest("Base class for presence_penalty tests") - - super().setUp() - - def _create_subject(self, presence_penalty: float) -> Subject: - return Subject( - sampling_params=MockSamplingParams( - presence_penalty=presence_penalty, - ), - steps=[ - Step( - type=StepType.INPUT, - token_ids=[0, 1, 2], - expected_tensors={ - "presence_penalties": self.tensor( - [[presence_penalty] * self.vocab_size], dtype=torch.float32 - ), - "cumulated_presence_penalties": self.tensor( - [[0.0] * self.vocab_size], dtype=torch.float32 - ), - }, - expected_logits=self.tensor( - [[1] * self.vocab_size], dtype=torch.float32 - ), - ), - Step( - type=StepType.OUTPUT, - token_ids=[1, 2, 2], - expected_tensors={ - "presence_penalties": self.tensor( - [[presence_penalty] * self.vocab_size], dtype=torch.float32 - ), - "cumulated_presence_penalties": self.tensor( - [ - [ - presence_penalty if i in {1, 2} else 0.0 - for i in range(self.vocab_size) - ], - ], - dtype=torch.float32, - ), - }, - expected_logits=self.tensor( - [ - [ - 1.0 - presence_penalty if i in {1, 2} else 1.0 - for i in range(self.vocab_size) - ], - ], - dtype=torch.float32, - ), - ), - ], - ) - - def create_test_subjects(self) -> List[Subject]: - self.enabled = self._create_subject(presence_penalty=self.presence_penalty) - self.disabled = self._create_subject(presence_penalty=0.0) - - -class TestBatchedPresencePenalizerPositiveValue(BaseBatchedPresencePenalizerTest): - presence_penalty = 0.12 - - -class TestBatchedPresencePenalizerNegativeValue(BaseBatchedPresencePenalizerTest): - presence_penalty = -0.12 - - -if __name__ == "__main__": - unittest.main() diff --git a/test/srt/sampling/penaltylib/penalizers/test_repetition_penalty.py b/test/srt/sampling/penaltylib/penalizers/test_repetition_penalty.py deleted file mode 100644 index 2f8671391..000000000 --- a/test/srt/sampling/penaltylib/penalizers/test_repetition_penalty.py +++ /dev/null @@ -1,87 +0,0 @@ -import unittest -from typing import List - -import torch - -from sglang.srt.sampling.penaltylib.penalizers.repetition_penalty import ( - BatchedRepetitionPenalizer, -) -from sglang.test.srt.sampling.penaltylib.utils import ( - BaseBatchedPenalizerTest, - MockSamplingParams, - Step, - StepType, - Subject, -) - -REPETITION_PENALTY = 2.0 - - -class TestBatchedRepetitionPenalizer(BaseBatchedPenalizerTest): - Penalizer = BatchedRepetitionPenalizer - - def _create_subject(self, repetition_penalty: float) -> Subject: - l = 1.0 / repetition_penalty - return Subject( - sampling_params=MockSamplingParams( - repetition_penalty=repetition_penalty, - ), - steps=[ - Step( - type=StepType.INPUT, - token_ids=[0, 1, 2], - expected_tensors={ - "repetition_penalties": self.tensor( - [[repetition_penalty] * self.vocab_size], - dtype=torch.float32, - ), - "cumulated_repetition_penalties": ( - self.tensor( - [[2.0, 2.0, 2.0, 1.0, 1.0]], dtype=torch.float32 - ) - if repetition_penalty != 1.0 - else self.tensor( - [[1.0] * self.vocab_size], dtype=torch.float32 - ) - ), - }, - expected_logits=( - self.tensor([[l, l, l, 1.0, 1.0]], dtype=torch.float32) - if repetition_penalty != 1.0 - else self.tensor([[1.0] * self.vocab_size], dtype=torch.float32) - ), - ), - Step( - type=StepType.OUTPUT, - token_ids=[0, 1, 3], - expected_tensors={ - "repetition_penalties": self.tensor( - [[repetition_penalty] * self.vocab_size], - dtype=torch.float32, - ), - "cumulated_repetition_penalties": ( - self.tensor( - [[2.0, 2.0, 2.0, 2.0, 1.0]], dtype=torch.float32 - ) - if repetition_penalty != 1.0 - else self.tensor( - [[1.0] * self.vocab_size], dtype=torch.float32 - ) - ), - }, - expected_logits=( - self.tensor([[l, l, l, l, 1.0]], dtype=torch.float32) - if repetition_penalty != 1.0 - else self.tensor([[1.0] * self.vocab_size], dtype=torch.float32) - ), - ), - ], - ) - - def create_test_subjects(self) -> List[Subject]: - self.enabled = self._create_subject(repetition_penalty=REPETITION_PENALTY) - self.disabled = self._create_subject(repetition_penalty=1.0) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py b/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py deleted file mode 100644 index d9d77a9ae..000000000 --- a/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py +++ /dev/null @@ -1,114 +0,0 @@ -import json -import unittest -from multiprocessing import Process - -import requests - -from sglang.srt.utils import kill_process_tree -from sglang.test.test_utils import ( - DEFAULT_SMALL_MODEL_NAME_FOR_TEST, - DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - DEFAULT_URL_FOR_TEST, - popen_launch_server, -) - - -class TestBatchPenalizerE2E(unittest.TestCase): - - @classmethod - def setUpClass(cls): - cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=( - "--random-seed", - "0", - ), - ) - - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - - def run_decode( - self, - return_logprob=True, - top_logprobs_num=5, - return_text=True, - n=1, - **sampling_params, - ): - response = requests.post( - self.base_url + "/generate", - json={ - # prompt that is supposed to generate < 32 tokens - "text": "<|start_header_id|>user<|end_header_id|>\n\nWhat is the answer for 1 + 1 = ?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", - "sampling_params": { - "max_new_tokens": 32, - "n": n, - **sampling_params, - }, - "stream": False, - "return_logprob": return_logprob, - "top_logprobs_num": top_logprobs_num, - "return_text_in_logprobs": return_text, - "logprob_start_len": 0, - }, - ) - assert response.status_code == 200, "Request failed: " + response.text - - def test_default_values(self): - self.run_decode() - - def test_mixed(self): - """ - Sends two requests with one with penalizers disabled, and the other with penalizers enabled. - This will cause two different {ScheduleBatch} to be initialized and eventually gets merged. - - Merging batch with penalizers enabled with enabled, or disabled is trivial. However disabled + enabled is not. - This is because the penalizer will not be prepared if it is not required, then it will be prepared during the merge. - - This test triggers the merge of disabled + enabled. - """ - - processes = [] - - p = Process( - target=self.run_decode, - ) - processes.append(p) - p.start() - - p = Process( - target=self.run_decode, - kwargs={ - "frequency_penalty": 2, - "min_new_tokens": 16, - "presence_penalty": 2, - "repetition_penalty": 2, - }, - ) - processes.append(p) - p.start() - - for p in processes: - p.join() - - def test_frequency_penalty(self): - self.run_decode(frequency_penalty=2) - - def test_min_new_tokens(self): - self.run_decode(min_new_tokens=16) - - def test_presence_penalty(self): - self.run_decode(presence_penalty=2) - - def test_repetition_penalty(self): - self.run_decode(repetition_penalty=2) - - -if __name__ == "__main__": - unittest.main(verbosity=3) diff --git a/test/srt/test_bench_serving.py b/test/srt/test_bench_serving.py index b0c6dcd19..d9970a2ec 100644 --- a/test/srt/test_bench_serving.py +++ b/test/srt/test_bench_serving.py @@ -138,6 +138,7 @@ class TestBenchServing(unittest.TestCase): model=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, num_prompts=50, request_rate=1, + sharegpt_context_len=3072, disable_ignore_eos=True, dataset_name="sharegpt", other_server_args=[ @@ -148,22 +149,23 @@ class TestBenchServing(unittest.TestCase): "--speculative-num-steps", "5", "--speculative-eagle-topk", - "8", + "4", "--speculative-num-draft-tokens", - "64", + "16", "--mem-fraction-static", "0.7", - "--cuda-graph-max-bs", - "32", ], + need_warmup=True, ) if is_in_ci(): write_github_step_summary( f"### test_online_latency_eagle\n" f'median_e2e_latency_ms : {res["median_e2e_latency_ms"]:.2f} ms\n' + f'accept_length : {res["accept_length"]:.2f} \n' ) - self.assertLess(res["median_e2e_latency_ms"], 450) + self.assertLess(res["median_e2e_latency_ms"], 700) + self.assertGreater(res["accept_length"], 2.50) def test_moe_offline_throughput_default(self): res = run_bench_serving( diff --git a/test/srt/test_eval_accuracy_large.py b/test/srt/test_eval_accuracy_large.py index f7fb3cec3..dd923777f 100644 --- a/test/srt/test_eval_accuracy_large.py +++ b/test/srt/test_eval_accuracy_large.py @@ -12,7 +12,9 @@ from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, + is_in_ci, popen_launch_server, + write_github_step_summary, ) @@ -44,6 +46,9 @@ class TestEvalAccuracyLarge(unittest.TestCase): metrics = run_eval(args) self.assertGreater(metrics["score"], 0.71) + if is_in_ci(): + write_github_step_summary(f"### test_mmlu\n" f'{metrics["score"]=:.4f}\n') + def test_human_eval(self): args = SimpleNamespace( base_url=self.base_url, @@ -56,6 +61,11 @@ class TestEvalAccuracyLarge(unittest.TestCase): metrics = run_eval(args) self.assertGreater(metrics["score"], 0.64) + if is_in_ci(): + write_github_step_summary( + f"### test_human_eval\n" f'{metrics["score"]=:.4f}\n' + ) + def test_mgsm_en(self): args = SimpleNamespace( base_url=self.base_url, @@ -68,6 +78,11 @@ class TestEvalAccuracyLarge(unittest.TestCase): metrics = run_eval(args) self.assertGreater(metrics["score"], 0.835) + if is_in_ci(): + write_github_step_summary( + f"### test_mgsm_en\n" f'{metrics["score"]=:.4f}\n' + ) + if __name__ == "__main__": unittest.main() diff --git a/test/srt/test_health_check.py b/test/srt/test_health_check.py new file mode 100644 index 000000000..708230ffd --- /dev/null +++ b/test/srt/test_health_check.py @@ -0,0 +1,27 @@ +import unittest + +from sglang.test.test_utils import ( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestHealthCheck(unittest.TestCase): + def test_health_check(self): + """Test that metrics endpoint returns data when enabled""" + with self.assertRaises(TimeoutError): + popen_launch_server( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_URL_FOR_TEST, + timeout=60, + other_args=[ + "--disable-cuda-graph", + "--json-model-override-args", + '{"architectures": ["LlamaForCausalLMForHealthTest"]}', + ], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_hidden_states.py b/test/srt/test_hidden_states.py index 0deec49d7..83fda7756 100644 --- a/test/srt/test_hidden_states.py +++ b/test/srt/test_hidden_states.py @@ -49,7 +49,7 @@ class TestHiddenState(unittest.TestCase): with torch.inference_mode(): hf_out = model( torch.tensor( - [input_id + output["token_ids"][:-1]], device=model.device + [input_id + output["output_ids"][:-1]], device=model.device ), output_hidden_states=True, ) diff --git a/test/srt/test_metrics.py b/test/srt/test_metrics.py index 2837107a1..6fd3295b5 100644 --- a/test/srt/test_metrics.py +++ b/test/srt/test_metrics.py @@ -56,11 +56,13 @@ class TestEnableMetrics(unittest.TestCase): "sglang:gen_throughput", "sglang:num_queue_reqs", "sglang:cache_hit_rate", + "sglang:spec_accept_length", "sglang:prompt_tokens_total", "sglang:generation_tokens_total", "sglang:num_requests_total", "sglang:time_to_first_token_seconds", "sglang:time_per_output_token_seconds", + "sglang:inter_token_latency_seconds", "sglang:e2e_request_latency_seconds", ] diff --git a/test/srt/test_mla.py b/test/srt/test_mla.py index 0bc64ea3e..abc2fb656 100644 --- a/test/srt/test_mla.py +++ b/test/srt/test_mla.py @@ -141,7 +141,7 @@ class TestDeepseekV3MTP(unittest.TestCase): metrics = run_eval_few_shot_gsm8k(args) print(metrics) - self.assertGreater(metrics["accuracy"], 0.62) + self.assertGreater(metrics["accuracy"], 0.60) if __name__ == "__main__": diff --git a/test/srt/test_penalty.py b/test/srt/test_penalty.py new file mode 100644 index 000000000..cb9b6b3dc --- /dev/null +++ b/test/srt/test_penalty.py @@ -0,0 +1,91 @@ +import json +import random +import unittest +from concurrent.futures import ThreadPoolExecutor + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestPenalty(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def run_decode(self, sampling_params): + return_logprob = True + top_logprobs_num = 5 + return_text = True + n = 1 + + response = requests.post( + self.base_url + "/generate", + json={ + # prompt that is supposed to generate < 32 tokens + "text": "<|start_header_id|>user<|end_header_id|>\n\nWhat is the answer for 1 + 1 = ?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", + "sampling_params": { + "max_new_tokens": 32, + "n": n, + **sampling_params, + }, + "return_logprob": return_logprob, + "top_logprobs_num": top_logprobs_num, + "return_text_in_logprobs": return_text, + "logprob_start_len": 0, + }, + ) + self.assertEqual(response.status_code, 200) + print(json.dumps(response.json())) + print("=" * 100) + + def test_default_values(self): + self.run_decode({}) + + def test_frequency_penalty(self): + self.run_decode({"frequency_penalty": 2}) + + def test_min_new_tokens(self): + self.run_decode({"min_new_tokens": 16}) + + def test_presence_penalty(self): + self.run_decode({"presence_penalty": 2}) + + def test_mixed(self): + args = [ + {}, + {}, + {}, + {"frequency_penalty": 2}, + {"min_new_tokens": 16}, + {"presence_penalty": 1}, + {"frequency_penalty": 0.2}, + {"min_new_tokens": 8}, + {"presence_penalty": 0.4}, + {"presence_penalty": 0.4, "frequency_penalty": 2}, + {"min_new_tokens": 12, "frequency_penalty": 2}, + ] + random.shuffle(args * 5) + with ThreadPoolExecutor(8) as executor: + list(executor.map(self.run_decode, args)) + + +if __name__ == "__main__": + unittest.main(verbosity=3) diff --git a/test/srt/test_session_control.py b/test/srt/test_session_control.py index 2915133f4..9a3de8d13 100644 --- a/test/srt/test_session_control.py +++ b/test/srt/test_session_control.py @@ -70,7 +70,10 @@ class TestSessionControl(unittest.TestCase): first_rid = None outputs_from_session = [] + logprobs_from_session = [] + cur_logprob_start_len = 0 for i, chunk_ids in enumerate(chunks_ids): + max_new_tokens = gen_len if i > 0 else 1 # prefill only for the first chunk response = requests.post( self.base_url + "/generate", json={ @@ -83,12 +86,12 @@ class TestSessionControl(unittest.TestCase): }, "sampling_params": { "temperature": 0, - "max_new_tokens": ( - gen_len if i > 0 else 1 - ), # prefill only for the first chunk + "max_new_tokens": max_new_tokens, "no_stop_trim": True, "skip_special_tokens": False, }, + "return_logprob": True, + "logprob_start_len": cur_logprob_start_len - 1, }, ).json() rid = response["meta_info"]["id"] @@ -96,8 +99,39 @@ class TestSessionControl(unittest.TestCase): first_rid = rid if i > 0: outputs_from_session.append(response["text"]) + logprobs_from_session.extend( + [ + round(sublist[0], 2) + for sublist in response["meta_info"]["output_token_logprobs"] + ] + ) + cur_logprob_start_len += len(chunk_ids) + max_new_tokens + + # query with a logprob_start_len longer than the request, should see error + response = requests.post( + self.base_url + "/generate", + json={ + "input_ids": chunk_ids, + "session_params": { + "id": session_id, + "rid": rid, + "offset": -1, + "replace": True, + }, + "sampling_params": { + "temperature": 0, + "max_new_tokens": max_new_tokens, + "no_stop_trim": True, + "skip_special_tokens": False, + }, + "return_logprob": True, + "logprob_start_len": cur_logprob_start_len + len(chunk_ids), + }, + ).json() + assert "Request with a lower logprob_start_len" in response["error"]["message"] # backtrack to the first request and regenerate + cur_logprob_start_len = 0 response = requests.post( self.base_url + "/generate", json={ @@ -114,9 +148,17 @@ class TestSessionControl(unittest.TestCase): "no_stop_trim": True, "skip_special_tokens": False, }, + "return_logprob": True, + "logprob_start_len": cur_logprob_start_len, }, ).json() outputs_from_session.append(response["text"]) + logprobs_from_session.extend( + [ + round(sublist[0], 2) + for sublist in response["meta_info"]["output_token_logprobs"] + ] + ) # query with a non-existing rid (the last one should be disappeared becuase of backtrack), should see abort response = requests.post( @@ -135,6 +177,7 @@ class TestSessionControl(unittest.TestCase): "no_stop_trim": True, "skip_special_tokens": False, }, + "return_logprob": True, }, ).json() assert response["meta_info"]["finish_reason"]["type"] == "abort" @@ -162,6 +205,7 @@ class TestSessionControl(unittest.TestCase): "no_stop_trim": True, "skip_special_tokens": False, }, + "return_logprob": True, }, ).json() assert response["meta_info"]["finish_reason"]["type"] == "abort" @@ -172,6 +216,7 @@ class TestSessionControl(unittest.TestCase): input_ids_first_req = None input_ids = [] outputs_normal = [] + logprobs_normal = [] for i, chunk_ids in enumerate(chunks_ids): input_ids += chunk_ids response = requests.post( @@ -186,6 +231,7 @@ class TestSessionControl(unittest.TestCase): "no_stop_trim": True, "skip_special_tokens": False, }, + "return_logprob": True, }, ).json() if i > 0: @@ -194,6 +240,12 @@ class TestSessionControl(unittest.TestCase): output_ids = output_ids[1:] input_ids += output_ids[:-1] outputs_normal.append(response["text"]) + logprobs_normal.extend( + [ + round(sublist[0], 2) + for sublist in response["meta_info"]["output_token_logprobs"] + ] + ) if i == 0: input_ids_first_req = input_ids.copy() @@ -208,17 +260,31 @@ class TestSessionControl(unittest.TestCase): "no_stop_trim": True, "skip_special_tokens": False, }, + "return_logprob": True, }, ).json() outputs_normal.append(response["text"]) + logprobs_normal.extend( + [ + round(sublist[0], 2) + for sublist in response["meta_info"]["output_token_logprobs"] + ] + ) print("outputs from chunked queries with session control:") print(outputs_from_session) print("outputs from normal queries:") print(outputs_normal) - assert ( - outputs_from_session == outputs_normal - ), f"outputs_from_session: {outputs_from_session}, outputs_normal: {outputs_normal}" + assert outputs_from_session == outputs_normal + print("logprobs from chunked queries with session control:") + print(logprobs_from_session) + print("logprobs from normal queries:") + print(logprobs_normal) + assert len(logprobs_from_session) == len( + logprobs_normal + ), "logprobs must have equal length" + for a, b in zip(logprobs_from_session, logprobs_normal): + assert abs(a - b) <= 0.1, f"logprobs {a} and {b} differ by more than 0.1" async def async_generate(self, payload): url = self.base_url + "/generate" diff --git a/test/srt/test_skip_tokenizer_init.py b/test/srt/test_skip_tokenizer_init.py index db7094409..d714a593c 100644 --- a/test/srt/test_skip_tokenizer_init.py +++ b/test/srt/test_skip_tokenizer_init.py @@ -1,3 +1,8 @@ +""" +python3 -m unittest test_skip_tokenizer_init.TestSkipTokenizerInit.test_parallel_sample +python3 -m unittest test_skip_tokenizer_init.TestSkipTokenizerInit.run_decode_stream +""" + import json import unittest @@ -12,42 +17,26 @@ from sglang.test.test_utils import ( popen_launch_server, ) -_server_process = None -_base_url = None -_tokenizer = None - - -def setUpModule(): - """ - Launch the server once before all tests and initialize the tokenizer. - """ - global _server_process, _base_url, _tokenizer - _server_process = popen_launch_server( - DEFAULT_SMALL_MODEL_NAME_FOR_TEST, - DEFAULT_URL_FOR_TEST, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=["--skip-tokenizer-init"], - ) - _base_url = DEFAULT_URL_FOR_TEST - - _tokenizer = AutoTokenizer.from_pretrained( - DEFAULT_SMALL_MODEL_NAME_FOR_TEST, use_fast=False - ) - print(">>> setUpModule: Server launched, tokenizer ready") - - -def tearDownModule(): - """ - Terminate the server once after all tests have completed. - """ - global _server_process - if _server_process is not None: - kill_process_tree(_server_process.pid) - _server_process = None - print(">>> tearDownModule: Server terminated") - class TestSkipTokenizerInit(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--skip-tokenizer-init", "--stream-output"], + ) + cls.tokenizer = AutoTokenizer.from_pretrained( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, use_fast=False + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + def run_decode( self, prompt_text="The capital of France is", @@ -56,19 +45,19 @@ class TestSkipTokenizerInit(unittest.TestCase): top_logprobs_num=0, n=1, ): - input_ids = _tokenizer(prompt_text, return_tensors="pt")["input_ids"][ + input_ids = self.tokenizer(prompt_text, return_tensors="pt")["input_ids"][ 0 ].tolist() response = requests.post( - _base_url + "/generate", + self.base_url + "/generate", json={ "input_ids": input_ids, "sampling_params": { "temperature": 0 if n == 1 else 0.5, "max_new_tokens": max_new_tokens, "n": n, - "stop_token_ids": [_tokenizer.eos_token_id], + "stop_token_ids": [self.tokenizer.eos_token_id], }, "stream": False, "return_logprob": return_logprob, @@ -83,13 +72,13 @@ class TestSkipTokenizerInit(unittest.TestCase): if item["meta_info"]["finish_reason"]["type"] == "stop": self.assertEqual( item["meta_info"]["finish_reason"]["matched"], - _tokenizer.eos_token_id, + self.tokenizer.eos_token_id, ) elif item["meta_info"]["finish_reason"]["type"] == "length": self.assertEqual( - len(item["token_ids"]), item["meta_info"]["completion_tokens"] + len(item["output_ids"]), item["meta_info"]["completion_tokens"] ) - self.assertEqual(len(item["token_ids"]), max_new_tokens) + self.assertEqual(len(item["output_ids"]), max_new_tokens) self.assertEqual(item["meta_info"]["prompt_tokens"], len(input_ids)) if return_logprob: @@ -113,6 +102,63 @@ class TestSkipTokenizerInit(unittest.TestCase): print("=" * 100) + def run_decode_stream(self, return_logprob=False, top_logprobs_num=0, n=1): + max_new_tokens = 32 + input_ids = [128000, 791, 6864, 315, 9822, 374] # The capital of France is + requests.post(self.base_url + "/flush_cache") + response = requests.post( + self.base_url + "/generate", + json={ + "input_ids": input_ids, + "sampling_params": { + "temperature": 0 if n == 1 else 0.5, + "max_new_tokens": max_new_tokens, + "n": n, + "stop_token_ids": [119690], + }, + "stream": False, + "return_logprob": return_logprob, + "top_logprobs_num": top_logprobs_num, + "logprob_start_len": 0, + }, + ) + ret = response.json() + print(json.dumps(ret)) + output_ids = ret["output_ids"] + + requests.post(self.base_url + "/flush_cache") + response_stream = requests.post( + self.base_url + "/generate", + json={ + "input_ids": input_ids, + "sampling_params": { + "temperature": 0 if n == 1 else 0.5, + "max_new_tokens": max_new_tokens, + "n": n, + "stop_token_ids": [119690], + }, + "stream": True, + "return_logprob": return_logprob, + "top_logprobs_num": top_logprobs_num, + "logprob_start_len": 0, + }, + ) + ret = response.json() + output_ids = ret["output_ids"] + print("output from non-streaming request:") + print(output_ids) + + response_stream_json = [] + for line in response_stream.iter_lines(): + if line.startswith(b"data: ") and line[6:] != b"[DONE]": + response_stream_json.append(json.loads(line[6:])) + out_stream_ids = [] + for x in response_stream_json: + out_stream_ids += x["output_ids"] + print("output from streaming request:") + print(out_stream_ids) + assert output_ids == out_stream_ids + def test_simple_decode(self): self.run_decode() @@ -126,6 +172,9 @@ class TestSkipTokenizerInit(unittest.TestCase): def test_eos_behavior(self): self.run_decode(max_new_tokens=256) + def test_simple_decode_stream(self): + self.run_decode_stream() + if __name__ == "__main__": unittest.main() diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py index 68db1d699..9673b19c5 100644 --- a/test/srt/test_srt_endpoint.py +++ b/test/srt/test_srt_endpoint.py @@ -8,6 +8,7 @@ import random import time import unittest from concurrent.futures import ThreadPoolExecutor +from functools import partial from typing import Optional import numpy as np @@ -20,6 +21,7 @@ from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, popen_launch_server, + run_logprob_check, ) @@ -35,7 +37,9 @@ class TestSRTEndpoint(unittest.TestCase): other_args=( "--enable-custom-logit-processor", "--mem-fraction-static", - "0.8", + "0.7", + "--cuda-graph-max-bs", + "8", ), ) @@ -131,7 +135,7 @@ class TestSRTEndpoint(unittest.TestCase): for i, res in enumerate(response_json): self.assertEqual( res["meta_info"]["prompt_tokens"], - logprob_start_len + 1 + len(res["meta_info"]["input_token_logprobs"]), + logprob_start_len + len(res["meta_info"]["input_token_logprobs"]), ) assert prompts[i].endswith( "".join([x[-1] for x in res["meta_info"]["input_token_logprobs"]]) @@ -235,83 +239,15 @@ class TestSRTEndpoint(unittest.TestCase): diff = np.abs(output_logprobs - output_logprobs_score) max_diff = np.max(diff) - self.assertLess(max_diff, 0.25) - - def run_logprob_check(self, arg): - ( - input_len, - output_len, - temperature, - logprob_start_len, - return_logprob, - top_logprobs_num, - ) = arg - input_ids = list(range(input_len)) - - response = requests.post( - self.base_url + "/generate", - json={ - "input_ids": input_ids, - "sampling_params": { - "temperature": temperature, - "max_new_tokens": output_len, - }, - "return_logprob": return_logprob, - "logprob_start_len": logprob_start_len, - "top_logprobs_num": top_logprobs_num, - }, - ) - response_json = response.json() - - res = response_json - self.assertEqual(res["meta_info"]["prompt_tokens"], input_len) - self.assertEqual(res["meta_info"]["completion_tokens"], output_len) - - # Test the number of tokens are correct - if return_logprob: - # This is because if logprob_start_len == 0, we added a padding for the first token. - # In other cases, we do not add the padding - delta = 0 if logprob_start_len == 0 else 1 - - self.assertEqual( - len(res["meta_info"]["input_token_logprobs"]) - + logprob_start_len - + delta, - res["meta_info"]["prompt_tokens"], - ) - self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), output_len) - - if top_logprobs_num: - self.assertEqual( - len(res["meta_info"]["input_top_logprobs"]) - + logprob_start_len - + delta, - res["meta_info"]["prompt_tokens"], - ) - self.assertEqual( - len(res["meta_info"]["output_top_logprobs"]), output_len - ) - - for i in range(output_len): - self.assertEqual( - len(res["meta_info"]["output_top_logprobs"][i]), - top_logprobs_num, - ) - - # Test the top-1 tokens are the same as output tokens if temperature == 0 - if temperature == 0: - self.assertListEqual( - res["meta_info"]["output_token_logprobs"][i], - res["meta_info"]["output_top_logprobs"][i][0], - ) + self.assertLess(max_diff, 0.35) def test_logprob_mixed(self): args = [] temperature = 0 # input_len, output_len, temperature, logprob_start_len, return_logprob, top_logprobs_num - for input_len in [1000, 2000]: + for input_len in [1000, 5000, 10000, 50000]: for output_len in [4, 8]: - for logprob_start_len in [0, 500, 1000]: + for logprob_start_len in [0, 500, 2500, 5000, 25000]: for return_logprob in [True, False]: for top_logprobs_num in [0, 5]: @@ -331,8 +267,9 @@ class TestSRTEndpoint(unittest.TestCase): random.shuffle(args) + func = partial(run_logprob_check, self) with ThreadPoolExecutor(8) as executor: - list(executor.map(self.run_logprob_check, args)) + list(executor.map(func, args)) def test_logprob_grammar(self): prompts = "Question: Is Paris the Capital of France? Answer:" @@ -427,6 +364,77 @@ class TestSRTEndpoint(unittest.TestCase): f"{target_token_id=}\n{sampled_tokens=}\n{custom_response=}", ) + def run_stateful_custom_logit_processor( + self, first_token_id: int | None, delay: int = 2 + ): + """Test custom logit processor with custom params and state. + + Should sample the first `delay` tokens normally, then output first_token_id and consecutive tokens after that. + If first_token_id is None, the custom logit processor won't be passed in. + """ + + custom_params = {"token_id": first_token_id, "delay": 2} + + class DeterministicStatefulLogitProcessor(CustomLogitProcessor): + """A dummy logit processor that changes the logits to always + sample the given token id. + """ + + def __call__(self, logits, custom_param_list): + assert logits.shape[0] == len(custom_param_list) + + for i, param_dict in enumerate(custom_param_list): + if param_dict["delay"] > 0: + param_dict["delay"] -= 1 + continue + if param_dict["delay"] == 0: + param_dict["delay"] -= 1 + force_token = param_dict["token_id"] + else: + output_ids = param_dict["__req__"].output_ids + force_token = output_ids[-1] + 1 + # Mask all other tokens + logits[i, :] = -float("inf") + # Assign highest probability to the specified token + logits[i, force_token] = 0.0 + return logits + + prompts = "Question: Is Paris the Capital of France? Answer:" + + # Base case json data to be posted to the server. + base_json = { + "text": prompts, + "sampling_params": {"temperature": 0.0}, + "return_logprob": True, + } + + # Custom json data with custom logit processor and params. + custom_json = base_json.copy() + # Only set the custom logit processor if target_token_id is not None. + if first_token_id is not None: + custom_json["custom_logit_processor"] = ( + DeterministicStatefulLogitProcessor().to_str() + ) + custom_json["sampling_params"]["custom_params"] = custom_params + + custom_response = requests.post( + self.base_url + "/generate", + json=custom_json, + ).json() + + output_token_logprobs = custom_response["meta_info"]["output_token_logprobs"] + sampled_tokens = [x[1] for x in output_token_logprobs] + # The logit processor should always sample the given token as the logits is deterministic. + if first_token_id is not None: + self.assertTrue( + all( + x == custom_params["token_id"] + k + for k, x in enumerate(sampled_tokens[custom_params["delay"] :]) + ), + # Print the detailed test case info if the test fails. + f"{first_token_id=}\n{sampled_tokens=}\n{custom_response=}", + ) + def test_custom_logit_processor(self): """Test custom logit processor with a single request.""" self.run_custom_logit_processor(target_token_id=5) @@ -438,6 +446,19 @@ class TestSRTEndpoint(unittest.TestCase): with ThreadPoolExecutor(len(target_token_ids)) as executor: list(executor.map(self.run_custom_logit_processor, target_token_ids)) + def test_stateful_custom_logit_processor(self): + """Test custom logit processor with a single request.""" + self.run_stateful_custom_logit_processor(first_token_id=5) + + def test_stateful_custom_logit_processor_batch_mixed(self): + """Test a batch of requests mixed of requests with and without custom logit processor.""" + target_token_ids = list(range(32)) + [None] * 16 + random.shuffle(target_token_ids) + with ThreadPoolExecutor(len(target_token_ids)) as executor: + list( + executor.map(self.run_stateful_custom_logit_processor, target_token_ids) + ) + def test_cache_tokens(self): for _ in range(2): time.sleep(1) @@ -476,6 +497,21 @@ class TestSRTEndpoint(unittest.TestCase): version = response_json["version"] self.assertIsInstance(version, str) + def test_get_server_info_concurrent(self): + """Make sure the concurrent get_server_info doesn't crash the server.""" + tp = ThreadPoolExecutor(max_workers=30) + + def s(): + server_info = requests.get(self.base_url + "/get_server_info") + server_info.json() + + futures = [] + for _ in range(4): + futures.append(tp.submit(s)) + + for f in futures: + f.result() + if __name__ == "__main__": unittest.main() diff --git a/test/srt/test_verl_engine.py b/test/srt/test_verl_engine.py index 1d134ad41..88455525f 100644 --- a/test/srt/test_verl_engine.py +++ b/test/srt/test_verl_engine.py @@ -168,9 +168,9 @@ def _run_subprocess( hf_tokenizer = get_tokenizer(model_path, trust_remote_code=True) hf_outputs = HFRunner.forward_generation_raw( + base_model=hf_model, prompts=_PROMPTS, max_new_tokens=_MAX_NEW_TOKENS, - base_model=hf_model, tokenizer=hf_tokenizer, lora_paths=None, torch_dtype=_TORCH_DTYPE,