[misc] remove is_cuda_available (#5319)
This commit is contained in:
@@ -28,9 +28,9 @@ from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.utils import is_cuda_available, set_weight_attrs
|
||||
from sglang.srt.utils import is_cuda, set_weight_attrs
|
||||
|
||||
_is_cuda = is_cuda_available()
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
||||
|
||||
@@ -3,10 +3,10 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.utils import is_hip
|
||||
from sglang.srt.utils import is_cuda, is_hip
|
||||
|
||||
is_cuda_available = torch.cuda.is_available()
|
||||
if is_cuda_available:
|
||||
_is_cuda = is_cuda()
|
||||
if _is_cuda:
|
||||
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
||||
|
||||
_is_hip = is_hip()
|
||||
@@ -1037,12 +1037,12 @@ def extend_attention_fwd(
|
||||
num_warps = 4
|
||||
|
||||
else:
|
||||
if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
|
||||
if _is_cuda and CUDA_CAPABILITY[0] >= 9:
|
||||
if Lq <= 256:
|
||||
BLOCK_M, BLOCK_N = (128, 64)
|
||||
else:
|
||||
BLOCK_M, BLOCK_N = (32, 64)
|
||||
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
|
||||
elif _is_cuda and CUDA_CAPABILITY[0] >= 8:
|
||||
if Lq <= 128:
|
||||
BLOCK_M, BLOCK_N = (128, 128)
|
||||
elif Lq <= 256:
|
||||
|
||||
@@ -23,10 +23,10 @@ import triton.language as tl
|
||||
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
|
||||
context_attention_fwd,
|
||||
)
|
||||
from sglang.srt.utils import is_hip
|
||||
from sglang.srt.utils import is_cuda, is_hip
|
||||
|
||||
is_cuda_available = torch.cuda.is_available()
|
||||
if is_cuda_available:
|
||||
_is_cuda = is_cuda()
|
||||
if _is_cuda:
|
||||
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
||||
|
||||
_is_hip = is_hip()
|
||||
@@ -345,12 +345,12 @@ def extend_attention_fwd(
|
||||
num_warps = 4
|
||||
|
||||
else:
|
||||
if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
|
||||
if _is_cuda and CUDA_CAPABILITY[0] >= 9:
|
||||
if Lq <= 256:
|
||||
BLOCK_M, BLOCK_N = (128, 64)
|
||||
else:
|
||||
BLOCK_M, BLOCK_N = (32, 64)
|
||||
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
|
||||
elif _is_cuda and CUDA_CAPABILITY[0] >= 8:
|
||||
# sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K)
|
||||
if CUDA_CAPABILITY[1] == 9 or CUDA_CAPABILITY[1] == 6:
|
||||
if Lq <= 128:
|
||||
|
||||
@@ -22,8 +22,12 @@ import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
is_cuda_available = torch.cuda.is_available()
|
||||
if is_cuda_available:
|
||||
from sglang.srt.utils import is_cuda, is_hip
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_hip = is_hip()
|
||||
|
||||
if _is_cuda or _is_hip:
|
||||
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
||||
|
||||
|
||||
@@ -172,7 +176,7 @@ def context_attention_fwd(
|
||||
b_seq_len: [b]
|
||||
out: [b * s, head, head_dim]
|
||||
"""
|
||||
if is_cuda_available and CUDA_CAPABILITY[0] > 8:
|
||||
if (_is_cuda or _is_hip) and CUDA_CAPABILITY[0] > 8:
|
||||
BLOCK = 128
|
||||
else:
|
||||
BLOCK = 64
|
||||
|
||||
@@ -20,9 +20,9 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.utils import is_cuda_available
|
||||
from sglang.srt.utils import is_cuda
|
||||
|
||||
_is_cuda = is_cuda_available()
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import (
|
||||
|
||||
@@ -22,9 +22,9 @@ from sglang.srt.layers.quantization.utils import (
|
||||
requantize_with_max_scale,
|
||||
)
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.utils import is_cuda_available
|
||||
from sglang.srt.utils import is_cuda
|
||||
|
||||
if is_cuda_available():
|
||||
if is_cuda():
|
||||
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
|
||||
# Initialize logger for the module
|
||||
|
||||
@@ -11,10 +11,10 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
||||
from sglang.srt.utils import is_cuda_available, set_weight_attrs
|
||||
from sglang.srt.utils import is_cuda, set_weight_attrs
|
||||
|
||||
is_cuda = is_cuda_available()
|
||||
if is_cuda:
|
||||
_is_cuda = is_cuda()
|
||||
if _is_cuda:
|
||||
from sgl_kernel import int8_scaled_mm
|
||||
|
||||
|
||||
|
||||
@@ -8,11 +8,11 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.utils import is_cuda_available
|
||||
from sglang.srt.utils import is_cuda
|
||||
|
||||
_is_cuda_available = is_cuda_available()
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
if _is_cuda_available:
|
||||
if _is_cuda:
|
||||
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
|
||||
else:
|
||||
from vllm._custom_ops import rotary_embedding as vllm_rotary_embedding
|
||||
@@ -82,7 +82,7 @@ class RotaryEmbedding(CustomOp):
|
||||
|
||||
cache = self._compute_cos_sin_cache()
|
||||
# NOTE(ByronHsu): cache needs to be in FP32 for numerical stability
|
||||
if not _is_cuda_available:
|
||||
if not _is_cuda:
|
||||
cache = cache.to(dtype)
|
||||
self.cos_sin_cache: torch.Tensor
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
@@ -149,7 +149,7 @@ class RotaryEmbedding(CustomOp):
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if _is_cuda_available and (self.head_size in [64, 128, 256, 512]):
|
||||
if _is_cuda and (self.head_size in [64, 128, 256, 512]):
|
||||
apply_rope_with_cos_sin_cache_inplace(
|
||||
positions=positions,
|
||||
query=query,
|
||||
@@ -652,7 +652,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
def forward(self, *args, **kwargs):
|
||||
if torch.compiler.is_compiling():
|
||||
return self.forward_native(*args, **kwargs)
|
||||
if _is_cuda_available:
|
||||
if _is_cuda:
|
||||
return self.forward_cuda(*args, **kwargs)
|
||||
else:
|
||||
return self.forward_native(*args, **kwargs)
|
||||
|
||||
@@ -10,9 +10,9 @@ from sglang.srt.layers.dp_attention import get_attention_tp_group
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda_available
|
||||
from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda
|
||||
|
||||
if is_cuda_available():
|
||||
if is_cuda():
|
||||
from sgl_kernel import (
|
||||
min_p_sampling_from_probs,
|
||||
top_k_renorm_prob,
|
||||
|
||||
@@ -40,9 +40,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix, is_cuda_available
|
||||
from sglang.srt.utils import add_prefix, is_cuda
|
||||
|
||||
if is_cuda_available():
|
||||
if is_cuda():
|
||||
from sgl_kernel import bmm_fp8
|
||||
|
||||
|
||||
|
||||
@@ -4,9 +4,9 @@ from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.utils import is_cuda_available, is_hip
|
||||
from sglang.srt.utils import is_cuda, is_hip
|
||||
|
||||
if is_cuda_available() or is_hip():
|
||||
if is_cuda() or is_hip():
|
||||
from sgl_kernel import (
|
||||
build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
|
||||
)
|
||||
|
||||
@@ -19,9 +19,9 @@ from sglang.srt.managers.schedule_batch import (
|
||||
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
||||
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
|
||||
from sglang.srt.utils import fast_topk, is_cuda_available, is_hip, next_power_of_2
|
||||
from sglang.srt.utils import fast_topk, is_cuda, is_hip, next_power_of_2
|
||||
|
||||
if is_cuda_available():
|
||||
if is_cuda():
|
||||
from sgl_kernel import (
|
||||
top_k_renorm_prob,
|
||||
top_p_renorm_prob,
|
||||
|
||||
@@ -34,14 +34,9 @@ from sglang.srt.speculative.eagle_utils import (
|
||||
select_top_k_tokens,
|
||||
)
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.utils import (
|
||||
empty_context,
|
||||
fast_topk,
|
||||
get_available_gpu_memory,
|
||||
is_cuda_available,
|
||||
)
|
||||
from sglang.srt.utils import empty_context, fast_topk, get_available_gpu_memory, is_cuda
|
||||
|
||||
if is_cuda_available():
|
||||
if is_cuda():
|
||||
from sgl_kernel import segment_packbits
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -130,10 +130,6 @@ def is_flashinfer_available():
|
||||
return importlib.util.find_spec("flashinfer") is not None and is_cuda()
|
||||
|
||||
|
||||
def is_cuda_available():
|
||||
return is_cuda()
|
||||
|
||||
|
||||
_ENABLE_TORCH_INFERENCE_MODE = get_bool_env_var(
|
||||
"SGLANG_ENABLE_TORCH_INFERENCE_MODE", "false"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user