feat: add should_use_tensor_core (#2179)
This commit is contained in:
@@ -11,6 +11,7 @@ from sglang.srt.layers.attention.triton_ops.extend_attention import (
|
||||
extend_attention_fwd,
|
||||
redundant_attention,
|
||||
)
|
||||
from sglang.srt.utils import should_use_tensor_core
|
||||
|
||||
flashinfer_prefill_wrapper = None
|
||||
flashinfer_decode_wrapper = None
|
||||
@@ -195,10 +196,9 @@ def test_batch_decode_with_paged_kv_cache(
|
||||
|
||||
|
||||
def init_flashinfer(num_attention_heads, num_kv_heads):
|
||||
if not _grouped_size_compiled_for_decode_kernels(num_attention_heads, num_kv_heads):
|
||||
use_tensor_cores = True
|
||||
else:
|
||||
use_tensor_cores = False
|
||||
use_tensor_cores = should_use_tensor_core(
|
||||
torch.half, num_attention_heads, num_kv_heads
|
||||
)
|
||||
|
||||
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user