[Feature, Hardware] Enable SGLang on AMD GPUs via PyTorch for ROCm (#1420)
This commit is contained in:
@@ -21,6 +21,8 @@ import logging
|
||||
import random
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -164,6 +166,11 @@ class ServerArgs:
|
||||
)
|
||||
self.sampling_backend = "pytorch"
|
||||
|
||||
# ROCm: flashinfer available later
|
||||
if is_hip():
|
||||
self.attention_backend = "triton"
|
||||
self.sampling_backend = "pytorch"
|
||||
|
||||
# Default kernel backends
|
||||
if self.enable_mla:
|
||||
logger.info("MLA optimization is tunred on. Use triton backend.")
|
||||
|
||||
Reference in New Issue
Block a user