feat: add should_use_tensor_core (#2179)
This commit is contained in:
@@ -18,7 +18,11 @@ import triton.language as tl
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.layers.attention import AttentionBackend
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import get_bool_env_var, is_flashinfer_available
|
||||
from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
is_flashinfer_available,
|
||||
should_use_tensor_core,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
@@ -31,7 +35,6 @@ if is_flashinfer_available():
|
||||
BatchPrefillWithRaggedKVCacheWrapper,
|
||||
)
|
||||
from flashinfer.cascade import merge_state
|
||||
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
||||
|
||||
|
||||
class WrapperDispatch(Enum):
|
||||
@@ -45,19 +48,14 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
def __init__(self, model_runner: ModelRunner):
|
||||
super().__init__()
|
||||
|
||||
# Parse constants
|
||||
if "SGLANG_FLASHINFER_USE_TENSOR_CORE" in os.environ:
|
||||
self.decode_use_tensor_cores = get_bool_env_var(
|
||||
"SGLANG_FLASHINFER_USE_TENSOR_CORE"
|
||||
)
|
||||
else:
|
||||
if not _grouped_size_compiled_for_decode_kernels(
|
||||
model_runner.model_config.num_attention_heads // model_runner.tp_size,
|
||||
model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
|
||||
):
|
||||
self.decode_use_tensor_cores = True
|
||||
else:
|
||||
self.decode_use_tensor_cores = False
|
||||
self.decode_use_tensor_cores = should_use_tensor_core(
|
||||
kv_cache_dtype=model_runner.kv_cache_dtype,
|
||||
num_attention_heads=model_runner.model_config.num_attention_heads
|
||||
// model_runner.tp_size,
|
||||
num_kv_heads=model_runner.model_config.get_num_kv_heads(
|
||||
model_runner.tp_size
|
||||
),
|
||||
)
|
||||
|
||||
self.max_context_len = model_runner.model_config.context_len
|
||||
|
||||
|
||||
@@ -1108,3 +1108,51 @@ def cuda_device_count_stateless() -> int:
|
||||
# This can be removed and simply replaced with torch.cuda.get_device_count
|
||||
# after https://github.com/pytorch/pytorch/pull/122815 is released.
|
||||
return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES", None))
|
||||
|
||||
|
||||
def should_use_tensor_core(
|
||||
kv_cache_dtype: torch.dtype,
|
||||
num_attention_heads: int,
|
||||
num_kv_heads: int,
|
||||
) -> bool:
|
||||
"""
|
||||
Determine whether to use tensor cores for attention computation.
|
||||
|
||||
Args:
|
||||
kv_cache_dtype: Data type of the KV cache
|
||||
num_attention_heads: Number of attention heads
|
||||
num_kv_heads: Number of key/value heads
|
||||
|
||||
Returns:
|
||||
bool: Whether to use tensor cores
|
||||
"""
|
||||
# Try to use environment variable first
|
||||
env_override = os.environ.get("SGLANG_FLASHINFER_USE_TENSOR_CORE")
|
||||
if env_override is not None:
|
||||
return env_override.lower() == "true"
|
||||
|
||||
# Try to use _grouped_size_compiled_for_decode_kernels if available
|
||||
# This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
|
||||
try:
|
||||
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
||||
|
||||
if not _grouped_size_compiled_for_decode_kernels(
|
||||
num_attention_heads,
|
||||
num_kv_heads,
|
||||
):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
||||
|
||||
# Calculate GQA group size
|
||||
gqa_group_size = num_attention_heads // num_kv_heads
|
||||
|
||||
# Determine based on dtype and GQA group size
|
||||
if kv_cache_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
|
||||
return True
|
||||
elif kv_cache_dtype in (torch.float16, torch.half, torch.bfloat16):
|
||||
return gqa_group_size > 4
|
||||
else:
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user