From debbdb5178f347159b42550298806625d0989ff8 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Thu, 12 Sep 2024 00:38:18 -0700 Subject: [PATCH] kernel: use tensor cores for flashinfer gqa kernels (#1403) --- python/sglang/srt/layers/attention_backend.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/attention_backend.py b/python/sglang/srt/layers/attention_backend.py index 3fc79fe0d..af3986bc2 100644 --- a/python/sglang/srt/layers/attention_backend.py +++ b/python/sglang/srt/layers/attention_backend.py @@ -79,9 +79,17 @@ class FlashInferAttnBackend(AttentionBackend): super().__init__() self.model_runner = model_runner - if not _grouped_size_compiled_for_decode_kernels( - model_runner.model_config.num_attention_heads // model_runner.tp_size, - model_runner.model_config.get_num_kv_heads(model_runner.tp_size), + local_num_qo_heads = ( + model_runner.model_config.num_attention_heads // model_runner.tp_size + ) + local_num_kv_heads = model_runner.model_config.get_num_kv_heads( + model_runner.tp_size + ) + if ( + not _grouped_size_compiled_for_decode_kernels( + local_num_qo_heads, local_num_kv_heads + ) + or local_num_qo_heads // local_num_kv_heads > 4 ): self.decode_use_tensor_cores = True else: