Use is_flashinfer_available to replace is_hip for flashinfer check (#1596)

Co-authored-by: Zhang Liangang <liangang.zhang@intel.com>
This commit is contained in:
Lianmin Zheng
2024-10-06 22:54:05 -07:00
committed by GitHub
parent 565b05f02f
commit 6a5b352aaf
9 changed files with 29 additions and 28 deletions

View File

@@ -21,9 +21,9 @@ from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from sglang.srt.utils import is_hip
from sglang.srt.utils import is_flashinfer_available
if not is_hip():
if is_flashinfer_available():
from flashinfer.norm import (
fused_add_rmsnorm,
gemma_fused_add_rmsnorm,
@@ -119,8 +119,8 @@ class GemmaRMSNorm(CustomOp):
return out
if is_hip():
if not is_flashinfer_available():
logger.info(
"FlashInfer is not available on AMD GPUs. Fallback to other kernel libraries."
"FlashInfer is not available on Non-NV platforms. Fallback to other kernel libraries."
)
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm