feat: add should_use_tensor_core (#2179)
This commit is contained in:
@@ -18,7 +18,11 @@ import triton.language as tl
|
|||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.layers.attention import AttentionBackend
|
from sglang.srt.layers.attention import AttentionBackend
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.utils import get_bool_env_var, is_flashinfer_available
|
from sglang.srt.utils import (
|
||||||
|
get_bool_env_var,
|
||||||
|
is_flashinfer_available,
|
||||||
|
should_use_tensor_core,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
@@ -31,7 +35,6 @@ if is_flashinfer_available():
|
|||||||
BatchPrefillWithRaggedKVCacheWrapper,
|
BatchPrefillWithRaggedKVCacheWrapper,
|
||||||
)
|
)
|
||||||
from flashinfer.cascade import merge_state
|
from flashinfer.cascade import merge_state
|
||||||
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
|
||||||
|
|
||||||
|
|
||||||
class WrapperDispatch(Enum):
|
class WrapperDispatch(Enum):
|
||||||
@@ -45,19 +48,14 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
def __init__(self, model_runner: ModelRunner):
|
def __init__(self, model_runner: ModelRunner):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Parse constants
|
self.decode_use_tensor_cores = should_use_tensor_core(
|
||||||
if "SGLANG_FLASHINFER_USE_TENSOR_CORE" in os.environ:
|
kv_cache_dtype=model_runner.kv_cache_dtype,
|
||||||
self.decode_use_tensor_cores = get_bool_env_var(
|
num_attention_heads=model_runner.model_config.num_attention_heads
|
||||||
"SGLANG_FLASHINFER_USE_TENSOR_CORE"
|
// model_runner.tp_size,
|
||||||
)
|
num_kv_heads=model_runner.model_config.get_num_kv_heads(
|
||||||
else:
|
model_runner.tp_size
|
||||||
if not _grouped_size_compiled_for_decode_kernels(
|
),
|
||||||
model_runner.model_config.num_attention_heads // model_runner.tp_size,
|
)
|
||||||
model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
|
|
||||||
):
|
|
||||||
self.decode_use_tensor_cores = True
|
|
||||||
else:
|
|
||||||
self.decode_use_tensor_cores = False
|
|
||||||
|
|
||||||
self.max_context_len = model_runner.model_config.context_len
|
self.max_context_len = model_runner.model_config.context_len
|
||||||
|
|
||||||
|
|||||||
@@ -1108,3 +1108,51 @@ def cuda_device_count_stateless() -> int:
|
|||||||
# This can be removed and simply replaced with torch.cuda.get_device_count
|
# This can be removed and simply replaced with torch.cuda.get_device_count
|
||||||
# after https://github.com/pytorch/pytorch/pull/122815 is released.
|
# after https://github.com/pytorch/pytorch/pull/122815 is released.
|
||||||
return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES", None))
|
return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES", None))
|
||||||
|
|
||||||
|
|
||||||
|
def should_use_tensor_core(
|
||||||
|
kv_cache_dtype: torch.dtype,
|
||||||
|
num_attention_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Determine whether to use tensor cores for attention computation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kv_cache_dtype: Data type of the KV cache
|
||||||
|
num_attention_heads: Number of attention heads
|
||||||
|
num_kv_heads: Number of key/value heads
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: Whether to use tensor cores
|
||||||
|
"""
|
||||||
|
# Try to use environment variable first
|
||||||
|
env_override = os.environ.get("SGLANG_FLASHINFER_USE_TENSOR_CORE")
|
||||||
|
if env_override is not None:
|
||||||
|
return env_override.lower() == "true"
|
||||||
|
|
||||||
|
# Try to use _grouped_size_compiled_for_decode_kernels if available
|
||||||
|
# This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
|
||||||
|
try:
|
||||||
|
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
||||||
|
|
||||||
|
if not _grouped_size_compiled_for_decode_kernels(
|
||||||
|
num_attention_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
except (ImportError, AttributeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Calculate GQA group size
|
||||||
|
gqa_group_size = num_attention_heads // num_kv_heads
|
||||||
|
|
||||||
|
# Determine based on dtype and GQA group size
|
||||||
|
if kv_cache_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
|
||||||
|
return True
|
||||||
|
elif kv_cache_dtype in (torch.float16, torch.half, torch.bfloat16):
|
||||||
|
return gqa_group_size > 4
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from sglang.srt.layers.attention.triton_ops.extend_attention import (
|
|||||||
extend_attention_fwd,
|
extend_attention_fwd,
|
||||||
redundant_attention,
|
redundant_attention,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.utils import should_use_tensor_core
|
||||||
|
|
||||||
flashinfer_prefill_wrapper = None
|
flashinfer_prefill_wrapper = None
|
||||||
flashinfer_decode_wrapper = None
|
flashinfer_decode_wrapper = None
|
||||||
@@ -195,10 +196,9 @@ def test_batch_decode_with_paged_kv_cache(
|
|||||||
|
|
||||||
|
|
||||||
def init_flashinfer(num_attention_heads, num_kv_heads):
|
def init_flashinfer(num_attention_heads, num_kv_heads):
|
||||||
if not _grouped_size_compiled_for_decode_kernels(num_attention_heads, num_kv_heads):
|
use_tensor_cores = should_use_tensor_core(
|
||||||
use_tensor_cores = True
|
torch.half, num_attention_heads, num_kv_heads
|
||||||
else:
|
)
|
||||||
use_tensor_cores = False
|
|
||||||
|
|
||||||
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda")
|
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user