[Feature, Hardware] Enable SGLang on AMD GPUs via PyTorch for ROCm (#1420)
This commit is contained in:
@@ -13,6 +13,7 @@ limitations under the License.
|
|||||||
|
|
||||||
"""Fused operators for activation layers."""
|
"""Fused operators for activation layers."""
|
||||||
|
|
||||||
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -28,6 +29,10 @@ from vllm.model_executor.custom_op import CustomOp
|
|||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
|
||||||
|
from sglang.srt.utils import is_hip
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SiluAndMul(CustomOp):
|
class SiluAndMul(CustomOp):
|
||||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -135,3 +140,10 @@ def get_act_fn(
|
|||||||
act_fn, intermediate_size, input_is_parallel, params_dtype
|
act_fn, intermediate_size, input_is_parallel, params_dtype
|
||||||
)
|
)
|
||||||
return act_fn
|
return act_fn
|
||||||
|
|
||||||
|
|
||||||
|
if is_hip():
|
||||||
|
logger.info(
|
||||||
|
"FlashInfer is not available on AMD GPUs. Fallback to other kernel libraries."
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
|
||||||
|
|||||||
@@ -12,22 +12,26 @@ from typing import TYPE_CHECKING
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from flashinfer import (
|
|
||||||
BatchDecodeWithPagedKVCacheWrapper,
|
|
||||||
BatchPrefillWithPagedKVCacheWrapper,
|
|
||||||
BatchPrefillWithRaggedKVCacheWrapper,
|
|
||||||
)
|
|
||||||
from flashinfer.cascade import merge_state
|
|
||||||
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
|
||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices
|
from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices
|
||||||
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||||
|
from sglang.srt.utils import is_hip
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
|
|
||||||
|
# ROCm: flashinfer available later
|
||||||
|
if not is_hip():
|
||||||
|
from flashinfer import (
|
||||||
|
BatchDecodeWithPagedKVCacheWrapper,
|
||||||
|
BatchPrefillWithPagedKVCacheWrapper,
|
||||||
|
BatchPrefillWithRaggedKVCacheWrapper,
|
||||||
|
)
|
||||||
|
from flashinfer.cascade import merge_state
|
||||||
|
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
||||||
|
|
||||||
|
|
||||||
class AttentionBackend(ABC):
|
class AttentionBackend(ABC):
|
||||||
"""The base class of attention backends"""
|
"""The base class of attention backends"""
|
||||||
|
|||||||
@@ -18,6 +18,8 @@ from vllm.model_executor.layers.quantization.base_config import (
|
|||||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
|
||||||
|
from sglang.srt.utils import is_hip
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -381,6 +383,7 @@ from torch.nn import Module
|
|||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
all_close_1d,
|
all_close_1d,
|
||||||
|
normalize_e4m3fn_to_e4m3fnuz,
|
||||||
per_tensor_dequantize,
|
per_tensor_dequantize,
|
||||||
)
|
)
|
||||||
from vllm.utils import print_warning_once
|
from vllm.utils import print_warning_once
|
||||||
@@ -479,14 +482,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
|
|
||||||
# If checkpoint is fp16, quantize in place.
|
# If checkpoint is fp16 or bfloat16, quantize in place.
|
||||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
w13_weight = torch.empty_like(
|
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
|
||||||
layer.w13_weight.data, dtype=torch.float8_e4m3fn
|
fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
|
||||||
)
|
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
||||||
w2_weight = torch.empty_like(
|
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
||||||
layer.w2_weight.data, dtype=torch.float8_e4m3fn
|
|
||||||
)
|
|
||||||
|
|
||||||
# Re-initialize w13_scale because we directly quantize
|
# Re-initialize w13_scale because we directly quantize
|
||||||
# merged w13 weights and generate a single scaling factor.
|
# merged w13 weights and generate a single scaling factor.
|
||||||
@@ -534,6 +535,25 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
layer.a2_scale.max(), requires_grad=False
|
layer.a2_scale.max(), requires_grad=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||||
|
if is_hip():
|
||||||
|
# Normalize the weights and scales
|
||||||
|
w13_weight, w13_scale, a13_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
layer.w13_weight, layer.w13_scale, layer.a13_scale
|
||||||
|
)
|
||||||
|
w2_weight, w2_scale, a2_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
layer.w2_weight, layer.w2_scale, layer.a2_scale
|
||||||
|
)
|
||||||
|
# Reset the parameters
|
||||||
|
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||||
|
layer.w13_scale = torch.nn.Parameter(w13_scale, requires_grad=False)
|
||||||
|
if a13_scale is not None:
|
||||||
|
layer.a13_scale = torch.nn.Parameter(a13_scale, requires_grad=False)
|
||||||
|
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||||
|
layer.w2_scale = torch.nn.Parameter(w2_scale, requires_grad=False)
|
||||||
|
if a2_scale is not None:
|
||||||
|
layer.a2_scale = torch.nn.Parameter(a2_scale, requires_grad=False)
|
||||||
|
|
||||||
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
||||||
# We take the max then dequant and requant each expert.
|
# We take the max then dequant and requant each expert.
|
||||||
assert layer.w13_scale is not None
|
assert layer.w13_scale is not None
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
"""Fused operators for normalization layers."""
|
"""Fused operators for normalization layers."""
|
||||||
|
|
||||||
|
import logging
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -27,6 +28,10 @@ from flashinfer.norm import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
|
|
||||||
|
from sglang.srt.utils import is_hip
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(CustomOp):
|
class RMSNorm(CustomOp):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -109,3 +114,10 @@ class GemmaRMSNorm(CustomOp):
|
|||||||
return x, residual
|
return x, residual
|
||||||
out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
|
out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
if is_hip():
|
||||||
|
logger.info(
|
||||||
|
"FlashInfer is not available on AMD GPUs. Fallback to other kernel libraries."
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
|
||||||
|
|||||||
@@ -2,17 +2,21 @@ import logging
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from flashinfer.sampling import (
|
|
||||||
min_p_sampling_from_probs,
|
|
||||||
top_k_renorm_prob,
|
|
||||||
top_k_top_p_sampling_from_probs,
|
|
||||||
top_p_renorm_prob,
|
|
||||||
)
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
|
from sglang.srt.utils import is_hip
|
||||||
|
|
||||||
|
# ROCm: flashinfer available later
|
||||||
|
if not is_hip():
|
||||||
|
from flashinfer.sampling import (
|
||||||
|
min_p_sampling_from_probs,
|
||||||
|
top_k_renorm_prob,
|
||||||
|
top_k_top_p_sampling_from_probs,
|
||||||
|
top_p_renorm_prob,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -21,12 +21,15 @@ import re
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from flashinfer import SegmentGEMMWrapper
|
|
||||||
|
|
||||||
from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
|
from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
|
||||||
from sglang.srt.lora.lora_config import LoRAConfig
|
from sglang.srt.lora.lora_config import LoRAConfig
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||||
from sglang.srt.utils import replace_submodule
|
from sglang.srt.utils import is_hip, replace_submodule
|
||||||
|
|
||||||
|
# ROCm: flashinfer available later
|
||||||
|
if not is_hip():
|
||||||
|
from flashinfer import SegmentGEMMWrapper
|
||||||
|
|
||||||
|
|
||||||
def get_stacked_name(name):
|
def get_stacked_name(name):
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ limitations under the License.
|
|||||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from flashinfer import bmm_fp8
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig
|
||||||
@@ -48,6 +47,11 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
|||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
from sglang.srt.utils import is_hip
|
||||||
|
|
||||||
|
# ROCm: flashinfer available later
|
||||||
|
if not is_hip():
|
||||||
|
from flashinfer import bmm_fp8
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV2MLP(nn.Module):
|
class DeepseekV2MLP(nn.Module):
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ import math
|
|||||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from flashinfer import bmm_fp8
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig
|
||||||
@@ -44,6 +43,11 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
|||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
from sglang.srt.utils import is_hip
|
||||||
|
|
||||||
|
# ROCm: flashinfer available later
|
||||||
|
if not is_hip():
|
||||||
|
from flashinfer import bmm_fp8
|
||||||
|
|
||||||
|
|
||||||
class MiniCPM3MLP(nn.Module):
|
class MiniCPM3MLP(nn.Module):
|
||||||
|
|||||||
@@ -78,6 +78,7 @@ from sglang.srt.utils import (
|
|||||||
assert_pkg_version,
|
assert_pkg_version,
|
||||||
configure_logger,
|
configure_logger,
|
||||||
enable_show_time_cost,
|
enable_show_time_cost,
|
||||||
|
is_hip,
|
||||||
kill_child_process,
|
kill_child_process,
|
||||||
maybe_set_triton_cache_manager,
|
maybe_set_triton_cache_manager,
|
||||||
prepare_model,
|
prepare_model,
|
||||||
@@ -434,6 +435,10 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|||||||
"at https://docs.flashinfer.ai/installation.html.",
|
"at https://docs.flashinfer.ai/installation.html.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if is_hip():
|
||||||
|
# to figure out a better method of not using fork later
|
||||||
|
mp.set_start_method("spawn", force=True)
|
||||||
|
|
||||||
|
|
||||||
def _wait_and_warmup(server_args, pipe_finish_writer, pid):
|
def _wait_and_warmup(server_args, pipe_finish_writer, pid):
|
||||||
headers = {}
|
headers = {}
|
||||||
|
|||||||
@@ -21,6 +21,8 @@ import logging
|
|||||||
import random
|
import random
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
from sglang.srt.utils import is_hip
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -164,6 +166,11 @@ class ServerArgs:
|
|||||||
)
|
)
|
||||||
self.sampling_backend = "pytorch"
|
self.sampling_backend = "pytorch"
|
||||||
|
|
||||||
|
# ROCm: flashinfer available later
|
||||||
|
if is_hip():
|
||||||
|
self.attention_backend = "triton"
|
||||||
|
self.sampling_backend = "pytorch"
|
||||||
|
|
||||||
# Default kernel backends
|
# Default kernel backends
|
||||||
if self.enable_mla:
|
if self.enable_mla:
|
||||||
logger.info("MLA optimization is tunred on. Use triton backend.")
|
logger.info("MLA optimization is tunred on. Use triton backend.")
|
||||||
|
|||||||
@@ -51,6 +51,11 @@ show_time_cost = False
|
|||||||
time_infos = {}
|
time_infos = {}
|
||||||
|
|
||||||
|
|
||||||
|
# torch flag AMD GPU
|
||||||
|
def is_hip() -> bool:
|
||||||
|
return torch.version.hip is not None
|
||||||
|
|
||||||
|
|
||||||
def enable_show_time_cost():
|
def enable_show_time_cost():
|
||||||
global show_time_cost
|
global show_time_cost
|
||||||
show_time_cost = True
|
show_time_cost = True
|
||||||
|
|||||||
Reference in New Issue
Block a user