From 0f587e80d3538c2816ae243bfb8bda8e1b08cab9 Mon Sep 17 00:00:00 2001 From: Wenxuan Tan Date: Fri, 22 Aug 2025 10:25:15 -0500 Subject: [PATCH] Use Tensor Core Decode when gqa group size >= 4 (#8624) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- python/sglang/srt/layers/attention/flashinfer_backend.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 656679a52..d1e778e92 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -1263,11 +1263,12 @@ def should_use_tensor_core( # Calculate GQA group size gqa_group_size = num_attention_heads // num_kv_heads - # Determine based on dtype and GQA group size + # For Flashinfer, a GQA group size of at least 4 is needed to efficiently + # use Tensor Cores, as it fuses the head group with the token dimension in MMA. if kv_cache_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): return True elif kv_cache_dtype in (torch.float16, torch.half, torch.bfloat16): - return gqa_group_size > 4 + return gqa_group_size >= 4 else: return False