diff --git a/docs/test_process.md b/docs/test_process.md index e7aff5b5a..18f91c6d4 100644 --- a/docs/test_process.md +++ b/docs/test_process.md @@ -1,8 +1,18 @@ ## SRT Unit Tests ### Latency Alignment +Make sure your changes do not slow down the following benchmarks ``` +# single gpu python -m sglang.bench_latency --model-path meta-llama/Llama-2-7b-chat-hf --mem-fraction-static 0.8 --batch 32 --input-len 512 --output-len 256 +python -m sglang.bench_latency --model-path meta-llama/Llama-2-7b-chat-hf --mem-fraction-static 0.8 --batch 1 --input-len 512 --output-len 256 + +# multiple gpu +python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-70B --tp 8 --mem-fraction-static 0.6 --batch 32 --input-len 8192 --output-len 1 +python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-70B --tp 8 --mem-fraction-static 0.6 --batch 1 --input-len 8100 --output-len 32 + +# moe model +python -m sglang.bench_latency --model-path databricks/dbrx-base --tp 8 --mem-fraction-static 0.6 --batch 4 --input-len 1024 --output-len 32 ``` ### High-level API diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index a163cbd30..ca09028f4 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -230,7 +230,7 @@ def latency_test( prefill_latency = time.time() - tic tot_latency += prefill_latency throughput = bench_args.input_len * bench_args.batch_size / prefill_latency - rank_print(f"Prefill. latency: {prefill_latency:6.5f} ms, throughput: {throughput:9.2f} token/s") + rank_print(f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s") # Decode for i in range(output_len): @@ -241,13 +241,13 @@ def latency_test( latency = time.time() - tic tot_latency += latency throughput = bench_args.batch_size / latency - if i < 5: rank_print(f"Decode. latency: {latency:6.5f} ms, throughput: {throughput:9.2f} token/s") + if i < 5: rank_print(f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s") avg_decode_latency = (tot_latency - prefill_latency) / output_len avg_decode_throughput = bench_args.batch_size / avg_decode_latency - rank_print(f"Decode. avg latency: {avg_decode_latency:6.5f} ms, avg throughput: {avg_decode_throughput:9.2f} token/s") + rank_print(f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s") throughput = (bench_args.input_len + bench_args.output_len) * bench_args.batch_size / tot_latency - rank_print(f"Total. latency: {tot_latency:6.3f} ms, throughput: {throughput:9.2f} token/s") + rank_print(f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s") # Warm up run_once(4) diff --git a/python/sglang/global_config.py b/python/sglang/global_config.py index 0cc0f747f..377bde82e 100644 --- a/python/sglang/global_config.py +++ b/python/sglang/global_config.py @@ -35,5 +35,8 @@ class GlobalConfig: self.new_token_ratio_decay = 0.0001 self.new_token_ratio_recovery = 0.05 + # The threshold (number of tokens) to trigger layer-wise cuda sync. + # This can improve the speed for large batch sizes during prefill. + self.layer_sync_threshold = 8192 global_config = GlobalConfig() diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 66d206082..c46c11237 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -4,6 +4,7 @@ import numpy as np import torch from torch import nn +from sglang.global_config import global_config from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd from sglang.srt.layers.extend_attention import extend_attention_fwd from sglang.srt.layers.token_attention import token_attention_fwd @@ -103,12 +104,29 @@ class RadixAttention(nn.Module): def prefill_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): self.store_kv_cache(k, v, input_metadata) - o = input_metadata.flashinfer_prefill_wrapper.forward( + o1, s1 = input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse( q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), - input_metadata.token_to_kv_pool.kv_data[self.layer_id], + k.contiguous().view(-1, self.tp_k_head_num, self.head_dim), + v.contiguous().view(-1, self.tp_v_head_num, self.head_dim), logits_soft_cap=self.logit_cap, ) + if input_metadata.no_prefix: + o = o1 + else: + o2, s2 = input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse( + q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), + input_metadata.token_to_kv_pool.kv_data[self.layer_id], + causal=False, + logits_soft_cap=self.logit_cap, + ) + + from flashinfer.cascade import merge_state + o, _ = merge_state(o1, s1, o2, s2) + + if input_metadata.total_num_tokens >= global_config.layer_sync_threshold: + torch.cuda.synchronize() + return o.view(-1, self.tp_q_head_num * self.head_dim) def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): diff --git a/python/sglang/srt/managers/controller/model_runner.py b/python/sglang/srt/managers/controller/model_runner.py index e41514706..bded85af9 100644 --- a/python/sglang/srt/managers/controller/model_runner.py +++ b/python/sglang/srt/managers/controller/model_runner.py @@ -65,23 +65,33 @@ class InputMetadata: kv_indptr: torch.Tensor = None kv_indices: torch.Tensor = None kv_last_page_len: torch.Tensor = None - flashinfer_prefill_wrapper: "BatchPrefillWithPagedKVCacheWrapper" = None + flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None + flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim): + if ( + self.forward_mode == ForwardMode.PREFILL + or self.forward_mode == ForwardMode.EXTEND + ): + paged_kernel_lens = self.prefix_lens + self.no_prefix = torch.all(self.prefix_lens == 0) + else: + paged_kernel_lens = self.seq_lens + self.kv_indptr = torch.zeros( (self.batch_size + 1,), dtype=torch.int32, device="cuda" ) - self.kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0) + self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) self.kv_last_page_len = torch.ones( (self.batch_size,), dtype=torch.int32, device="cuda" ) req_pool_indices_cpu = self.req_pool_indices.cpu().numpy() - seq_lens_cpu = self.seq_lens.cpu().numpy() + paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy() self.kv_indices = torch.cat( [ self.req_to_token_pool.req_to_token[ - req_pool_indices_cpu[i], : seq_lens_cpu[i] + req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i] ] for i in range(self.batch_size) ], @@ -92,13 +102,24 @@ class InputMetadata: self.forward_mode == ForwardMode.PREFILL or self.forward_mode == ForwardMode.EXTEND ): + # extend part self.qo_indptr = torch.zeros( (self.batch_size + 1,), dtype=torch.int32, device="cuda" ) self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0) - self.flashinfer_prefill_wrapper.end_forward() - self.flashinfer_prefill_wrapper.begin_forward( + self.flashinfer_prefill_wrapper_ragged.end_forward() + self.flashinfer_prefill_wrapper_ragged.begin_forward( + self.qo_indptr, + self.qo_indptr.clone(), + num_qo_heads, + num_kv_heads, + head_dim, + ) + + # cached part + self.flashinfer_prefill_wrapper_paged.end_forward() + self.flashinfer_prefill_wrapper_paged.begin_forward( self.qo_indptr, self.kv_indptr, self.kv_indices, @@ -143,7 +164,8 @@ class InputMetadata: out_cache_cont_end=None, top_logprobs_nums=None, return_logprob=False, - flashinfer_prefill_wrapper=None, + flashinfer_prefill_wrapper_ragged=None, + flashinfer_prefill_wrapper_paged=None, flashinfer_decode_wrapper=None, ): batch_size = len(req_pool_indices) @@ -194,7 +216,8 @@ class InputMetadata: other_kv_index=other_kv_index, return_logprob=return_logprob, top_logprobs_nums=top_logprobs_nums, - flashinfer_prefill_wrapper=flashinfer_prefill_wrapper, + flashinfer_prefill_wrapper_ragged=flashinfer_prefill_wrapper_ragged, + flashinfer_prefill_wrapper_paged=flashinfer_prefill_wrapper_paged, flashinfer_decode_wrapper=flashinfer_decode_wrapper, ) @@ -361,6 +384,7 @@ class ModelRunner: def init_flash_infer(self): if not global_server_args_dict.get("disable_flashinfer", False): from flashinfer import ( + BatchPrefillWithRaggedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper, BatchDecodeWithPagedKVCacheWrapper, ) @@ -373,17 +397,21 @@ class ModelRunner: else: use_tensor_cores = False - workspace_buffer = torch.empty( - 128 * 1024 * 1024, dtype=torch.int8, device="cuda" + workspace_buffers = torch.empty( + 3, 96 * 1024 * 1024, dtype=torch.uint8, device="cuda" ) - self.flashinfer_prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, "NHD" + self.flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper( + workspace_buffers[0], "NHD" + ) + self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper( + workspace_buffers[1], "NHD" ) self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, "NHD", use_tensor_cores=use_tensor_cores + workspace_buffers[2], "NHD", use_tensor_cores=use_tensor_cores ) else: - self.flashinfer_prefill_wrapper = self.flashinfer_decode_wrapper = None + self.flashinfer_prefill_wrapper_ragged = self.flashinfer_prefill_wrapper_paged = None + self.flashinfer_decode_wrapper = None @torch.inference_mode() def forward_prefill(self, batch: Batch): @@ -398,7 +426,8 @@ class ModelRunner: out_cache_loc=batch.out_cache_loc, top_logprobs_nums=batch.top_logprobs_nums, return_logprob=batch.return_logprob, - flashinfer_prefill_wrapper=self.flashinfer_prefill_wrapper, + flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged, + flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged, flashinfer_decode_wrapper=self.flashinfer_decode_wrapper, ) return self.model.forward( @@ -418,7 +447,8 @@ class ModelRunner: out_cache_loc=batch.out_cache_loc, top_logprobs_nums=batch.top_logprobs_nums, return_logprob=batch.return_logprob, - flashinfer_prefill_wrapper=self.flashinfer_prefill_wrapper, + flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged, + flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged, flashinfer_decode_wrapper=self.flashinfer_decode_wrapper, ) return self.model.forward( @@ -440,7 +470,8 @@ class ModelRunner: out_cache_cont_end=batch.out_cache_cont_end, top_logprobs_nums=batch.top_logprobs_nums, return_logprob=batch.return_logprob, - flashinfer_prefill_wrapper=self.flashinfer_prefill_wrapper, + flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged, + flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged, flashinfer_decode_wrapper=self.flashinfer_decode_wrapper, ) return self.model.forward( @@ -460,7 +491,8 @@ class ModelRunner: out_cache_loc=batch.out_cache_loc, top_logprobs_nums=batch.top_logprobs_nums, return_logprob=batch.return_logprob, - flashinfer_prefill_wrapper=self.flashinfer_prefill_wrapper, + flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged, + flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged, flashinfer_decode_wrapper=self.flashinfer_decode_wrapper, ) return self.model.forward( diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index e28530889..78bd2e0d1 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -152,7 +152,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg if server_args.disable_disk_cache: disable_cache() if not server_args.disable_flashinfer: - assert_pkg_version("flashinfer", "0.0.7") + assert_pkg_version("flashinfer", "0.0.8") if server_args.chat_template: # TODO: replace this with huggingface transformers template load_chat_template_for_openai_api(server_args.chat_template)