ROCm: Flex Attention Enablement with custom backends (#4178)
Co-authored-by: linsun12 <linsun12@amd.com>
This commit is contained in:
@@ -2,7 +2,7 @@
|
|||||||
# docker build --build-arg SGL_BRANCH=v0.4.3.post4 -t v0.4.3.post4-rocm630 -f Dockerfile.rocm .
|
# docker build --build-arg SGL_BRANCH=v0.4.3.post4 -t v0.4.3.post4-rocm630 -f Dockerfile.rocm .
|
||||||
|
|
||||||
# default base image
|
# default base image
|
||||||
ARG BASE_IMAGE="rocm/sgl-dev:vllm20250114"
|
ARG BASE_IMAGE="rocm/sgl-dev:20250114vllm-blas-flash"
|
||||||
|
|
||||||
FROM $BASE_IMAGE AS base
|
FROM $BASE_IMAGE AS base
|
||||||
USER root
|
USER root
|
||||||
@@ -16,10 +16,10 @@ ARG SGL_BRANCH=${SGL_DEFAULT}
|
|||||||
ARG TRITON_REPO="https://github.com/ROCm/triton.git"
|
ARG TRITON_REPO="https://github.com/ROCm/triton.git"
|
||||||
ARG TRITON_COMMIT="improve_fa_decode_3.0.0"
|
ARG TRITON_COMMIT="improve_fa_decode_3.0.0"
|
||||||
|
|
||||||
|
|
||||||
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
|
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
|
||||||
ARG AITER_COMMIT="testx"
|
ARG AITER_COMMIT="testx"
|
||||||
|
|
||||||
|
|
||||||
RUN git clone ${SGL_REPO} \
|
RUN git clone ${SGL_REPO} \
|
||||||
&& cd sglang \
|
&& cd sglang \
|
||||||
&& if [ "${SGL_BRANCH}" = ${SGL_DEFAULT} ]; then \
|
&& if [ "${SGL_BRANCH}" = ${SGL_DEFAULT} ]; then \
|
||||||
@@ -59,6 +59,7 @@ RUN git clone ${AITER_REPO} \
|
|||||||
&& git submodule update --init --recursive \
|
&& git submodule update --init --recursive \
|
||||||
&& PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop
|
&& PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop
|
||||||
|
|
||||||
|
|
||||||
# Copy config files to support MI300X in virtualized environments (MI300X_VF). Symlinks will not be created in image build.
|
# Copy config files to support MI300X in virtualized environments (MI300X_VF). Symlinks will not be created in image build.
|
||||||
RUN find /sgl-workspace/sglang/python/sglang/srt/layers/quantization/configs/ \
|
RUN find /sgl-workspace/sglang/python/sglang/srt/layers/quantization/configs/ \
|
||||||
/sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/configs/ \
|
/sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/configs/ \
|
||||||
|
|||||||
605
python/sglang/srt/layers/attention/aiter_backend.py
Normal file
605
python/sglang/srt/layers/attention/aiter_backend.py
Normal file
@@ -0,0 +1,605 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
"""
|
||||||
|
end to end attention solution with aiter kernels
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum, auto
|
||||||
|
from functools import partial
|
||||||
|
from typing import TYPE_CHECKING, List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
from sglang.global_config import global_config
|
||||||
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||||
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
|
from sglang.srt.speculative.spec_info import SpecInfo
|
||||||
|
|
||||||
|
# flashinfer AMD fork
|
||||||
|
from flashinfer import BatchPrefillWithPagedKVCacheWrapper
|
||||||
|
|
||||||
|
try:
|
||||||
|
from aiter import paged_attention_rocm
|
||||||
|
except ImportError:
|
||||||
|
print(
|
||||||
|
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WrapperDispatch(Enum):
|
||||||
|
SLIDING_WINDOW = auto()
|
||||||
|
CROSS_ATTENTION = auto()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DecodeMetadata:
|
||||||
|
kv_indptr: torch.Tensor
|
||||||
|
kv_indices: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PrefillMetadata:
|
||||||
|
prefill_wrapper: BatchPrefillWithPagedKVCacheWrapper
|
||||||
|
extend_no_prefix: bool
|
||||||
|
|
||||||
|
|
||||||
|
global_workspace_buffer = None
|
||||||
|
|
||||||
|
_AITER_PARTITION_SIZE_ROCM = 256
|
||||||
|
|
||||||
|
|
||||||
|
class AiterAttnBackend(AttentionBackend):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_runner: ModelRunner,
|
||||||
|
skip_prefill: bool = False,
|
||||||
|
kv_indptr_buf: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.device = model_runner.device
|
||||||
|
self.is_multimodal = model_runner.model_config.is_multimodal
|
||||||
|
self.num_head = (
|
||||||
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||||
|
)
|
||||||
|
self.head_dim = model_runner.model_config.head_dim
|
||||||
|
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
|
||||||
|
self.num_kv_head = model_runner.model_config.get_num_kv_heads(
|
||||||
|
get_attention_tp_size()
|
||||||
|
)
|
||||||
|
self.kv_cache_dtype = model_runner.kv_cache_dtype
|
||||||
|
|
||||||
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||||
|
|
||||||
|
# Parse constants
|
||||||
|
self.max_context_len = model_runner.model_config.context_len
|
||||||
|
self.skip_prefill = skip_prefill
|
||||||
|
|
||||||
|
# Qwen2 models require higher flashinfer workspace size
|
||||||
|
if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures:
|
||||||
|
global_config.flashinfer_workspace_size = 512 * 1024 * 1024
|
||||||
|
|
||||||
|
global global_workspace_buffer
|
||||||
|
if global_workspace_buffer is None:
|
||||||
|
global_workspace_buffer = torch.empty(
|
||||||
|
global_config.flashinfer_workspace_size,
|
||||||
|
dtype=torch.uint8,
|
||||||
|
device=model_runner.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.workspace_buffer = global_workspace_buffer
|
||||||
|
max_bs = model_runner.req_to_token_pool.size
|
||||||
|
|
||||||
|
if kv_indptr_buf is None:
|
||||||
|
self.kv_indptr = torch.zeros(
|
||||||
|
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.kv_indptr = kv_indptr_buf
|
||||||
|
|
||||||
|
self.kv_last_page_len = torch.ones(
|
||||||
|
(max_bs,), dtype=torch.int32, device=model_runner.device
|
||||||
|
)
|
||||||
|
self.qo_indptr = torch.zeros(
|
||||||
|
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||||
|
)
|
||||||
|
|
||||||
|
self.prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
||||||
|
self.workspace_buffer, "NHD", backend="fa2"
|
||||||
|
)
|
||||||
|
self.prefill_wrapper_verify = BatchPrefillWithPagedKVCacheWrapper(
|
||||||
|
self.workspace_buffer, "NHD"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create prefill indices updater
|
||||||
|
if not skip_prefill:
|
||||||
|
self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill(
|
||||||
|
model_runner, self
|
||||||
|
)
|
||||||
|
|
||||||
|
# aiter kernel related initialization
|
||||||
|
self.max_num_partitions = (
|
||||||
|
self.max_context_len + _AITER_PARTITION_SIZE_ROCM - 1
|
||||||
|
) // _AITER_PARTITION_SIZE_ROCM
|
||||||
|
|
||||||
|
nbyes_per_qo_elem = torch.finfo(torch.float32).bits // 8
|
||||||
|
|
||||||
|
self.workspace_buffer = torch.empty(
|
||||||
|
(max_bs * self.num_head * self.max_num_partitions * self.head_dim)
|
||||||
|
* nbyes_per_qo_elem
|
||||||
|
+ 2 * (max_bs * self.num_head * self.max_num_partitions) * 4,
|
||||||
|
dtype=torch.uint8,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.scale = float(1.0 / (self.head_dim**0.5))
|
||||||
|
self.k_scale = self.v_scale = torch.tensor([1.0], dtype=torch.float32).to(
|
||||||
|
self.device
|
||||||
|
)
|
||||||
|
self.kv_last_page_lens = torch.ones((max_bs,), dtype=torch.int32).to(
|
||||||
|
self.device
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logits_soft_cap = 0.0
|
||||||
|
|
||||||
|
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
|
||||||
|
self.decode_cuda_graph_metadata = {}
|
||||||
|
self.prefill_cuda_graph_metadata = {}
|
||||||
|
|
||||||
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
|
if forward_batch.forward_mode.is_decode_or_idle():
|
||||||
|
# update for aiter
|
||||||
|
# create kv_indices and kv_inptr
|
||||||
|
bs = forward_batch.batch_size
|
||||||
|
kv_indptr = self.kv_indptr
|
||||||
|
spec_info = forward_batch.spec_info
|
||||||
|
if spec_info is None:
|
||||||
|
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
||||||
|
kv_indptr = kv_indptr[: bs + 1]
|
||||||
|
kv_indices = torch.zeros(
|
||||||
|
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
|
||||||
|
)
|
||||||
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
|
self.req_to_token,
|
||||||
|
forward_batch.req_pool_indices,
|
||||||
|
forward_batch.seq_lens,
|
||||||
|
kv_indptr,
|
||||||
|
None,
|
||||||
|
kv_indices,
|
||||||
|
self.req_to_token.stride(0),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
||||||
|
bs = kv_indptr.shape[0] - 1
|
||||||
|
|
||||||
|
self.forward_metadata = DecodeMetadata(kv_indptr, kv_indices)
|
||||||
|
|
||||||
|
elif forward_batch.forward_mode.is_draft_extend():
|
||||||
|
self.indices_updater_prefill.update(
|
||||||
|
forward_batch.req_pool_indices,
|
||||||
|
forward_batch.seq_lens,
|
||||||
|
forward_batch.seq_lens_sum,
|
||||||
|
prefix_lens=None,
|
||||||
|
prefill_wrapper=self.prefill_wrapper_paged,
|
||||||
|
encoder_lens=forward_batch.encoder_lens,
|
||||||
|
spec_info=forward_batch.spec_info,
|
||||||
|
)
|
||||||
|
self.forward_metadata = PrefillMetadata(
|
||||||
|
self.prefill_wrapper_paged, False, False
|
||||||
|
)
|
||||||
|
elif forward_batch.forward_mode.is_target_verify():
|
||||||
|
self.indices_updater_prefill.update(
|
||||||
|
forward_batch.req_pool_indices,
|
||||||
|
forward_batch.seq_lens,
|
||||||
|
forward_batch.seq_lens_sum,
|
||||||
|
prefix_lens=None,
|
||||||
|
prefill_wrapper=self.prefill_wrapper_verify,
|
||||||
|
encoder_lens=forward_batch.encoder_lens,
|
||||||
|
spec_info=forward_batch.spec_info,
|
||||||
|
)
|
||||||
|
self.forward_metadata = PrefillMetadata(
|
||||||
|
self.prefill_wrapper_verify, False, False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prefix_lens = forward_batch.extend_prefix_lens
|
||||||
|
|
||||||
|
if self.is_multimodal:
|
||||||
|
extend_no_prefix = False
|
||||||
|
else:
|
||||||
|
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
||||||
|
|
||||||
|
self.indices_updater_prefill.update(
|
||||||
|
forward_batch.req_pool_indices,
|
||||||
|
forward_batch.seq_lens,
|
||||||
|
forward_batch.seq_lens_sum,
|
||||||
|
prefix_lens,
|
||||||
|
prefill_wrapper=self.prefill_wrapper_paged,
|
||||||
|
encoder_lens=forward_batch.encoder_lens,
|
||||||
|
spec_info=None,
|
||||||
|
)
|
||||||
|
self.forward_metadata = PrefillMetadata(
|
||||||
|
self.prefill_wrapper_paged, extend_no_prefix
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_cuda_graph_state(
|
||||||
|
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
|
||||||
|
):
|
||||||
|
if kv_indices_buf is None:
|
||||||
|
self.cuda_graph_kv_indices = torch.zeros(
|
||||||
|
(max_bs * self.max_context_len),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.cuda_graph_kv_indices = kv_indices_buf
|
||||||
|
|
||||||
|
if not self.skip_prefill:
|
||||||
|
self.cuda_graph_custom_mask = torch.zeros(
|
||||||
|
(max_bs * self.max_context_len),
|
||||||
|
dtype=torch.uint8,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_forward_metadata_capture_cuda_graph(
|
||||||
|
self,
|
||||||
|
bs: int,
|
||||||
|
num_tokens: int,
|
||||||
|
req_pool_indices: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
encoder_lens: Optional[torch.Tensor],
|
||||||
|
forward_mode: ForwardMode,
|
||||||
|
spec_info: Optional[SpecInfo],
|
||||||
|
):
|
||||||
|
if forward_mode.is_decode_or_idle():
|
||||||
|
if spec_info is None:
|
||||||
|
kv_indptr = self.kv_indptr
|
||||||
|
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
||||||
|
kv_indptr = kv_indptr[: bs + 1]
|
||||||
|
kv_indices = self.cuda_graph_kv_indices
|
||||||
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
|
self.req_to_token,
|
||||||
|
req_pool_indices,
|
||||||
|
seq_lens,
|
||||||
|
kv_indptr,
|
||||||
|
None,
|
||||||
|
kv_indices,
|
||||||
|
self.req_to_token.stride(0),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
||||||
|
self.forward_metadata = DecodeMetadata(kv_indptr, kv_indices)
|
||||||
|
self.decode_cuda_graph_metadata[bs] = self.forward_metadata
|
||||||
|
|
||||||
|
elif forward_mode.is_target_verify():
|
||||||
|
prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
|
||||||
|
self.workspace_buffer,
|
||||||
|
"NHD",
|
||||||
|
use_cuda_graph=False,
|
||||||
|
qo_indptr_buf=self.cuda_graph_qo_indptr[: bs + 1],
|
||||||
|
paged_kv_indptr_buf=self.kv_indptr[: bs + 1],
|
||||||
|
paged_kv_indices_buf=self.cuda_graph_kv_indices,
|
||||||
|
paged_kv_last_page_len_buf=self.kv_last_page_len[:bs],
|
||||||
|
custom_mask_buf=self.cuda_graph_custom_mask,
|
||||||
|
mask_indptr_buf=self.cuda_graph_qk_indptr[: bs + 1],
|
||||||
|
)
|
||||||
|
|
||||||
|
seq_lens_sum = seq_lens.sum().item()
|
||||||
|
self.indices_updater_prefill.update(
|
||||||
|
req_pool_indices,
|
||||||
|
seq_lens,
|
||||||
|
seq_lens_sum,
|
||||||
|
prefix_lens=None,
|
||||||
|
prefill_wrapper=prefill_wrapper,
|
||||||
|
encoder_lens=encoder_lens,
|
||||||
|
spec_info=spec_info,
|
||||||
|
)
|
||||||
|
self.prefill_cuda_graph_metadata[bs] = prefill_wrapper
|
||||||
|
self.forward_metadata = PrefillMetadata(prefill_wrapper, False)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid mode: {forward_mode=}")
|
||||||
|
|
||||||
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
|
self,
|
||||||
|
bs: int,
|
||||||
|
req_pool_indices: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
seq_lens_sum: int,
|
||||||
|
encoder_lens: Optional[torch.Tensor],
|
||||||
|
forward_mode: ForwardMode,
|
||||||
|
spec_info: Optional[SpecInfo],
|
||||||
|
seq_lens_cpu: Optional[torch.Tensor],
|
||||||
|
):
|
||||||
|
if forward_mode.is_decode_or_idle():
|
||||||
|
kv_indptr = self.kv_indptr
|
||||||
|
kv_indices = self.cuda_graph_kv_indices
|
||||||
|
if spec_info is None:
|
||||||
|
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)
|
||||||
|
kv_indptr = kv_indptr[: bs + 1]
|
||||||
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
|
self.req_to_token,
|
||||||
|
req_pool_indices[:bs],
|
||||||
|
seq_lens[:bs],
|
||||||
|
kv_indptr,
|
||||||
|
None,
|
||||||
|
kv_indices,
|
||||||
|
self.req_to_token.stride(0),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr
|
||||||
|
kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
|
||||||
|
|
||||||
|
elif forward_mode.is_target_verify():
|
||||||
|
self.indices_updater_prefill.update(
|
||||||
|
req_pool_indices[:bs],
|
||||||
|
seq_lens[:bs],
|
||||||
|
seq_lens_sum,
|
||||||
|
prefix_lens=None,
|
||||||
|
prefill_wrapper=self.prefill_cuda_graph_metadata[bs],
|
||||||
|
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
||||||
|
spec_info=spec_info,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid forward mode")
|
||||||
|
|
||||||
|
def get_cuda_graph_seq_len_fill_value(self):
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def forward_extend(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
layer: RadixAttention,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
save_kv_cache=True,
|
||||||
|
):
|
||||||
|
prefill_wrapper_paged = self.forward_metadata.prefill_wrapper
|
||||||
|
cache_loc = (
|
||||||
|
forward_batch.out_cache_loc
|
||||||
|
if not layer.is_cross_attention
|
||||||
|
else forward_batch.encoder_out_cache_loc
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logits_soft_cap = layer.logit_cap
|
||||||
|
|
||||||
|
if k is not None:
|
||||||
|
assert v is not None
|
||||||
|
if save_kv_cache:
|
||||||
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
|
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
o = prefill_wrapper_paged.forward(
|
||||||
|
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||||
|
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||||
|
causal=not layer.is_cross_attention,
|
||||||
|
sm_scale=layer.scaling,
|
||||||
|
window_left=layer.sliding_window_size,
|
||||||
|
logits_soft_cap=self.logits_soft_cap,
|
||||||
|
k_scale=layer.k_scale,
|
||||||
|
v_scale=layer.v_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||||
|
|
||||||
|
def forward_decode(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
layer: RadixAttention,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
save_kv_cache=True,
|
||||||
|
):
|
||||||
|
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
||||||
|
|
||||||
|
if layer.qk_head_dim != layer.v_head_dim:
|
||||||
|
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
||||||
|
else:
|
||||||
|
o = torch.empty_like(q)
|
||||||
|
|
||||||
|
if save_kv_cache:
|
||||||
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
|
layer, forward_batch.out_cache_loc, k, v
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logits_soft_cap = layer.logit_cap
|
||||||
|
paged_attention_rocm(
|
||||||
|
o.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||||
|
self.workspace_buffer,
|
||||||
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||||
|
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).view(
|
||||||
|
-1, 1, layer.tp_k_head_num, layer.qk_head_dim
|
||||||
|
),
|
||||||
|
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id).view(
|
||||||
|
-1, 1, layer.tp_v_head_num, layer.v_head_dim
|
||||||
|
),
|
||||||
|
self.scale,
|
||||||
|
self.forward_metadata.kv_indptr,
|
||||||
|
self.forward_metadata.kv_indices,
|
||||||
|
self.kv_last_page_lens,
|
||||||
|
1,
|
||||||
|
self.max_num_partitions,
|
||||||
|
None,
|
||||||
|
"auto",
|
||||||
|
"NHD",
|
||||||
|
self.logits_soft_cap,
|
||||||
|
self.k_scale,
|
||||||
|
self.v_scale,
|
||||||
|
None,
|
||||||
|
_AITER_PARTITION_SIZE_ROCM,
|
||||||
|
)
|
||||||
|
|
||||||
|
return o
|
||||||
|
|
||||||
|
|
||||||
|
class FlashInferIndicesUpdaterPrefill:
|
||||||
|
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
|
||||||
|
# Parse Constants
|
||||||
|
self.num_qo_heads = (
|
||||||
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||||
|
)
|
||||||
|
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
|
||||||
|
get_attention_tp_size()
|
||||||
|
)
|
||||||
|
self.head_dim = model_runner.model_config.head_dim
|
||||||
|
self.data_type = model_runner.kv_cache_dtype
|
||||||
|
self.q_data_type = model_runner.dtype
|
||||||
|
self.sliding_window_size = model_runner.sliding_window_size
|
||||||
|
self.attn_backend = attn_backend
|
||||||
|
|
||||||
|
# Buffers and wrappers
|
||||||
|
self.kv_indptr = attn_backend.kv_indptr
|
||||||
|
self.kv_last_page_len = attn_backend.kv_last_page_len
|
||||||
|
self.qo_indptr = attn_backend.qo_indptr
|
||||||
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||||
|
self.update = self.update_single_wrapper
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
req_pool_indices: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
seq_lens_sum: int,
|
||||||
|
prefix_lens: torch.Tensor,
|
||||||
|
prefill_wrapper: BatchPrefillWithPagedKVCacheWrapper,
|
||||||
|
encoder_lens: Optional[torch.Tensor],
|
||||||
|
spec_info: Optional[SpecInfo],
|
||||||
|
):
|
||||||
|
# Keep the signature for type checking. It will be assigned during runtime.
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def update_single_wrapper(
|
||||||
|
self,
|
||||||
|
req_pool_indices: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
seq_lens_sum: int,
|
||||||
|
prefix_lens: torch.Tensor,
|
||||||
|
prefill_wrapper: BatchPrefillWithPagedKVCacheWrapper,
|
||||||
|
encoder_lens: Optional[torch.Tensor],
|
||||||
|
spec_info: Optional[SpecInfo],
|
||||||
|
):
|
||||||
|
|
||||||
|
paged_kernel_lens = seq_lens
|
||||||
|
paged_kernel_lens_sum = seq_lens_sum
|
||||||
|
|
||||||
|
self.call_begin_forward(
|
||||||
|
prefill_wrapper,
|
||||||
|
req_pool_indices,
|
||||||
|
paged_kernel_lens,
|
||||||
|
paged_kernel_lens_sum,
|
||||||
|
seq_lens,
|
||||||
|
prefix_lens,
|
||||||
|
None,
|
||||||
|
self.kv_indptr,
|
||||||
|
self.qo_indptr,
|
||||||
|
spec_info,
|
||||||
|
)
|
||||||
|
|
||||||
|
def call_begin_forward(
|
||||||
|
self,
|
||||||
|
wrapper_paged: BatchPrefillWithPagedKVCacheWrapper,
|
||||||
|
req_pool_indices: torch.Tensor,
|
||||||
|
paged_kernel_lens: torch.Tensor,
|
||||||
|
paged_kernel_lens_sum: int,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
prefix_lens: torch.Tensor,
|
||||||
|
kv_start_idx: torch.Tensor,
|
||||||
|
kv_indptr: torch.Tensor,
|
||||||
|
qo_indptr: torch.Tensor,
|
||||||
|
spec_info: Optional[SpecInfo],
|
||||||
|
):
|
||||||
|
bs = len(req_pool_indices)
|
||||||
|
if spec_info is None:
|
||||||
|
# Normal extend
|
||||||
|
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||||
|
kv_indptr = kv_indptr[: bs + 1]
|
||||||
|
kv_indices = torch.empty(
|
||||||
|
paged_kernel_lens_sum + 256,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=req_pool_indices.device,
|
||||||
|
)
|
||||||
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
|
self.req_to_token,
|
||||||
|
req_pool_indices,
|
||||||
|
paged_kernel_lens,
|
||||||
|
kv_indptr,
|
||||||
|
kv_start_idx,
|
||||||
|
kv_indices,
|
||||||
|
self.req_to_token.shape[1],
|
||||||
|
)
|
||||||
|
|
||||||
|
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
||||||
|
qo_indptr = qo_indptr[: bs + 1]
|
||||||
|
custom_mask = None
|
||||||
|
else:
|
||||||
|
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
||||||
|
spec_info.generate_attn_arg_prefill(
|
||||||
|
req_pool_indices,
|
||||||
|
paged_kernel_lens,
|
||||||
|
self.req_to_token,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# cached part
|
||||||
|
# adding logits_soft_cap arg in plan() stage
|
||||||
|
wrapper_paged.begin_forward(
|
||||||
|
qo_indptr,
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
self.kv_last_page_len[:bs],
|
||||||
|
self.num_qo_heads,
|
||||||
|
self.num_kv_heads,
|
||||||
|
self.head_dim,
|
||||||
|
1,
|
||||||
|
q_data_type=self.q_data_type,
|
||||||
|
custom_mask=custom_mask,
|
||||||
|
non_blocking=True,
|
||||||
|
logits_soft_cap=self.attn_backend.logits_soft_cap,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def create_flashinfer_kv_indices_triton(
|
||||||
|
req_to_token_ptr, # [max_batch, max_context_len]
|
||||||
|
req_pool_indices_ptr,
|
||||||
|
page_kernel_lens_ptr,
|
||||||
|
kv_indptr,
|
||||||
|
kv_start_idx,
|
||||||
|
kv_indices_ptr,
|
||||||
|
req_to_token_ptr_stride: tl.constexpr,
|
||||||
|
):
|
||||||
|
BLOCK_SIZE: tl.constexpr = 512
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
|
||||||
|
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
||||||
|
kv_indices_offset = tl.load(kv_indptr + pid)
|
||||||
|
|
||||||
|
kv_start = 0
|
||||||
|
kv_end = 0
|
||||||
|
if kv_start_idx:
|
||||||
|
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
|
||||||
|
kv_end = kv_start
|
||||||
|
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
|
||||||
|
|
||||||
|
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
||||||
|
for i in range(num_loop):
|
||||||
|
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
||||||
|
mask = offset < kv_end - kv_start
|
||||||
|
data = tl.load(
|
||||||
|
req_to_token_ptr
|
||||||
|
+ req_pool_index * req_to_token_ptr_stride
|
||||||
|
+ kv_start
|
||||||
|
+ offset,
|
||||||
|
mask=mask,
|
||||||
|
)
|
||||||
|
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
|
||||||
535
python/sglang/srt/layers/attention/aiter_decode_backend.py
Normal file
535
python/sglang/srt/layers/attention/aiter_decode_backend.py
Normal file
@@ -0,0 +1,535 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||||
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
|
from sglang.srt.speculative.spec_info import SpecInfo
|
||||||
|
|
||||||
|
try:
|
||||||
|
from aiter import paged_attention_rocm
|
||||||
|
except ImportError:
|
||||||
|
print(
|
||||||
|
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
|
||||||
|
)
|
||||||
|
|
||||||
|
from sglang.srt.layers.attention.triton_ops.extend_attention import extend_attention_fwd
|
||||||
|
|
||||||
|
_AITER_PARTITION_SIZE_ROCM = 256
|
||||||
|
|
||||||
|
|
||||||
|
class AiterDecodeAttnBackend(AttentionBackend):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_runner: ModelRunner,
|
||||||
|
skip_prefill: bool = False,
|
||||||
|
kv_indptr_buf: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.decode_attention_fwd = paged_attention_rocm
|
||||||
|
self.extend_attention_fwd = extend_attention_fwd
|
||||||
|
|
||||||
|
self.skip_prefill = skip_prefill
|
||||||
|
|
||||||
|
max_bs = model_runner.req_to_token_pool.size
|
||||||
|
|
||||||
|
if kv_indptr_buf is None:
|
||||||
|
self.kv_indptr = torch.zeros(
|
||||||
|
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.kv_indptr = kv_indptr_buf
|
||||||
|
|
||||||
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||||
|
|
||||||
|
if not self.skip_prefill:
|
||||||
|
self.qo_indptr = torch.zeros(
|
||||||
|
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mask_indptr = torch.zeros(
|
||||||
|
(max_bs + 1,), dtype=torch.int64, device=model_runner.device
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
|
||||||
|
|
||||||
|
# tp sharding on number of heads
|
||||||
|
self.num_head = (
|
||||||
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.head_dim = model_runner.model_config.head_dim
|
||||||
|
|
||||||
|
# triton prefill initialization
|
||||||
|
self.num_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
|
||||||
|
|
||||||
|
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
|
||||||
|
|
||||||
|
self.num_v_head = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-2]
|
||||||
|
|
||||||
|
self.forward_metadata = None
|
||||||
|
|
||||||
|
self.max_context_len = model_runner.model_config.context_len
|
||||||
|
|
||||||
|
self.device = model_runner.device
|
||||||
|
|
||||||
|
self.kv_cache_dtype = model_runner.kv_cache_dtype
|
||||||
|
|
||||||
|
self.q_dtype = model_runner.model_config.dtype
|
||||||
|
|
||||||
|
# aiter decode initialization
|
||||||
|
self.max_num_partitions = (
|
||||||
|
self.max_context_len + _AITER_PARTITION_SIZE_ROCM - 1
|
||||||
|
) // _AITER_PARTITION_SIZE_ROCM
|
||||||
|
|
||||||
|
nbyes_per_qo_elem = torch.finfo(torch.float32).bits // 8
|
||||||
|
|
||||||
|
self.workspace_buffer = torch.empty(
|
||||||
|
(max_bs * self.num_head * self.max_num_partitions * self.head_dim)
|
||||||
|
* nbyes_per_qo_elem
|
||||||
|
+ 2 * (max_bs * self.num_head * self.max_num_partitions) * 4,
|
||||||
|
dtype=torch.uint8,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.scale = float(1.0 / (self.head_dim**0.5))
|
||||||
|
self.k_scale = self.v_scale = torch.tensor([1.0], dtype=torch.float32).to(
|
||||||
|
self.device
|
||||||
|
)
|
||||||
|
self.kv_last_page_lens = torch.ones((max_bs,), dtype=torch.int32).to(
|
||||||
|
self.device
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
|
"""Init auxiliary variables"""
|
||||||
|
bs = forward_batch.batch_size
|
||||||
|
kv_indptr = self.kv_indptr
|
||||||
|
spec_info = forward_batch.spec_info
|
||||||
|
|
||||||
|
if forward_batch.forward_mode.is_decode_or_idle():
|
||||||
|
if spec_info is None:
|
||||||
|
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
||||||
|
kv_indptr = kv_indptr[: bs + 1]
|
||||||
|
kv_indices = torch.zeros(
|
||||||
|
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
|
||||||
|
)
|
||||||
|
# prepare kv_indices and kv_indptr
|
||||||
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
|
self.req_to_token,
|
||||||
|
forward_batch.req_pool_indices,
|
||||||
|
forward_batch.seq_lens,
|
||||||
|
kv_indptr,
|
||||||
|
None,
|
||||||
|
kv_indices,
|
||||||
|
self.req_to_token.stride(0),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
||||||
|
bs = kv_indptr.shape[0] - 1
|
||||||
|
|
||||||
|
attn_logits = None # accomodate forward_metadata format
|
||||||
|
qo_indptr = None
|
||||||
|
custom_mask = None
|
||||||
|
mask_indptr = None
|
||||||
|
max_extend_len = None
|
||||||
|
elif forward_batch.forward_mode.is_target_verify():
|
||||||
|
bs = len(forward_batch.req_pool_indices)
|
||||||
|
qo_indptr = torch.arange(
|
||||||
|
0,
|
||||||
|
(1 + bs) * self.num_draft_tokens,
|
||||||
|
step=self.num_draft_tokens,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
||||||
|
kv_indptr = kv_indptr[: bs + 1]
|
||||||
|
kv_indices = torch.zeros(
|
||||||
|
kv_indptr[-1], dtype=torch.int32, device=self.device
|
||||||
|
)
|
||||||
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
|
self.req_to_token,
|
||||||
|
forward_batch.req_pool_indices,
|
||||||
|
forward_batch.seq_lens,
|
||||||
|
kv_indptr,
|
||||||
|
None,
|
||||||
|
kv_indices,
|
||||||
|
self.req_to_token.stride(0),
|
||||||
|
)
|
||||||
|
|
||||||
|
custom_mask = spec_info.custom_mask
|
||||||
|
seq_mask_len = self.num_draft_tokens * (
|
||||||
|
forward_batch.seq_lens + self.num_draft_tokens
|
||||||
|
)
|
||||||
|
mask_indptr = self.mask_indptr
|
||||||
|
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len[:bs], dim=0)
|
||||||
|
mask_indptr = mask_indptr[: bs + 1]
|
||||||
|
max_extend_len = self.num_draft_tokens
|
||||||
|
attn_logits = None
|
||||||
|
elif forward_batch.forward_mode.is_draft_extend():
|
||||||
|
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
||||||
|
spec_info.generate_attn_arg_prefill(
|
||||||
|
forward_batch.req_pool_indices,
|
||||||
|
forward_batch.seq_lens,
|
||||||
|
self.req_to_token,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
mask_indptr = None
|
||||||
|
max_extend_len = torch.max(spec_info.accept_length).item()
|
||||||
|
attn_logits = None
|
||||||
|
else:
|
||||||
|
kv_indptr[1 : bs + 1] = torch.cumsum(
|
||||||
|
forward_batch.extend_prefix_lens, dim=0
|
||||||
|
)
|
||||||
|
kv_indptr = kv_indptr[: bs + 1]
|
||||||
|
kv_indices = torch.zeros(
|
||||||
|
forward_batch.extend_prefix_lens.sum().item(),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
|
self.req_to_token,
|
||||||
|
forward_batch.req_pool_indices,
|
||||||
|
forward_batch.extend_prefix_lens,
|
||||||
|
kv_indptr,
|
||||||
|
None,
|
||||||
|
kv_indices,
|
||||||
|
self.req_to_token.stride(0),
|
||||||
|
)
|
||||||
|
|
||||||
|
qo_indptr = self.qo_indptr
|
||||||
|
qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0)
|
||||||
|
qo_indptr = qo_indptr[: bs + 1]
|
||||||
|
custom_mask = None
|
||||||
|
mask_indptr = None
|
||||||
|
attn_logits = None
|
||||||
|
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
|
||||||
|
|
||||||
|
self.forward_metadata = (
|
||||||
|
attn_logits,
|
||||||
|
max_extend_len,
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
qo_indptr,
|
||||||
|
custom_mask,
|
||||||
|
mask_indptr,
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_cuda_graph_state(
|
||||||
|
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
|
||||||
|
):
|
||||||
|
|
||||||
|
self.cuda_graph_attn_logits = torch.zeros(
|
||||||
|
(max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
if kv_indices_buf is None:
|
||||||
|
self.cuda_graph_kv_indices = torch.zeros(
|
||||||
|
(max_bs * self.max_context_len),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.cuda_graph_kv_indices = kv_indices_buf
|
||||||
|
|
||||||
|
if not self.skip_prefill:
|
||||||
|
self.cuda_graph_custom_mask = torch.zeros(
|
||||||
|
(max_bs * self.max_context_len),
|
||||||
|
dtype=torch.uint8,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_forward_metadata_capture_cuda_graph(
|
||||||
|
self,
|
||||||
|
bs: int,
|
||||||
|
num_tokens: int,
|
||||||
|
req_pool_indices: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
encoder_lens: Optional[torch.Tensor],
|
||||||
|
forward_mode: ForwardMode,
|
||||||
|
spec_info: Optional[SpecInfo],
|
||||||
|
):
|
||||||
|
assert encoder_lens is None, "Not supported"
|
||||||
|
|
||||||
|
if forward_mode.is_decode_or_idle():
|
||||||
|
if spec_info is None:
|
||||||
|
kv_indptr = self.kv_indptr
|
||||||
|
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
||||||
|
kv_indptr = kv_indptr[: bs + 1]
|
||||||
|
kv_indices = self.cuda_graph_kv_indices
|
||||||
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
|
self.req_to_token,
|
||||||
|
req_pool_indices,
|
||||||
|
seq_lens,
|
||||||
|
kv_indptr,
|
||||||
|
None,
|
||||||
|
kv_indices,
|
||||||
|
self.req_to_token.stride(0),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
||||||
|
attn_logits = None
|
||||||
|
max_extend_len = None
|
||||||
|
qo_indptr = None
|
||||||
|
custom_mask = None
|
||||||
|
mask_indptr = None
|
||||||
|
elif forward_mode.is_target_verify():
|
||||||
|
qo_indptr = self.qo_indptr[: bs + 1]
|
||||||
|
qo_indptr[: bs + 1] = torch.arange(
|
||||||
|
0,
|
||||||
|
(1 + bs) * self.num_draft_tokens,
|
||||||
|
step=self.num_draft_tokens,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
kv_indptr = self.kv_indptr[: bs + 1]
|
||||||
|
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
||||||
|
kv_indices = self.cuda_graph_kv_indices
|
||||||
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
|
self.req_to_token,
|
||||||
|
req_pool_indices,
|
||||||
|
seq_lens,
|
||||||
|
kv_indptr,
|
||||||
|
None,
|
||||||
|
kv_indices,
|
||||||
|
self.req_to_token.stride(0),
|
||||||
|
)
|
||||||
|
|
||||||
|
custom_mask = self.cuda_graph_custom_mask
|
||||||
|
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
|
||||||
|
mask_indptr = self.mask_indptr[: bs + 1]
|
||||||
|
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
|
||||||
|
max_extend_len = self.num_draft_tokens
|
||||||
|
attn_logits = None
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid forward mode: {forward_mode=} for CUDA Graph capture."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.forward_metadata = (
|
||||||
|
attn_logits,
|
||||||
|
max_extend_len,
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
qo_indptr,
|
||||||
|
custom_mask,
|
||||||
|
mask_indptr,
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
|
self,
|
||||||
|
bs: int,
|
||||||
|
req_pool_indices: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
seq_lens_sum: int,
|
||||||
|
encoder_lens: Optional[torch.Tensor],
|
||||||
|
forward_mode: ForwardMode,
|
||||||
|
spec_info: Optional[SpecInfo],
|
||||||
|
seq_lens_cpu: Optional[torch.Tensor],
|
||||||
|
):
|
||||||
|
# NOTE: encoder_lens expected to be zeros or None
|
||||||
|
if forward_mode.is_decode_or_idle():
|
||||||
|
# Update kv_indptr, kv_indices
|
||||||
|
kv_indptr = self.kv_indptr
|
||||||
|
kv_indices = self.cuda_graph_kv_indices
|
||||||
|
if spec_info is None:
|
||||||
|
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)
|
||||||
|
kv_indptr = kv_indptr[: bs + 1]
|
||||||
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
|
self.req_to_token,
|
||||||
|
req_pool_indices[:bs],
|
||||||
|
seq_lens[:bs],
|
||||||
|
kv_indptr,
|
||||||
|
None,
|
||||||
|
kv_indices,
|
||||||
|
self.req_to_token.stride(0),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr
|
||||||
|
kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
|
||||||
|
elif forward_mode.is_target_verify():
|
||||||
|
# Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr
|
||||||
|
bs = len(req_pool_indices)
|
||||||
|
qo_indptr = self.qo_indptr[: bs + 1]
|
||||||
|
qo_indptr[: bs + 1] = torch.arange(
|
||||||
|
0,
|
||||||
|
(1 + bs) * self.num_draft_tokens,
|
||||||
|
step=self.num_draft_tokens,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
kv_indptr = self.kv_indptr[: bs + 1]
|
||||||
|
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
||||||
|
kv_indices = self.cuda_graph_kv_indices
|
||||||
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
|
self.req_to_token,
|
||||||
|
req_pool_indices,
|
||||||
|
seq_lens,
|
||||||
|
kv_indptr,
|
||||||
|
None,
|
||||||
|
kv_indices,
|
||||||
|
self.req_to_token.stride(0),
|
||||||
|
)
|
||||||
|
custom_mask = self.cuda_graph_custom_mask
|
||||||
|
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
|
||||||
|
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
|
||||||
|
mask_indptr = self.mask_indptr[: bs + 1]
|
||||||
|
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid forward mode: {forward_mode=} for CUDA Graph replay."
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_cuda_graph_seq_len_fill_value(self):
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def forward_extend(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
layer: RadixAttention,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
save_kv_cache=True,
|
||||||
|
):
|
||||||
|
# TODO: reuse the buffer across layers
|
||||||
|
if layer.qk_head_dim != layer.v_head_dim:
|
||||||
|
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
||||||
|
else:
|
||||||
|
o = torch.empty_like(q)
|
||||||
|
|
||||||
|
if save_kv_cache:
|
||||||
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
|
layer, forward_batch.out_cache_loc, k, v
|
||||||
|
)
|
||||||
|
|
||||||
|
(
|
||||||
|
_,
|
||||||
|
max_extend_len,
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
qo_indptr,
|
||||||
|
custom_mask,
|
||||||
|
mask_indptr,
|
||||||
|
) = self.forward_metadata
|
||||||
|
|
||||||
|
self.extend_attention_fwd(
|
||||||
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||||
|
k.contiguous(),
|
||||||
|
v.contiguous(),
|
||||||
|
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
||||||
|
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
||||||
|
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
||||||
|
qo_indptr,
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
custom_mask,
|
||||||
|
mask_indptr,
|
||||||
|
max_extend_len,
|
||||||
|
layer.scaling,
|
||||||
|
layer.logit_cap,
|
||||||
|
)
|
||||||
|
return o
|
||||||
|
|
||||||
|
def forward_decode(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
layer: RadixAttention,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
save_kv_cache=True,
|
||||||
|
):
|
||||||
|
# During torch.compile, there is a bug in rotary_emb that causes the
|
||||||
|
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
||||||
|
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
||||||
|
|
||||||
|
if layer.qk_head_dim != layer.v_head_dim:
|
||||||
|
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
||||||
|
else:
|
||||||
|
o = torch.empty_like(q)
|
||||||
|
|
||||||
|
attn_logits, _, kv_indptr, kv_indices, _, _, _ = self.forward_metadata
|
||||||
|
|
||||||
|
if save_kv_cache:
|
||||||
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
|
layer, forward_batch.out_cache_loc, k, v
|
||||||
|
)
|
||||||
|
|
||||||
|
self.decode_attention_fwd(
|
||||||
|
o.view(
|
||||||
|
-1, layer.tp_q_head_num, layer.qk_head_dim
|
||||||
|
), # (bs, head_num_q, head_dim_q)
|
||||||
|
self.workspace_buffer,
|
||||||
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||||
|
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).view(
|
||||||
|
-1, 1, layer.tp_k_head_num, layer.qk_head_dim
|
||||||
|
),
|
||||||
|
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id).view(
|
||||||
|
-1, 1, layer.tp_v_head_num, layer.v_head_dim
|
||||||
|
),
|
||||||
|
self.scale,
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
self.kv_last_page_lens,
|
||||||
|
1,
|
||||||
|
self.max_num_partitions,
|
||||||
|
None,
|
||||||
|
"auto",
|
||||||
|
"NHD",
|
||||||
|
layer.logit_cap,
|
||||||
|
self.k_scale,
|
||||||
|
self.v_scale,
|
||||||
|
None,
|
||||||
|
_AITER_PARTITION_SIZE_ROCM,
|
||||||
|
)
|
||||||
|
|
||||||
|
return o
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def create_flashinfer_kv_indices_triton(
|
||||||
|
req_to_token_ptr, # [max_batch, max_context_len]
|
||||||
|
req_pool_indices_ptr,
|
||||||
|
page_kernel_lens_ptr,
|
||||||
|
kv_indptr,
|
||||||
|
kv_start_idx,
|
||||||
|
kv_indices_ptr,
|
||||||
|
req_to_token_ptr_stride: tl.constexpr,
|
||||||
|
):
|
||||||
|
BLOCK_SIZE: tl.constexpr = 512
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
|
||||||
|
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
||||||
|
kv_indices_offset = tl.load(kv_indptr + pid)
|
||||||
|
|
||||||
|
kv_start = 0
|
||||||
|
kv_end = 0
|
||||||
|
if kv_start_idx:
|
||||||
|
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
|
||||||
|
kv_end = kv_start
|
||||||
|
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
|
||||||
|
|
||||||
|
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
||||||
|
for i in range(num_loop):
|
||||||
|
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
||||||
|
mask = offset < kv_end - kv_start
|
||||||
|
data = tl.load(
|
||||||
|
req_to_token_ptr
|
||||||
|
+ req_pool_index * req_to_token_ptr_stride
|
||||||
|
+ kv_start
|
||||||
|
+ offset,
|
||||||
|
mask=mask,
|
||||||
|
)
|
||||||
|
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
|
||||||
@@ -79,6 +79,12 @@ from sglang.srt.utils import (
|
|||||||
)
|
)
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
|
is_hip_ = is_hip()
|
||||||
|
|
||||||
|
if is_hip_:
|
||||||
|
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
|
||||||
|
from sglang.srt.layers.attention.aiter_decode_backend import AiterDecodeAttnBackend
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -641,7 +647,7 @@ class ModelRunner:
|
|||||||
if self.server_args.kv_cache_dtype == "auto":
|
if self.server_args.kv_cache_dtype == "auto":
|
||||||
self.kv_cache_dtype = self.dtype
|
self.kv_cache_dtype = self.dtype
|
||||||
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
|
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
|
||||||
if is_hip(): # Using natively supported format
|
if is_hip_: # Using natively supported format
|
||||||
self.kv_cache_dtype = torch.float8_e5m2fnuz
|
self.kv_cache_dtype = torch.float8_e5m2fnuz
|
||||||
else:
|
else:
|
||||||
self.kv_cache_dtype = torch.float8_e5m2
|
self.kv_cache_dtype = torch.float8_e5m2
|
||||||
@@ -778,6 +784,7 @@ class ModelRunner:
|
|||||||
|
|
||||||
def init_attention_backend(self):
|
def init_attention_backend(self):
|
||||||
"""Init attention kernel backend."""
|
"""Init attention kernel backend."""
|
||||||
|
if is_cuda():
|
||||||
if self.server_args.attention_backend == "flashinfer":
|
if self.server_args.attention_backend == "flashinfer":
|
||||||
# Init streams
|
# Init streams
|
||||||
if self.server_args.speculative_algorithm == "EAGLE":
|
if self.server_args.speculative_algorithm == "EAGLE":
|
||||||
@@ -805,6 +812,31 @@ class ModelRunner:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid attention backend: {self.server_args.attention_backend}"
|
f"Invalid attention backend: {self.server_args.attention_backend}"
|
||||||
)
|
)
|
||||||
|
elif is_hip_:
|
||||||
|
# AMD hip supported attention backends
|
||||||
|
if self.server_args.attention_backend == "aiter":
|
||||||
|
self.attn_backend = AiterAttnBackend(self)
|
||||||
|
elif self.server_args.attention_backend == "aiter_decode":
|
||||||
|
self.attn_backend = AiterDecodeAttnBackend(self)
|
||||||
|
elif self.server_args.attention_backend == "triton":
|
||||||
|
assert self.sliding_window_size is None, (
|
||||||
|
"Window attention is not supported in the triton attention backend. "
|
||||||
|
"Please use `--attention-backend flashinfer`."
|
||||||
|
)
|
||||||
|
assert not self.model_config.is_encoder_decoder, (
|
||||||
|
"Cross attention is not supported in the triton attention backend. "
|
||||||
|
"Please use `--attention-backend flashinfer`."
|
||||||
|
)
|
||||||
|
if self.server_args.enable_double_sparsity:
|
||||||
|
self.attn_backend = DoubleSparseAttnBackend(self)
|
||||||
|
else:
|
||||||
|
self.attn_backend = TritonAttnBackend(self)
|
||||||
|
elif self.server_args.attention_backend == "torch_native":
|
||||||
|
self.attn_backend = TorchNativeAttnBackend(self)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid attention backend: {self.server_args.attention_backend}"
|
||||||
|
)
|
||||||
|
|
||||||
def init_double_sparsity_channel_config(self, selected_channel):
|
def init_double_sparsity_channel_config(self, selected_channel):
|
||||||
selected_channel = "." + selected_channel + "_proj"
|
selected_channel = "." + selected_channel + "_proj"
|
||||||
|
|||||||
@@ -710,6 +710,15 @@ class ServerArgs:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Kernel backend
|
# Kernel backend
|
||||||
|
if is_hip():
|
||||||
|
parser.add_argument(
|
||||||
|
"--attention-backend",
|
||||||
|
type=str,
|
||||||
|
choices=["triton", "torch_native", "aiter", "aiter_decode"],
|
||||||
|
default=ServerArgs.attention_backend,
|
||||||
|
help="Choose the kernels for attention layers.",
|
||||||
|
)
|
||||||
|
else:
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--attention-backend",
|
"--attention-backend",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -717,6 +726,7 @@ class ServerArgs:
|
|||||||
default=ServerArgs.attention_backend,
|
default=ServerArgs.attention_backend,
|
||||||
help="Choose the kernels for attention layers.",
|
help="Choose the kernels for attention layers.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sampling-backend",
|
"--sampling-backend",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
118
sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.hip
Normal file
118
sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.hip
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
// !!! This is a file automatically generated by hipify!!!
|
||||||
|
#include "hip/hip_runtime.h"
|
||||||
|
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <ATen/ATen.h>
|
||||||
|
#include <ATen/hip/HIPContext.h>
|
||||||
|
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
#include <THH/THHAtomics.cuh>
|
||||||
|
|
||||||
|
#include "utils_hip.h"
|
||||||
|
|
||||||
|
#define WARP_SIZE 32
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
__global__ void count_and_sort_expert_tokens_kernel(const scalar_t* __restrict__ topk_ids,
|
||||||
|
int32_t* __restrict__ sorted_token_ids,
|
||||||
|
int32_t* __restrict__ cumsum_buffer, size_t numel) {
|
||||||
|
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
const size_t stride = blockDim.x * gridDim.x;
|
||||||
|
|
||||||
|
for (size_t i = tid; i < numel; i += stride) {
|
||||||
|
int32_t expert_id = topk_ids[i];
|
||||||
|
int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1);
|
||||||
|
sorted_token_ids[rank_post_pad] = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
__global__ void moe_align_block_size_kernel(const scalar_t* __restrict__ topk_ids,
|
||||||
|
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
|
||||||
|
int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts,
|
||||||
|
int32_t block_size, size_t numel, int32_t* __restrict__ cumsum) {
|
||||||
|
__shared__ int32_t shared_counts[WARP_SIZE][8];
|
||||||
|
|
||||||
|
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||||
|
const int experts_per_warp = 8;
|
||||||
|
const int my_expert_start = warp_id * experts_per_warp;
|
||||||
|
|
||||||
|
for (int i = 0; i < experts_per_warp; ++i) {
|
||||||
|
if (my_expert_start + i < num_experts) {
|
||||||
|
shared_counts[warp_id][i] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
|
||||||
|
const size_t start_idx = threadIdx.x * tokens_per_thread;
|
||||||
|
|
||||||
|
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
||||||
|
int expert_id = topk_ids[i];
|
||||||
|
int warp_idx = expert_id / experts_per_warp;
|
||||||
|
int expert_offset = expert_id % experts_per_warp;
|
||||||
|
atomicAdd(&shared_counts[warp_idx][expert_offset], 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
cumsum[0] = 0;
|
||||||
|
for (int i = 1; i <= num_experts; ++i) {
|
||||||
|
int expert_count = 0;
|
||||||
|
int warp_idx = (i - 1) / experts_per_warp;
|
||||||
|
int expert_offset = (i - 1) % experts_per_warp;
|
||||||
|
expert_count = shared_counts[warp_idx][expert_offset];
|
||||||
|
|
||||||
|
cumsum[i] = cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size;
|
||||||
|
}
|
||||||
|
*total_tokens_post_pad = cumsum[num_experts];
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
if (threadIdx.x < num_experts) {
|
||||||
|
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) {
|
||||||
|
expert_ids[i / block_size] = threadIdx.x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size,
|
||||||
|
torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad,
|
||||||
|
torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer) {
|
||||||
|
const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
|
||||||
|
TORCH_CHECK(num_experts == 256, "moe_align_block_size kernel only support deepseek v3 now.");
|
||||||
|
|
||||||
|
DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||||
|
auto align_kernel = moe_align_block_size_kernel<scalar_t>;
|
||||||
|
hipLaunchKernelGGL(( align_kernel), dim3(1), dim3(1024), 0, stream, topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
|
||||||
|
experts_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(),
|
||||||
|
num_experts, block_size, topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>());
|
||||||
|
|
||||||
|
const int block_threads = 256;
|
||||||
|
const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads;
|
||||||
|
const int max_blocks = 65535;
|
||||||
|
const int actual_blocks = ::min(num_blocks, max_blocks);
|
||||||
|
|
||||||
|
auto sort_kernel = count_and_sort_expert_tokens_kernel<scalar_t>;
|
||||||
|
hipLaunchKernelGGL(( sort_kernel), dim3(actual_blocks), dim3(block_threads), 0, stream, topk_ids.data_ptr<scalar_t>(),
|
||||||
|
sorted_token_ids.data_ptr<int32_t>(),
|
||||||
|
cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel());
|
||||||
|
});
|
||||||
|
}
|
||||||
98
sgl-kernel/src/sgl-kernel/include/utils_hip.h
Normal file
98
sgl-kernel/src/sgl-kernel/include/utils_hip.h
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
// !!! This is a file automatically generated by hipify!!!
|
||||||
|
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <hip/hip_runtime.h>
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
#include <pytorch_extension_utils.h>
|
||||||
|
#endif
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
struct cuda_error : public std::runtime_error {
|
||||||
|
/**
|
||||||
|
* @brief Constructs a `cuda_error` object with the given `message`.
|
||||||
|
*
|
||||||
|
* @param message The error char array used to construct `cuda_error`
|
||||||
|
*/
|
||||||
|
cuda_error(const char* message) : std::runtime_error(message) {}
|
||||||
|
/**
|
||||||
|
* @brief Constructs a `cuda_error` object with the given `message` string.
|
||||||
|
*
|
||||||
|
* @param message The `std::string` used to construct `cuda_error`
|
||||||
|
*/
|
||||||
|
cuda_error(std::string const& message) : cuda_error{message.c_str()} {}
|
||||||
|
};
|
||||||
|
|
||||||
|
#define CHECK_CUDA_SUCCESS(cmd) \
|
||||||
|
do { \
|
||||||
|
hipError_t e = cmd; \
|
||||||
|
if (e != hipSuccess) { \
|
||||||
|
std::stringstream _message; \
|
||||||
|
auto s = hipGetErrorString(e); \
|
||||||
|
_message << std::string(s) + "\n" << __FILE__ << ':' << __LINE__; \
|
||||||
|
throw cuda_error(_message.str()); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
#define CHECK_IS_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
||||||
|
#define CHECK_IS_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||||
|
#define CHECK_CUDA_INPUT(x) \
|
||||||
|
CHECK_IS_CUDA(x); \
|
||||||
|
CHECK_IS_CONTIGUOUS(x)
|
||||||
|
|
||||||
|
inline int getSMVersion() {
|
||||||
|
int device{-1};
|
||||||
|
CHECK_CUDA_SUCCESS(hipGetDevice(&device));
|
||||||
|
int sm_major = 0;
|
||||||
|
int sm_minor = 0;
|
||||||
|
CHECK_CUDA_SUCCESS(hipDeviceGetAttribute(&sm_major, hipDeviceAttributeComputeCapabilityMajor, device));
|
||||||
|
CHECK_CUDA_SUCCESS(hipDeviceGetAttribute(&sm_minor, hipDeviceAttributeComputeCapabilityMinor, device));
|
||||||
|
return sm_major * 10 + sm_minor;
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(pytorch_dtype, c_type, ...) \
|
||||||
|
[&]() -> bool { \
|
||||||
|
switch (pytorch_dtype) { \
|
||||||
|
case at::ScalarType::Float: { \
|
||||||
|
using c_type = float; \
|
||||||
|
return __VA_ARGS__(); \
|
||||||
|
} \
|
||||||
|
_DISPATCH_CASE_F16(c_type, __VA_ARGS__) \
|
||||||
|
_DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \
|
||||||
|
default: \
|
||||||
|
std::ostringstream oss; \
|
||||||
|
oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \
|
||||||
|
TORCH_CHECK(false, oss.str()); \
|
||||||
|
return false; \
|
||||||
|
} \
|
||||||
|
}()
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#define DISPATCH_CASE_INTEGRAL_TYPES(...) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
|
||||||
|
|
||||||
|
#define DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
||||||
|
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
||||||
|
|
||||||
|
#define CEILDIV(x, y) (((x) + (y)-1) / (y))
|
||||||
Reference in New Issue
Block a user