2x performance improvement for large prefill & Fix workspace conflicts (#579)
This commit is contained in:
@@ -1,8 +1,18 @@
|
|||||||
## SRT Unit Tests
|
## SRT Unit Tests
|
||||||
|
|
||||||
### Latency Alignment
|
### Latency Alignment
|
||||||
|
Make sure your changes do not slow down the following benchmarks
|
||||||
```
|
```
|
||||||
|
# single gpu
|
||||||
python -m sglang.bench_latency --model-path meta-llama/Llama-2-7b-chat-hf --mem-fraction-static 0.8 --batch 32 --input-len 512 --output-len 256
|
python -m sglang.bench_latency --model-path meta-llama/Llama-2-7b-chat-hf --mem-fraction-static 0.8 --batch 32 --input-len 512 --output-len 256
|
||||||
|
python -m sglang.bench_latency --model-path meta-llama/Llama-2-7b-chat-hf --mem-fraction-static 0.8 --batch 1 --input-len 512 --output-len 256
|
||||||
|
|
||||||
|
# multiple gpu
|
||||||
|
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-70B --tp 8 --mem-fraction-static 0.6 --batch 32 --input-len 8192 --output-len 1
|
||||||
|
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-70B --tp 8 --mem-fraction-static 0.6 --batch 1 --input-len 8100 --output-len 32
|
||||||
|
|
||||||
|
# moe model
|
||||||
|
python -m sglang.bench_latency --model-path databricks/dbrx-base --tp 8 --mem-fraction-static 0.6 --batch 4 --input-len 1024 --output-len 32
|
||||||
```
|
```
|
||||||
|
|
||||||
### High-level API
|
### High-level API
|
||||||
|
|||||||
@@ -230,7 +230,7 @@ def latency_test(
|
|||||||
prefill_latency = time.time() - tic
|
prefill_latency = time.time() - tic
|
||||||
tot_latency += prefill_latency
|
tot_latency += prefill_latency
|
||||||
throughput = bench_args.input_len * bench_args.batch_size / prefill_latency
|
throughput = bench_args.input_len * bench_args.batch_size / prefill_latency
|
||||||
rank_print(f"Prefill. latency: {prefill_latency:6.5f} ms, throughput: {throughput:9.2f} token/s")
|
rank_print(f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s")
|
||||||
|
|
||||||
# Decode
|
# Decode
|
||||||
for i in range(output_len):
|
for i in range(output_len):
|
||||||
@@ -241,13 +241,13 @@ def latency_test(
|
|||||||
latency = time.time() - tic
|
latency = time.time() - tic
|
||||||
tot_latency += latency
|
tot_latency += latency
|
||||||
throughput = bench_args.batch_size / latency
|
throughput = bench_args.batch_size / latency
|
||||||
if i < 5: rank_print(f"Decode. latency: {latency:6.5f} ms, throughput: {throughput:9.2f} token/s")
|
if i < 5: rank_print(f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s")
|
||||||
avg_decode_latency = (tot_latency - prefill_latency) / output_len
|
avg_decode_latency = (tot_latency - prefill_latency) / output_len
|
||||||
avg_decode_throughput = bench_args.batch_size / avg_decode_latency
|
avg_decode_throughput = bench_args.batch_size / avg_decode_latency
|
||||||
rank_print(f"Decode. avg latency: {avg_decode_latency:6.5f} ms, avg throughput: {avg_decode_throughput:9.2f} token/s")
|
rank_print(f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s")
|
||||||
|
|
||||||
throughput = (bench_args.input_len + bench_args.output_len) * bench_args.batch_size / tot_latency
|
throughput = (bench_args.input_len + bench_args.output_len) * bench_args.batch_size / tot_latency
|
||||||
rank_print(f"Total. latency: {tot_latency:6.3f} ms, throughput: {throughput:9.2f} token/s")
|
rank_print(f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s")
|
||||||
|
|
||||||
# Warm up
|
# Warm up
|
||||||
run_once(4)
|
run_once(4)
|
||||||
|
|||||||
@@ -35,5 +35,8 @@ class GlobalConfig:
|
|||||||
self.new_token_ratio_decay = 0.0001
|
self.new_token_ratio_decay = 0.0001
|
||||||
self.new_token_ratio_recovery = 0.05
|
self.new_token_ratio_recovery = 0.05
|
||||||
|
|
||||||
|
# The threshold (number of tokens) to trigger layer-wise cuda sync.
|
||||||
|
# This can improve the speed for large batch sizes during prefill.
|
||||||
|
self.layer_sync_threshold = 8192
|
||||||
|
|
||||||
global_config = GlobalConfig()
|
global_config = GlobalConfig()
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
|
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
|
||||||
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
||||||
from sglang.srt.layers.token_attention import token_attention_fwd
|
from sglang.srt.layers.token_attention import token_attention_fwd
|
||||||
@@ -103,12 +104,29 @@ class RadixAttention(nn.Module):
|
|||||||
def prefill_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
def prefill_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
||||||
self.store_kv_cache(k, v, input_metadata)
|
self.store_kv_cache(k, v, input_metadata)
|
||||||
|
|
||||||
o = input_metadata.flashinfer_prefill_wrapper.forward(
|
o1, s1 = input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
|
||||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||||
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
|
k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
|
||||||
|
v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
|
||||||
logits_soft_cap=self.logit_cap,
|
logits_soft_cap=self.logit_cap,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if input_metadata.no_prefix:
|
||||||
|
o = o1
|
||||||
|
else:
|
||||||
|
o2, s2 = input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(
|
||||||
|
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||||
|
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
|
||||||
|
causal=False,
|
||||||
|
logits_soft_cap=self.logit_cap,
|
||||||
|
)
|
||||||
|
|
||||||
|
from flashinfer.cascade import merge_state
|
||||||
|
o, _ = merge_state(o1, s1, o2, s2)
|
||||||
|
|
||||||
|
if input_metadata.total_num_tokens >= global_config.layer_sync_threshold:
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
||||||
|
|
||||||
def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
||||||
|
|||||||
@@ -65,23 +65,33 @@ class InputMetadata:
|
|||||||
kv_indptr: torch.Tensor = None
|
kv_indptr: torch.Tensor = None
|
||||||
kv_indices: torch.Tensor = None
|
kv_indices: torch.Tensor = None
|
||||||
kv_last_page_len: torch.Tensor = None
|
kv_last_page_len: torch.Tensor = None
|
||||||
flashinfer_prefill_wrapper: "BatchPrefillWithPagedKVCacheWrapper" = None
|
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
|
||||||
|
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
|
||||||
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
||||||
|
|
||||||
def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim):
|
def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim):
|
||||||
|
if (
|
||||||
|
self.forward_mode == ForwardMode.PREFILL
|
||||||
|
or self.forward_mode == ForwardMode.EXTEND
|
||||||
|
):
|
||||||
|
paged_kernel_lens = self.prefix_lens
|
||||||
|
self.no_prefix = torch.all(self.prefix_lens == 0)
|
||||||
|
else:
|
||||||
|
paged_kernel_lens = self.seq_lens
|
||||||
|
|
||||||
self.kv_indptr = torch.zeros(
|
self.kv_indptr = torch.zeros(
|
||||||
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
||||||
)
|
)
|
||||||
self.kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0)
|
self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||||
self.kv_last_page_len = torch.ones(
|
self.kv_last_page_len = torch.ones(
|
||||||
(self.batch_size,), dtype=torch.int32, device="cuda"
|
(self.batch_size,), dtype=torch.int32, device="cuda"
|
||||||
)
|
)
|
||||||
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
|
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
|
||||||
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
||||||
self.kv_indices = torch.cat(
|
self.kv_indices = torch.cat(
|
||||||
[
|
[
|
||||||
self.req_to_token_pool.req_to_token[
|
self.req_to_token_pool.req_to_token[
|
||||||
req_pool_indices_cpu[i], : seq_lens_cpu[i]
|
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
|
||||||
]
|
]
|
||||||
for i in range(self.batch_size)
|
for i in range(self.batch_size)
|
||||||
],
|
],
|
||||||
@@ -92,13 +102,24 @@ class InputMetadata:
|
|||||||
self.forward_mode == ForwardMode.PREFILL
|
self.forward_mode == ForwardMode.PREFILL
|
||||||
or self.forward_mode == ForwardMode.EXTEND
|
or self.forward_mode == ForwardMode.EXTEND
|
||||||
):
|
):
|
||||||
|
# extend part
|
||||||
self.qo_indptr = torch.zeros(
|
self.qo_indptr = torch.zeros(
|
||||||
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
||||||
)
|
)
|
||||||
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
|
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
|
||||||
|
|
||||||
self.flashinfer_prefill_wrapper.end_forward()
|
self.flashinfer_prefill_wrapper_ragged.end_forward()
|
||||||
self.flashinfer_prefill_wrapper.begin_forward(
|
self.flashinfer_prefill_wrapper_ragged.begin_forward(
|
||||||
|
self.qo_indptr,
|
||||||
|
self.qo_indptr.clone(),
|
||||||
|
num_qo_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
# cached part
|
||||||
|
self.flashinfer_prefill_wrapper_paged.end_forward()
|
||||||
|
self.flashinfer_prefill_wrapper_paged.begin_forward(
|
||||||
self.qo_indptr,
|
self.qo_indptr,
|
||||||
self.kv_indptr,
|
self.kv_indptr,
|
||||||
self.kv_indices,
|
self.kv_indices,
|
||||||
@@ -143,7 +164,8 @@ class InputMetadata:
|
|||||||
out_cache_cont_end=None,
|
out_cache_cont_end=None,
|
||||||
top_logprobs_nums=None,
|
top_logprobs_nums=None,
|
||||||
return_logprob=False,
|
return_logprob=False,
|
||||||
flashinfer_prefill_wrapper=None,
|
flashinfer_prefill_wrapper_ragged=None,
|
||||||
|
flashinfer_prefill_wrapper_paged=None,
|
||||||
flashinfer_decode_wrapper=None,
|
flashinfer_decode_wrapper=None,
|
||||||
):
|
):
|
||||||
batch_size = len(req_pool_indices)
|
batch_size = len(req_pool_indices)
|
||||||
@@ -194,7 +216,8 @@ class InputMetadata:
|
|||||||
other_kv_index=other_kv_index,
|
other_kv_index=other_kv_index,
|
||||||
return_logprob=return_logprob,
|
return_logprob=return_logprob,
|
||||||
top_logprobs_nums=top_logprobs_nums,
|
top_logprobs_nums=top_logprobs_nums,
|
||||||
flashinfer_prefill_wrapper=flashinfer_prefill_wrapper,
|
flashinfer_prefill_wrapper_ragged=flashinfer_prefill_wrapper_ragged,
|
||||||
|
flashinfer_prefill_wrapper_paged=flashinfer_prefill_wrapper_paged,
|
||||||
flashinfer_decode_wrapper=flashinfer_decode_wrapper,
|
flashinfer_decode_wrapper=flashinfer_decode_wrapper,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -361,6 +384,7 @@ class ModelRunner:
|
|||||||
def init_flash_infer(self):
|
def init_flash_infer(self):
|
||||||
if not global_server_args_dict.get("disable_flashinfer", False):
|
if not global_server_args_dict.get("disable_flashinfer", False):
|
||||||
from flashinfer import (
|
from flashinfer import (
|
||||||
|
BatchPrefillWithRaggedKVCacheWrapper,
|
||||||
BatchPrefillWithPagedKVCacheWrapper,
|
BatchPrefillWithPagedKVCacheWrapper,
|
||||||
BatchDecodeWithPagedKVCacheWrapper,
|
BatchDecodeWithPagedKVCacheWrapper,
|
||||||
)
|
)
|
||||||
@@ -373,17 +397,21 @@ class ModelRunner:
|
|||||||
else:
|
else:
|
||||||
use_tensor_cores = False
|
use_tensor_cores = False
|
||||||
|
|
||||||
workspace_buffer = torch.empty(
|
workspace_buffers = torch.empty(
|
||||||
128 * 1024 * 1024, dtype=torch.int8, device="cuda"
|
3, 96 * 1024 * 1024, dtype=torch.uint8, device="cuda"
|
||||||
)
|
)
|
||||||
self.flashinfer_prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
|
self.flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
||||||
workspace_buffer, "NHD"
|
workspace_buffers[0], "NHD"
|
||||||
|
)
|
||||||
|
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
||||||
|
workspace_buffers[1], "NHD"
|
||||||
)
|
)
|
||||||
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||||
workspace_buffer, "NHD", use_tensor_cores=use_tensor_cores
|
workspace_buffers[2], "NHD", use_tensor_cores=use_tensor_cores
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.flashinfer_prefill_wrapper = self.flashinfer_decode_wrapper = None
|
self.flashinfer_prefill_wrapper_ragged = self.flashinfer_prefill_wrapper_paged = None
|
||||||
|
self.flashinfer_decode_wrapper = None
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def forward_prefill(self, batch: Batch):
|
def forward_prefill(self, batch: Batch):
|
||||||
@@ -398,7 +426,8 @@ class ModelRunner:
|
|||||||
out_cache_loc=batch.out_cache_loc,
|
out_cache_loc=batch.out_cache_loc,
|
||||||
top_logprobs_nums=batch.top_logprobs_nums,
|
top_logprobs_nums=batch.top_logprobs_nums,
|
||||||
return_logprob=batch.return_logprob,
|
return_logprob=batch.return_logprob,
|
||||||
flashinfer_prefill_wrapper=self.flashinfer_prefill_wrapper,
|
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
|
||||||
|
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
|
||||||
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
||||||
)
|
)
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
@@ -418,7 +447,8 @@ class ModelRunner:
|
|||||||
out_cache_loc=batch.out_cache_loc,
|
out_cache_loc=batch.out_cache_loc,
|
||||||
top_logprobs_nums=batch.top_logprobs_nums,
|
top_logprobs_nums=batch.top_logprobs_nums,
|
||||||
return_logprob=batch.return_logprob,
|
return_logprob=batch.return_logprob,
|
||||||
flashinfer_prefill_wrapper=self.flashinfer_prefill_wrapper,
|
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
|
||||||
|
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
|
||||||
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
||||||
)
|
)
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
@@ -440,7 +470,8 @@ class ModelRunner:
|
|||||||
out_cache_cont_end=batch.out_cache_cont_end,
|
out_cache_cont_end=batch.out_cache_cont_end,
|
||||||
top_logprobs_nums=batch.top_logprobs_nums,
|
top_logprobs_nums=batch.top_logprobs_nums,
|
||||||
return_logprob=batch.return_logprob,
|
return_logprob=batch.return_logprob,
|
||||||
flashinfer_prefill_wrapper=self.flashinfer_prefill_wrapper,
|
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
|
||||||
|
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
|
||||||
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
||||||
)
|
)
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
@@ -460,7 +491,8 @@ class ModelRunner:
|
|||||||
out_cache_loc=batch.out_cache_loc,
|
out_cache_loc=batch.out_cache_loc,
|
||||||
top_logprobs_nums=batch.top_logprobs_nums,
|
top_logprobs_nums=batch.top_logprobs_nums,
|
||||||
return_logprob=batch.return_logprob,
|
return_logprob=batch.return_logprob,
|
||||||
flashinfer_prefill_wrapper=self.flashinfer_prefill_wrapper,
|
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
|
||||||
|
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
|
||||||
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
||||||
)
|
)
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
|
|||||||
@@ -152,7 +152,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|||||||
if server_args.disable_disk_cache:
|
if server_args.disable_disk_cache:
|
||||||
disable_cache()
|
disable_cache()
|
||||||
if not server_args.disable_flashinfer:
|
if not server_args.disable_flashinfer:
|
||||||
assert_pkg_version("flashinfer", "0.0.7")
|
assert_pkg_version("flashinfer", "0.0.8")
|
||||||
if server_args.chat_template:
|
if server_args.chat_template:
|
||||||
# TODO: replace this with huggingface transformers template
|
# TODO: replace this with huggingface transformers template
|
||||||
load_chat_template_for_openai_api(server_args.chat_template)
|
load_chat_template_for_openai_api(server_args.chat_template)
|
||||||
|
|||||||
Reference in New Issue
Block a user