From 396a69240fc99e54d079d9f623ad83239eb39167 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Fri, 12 Jul 2024 18:21:11 -0700 Subject: [PATCH] Cleanup attention backend: flashinfer and triton (#611) --- python/sglang/srt/layers/radix_attention.py | 15 +- python/sglang/srt/layers/token_attention.py | 3 +- .../srt/managers/controller/infer_batch.py | 259 ++++++++++-------- .../srt/managers/controller/model_runner.py | 64 ++--- 4 files changed, 180 insertions(+), 161 deletions(-) diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 2c3e91af1..2964ae5b2 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -1,6 +1,5 @@ """Radix attention.""" -import numpy as np import torch from flashinfer.cascade import merge_state from torch import nn @@ -51,13 +50,13 @@ class RadixAttention(nn.Module): input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id), input_metadata.req_to_token_pool.req_to_token, input_metadata.req_pool_indices, - input_metadata.start_loc, + input_metadata.triton_start_loc, input_metadata.seq_lens, - input_metadata.prefix_lens, + input_metadata.triton_prefix_lens, input_metadata.extend_start_loc, input_metadata.extend_seq_lens, - input_metadata.max_seq_len, - input_metadata.max_extend_len, + input_metadata.triton_max_seq_len, + input_metadata.triton_max_extend_len, sm_scale=self.scaling, logit_cap=self.logit_cap, ) @@ -75,9 +74,9 @@ class RadixAttention(nn.Module): o.view(-1, self.tp_q_head_num, self.head_dim), input_metadata.req_to_token_pool.req_to_token, input_metadata.req_pool_indices, - input_metadata.start_loc, + input_metadata.triton_start_loc, input_metadata.seq_lens, - input_metadata.max_seq_len, + input_metadata.triton_max_seq_len, input_metadata.total_num_tokens, sm_scale=self.scaling, logit_cap=self.logit_cap, @@ -95,7 +94,7 @@ class RadixAttention(nn.Module): logits_soft_cap=self.logit_cap, ) - if input_metadata.no_prefix: + if input_metadata.extend_no_prefix: o = o1 else: o2, s2 = input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse( diff --git a/python/sglang/srt/layers/token_attention.py b/python/sglang/srt/layers/token_attention.py index 9d7bda145..2d0250114 100644 --- a/python/sglang/srt/layers/token_attention.py +++ b/python/sglang/srt/layers/token_attention.py @@ -312,7 +312,7 @@ def token_attention_fwd( b_seq_len, max_len_in_batch, total_num_tokens, - sm_scale=None, + sm_scale, logit_cap=-1, att_m=None, ): @@ -320,7 +320,6 @@ def token_attention_fwd( att_m = torch.empty( (q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda" ) - sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale _token_att_m_fwd( q, diff --git a/python/sglang/srt/managers/controller/infer_batch.py b/python/sglang/srt/managers/controller/infer_batch.py index 27d041d1d..b7e05a5dc 100644 --- a/python/sglang/srt/managers/controller/infer_batch.py +++ b/python/sglang/srt/managers/controller/infer_batch.py @@ -75,6 +75,7 @@ class Req: """Store all inforamtion of a request.""" def __init__(self, rid, origin_input_text, origin_input_ids): + # Input and output info self.rid = rid self.origin_input_text = origin_input_text self.origin_input_ids_unpadded = origin_input_ids # Before image padding @@ -97,6 +98,11 @@ class Req: self.image_offset = 0 self.pad_value = None + # Prefix info + self.extend_input_len = 0 + self.prefix_indices = [] + self.last_node = None + # Sampling parameters self.sampling_params = None self.stream = False @@ -105,11 +111,6 @@ class Req: self.tokenizer = None self.finished_reason = None - # Prefix info - self.extend_input_len = 0 - self.prefix_indices = [] - self.last_node = None - # Logprobs self.return_logprob = False self.logprob_start_len = 0 @@ -261,35 +262,36 @@ class Req: class Batch: """Store all inforamtion of a batch.""" + # Request, memory pool, and cache reqs: List[Req] req_to_token_pool: ReqToTokenPool token_to_kv_pool: TokenToKVPool tree_cache: RadixCache - # batched arguments to model runner + # Batched arguments to model runner input_ids: torch.Tensor = None req_pool_indices: torch.Tensor = None seq_lens: torch.Tensor = None prefix_lens: torch.Tensor = None position_ids_offsets: torch.Tensor = None out_cache_loc: torch.Tensor = None - out_cache_cont_start: torch.Tensor = None - out_cache_cont_end: torch.Tensor = None + out_cache_cont_start: int = None + out_cache_cont_end: int = None - # for processing logprobs + # For processing logprobs return_logprob: bool = False top_logprobs_nums: List[int] = None - # for multimodal + # For multimodal pixel_values: List[torch.Tensor] = None image_sizes: List[List[int]] = None image_offsets: List[int] = None - # other arguments for control + # Other arguments for control output_ids: torch.Tensor = None extend_num_tokens: int = None - # batched sampling params + # Batched sampling params temperatures: torch.Tensor = None top_ps: torch.Tensor = None top_ks: torch.Tensor = None @@ -312,8 +314,8 @@ class Batch: def is_empty(self): return len(self.reqs) == 0 - # whether batch has at least 1 streaming request def has_stream(self) -> bool: + # Return whether batch has at least 1 streaming request return any(r.stream for r in self.reqs) def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor): @@ -347,7 +349,7 @@ class Batch: position_ids_offsets = torch.zeros((bs,), dtype=torch.int32, device=device) - # Alloc mem + # Allocate memory seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens) extend_num_tokens = seq_lens.sum() - prefix_lens.sum() out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens) @@ -703,7 +705,6 @@ def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor return probs_sort, probs_idx - @dataclass class InputMetadata: """Store all inforamtion of a forward pass.""" @@ -711,110 +712,37 @@ class InputMetadata: forward_mode: ForwardMode batch_size: int total_num_tokens: int - max_seq_len: int req_pool_indices: torch.Tensor - start_loc: torch.Tensor seq_lens: torch.Tensor - prefix_lens: torch.Tensor positions: torch.Tensor req_to_token_pool: ReqToTokenPool token_to_kv_pool: TokenToKVPool - # for extend - extend_seq_lens: torch.Tensor = None - extend_start_loc: torch.Tensor = None - max_extend_len: int = 0 + # For extend + extend_seq_lens: torch.Tensor + extend_start_loc: torch.Tensor + extend_no_prefix: bool + # Output location of the KV cache out_cache_loc: torch.Tensor = None - out_cache_cont_start: torch.Tensor = None - out_cache_cont_end: torch.Tensor = None + out_cache_cont_start: int = None + out_cache_cont_end: int = None + # Output options return_logprob: bool = False top_logprobs_nums: List[int] = None - # for flashinfer - qo_indptr: torch.Tensor = None - kv_indptr: torch.Tensor = None - kv_indices: torch.Tensor = None - kv_last_page_len: torch.Tensor = None + # Trition attention backend + triton_max_seq_len: int = 0 + triton_max_extend_len: int = 0 + triton_start_loc: torch.Tensor = None + triton_prefix_lens: torch.Tensor = None + + # FlashInfer attention backend flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None - def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim): - if self.forward_mode == ForwardMode.DECODE: - paged_kernel_lens = self.seq_lens - else: - paged_kernel_lens = self.prefix_lens - self.no_prefix = torch.all(self.prefix_lens == 0) - - kv_indptr = torch.zeros( - (self.batch_size + 1,), dtype=torch.int32, device="cuda" - ) - kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) - req_pool_indices_cpu = self.req_pool_indices.cpu().numpy() - paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy() - kv_indices = torch.cat( - [ - self.req_to_token_pool.req_to_token[ - req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i] - ] - for i in range(self.batch_size) - ], - dim=0, - ).contiguous() - kv_last_page_len = torch.ones( - (self.batch_size,), dtype=torch.int32, device="cuda" - ) - - if self.forward_mode == ForwardMode.DECODE: - self.flashinfer_decode_wrapper.end_forward() - self.flashinfer_decode_wrapper.begin_forward( - kv_indptr, - kv_indices, - kv_last_page_len, - num_qo_heads, - num_kv_heads, - head_dim, - 1, - pos_encoding_mode="NONE", - data_type=self.token_to_kv_pool.kv_data[0].dtype, - ) - else: - # extend part - qo_indptr = torch.zeros( - (self.batch_size + 1,), dtype=torch.int32, device="cuda" - ) - qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0) - - self.flashinfer_prefill_wrapper_ragged.end_forward() - self.flashinfer_prefill_wrapper_ragged.begin_forward( - qo_indptr, - qo_indptr, - num_qo_heads, - num_kv_heads, - head_dim, - ) - - # cached part - self.flashinfer_prefill_wrapper_paged.end_forward() - self.flashinfer_prefill_wrapper_paged.begin_forward( - qo_indptr, - kv_indptr, - kv_indices, - kv_last_page_len, - num_qo_heads, - num_kv_heads, - head_dim, - 1, - ) - - def init_extend_args(self): - self.extend_seq_lens = self.seq_lens - self.prefix_lens - self.extend_start_loc = torch.zeros_like(self.seq_lens) - self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0) - self.max_extend_len = int(torch.max(self.extend_seq_lens)) - @classmethod def create( cls, @@ -830,14 +758,20 @@ class InputMetadata: top_logprobs_nums=None, return_logprob=False, ): + if not model_runner.server_args.disable_flashinfer: + init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens) + batch_size = len(req_pool_indices) - start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") - start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0) - total_num_tokens = int(torch.sum(seq_lens)) - max_seq_len = int(torch.max(seq_lens)) if forward_mode == ForwardMode.DECODE: positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64) + extend_seq_lens = extend_start_loc = extend_no_prefix = None + if not model_runner.server_args.disable_flashinfer: + # This variable is not needed in this case, + # we do not compute it to make it compatbile with cuda graph. + total_num_tokens = None + else: + total_num_tokens = int(torch.sum(seq_lens)) else: seq_lens_cpu = seq_lens.cpu().numpy() prefix_lens_cpu = prefix_lens.cpu().numpy() @@ -855,22 +789,27 @@ class InputMetadata: ), device="cuda", ) + extend_seq_lens = seq_lens - prefix_lens + extend_start_loc = torch.zeros_like(seq_lens) + extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0) + extend_no_prefix = torch.all(prefix_lens == 0) + total_num_tokens = int(torch.sum(seq_lens)) ret = cls( forward_mode=forward_mode, batch_size=batch_size, total_num_tokens=total_num_tokens, - max_seq_len=max_seq_len, req_pool_indices=req_pool_indices, - start_loc=start_loc, seq_lens=seq_lens, - prefix_lens=prefix_lens, positions=positions, req_to_token_pool=model_runner.req_to_token_pool, token_to_kv_pool=model_runner.token_to_kv_pool, out_cache_loc=out_cache_loc, out_cache_cont_start=out_cache_cont_start, out_cache_cont_end=out_cache_cont_end, + extend_seq_lens=extend_seq_lens, + extend_start_loc=extend_start_loc, + extend_no_prefix=extend_no_prefix, return_logprob=return_logprob, top_logprobs_nums=top_logprobs_nums, flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged, @@ -878,14 +817,96 @@ class InputMetadata: flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper, ) - if forward_mode == ForwardMode.EXTEND: - ret.init_extend_args() - - if not global_server_args_dict.get("disable_flashinfer", False): - ret.init_flashinfer_args( - model_runner.model_config.num_attention_heads // model_runner.tp_size, - model_runner.model_config.get_num_kv_heads(model_runner.tp_size), - model_runner.model_config.head_dim, - ) + if model_runner.server_args.disable_flashinfer: + (ret.triton_max_seq_len, + ret.triton_max_extend_len, + ret.triton_start_loc, + ret.triton_prefix_lens) = init_triton_args(forward_mode, seq_lens, prefix_lens) return ret + + +def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens): + num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size + num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size) + head_dim = model_runner.model_config.head_dim + batch_size = len(req_pool_indices) + + if forward_mode == ForwardMode.DECODE: + paged_kernel_lens = seq_lens + else: + paged_kernel_lens = prefix_lens + + kv_indptr = torch.zeros( + (batch_size + 1,), dtype=torch.int32, device="cuda" + ) + kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) + req_pool_indices_cpu = req_pool_indices.cpu().numpy() + paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy() + kv_indices = torch.cat( + [ + model_runner.req_to_token_pool.req_to_token[ + req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i] + ] + for i in range(batch_size) + ], + dim=0, + ).contiguous() + kv_last_page_len = torch.ones( + (batch_size,), dtype=torch.int32, device="cuda" + ) + + if forward_mode == ForwardMode.DECODE: + model_runner.flashinfer_decode_wrapper.end_forward() + model_runner.flashinfer_decode_wrapper.begin_forward( + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + 1, + ) + else: + # extend part + qo_indptr = torch.zeros( + (batch_size + 1,), dtype=torch.int32, device="cuda" + ) + qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0) + + model_runner.flashinfer_prefill_wrapper_ragged.end_forward() + model_runner.flashinfer_prefill_wrapper_ragged.begin_forward( + qo_indptr, + qo_indptr, + num_qo_heads, + num_kv_heads, + head_dim, + ) + + # cached part + model_runner.flashinfer_prefill_wrapper_paged.end_forward() + model_runner.flashinfer_prefill_wrapper_paged.begin_forward( + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + 1, + ) + + +def init_triton_args(forward_mode, seq_lens, prefix_lens): + batch_size = len(seq_lens) + max_seq_len = int(torch.max(seq_lens)) + start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") + start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0) + + if forward_mode == ForwardMode.DECODE: + max_extend_len = None + else: + extend_seq_lens = seq_lens - prefix_lens + max_extend_len = int(torch.max(extend_seq_lens)) + + return max_seq_len, max_extend_len, start_loc, prefix_lens diff --git a/python/sglang/srt/managers/controller/model_runner.py b/python/sglang/srt/managers/controller/model_runner.py index 30a8001e7..5945282ab 100644 --- a/python/sglang/srt/managers/controller/model_runner.py +++ b/python/sglang/srt/managers/controller/model_runner.py @@ -182,39 +182,39 @@ class ModelRunner: return c def init_flash_infer(self): - if not global_server_args_dict.get("disable_flashinfer", False): - from flashinfer import ( - BatchDecodeWithPagedKVCacheWrapper, - BatchPrefillWithPagedKVCacheWrapper, - BatchPrefillWithRaggedKVCacheWrapper, - ) - from flashinfer.decode import _grouped_size_compiled_for_decode_kernels - - if not _grouped_size_compiled_for_decode_kernels( - self.model_config.num_attention_heads // self.tp_size, - self.model_config.get_num_kv_heads(self.tp_size), - ): - use_tensor_cores = True - else: - use_tensor_cores = False - - workspace_buffers = torch.empty( - 2, 96 * 1024 * 1024, dtype=torch.uint8, device="cuda" - ) - self.flashinfer_prefill_wrapper_ragged = ( - BatchPrefillWithRaggedKVCacheWrapper(workspace_buffers[0], "NHD") - ) - self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper( - workspace_buffers[1], "NHD" - ) - self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( - workspace_buffers[0], "NHD", use_tensor_cores=use_tensor_cores - ) - else: - self.flashinfer_prefill_wrapper_ragged = ( - self.flashinfer_prefill_wrapper_paged - ) = None + if self.server_args.disable_flashinfer: + self.flashinfer_prefill_wrapper_ragged = None + self.flashinfer_prefill_wrapper_paged = None self.flashinfer_decode_wrapper = None + return + + from flashinfer import ( + BatchDecodeWithPagedKVCacheWrapper, + BatchPrefillWithPagedKVCacheWrapper, + BatchPrefillWithRaggedKVCacheWrapper, + ) + from flashinfer.decode import _grouped_size_compiled_for_decode_kernels + + if not _grouped_size_compiled_for_decode_kernels( + self.model_config.num_attention_heads // self.tp_size, + self.model_config.get_num_kv_heads(self.tp_size), + ): + use_tensor_cores = True + else: + use_tensor_cores = False + + workspace_buffers = torch.empty( + 3, 96 * 1024 * 1024, dtype=torch.uint8, device="cuda" + ) + self.flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper( + workspace_buffers[0], "NHD" + ) + self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper( + workspace_buffers[1], "NHD" + ) + self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( + workspace_buffers[2], "NHD", use_tensor_cores=use_tensor_cores + ) @torch.inference_mode() def forward_extend(self, batch: Batch):