From d2f8bfb2e142348b38cdb4f8c5cd82f0ef3dcbff Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Thu, 20 Jun 2024 23:19:52 -0700 Subject: [PATCH] Follow-up fixes for flashinfer 0.0.5 (#556) --- .../srt/managers/controller/model_runner.py | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/managers/controller/model_runner.py b/python/sglang/srt/managers/controller/model_runner.py index b28f30806..ecca79976 100644 --- a/python/sglang/srt/managers/controller/model_runner.py +++ b/python/sglang/srt/managers/controller/model_runner.py @@ -67,7 +67,7 @@ class InputMetadata: flashinfer_prefill_wrapper: "BatchPrefillWithPagedKVCacheWrapper" = None flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None - def init_flashinfer_args(self, num_attention_heads, num_key_value_heads, head_dim): + def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim): self.kv_indptr = torch.zeros( (self.batch_size + 1,), dtype=torch.int32, device="cuda" ) @@ -102,8 +102,8 @@ class InputMetadata: self.kv_indptr, self.kv_indices, self.kv_last_page_len, - num_attention_heads, - num_key_value_heads, + num_qo_heads, + num_kv_heads, head_dim, 1 ) @@ -113,8 +113,8 @@ class InputMetadata: self.kv_indptr, self.kv_indices, self.kv_last_page_len, - num_attention_heads, - num_key_value_heads, + num_qo_heads, + num_kv_heads, head_dim, 1, pos_encoding_mode="NONE", @@ -203,7 +203,7 @@ class InputMetadata: if global_server_args_dict.get("enable_flashinfer", False): ret.init_flashinfer_args( model_runner.model_config.num_attention_heads // tp_size, - model_runner.model_config.num_key_value_heads // tp_size, + model_runner.model_config.get_num_kv_heads(tp_size), model_runner.model_config.head_dim ) @@ -350,6 +350,15 @@ class ModelRunner: BatchPrefillWithPagedKVCacheWrapper, BatchDecodeWithPagedKVCacheWrapper, ) + from flashinfer.decode import _grouped_size_compiled_for_decode_kernels + + if not _grouped_size_compiled_for_decode_kernels( + self.model_config.num_attention_heads // self.tp_size, + self.model_config.get_num_kv_heads(self.tp_size)): + use_tensor_cores = True + else: + use_tensor_cores = False + workspace_buffer = torch.empty( 32 * 1024 * 1024, dtype=torch.int8, device="cuda" ) @@ -357,8 +366,10 @@ class ModelRunner: workspace_buffer, "NHD" ) self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, "NHD" + workspace_buffer, "NHD", use_tensor_cores=use_tensor_cores ) + else: + self.flashinfer_prefill_wrapper = self.flashinfer_decode_wrapper = None @torch.inference_mode() def forward_prefill(self, batch: Batch):