[misc] remove is_cuda_available (#5319)
This commit is contained in:
@@ -3,10 +3,10 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.utils import is_hip
|
||||
from sglang.srt.utils import is_cuda, is_hip
|
||||
|
||||
is_cuda_available = torch.cuda.is_available()
|
||||
if is_cuda_available:
|
||||
_is_cuda = is_cuda()
|
||||
if _is_cuda:
|
||||
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
||||
|
||||
_is_hip = is_hip()
|
||||
@@ -1037,12 +1037,12 @@ def extend_attention_fwd(
|
||||
num_warps = 4
|
||||
|
||||
else:
|
||||
if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
|
||||
if _is_cuda and CUDA_CAPABILITY[0] >= 9:
|
||||
if Lq <= 256:
|
||||
BLOCK_M, BLOCK_N = (128, 64)
|
||||
else:
|
||||
BLOCK_M, BLOCK_N = (32, 64)
|
||||
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
|
||||
elif _is_cuda and CUDA_CAPABILITY[0] >= 8:
|
||||
if Lq <= 128:
|
||||
BLOCK_M, BLOCK_N = (128, 128)
|
||||
elif Lq <= 256:
|
||||
|
||||
@@ -23,10 +23,10 @@ import triton.language as tl
|
||||
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
|
||||
context_attention_fwd,
|
||||
)
|
||||
from sglang.srt.utils import is_hip
|
||||
from sglang.srt.utils import is_cuda, is_hip
|
||||
|
||||
is_cuda_available = torch.cuda.is_available()
|
||||
if is_cuda_available:
|
||||
_is_cuda = is_cuda()
|
||||
if _is_cuda:
|
||||
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
||||
|
||||
_is_hip = is_hip()
|
||||
@@ -345,12 +345,12 @@ def extend_attention_fwd(
|
||||
num_warps = 4
|
||||
|
||||
else:
|
||||
if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
|
||||
if _is_cuda and CUDA_CAPABILITY[0] >= 9:
|
||||
if Lq <= 256:
|
||||
BLOCK_M, BLOCK_N = (128, 64)
|
||||
else:
|
||||
BLOCK_M, BLOCK_N = (32, 64)
|
||||
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
|
||||
elif _is_cuda and CUDA_CAPABILITY[0] >= 8:
|
||||
# sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K)
|
||||
if CUDA_CAPABILITY[1] == 9 or CUDA_CAPABILITY[1] == 6:
|
||||
if Lq <= 128:
|
||||
|
||||
@@ -22,8 +22,12 @@ import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
is_cuda_available = torch.cuda.is_available()
|
||||
if is_cuda_available:
|
||||
from sglang.srt.utils import is_cuda, is_hip
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_hip = is_hip()
|
||||
|
||||
if _is_cuda or _is_hip:
|
||||
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
||||
|
||||
|
||||
@@ -172,7 +176,7 @@ def context_attention_fwd(
|
||||
b_seq_len: [b]
|
||||
out: [b * s, head, head_dim]
|
||||
"""
|
||||
if is_cuda_available and CUDA_CAPABILITY[0] > 8:
|
||||
if (_is_cuda or _is_hip) and CUDA_CAPABILITY[0] > 8:
|
||||
BLOCK = 128
|
||||
else:
|
||||
BLOCK = 64
|
||||
|
||||
Reference in New Issue
Block a user