diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index cc2c0ca4d..9c4468c11 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -2,7 +2,7 @@ # docker build --build-arg SGL_BRANCH=v0.4.3.post4 -t v0.4.3.post4-rocm630 -f Dockerfile.rocm . # default base image -ARG BASE_IMAGE="rocm/sgl-dev:vllm20250114" +ARG BASE_IMAGE="rocm/sgl-dev:20250114vllm-blas-flash" FROM $BASE_IMAGE AS base USER root @@ -16,10 +16,10 @@ ARG SGL_BRANCH=${SGL_DEFAULT} ARG TRITON_REPO="https://github.com/ROCm/triton.git" ARG TRITON_COMMIT="improve_fa_decode_3.0.0" - ARG AITER_REPO="https://github.com/ROCm/aiter.git" ARG AITER_COMMIT="testx" + RUN git clone ${SGL_REPO} \ && cd sglang \ && if [ "${SGL_BRANCH}" = ${SGL_DEFAULT} ]; then \ @@ -59,6 +59,7 @@ RUN git clone ${AITER_REPO} \ && git submodule update --init --recursive \ && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop + # Copy config files to support MI300X in virtualized environments (MI300X_VF). Symlinks will not be created in image build. RUN find /sgl-workspace/sglang/python/sglang/srt/layers/quantization/configs/ \ /sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/configs/ \ diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py new file mode 100644 index 000000000..5f7e091c9 --- /dev/null +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -0,0 +1,605 @@ +from __future__ import annotations + +""" +end to end attention solution with aiter kernels +""" + +import math +import os +from dataclasses import dataclass +from enum import Enum, auto +from functools import partial +from typing import TYPE_CHECKING, List, Optional, Union + +import torch +import triton +import triton.language as tl + +from sglang.global_config import global_config +from sglang.srt.layers.attention.base_attn_backend import AttentionBackend +from sglang.srt.layers.dp_attention import get_attention_tp_size +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode + +if TYPE_CHECKING: + from sglang.srt.layers.radix_attention import RadixAttention + from sglang.srt.model_executor.model_runner import ModelRunner + from sglang.srt.speculative.spec_info import SpecInfo + +# flashinfer AMD fork +from flashinfer import BatchPrefillWithPagedKVCacheWrapper + +try: + from aiter import paged_attention_rocm +except ImportError: + print( + "aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device." + ) + + +class WrapperDispatch(Enum): + SLIDING_WINDOW = auto() + CROSS_ATTENTION = auto() + + +@dataclass +class DecodeMetadata: + kv_indptr: torch.Tensor + kv_indices: torch.Tensor + + +@dataclass +class PrefillMetadata: + prefill_wrapper: BatchPrefillWithPagedKVCacheWrapper + extend_no_prefix: bool + + +global_workspace_buffer = None + +_AITER_PARTITION_SIZE_ROCM = 256 + + +class AiterAttnBackend(AttentionBackend): + def __init__( + self, + model_runner: ModelRunner, + skip_prefill: bool = False, + kv_indptr_buf: Optional[torch.Tensor] = None, + ): + super().__init__() + + self.device = model_runner.device + self.is_multimodal = model_runner.model_config.is_multimodal + self.num_head = ( + model_runner.model_config.num_attention_heads // get_attention_tp_size() + ) + self.head_dim = model_runner.model_config.head_dim + self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1] + self.num_kv_head = model_runner.model_config.get_num_kv_heads( + get_attention_tp_size() + ) + self.kv_cache_dtype = model_runner.kv_cache_dtype + + self.req_to_token = model_runner.req_to_token_pool.req_to_token + + # Parse constants + self.max_context_len = model_runner.model_config.context_len + self.skip_prefill = skip_prefill + + # Qwen2 models require higher flashinfer workspace size + if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures: + global_config.flashinfer_workspace_size = 512 * 1024 * 1024 + + global global_workspace_buffer + if global_workspace_buffer is None: + global_workspace_buffer = torch.empty( + global_config.flashinfer_workspace_size, + dtype=torch.uint8, + device=model_runner.device, + ) + + self.workspace_buffer = global_workspace_buffer + max_bs = model_runner.req_to_token_pool.size + + if kv_indptr_buf is None: + self.kv_indptr = torch.zeros( + (max_bs + 1,), dtype=torch.int32, device=model_runner.device + ) + else: + self.kv_indptr = kv_indptr_buf + + self.kv_last_page_len = torch.ones( + (max_bs,), dtype=torch.int32, device=model_runner.device + ) + self.qo_indptr = torch.zeros( + (max_bs + 1,), dtype=torch.int32, device=model_runner.device + ) + + self.prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper( + self.workspace_buffer, "NHD", backend="fa2" + ) + self.prefill_wrapper_verify = BatchPrefillWithPagedKVCacheWrapper( + self.workspace_buffer, "NHD" + ) + + # Create prefill indices updater + if not skip_prefill: + self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill( + model_runner, self + ) + + # aiter kernel related initialization + self.max_num_partitions = ( + self.max_context_len + _AITER_PARTITION_SIZE_ROCM - 1 + ) // _AITER_PARTITION_SIZE_ROCM + + nbyes_per_qo_elem = torch.finfo(torch.float32).bits // 8 + + self.workspace_buffer = torch.empty( + (max_bs * self.num_head * self.max_num_partitions * self.head_dim) + * nbyes_per_qo_elem + + 2 * (max_bs * self.num_head * self.max_num_partitions) * 4, + dtype=torch.uint8, + device=self.device, + ) + + self.scale = float(1.0 / (self.head_dim**0.5)) + self.k_scale = self.v_scale = torch.tensor([1.0], dtype=torch.float32).to( + self.device + ) + self.kv_last_page_lens = torch.ones((max_bs,), dtype=torch.int32).to( + self.device + ) + + self.logits_soft_cap = 0.0 + + self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None + self.decode_cuda_graph_metadata = {} + self.prefill_cuda_graph_metadata = {} + + def init_forward_metadata(self, forward_batch: ForwardBatch): + if forward_batch.forward_mode.is_decode_or_idle(): + # update for aiter + # create kv_indices and kv_inptr + bs = forward_batch.batch_size + kv_indptr = self.kv_indptr + spec_info = forward_batch.spec_info + if spec_info is None: + kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = torch.zeros( + forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device + ) + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + else: + kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices + bs = kv_indptr.shape[0] - 1 + + self.forward_metadata = DecodeMetadata(kv_indptr, kv_indices) + + elif forward_batch.forward_mode.is_draft_extend(): + self.indices_updater_prefill.update( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens_sum, + prefix_lens=None, + prefill_wrapper=self.prefill_wrapper_paged, + encoder_lens=forward_batch.encoder_lens, + spec_info=forward_batch.spec_info, + ) + self.forward_metadata = PrefillMetadata( + self.prefill_wrapper_paged, False, False + ) + elif forward_batch.forward_mode.is_target_verify(): + self.indices_updater_prefill.update( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens_sum, + prefix_lens=None, + prefill_wrapper=self.prefill_wrapper_verify, + encoder_lens=forward_batch.encoder_lens, + spec_info=forward_batch.spec_info, + ) + self.forward_metadata = PrefillMetadata( + self.prefill_wrapper_verify, False, False + ) + else: + prefix_lens = forward_batch.extend_prefix_lens + + if self.is_multimodal: + extend_no_prefix = False + else: + extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu) + + self.indices_updater_prefill.update( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens_sum, + prefix_lens, + prefill_wrapper=self.prefill_wrapper_paged, + encoder_lens=forward_batch.encoder_lens, + spec_info=None, + ) + self.forward_metadata = PrefillMetadata( + self.prefill_wrapper_paged, extend_no_prefix + ) + + def init_cuda_graph_state( + self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None + ): + if kv_indices_buf is None: + self.cuda_graph_kv_indices = torch.zeros( + (max_bs * self.max_context_len), + dtype=torch.int32, + device=self.device, + ) + else: + self.cuda_graph_kv_indices = kv_indices_buf + + if not self.skip_prefill: + self.cuda_graph_custom_mask = torch.zeros( + (max_bs * self.max_context_len), + dtype=torch.uint8, + device=self.device, + ) + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + num_tokens: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInfo], + ): + if forward_mode.is_decode_or_idle(): + if spec_info is None: + kv_indptr = self.kv_indptr + kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = self.cuda_graph_kv_indices + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + else: + kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices + self.forward_metadata = DecodeMetadata(kv_indptr, kv_indices) + self.decode_cuda_graph_metadata[bs] = self.forward_metadata + + elif forward_mode.is_target_verify(): + prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( + self.workspace_buffer, + "NHD", + use_cuda_graph=False, + qo_indptr_buf=self.cuda_graph_qo_indptr[: bs + 1], + paged_kv_indptr_buf=self.kv_indptr[: bs + 1], + paged_kv_indices_buf=self.cuda_graph_kv_indices, + paged_kv_last_page_len_buf=self.kv_last_page_len[:bs], + custom_mask_buf=self.cuda_graph_custom_mask, + mask_indptr_buf=self.cuda_graph_qk_indptr[: bs + 1], + ) + + seq_lens_sum = seq_lens.sum().item() + self.indices_updater_prefill.update( + req_pool_indices, + seq_lens, + seq_lens_sum, + prefix_lens=None, + prefill_wrapper=prefill_wrapper, + encoder_lens=encoder_lens, + spec_info=spec_info, + ) + self.prefill_cuda_graph_metadata[bs] = prefill_wrapper + self.forward_metadata = PrefillMetadata(prefill_wrapper, False) + else: + raise ValueError(f"Invalid mode: {forward_mode=}") + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInfo], + seq_lens_cpu: Optional[torch.Tensor], + ): + if forward_mode.is_decode_or_idle(): + kv_indptr = self.kv_indptr + kv_indices = self.cuda_graph_kv_indices + if spec_info is None: + kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0) + kv_indptr = kv_indptr[: bs + 1] + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices[:bs], + seq_lens[:bs], + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + else: + kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr + kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices + + elif forward_mode.is_target_verify(): + self.indices_updater_prefill.update( + req_pool_indices[:bs], + seq_lens[:bs], + seq_lens_sum, + prefix_lens=None, + prefill_wrapper=self.prefill_cuda_graph_metadata[bs], + encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None, + spec_info=spec_info, + ) + else: + raise ValueError("Invalid forward mode") + + def get_cuda_graph_seq_len_fill_value(self): + return 1 + + def forward_extend( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + ): + prefill_wrapper_paged = self.forward_metadata.prefill_wrapper + cache_loc = ( + forward_batch.out_cache_loc + if not layer.is_cross_attention + else forward_batch.encoder_out_cache_loc + ) + + self.logits_soft_cap = layer.logit_cap + + if k is not None: + assert v is not None + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, cache_loc, k, v, layer.k_scale, layer.v_scale + ) + + o = prefill_wrapper_paged.forward( + q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), + causal=not layer.is_cross_attention, + sm_scale=layer.scaling, + window_left=layer.sliding_window_size, + logits_soft_cap=self.logits_soft_cap, + k_scale=layer.k_scale, + v_scale=layer.v_scale, + ) + + return o.view(-1, layer.tp_q_head_num * layer.head_dim) + + def forward_decode( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + ): + q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) + + if layer.qk_head_dim != layer.v_head_dim: + o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) + else: + o = torch.empty_like(q) + + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v + ) + + self.logits_soft_cap = layer.logit_cap + paged_attention_rocm( + o.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + self.workspace_buffer, + q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).view( + -1, 1, layer.tp_k_head_num, layer.qk_head_dim + ), + forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id).view( + -1, 1, layer.tp_v_head_num, layer.v_head_dim + ), + self.scale, + self.forward_metadata.kv_indptr, + self.forward_metadata.kv_indices, + self.kv_last_page_lens, + 1, + self.max_num_partitions, + None, + "auto", + "NHD", + self.logits_soft_cap, + self.k_scale, + self.v_scale, + None, + _AITER_PARTITION_SIZE_ROCM, + ) + + return o + + +class FlashInferIndicesUpdaterPrefill: + def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): + # Parse Constants + self.num_qo_heads = ( + model_runner.model_config.num_attention_heads // get_attention_tp_size() + ) + self.num_kv_heads = model_runner.model_config.get_num_kv_heads( + get_attention_tp_size() + ) + self.head_dim = model_runner.model_config.head_dim + self.data_type = model_runner.kv_cache_dtype + self.q_data_type = model_runner.dtype + self.sliding_window_size = model_runner.sliding_window_size + self.attn_backend = attn_backend + + # Buffers and wrappers + self.kv_indptr = attn_backend.kv_indptr + self.kv_last_page_len = attn_backend.kv_last_page_len + self.qo_indptr = attn_backend.qo_indptr + self.req_to_token = model_runner.req_to_token_pool.req_to_token + self.update = self.update_single_wrapper + + def update( + self, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + prefix_lens: torch.Tensor, + prefill_wrapper: BatchPrefillWithPagedKVCacheWrapper, + encoder_lens: Optional[torch.Tensor], + spec_info: Optional[SpecInfo], + ): + # Keep the signature for type checking. It will be assigned during runtime. + raise NotImplementedError() + + def update_single_wrapper( + self, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + prefix_lens: torch.Tensor, + prefill_wrapper: BatchPrefillWithPagedKVCacheWrapper, + encoder_lens: Optional[torch.Tensor], + spec_info: Optional[SpecInfo], + ): + + paged_kernel_lens = seq_lens + paged_kernel_lens_sum = seq_lens_sum + + self.call_begin_forward( + prefill_wrapper, + req_pool_indices, + paged_kernel_lens, + paged_kernel_lens_sum, + seq_lens, + prefix_lens, + None, + self.kv_indptr, + self.qo_indptr, + spec_info, + ) + + def call_begin_forward( + self, + wrapper_paged: BatchPrefillWithPagedKVCacheWrapper, + req_pool_indices: torch.Tensor, + paged_kernel_lens: torch.Tensor, + paged_kernel_lens_sum: int, + seq_lens: torch.Tensor, + prefix_lens: torch.Tensor, + kv_start_idx: torch.Tensor, + kv_indptr: torch.Tensor, + qo_indptr: torch.Tensor, + spec_info: Optional[SpecInfo], + ): + bs = len(req_pool_indices) + if spec_info is None: + # Normal extend + kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = torch.empty( + paged_kernel_lens_sum + 256, + dtype=torch.int32, + device=req_pool_indices.device, + ) + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + paged_kernel_lens, + kv_indptr, + kv_start_idx, + kv_indices, + self.req_to_token.shape[1], + ) + + qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0) + qo_indptr = qo_indptr[: bs + 1] + custom_mask = None + else: + kv_indices, kv_indptr, qo_indptr, custom_mask = ( + spec_info.generate_attn_arg_prefill( + req_pool_indices, + paged_kernel_lens, + self.req_to_token, + ) + ) + + # cached part + # adding logits_soft_cap arg in plan() stage + wrapper_paged.begin_forward( + qo_indptr, + kv_indptr, + kv_indices, + self.kv_last_page_len[:bs], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + 1, + q_data_type=self.q_data_type, + custom_mask=custom_mask, + non_blocking=True, + logits_soft_cap=self.attn_backend.logits_soft_cap, + ) + + +@triton.jit +def create_flashinfer_kv_indices_triton( + req_to_token_ptr, # [max_batch, max_context_len] + req_pool_indices_ptr, + page_kernel_lens_ptr, + kv_indptr, + kv_start_idx, + kv_indices_ptr, + req_to_token_ptr_stride: tl.constexpr, +): + BLOCK_SIZE: tl.constexpr = 512 + pid = tl.program_id(axis=0) + + req_pool_index = tl.load(req_pool_indices_ptr + pid) + kv_indices_offset = tl.load(kv_indptr + pid) + + kv_start = 0 + kv_end = 0 + if kv_start_idx: + kv_start = tl.load(kv_start_idx + pid).to(tl.int32) + kv_end = kv_start + kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32) + + num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE) + for i in range(num_loop): + offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + mask = offset < kv_end - kv_start + data = tl.load( + req_to_token_ptr + + req_pool_index * req_to_token_ptr_stride + + kv_start + + offset, + mask=mask, + ) + tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask) diff --git a/python/sglang/srt/layers/attention/aiter_decode_backend.py b/python/sglang/srt/layers/attention/aiter_decode_backend.py new file mode 100644 index 000000000..b8d3c77ef --- /dev/null +++ b/python/sglang/srt/layers/attention/aiter_decode_backend.py @@ -0,0 +1,535 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +import torch +import triton +import triton.language as tl + +from sglang.srt.layers.attention.base_attn_backend import AttentionBackend +from sglang.srt.layers.dp_attention import get_attention_tp_size +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode + +if TYPE_CHECKING: + from sglang.srt.layers.radix_attention import RadixAttention + from sglang.srt.model_executor.model_runner import ModelRunner + from sglang.srt.speculative.spec_info import SpecInfo + +try: + from aiter import paged_attention_rocm +except ImportError: + print( + "aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device." + ) + +from sglang.srt.layers.attention.triton_ops.extend_attention import extend_attention_fwd + +_AITER_PARTITION_SIZE_ROCM = 256 + + +class AiterDecodeAttnBackend(AttentionBackend): + def __init__( + self, + model_runner: ModelRunner, + skip_prefill: bool = False, + kv_indptr_buf: Optional[torch.Tensor] = None, + ): + super().__init__() + + self.decode_attention_fwd = paged_attention_rocm + self.extend_attention_fwd = extend_attention_fwd + + self.skip_prefill = skip_prefill + + max_bs = model_runner.req_to_token_pool.size + + if kv_indptr_buf is None: + self.kv_indptr = torch.zeros( + (max_bs + 1,), dtype=torch.int32, device=model_runner.device + ) + else: + self.kv_indptr = kv_indptr_buf + + self.req_to_token = model_runner.req_to_token_pool.req_to_token + + if not self.skip_prefill: + self.qo_indptr = torch.zeros( + (max_bs + 1,), dtype=torch.int32, device=model_runner.device + ) + + self.mask_indptr = torch.zeros( + (max_bs + 1,), dtype=torch.int64, device=model_runner.device + ) + + self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens + + # tp sharding on number of heads + self.num_head = ( + model_runner.model_config.num_attention_heads // get_attention_tp_size() + ) + + self.head_dim = model_runner.model_config.head_dim + + # triton prefill initialization + self.num_kv_splits = model_runner.server_args.triton_attention_num_kv_splits + + self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1] + + self.num_v_head = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-2] + + self.forward_metadata = None + + self.max_context_len = model_runner.model_config.context_len + + self.device = model_runner.device + + self.kv_cache_dtype = model_runner.kv_cache_dtype + + self.q_dtype = model_runner.model_config.dtype + + # aiter decode initialization + self.max_num_partitions = ( + self.max_context_len + _AITER_PARTITION_SIZE_ROCM - 1 + ) // _AITER_PARTITION_SIZE_ROCM + + nbyes_per_qo_elem = torch.finfo(torch.float32).bits // 8 + + self.workspace_buffer = torch.empty( + (max_bs * self.num_head * self.max_num_partitions * self.head_dim) + * nbyes_per_qo_elem + + 2 * (max_bs * self.num_head * self.max_num_partitions) * 4, + dtype=torch.uint8, + device=self.device, + ) + + self.scale = float(1.0 / (self.head_dim**0.5)) + self.k_scale = self.v_scale = torch.tensor([1.0], dtype=torch.float32).to( + self.device + ) + self.kv_last_page_lens = torch.ones((max_bs,), dtype=torch.int32).to( + self.device + ) + + def init_forward_metadata(self, forward_batch: ForwardBatch): + """Init auxiliary variables""" + bs = forward_batch.batch_size + kv_indptr = self.kv_indptr + spec_info = forward_batch.spec_info + + if forward_batch.forward_mode.is_decode_or_idle(): + if spec_info is None: + kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = torch.zeros( + forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device + ) + # prepare kv_indices and kv_indptr + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + else: + kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices + bs = kv_indptr.shape[0] - 1 + + attn_logits = None # accomodate forward_metadata format + qo_indptr = None + custom_mask = None + mask_indptr = None + max_extend_len = None + elif forward_batch.forward_mode.is_target_verify(): + bs = len(forward_batch.req_pool_indices) + qo_indptr = torch.arange( + 0, + (1 + bs) * self.num_draft_tokens, + step=self.num_draft_tokens, + dtype=torch.int32, + device=self.device, + ) + kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = torch.zeros( + kv_indptr[-1], dtype=torch.int32, device=self.device + ) + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + + custom_mask = spec_info.custom_mask + seq_mask_len = self.num_draft_tokens * ( + forward_batch.seq_lens + self.num_draft_tokens + ) + mask_indptr = self.mask_indptr + mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len[:bs], dim=0) + mask_indptr = mask_indptr[: bs + 1] + max_extend_len = self.num_draft_tokens + attn_logits = None + elif forward_batch.forward_mode.is_draft_extend(): + kv_indices, kv_indptr, qo_indptr, custom_mask = ( + spec_info.generate_attn_arg_prefill( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + self.req_to_token, + ) + ) + mask_indptr = None + max_extend_len = torch.max(spec_info.accept_length).item() + attn_logits = None + else: + kv_indptr[1 : bs + 1] = torch.cumsum( + forward_batch.extend_prefix_lens, dim=0 + ) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = torch.zeros( + forward_batch.extend_prefix_lens.sum().item(), + dtype=torch.int32, + device=self.device, + ) + + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + forward_batch.req_pool_indices, + forward_batch.extend_prefix_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + + qo_indptr = self.qo_indptr + qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0) + qo_indptr = qo_indptr[: bs + 1] + custom_mask = None + mask_indptr = None + attn_logits = None + max_extend_len = torch.max(forward_batch.extend_seq_lens).item() + + self.forward_metadata = ( + attn_logits, + max_extend_len, + kv_indptr, + kv_indices, + qo_indptr, + custom_mask, + mask_indptr, + ) + + def init_cuda_graph_state( + self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None + ): + + self.cuda_graph_attn_logits = torch.zeros( + (max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1), + dtype=torch.float32, + device=self.device, + ) + if kv_indices_buf is None: + self.cuda_graph_kv_indices = torch.zeros( + (max_bs * self.max_context_len), + dtype=torch.int32, + device=self.device, + ) + else: + self.cuda_graph_kv_indices = kv_indices_buf + + if not self.skip_prefill: + self.cuda_graph_custom_mask = torch.zeros( + (max_bs * self.max_context_len), + dtype=torch.uint8, + device=self.device, + ) + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + num_tokens: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInfo], + ): + assert encoder_lens is None, "Not supported" + + if forward_mode.is_decode_or_idle(): + if spec_info is None: + kv_indptr = self.kv_indptr + kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = self.cuda_graph_kv_indices + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + else: + kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices + attn_logits = None + max_extend_len = None + qo_indptr = None + custom_mask = None + mask_indptr = None + elif forward_mode.is_target_verify(): + qo_indptr = self.qo_indptr[: bs + 1] + qo_indptr[: bs + 1] = torch.arange( + 0, + (1 + bs) * self.num_draft_tokens, + step=self.num_draft_tokens, + dtype=torch.int32, + device=self.device, + ) + kv_indptr = self.kv_indptr[: bs + 1] + kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) + kv_indices = self.cuda_graph_kv_indices + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + + custom_mask = self.cuda_graph_custom_mask + seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens) + mask_indptr = self.mask_indptr[: bs + 1] + mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0) + max_extend_len = self.num_draft_tokens + attn_logits = None + else: + raise ValueError( + f"Invalid forward mode: {forward_mode=} for CUDA Graph capture." + ) + + self.forward_metadata = ( + attn_logits, + max_extend_len, + kv_indptr, + kv_indices, + qo_indptr, + custom_mask, + mask_indptr, + ) + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInfo], + seq_lens_cpu: Optional[torch.Tensor], + ): + # NOTE: encoder_lens expected to be zeros or None + if forward_mode.is_decode_or_idle(): + # Update kv_indptr, kv_indices + kv_indptr = self.kv_indptr + kv_indices = self.cuda_graph_kv_indices + if spec_info is None: + kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0) + kv_indptr = kv_indptr[: bs + 1] + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices[:bs], + seq_lens[:bs], + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + else: + kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr + kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices + elif forward_mode.is_target_verify(): + # Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr + bs = len(req_pool_indices) + qo_indptr = self.qo_indptr[: bs + 1] + qo_indptr[: bs + 1] = torch.arange( + 0, + (1 + bs) * self.num_draft_tokens, + step=self.num_draft_tokens, + dtype=torch.int32, + device=self.device, + ) + kv_indptr = self.kv_indptr[: bs + 1] + kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) + kv_indices = self.cuda_graph_kv_indices + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + custom_mask = self.cuda_graph_custom_mask + custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask + seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens) + mask_indptr = self.mask_indptr[: bs + 1] + mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0) + else: + raise ValueError( + f"Invalid forward mode: {forward_mode=} for CUDA Graph replay." + ) + + def get_cuda_graph_seq_len_fill_value(self): + return 1 + + def forward_extend( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + ): + # TODO: reuse the buffer across layers + if layer.qk_head_dim != layer.v_head_dim: + o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) + else: + o = torch.empty_like(q) + + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v + ) + + ( + _, + max_extend_len, + kv_indptr, + kv_indices, + qo_indptr, + custom_mask, + mask_indptr, + ) = self.forward_metadata + + self.extend_attention_fwd( + q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + k.contiguous(), + v.contiguous(), + o.view(-1, layer.tp_q_head_num, layer.v_head_dim), + forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id), + forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id), + qo_indptr, + kv_indptr, + kv_indices, + custom_mask, + mask_indptr, + max_extend_len, + layer.scaling, + layer.logit_cap, + ) + return o + + def forward_decode( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + ): + # During torch.compile, there is a bug in rotary_emb that causes the + # output value to have a 3D tensor shape. This reshapes the output correctly. + q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) + + if layer.qk_head_dim != layer.v_head_dim: + o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) + else: + o = torch.empty_like(q) + + attn_logits, _, kv_indptr, kv_indices, _, _, _ = self.forward_metadata + + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v + ) + + self.decode_attention_fwd( + o.view( + -1, layer.tp_q_head_num, layer.qk_head_dim + ), # (bs, head_num_q, head_dim_q) + self.workspace_buffer, + q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).view( + -1, 1, layer.tp_k_head_num, layer.qk_head_dim + ), + forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id).view( + -1, 1, layer.tp_v_head_num, layer.v_head_dim + ), + self.scale, + kv_indptr, + kv_indices, + self.kv_last_page_lens, + 1, + self.max_num_partitions, + None, + "auto", + "NHD", + layer.logit_cap, + self.k_scale, + self.v_scale, + None, + _AITER_PARTITION_SIZE_ROCM, + ) + + return o + + +@triton.jit +def create_flashinfer_kv_indices_triton( + req_to_token_ptr, # [max_batch, max_context_len] + req_pool_indices_ptr, + page_kernel_lens_ptr, + kv_indptr, + kv_start_idx, + kv_indices_ptr, + req_to_token_ptr_stride: tl.constexpr, +): + BLOCK_SIZE: tl.constexpr = 512 + pid = tl.program_id(axis=0) + + req_pool_index = tl.load(req_pool_indices_ptr + pid) + kv_indices_offset = tl.load(kv_indptr + pid) + + kv_start = 0 + kv_end = 0 + if kv_start_idx: + kv_start = tl.load(kv_start_idx + pid).to(tl.int32) + kv_end = kv_start + kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32) + + num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE) + for i in range(num_loop): + offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + mask = offset < kv_end - kv_start + data = tl.load( + req_to_token_ptr + + req_pool_index * req_to_token_ptr_stride + + kv_start + + offset, + mask=mask, + ) + tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 666b97e2b..8c3df4341 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -79,6 +79,12 @@ from sglang.srt.utils import ( ) from sglang.utils import get_exception_traceback +is_hip_ = is_hip() + +if is_hip_: + from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend + from sglang.srt.layers.attention.aiter_decode_backend import AiterDecodeAttnBackend + logger = logging.getLogger(__name__) @@ -641,7 +647,7 @@ class ModelRunner: if self.server_args.kv_cache_dtype == "auto": self.kv_cache_dtype = self.dtype elif self.server_args.kv_cache_dtype == "fp8_e5m2": - if is_hip(): # Using natively supported format + if is_hip_: # Using natively supported format self.kv_cache_dtype = torch.float8_e5m2fnuz else: self.kv_cache_dtype = torch.float8_e5m2 @@ -778,33 +784,59 @@ class ModelRunner: def init_attention_backend(self): """Init attention kernel backend.""" - if self.server_args.attention_backend == "flashinfer": - # Init streams - if self.server_args.speculative_algorithm == "EAGLE": - self.plan_stream_for_flashinfer = torch.cuda.Stream() + if is_cuda(): + if self.server_args.attention_backend == "flashinfer": + # Init streams + if self.server_args.speculative_algorithm == "EAGLE": + self.plan_stream_for_flashinfer = torch.cuda.Stream() - self.attn_backend = FlashInferAttnBackend(self) - elif self.server_args.attention_backend == "triton": - assert self.sliding_window_size is None, ( - "Window attention is not supported in the triton attention backend. " - "Please use `--attention-backend flashinfer`." - ) - assert not self.model_config.is_encoder_decoder, ( - "Cross attention is not supported in the triton attention backend. " - "Please use `--attention-backend flashinfer`." - ) - if self.server_args.enable_double_sparsity: - self.attn_backend = DoubleSparseAttnBackend(self) + self.attn_backend = FlashInferAttnBackend(self) + elif self.server_args.attention_backend == "triton": + assert self.sliding_window_size is None, ( + "Window attention is not supported in the triton attention backend. " + "Please use `--attention-backend flashinfer`." + ) + assert not self.model_config.is_encoder_decoder, ( + "Cross attention is not supported in the triton attention backend. " + "Please use `--attention-backend flashinfer`." + ) + if self.server_args.enable_double_sparsity: + self.attn_backend = DoubleSparseAttnBackend(self) + else: + self.attn_backend = TritonAttnBackend(self) + elif self.server_args.attention_backend == "torch_native": + self.attn_backend = TorchNativeAttnBackend(self) + elif self.server_args.attention_backend == "flashinfer_mla": + self.attn_backend = FlashInferMLAAttnBackend(self) else: - self.attn_backend = TritonAttnBackend(self) - elif self.server_args.attention_backend == "torch_native": - self.attn_backend = TorchNativeAttnBackend(self) - elif self.server_args.attention_backend == "flashinfer_mla": - self.attn_backend = FlashInferMLAAttnBackend(self) - else: - raise ValueError( - f"Invalid attention backend: {self.server_args.attention_backend}" - ) + raise ValueError( + f"Invalid attention backend: {self.server_args.attention_backend}" + ) + elif is_hip_: + # AMD hip supported attention backends + if self.server_args.attention_backend == "aiter": + self.attn_backend = AiterAttnBackend(self) + elif self.server_args.attention_backend == "aiter_decode": + self.attn_backend = AiterDecodeAttnBackend(self) + elif self.server_args.attention_backend == "triton": + assert self.sliding_window_size is None, ( + "Window attention is not supported in the triton attention backend. " + "Please use `--attention-backend flashinfer`." + ) + assert not self.model_config.is_encoder_decoder, ( + "Cross attention is not supported in the triton attention backend. " + "Please use `--attention-backend flashinfer`." + ) + if self.server_args.enable_double_sparsity: + self.attn_backend = DoubleSparseAttnBackend(self) + else: + self.attn_backend = TritonAttnBackend(self) + elif self.server_args.attention_backend == "torch_native": + self.attn_backend = TorchNativeAttnBackend(self) + else: + raise ValueError( + f"Invalid attention backend: {self.server_args.attention_backend}" + ) def init_double_sparsity_channel_config(self, selected_channel): selected_channel = "." + selected_channel + "_proj" diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index c5b8b920e..f53e8068d 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -710,13 +710,23 @@ class ServerArgs: ) # Kernel backend - parser.add_argument( - "--attention-backend", - type=str, - choices=["flashinfer", "triton", "torch_native"], - default=ServerArgs.attention_backend, - help="Choose the kernels for attention layers.", - ) + if is_hip(): + parser.add_argument( + "--attention-backend", + type=str, + choices=["triton", "torch_native", "aiter", "aiter_decode"], + default=ServerArgs.attention_backend, + help="Choose the kernels for attention layers.", + ) + else: + parser.add_argument( + "--attention-backend", + type=str, + choices=["flashinfer", "triton", "torch_native"], + default=ServerArgs.attention_backend, + help="Choose the kernels for attention layers.", + ) + parser.add_argument( "--sampling-backend", type=str, diff --git a/sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.hip b/sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.hip new file mode 100644 index 000000000..d0144a617 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.hip @@ -0,0 +1,118 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include + +#include "utils_hip.h" + +#define WARP_SIZE 32 + +template +__global__ void count_and_sort_expert_tokens_kernel(const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ cumsum_buffer, size_t numel) { + const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const size_t stride = blockDim.x * gridDim.x; + + for (size_t i = tid; i < numel; i += stride) { + int32_t expert_id = topk_ids[i]; + int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1); + sorted_token_ids[rank_post_pad] = i; + } +} + +template +__global__ void moe_align_block_size_kernel(const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, + int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts, + int32_t block_size, size_t numel, int32_t* __restrict__ cumsum) { + __shared__ int32_t shared_counts[WARP_SIZE][8]; + + const int warp_id = threadIdx.x / WARP_SIZE; + const int experts_per_warp = 8; + const int my_expert_start = warp_id * experts_per_warp; + + for (int i = 0; i < experts_per_warp; ++i) { + if (my_expert_start + i < num_experts) { + shared_counts[warp_id][i] = 0; + } + } + + __syncthreads(); + + const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); + const size_t start_idx = threadIdx.x * tokens_per_thread; + + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + int expert_id = topk_ids[i]; + int warp_idx = expert_id / experts_per_warp; + int expert_offset = expert_id % experts_per_warp; + atomicAdd(&shared_counts[warp_idx][expert_offset], 1); + } + + __syncthreads(); + + if (threadIdx.x == 0) { + cumsum[0] = 0; + for (int i = 1; i <= num_experts; ++i) { + int expert_count = 0; + int warp_idx = (i - 1) / experts_per_warp; + int expert_offset = (i - 1) % experts_per_warp; + expert_count = shared_counts[warp_idx][expert_offset]; + + cumsum[i] = cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size; + } + *total_tokens_post_pad = cumsum[num_experts]; + } + + __syncthreads(); + + if (threadIdx.x < num_experts) { + for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) { + expert_ids[i / block_size] = threadIdx.x; + } + } +} + +void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, + torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad, + torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer) { + const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); + TORCH_CHECK(num_experts == 256, "moe_align_block_size kernel only support deepseek v3 now."); + + DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { + auto align_kernel = moe_align_block_size_kernel; + hipLaunchKernelGGL(( align_kernel), dim3(1), dim3(1024), 0, stream, topk_ids.data_ptr(), sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), num_tokens_post_pad.data_ptr(), + num_experts, block_size, topk_ids.numel(), cumsum_buffer.data_ptr()); + + const int block_threads = 256; + const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads; + const int max_blocks = 65535; + const int actual_blocks = ::min(num_blocks, max_blocks); + + auto sort_kernel = count_and_sort_expert_tokens_kernel; + hipLaunchKernelGGL(( sort_kernel), dim3(actual_blocks), dim3(block_threads), 0, stream, topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + cumsum_buffer.data_ptr(), topk_ids.numel()); + }); +} diff --git a/sgl-kernel/src/sgl-kernel/include/utils_hip.h b/sgl-kernel/src/sgl-kernel/include/utils_hip.h new file mode 100644 index 000000000..e5f1a9355 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/include/utils_hip.h @@ -0,0 +1,98 @@ +// !!! This is a file automatically generated by hipify!!! +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#ifndef USE_ROCM +#include +#endif +#include + +#include + +struct cuda_error : public std::runtime_error { + /** + * @brief Constructs a `cuda_error` object with the given `message`. + * + * @param message The error char array used to construct `cuda_error` + */ + cuda_error(const char* message) : std::runtime_error(message) {} + /** + * @brief Constructs a `cuda_error` object with the given `message` string. + * + * @param message The `std::string` used to construct `cuda_error` + */ + cuda_error(std::string const& message) : cuda_error{message.c_str()} {} +}; + +#define CHECK_CUDA_SUCCESS(cmd) \ + do { \ + hipError_t e = cmd; \ + if (e != hipSuccess) { \ + std::stringstream _message; \ + auto s = hipGetErrorString(e); \ + _message << std::string(s) + "\n" << __FILE__ << ':' << __LINE__; \ + throw cuda_error(_message.str()); \ + } \ + } while (0) + +#define CHECK_IS_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_IS_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_CUDA_INPUT(x) \ + CHECK_IS_CUDA(x); \ + CHECK_IS_CONTIGUOUS(x) + +inline int getSMVersion() { + int device{-1}; + CHECK_CUDA_SUCCESS(hipGetDevice(&device)); + int sm_major = 0; + int sm_minor = 0; + CHECK_CUDA_SUCCESS(hipDeviceGetAttribute(&sm_major, hipDeviceAttributeComputeCapabilityMajor, device)); + CHECK_CUDA_SUCCESS(hipDeviceGetAttribute(&sm_minor, hipDeviceAttributeComputeCapabilityMinor, device)); + return sm_major * 10 + sm_minor; +} + +#ifndef USE_ROCM +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + case at::ScalarType::Float: { \ + using c_type = float; \ + return __VA_ARGS__(); \ + } \ + _DISPATCH_CASE_F16(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() +#endif + +#define DISPATCH_CASE_INTEGRAL_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) + +#define DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) + +#define CEILDIV(x, y) (((x) + (y)-1) / (y))