Use is_flashinfer_available to replace is_hip for flashinfer check (#1596)
Co-authored-by: Zhang Liangang <liangang.zhang@intel.com>
This commit is contained in:
@@ -20,9 +20,9 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from sglang.srt.utils import is_hip
|
from sglang.srt.utils import is_flashinfer_available
|
||||||
|
|
||||||
if not is_hip():
|
if is_flashinfer_available():
|
||||||
from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
||||||
|
|
||||||
from vllm.distributed import (
|
from vllm.distributed import (
|
||||||
@@ -146,8 +146,8 @@ def get_act_fn(
|
|||||||
return act_fn
|
return act_fn
|
||||||
|
|
||||||
|
|
||||||
if is_hip():
|
if not is_flashinfer_available():
|
||||||
logger.info(
|
logger.info(
|
||||||
"FlashInfer is not available on AMD GPUs. Fallback to other kernel libraries."
|
"FlashInfer is not available on Non-NV GPUs. Fallback to other kernel libraries."
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
|
from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
|
||||||
|
|||||||
@@ -19,13 +19,12 @@ from sglang.srt.layers.attention.flashinfer_utils import (
|
|||||||
update_flashinfer_indices,
|
update_flashinfer_indices,
|
||||||
)
|
)
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
from sglang.srt.utils import is_hip
|
from sglang.srt.utils import is_flashinfer_available
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
|
|
||||||
# ROCm: flashinfer available later
|
if is_flashinfer_available():
|
||||||
if not is_hip():
|
|
||||||
from flashinfer import (
|
from flashinfer import (
|
||||||
BatchDecodeWithPagedKVCacheWrapper,
|
BatchDecodeWithPagedKVCacheWrapper,
|
||||||
BatchPrefillWithPagedKVCacheWrapper,
|
BatchPrefillWithPagedKVCacheWrapper,
|
||||||
|
|||||||
@@ -21,9 +21,9 @@ from typing import Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from sglang.srt.utils import is_hip
|
from sglang.srt.utils import is_flashinfer_available
|
||||||
|
|
||||||
if not is_hip():
|
if is_flashinfer_available():
|
||||||
from flashinfer.norm import (
|
from flashinfer.norm import (
|
||||||
fused_add_rmsnorm,
|
fused_add_rmsnorm,
|
||||||
gemma_fused_add_rmsnorm,
|
gemma_fused_add_rmsnorm,
|
||||||
@@ -119,8 +119,8 @@ class GemmaRMSNorm(CustomOp):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
if is_hip():
|
if not is_flashinfer_available():
|
||||||
logger.info(
|
logger.info(
|
||||||
"FlashInfer is not available on AMD GPUs. Fallback to other kernel libraries."
|
"FlashInfer is not available on Non-NV platforms. Fallback to other kernel libraries."
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
|
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
|
||||||
|
|||||||
@@ -7,10 +7,9 @@ from torch import nn
|
|||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
from sglang.srt.utils import is_hip
|
from sglang.srt.utils import is_flashinfer_available
|
||||||
|
|
||||||
# ROCm: flashinfer available later
|
if is_flashinfer_available():
|
||||||
if not is_hip():
|
|
||||||
from flashinfer.sampling import (
|
from flashinfer.sampling import (
|
||||||
min_p_sampling_from_probs,
|
min_p_sampling_from_probs,
|
||||||
top_k_renorm_prob,
|
top_k_renorm_prob,
|
||||||
|
|||||||
@@ -25,13 +25,11 @@ import torch
|
|||||||
from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
|
from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
|
||||||
from sglang.srt.lora.lora_config import LoRAConfig
|
from sglang.srt.lora.lora_config import LoRAConfig
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.utils import is_hip, replace_submodule
|
from sglang.srt.utils import is_flashinfer_available, replace_submodule
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
if is_flashinfer_available():
|
||||||
# ROCm: flashinfer available later
|
|
||||||
if not is_hip():
|
|
||||||
from flashinfer import SegmentGEMMWrapper
|
from flashinfer import SegmentGEMMWrapper
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -47,10 +47,9 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.utils import is_hip
|
from sglang.srt.utils import is_flashinfer_available
|
||||||
|
|
||||||
# ROCm: flashinfer available later
|
if is_flashinfer_available():
|
||||||
if not is_hip():
|
|
||||||
from flashinfer import bmm_fp8
|
from flashinfer import bmm_fp8
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -43,10 +43,9 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.utils import is_hip
|
from sglang.srt.utils import is_flashinfer_available
|
||||||
|
|
||||||
# ROCm: flashinfer available later
|
if is_flashinfer_available():
|
||||||
if not is_hip():
|
|
||||||
from flashinfer import bmm_fp8
|
from flashinfer import bmm_fp8
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import random
|
|||||||
import tempfile
|
import tempfile
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from sglang.srt.utils import is_hip, is_ipv6, is_port_available
|
from sglang.srt.utils import is_flashinfer_available, is_ipv6, is_port_available
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -151,8 +151,7 @@ class ServerArgs:
|
|||||||
)
|
)
|
||||||
self.sampling_backend = "pytorch"
|
self.sampling_backend = "pytorch"
|
||||||
|
|
||||||
# ROCm: flashinfer available later
|
if not is_flashinfer_available():
|
||||||
if is_hip():
|
|
||||||
self.attention_backend = "triton"
|
self.attention_backend = "triton"
|
||||||
self.sampling_backend = "pytorch"
|
self.sampling_backend = "pytorch"
|
||||||
|
|
||||||
|
|||||||
@@ -50,11 +50,19 @@ show_time_cost = False
|
|||||||
time_infos = {}
|
time_infos = {}
|
||||||
|
|
||||||
|
|
||||||
# torch flag AMD GPU
|
|
||||||
def is_hip() -> bool:
|
def is_hip() -> bool:
|
||||||
|
"""Return whether it is HIP on the AMD ROCm platform."""
|
||||||
return torch.version.hip is not None
|
return torch.version.hip is not None
|
||||||
|
|
||||||
|
|
||||||
|
def is_flashinfer_available():
|
||||||
|
"""
|
||||||
|
Check whether flashinfer is available.
|
||||||
|
As of Oct. 6, 2024, it is only available on NVIDIA GPUs.
|
||||||
|
"""
|
||||||
|
return torch.cuda.is_available() and not is_hip()
|
||||||
|
|
||||||
|
|
||||||
def is_ipv6(address):
|
def is_ipv6(address):
|
||||||
try:
|
try:
|
||||||
ipaddress.IPv6Address(address)
|
ipaddress.IPv6Address(address)
|
||||||
|
|||||||
Reference in New Issue
Block a user