[Bugfix] Enable SGLang on AMD GPUs via PyTorch for ROCm (#1419) (#1453)

This commit is contained in:
HAI
2024-09-18 02:01:35 -07:00
committed by GitHub
parent 5e62a6b706
commit aa2750beb3
2 changed files with 16 additions and 10 deletions

View File

@@ -19,7 +19,12 @@ from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
from sglang.srt.utils import is_hip
if not is_hip():
from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
from vllm.distributed import (
divide,
get_tensor_model_parallel_rank,
@@ -29,8 +34,6 @@ from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.utils import set_weight_attrs
from sglang.srt.utils import is_hip
logger = logging.getLogger(__name__)

View File

@@ -20,16 +20,19 @@ from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from flashinfer.norm import (
fused_add_rmsnorm,
gemma_fused_add_rmsnorm,
gemma_rmsnorm,
rmsnorm,
)
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.utils import is_hip
if not is_hip():
from flashinfer.norm import (
fused_add_rmsnorm,
gemma_fused_add_rmsnorm,
gemma_rmsnorm,
rmsnorm,
)
from vllm.model_executor.custom_op import CustomOp
logger = logging.getLogger(__name__)