diff --git a/python/sglang/srt/layers/attention_backend.py b/python/sglang/srt/layers/attention_backend.py index 71dbfe0e3..d7c1cf39d 100644 --- a/python/sglang/srt/layers/attention_backend.py +++ b/python/sglang/srt/layers/attention_backend.py @@ -86,17 +86,9 @@ class FlashInferAttnBackend(AttentionBackend): super().__init__() self.model_runner = model_runner - local_num_qo_heads = ( - model_runner.model_config.num_attention_heads // model_runner.tp_size - ) - local_num_kv_heads = model_runner.model_config.get_num_kv_heads( - model_runner.tp_size - ) - if ( - not _grouped_size_compiled_for_decode_kernels( - local_num_qo_heads, local_num_kv_heads - ) - or local_num_qo_heads // local_num_kv_heads > 4 + if not _grouped_size_compiled_for_decode_kernels( + model_runner.model_config.num_attention_heads // model_runner.tp_size, + model_runner.model_config.get_num_kv_heads(model_runner.tp_size), ): self.decode_use_tensor_cores = True else: