[Feature, Hardware] Enable SGLang on AMD GPUs via PyTorch for ROCm (#1420)

This commit is contained in:
HAI
2024-09-17 00:43:52 -07:00
committed by GitHub
parent 2fa5cec775
commit 3a6e04185b
11 changed files with 104 additions and 24 deletions

View File

@@ -21,6 +21,8 @@ import logging
import random
from typing import List, Optional, Union
from sglang.srt.utils import is_hip
logger = logging.getLogger(__name__)
@@ -164,6 +166,11 @@ class ServerArgs:
)
self.sampling_backend = "pytorch"
# ROCm: flashinfer available later
if is_hip():
self.attention_backend = "triton"
self.sampling_backend = "pytorch"
# Default kernel backends
if self.enable_mla:
logger.info("MLA optimization is tunred on. Use triton backend.")