diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 605ec643b..983eff0e3 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -12,7 +12,8 @@ from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetada class RadixAttention(nn.Module): def __init__( - self, num_heads, head_dim, scaling, num_kv_heads, layer_id, logit_cap=-1 + self, num_heads: int, head_dim: int, scaling: float, num_kv_heads: int, + layer_id: int, logit_cap: int = -1 ): super().__init__() self.tp_q_head_num = num_heads @@ -20,7 +21,6 @@ class RadixAttention(nn.Module): self.tp_v_head_num = num_kv_heads self.head_dim = head_dim self.layer_id = layer_id - self.logit_cap = logit_cap assert np.allclose(scaling, 1.0 / (head_dim**0.5)) @@ -30,10 +30,17 @@ class RadixAttention(nn.Module): self.prefill_forward = self.prefill_forward_flashinfer self.extend_forward = self.prefill_forward_flashinfer self.decode_forward = self.decode_forward_flashinfer + # flashinfer only accepts a boolean logit_cap argument + if logit_cap > 0: + assert logit_cap == 30 + self.logit_cap = True + else: + self.logit_cap = False else: self.prefill_forward = self.prefill_forward_triton self.extend_forward = self.extend_forward_triton self.decode_forward = self.decode_forward_triton + self.logit_cap = logit_cap def prefill_forward_triton(self, q, k, v, input_metadata: InputMetadata): o = torch.empty_like(q) @@ -100,9 +107,10 @@ class RadixAttention(nn.Module): def prefill_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): self.store_kv_cache(k, v, input_metadata) - o = input_metadata.prefill_wrapper.forward( + o = input_metadata.flashinfer_prefill_wrapper.forward( q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), input_metadata.token_to_kv_pool.kv_data[self.layer_id], + logits_cap=self.logit_cap, ) return o.view(-1, self.tp_q_head_num * self.head_dim) @@ -110,9 +118,10 @@ class RadixAttention(nn.Module): def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): self.store_kv_cache(k, v, input_metadata) - o = input_metadata.decode_wrapper.forward( + o = input_metadata.flashinfer_decode_wrapper.forward( q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), input_metadata.token_to_kv_pool.kv_data[self.layer_id], + logits_cap=self.logit_cap, ) return o.view(-1, self.tp_q_head_num * self.head_dim) diff --git a/python/sglang/srt/managers/controller/model_runner.py b/python/sglang/srt/managers/controller/model_runner.py index 95f6b4e5a..b28f30806 100644 --- a/python/sglang/srt/managers/controller/model_runner.py +++ b/python/sglang/srt/managers/controller/model_runner.py @@ -6,7 +6,7 @@ import logging import pkgutil from dataclasses import dataclass from functools import lru_cache -from typing import List, Optional, Type +from typing import List, Optional, Type, Any import numpy as np import torch @@ -34,7 +34,6 @@ global_server_args_dict = {} @dataclass class InputMetadata: - model_runner: "ModelRunner" forward_mode: ForwardMode batch_size: int total_num_tokens: int @@ -65,15 +64,10 @@ class InputMetadata: kv_indptr: torch.Tensor = None kv_indices: torch.Tensor = None kv_last_page_len: torch.Tensor = None - prefill_wrapper = None - decode_wrapper = None - - def init_flashinfer_args(self, tp_size): - from flashinfer import ( - BatchDecodeWithPagedKVCacheWrapper, - BatchPrefillWithPagedKVCacheWrapper, - ) + flashinfer_prefill_wrapper: "BatchPrefillWithPagedKVCacheWrapper" = None + flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None + def init_flashinfer_args(self, num_attention_heads, num_key_value_heads, head_dim): self.kv_indptr = torch.zeros( (self.batch_size + 1,), dtype=torch.int32, device="cuda" ) @@ -93,9 +87,6 @@ class InputMetadata: dim=0, ).contiguous() - workspace_buffer = torch.empty( - 32 * 1024 * 1024, dtype=torch.int8, device="cuda" - ) if ( self.forward_mode == ForwardMode.PREFILL or self.forward_mode == ForwardMode.EXTEND @@ -104,34 +95,30 @@ class InputMetadata: (self.batch_size + 1,), dtype=torch.int32, device="cuda" ) self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0) - self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, "NHD" - ) - args = [ + + self.flashinfer_prefill_wrapper.end_forward() + self.flashinfer_prefill_wrapper.begin_forward( self.qo_indptr, self.kv_indptr, self.kv_indices, self.kv_last_page_len, - self.model_runner.model_config.num_attention_heads // tp_size, - self.model_runner.model_config.num_key_value_heads // tp_size, - self.model_runner.model_config.head_dim, - ] - - self.prefill_wrapper.begin_forward(*args) - else: - self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, "NHD" + num_attention_heads, + num_key_value_heads, + head_dim, + 1 ) - self.decode_wrapper.begin_forward( + else: + self.flashinfer_decode_wrapper.end_forward() + self.flashinfer_decode_wrapper.begin_forward( self.kv_indptr, self.kv_indices, self.kv_last_page_len, - self.model_runner.model_config.num_attention_heads // tp_size, - self.model_runner.model_config.num_key_value_heads // tp_size, - self.model_runner.model_config.head_dim, + num_attention_heads, + num_key_value_heads, + head_dim, 1, - "NONE", - "float16", + pos_encoding_mode="NONE", + data_type="float16", ) def init_extend_args(self): @@ -155,6 +142,8 @@ class InputMetadata: out_cache_cont_end=None, top_logprobs_nums=None, return_logprob=False, + flashinfer_prefill_wrapper=None, + flashinfer_decode_wrapper=None, ): batch_size = len(req_pool_indices) start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") @@ -187,7 +176,6 @@ class InputMetadata: other_kv_index = None ret = cls( - model_runner=model_runner, forward_mode=forward_mode, batch_size=batch_size, total_num_tokens=total_num_tokens, @@ -205,13 +193,19 @@ class InputMetadata: other_kv_index=other_kv_index, return_logprob=return_logprob, top_logprobs_nums=top_logprobs_nums, + flashinfer_prefill_wrapper=flashinfer_prefill_wrapper, + flashinfer_decode_wrapper=flashinfer_decode_wrapper, ) if forward_mode == ForwardMode.EXTEND: ret.init_extend_args() if global_server_args_dict.get("enable_flashinfer", False): - ret.init_flashinfer_args(tp_size) + ret.init_flashinfer_args( + model_runner.model_config.num_attention_heads // tp_size, + model_runner.model_config.num_key_value_heads // tp_size, + model_runner.model_config.head_dim + ) return ret @@ -234,12 +228,7 @@ class ModelRunner: self.tp_size = tp_size self.nccl_port = nccl_port self.server_args = server_args - - global global_server_args_dict - global_server_args_dict = { - "enable_flashinfer": server_args.enable_flashinfer, - "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32, - } + self.is_multimodal_model = is_multimodal_model(self.model_config) # Init torch distributed logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.") @@ -269,9 +258,17 @@ class ModelRunner: "The memory capacity is unbalanced. Some GPUs may be occupied by other processes." ) + # Set some global args + global global_server_args_dict + global_server_args_dict = { + "enable_flashinfer": server_args.enable_flashinfer, + "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32, + } + + # Load the model and create memory pool self.load_model() self.init_memory_pool(total_gpu_memory) - self.is_multimodal_model = is_multimodal_model(self.model_config) + self.init_flash_infer() def load_model(self): logger.info( @@ -347,6 +344,22 @@ class ModelRunner: f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" ) + def init_flash_infer(self): + if global_server_args_dict.get("enable_flashinfer", False): + from flashinfer import ( + BatchPrefillWithPagedKVCacheWrapper, + BatchDecodeWithPagedKVCacheWrapper, + ) + workspace_buffer = torch.empty( + 32 * 1024 * 1024, dtype=torch.int8, device="cuda" + ) + self.flashinfer_prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, "NHD" + ) + self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, "NHD" + ) + @torch.inference_mode() def forward_prefill(self, batch: Batch): input_metadata = InputMetadata.create( @@ -360,6 +373,8 @@ class ModelRunner: out_cache_loc=batch.out_cache_loc, top_logprobs_nums=batch.top_logprobs_nums, return_logprob=batch.return_logprob, + flashinfer_prefill_wrapper=self.flashinfer_prefill_wrapper, + flashinfer_decode_wrapper=self.flashinfer_decode_wrapper, ) return self.model.forward( batch.input_ids, input_metadata.positions, input_metadata @@ -378,6 +393,8 @@ class ModelRunner: out_cache_loc=batch.out_cache_loc, top_logprobs_nums=batch.top_logprobs_nums, return_logprob=batch.return_logprob, + flashinfer_prefill_wrapper=self.flashinfer_prefill_wrapper, + flashinfer_decode_wrapper=self.flashinfer_decode_wrapper, ) return self.model.forward( batch.input_ids, input_metadata.positions, input_metadata @@ -398,6 +415,8 @@ class ModelRunner: out_cache_cont_end=batch.out_cache_cont_end, top_logprobs_nums=batch.top_logprobs_nums, return_logprob=batch.return_logprob, + flashinfer_prefill_wrapper=self.flashinfer_prefill_wrapper, + flashinfer_decode_wrapper=self.flashinfer_decode_wrapper, ) return self.model.forward( batch.input_ids, input_metadata.positions, input_metadata @@ -416,6 +435,8 @@ class ModelRunner: out_cache_loc=batch.out_cache_loc, top_logprobs_nums=batch.top_logprobs_nums, return_logprob=batch.return_logprob, + flashinfer_prefill_wrapper=self.flashinfer_prefill_wrapper, + flashinfer_decode_wrapper=self.flashinfer_decode_wrapper, ) return self.model.forward( batch.input_ids, diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 4e088a350..cdcd33be1 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -150,7 +150,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg if server_args.disable_disk_cache: disable_cache() if server_args.enable_flashinfer: - assert_pkg_version("flashinfer", "0.0.4") + assert_pkg_version("flashinfer", "0.0.5") if server_args.chat_template: # TODO: replace this with huggingface transformers template load_chat_template_for_openai_api(server_args.chat_template)