2x performance improvement for large prefill & Fix workspace conflicts (#579)
This commit is contained in:
@@ -230,7 +230,7 @@ def latency_test(
|
||||
prefill_latency = time.time() - tic
|
||||
tot_latency += prefill_latency
|
||||
throughput = bench_args.input_len * bench_args.batch_size / prefill_latency
|
||||
rank_print(f"Prefill. latency: {prefill_latency:6.5f} ms, throughput: {throughput:9.2f} token/s")
|
||||
rank_print(f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s")
|
||||
|
||||
# Decode
|
||||
for i in range(output_len):
|
||||
@@ -241,13 +241,13 @@ def latency_test(
|
||||
latency = time.time() - tic
|
||||
tot_latency += latency
|
||||
throughput = bench_args.batch_size / latency
|
||||
if i < 5: rank_print(f"Decode. latency: {latency:6.5f} ms, throughput: {throughput:9.2f} token/s")
|
||||
if i < 5: rank_print(f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s")
|
||||
avg_decode_latency = (tot_latency - prefill_latency) / output_len
|
||||
avg_decode_throughput = bench_args.batch_size / avg_decode_latency
|
||||
rank_print(f"Decode. avg latency: {avg_decode_latency:6.5f} ms, avg throughput: {avg_decode_throughput:9.2f} token/s")
|
||||
rank_print(f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s")
|
||||
|
||||
throughput = (bench_args.input_len + bench_args.output_len) * bench_args.batch_size / tot_latency
|
||||
rank_print(f"Total. latency: {tot_latency:6.3f} ms, throughput: {throughput:9.2f} token/s")
|
||||
rank_print(f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s")
|
||||
|
||||
# Warm up
|
||||
run_once(4)
|
||||
|
||||
@@ -35,5 +35,8 @@ class GlobalConfig:
|
||||
self.new_token_ratio_decay = 0.0001
|
||||
self.new_token_ratio_recovery = 0.05
|
||||
|
||||
# The threshold (number of tokens) to trigger layer-wise cuda sync.
|
||||
# This can improve the speed for large batch sizes during prefill.
|
||||
self.layer_sync_threshold = 8192
|
||||
|
||||
global_config = GlobalConfig()
|
||||
|
||||
@@ -4,6 +4,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
|
||||
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
||||
from sglang.srt.layers.token_attention import token_attention_fwd
|
||||
@@ -103,12 +104,29 @@ class RadixAttention(nn.Module):
|
||||
def prefill_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
||||
self.store_kv_cache(k, v, input_metadata)
|
||||
|
||||
o = input_metadata.flashinfer_prefill_wrapper.forward(
|
||||
o1, s1 = input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
|
||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
|
||||
k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
|
||||
v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
|
||||
logits_soft_cap=self.logit_cap,
|
||||
)
|
||||
|
||||
if input_metadata.no_prefix:
|
||||
o = o1
|
||||
else:
|
||||
o2, s2 = input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(
|
||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
|
||||
causal=False,
|
||||
logits_soft_cap=self.logit_cap,
|
||||
)
|
||||
|
||||
from flashinfer.cascade import merge_state
|
||||
o, _ = merge_state(o1, s1, o2, s2)
|
||||
|
||||
if input_metadata.total_num_tokens >= global_config.layer_sync_threshold:
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
||||
|
||||
def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
||||
|
||||
@@ -65,23 +65,33 @@ class InputMetadata:
|
||||
kv_indptr: torch.Tensor = None
|
||||
kv_indices: torch.Tensor = None
|
||||
kv_last_page_len: torch.Tensor = None
|
||||
flashinfer_prefill_wrapper: "BatchPrefillWithPagedKVCacheWrapper" = None
|
||||
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
|
||||
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
|
||||
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
||||
|
||||
def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim):
|
||||
if (
|
||||
self.forward_mode == ForwardMode.PREFILL
|
||||
or self.forward_mode == ForwardMode.EXTEND
|
||||
):
|
||||
paged_kernel_lens = self.prefix_lens
|
||||
self.no_prefix = torch.all(self.prefix_lens == 0)
|
||||
else:
|
||||
paged_kernel_lens = self.seq_lens
|
||||
|
||||
self.kv_indptr = torch.zeros(
|
||||
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0)
|
||||
self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
self.kv_last_page_len = torch.ones(
|
||||
(self.batch_size,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
|
||||
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
||||
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
||||
self.kv_indices = torch.cat(
|
||||
[
|
||||
self.req_to_token_pool.req_to_token[
|
||||
req_pool_indices_cpu[i], : seq_lens_cpu[i]
|
||||
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
|
||||
]
|
||||
for i in range(self.batch_size)
|
||||
],
|
||||
@@ -92,13 +102,24 @@ class InputMetadata:
|
||||
self.forward_mode == ForwardMode.PREFILL
|
||||
or self.forward_mode == ForwardMode.EXTEND
|
||||
):
|
||||
# extend part
|
||||
self.qo_indptr = torch.zeros(
|
||||
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
|
||||
|
||||
self.flashinfer_prefill_wrapper.end_forward()
|
||||
self.flashinfer_prefill_wrapper.begin_forward(
|
||||
self.flashinfer_prefill_wrapper_ragged.end_forward()
|
||||
self.flashinfer_prefill_wrapper_ragged.begin_forward(
|
||||
self.qo_indptr,
|
||||
self.qo_indptr.clone(),
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
)
|
||||
|
||||
# cached part
|
||||
self.flashinfer_prefill_wrapper_paged.end_forward()
|
||||
self.flashinfer_prefill_wrapper_paged.begin_forward(
|
||||
self.qo_indptr,
|
||||
self.kv_indptr,
|
||||
self.kv_indices,
|
||||
@@ -143,7 +164,8 @@ class InputMetadata:
|
||||
out_cache_cont_end=None,
|
||||
top_logprobs_nums=None,
|
||||
return_logprob=False,
|
||||
flashinfer_prefill_wrapper=None,
|
||||
flashinfer_prefill_wrapper_ragged=None,
|
||||
flashinfer_prefill_wrapper_paged=None,
|
||||
flashinfer_decode_wrapper=None,
|
||||
):
|
||||
batch_size = len(req_pool_indices)
|
||||
@@ -194,7 +216,8 @@ class InputMetadata:
|
||||
other_kv_index=other_kv_index,
|
||||
return_logprob=return_logprob,
|
||||
top_logprobs_nums=top_logprobs_nums,
|
||||
flashinfer_prefill_wrapper=flashinfer_prefill_wrapper,
|
||||
flashinfer_prefill_wrapper_ragged=flashinfer_prefill_wrapper_ragged,
|
||||
flashinfer_prefill_wrapper_paged=flashinfer_prefill_wrapper_paged,
|
||||
flashinfer_decode_wrapper=flashinfer_decode_wrapper,
|
||||
)
|
||||
|
||||
@@ -361,6 +384,7 @@ class ModelRunner:
|
||||
def init_flash_infer(self):
|
||||
if not global_server_args_dict.get("disable_flashinfer", False):
|
||||
from flashinfer import (
|
||||
BatchPrefillWithRaggedKVCacheWrapper,
|
||||
BatchPrefillWithPagedKVCacheWrapper,
|
||||
BatchDecodeWithPagedKVCacheWrapper,
|
||||
)
|
||||
@@ -373,17 +397,21 @@ class ModelRunner:
|
||||
else:
|
||||
use_tensor_cores = False
|
||||
|
||||
workspace_buffer = torch.empty(
|
||||
128 * 1024 * 1024, dtype=torch.int8, device="cuda"
|
||||
workspace_buffers = torch.empty(
|
||||
3, 96 * 1024 * 1024, dtype=torch.uint8, device="cuda"
|
||||
)
|
||||
self.flashinfer_prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
|
||||
workspace_buffer, "NHD"
|
||||
self.flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
||||
workspace_buffers[0], "NHD"
|
||||
)
|
||||
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
||||
workspace_buffers[1], "NHD"
|
||||
)
|
||||
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer, "NHD", use_tensor_cores=use_tensor_cores
|
||||
workspace_buffers[2], "NHD", use_tensor_cores=use_tensor_cores
|
||||
)
|
||||
else:
|
||||
self.flashinfer_prefill_wrapper = self.flashinfer_decode_wrapper = None
|
||||
self.flashinfer_prefill_wrapper_ragged = self.flashinfer_prefill_wrapper_paged = None
|
||||
self.flashinfer_decode_wrapper = None
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_prefill(self, batch: Batch):
|
||||
@@ -398,7 +426,8 @@ class ModelRunner:
|
||||
out_cache_loc=batch.out_cache_loc,
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
return_logprob=batch.return_logprob,
|
||||
flashinfer_prefill_wrapper=self.flashinfer_prefill_wrapper,
|
||||
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
|
||||
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
|
||||
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
||||
)
|
||||
return self.model.forward(
|
||||
@@ -418,7 +447,8 @@ class ModelRunner:
|
||||
out_cache_loc=batch.out_cache_loc,
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
return_logprob=batch.return_logprob,
|
||||
flashinfer_prefill_wrapper=self.flashinfer_prefill_wrapper,
|
||||
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
|
||||
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
|
||||
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
||||
)
|
||||
return self.model.forward(
|
||||
@@ -440,7 +470,8 @@ class ModelRunner:
|
||||
out_cache_cont_end=batch.out_cache_cont_end,
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
return_logprob=batch.return_logprob,
|
||||
flashinfer_prefill_wrapper=self.flashinfer_prefill_wrapper,
|
||||
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
|
||||
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
|
||||
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
||||
)
|
||||
return self.model.forward(
|
||||
@@ -460,7 +491,8 @@ class ModelRunner:
|
||||
out_cache_loc=batch.out_cache_loc,
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
return_logprob=batch.return_logprob,
|
||||
flashinfer_prefill_wrapper=self.flashinfer_prefill_wrapper,
|
||||
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
|
||||
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
|
||||
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
||||
)
|
||||
return self.model.forward(
|
||||
|
||||
@@ -152,7 +152,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
||||
if server_args.disable_disk_cache:
|
||||
disable_cache()
|
||||
if not server_args.disable_flashinfer:
|
||||
assert_pkg_version("flashinfer", "0.0.7")
|
||||
assert_pkg_version("flashinfer", "0.0.8")
|
||||
if server_args.chat_template:
|
||||
# TODO: replace this with huggingface transformers template
|
||||
load_chat_template_for_openai_api(server_args.chat_template)
|
||||
|
||||
Reference in New Issue
Block a user