[Feature, Hardware] Enable SGLang on AMD GPUs via PyTorch for ROCm (#1420)

This commit is contained in:
HAI
2024-09-17 00:43:52 -07:00
committed by GitHub
parent 2fa5cec775
commit 3a6e04185b
11 changed files with 104 additions and 24 deletions

View File

@@ -19,7 +19,6 @@ import math
from typing import Any, Dict, Iterable, Optional, Tuple
import torch
from flashinfer import bmm_fp8
from torch import nn
from transformers import PretrainedConfig
from vllm.config import CacheConfig
@@ -44,6 +43,11 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.utils import is_hip
# ROCm: flashinfer available later
if not is_hip():
from flashinfer import bmm_fp8
class MiniCPM3MLP(nn.Module):