Use is_flashinfer_available to replace is_hip for flashinfer check (#1596)
Co-authored-by: Zhang Liangang <liangang.zhang@intel.com>
This commit is contained in:
@@ -20,9 +20,9 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from sglang.srt.utils import is_hip
|
||||
from sglang.srt.utils import is_flashinfer_available
|
||||
|
||||
if not is_hip():
|
||||
if is_flashinfer_available():
|
||||
from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
||||
|
||||
from vllm.distributed import (
|
||||
@@ -146,8 +146,8 @@ def get_act_fn(
|
||||
return act_fn
|
||||
|
||||
|
||||
if is_hip():
|
||||
if not is_flashinfer_available():
|
||||
logger.info(
|
||||
"FlashInfer is not available on AMD GPUs. Fallback to other kernel libraries."
|
||||
"FlashInfer is not available on Non-NV GPUs. Fallback to other kernel libraries."
|
||||
)
|
||||
from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
|
||||
|
||||
@@ -19,13 +19,12 @@ from sglang.srt.layers.attention.flashinfer_utils import (
|
||||
update_flashinfer_indices,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.utils import is_hip
|
||||
from sglang.srt.utils import is_flashinfer_available
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
|
||||
# ROCm: flashinfer available later
|
||||
if not is_hip():
|
||||
if is_flashinfer_available():
|
||||
from flashinfer import (
|
||||
BatchDecodeWithPagedKVCacheWrapper,
|
||||
BatchPrefillWithPagedKVCacheWrapper,
|
||||
|
||||
@@ -21,9 +21,9 @@ from typing import Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from sglang.srt.utils import is_hip
|
||||
from sglang.srt.utils import is_flashinfer_available
|
||||
|
||||
if not is_hip():
|
||||
if is_flashinfer_available():
|
||||
from flashinfer.norm import (
|
||||
fused_add_rmsnorm,
|
||||
gemma_fused_add_rmsnorm,
|
||||
@@ -119,8 +119,8 @@ class GemmaRMSNorm(CustomOp):
|
||||
return out
|
||||
|
||||
|
||||
if is_hip():
|
||||
if not is_flashinfer_available():
|
||||
logger.info(
|
||||
"FlashInfer is not available on AMD GPUs. Fallback to other kernel libraries."
|
||||
"FlashInfer is not available on Non-NV platforms. Fallback to other kernel libraries."
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
|
||||
|
||||
@@ -7,10 +7,9 @@ from torch import nn
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.utils import is_hip
|
||||
from sglang.srt.utils import is_flashinfer_available
|
||||
|
||||
# ROCm: flashinfer available later
|
||||
if not is_hip():
|
||||
if is_flashinfer_available():
|
||||
from flashinfer.sampling import (
|
||||
min_p_sampling_from_probs,
|
||||
top_k_renorm_prob,
|
||||
|
||||
@@ -25,13 +25,11 @@ import torch
|
||||
from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
|
||||
from sglang.srt.lora.lora_config import LoRAConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import is_hip, replace_submodule
|
||||
from sglang.srt.utils import is_flashinfer_available, replace_submodule
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ROCm: flashinfer available later
|
||||
if not is_hip():
|
||||
if is_flashinfer_available():
|
||||
from flashinfer import SegmentGEMMWrapper
|
||||
|
||||
|
||||
|
||||
@@ -47,10 +47,9 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import is_hip
|
||||
from sglang.srt.utils import is_flashinfer_available
|
||||
|
||||
# ROCm: flashinfer available later
|
||||
if not is_hip():
|
||||
if is_flashinfer_available():
|
||||
from flashinfer import bmm_fp8
|
||||
|
||||
|
||||
|
||||
@@ -43,10 +43,9 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import is_hip
|
||||
from sglang.srt.utils import is_flashinfer_available
|
||||
|
||||
# ROCm: flashinfer available later
|
||||
if not is_hip():
|
||||
if is_flashinfer_available():
|
||||
from flashinfer import bmm_fp8
|
||||
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ import random
|
||||
import tempfile
|
||||
from typing import List, Optional
|
||||
|
||||
from sglang.srt.utils import is_hip, is_ipv6, is_port_available
|
||||
from sglang.srt.utils import is_flashinfer_available, is_ipv6, is_port_available
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -151,8 +151,7 @@ class ServerArgs:
|
||||
)
|
||||
self.sampling_backend = "pytorch"
|
||||
|
||||
# ROCm: flashinfer available later
|
||||
if is_hip():
|
||||
if not is_flashinfer_available():
|
||||
self.attention_backend = "triton"
|
||||
self.sampling_backend = "pytorch"
|
||||
|
||||
|
||||
@@ -50,11 +50,19 @@ show_time_cost = False
|
||||
time_infos = {}
|
||||
|
||||
|
||||
# torch flag AMD GPU
|
||||
def is_hip() -> bool:
|
||||
"""Return whether it is HIP on the AMD ROCm platform."""
|
||||
return torch.version.hip is not None
|
||||
|
||||
|
||||
def is_flashinfer_available():
|
||||
"""
|
||||
Check whether flashinfer is available.
|
||||
As of Oct. 6, 2024, it is only available on NVIDIA GPUs.
|
||||
"""
|
||||
return torch.cuda.is_available() and not is_hip()
|
||||
|
||||
|
||||
def is_ipv6(address):
|
||||
try:
|
||||
ipaddress.IPv6Address(address)
|
||||
|
||||
Reference in New Issue
Block a user