[Feature, Hardware] Enable SGLang on AMD GPUs via PyTorch for ROCm (#1420)

This commit is contained in:
HAI
2024-09-17 00:43:52 -07:00
committed by GitHub
parent 2fa5cec775
commit 3a6e04185b
11 changed files with 104 additions and 24 deletions

View File

@@ -15,6 +15,7 @@ limitations under the License.
"""Fused operators for normalization layers."""
import logging
from typing import Optional, Tuple, Union
import torch
@@ -27,6 +28,10 @@ from flashinfer.norm import (
)
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.utils import is_hip
logger = logging.getLogger(__name__)
class RMSNorm(CustomOp):
def __init__(
@@ -109,3 +114,10 @@ class GemmaRMSNorm(CustomOp):
return x, residual
out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
return out
if is_hip():
logger.info(
"FlashInfer is not available on AMD GPUs. Fallback to other kernel libraries."
)
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm