[Feature, Hardware] Enable SGLang on AMD GPUs via PyTorch for ROCm (#1420)
This commit is contained in:
@@ -13,6 +13,7 @@ limitations under the License.
|
||||
|
||||
"""Fused operators for activation layers."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@@ -28,6 +29,10 @@ from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SiluAndMul(CustomOp):
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -135,3 +140,10 @@ def get_act_fn(
|
||||
act_fn, intermediate_size, input_is_parallel, params_dtype
|
||||
)
|
||||
return act_fn
|
||||
|
||||
|
||||
if is_hip():
|
||||
logger.info(
|
||||
"FlashInfer is not available on AMD GPUs. Fallback to other kernel libraries."
|
||||
)
|
||||
from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
|
||||
|
||||
@@ -12,22 +12,26 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from flashinfer import (
|
||||
BatchDecodeWithPagedKVCacheWrapper,
|
||||
BatchPrefillWithPagedKVCacheWrapper,
|
||||
BatchPrefillWithRaggedKVCacheWrapper,
|
||||
)
|
||||
from flashinfer.cascade import merge_state
|
||||
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
|
||||
# ROCm: flashinfer available later
|
||||
if not is_hip():
|
||||
from flashinfer import (
|
||||
BatchDecodeWithPagedKVCacheWrapper,
|
||||
BatchPrefillWithPagedKVCacheWrapper,
|
||||
BatchPrefillWithRaggedKVCacheWrapper,
|
||||
)
|
||||
from flashinfer.cascade import merge_state
|
||||
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
||||
|
||||
|
||||
class AttentionBackend(ABC):
|
||||
"""The base class of attention backends"""
|
||||
|
||||
@@ -18,6 +18,8 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -381,6 +383,7 @@ from torch.nn import Module
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
all_close_1d,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
per_tensor_dequantize,
|
||||
)
|
||||
from vllm.utils import print_warning_once
|
||||
@@ -479,14 +482,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
|
||||
# If checkpoint is fp16, quantize in place.
|
||||
# If checkpoint is fp16 or bfloat16, quantize in place.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
w13_weight = torch.empty_like(
|
||||
layer.w13_weight.data, dtype=torch.float8_e4m3fn
|
||||
)
|
||||
w2_weight = torch.empty_like(
|
||||
layer.w2_weight.data, dtype=torch.float8_e4m3fn
|
||||
)
|
||||
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
|
||||
fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
|
||||
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
||||
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
||||
|
||||
# Re-initialize w13_scale because we directly quantize
|
||||
# merged w13 weights and generate a single scaling factor.
|
||||
@@ -534,6 +535,25 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.a2_scale.max(), requires_grad=False
|
||||
)
|
||||
|
||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||
if is_hip():
|
||||
# Normalize the weights and scales
|
||||
w13_weight, w13_scale, a13_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
layer.w13_weight, layer.w13_scale, layer.a13_scale
|
||||
)
|
||||
w2_weight, w2_scale, a2_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
layer.w2_weight, layer.w2_scale, layer.a2_scale
|
||||
)
|
||||
# Reset the parameters
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||
layer.w13_scale = torch.nn.Parameter(w13_scale, requires_grad=False)
|
||||
if a13_scale is not None:
|
||||
layer.a13_scale = torch.nn.Parameter(a13_scale, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||
layer.w2_scale = torch.nn.Parameter(w2_scale, requires_grad=False)
|
||||
if a2_scale is not None:
|
||||
layer.a2_scale = torch.nn.Parameter(a2_scale, requires_grad=False)
|
||||
|
||||
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
||||
# We take the max then dequant and requant each expert.
|
||||
assert layer.w13_scale is not None
|
||||
|
||||
@@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
"""Fused operators for normalization layers."""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -27,6 +28,10 @@ from flashinfer.norm import (
|
||||
)
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RMSNorm(CustomOp):
|
||||
def __init__(
|
||||
@@ -109,3 +114,10 @@ class GemmaRMSNorm(CustomOp):
|
||||
return x, residual
|
||||
out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
|
||||
return out
|
||||
|
||||
|
||||
if is_hip():
|
||||
logger.info(
|
||||
"FlashInfer is not available on AMD GPUs. Fallback to other kernel libraries."
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
|
||||
|
||||
@@ -2,17 +2,21 @@ import logging
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from flashinfer.sampling import (
|
||||
min_p_sampling_from_probs,
|
||||
top_k_renorm_prob,
|
||||
top_k_top_p_sampling_from_probs,
|
||||
top_p_renorm_prob,
|
||||
)
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
# ROCm: flashinfer available later
|
||||
if not is_hip():
|
||||
from flashinfer.sampling import (
|
||||
min_p_sampling_from_probs,
|
||||
top_k_renorm_prob,
|
||||
top_k_top_p_sampling_from_probs,
|
||||
top_p_renorm_prob,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -21,12 +21,15 @@ import re
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from flashinfer import SegmentGEMMWrapper
|
||||
|
||||
from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
|
||||
from sglang.srt.lora.lora_config import LoRAConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.utils import replace_submodule
|
||||
from sglang.srt.utils import is_hip, replace_submodule
|
||||
|
||||
# ROCm: flashinfer available later
|
||||
if not is_hip():
|
||||
from flashinfer import SegmentGEMMWrapper
|
||||
|
||||
|
||||
def get_stacked_name(name):
|
||||
|
||||
@@ -19,7 +19,6 @@ limitations under the License.
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from flashinfer import bmm_fp8
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.config import CacheConfig
|
||||
@@ -48,6 +47,11 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
# ROCm: flashinfer available later
|
||||
if not is_hip():
|
||||
from flashinfer import bmm_fp8
|
||||
|
||||
|
||||
class DeepseekV2MLP(nn.Module):
|
||||
|
||||
@@ -19,7 +19,6 @@ import math
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from flashinfer import bmm_fp8
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.config import CacheConfig
|
||||
@@ -44,6 +43,11 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
# ROCm: flashinfer available later
|
||||
if not is_hip():
|
||||
from flashinfer import bmm_fp8
|
||||
|
||||
|
||||
class MiniCPM3MLP(nn.Module):
|
||||
|
||||
@@ -78,6 +78,7 @@ from sglang.srt.utils import (
|
||||
assert_pkg_version,
|
||||
configure_logger,
|
||||
enable_show_time_cost,
|
||||
is_hip,
|
||||
kill_child_process,
|
||||
maybe_set_triton_cache_manager,
|
||||
prepare_model,
|
||||
@@ -434,6 +435,10 @@ def _set_envs_and_config(server_args: ServerArgs):
|
||||
"at https://docs.flashinfer.ai/installation.html.",
|
||||
)
|
||||
|
||||
if is_hip():
|
||||
# to figure out a better method of not using fork later
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
|
||||
def _wait_and_warmup(server_args, pipe_finish_writer, pid):
|
||||
headers = {}
|
||||
|
||||
@@ -21,6 +21,8 @@ import logging
|
||||
import random
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -164,6 +166,11 @@ class ServerArgs:
|
||||
)
|
||||
self.sampling_backend = "pytorch"
|
||||
|
||||
# ROCm: flashinfer available later
|
||||
if is_hip():
|
||||
self.attention_backend = "triton"
|
||||
self.sampling_backend = "pytorch"
|
||||
|
||||
# Default kernel backends
|
||||
if self.enable_mla:
|
||||
logger.info("MLA optimization is tunred on. Use triton backend.")
|
||||
|
||||
@@ -51,6 +51,11 @@ show_time_cost = False
|
||||
time_infos = {}
|
||||
|
||||
|
||||
# torch flag AMD GPU
|
||||
def is_hip() -> bool:
|
||||
return torch.version.hip is not None
|
||||
|
||||
|
||||
def enable_show_time_cost():
|
||||
global show_time_cost
|
||||
show_time_cost = True
|
||||
|
||||
Reference in New Issue
Block a user