Use is_flashinfer_available to replace is_hip for flashinfer check (#1596)

Co-authored-by: Zhang Liangang <liangang.zhang@intel.com>
This commit is contained in:
Lianmin Zheng
2024-10-06 22:54:05 -07:00
committed by GitHub
parent 565b05f02f
commit 6a5b352aaf
9 changed files with 29 additions and 28 deletions

View File

@@ -20,9 +20,9 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from sglang.srt.utils import is_hip
from sglang.srt.utils import is_flashinfer_available
if not is_hip():
if is_flashinfer_available():
from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
from vllm.distributed import (
@@ -146,8 +146,8 @@ def get_act_fn(
return act_fn
if is_hip():
if not is_flashinfer_available():
logger.info(
"FlashInfer is not available on AMD GPUs. Fallback to other kernel libraries."
"FlashInfer is not available on Non-NV GPUs. Fallback to other kernel libraries."
)
from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul

View File

@@ -19,13 +19,12 @@ from sglang.srt.layers.attention.flashinfer_utils import (
update_flashinfer_indices,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import is_hip
from sglang.srt.utils import is_flashinfer_available
if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
# ROCm: flashinfer available later
if not is_hip():
if is_flashinfer_available():
from flashinfer import (
BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,

View File

@@ -21,9 +21,9 @@ from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from sglang.srt.utils import is_hip
from sglang.srt.utils import is_flashinfer_available
if not is_hip():
if is_flashinfer_available():
from flashinfer.norm import (
fused_add_rmsnorm,
gemma_fused_add_rmsnorm,
@@ -119,8 +119,8 @@ class GemmaRMSNorm(CustomOp):
return out
if is_hip():
if not is_flashinfer_available():
logger.info(
"FlashInfer is not available on AMD GPUs. Fallback to other kernel libraries."
"FlashInfer is not available on Non-NV platforms. Fallback to other kernel libraries."
)
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm

View File

@@ -7,10 +7,9 @@ from torch import nn
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.utils import is_hip
from sglang.srt.utils import is_flashinfer_available
# ROCm: flashinfer available later
if not is_hip():
if is_flashinfer_available():
from flashinfer.sampling import (
min_p_sampling_from_probs,
top_k_renorm_prob,

View File

@@ -25,13 +25,11 @@ import torch
from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_hip, replace_submodule
from sglang.srt.utils import is_flashinfer_available, replace_submodule
logger = logging.getLogger(__name__)
# ROCm: flashinfer available later
if not is_hip():
if is_flashinfer_available():
from flashinfer import SegmentGEMMWrapper

View File

@@ -47,10 +47,9 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_hip
from sglang.srt.utils import is_flashinfer_available
# ROCm: flashinfer available later
if not is_hip():
if is_flashinfer_available():
from flashinfer import bmm_fp8

View File

@@ -43,10 +43,9 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_hip
from sglang.srt.utils import is_flashinfer_available
# ROCm: flashinfer available later
if not is_hip():
if is_flashinfer_available():
from flashinfer import bmm_fp8

View File

@@ -22,7 +22,7 @@ import random
import tempfile
from typing import List, Optional
from sglang.srt.utils import is_hip, is_ipv6, is_port_available
from sglang.srt.utils import is_flashinfer_available, is_ipv6, is_port_available
logger = logging.getLogger(__name__)
@@ -151,8 +151,7 @@ class ServerArgs:
)
self.sampling_backend = "pytorch"
# ROCm: flashinfer available later
if is_hip():
if not is_flashinfer_available():
self.attention_backend = "triton"
self.sampling_backend = "pytorch"

View File

@@ -50,11 +50,19 @@ show_time_cost = False
time_infos = {}
# torch flag AMD GPU
def is_hip() -> bool:
"""Return whether it is HIP on the AMD ROCm platform."""
return torch.version.hip is not None
def is_flashinfer_available():
"""
Check whether flashinfer is available.
As of Oct. 6, 2024, it is only available on NVIDIA GPUs.
"""
return torch.cuda.is_available() and not is_hip()
def is_ipv6(address):
try:
ipaddress.IPv6Address(address)