From aa2750beb30fe6663fa68162b8937399cebb03e4 Mon Sep 17 00:00:00 2001 From: HAI Date: Wed, 18 Sep 2024 02:01:35 -0700 Subject: [PATCH] [Bugfix] Enable SGLang on AMD GPUs via PyTorch for ROCm (#1419) (#1453) --- python/sglang/srt/layers/activation.py | 9 ++++++--- python/sglang/srt/layers/layernorm.py | 17 ++++++++++------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index a3aeda9c4..6cae1fd9a 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -19,7 +19,12 @@ from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F -from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul + +from sglang.srt.utils import is_hip + +if not is_hip(): + from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul + from vllm.distributed import ( divide, get_tensor_model_parallel_rank, @@ -29,8 +34,6 @@ from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.utils import set_weight_attrs -from sglang.srt.utils import is_hip - logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index c4803a334..042c88e24 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -20,16 +20,19 @@ from typing import Optional, Tuple, Union import torch import torch.nn as nn -from flashinfer.norm import ( - fused_add_rmsnorm, - gemma_fused_add_rmsnorm, - gemma_rmsnorm, - rmsnorm, -) -from vllm.model_executor.custom_op import CustomOp from sglang.srt.utils import is_hip +if not is_hip(): + from flashinfer.norm import ( + fused_add_rmsnorm, + gemma_fused_add_rmsnorm, + gemma_rmsnorm, + rmsnorm, + ) + +from vllm.model_executor.custom_op import CustomOp + logger = logging.getLogger(__name__)