diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index 9047197af..a3aeda9c4 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -13,6 +13,7 @@ limitations under the License. """Fused operators for activation layers.""" +import logging from typing import Optional import torch @@ -28,6 +29,10 @@ from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.utils import set_weight_attrs +from sglang.srt.utils import is_hip + +logger = logging.getLogger(__name__) + class SiluAndMul(CustomOp): def forward_native(self, x: torch.Tensor) -> torch.Tensor: @@ -135,3 +140,10 @@ def get_act_fn( act_fn, intermediate_size, input_is_parallel, params_dtype ) return act_fn + + +if is_hip(): + logger.info( + "FlashInfer is not available on AMD GPUs. Fallback to other kernel libraries." + ) + from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul diff --git a/python/sglang/srt/layers/attention_backend.py b/python/sglang/srt/layers/attention_backend.py index c01016bbd..73bdf512b 100644 --- a/python/sglang/srt/layers/attention_backend.py +++ b/python/sglang/srt/layers/attention_backend.py @@ -12,22 +12,26 @@ from typing import TYPE_CHECKING import torch import torch.nn as nn -from flashinfer import ( - BatchDecodeWithPagedKVCacheWrapper, - BatchPrefillWithPagedKVCacheWrapper, - BatchPrefillWithRaggedKVCacheWrapper, -) -from flashinfer.cascade import merge_state -from flashinfer.decode import _grouped_size_compiled_for_decode_kernels from sglang.global_config import global_config from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata +from sglang.srt.utils import is_hip if TYPE_CHECKING: from sglang.srt.model_executor.model_runner import ModelRunner +# ROCm: flashinfer available later +if not is_hip(): + from flashinfer import ( + BatchDecodeWithPagedKVCacheWrapper, + BatchPrefillWithPagedKVCacheWrapper, + BatchPrefillWithRaggedKVCacheWrapper, + ) + from flashinfer.cascade import merge_state + from flashinfer.decode import _grouped_size_compiled_for_decode_kernels + class AttentionBackend(ABC): """The base class of attention backends""" diff --git a/python/sglang/srt/layers/fused_moe/layer.py b/python/sglang/srt/layers/fused_moe/layer.py index e08ec5c58..0511db5a1 100644 --- a/python/sglang/srt/layers/fused_moe/layer.py +++ b/python/sglang/srt/layers/fused_moe/layer.py @@ -18,6 +18,8 @@ from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.utils import set_weight_attrs +from sglang.srt.utils import is_hip + logger = init_logger(__name__) @@ -381,6 +383,7 @@ from torch.nn import Module from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, + normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize, ) from vllm.utils import print_warning_once @@ -479,14 +482,12 @@ class Fp8MoEMethod(FusedMoEMethodBase): def process_weights_after_loading(self, layer: Module) -> None: - # If checkpoint is fp16, quantize in place. + # If checkpoint is fp16 or bfloat16, quantize in place. if not self.quant_config.is_checkpoint_fp8_serialized: - w13_weight = torch.empty_like( - layer.w13_weight.data, dtype=torch.float8_e4m3fn - ) - w2_weight = torch.empty_like( - layer.w2_weight.data, dtype=torch.float8_e4m3fn - ) + # If ROCm, use float8_e4m3fnuz instead (MI300x HW) + fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn + w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) + w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) # Re-initialize w13_scale because we directly quantize # merged w13 weights and generate a single scaling factor. @@ -534,6 +535,25 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer.a2_scale.max(), requires_grad=False ) + # If ROCm, normalize the weights and scales to e4m3fnuz + if is_hip(): + # Normalize the weights and scales + w13_weight, w13_scale, a13_scale = normalize_e4m3fn_to_e4m3fnuz( + layer.w13_weight, layer.w13_scale, layer.a13_scale + ) + w2_weight, w2_scale, a2_scale = normalize_e4m3fn_to_e4m3fnuz( + layer.w2_weight, layer.w2_scale, layer.a2_scale + ) + # Reset the parameters + layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) + layer.w13_scale = torch.nn.Parameter(w13_scale, requires_grad=False) + if a13_scale is not None: + layer.a13_scale = torch.nn.Parameter(a13_scale, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + layer.w2_scale = torch.nn.Parameter(w2_scale, requires_grad=False) + if a2_scale is not None: + layer.a2_scale = torch.nn.Parameter(a2_scale, requires_grad=False) + # Fp8 moe kernel needs single weight scale for w13 per expert. # We take the max then dequant and requant each expert. assert layer.w13_scale is not None diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 4c24f50ff..c4803a334 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -15,6 +15,7 @@ limitations under the License. """Fused operators for normalization layers.""" +import logging from typing import Optional, Tuple, Union import torch @@ -27,6 +28,10 @@ from flashinfer.norm import ( ) from vllm.model_executor.custom_op import CustomOp +from sglang.srt.utils import is_hip + +logger = logging.getLogger(__name__) + class RMSNorm(CustomOp): def __init__( @@ -109,3 +114,10 @@ class GemmaRMSNorm(CustomOp): return x, residual out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon) return out + + +if is_hip(): + logger.info( + "FlashInfer is not available on AMD GPUs. Fallback to other kernel libraries." + ) + from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 55b2daab4..88ae1322a 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -2,17 +2,21 @@ import logging from typing import Union import torch -from flashinfer.sampling import ( - min_p_sampling_from_probs, - top_k_renorm_prob, - top_k_top_p_sampling_from_probs, - top_p_renorm_prob, -) from torch import nn from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo +from sglang.srt.utils import is_hip + +# ROCm: flashinfer available later +if not is_hip(): + from flashinfer.sampling import ( + min_p_sampling_from_probs, + top_k_renorm_prob, + top_k_top_p_sampling_from_probs, + top_p_renorm_prob, + ) logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index a9ea232ed..d0a604fe6 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -21,12 +21,15 @@ import re from dataclasses import dataclass import torch -from flashinfer import SegmentGEMMWrapper from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer from sglang.srt.lora.lora_config import LoRAConfig from sglang.srt.model_executor.forward_batch_info import ForwardMode -from sglang.srt.utils import replace_submodule +from sglang.srt.utils import is_hip, replace_submodule + +# ROCm: flashinfer available later +if not is_hip(): + from flashinfer import SegmentGEMMWrapper def get_stacked_name(name): diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 0ff236f8f..7bc106abc 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -19,7 +19,6 @@ limitations under the License. from typing import Any, Dict, Iterable, Optional, Tuple import torch -from flashinfer import bmm_fp8 from torch import nn from transformers import PretrainedConfig from vllm.config import CacheConfig @@ -48,6 +47,11 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import InputMetadata +from sglang.srt.utils import is_hip + +# ROCm: flashinfer available later +if not is_hip(): + from flashinfer import bmm_fp8 class DeepseekV2MLP(nn.Module): diff --git a/python/sglang/srt/models/minicpm3.py b/python/sglang/srt/models/minicpm3.py index ce40a94a7..40278f45d 100644 --- a/python/sglang/srt/models/minicpm3.py +++ b/python/sglang/srt/models/minicpm3.py @@ -19,7 +19,6 @@ import math from typing import Any, Dict, Iterable, Optional, Tuple import torch -from flashinfer import bmm_fp8 from torch import nn from transformers import PretrainedConfig from vllm.config import CacheConfig @@ -44,6 +43,11 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import InputMetadata +from sglang.srt.utils import is_hip + +# ROCm: flashinfer available later +if not is_hip(): + from flashinfer import bmm_fp8 class MiniCPM3MLP(nn.Module): diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 9749075d0..9379b9d08 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -78,6 +78,7 @@ from sglang.srt.utils import ( assert_pkg_version, configure_logger, enable_show_time_cost, + is_hip, kill_child_process, maybe_set_triton_cache_manager, prepare_model, @@ -434,6 +435,10 @@ def _set_envs_and_config(server_args: ServerArgs): "at https://docs.flashinfer.ai/installation.html.", ) + if is_hip(): + # to figure out a better method of not using fork later + mp.set_start_method("spawn", force=True) + def _wait_and_warmup(server_args, pipe_finish_writer, pid): headers = {} diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 818856716..35b99b6af 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -21,6 +21,8 @@ import logging import random from typing import List, Optional, Union +from sglang.srt.utils import is_hip + logger = logging.getLogger(__name__) @@ -164,6 +166,11 @@ class ServerArgs: ) self.sampling_backend = "pytorch" + # ROCm: flashinfer available later + if is_hip(): + self.attention_backend = "triton" + self.sampling_backend = "pytorch" + # Default kernel backends if self.enable_mla: logger.info("MLA optimization is tunred on. Use triton backend.") diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 1f1a44870..92e670479 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -51,6 +51,11 @@ show_time_cost = False time_infos = {} +# torch flag AMD GPU +def is_hip() -> bool: + return torch.version.hip is not None + + def enable_show_time_cost(): global show_time_cost show_time_cost = True