[Bugfix] Enable SGLang on AMD GPUs via PyTorch for ROCm (#1419) (#1453)

This commit is contained in:
HAI
2024-09-18 02:01:35 -07:00
committed by GitHub
parent 5e62a6b706
commit aa2750beb3
2 changed files with 16 additions and 10 deletions

View File

@@ -20,16 +20,19 @@ from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from flashinfer.norm import (
fused_add_rmsnorm,
gemma_fused_add_rmsnorm,
gemma_rmsnorm,
rmsnorm,
)
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.utils import is_hip
if not is_hip():
from flashinfer.norm import (
fused_add_rmsnorm,
gemma_fused_add_rmsnorm,
gemma_rmsnorm,
rmsnorm,
)
from vllm.model_executor.custom_op import CustomOp
logger = logging.getLogger(__name__)