diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index 9c4468c11..cc2c0ca4d 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -2,7 +2,7 @@ # docker build --build-arg SGL_BRANCH=v0.4.3.post4 -t v0.4.3.post4-rocm630 -f Dockerfile.rocm . # default base image -ARG BASE_IMAGE="rocm/sgl-dev:20250114vllm-blas-flash" +ARG BASE_IMAGE="rocm/sgl-dev:vllm20250114" FROM $BASE_IMAGE AS base USER root @@ -16,10 +16,10 @@ ARG SGL_BRANCH=${SGL_DEFAULT} ARG TRITON_REPO="https://github.com/ROCm/triton.git" ARG TRITON_COMMIT="improve_fa_decode_3.0.0" + ARG AITER_REPO="https://github.com/ROCm/aiter.git" ARG AITER_COMMIT="testx" - RUN git clone ${SGL_REPO} \ && cd sglang \ && if [ "${SGL_BRANCH}" = ${SGL_DEFAULT} ]; then \ @@ -59,7 +59,6 @@ RUN git clone ${AITER_REPO} \ && git submodule update --init --recursive \ && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop - # Copy config files to support MI300X in virtualized environments (MI300X_VF). Symlinks will not be created in image build. RUN find /sgl-workspace/sglang/python/sglang/srt/layers/quantization/configs/ \ /sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/configs/ \ diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py deleted file mode 100644 index 5f7e091c9..000000000 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ /dev/null @@ -1,605 +0,0 @@ -from __future__ import annotations - -""" -end to end attention solution with aiter kernels -""" - -import math -import os -from dataclasses import dataclass -from enum import Enum, auto -from functools import partial -from typing import TYPE_CHECKING, List, Optional, Union - -import torch -import triton -import triton.language as tl - -from sglang.global_config import global_config -from sglang.srt.layers.attention.base_attn_backend import AttentionBackend -from sglang.srt.layers.dp_attention import get_attention_tp_size -from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode - -if TYPE_CHECKING: - from sglang.srt.layers.radix_attention import RadixAttention - from sglang.srt.model_executor.model_runner import ModelRunner - from sglang.srt.speculative.spec_info import SpecInfo - -# flashinfer AMD fork -from flashinfer import BatchPrefillWithPagedKVCacheWrapper - -try: - from aiter import paged_attention_rocm -except ImportError: - print( - "aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device." - ) - - -class WrapperDispatch(Enum): - SLIDING_WINDOW = auto() - CROSS_ATTENTION = auto() - - -@dataclass -class DecodeMetadata: - kv_indptr: torch.Tensor - kv_indices: torch.Tensor - - -@dataclass -class PrefillMetadata: - prefill_wrapper: BatchPrefillWithPagedKVCacheWrapper - extend_no_prefix: bool - - -global_workspace_buffer = None - -_AITER_PARTITION_SIZE_ROCM = 256 - - -class AiterAttnBackend(AttentionBackend): - def __init__( - self, - model_runner: ModelRunner, - skip_prefill: bool = False, - kv_indptr_buf: Optional[torch.Tensor] = None, - ): - super().__init__() - - self.device = model_runner.device - self.is_multimodal = model_runner.model_config.is_multimodal - self.num_head = ( - model_runner.model_config.num_attention_heads // get_attention_tp_size() - ) - self.head_dim = model_runner.model_config.head_dim - self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1] - self.num_kv_head = model_runner.model_config.get_num_kv_heads( - get_attention_tp_size() - ) - self.kv_cache_dtype = model_runner.kv_cache_dtype - - self.req_to_token = model_runner.req_to_token_pool.req_to_token - - # Parse constants - self.max_context_len = model_runner.model_config.context_len - self.skip_prefill = skip_prefill - - # Qwen2 models require higher flashinfer workspace size - if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures: - global_config.flashinfer_workspace_size = 512 * 1024 * 1024 - - global global_workspace_buffer - if global_workspace_buffer is None: - global_workspace_buffer = torch.empty( - global_config.flashinfer_workspace_size, - dtype=torch.uint8, - device=model_runner.device, - ) - - self.workspace_buffer = global_workspace_buffer - max_bs = model_runner.req_to_token_pool.size - - if kv_indptr_buf is None: - self.kv_indptr = torch.zeros( - (max_bs + 1,), dtype=torch.int32, device=model_runner.device - ) - else: - self.kv_indptr = kv_indptr_buf - - self.kv_last_page_len = torch.ones( - (max_bs,), dtype=torch.int32, device=model_runner.device - ) - self.qo_indptr = torch.zeros( - (max_bs + 1,), dtype=torch.int32, device=model_runner.device - ) - - self.prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper( - self.workspace_buffer, "NHD", backend="fa2" - ) - self.prefill_wrapper_verify = BatchPrefillWithPagedKVCacheWrapper( - self.workspace_buffer, "NHD" - ) - - # Create prefill indices updater - if not skip_prefill: - self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill( - model_runner, self - ) - - # aiter kernel related initialization - self.max_num_partitions = ( - self.max_context_len + _AITER_PARTITION_SIZE_ROCM - 1 - ) // _AITER_PARTITION_SIZE_ROCM - - nbyes_per_qo_elem = torch.finfo(torch.float32).bits // 8 - - self.workspace_buffer = torch.empty( - (max_bs * self.num_head * self.max_num_partitions * self.head_dim) - * nbyes_per_qo_elem - + 2 * (max_bs * self.num_head * self.max_num_partitions) * 4, - dtype=torch.uint8, - device=self.device, - ) - - self.scale = float(1.0 / (self.head_dim**0.5)) - self.k_scale = self.v_scale = torch.tensor([1.0], dtype=torch.float32).to( - self.device - ) - self.kv_last_page_lens = torch.ones((max_bs,), dtype=torch.int32).to( - self.device - ) - - self.logits_soft_cap = 0.0 - - self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None - self.decode_cuda_graph_metadata = {} - self.prefill_cuda_graph_metadata = {} - - def init_forward_metadata(self, forward_batch: ForwardBatch): - if forward_batch.forward_mode.is_decode_or_idle(): - # update for aiter - # create kv_indices and kv_inptr - bs = forward_batch.batch_size - kv_indptr = self.kv_indptr - spec_info = forward_batch.spec_info - if spec_info is None: - kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) - kv_indptr = kv_indptr[: bs + 1] - kv_indices = torch.zeros( - forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device - ) - create_flashinfer_kv_indices_triton[(bs,)]( - self.req_to_token, - forward_batch.req_pool_indices, - forward_batch.seq_lens, - kv_indptr, - None, - kv_indices, - self.req_to_token.stride(0), - ) - else: - kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices - bs = kv_indptr.shape[0] - 1 - - self.forward_metadata = DecodeMetadata(kv_indptr, kv_indices) - - elif forward_batch.forward_mode.is_draft_extend(): - self.indices_updater_prefill.update( - forward_batch.req_pool_indices, - forward_batch.seq_lens, - forward_batch.seq_lens_sum, - prefix_lens=None, - prefill_wrapper=self.prefill_wrapper_paged, - encoder_lens=forward_batch.encoder_lens, - spec_info=forward_batch.spec_info, - ) - self.forward_metadata = PrefillMetadata( - self.prefill_wrapper_paged, False, False - ) - elif forward_batch.forward_mode.is_target_verify(): - self.indices_updater_prefill.update( - forward_batch.req_pool_indices, - forward_batch.seq_lens, - forward_batch.seq_lens_sum, - prefix_lens=None, - prefill_wrapper=self.prefill_wrapper_verify, - encoder_lens=forward_batch.encoder_lens, - spec_info=forward_batch.spec_info, - ) - self.forward_metadata = PrefillMetadata( - self.prefill_wrapper_verify, False, False - ) - else: - prefix_lens = forward_batch.extend_prefix_lens - - if self.is_multimodal: - extend_no_prefix = False - else: - extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu) - - self.indices_updater_prefill.update( - forward_batch.req_pool_indices, - forward_batch.seq_lens, - forward_batch.seq_lens_sum, - prefix_lens, - prefill_wrapper=self.prefill_wrapper_paged, - encoder_lens=forward_batch.encoder_lens, - spec_info=None, - ) - self.forward_metadata = PrefillMetadata( - self.prefill_wrapper_paged, extend_no_prefix - ) - - def init_cuda_graph_state( - self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None - ): - if kv_indices_buf is None: - self.cuda_graph_kv_indices = torch.zeros( - (max_bs * self.max_context_len), - dtype=torch.int32, - device=self.device, - ) - else: - self.cuda_graph_kv_indices = kv_indices_buf - - if not self.skip_prefill: - self.cuda_graph_custom_mask = torch.zeros( - (max_bs * self.max_context_len), - dtype=torch.uint8, - device=self.device, - ) - - def init_forward_metadata_capture_cuda_graph( - self, - bs: int, - num_tokens: int, - req_pool_indices: torch.Tensor, - seq_lens: torch.Tensor, - encoder_lens: Optional[torch.Tensor], - forward_mode: ForwardMode, - spec_info: Optional[SpecInfo], - ): - if forward_mode.is_decode_or_idle(): - if spec_info is None: - kv_indptr = self.kv_indptr - kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) - kv_indptr = kv_indptr[: bs + 1] - kv_indices = self.cuda_graph_kv_indices - create_flashinfer_kv_indices_triton[(bs,)]( - self.req_to_token, - req_pool_indices, - seq_lens, - kv_indptr, - None, - kv_indices, - self.req_to_token.stride(0), - ) - else: - kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices - self.forward_metadata = DecodeMetadata(kv_indptr, kv_indices) - self.decode_cuda_graph_metadata[bs] = self.forward_metadata - - elif forward_mode.is_target_verify(): - prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( - self.workspace_buffer, - "NHD", - use_cuda_graph=False, - qo_indptr_buf=self.cuda_graph_qo_indptr[: bs + 1], - paged_kv_indptr_buf=self.kv_indptr[: bs + 1], - paged_kv_indices_buf=self.cuda_graph_kv_indices, - paged_kv_last_page_len_buf=self.kv_last_page_len[:bs], - custom_mask_buf=self.cuda_graph_custom_mask, - mask_indptr_buf=self.cuda_graph_qk_indptr[: bs + 1], - ) - - seq_lens_sum = seq_lens.sum().item() - self.indices_updater_prefill.update( - req_pool_indices, - seq_lens, - seq_lens_sum, - prefix_lens=None, - prefill_wrapper=prefill_wrapper, - encoder_lens=encoder_lens, - spec_info=spec_info, - ) - self.prefill_cuda_graph_metadata[bs] = prefill_wrapper - self.forward_metadata = PrefillMetadata(prefill_wrapper, False) - else: - raise ValueError(f"Invalid mode: {forward_mode=}") - - def init_forward_metadata_replay_cuda_graph( - self, - bs: int, - req_pool_indices: torch.Tensor, - seq_lens: torch.Tensor, - seq_lens_sum: int, - encoder_lens: Optional[torch.Tensor], - forward_mode: ForwardMode, - spec_info: Optional[SpecInfo], - seq_lens_cpu: Optional[torch.Tensor], - ): - if forward_mode.is_decode_or_idle(): - kv_indptr = self.kv_indptr - kv_indices = self.cuda_graph_kv_indices - if spec_info is None: - kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0) - kv_indptr = kv_indptr[: bs + 1] - create_flashinfer_kv_indices_triton[(bs,)]( - self.req_to_token, - req_pool_indices[:bs], - seq_lens[:bs], - kv_indptr, - None, - kv_indices, - self.req_to_token.stride(0), - ) - else: - kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr - kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices - - elif forward_mode.is_target_verify(): - self.indices_updater_prefill.update( - req_pool_indices[:bs], - seq_lens[:bs], - seq_lens_sum, - prefix_lens=None, - prefill_wrapper=self.prefill_cuda_graph_metadata[bs], - encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None, - spec_info=spec_info, - ) - else: - raise ValueError("Invalid forward mode") - - def get_cuda_graph_seq_len_fill_value(self): - return 1 - - def forward_extend( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - layer: RadixAttention, - forward_batch: ForwardBatch, - save_kv_cache=True, - ): - prefill_wrapper_paged = self.forward_metadata.prefill_wrapper - cache_loc = ( - forward_batch.out_cache_loc - if not layer.is_cross_attention - else forward_batch.encoder_out_cache_loc - ) - - self.logits_soft_cap = layer.logit_cap - - if k is not None: - assert v is not None - if save_kv_cache: - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, cache_loc, k, v, layer.k_scale, layer.v_scale - ) - - o = prefill_wrapper_paged.forward( - q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), - forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), - causal=not layer.is_cross_attention, - sm_scale=layer.scaling, - window_left=layer.sliding_window_size, - logits_soft_cap=self.logits_soft_cap, - k_scale=layer.k_scale, - v_scale=layer.v_scale, - ) - - return o.view(-1, layer.tp_q_head_num * layer.head_dim) - - def forward_decode( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - layer: RadixAttention, - forward_batch: ForwardBatch, - save_kv_cache=True, - ): - q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) - - if layer.qk_head_dim != layer.v_head_dim: - o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) - else: - o = torch.empty_like(q) - - if save_kv_cache: - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, forward_batch.out_cache_loc, k, v - ) - - self.logits_soft_cap = layer.logit_cap - paged_attention_rocm( - o.view(-1, layer.tp_q_head_num, layer.qk_head_dim), - self.workspace_buffer, - q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), - forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).view( - -1, 1, layer.tp_k_head_num, layer.qk_head_dim - ), - forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id).view( - -1, 1, layer.tp_v_head_num, layer.v_head_dim - ), - self.scale, - self.forward_metadata.kv_indptr, - self.forward_metadata.kv_indices, - self.kv_last_page_lens, - 1, - self.max_num_partitions, - None, - "auto", - "NHD", - self.logits_soft_cap, - self.k_scale, - self.v_scale, - None, - _AITER_PARTITION_SIZE_ROCM, - ) - - return o - - -class FlashInferIndicesUpdaterPrefill: - def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): - # Parse Constants - self.num_qo_heads = ( - model_runner.model_config.num_attention_heads // get_attention_tp_size() - ) - self.num_kv_heads = model_runner.model_config.get_num_kv_heads( - get_attention_tp_size() - ) - self.head_dim = model_runner.model_config.head_dim - self.data_type = model_runner.kv_cache_dtype - self.q_data_type = model_runner.dtype - self.sliding_window_size = model_runner.sliding_window_size - self.attn_backend = attn_backend - - # Buffers and wrappers - self.kv_indptr = attn_backend.kv_indptr - self.kv_last_page_len = attn_backend.kv_last_page_len - self.qo_indptr = attn_backend.qo_indptr - self.req_to_token = model_runner.req_to_token_pool.req_to_token - self.update = self.update_single_wrapper - - def update( - self, - req_pool_indices: torch.Tensor, - seq_lens: torch.Tensor, - seq_lens_sum: int, - prefix_lens: torch.Tensor, - prefill_wrapper: BatchPrefillWithPagedKVCacheWrapper, - encoder_lens: Optional[torch.Tensor], - spec_info: Optional[SpecInfo], - ): - # Keep the signature for type checking. It will be assigned during runtime. - raise NotImplementedError() - - def update_single_wrapper( - self, - req_pool_indices: torch.Tensor, - seq_lens: torch.Tensor, - seq_lens_sum: int, - prefix_lens: torch.Tensor, - prefill_wrapper: BatchPrefillWithPagedKVCacheWrapper, - encoder_lens: Optional[torch.Tensor], - spec_info: Optional[SpecInfo], - ): - - paged_kernel_lens = seq_lens - paged_kernel_lens_sum = seq_lens_sum - - self.call_begin_forward( - prefill_wrapper, - req_pool_indices, - paged_kernel_lens, - paged_kernel_lens_sum, - seq_lens, - prefix_lens, - None, - self.kv_indptr, - self.qo_indptr, - spec_info, - ) - - def call_begin_forward( - self, - wrapper_paged: BatchPrefillWithPagedKVCacheWrapper, - req_pool_indices: torch.Tensor, - paged_kernel_lens: torch.Tensor, - paged_kernel_lens_sum: int, - seq_lens: torch.Tensor, - prefix_lens: torch.Tensor, - kv_start_idx: torch.Tensor, - kv_indptr: torch.Tensor, - qo_indptr: torch.Tensor, - spec_info: Optional[SpecInfo], - ): - bs = len(req_pool_indices) - if spec_info is None: - # Normal extend - kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) - kv_indptr = kv_indptr[: bs + 1] - kv_indices = torch.empty( - paged_kernel_lens_sum + 256, - dtype=torch.int32, - device=req_pool_indices.device, - ) - create_flashinfer_kv_indices_triton[(bs,)]( - self.req_to_token, - req_pool_indices, - paged_kernel_lens, - kv_indptr, - kv_start_idx, - kv_indices, - self.req_to_token.shape[1], - ) - - qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0) - qo_indptr = qo_indptr[: bs + 1] - custom_mask = None - else: - kv_indices, kv_indptr, qo_indptr, custom_mask = ( - spec_info.generate_attn_arg_prefill( - req_pool_indices, - paged_kernel_lens, - self.req_to_token, - ) - ) - - # cached part - # adding logits_soft_cap arg in plan() stage - wrapper_paged.begin_forward( - qo_indptr, - kv_indptr, - kv_indices, - self.kv_last_page_len[:bs], - self.num_qo_heads, - self.num_kv_heads, - self.head_dim, - 1, - q_data_type=self.q_data_type, - custom_mask=custom_mask, - non_blocking=True, - logits_soft_cap=self.attn_backend.logits_soft_cap, - ) - - -@triton.jit -def create_flashinfer_kv_indices_triton( - req_to_token_ptr, # [max_batch, max_context_len] - req_pool_indices_ptr, - page_kernel_lens_ptr, - kv_indptr, - kv_start_idx, - kv_indices_ptr, - req_to_token_ptr_stride: tl.constexpr, -): - BLOCK_SIZE: tl.constexpr = 512 - pid = tl.program_id(axis=0) - - req_pool_index = tl.load(req_pool_indices_ptr + pid) - kv_indices_offset = tl.load(kv_indptr + pid) - - kv_start = 0 - kv_end = 0 - if kv_start_idx: - kv_start = tl.load(kv_start_idx + pid).to(tl.int32) - kv_end = kv_start - kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32) - - num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE) - for i in range(num_loop): - offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE - mask = offset < kv_end - kv_start - data = tl.load( - req_to_token_ptr - + req_pool_index * req_to_token_ptr_stride - + kv_start - + offset, - mask=mask, - ) - tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask) diff --git a/python/sglang/srt/layers/attention/aiter_decode_backend.py b/python/sglang/srt/layers/attention/aiter_decode_backend.py deleted file mode 100644 index b8d3c77ef..000000000 --- a/python/sglang/srt/layers/attention/aiter_decode_backend.py +++ /dev/null @@ -1,535 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Optional - -import torch -import triton -import triton.language as tl - -from sglang.srt.layers.attention.base_attn_backend import AttentionBackend -from sglang.srt.layers.dp_attention import get_attention_tp_size -from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode - -if TYPE_CHECKING: - from sglang.srt.layers.radix_attention import RadixAttention - from sglang.srt.model_executor.model_runner import ModelRunner - from sglang.srt.speculative.spec_info import SpecInfo - -try: - from aiter import paged_attention_rocm -except ImportError: - print( - "aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device." - ) - -from sglang.srt.layers.attention.triton_ops.extend_attention import extend_attention_fwd - -_AITER_PARTITION_SIZE_ROCM = 256 - - -class AiterDecodeAttnBackend(AttentionBackend): - def __init__( - self, - model_runner: ModelRunner, - skip_prefill: bool = False, - kv_indptr_buf: Optional[torch.Tensor] = None, - ): - super().__init__() - - self.decode_attention_fwd = paged_attention_rocm - self.extend_attention_fwd = extend_attention_fwd - - self.skip_prefill = skip_prefill - - max_bs = model_runner.req_to_token_pool.size - - if kv_indptr_buf is None: - self.kv_indptr = torch.zeros( - (max_bs + 1,), dtype=torch.int32, device=model_runner.device - ) - else: - self.kv_indptr = kv_indptr_buf - - self.req_to_token = model_runner.req_to_token_pool.req_to_token - - if not self.skip_prefill: - self.qo_indptr = torch.zeros( - (max_bs + 1,), dtype=torch.int32, device=model_runner.device - ) - - self.mask_indptr = torch.zeros( - (max_bs + 1,), dtype=torch.int64, device=model_runner.device - ) - - self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens - - # tp sharding on number of heads - self.num_head = ( - model_runner.model_config.num_attention_heads // get_attention_tp_size() - ) - - self.head_dim = model_runner.model_config.head_dim - - # triton prefill initialization - self.num_kv_splits = model_runner.server_args.triton_attention_num_kv_splits - - self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1] - - self.num_v_head = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-2] - - self.forward_metadata = None - - self.max_context_len = model_runner.model_config.context_len - - self.device = model_runner.device - - self.kv_cache_dtype = model_runner.kv_cache_dtype - - self.q_dtype = model_runner.model_config.dtype - - # aiter decode initialization - self.max_num_partitions = ( - self.max_context_len + _AITER_PARTITION_SIZE_ROCM - 1 - ) // _AITER_PARTITION_SIZE_ROCM - - nbyes_per_qo_elem = torch.finfo(torch.float32).bits // 8 - - self.workspace_buffer = torch.empty( - (max_bs * self.num_head * self.max_num_partitions * self.head_dim) - * nbyes_per_qo_elem - + 2 * (max_bs * self.num_head * self.max_num_partitions) * 4, - dtype=torch.uint8, - device=self.device, - ) - - self.scale = float(1.0 / (self.head_dim**0.5)) - self.k_scale = self.v_scale = torch.tensor([1.0], dtype=torch.float32).to( - self.device - ) - self.kv_last_page_lens = torch.ones((max_bs,), dtype=torch.int32).to( - self.device - ) - - def init_forward_metadata(self, forward_batch: ForwardBatch): - """Init auxiliary variables""" - bs = forward_batch.batch_size - kv_indptr = self.kv_indptr - spec_info = forward_batch.spec_info - - if forward_batch.forward_mode.is_decode_or_idle(): - if spec_info is None: - kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) - kv_indptr = kv_indptr[: bs + 1] - kv_indices = torch.zeros( - forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device - ) - # prepare kv_indices and kv_indptr - create_flashinfer_kv_indices_triton[(bs,)]( - self.req_to_token, - forward_batch.req_pool_indices, - forward_batch.seq_lens, - kv_indptr, - None, - kv_indices, - self.req_to_token.stride(0), - ) - else: - kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices - bs = kv_indptr.shape[0] - 1 - - attn_logits = None # accomodate forward_metadata format - qo_indptr = None - custom_mask = None - mask_indptr = None - max_extend_len = None - elif forward_batch.forward_mode.is_target_verify(): - bs = len(forward_batch.req_pool_indices) - qo_indptr = torch.arange( - 0, - (1 + bs) * self.num_draft_tokens, - step=self.num_draft_tokens, - dtype=torch.int32, - device=self.device, - ) - kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) - kv_indptr = kv_indptr[: bs + 1] - kv_indices = torch.zeros( - kv_indptr[-1], dtype=torch.int32, device=self.device - ) - create_flashinfer_kv_indices_triton[(bs,)]( - self.req_to_token, - forward_batch.req_pool_indices, - forward_batch.seq_lens, - kv_indptr, - None, - kv_indices, - self.req_to_token.stride(0), - ) - - custom_mask = spec_info.custom_mask - seq_mask_len = self.num_draft_tokens * ( - forward_batch.seq_lens + self.num_draft_tokens - ) - mask_indptr = self.mask_indptr - mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len[:bs], dim=0) - mask_indptr = mask_indptr[: bs + 1] - max_extend_len = self.num_draft_tokens - attn_logits = None - elif forward_batch.forward_mode.is_draft_extend(): - kv_indices, kv_indptr, qo_indptr, custom_mask = ( - spec_info.generate_attn_arg_prefill( - forward_batch.req_pool_indices, - forward_batch.seq_lens, - self.req_to_token, - ) - ) - mask_indptr = None - max_extend_len = torch.max(spec_info.accept_length).item() - attn_logits = None - else: - kv_indptr[1 : bs + 1] = torch.cumsum( - forward_batch.extend_prefix_lens, dim=0 - ) - kv_indptr = kv_indptr[: bs + 1] - kv_indices = torch.zeros( - forward_batch.extend_prefix_lens.sum().item(), - dtype=torch.int32, - device=self.device, - ) - - create_flashinfer_kv_indices_triton[(bs,)]( - self.req_to_token, - forward_batch.req_pool_indices, - forward_batch.extend_prefix_lens, - kv_indptr, - None, - kv_indices, - self.req_to_token.stride(0), - ) - - qo_indptr = self.qo_indptr - qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0) - qo_indptr = qo_indptr[: bs + 1] - custom_mask = None - mask_indptr = None - attn_logits = None - max_extend_len = torch.max(forward_batch.extend_seq_lens).item() - - self.forward_metadata = ( - attn_logits, - max_extend_len, - kv_indptr, - kv_indices, - qo_indptr, - custom_mask, - mask_indptr, - ) - - def init_cuda_graph_state( - self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None - ): - - self.cuda_graph_attn_logits = torch.zeros( - (max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1), - dtype=torch.float32, - device=self.device, - ) - if kv_indices_buf is None: - self.cuda_graph_kv_indices = torch.zeros( - (max_bs * self.max_context_len), - dtype=torch.int32, - device=self.device, - ) - else: - self.cuda_graph_kv_indices = kv_indices_buf - - if not self.skip_prefill: - self.cuda_graph_custom_mask = torch.zeros( - (max_bs * self.max_context_len), - dtype=torch.uint8, - device=self.device, - ) - - def init_forward_metadata_capture_cuda_graph( - self, - bs: int, - num_tokens: int, - req_pool_indices: torch.Tensor, - seq_lens: torch.Tensor, - encoder_lens: Optional[torch.Tensor], - forward_mode: ForwardMode, - spec_info: Optional[SpecInfo], - ): - assert encoder_lens is None, "Not supported" - - if forward_mode.is_decode_or_idle(): - if spec_info is None: - kv_indptr = self.kv_indptr - kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) - kv_indptr = kv_indptr[: bs + 1] - kv_indices = self.cuda_graph_kv_indices - create_flashinfer_kv_indices_triton[(bs,)]( - self.req_to_token, - req_pool_indices, - seq_lens, - kv_indptr, - None, - kv_indices, - self.req_to_token.stride(0), - ) - else: - kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices - attn_logits = None - max_extend_len = None - qo_indptr = None - custom_mask = None - mask_indptr = None - elif forward_mode.is_target_verify(): - qo_indptr = self.qo_indptr[: bs + 1] - qo_indptr[: bs + 1] = torch.arange( - 0, - (1 + bs) * self.num_draft_tokens, - step=self.num_draft_tokens, - dtype=torch.int32, - device=self.device, - ) - kv_indptr = self.kv_indptr[: bs + 1] - kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) - kv_indices = self.cuda_graph_kv_indices - create_flashinfer_kv_indices_triton[(bs,)]( - self.req_to_token, - req_pool_indices, - seq_lens, - kv_indptr, - None, - kv_indices, - self.req_to_token.stride(0), - ) - - custom_mask = self.cuda_graph_custom_mask - seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens) - mask_indptr = self.mask_indptr[: bs + 1] - mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0) - max_extend_len = self.num_draft_tokens - attn_logits = None - else: - raise ValueError( - f"Invalid forward mode: {forward_mode=} for CUDA Graph capture." - ) - - self.forward_metadata = ( - attn_logits, - max_extend_len, - kv_indptr, - kv_indices, - qo_indptr, - custom_mask, - mask_indptr, - ) - - def init_forward_metadata_replay_cuda_graph( - self, - bs: int, - req_pool_indices: torch.Tensor, - seq_lens: torch.Tensor, - seq_lens_sum: int, - encoder_lens: Optional[torch.Tensor], - forward_mode: ForwardMode, - spec_info: Optional[SpecInfo], - seq_lens_cpu: Optional[torch.Tensor], - ): - # NOTE: encoder_lens expected to be zeros or None - if forward_mode.is_decode_or_idle(): - # Update kv_indptr, kv_indices - kv_indptr = self.kv_indptr - kv_indices = self.cuda_graph_kv_indices - if spec_info is None: - kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0) - kv_indptr = kv_indptr[: bs + 1] - create_flashinfer_kv_indices_triton[(bs,)]( - self.req_to_token, - req_pool_indices[:bs], - seq_lens[:bs], - kv_indptr, - None, - kv_indices, - self.req_to_token.stride(0), - ) - else: - kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr - kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices - elif forward_mode.is_target_verify(): - # Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr - bs = len(req_pool_indices) - qo_indptr = self.qo_indptr[: bs + 1] - qo_indptr[: bs + 1] = torch.arange( - 0, - (1 + bs) * self.num_draft_tokens, - step=self.num_draft_tokens, - dtype=torch.int32, - device=self.device, - ) - kv_indptr = self.kv_indptr[: bs + 1] - kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) - kv_indices = self.cuda_graph_kv_indices - create_flashinfer_kv_indices_triton[(bs,)]( - self.req_to_token, - req_pool_indices, - seq_lens, - kv_indptr, - None, - kv_indices, - self.req_to_token.stride(0), - ) - custom_mask = self.cuda_graph_custom_mask - custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask - seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens) - mask_indptr = self.mask_indptr[: bs + 1] - mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0) - else: - raise ValueError( - f"Invalid forward mode: {forward_mode=} for CUDA Graph replay." - ) - - def get_cuda_graph_seq_len_fill_value(self): - return 1 - - def forward_extend( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - layer: RadixAttention, - forward_batch: ForwardBatch, - save_kv_cache=True, - ): - # TODO: reuse the buffer across layers - if layer.qk_head_dim != layer.v_head_dim: - o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) - else: - o = torch.empty_like(q) - - if save_kv_cache: - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, forward_batch.out_cache_loc, k, v - ) - - ( - _, - max_extend_len, - kv_indptr, - kv_indices, - qo_indptr, - custom_mask, - mask_indptr, - ) = self.forward_metadata - - self.extend_attention_fwd( - q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), - k.contiguous(), - v.contiguous(), - o.view(-1, layer.tp_q_head_num, layer.v_head_dim), - forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id), - forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id), - qo_indptr, - kv_indptr, - kv_indices, - custom_mask, - mask_indptr, - max_extend_len, - layer.scaling, - layer.logit_cap, - ) - return o - - def forward_decode( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - layer: RadixAttention, - forward_batch: ForwardBatch, - save_kv_cache=True, - ): - # During torch.compile, there is a bug in rotary_emb that causes the - # output value to have a 3D tensor shape. This reshapes the output correctly. - q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) - - if layer.qk_head_dim != layer.v_head_dim: - o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) - else: - o = torch.empty_like(q) - - attn_logits, _, kv_indptr, kv_indices, _, _, _ = self.forward_metadata - - if save_kv_cache: - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, forward_batch.out_cache_loc, k, v - ) - - self.decode_attention_fwd( - o.view( - -1, layer.tp_q_head_num, layer.qk_head_dim - ), # (bs, head_num_q, head_dim_q) - self.workspace_buffer, - q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), - forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).view( - -1, 1, layer.tp_k_head_num, layer.qk_head_dim - ), - forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id).view( - -1, 1, layer.tp_v_head_num, layer.v_head_dim - ), - self.scale, - kv_indptr, - kv_indices, - self.kv_last_page_lens, - 1, - self.max_num_partitions, - None, - "auto", - "NHD", - layer.logit_cap, - self.k_scale, - self.v_scale, - None, - _AITER_PARTITION_SIZE_ROCM, - ) - - return o - - -@triton.jit -def create_flashinfer_kv_indices_triton( - req_to_token_ptr, # [max_batch, max_context_len] - req_pool_indices_ptr, - page_kernel_lens_ptr, - kv_indptr, - kv_start_idx, - kv_indices_ptr, - req_to_token_ptr_stride: tl.constexpr, -): - BLOCK_SIZE: tl.constexpr = 512 - pid = tl.program_id(axis=0) - - req_pool_index = tl.load(req_pool_indices_ptr + pid) - kv_indices_offset = tl.load(kv_indptr + pid) - - kv_start = 0 - kv_end = 0 - if kv_start_idx: - kv_start = tl.load(kv_start_idx + pid).to(tl.int32) - kv_end = kv_start - kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32) - - num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE) - for i in range(num_loop): - offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE - mask = offset < kv_end - kv_start - data = tl.load( - req_to_token_ptr - + req_pool_index * req_to_token_ptr_stride - + kv_start - + offset, - mask=mask, - ) - tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 8c3df4341..666b97e2b 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -79,12 +79,6 @@ from sglang.srt.utils import ( ) from sglang.utils import get_exception_traceback -is_hip_ = is_hip() - -if is_hip_: - from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend - from sglang.srt.layers.attention.aiter_decode_backend import AiterDecodeAttnBackend - logger = logging.getLogger(__name__) @@ -647,7 +641,7 @@ class ModelRunner: if self.server_args.kv_cache_dtype == "auto": self.kv_cache_dtype = self.dtype elif self.server_args.kv_cache_dtype == "fp8_e5m2": - if is_hip_: # Using natively supported format + if is_hip(): # Using natively supported format self.kv_cache_dtype = torch.float8_e5m2fnuz else: self.kv_cache_dtype = torch.float8_e5m2 @@ -784,59 +778,33 @@ class ModelRunner: def init_attention_backend(self): """Init attention kernel backend.""" - if is_cuda(): - if self.server_args.attention_backend == "flashinfer": - # Init streams - if self.server_args.speculative_algorithm == "EAGLE": - self.plan_stream_for_flashinfer = torch.cuda.Stream() + if self.server_args.attention_backend == "flashinfer": + # Init streams + if self.server_args.speculative_algorithm == "EAGLE": + self.plan_stream_for_flashinfer = torch.cuda.Stream() - self.attn_backend = FlashInferAttnBackend(self) - elif self.server_args.attention_backend == "triton": - assert self.sliding_window_size is None, ( - "Window attention is not supported in the triton attention backend. " - "Please use `--attention-backend flashinfer`." - ) - assert not self.model_config.is_encoder_decoder, ( - "Cross attention is not supported in the triton attention backend. " - "Please use `--attention-backend flashinfer`." - ) - if self.server_args.enable_double_sparsity: - self.attn_backend = DoubleSparseAttnBackend(self) - else: - self.attn_backend = TritonAttnBackend(self) - elif self.server_args.attention_backend == "torch_native": - self.attn_backend = TorchNativeAttnBackend(self) - elif self.server_args.attention_backend == "flashinfer_mla": - self.attn_backend = FlashInferMLAAttnBackend(self) + self.attn_backend = FlashInferAttnBackend(self) + elif self.server_args.attention_backend == "triton": + assert self.sliding_window_size is None, ( + "Window attention is not supported in the triton attention backend. " + "Please use `--attention-backend flashinfer`." + ) + assert not self.model_config.is_encoder_decoder, ( + "Cross attention is not supported in the triton attention backend. " + "Please use `--attention-backend flashinfer`." + ) + if self.server_args.enable_double_sparsity: + self.attn_backend = DoubleSparseAttnBackend(self) else: - raise ValueError( - f"Invalid attention backend: {self.server_args.attention_backend}" - ) - elif is_hip_: - # AMD hip supported attention backends - if self.server_args.attention_backend == "aiter": - self.attn_backend = AiterAttnBackend(self) - elif self.server_args.attention_backend == "aiter_decode": - self.attn_backend = AiterDecodeAttnBackend(self) - elif self.server_args.attention_backend == "triton": - assert self.sliding_window_size is None, ( - "Window attention is not supported in the triton attention backend. " - "Please use `--attention-backend flashinfer`." - ) - assert not self.model_config.is_encoder_decoder, ( - "Cross attention is not supported in the triton attention backend. " - "Please use `--attention-backend flashinfer`." - ) - if self.server_args.enable_double_sparsity: - self.attn_backend = DoubleSparseAttnBackend(self) - else: - self.attn_backend = TritonAttnBackend(self) - elif self.server_args.attention_backend == "torch_native": - self.attn_backend = TorchNativeAttnBackend(self) - else: - raise ValueError( - f"Invalid attention backend: {self.server_args.attention_backend}" - ) + self.attn_backend = TritonAttnBackend(self) + elif self.server_args.attention_backend == "torch_native": + self.attn_backend = TorchNativeAttnBackend(self) + elif self.server_args.attention_backend == "flashinfer_mla": + self.attn_backend = FlashInferMLAAttnBackend(self) + else: + raise ValueError( + f"Invalid attention backend: {self.server_args.attention_backend}" + ) def init_double_sparsity_channel_config(self, selected_channel): selected_channel = "." + selected_channel + "_proj" diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index f53e8068d..c5b8b920e 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -710,23 +710,13 @@ class ServerArgs: ) # Kernel backend - if is_hip(): - parser.add_argument( - "--attention-backend", - type=str, - choices=["triton", "torch_native", "aiter", "aiter_decode"], - default=ServerArgs.attention_backend, - help="Choose the kernels for attention layers.", - ) - else: - parser.add_argument( - "--attention-backend", - type=str, - choices=["flashinfer", "triton", "torch_native"], - default=ServerArgs.attention_backend, - help="Choose the kernels for attention layers.", - ) - + parser.add_argument( + "--attention-backend", + type=str, + choices=["flashinfer", "triton", "torch_native"], + default=ServerArgs.attention_backend, + help="Choose the kernels for attention layers.", + ) parser.add_argument( "--sampling-backend", type=str, diff --git a/sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.hip b/sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.hip deleted file mode 100644 index d0144a617..000000000 --- a/sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.hip +++ /dev/null @@ -1,118 +0,0 @@ -// !!! This is a file automatically generated by hipify!!! -#include "hip/hip_runtime.h" -/* Copyright 2025 SGLang Team. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include - -#include - -#include "utils_hip.h" - -#define WARP_SIZE 32 - -template -__global__ void count_and_sort_expert_tokens_kernel(const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ sorted_token_ids, - int32_t* __restrict__ cumsum_buffer, size_t numel) { - const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; - const size_t stride = blockDim.x * gridDim.x; - - for (size_t i = tid; i < numel; i += stride) { - int32_t expert_id = topk_ids[i]; - int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1); - sorted_token_ids[rank_post_pad] = i; - } -} - -template -__global__ void moe_align_block_size_kernel(const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, - int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts, - int32_t block_size, size_t numel, int32_t* __restrict__ cumsum) { - __shared__ int32_t shared_counts[WARP_SIZE][8]; - - const int warp_id = threadIdx.x / WARP_SIZE; - const int experts_per_warp = 8; - const int my_expert_start = warp_id * experts_per_warp; - - for (int i = 0; i < experts_per_warp; ++i) { - if (my_expert_start + i < num_experts) { - shared_counts[warp_id][i] = 0; - } - } - - __syncthreads(); - - const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); - const size_t start_idx = threadIdx.x * tokens_per_thread; - - for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { - int expert_id = topk_ids[i]; - int warp_idx = expert_id / experts_per_warp; - int expert_offset = expert_id % experts_per_warp; - atomicAdd(&shared_counts[warp_idx][expert_offset], 1); - } - - __syncthreads(); - - if (threadIdx.x == 0) { - cumsum[0] = 0; - for (int i = 1; i <= num_experts; ++i) { - int expert_count = 0; - int warp_idx = (i - 1) / experts_per_warp; - int expert_offset = (i - 1) % experts_per_warp; - expert_count = shared_counts[warp_idx][expert_offset]; - - cumsum[i] = cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size; - } - *total_tokens_post_pad = cumsum[num_experts]; - } - - __syncthreads(); - - if (threadIdx.x < num_experts) { - for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) { - expert_ids[i / block_size] = threadIdx.x; - } - } -} - -void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, - torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad, - torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer) { - const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); - TORCH_CHECK(num_experts == 256, "moe_align_block_size kernel only support deepseek v3 now."); - - DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { - auto align_kernel = moe_align_block_size_kernel; - hipLaunchKernelGGL(( align_kernel), dim3(1), dim3(1024), 0, stream, topk_ids.data_ptr(), sorted_token_ids.data_ptr(), - experts_ids.data_ptr(), num_tokens_post_pad.data_ptr(), - num_experts, block_size, topk_ids.numel(), cumsum_buffer.data_ptr()); - - const int block_threads = 256; - const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads; - const int max_blocks = 65535; - const int actual_blocks = ::min(num_blocks, max_blocks); - - auto sort_kernel = count_and_sort_expert_tokens_kernel; - hipLaunchKernelGGL(( sort_kernel), dim3(actual_blocks), dim3(block_threads), 0, stream, topk_ids.data_ptr(), - sorted_token_ids.data_ptr(), - cumsum_buffer.data_ptr(), topk_ids.numel()); - }); -} diff --git a/sgl-kernel/src/sgl-kernel/include/utils_hip.h b/sgl-kernel/src/sgl-kernel/include/utils_hip.h deleted file mode 100644 index e5f1a9355..000000000 --- a/sgl-kernel/src/sgl-kernel/include/utils_hip.h +++ /dev/null @@ -1,98 +0,0 @@ -// !!! This is a file automatically generated by hipify!!! -/* Copyright 2025 SGLang Team. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#pragma once - -#include -#ifndef USE_ROCM -#include -#endif -#include - -#include - -struct cuda_error : public std::runtime_error { - /** - * @brief Constructs a `cuda_error` object with the given `message`. - * - * @param message The error char array used to construct `cuda_error` - */ - cuda_error(const char* message) : std::runtime_error(message) {} - /** - * @brief Constructs a `cuda_error` object with the given `message` string. - * - * @param message The `std::string` used to construct `cuda_error` - */ - cuda_error(std::string const& message) : cuda_error{message.c_str()} {} -}; - -#define CHECK_CUDA_SUCCESS(cmd) \ - do { \ - hipError_t e = cmd; \ - if (e != hipSuccess) { \ - std::stringstream _message; \ - auto s = hipGetErrorString(e); \ - _message << std::string(s) + "\n" << __FILE__ << ':' << __LINE__; \ - throw cuda_error(_message.str()); \ - } \ - } while (0) - -#define CHECK_IS_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_IS_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_CUDA_INPUT(x) \ - CHECK_IS_CUDA(x); \ - CHECK_IS_CONTIGUOUS(x) - -inline int getSMVersion() { - int device{-1}; - CHECK_CUDA_SUCCESS(hipGetDevice(&device)); - int sm_major = 0; - int sm_minor = 0; - CHECK_CUDA_SUCCESS(hipDeviceGetAttribute(&sm_major, hipDeviceAttributeComputeCapabilityMajor, device)); - CHECK_CUDA_SUCCESS(hipDeviceGetAttribute(&sm_minor, hipDeviceAttributeComputeCapabilityMinor, device)); - return sm_major * 10 + sm_minor; -} - -#ifndef USE_ROCM -#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(pytorch_dtype, c_type, ...) \ - [&]() -> bool { \ - switch (pytorch_dtype) { \ - case at::ScalarType::Float: { \ - using c_type = float; \ - return __VA_ARGS__(); \ - } \ - _DISPATCH_CASE_F16(c_type, __VA_ARGS__) \ - _DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \ - default: \ - std::ostringstream oss; \ - oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ - TORCH_CHECK(false, oss.str()); \ - return false; \ - } \ - }() -#endif - -#define DISPATCH_CASE_INTEGRAL_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) - -#define DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) - -#define CEILDIV(x, y) (((x) + (y)-1) / (y))