diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 656679a52..d1e778e92 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -1263,11 +1263,12 @@ def should_use_tensor_core( # Calculate GQA group size gqa_group_size = num_attention_heads // num_kv_heads - # Determine based on dtype and GQA group size + # For Flashinfer, a GQA group size of at least 4 is needed to efficiently + # use Tensor Cores, as it fuses the head group with the token dimension in MMA. if kv_cache_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): return True elif kv_cache_dtype in (torch.float16, torch.half, torch.bfloat16): - return gqa_group_size > 4 + return gqa_group_size >= 4 else: return False