[Feature, Hardware] Enable SGLang on AMD GPUs via PyTorch for ROCm (#1420)

This commit is contained in:
HAI
2024-09-17 00:43:52 -07:00
committed by GitHub
parent 2fa5cec775
commit 3a6e04185b
11 changed files with 104 additions and 24 deletions

View File

@@ -21,12 +21,15 @@ import re
from dataclasses import dataclass
import torch
from flashinfer import SegmentGEMMWrapper
from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils import replace_submodule
from sglang.srt.utils import is_hip, replace_submodule
# ROCm: flashinfer available later
if not is_hip():
from flashinfer import SegmentGEMMWrapper
def get_stacked_name(name):