feat: add should_use_tensor_core (#2179)

This commit is contained in:
Yineng Zhang
2024-12-01 18:01:16 +08:00
committed by GitHub
parent 9449a95431
commit 118b6af35e
3 changed files with 65 additions and 19 deletions

View File

@@ -11,6 +11,7 @@ from sglang.srt.layers.attention.triton_ops.extend_attention import (
extend_attention_fwd,
redundant_attention,
)
from sglang.srt.utils import should_use_tensor_core
flashinfer_prefill_wrapper = None
flashinfer_decode_wrapper = None
@@ -195,10 +196,9 @@ def test_batch_decode_with_paged_kv_cache(
def init_flashinfer(num_attention_heads, num_kv_heads):
if not _grouped_size_compiled_for_decode_kernels(num_attention_heads, num_kv_heads):
use_tensor_cores = True
else:
use_tensor_cores = False
use_tensor_cores = should_use_tensor_core(
torch.half, num_attention_heads, num_kv_heads
)
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda")