This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
# docker build --build-arg SGL_BRANCH=v0.4.3.post4 -t v0.4.3.post4-rocm630 -f Dockerfile.rocm .
|
||||
|
||||
# default base image
|
||||
ARG BASE_IMAGE="rocm/sgl-dev:20250114vllm-blas-flash"
|
||||
ARG BASE_IMAGE="rocm/sgl-dev:vllm20250114"
|
||||
|
||||
FROM $BASE_IMAGE AS base
|
||||
USER root
|
||||
@@ -16,10 +16,10 @@ ARG SGL_BRANCH=${SGL_DEFAULT}
|
||||
ARG TRITON_REPO="https://github.com/ROCm/triton.git"
|
||||
ARG TRITON_COMMIT="improve_fa_decode_3.0.0"
|
||||
|
||||
|
||||
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
|
||||
ARG AITER_COMMIT="testx"
|
||||
|
||||
|
||||
RUN git clone ${SGL_REPO} \
|
||||
&& cd sglang \
|
||||
&& if [ "${SGL_BRANCH}" = ${SGL_DEFAULT} ]; then \
|
||||
@@ -59,7 +59,6 @@ RUN git clone ${AITER_REPO} \
|
||||
&& git submodule update --init --recursive \
|
||||
&& PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop
|
||||
|
||||
|
||||
# Copy config files to support MI300X in virtualized environments (MI300X_VF). Symlinks will not be created in image build.
|
||||
RUN find /sgl-workspace/sglang/python/sglang/srt/layers/quantization/configs/ \
|
||||
/sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/configs/ \
|
||||
|
||||
@@ -1,605 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""
|
||||
end to end attention solution with aiter kernels
|
||||
"""
|
||||
|
||||
import math
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.speculative.spec_info import SpecInfo
|
||||
|
||||
# flashinfer AMD fork
|
||||
from flashinfer import BatchPrefillWithPagedKVCacheWrapper
|
||||
|
||||
try:
|
||||
from aiter import paged_attention_rocm
|
||||
except ImportError:
|
||||
print(
|
||||
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
|
||||
)
|
||||
|
||||
|
||||
class WrapperDispatch(Enum):
|
||||
SLIDING_WINDOW = auto()
|
||||
CROSS_ATTENTION = auto()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecodeMetadata:
|
||||
kv_indptr: torch.Tensor
|
||||
kv_indices: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class PrefillMetadata:
|
||||
prefill_wrapper: BatchPrefillWithPagedKVCacheWrapper
|
||||
extend_no_prefix: bool
|
||||
|
||||
|
||||
global_workspace_buffer = None
|
||||
|
||||
_AITER_PARTITION_SIZE_ROCM = 256
|
||||
|
||||
|
||||
class AiterAttnBackend(AttentionBackend):
|
||||
def __init__(
|
||||
self,
|
||||
model_runner: ModelRunner,
|
||||
skip_prefill: bool = False,
|
||||
kv_indptr_buf: Optional[torch.Tensor] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.device = model_runner.device
|
||||
self.is_multimodal = model_runner.model_config.is_multimodal
|
||||
self.num_head = (
|
||||
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||
)
|
||||
self.head_dim = model_runner.model_config.head_dim
|
||||
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
|
||||
self.num_kv_head = model_runner.model_config.get_num_kv_heads(
|
||||
get_attention_tp_size()
|
||||
)
|
||||
self.kv_cache_dtype = model_runner.kv_cache_dtype
|
||||
|
||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||
|
||||
# Parse constants
|
||||
self.max_context_len = model_runner.model_config.context_len
|
||||
self.skip_prefill = skip_prefill
|
||||
|
||||
# Qwen2 models require higher flashinfer workspace size
|
||||
if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures:
|
||||
global_config.flashinfer_workspace_size = 512 * 1024 * 1024
|
||||
|
||||
global global_workspace_buffer
|
||||
if global_workspace_buffer is None:
|
||||
global_workspace_buffer = torch.empty(
|
||||
global_config.flashinfer_workspace_size,
|
||||
dtype=torch.uint8,
|
||||
device=model_runner.device,
|
||||
)
|
||||
|
||||
self.workspace_buffer = global_workspace_buffer
|
||||
max_bs = model_runner.req_to_token_pool.size
|
||||
|
||||
if kv_indptr_buf is None:
|
||||
self.kv_indptr = torch.zeros(
|
||||
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
else:
|
||||
self.kv_indptr = kv_indptr_buf
|
||||
|
||||
self.kv_last_page_len = torch.ones(
|
||||
(max_bs,), dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
self.qo_indptr = torch.zeros(
|
||||
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
|
||||
self.prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
||||
self.workspace_buffer, "NHD", backend="fa2"
|
||||
)
|
||||
self.prefill_wrapper_verify = BatchPrefillWithPagedKVCacheWrapper(
|
||||
self.workspace_buffer, "NHD"
|
||||
)
|
||||
|
||||
# Create prefill indices updater
|
||||
if not skip_prefill:
|
||||
self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill(
|
||||
model_runner, self
|
||||
)
|
||||
|
||||
# aiter kernel related initialization
|
||||
self.max_num_partitions = (
|
||||
self.max_context_len + _AITER_PARTITION_SIZE_ROCM - 1
|
||||
) // _AITER_PARTITION_SIZE_ROCM
|
||||
|
||||
nbyes_per_qo_elem = torch.finfo(torch.float32).bits // 8
|
||||
|
||||
self.workspace_buffer = torch.empty(
|
||||
(max_bs * self.num_head * self.max_num_partitions * self.head_dim)
|
||||
* nbyes_per_qo_elem
|
||||
+ 2 * (max_bs * self.num_head * self.max_num_partitions) * 4,
|
||||
dtype=torch.uint8,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.scale = float(1.0 / (self.head_dim**0.5))
|
||||
self.k_scale = self.v_scale = torch.tensor([1.0], dtype=torch.float32).to(
|
||||
self.device
|
||||
)
|
||||
self.kv_last_page_lens = torch.ones((max_bs,), dtype=torch.int32).to(
|
||||
self.device
|
||||
)
|
||||
|
||||
self.logits_soft_cap = 0.0
|
||||
|
||||
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
|
||||
self.decode_cuda_graph_metadata = {}
|
||||
self.prefill_cuda_graph_metadata = {}
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
if forward_batch.forward_mode.is_decode_or_idle():
|
||||
# update for aiter
|
||||
# create kv_indices and kv_inptr
|
||||
bs = forward_batch.batch_size
|
||||
kv_indptr = self.kv_indptr
|
||||
spec_info = forward_batch.spec_info
|
||||
if spec_info is None:
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indices = torch.zeros(
|
||||
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
|
||||
)
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
kv_indptr,
|
||||
None,
|
||||
kv_indices,
|
||||
self.req_to_token.stride(0),
|
||||
)
|
||||
else:
|
||||
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
||||
bs = kv_indptr.shape[0] - 1
|
||||
|
||||
self.forward_metadata = DecodeMetadata(kv_indptr, kv_indices)
|
||||
|
||||
elif forward_batch.forward_mode.is_draft_extend():
|
||||
self.indices_updater_prefill.update(
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.seq_lens_sum,
|
||||
prefix_lens=None,
|
||||
prefill_wrapper=self.prefill_wrapper_paged,
|
||||
encoder_lens=forward_batch.encoder_lens,
|
||||
spec_info=forward_batch.spec_info,
|
||||
)
|
||||
self.forward_metadata = PrefillMetadata(
|
||||
self.prefill_wrapper_paged, False, False
|
||||
)
|
||||
elif forward_batch.forward_mode.is_target_verify():
|
||||
self.indices_updater_prefill.update(
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.seq_lens_sum,
|
||||
prefix_lens=None,
|
||||
prefill_wrapper=self.prefill_wrapper_verify,
|
||||
encoder_lens=forward_batch.encoder_lens,
|
||||
spec_info=forward_batch.spec_info,
|
||||
)
|
||||
self.forward_metadata = PrefillMetadata(
|
||||
self.prefill_wrapper_verify, False, False
|
||||
)
|
||||
else:
|
||||
prefix_lens = forward_batch.extend_prefix_lens
|
||||
|
||||
if self.is_multimodal:
|
||||
extend_no_prefix = False
|
||||
else:
|
||||
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
||||
|
||||
self.indices_updater_prefill.update(
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.seq_lens_sum,
|
||||
prefix_lens,
|
||||
prefill_wrapper=self.prefill_wrapper_paged,
|
||||
encoder_lens=forward_batch.encoder_lens,
|
||||
spec_info=None,
|
||||
)
|
||||
self.forward_metadata = PrefillMetadata(
|
||||
self.prefill_wrapper_paged, extend_no_prefix
|
||||
)
|
||||
|
||||
def init_cuda_graph_state(
|
||||
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
|
||||
):
|
||||
if kv_indices_buf is None:
|
||||
self.cuda_graph_kv_indices = torch.zeros(
|
||||
(max_bs * self.max_context_len),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
self.cuda_graph_kv_indices = kv_indices_buf
|
||||
|
||||
if not self.skip_prefill:
|
||||
self.cuda_graph_custom_mask = torch.zeros(
|
||||
(max_bs * self.max_context_len),
|
||||
dtype=torch.uint8,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
num_tokens: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInfo],
|
||||
):
|
||||
if forward_mode.is_decode_or_idle():
|
||||
if spec_info is None:
|
||||
kv_indptr = self.kv_indptr
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indices = self.cuda_graph_kv_indices
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
kv_indptr,
|
||||
None,
|
||||
kv_indices,
|
||||
self.req_to_token.stride(0),
|
||||
)
|
||||
else:
|
||||
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
||||
self.forward_metadata = DecodeMetadata(kv_indptr, kv_indices)
|
||||
self.decode_cuda_graph_metadata[bs] = self.forward_metadata
|
||||
|
||||
elif forward_mode.is_target_verify():
|
||||
prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
|
||||
self.workspace_buffer,
|
||||
"NHD",
|
||||
use_cuda_graph=False,
|
||||
qo_indptr_buf=self.cuda_graph_qo_indptr[: bs + 1],
|
||||
paged_kv_indptr_buf=self.kv_indptr[: bs + 1],
|
||||
paged_kv_indices_buf=self.cuda_graph_kv_indices,
|
||||
paged_kv_last_page_len_buf=self.kv_last_page_len[:bs],
|
||||
custom_mask_buf=self.cuda_graph_custom_mask,
|
||||
mask_indptr_buf=self.cuda_graph_qk_indptr[: bs + 1],
|
||||
)
|
||||
|
||||
seq_lens_sum = seq_lens.sum().item()
|
||||
self.indices_updater_prefill.update(
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
seq_lens_sum,
|
||||
prefix_lens=None,
|
||||
prefill_wrapper=prefill_wrapper,
|
||||
encoder_lens=encoder_lens,
|
||||
spec_info=spec_info,
|
||||
)
|
||||
self.prefill_cuda_graph_metadata[bs] = prefill_wrapper
|
||||
self.forward_metadata = PrefillMetadata(prefill_wrapper, False)
|
||||
else:
|
||||
raise ValueError(f"Invalid mode: {forward_mode=}")
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInfo],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
if forward_mode.is_decode_or_idle():
|
||||
kv_indptr = self.kv_indptr
|
||||
kv_indices = self.cuda_graph_kv_indices
|
||||
if spec_info is None:
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices[:bs],
|
||||
seq_lens[:bs],
|
||||
kv_indptr,
|
||||
None,
|
||||
kv_indices,
|
||||
self.req_to_token.stride(0),
|
||||
)
|
||||
else:
|
||||
kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr
|
||||
kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
|
||||
|
||||
elif forward_mode.is_target_verify():
|
||||
self.indices_updater_prefill.update(
|
||||
req_pool_indices[:bs],
|
||||
seq_lens[:bs],
|
||||
seq_lens_sum,
|
||||
prefix_lens=None,
|
||||
prefill_wrapper=self.prefill_cuda_graph_metadata[bs],
|
||||
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
||||
spec_info=spec_info,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid forward mode")
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return 1
|
||||
|
||||
def forward_extend(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
):
|
||||
prefill_wrapper_paged = self.forward_metadata.prefill_wrapper
|
||||
cache_loc = (
|
||||
forward_batch.out_cache_loc
|
||||
if not layer.is_cross_attention
|
||||
else forward_batch.encoder_out_cache_loc
|
||||
)
|
||||
|
||||
self.logits_soft_cap = layer.logit_cap
|
||||
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||
)
|
||||
|
||||
o = prefill_wrapper_paged.forward(
|
||||
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||
causal=not layer.is_cross_attention,
|
||||
sm_scale=layer.scaling,
|
||||
window_left=layer.sliding_window_size,
|
||||
logits_soft_cap=self.logits_soft_cap,
|
||||
k_scale=layer.k_scale,
|
||||
v_scale=layer.v_scale,
|
||||
)
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
):
|
||||
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
||||
|
||||
if layer.qk_head_dim != layer.v_head_dim:
|
||||
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
||||
else:
|
||||
o = torch.empty_like(q)
|
||||
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, forward_batch.out_cache_loc, k, v
|
||||
)
|
||||
|
||||
self.logits_soft_cap = layer.logit_cap
|
||||
paged_attention_rocm(
|
||||
o.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||
self.workspace_buffer,
|
||||
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).view(
|
||||
-1, 1, layer.tp_k_head_num, layer.qk_head_dim
|
||||
),
|
||||
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id).view(
|
||||
-1, 1, layer.tp_v_head_num, layer.v_head_dim
|
||||
),
|
||||
self.scale,
|
||||
self.forward_metadata.kv_indptr,
|
||||
self.forward_metadata.kv_indices,
|
||||
self.kv_last_page_lens,
|
||||
1,
|
||||
self.max_num_partitions,
|
||||
None,
|
||||
"auto",
|
||||
"NHD",
|
||||
self.logits_soft_cap,
|
||||
self.k_scale,
|
||||
self.v_scale,
|
||||
None,
|
||||
_AITER_PARTITION_SIZE_ROCM,
|
||||
)
|
||||
|
||||
return o
|
||||
|
||||
|
||||
class FlashInferIndicesUpdaterPrefill:
|
||||
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
|
||||
# Parse Constants
|
||||
self.num_qo_heads = (
|
||||
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||
)
|
||||
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
|
||||
get_attention_tp_size()
|
||||
)
|
||||
self.head_dim = model_runner.model_config.head_dim
|
||||
self.data_type = model_runner.kv_cache_dtype
|
||||
self.q_data_type = model_runner.dtype
|
||||
self.sliding_window_size = model_runner.sliding_window_size
|
||||
self.attn_backend = attn_backend
|
||||
|
||||
# Buffers and wrappers
|
||||
self.kv_indptr = attn_backend.kv_indptr
|
||||
self.kv_last_page_len = attn_backend.kv_last_page_len
|
||||
self.qo_indptr = attn_backend.qo_indptr
|
||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||
self.update = self.update_single_wrapper
|
||||
|
||||
def update(
|
||||
self,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
prefix_lens: torch.Tensor,
|
||||
prefill_wrapper: BatchPrefillWithPagedKVCacheWrapper,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInfo],
|
||||
):
|
||||
# Keep the signature for type checking. It will be assigned during runtime.
|
||||
raise NotImplementedError()
|
||||
|
||||
def update_single_wrapper(
|
||||
self,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
prefix_lens: torch.Tensor,
|
||||
prefill_wrapper: BatchPrefillWithPagedKVCacheWrapper,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInfo],
|
||||
):
|
||||
|
||||
paged_kernel_lens = seq_lens
|
||||
paged_kernel_lens_sum = seq_lens_sum
|
||||
|
||||
self.call_begin_forward(
|
||||
prefill_wrapper,
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
paged_kernel_lens_sum,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
None,
|
||||
self.kv_indptr,
|
||||
self.qo_indptr,
|
||||
spec_info,
|
||||
)
|
||||
|
||||
def call_begin_forward(
|
||||
self,
|
||||
wrapper_paged: BatchPrefillWithPagedKVCacheWrapper,
|
||||
req_pool_indices: torch.Tensor,
|
||||
paged_kernel_lens: torch.Tensor,
|
||||
paged_kernel_lens_sum: int,
|
||||
seq_lens: torch.Tensor,
|
||||
prefix_lens: torch.Tensor,
|
||||
kv_start_idx: torch.Tensor,
|
||||
kv_indptr: torch.Tensor,
|
||||
qo_indptr: torch.Tensor,
|
||||
spec_info: Optional[SpecInfo],
|
||||
):
|
||||
bs = len(req_pool_indices)
|
||||
if spec_info is None:
|
||||
# Normal extend
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indices = torch.empty(
|
||||
paged_kernel_lens_sum + 256,
|
||||
dtype=torch.int32,
|
||||
device=req_pool_indices.device,
|
||||
)
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
kv_indptr,
|
||||
kv_start_idx,
|
||||
kv_indices,
|
||||
self.req_to_token.shape[1],
|
||||
)
|
||||
|
||||
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
||||
qo_indptr = qo_indptr[: bs + 1]
|
||||
custom_mask = None
|
||||
else:
|
||||
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
||||
spec_info.generate_attn_arg_prefill(
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
self.req_to_token,
|
||||
)
|
||||
)
|
||||
|
||||
# cached part
|
||||
# adding logits_soft_cap arg in plan() stage
|
||||
wrapper_paged.begin_forward(
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
self.kv_last_page_len[:bs],
|
||||
self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
1,
|
||||
q_data_type=self.q_data_type,
|
||||
custom_mask=custom_mask,
|
||||
non_blocking=True,
|
||||
logits_soft_cap=self.attn_backend.logits_soft_cap,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def create_flashinfer_kv_indices_triton(
|
||||
req_to_token_ptr, # [max_batch, max_context_len]
|
||||
req_pool_indices_ptr,
|
||||
page_kernel_lens_ptr,
|
||||
kv_indptr,
|
||||
kv_start_idx,
|
||||
kv_indices_ptr,
|
||||
req_to_token_ptr_stride: tl.constexpr,
|
||||
):
|
||||
BLOCK_SIZE: tl.constexpr = 512
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
||||
kv_indices_offset = tl.load(kv_indptr + pid)
|
||||
|
||||
kv_start = 0
|
||||
kv_end = 0
|
||||
if kv_start_idx:
|
||||
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
|
||||
kv_end = kv_start
|
||||
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
|
||||
|
||||
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
||||
for i in range(num_loop):
|
||||
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
||||
mask = offset < kv_end - kv_start
|
||||
data = tl.load(
|
||||
req_to_token_ptr
|
||||
+ req_pool_index * req_to_token_ptr_stride
|
||||
+ kv_start
|
||||
+ offset,
|
||||
mask=mask,
|
||||
)
|
||||
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
|
||||
@@ -1,535 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.speculative.spec_info import SpecInfo
|
||||
|
||||
try:
|
||||
from aiter import paged_attention_rocm
|
||||
except ImportError:
|
||||
print(
|
||||
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
|
||||
)
|
||||
|
||||
from sglang.srt.layers.attention.triton_ops.extend_attention import extend_attention_fwd
|
||||
|
||||
_AITER_PARTITION_SIZE_ROCM = 256
|
||||
|
||||
|
||||
class AiterDecodeAttnBackend(AttentionBackend):
|
||||
def __init__(
|
||||
self,
|
||||
model_runner: ModelRunner,
|
||||
skip_prefill: bool = False,
|
||||
kv_indptr_buf: Optional[torch.Tensor] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.decode_attention_fwd = paged_attention_rocm
|
||||
self.extend_attention_fwd = extend_attention_fwd
|
||||
|
||||
self.skip_prefill = skip_prefill
|
||||
|
||||
max_bs = model_runner.req_to_token_pool.size
|
||||
|
||||
if kv_indptr_buf is None:
|
||||
self.kv_indptr = torch.zeros(
|
||||
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
else:
|
||||
self.kv_indptr = kv_indptr_buf
|
||||
|
||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||
|
||||
if not self.skip_prefill:
|
||||
self.qo_indptr = torch.zeros(
|
||||
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
|
||||
self.mask_indptr = torch.zeros(
|
||||
(max_bs + 1,), dtype=torch.int64, device=model_runner.device
|
||||
)
|
||||
|
||||
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
|
||||
|
||||
# tp sharding on number of heads
|
||||
self.num_head = (
|
||||
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||
)
|
||||
|
||||
self.head_dim = model_runner.model_config.head_dim
|
||||
|
||||
# triton prefill initialization
|
||||
self.num_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
|
||||
|
||||
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
|
||||
|
||||
self.num_v_head = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-2]
|
||||
|
||||
self.forward_metadata = None
|
||||
|
||||
self.max_context_len = model_runner.model_config.context_len
|
||||
|
||||
self.device = model_runner.device
|
||||
|
||||
self.kv_cache_dtype = model_runner.kv_cache_dtype
|
||||
|
||||
self.q_dtype = model_runner.model_config.dtype
|
||||
|
||||
# aiter decode initialization
|
||||
self.max_num_partitions = (
|
||||
self.max_context_len + _AITER_PARTITION_SIZE_ROCM - 1
|
||||
) // _AITER_PARTITION_SIZE_ROCM
|
||||
|
||||
nbyes_per_qo_elem = torch.finfo(torch.float32).bits // 8
|
||||
|
||||
self.workspace_buffer = torch.empty(
|
||||
(max_bs * self.num_head * self.max_num_partitions * self.head_dim)
|
||||
* nbyes_per_qo_elem
|
||||
+ 2 * (max_bs * self.num_head * self.max_num_partitions) * 4,
|
||||
dtype=torch.uint8,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.scale = float(1.0 / (self.head_dim**0.5))
|
||||
self.k_scale = self.v_scale = torch.tensor([1.0], dtype=torch.float32).to(
|
||||
self.device
|
||||
)
|
||||
self.kv_last_page_lens = torch.ones((max_bs,), dtype=torch.int32).to(
|
||||
self.device
|
||||
)
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
"""Init auxiliary variables"""
|
||||
bs = forward_batch.batch_size
|
||||
kv_indptr = self.kv_indptr
|
||||
spec_info = forward_batch.spec_info
|
||||
|
||||
if forward_batch.forward_mode.is_decode_or_idle():
|
||||
if spec_info is None:
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indices = torch.zeros(
|
||||
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
|
||||
)
|
||||
# prepare kv_indices and kv_indptr
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
kv_indptr,
|
||||
None,
|
||||
kv_indices,
|
||||
self.req_to_token.stride(0),
|
||||
)
|
||||
else:
|
||||
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
||||
bs = kv_indptr.shape[0] - 1
|
||||
|
||||
attn_logits = None # accomodate forward_metadata format
|
||||
qo_indptr = None
|
||||
custom_mask = None
|
||||
mask_indptr = None
|
||||
max_extend_len = None
|
||||
elif forward_batch.forward_mode.is_target_verify():
|
||||
bs = len(forward_batch.req_pool_indices)
|
||||
qo_indptr = torch.arange(
|
||||
0,
|
||||
(1 + bs) * self.num_draft_tokens,
|
||||
step=self.num_draft_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indices = torch.zeros(
|
||||
kv_indptr[-1], dtype=torch.int32, device=self.device
|
||||
)
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
kv_indptr,
|
||||
None,
|
||||
kv_indices,
|
||||
self.req_to_token.stride(0),
|
||||
)
|
||||
|
||||
custom_mask = spec_info.custom_mask
|
||||
seq_mask_len = self.num_draft_tokens * (
|
||||
forward_batch.seq_lens + self.num_draft_tokens
|
||||
)
|
||||
mask_indptr = self.mask_indptr
|
||||
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len[:bs], dim=0)
|
||||
mask_indptr = mask_indptr[: bs + 1]
|
||||
max_extend_len = self.num_draft_tokens
|
||||
attn_logits = None
|
||||
elif forward_batch.forward_mode.is_draft_extend():
|
||||
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
||||
spec_info.generate_attn_arg_prefill(
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
self.req_to_token,
|
||||
)
|
||||
)
|
||||
mask_indptr = None
|
||||
max_extend_len = torch.max(spec_info.accept_length).item()
|
||||
attn_logits = None
|
||||
else:
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(
|
||||
forward_batch.extend_prefix_lens, dim=0
|
||||
)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indices = torch.zeros(
|
||||
forward_batch.extend_prefix_lens.sum().item(),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.extend_prefix_lens,
|
||||
kv_indptr,
|
||||
None,
|
||||
kv_indices,
|
||||
self.req_to_token.stride(0),
|
||||
)
|
||||
|
||||
qo_indptr = self.qo_indptr
|
||||
qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0)
|
||||
qo_indptr = qo_indptr[: bs + 1]
|
||||
custom_mask = None
|
||||
mask_indptr = None
|
||||
attn_logits = None
|
||||
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
|
||||
|
||||
self.forward_metadata = (
|
||||
attn_logits,
|
||||
max_extend_len,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
qo_indptr,
|
||||
custom_mask,
|
||||
mask_indptr,
|
||||
)
|
||||
|
||||
def init_cuda_graph_state(
|
||||
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
|
||||
):
|
||||
|
||||
self.cuda_graph_attn_logits = torch.zeros(
|
||||
(max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
if kv_indices_buf is None:
|
||||
self.cuda_graph_kv_indices = torch.zeros(
|
||||
(max_bs * self.max_context_len),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
self.cuda_graph_kv_indices = kv_indices_buf
|
||||
|
||||
if not self.skip_prefill:
|
||||
self.cuda_graph_custom_mask = torch.zeros(
|
||||
(max_bs * self.max_context_len),
|
||||
dtype=torch.uint8,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
num_tokens: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInfo],
|
||||
):
|
||||
assert encoder_lens is None, "Not supported"
|
||||
|
||||
if forward_mode.is_decode_or_idle():
|
||||
if spec_info is None:
|
||||
kv_indptr = self.kv_indptr
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indices = self.cuda_graph_kv_indices
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
kv_indptr,
|
||||
None,
|
||||
kv_indices,
|
||||
self.req_to_token.stride(0),
|
||||
)
|
||||
else:
|
||||
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
||||
attn_logits = None
|
||||
max_extend_len = None
|
||||
qo_indptr = None
|
||||
custom_mask = None
|
||||
mask_indptr = None
|
||||
elif forward_mode.is_target_verify():
|
||||
qo_indptr = self.qo_indptr[: bs + 1]
|
||||
qo_indptr[: bs + 1] = torch.arange(
|
||||
0,
|
||||
(1 + bs) * self.num_draft_tokens,
|
||||
step=self.num_draft_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
kv_indptr = self.kv_indptr[: bs + 1]
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
||||
kv_indices = self.cuda_graph_kv_indices
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
kv_indptr,
|
||||
None,
|
||||
kv_indices,
|
||||
self.req_to_token.stride(0),
|
||||
)
|
||||
|
||||
custom_mask = self.cuda_graph_custom_mask
|
||||
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
|
||||
mask_indptr = self.mask_indptr[: bs + 1]
|
||||
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
|
||||
max_extend_len = self.num_draft_tokens
|
||||
attn_logits = None
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid forward mode: {forward_mode=} for CUDA Graph capture."
|
||||
)
|
||||
|
||||
self.forward_metadata = (
|
||||
attn_logits,
|
||||
max_extend_len,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
qo_indptr,
|
||||
custom_mask,
|
||||
mask_indptr,
|
||||
)
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInfo],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
# NOTE: encoder_lens expected to be zeros or None
|
||||
if forward_mode.is_decode_or_idle():
|
||||
# Update kv_indptr, kv_indices
|
||||
kv_indptr = self.kv_indptr
|
||||
kv_indices = self.cuda_graph_kv_indices
|
||||
if spec_info is None:
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices[:bs],
|
||||
seq_lens[:bs],
|
||||
kv_indptr,
|
||||
None,
|
||||
kv_indices,
|
||||
self.req_to_token.stride(0),
|
||||
)
|
||||
else:
|
||||
kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr
|
||||
kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
|
||||
elif forward_mode.is_target_verify():
|
||||
# Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr
|
||||
bs = len(req_pool_indices)
|
||||
qo_indptr = self.qo_indptr[: bs + 1]
|
||||
qo_indptr[: bs + 1] = torch.arange(
|
||||
0,
|
||||
(1 + bs) * self.num_draft_tokens,
|
||||
step=self.num_draft_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
kv_indptr = self.kv_indptr[: bs + 1]
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
||||
kv_indices = self.cuda_graph_kv_indices
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
kv_indptr,
|
||||
None,
|
||||
kv_indices,
|
||||
self.req_to_token.stride(0),
|
||||
)
|
||||
custom_mask = self.cuda_graph_custom_mask
|
||||
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
|
||||
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
|
||||
mask_indptr = self.mask_indptr[: bs + 1]
|
||||
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid forward mode: {forward_mode=} for CUDA Graph replay."
|
||||
)
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return 1
|
||||
|
||||
def forward_extend(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
):
|
||||
# TODO: reuse the buffer across layers
|
||||
if layer.qk_head_dim != layer.v_head_dim:
|
||||
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
||||
else:
|
||||
o = torch.empty_like(q)
|
||||
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, forward_batch.out_cache_loc, k, v
|
||||
)
|
||||
|
||||
(
|
||||
_,
|
||||
max_extend_len,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
qo_indptr,
|
||||
custom_mask,
|
||||
mask_indptr,
|
||||
) = self.forward_metadata
|
||||
|
||||
self.extend_attention_fwd(
|
||||
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||
k.contiguous(),
|
||||
v.contiguous(),
|
||||
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
||||
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
||||
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
custom_mask,
|
||||
mask_indptr,
|
||||
max_extend_len,
|
||||
layer.scaling,
|
||||
layer.logit_cap,
|
||||
)
|
||||
return o
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
):
|
||||
# During torch.compile, there is a bug in rotary_emb that causes the
|
||||
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
||||
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
||||
|
||||
if layer.qk_head_dim != layer.v_head_dim:
|
||||
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
||||
else:
|
||||
o = torch.empty_like(q)
|
||||
|
||||
attn_logits, _, kv_indptr, kv_indices, _, _, _ = self.forward_metadata
|
||||
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, forward_batch.out_cache_loc, k, v
|
||||
)
|
||||
|
||||
self.decode_attention_fwd(
|
||||
o.view(
|
||||
-1, layer.tp_q_head_num, layer.qk_head_dim
|
||||
), # (bs, head_num_q, head_dim_q)
|
||||
self.workspace_buffer,
|
||||
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).view(
|
||||
-1, 1, layer.tp_k_head_num, layer.qk_head_dim
|
||||
),
|
||||
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id).view(
|
||||
-1, 1, layer.tp_v_head_num, layer.v_head_dim
|
||||
),
|
||||
self.scale,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
self.kv_last_page_lens,
|
||||
1,
|
||||
self.max_num_partitions,
|
||||
None,
|
||||
"auto",
|
||||
"NHD",
|
||||
layer.logit_cap,
|
||||
self.k_scale,
|
||||
self.v_scale,
|
||||
None,
|
||||
_AITER_PARTITION_SIZE_ROCM,
|
||||
)
|
||||
|
||||
return o
|
||||
|
||||
|
||||
@triton.jit
|
||||
def create_flashinfer_kv_indices_triton(
|
||||
req_to_token_ptr, # [max_batch, max_context_len]
|
||||
req_pool_indices_ptr,
|
||||
page_kernel_lens_ptr,
|
||||
kv_indptr,
|
||||
kv_start_idx,
|
||||
kv_indices_ptr,
|
||||
req_to_token_ptr_stride: tl.constexpr,
|
||||
):
|
||||
BLOCK_SIZE: tl.constexpr = 512
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
||||
kv_indices_offset = tl.load(kv_indptr + pid)
|
||||
|
||||
kv_start = 0
|
||||
kv_end = 0
|
||||
if kv_start_idx:
|
||||
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
|
||||
kv_end = kv_start
|
||||
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
|
||||
|
||||
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
||||
for i in range(num_loop):
|
||||
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
||||
mask = offset < kv_end - kv_start
|
||||
data = tl.load(
|
||||
req_to_token_ptr
|
||||
+ req_pool_index * req_to_token_ptr_stride
|
||||
+ kv_start
|
||||
+ offset,
|
||||
mask=mask,
|
||||
)
|
||||
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
|
||||
@@ -79,12 +79,6 @@ from sglang.srt.utils import (
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
is_hip_ = is_hip()
|
||||
|
||||
if is_hip_:
|
||||
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
|
||||
from sglang.srt.layers.attention.aiter_decode_backend import AiterDecodeAttnBackend
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -647,7 +641,7 @@ class ModelRunner:
|
||||
if self.server_args.kv_cache_dtype == "auto":
|
||||
self.kv_cache_dtype = self.dtype
|
||||
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
|
||||
if is_hip_: # Using natively supported format
|
||||
if is_hip(): # Using natively supported format
|
||||
self.kv_cache_dtype = torch.float8_e5m2fnuz
|
||||
else:
|
||||
self.kv_cache_dtype = torch.float8_e5m2
|
||||
@@ -784,59 +778,33 @@ class ModelRunner:
|
||||
|
||||
def init_attention_backend(self):
|
||||
"""Init attention kernel backend."""
|
||||
if is_cuda():
|
||||
if self.server_args.attention_backend == "flashinfer":
|
||||
# Init streams
|
||||
if self.server_args.speculative_algorithm == "EAGLE":
|
||||
self.plan_stream_for_flashinfer = torch.cuda.Stream()
|
||||
if self.server_args.attention_backend == "flashinfer":
|
||||
# Init streams
|
||||
if self.server_args.speculative_algorithm == "EAGLE":
|
||||
self.plan_stream_for_flashinfer = torch.cuda.Stream()
|
||||
|
||||
self.attn_backend = FlashInferAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "triton":
|
||||
assert self.sliding_window_size is None, (
|
||||
"Window attention is not supported in the triton attention backend. "
|
||||
"Please use `--attention-backend flashinfer`."
|
||||
)
|
||||
assert not self.model_config.is_encoder_decoder, (
|
||||
"Cross attention is not supported in the triton attention backend. "
|
||||
"Please use `--attention-backend flashinfer`."
|
||||
)
|
||||
if self.server_args.enable_double_sparsity:
|
||||
self.attn_backend = DoubleSparseAttnBackend(self)
|
||||
else:
|
||||
self.attn_backend = TritonAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "torch_native":
|
||||
self.attn_backend = TorchNativeAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "flashinfer_mla":
|
||||
self.attn_backend = FlashInferMLAAttnBackend(self)
|
||||
self.attn_backend = FlashInferAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "triton":
|
||||
assert self.sliding_window_size is None, (
|
||||
"Window attention is not supported in the triton attention backend. "
|
||||
"Please use `--attention-backend flashinfer`."
|
||||
)
|
||||
assert not self.model_config.is_encoder_decoder, (
|
||||
"Cross attention is not supported in the triton attention backend. "
|
||||
"Please use `--attention-backend flashinfer`."
|
||||
)
|
||||
if self.server_args.enable_double_sparsity:
|
||||
self.attn_backend = DoubleSparseAttnBackend(self)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend: {self.server_args.attention_backend}"
|
||||
)
|
||||
elif is_hip_:
|
||||
# AMD hip supported attention backends
|
||||
if self.server_args.attention_backend == "aiter":
|
||||
self.attn_backend = AiterAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "aiter_decode":
|
||||
self.attn_backend = AiterDecodeAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "triton":
|
||||
assert self.sliding_window_size is None, (
|
||||
"Window attention is not supported in the triton attention backend. "
|
||||
"Please use `--attention-backend flashinfer`."
|
||||
)
|
||||
assert not self.model_config.is_encoder_decoder, (
|
||||
"Cross attention is not supported in the triton attention backend. "
|
||||
"Please use `--attention-backend flashinfer`."
|
||||
)
|
||||
if self.server_args.enable_double_sparsity:
|
||||
self.attn_backend = DoubleSparseAttnBackend(self)
|
||||
else:
|
||||
self.attn_backend = TritonAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "torch_native":
|
||||
self.attn_backend = TorchNativeAttnBackend(self)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend: {self.server_args.attention_backend}"
|
||||
)
|
||||
self.attn_backend = TritonAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "torch_native":
|
||||
self.attn_backend = TorchNativeAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "flashinfer_mla":
|
||||
self.attn_backend = FlashInferMLAAttnBackend(self)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend: {self.server_args.attention_backend}"
|
||||
)
|
||||
|
||||
def init_double_sparsity_channel_config(self, selected_channel):
|
||||
selected_channel = "." + selected_channel + "_proj"
|
||||
|
||||
@@ -710,23 +710,13 @@ class ServerArgs:
|
||||
)
|
||||
|
||||
# Kernel backend
|
||||
if is_hip():
|
||||
parser.add_argument(
|
||||
"--attention-backend",
|
||||
type=str,
|
||||
choices=["triton", "torch_native", "aiter", "aiter_decode"],
|
||||
default=ServerArgs.attention_backend,
|
||||
help="Choose the kernels for attention layers.",
|
||||
)
|
||||
else:
|
||||
parser.add_argument(
|
||||
"--attention-backend",
|
||||
type=str,
|
||||
choices=["flashinfer", "triton", "torch_native"],
|
||||
default=ServerArgs.attention_backend,
|
||||
help="Choose the kernels for attention layers.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--attention-backend",
|
||||
type=str,
|
||||
choices=["flashinfer", "triton", "torch_native"],
|
||||
default=ServerArgs.attention_backend,
|
||||
help="Choose the kernels for attention layers.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sampling-backend",
|
||||
type=str,
|
||||
|
||||
@@ -1,118 +0,0 @@
|
||||
// !!! This is a file automatically generated by hipify!!!
|
||||
#include "hip/hip_runtime.h"
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/hip/HIPContext.h>
|
||||
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <THH/THHAtomics.cuh>
|
||||
|
||||
#include "utils_hip.h"
|
||||
|
||||
#define WARP_SIZE 32
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void count_and_sort_expert_tokens_kernel(const scalar_t* __restrict__ topk_ids,
|
||||
int32_t* __restrict__ sorted_token_ids,
|
||||
int32_t* __restrict__ cumsum_buffer, size_t numel) {
|
||||
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const size_t stride = blockDim.x * gridDim.x;
|
||||
|
||||
for (size_t i = tid; i < numel; i += stride) {
|
||||
int32_t expert_id = topk_ids[i];
|
||||
int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1);
|
||||
sorted_token_ids[rank_post_pad] = i;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void moe_align_block_size_kernel(const scalar_t* __restrict__ topk_ids,
|
||||
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
|
||||
int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts,
|
||||
int32_t block_size, size_t numel, int32_t* __restrict__ cumsum) {
|
||||
__shared__ int32_t shared_counts[WARP_SIZE][8];
|
||||
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
const int experts_per_warp = 8;
|
||||
const int my_expert_start = warp_id * experts_per_warp;
|
||||
|
||||
for (int i = 0; i < experts_per_warp; ++i) {
|
||||
if (my_expert_start + i < num_experts) {
|
||||
shared_counts[warp_id][i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
|
||||
const size_t start_idx = threadIdx.x * tokens_per_thread;
|
||||
|
||||
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
||||
int expert_id = topk_ids[i];
|
||||
int warp_idx = expert_id / experts_per_warp;
|
||||
int expert_offset = expert_id % experts_per_warp;
|
||||
atomicAdd(&shared_counts[warp_idx][expert_offset], 1);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
cumsum[0] = 0;
|
||||
for (int i = 1; i <= num_experts; ++i) {
|
||||
int expert_count = 0;
|
||||
int warp_idx = (i - 1) / experts_per_warp;
|
||||
int expert_offset = (i - 1) % experts_per_warp;
|
||||
expert_count = shared_counts[warp_idx][expert_offset];
|
||||
|
||||
cumsum[i] = cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size;
|
||||
}
|
||||
*total_tokens_post_pad = cumsum[num_experts];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < num_experts) {
|
||||
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) {
|
||||
expert_ids[i / block_size] = threadIdx.x;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size,
|
||||
torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad,
|
||||
torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer) {
|
||||
const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
|
||||
TORCH_CHECK(num_experts == 256, "moe_align_block_size kernel only support deepseek v3 now.");
|
||||
|
||||
DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||
auto align_kernel = moe_align_block_size_kernel<scalar_t>;
|
||||
hipLaunchKernelGGL(( align_kernel), dim3(1), dim3(1024), 0, stream, topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
|
||||
experts_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(),
|
||||
num_experts, block_size, topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>());
|
||||
|
||||
const int block_threads = 256;
|
||||
const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads;
|
||||
const int max_blocks = 65535;
|
||||
const int actual_blocks = ::min(num_blocks, max_blocks);
|
||||
|
||||
auto sort_kernel = count_and_sort_expert_tokens_kernel<scalar_t>;
|
||||
hipLaunchKernelGGL(( sort_kernel), dim3(actual_blocks), dim3(block_threads), 0, stream, topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel());
|
||||
});
|
||||
}
|
||||
@@ -1,98 +0,0 @@
|
||||
// !!! This is a file automatically generated by hipify!!!
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#ifndef USE_ROCM
|
||||
#include <pytorch_extension_utils.h>
|
||||
#endif
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <sstream>
|
||||
|
||||
struct cuda_error : public std::runtime_error {
|
||||
/**
|
||||
* @brief Constructs a `cuda_error` object with the given `message`.
|
||||
*
|
||||
* @param message The error char array used to construct `cuda_error`
|
||||
*/
|
||||
cuda_error(const char* message) : std::runtime_error(message) {}
|
||||
/**
|
||||
* @brief Constructs a `cuda_error` object with the given `message` string.
|
||||
*
|
||||
* @param message The `std::string` used to construct `cuda_error`
|
||||
*/
|
||||
cuda_error(std::string const& message) : cuda_error{message.c_str()} {}
|
||||
};
|
||||
|
||||
#define CHECK_CUDA_SUCCESS(cmd) \
|
||||
do { \
|
||||
hipError_t e = cmd; \
|
||||
if (e != hipSuccess) { \
|
||||
std::stringstream _message; \
|
||||
auto s = hipGetErrorString(e); \
|
||||
_message << std::string(s) + "\n" << __FILE__ << ':' << __LINE__; \
|
||||
throw cuda_error(_message.str()); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define CHECK_IS_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_IS_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_CUDA_INPUT(x) \
|
||||
CHECK_IS_CUDA(x); \
|
||||
CHECK_IS_CONTIGUOUS(x)
|
||||
|
||||
inline int getSMVersion() {
|
||||
int device{-1};
|
||||
CHECK_CUDA_SUCCESS(hipGetDevice(&device));
|
||||
int sm_major = 0;
|
||||
int sm_minor = 0;
|
||||
CHECK_CUDA_SUCCESS(hipDeviceGetAttribute(&sm_major, hipDeviceAttributeComputeCapabilityMajor, device));
|
||||
CHECK_CUDA_SUCCESS(hipDeviceGetAttribute(&sm_minor, hipDeviceAttributeComputeCapabilityMinor, device));
|
||||
return sm_major * 10 + sm_minor;
|
||||
}
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(pytorch_dtype, c_type, ...) \
|
||||
[&]() -> bool { \
|
||||
switch (pytorch_dtype) { \
|
||||
case at::ScalarType::Float: { \
|
||||
using c_type = float; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
_DISPATCH_CASE_F16(c_type, __VA_ARGS__) \
|
||||
_DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \
|
||||
default: \
|
||||
std::ostringstream oss; \
|
||||
oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \
|
||||
TORCH_CHECK(false, oss.str()); \
|
||||
return false; \
|
||||
} \
|
||||
}()
|
||||
#endif
|
||||
|
||||
#define DISPATCH_CASE_INTEGRAL_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
|
||||
|
||||
#define DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
||||
|
||||
#define CEILDIV(x, y) (((x) + (y)-1) / (y))
|
||||
Reference in New Issue
Block a user