[misc] remove is_cuda_available (#5319)

This commit is contained in:
JieXin Liang
2025-04-21 09:16:51 +08:00
committed by GitHub
parent 1195182040
commit 97cb762bb6
14 changed files with 42 additions and 47 deletions

View File

@@ -3,10 +3,10 @@ import triton
import triton.language as tl
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import is_hip
from sglang.srt.utils import is_cuda, is_hip
is_cuda_available = torch.cuda.is_available()
if is_cuda_available:
_is_cuda = is_cuda()
if _is_cuda:
CUDA_CAPABILITY = torch.cuda.get_device_capability()
_is_hip = is_hip()
@@ -1037,12 +1037,12 @@ def extend_attention_fwd(
num_warps = 4
else:
if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
if _is_cuda and CUDA_CAPABILITY[0] >= 9:
if Lq <= 256:
BLOCK_M, BLOCK_N = (128, 64)
else:
BLOCK_M, BLOCK_N = (32, 64)
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
elif _is_cuda and CUDA_CAPABILITY[0] >= 8:
if Lq <= 128:
BLOCK_M, BLOCK_N = (128, 128)
elif Lq <= 256:

View File

@@ -23,10 +23,10 @@ import triton.language as tl
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
context_attention_fwd,
)
from sglang.srt.utils import is_hip
from sglang.srt.utils import is_cuda, is_hip
is_cuda_available = torch.cuda.is_available()
if is_cuda_available:
_is_cuda = is_cuda()
if _is_cuda:
CUDA_CAPABILITY = torch.cuda.get_device_capability()
_is_hip = is_hip()
@@ -345,12 +345,12 @@ def extend_attention_fwd(
num_warps = 4
else:
if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
if _is_cuda and CUDA_CAPABILITY[0] >= 9:
if Lq <= 256:
BLOCK_M, BLOCK_N = (128, 64)
else:
BLOCK_M, BLOCK_N = (32, 64)
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
elif _is_cuda and CUDA_CAPABILITY[0] >= 8:
# sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K)
if CUDA_CAPABILITY[1] == 9 or CUDA_CAPABILITY[1] == 6:
if Lq <= 128:

View File

@@ -22,8 +22,12 @@ import torch
import triton
import triton.language as tl
is_cuda_available = torch.cuda.is_available()
if is_cuda_available:
from sglang.srt.utils import is_cuda, is_hip
_is_cuda = is_cuda()
_is_hip = is_hip()
if _is_cuda or _is_hip:
CUDA_CAPABILITY = torch.cuda.get_device_capability()
@@ -172,7 +176,7 @@ def context_attention_fwd(
b_seq_len: [b]
out: [b * s, head, head_dim]
"""
if is_cuda_available and CUDA_CAPABILITY[0] > 8:
if (_is_cuda or _is_hip) and CUDA_CAPABILITY[0] > 8:
BLOCK = 128
else:
BLOCK = 64