diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 16929498b..258659efa 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -18,7 +18,11 @@ import triton.language as tl from sglang.global_config import global_config from sglang.srt.layers.attention import AttentionBackend from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.utils import get_bool_env_var, is_flashinfer_available +from sglang.srt.utils import ( + get_bool_env_var, + is_flashinfer_available, + should_use_tensor_core, +) if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention @@ -31,7 +35,6 @@ if is_flashinfer_available(): BatchPrefillWithRaggedKVCacheWrapper, ) from flashinfer.cascade import merge_state - from flashinfer.decode import _grouped_size_compiled_for_decode_kernels class WrapperDispatch(Enum): @@ -45,19 +48,14 @@ class FlashInferAttnBackend(AttentionBackend): def __init__(self, model_runner: ModelRunner): super().__init__() - # Parse constants - if "SGLANG_FLASHINFER_USE_TENSOR_CORE" in os.environ: - self.decode_use_tensor_cores = get_bool_env_var( - "SGLANG_FLASHINFER_USE_TENSOR_CORE" - ) - else: - if not _grouped_size_compiled_for_decode_kernels( - model_runner.model_config.num_attention_heads // model_runner.tp_size, - model_runner.model_config.get_num_kv_heads(model_runner.tp_size), - ): - self.decode_use_tensor_cores = True - else: - self.decode_use_tensor_cores = False + self.decode_use_tensor_cores = should_use_tensor_core( + kv_cache_dtype=model_runner.kv_cache_dtype, + num_attention_heads=model_runner.model_config.num_attention_heads + // model_runner.tp_size, + num_kv_heads=model_runner.model_config.get_num_kv_heads( + model_runner.tp_size + ), + ) self.max_context_len = model_runner.model_config.context_len diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index f9286e7f3..19ea78015 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1108,3 +1108,51 @@ def cuda_device_count_stateless() -> int: # This can be removed and simply replaced with torch.cuda.get_device_count # after https://github.com/pytorch/pytorch/pull/122815 is released. return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES", None)) + + +def should_use_tensor_core( + kv_cache_dtype: torch.dtype, + num_attention_heads: int, + num_kv_heads: int, +) -> bool: + """ + Determine whether to use tensor cores for attention computation. + + Args: + kv_cache_dtype: Data type of the KV cache + num_attention_heads: Number of attention heads + num_kv_heads: Number of key/value heads + + Returns: + bool: Whether to use tensor cores + """ + # Try to use environment variable first + env_override = os.environ.get("SGLANG_FLASHINFER_USE_TENSOR_CORE") + if env_override is not None: + return env_override.lower() == "true" + + # Try to use _grouped_size_compiled_for_decode_kernels if available + # This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug + try: + from flashinfer.decode import _grouped_size_compiled_for_decode_kernels + + if not _grouped_size_compiled_for_decode_kernels( + num_attention_heads, + num_kv_heads, + ): + return True + else: + return False + except (ImportError, AttributeError): + pass + + # Calculate GQA group size + gqa_group_size = num_attention_heads // num_kv_heads + + # Determine based on dtype and GQA group size + if kv_cache_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + return True + elif kv_cache_dtype in (torch.float16, torch.half, torch.bfloat16): + return gqa_group_size > 4 + else: + return False diff --git a/scripts/deprecated/test_flashinfer.py b/scripts/deprecated/test_flashinfer.py index 2929d7bb8..7e7282a3a 100644 --- a/scripts/deprecated/test_flashinfer.py +++ b/scripts/deprecated/test_flashinfer.py @@ -11,6 +11,7 @@ from sglang.srt.layers.attention.triton_ops.extend_attention import ( extend_attention_fwd, redundant_attention, ) +from sglang.srt.utils import should_use_tensor_core flashinfer_prefill_wrapper = None flashinfer_decode_wrapper = None @@ -195,10 +196,9 @@ def test_batch_decode_with_paged_kv_cache( def init_flashinfer(num_attention_heads, num_kv_heads): - if not _grouped_size_compiled_for_decode_kernels(num_attention_heads, num_kv_heads): - use_tensor_cores = True - else: - use_tensor_cores = False + use_tensor_cores = should_use_tensor_core( + torch.half, num_attention_heads, num_kv_heads + ) workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda")