diff --git a/python/sglang/srt/layers/attention_backend.py b/python/sglang/srt/layers/attention_backend.py index 3fc79fe0d..af3986bc2 100644 --- a/python/sglang/srt/layers/attention_backend.py +++ b/python/sglang/srt/layers/attention_backend.py @@ -79,9 +79,17 @@ class FlashInferAttnBackend(AttentionBackend): super().__init__() self.model_runner = model_runner - if not _grouped_size_compiled_for_decode_kernels( - model_runner.model_config.num_attention_heads // model_runner.tp_size, - model_runner.model_config.get_num_kv_heads(model_runner.tp_size), + local_num_qo_heads = ( + model_runner.model_config.num_attention_heads // model_runner.tp_size + ) + local_num_kv_heads = model_runner.model_config.get_num_kv_heads( + model_runner.tp_size + ) + if ( + not _grouped_size_compiled_for_decode_kernels( + local_num_qo_heads, local_num_kv_heads + ) + or local_num_qo_heads // local_num_kv_heads > 4 ): self.decode_use_tensor_cores = True else: