[Feature, Hardware] Enable SGLang on AMD GPUs via PyTorch for ROCm (#1420)
This commit is contained in:
@@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
"""Fused operators for normalization layers."""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -27,6 +28,10 @@ from flashinfer.norm import (
|
||||
)
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RMSNorm(CustomOp):
|
||||
def __init__(
|
||||
@@ -109,3 +114,10 @@ class GemmaRMSNorm(CustomOp):
|
||||
return x, residual
|
||||
out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
|
||||
return out
|
||||
|
||||
|
||||
if is_hip():
|
||||
logger.info(
|
||||
"FlashInfer is not available on AMD GPUs. Fallback to other kernel libraries."
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
|
||||
|
||||
Reference in New Issue
Block a user