diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index 1ebb44d55..8f2b0e116 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -20,9 +20,9 @@ import torch import torch.nn as nn import torch.nn.functional as F -from sglang.srt.utils import is_hip +from sglang.srt.utils import is_flashinfer_available -if not is_hip(): +if is_flashinfer_available(): from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul from vllm.distributed import ( @@ -146,8 +146,8 @@ def get_act_fn( return act_fn -if is_hip(): +if not is_flashinfer_available(): logger.info( - "FlashInfer is not available on AMD GPUs. Fallback to other kernel libraries." + "FlashInfer is not available on Non-NV GPUs. Fallback to other kernel libraries." ) from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index a753233a9..0c9ca8f9d 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -19,13 +19,12 @@ from sglang.srt.layers.attention.flashinfer_utils import ( update_flashinfer_indices, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode -from sglang.srt.utils import is_hip +from sglang.srt.utils import is_flashinfer_available if TYPE_CHECKING: from sglang.srt.model_executor.model_runner import ModelRunner -# ROCm: flashinfer available later -if not is_hip(): +if is_flashinfer_available(): from flashinfer import ( BatchDecodeWithPagedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper, diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 042c88e24..3ae392eb9 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -21,9 +21,9 @@ from typing import Optional, Tuple, Union import torch import torch.nn as nn -from sglang.srt.utils import is_hip +from sglang.srt.utils import is_flashinfer_available -if not is_hip(): +if is_flashinfer_available(): from flashinfer.norm import ( fused_add_rmsnorm, gemma_fused_add_rmsnorm, @@ -119,8 +119,8 @@ class GemmaRMSNorm(CustomOp): return out -if is_hip(): +if not is_flashinfer_available(): logger.info( - "FlashInfer is not available on AMD GPUs. Fallback to other kernel libraries." + "FlashInfer is not available on Non-NV platforms. Fallback to other kernel libraries." ) from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index b45ec080b..7421bda18 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -7,10 +7,9 @@ from torch import nn from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo -from sglang.srt.utils import is_hip +from sglang.srt.utils import is_flashinfer_available -# ROCm: flashinfer available later -if not is_hip(): +if is_flashinfer_available(): from flashinfer.sampling import ( min_p_sampling_from_probs, top_k_renorm_prob, diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index dd46212ed..61ac43b2a 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -25,13 +25,11 @@ import torch from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer from sglang.srt.lora.lora_config import LoRAConfig from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.utils import is_hip, replace_submodule +from sglang.srt.utils import is_flashinfer_available, replace_submodule logger = logging.getLogger(__name__) - -# ROCm: flashinfer available later -if not is_hip(): +if is_flashinfer_available(): from flashinfer import SegmentGEMMWrapper diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 8524be22b..ce632275a 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -47,10 +47,9 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.utils import is_hip +from sglang.srt.utils import is_flashinfer_available -# ROCm: flashinfer available later -if not is_hip(): +if is_flashinfer_available(): from flashinfer import bmm_fp8 diff --git a/python/sglang/srt/models/minicpm3.py b/python/sglang/srt/models/minicpm3.py index 0e29eb357..abdf107d3 100644 --- a/python/sglang/srt/models/minicpm3.py +++ b/python/sglang/srt/models/minicpm3.py @@ -43,10 +43,9 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.utils import is_hip +from sglang.srt.utils import is_flashinfer_available -# ROCm: flashinfer available later -if not is_hip(): +if is_flashinfer_available(): from flashinfer import bmm_fp8 diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 12f130352..e467eb03f 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -22,7 +22,7 @@ import random import tempfile from typing import List, Optional -from sglang.srt.utils import is_hip, is_ipv6, is_port_available +from sglang.srt.utils import is_flashinfer_available, is_ipv6, is_port_available logger = logging.getLogger(__name__) @@ -151,8 +151,7 @@ class ServerArgs: ) self.sampling_backend = "pytorch" - # ROCm: flashinfer available later - if is_hip(): + if not is_flashinfer_available(): self.attention_backend = "triton" self.sampling_backend = "pytorch" diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index c543bb9a2..39106581a 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -50,11 +50,19 @@ show_time_cost = False time_infos = {} -# torch flag AMD GPU def is_hip() -> bool: + """Return whether it is HIP on the AMD ROCm platform.""" return torch.version.hip is not None +def is_flashinfer_available(): + """ + Check whether flashinfer is available. + As of Oct. 6, 2024, it is only available on NVIDIA GPUs. + """ + return torch.cuda.is_available() and not is_hip() + + def is_ipv6(address): try: ipaddress.IPv6Address(address)