[Feature, Hardware] Enable SGLang on AMD GPUs via PyTorch for ROCm (#1420)
This commit is contained in:
@@ -19,7 +19,6 @@ import math
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from flashinfer import bmm_fp8
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.config import CacheConfig
|
||||
@@ -44,6 +43,11 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
# ROCm: flashinfer available later
|
||||
if not is_hip():
|
||||
from flashinfer import bmm_fp8
|
||||
|
||||
|
||||
class MiniCPM3MLP(nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user