[AMD] Support Wave attention backend with AMD GPU optimizations (#8660)
Signed-off-by: Stanley Winata <stanley.winata@amd.com> Signed-off-by: Harsh Menon <harsh@nod-labs.com> Signed-off-by: nithinsubbiah <nithinsubbiah@gmail.com> Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com> Signed-off-by: xintin <gaurav.verma@amd.com> Co-authored-by: Harsh Menon <harsh@nod-labs.com> Co-authored-by: Stanley Winata <stanley.winata@amd.com> Co-authored-by: Stanley Winata <68087699+raikonenfnu@users.noreply.github.com> Co-authored-by: Stanley Winata <stanley@nod-labs.com> Co-authored-by: Ivan Butygin <ivan.butygin@gmail.com> Co-authored-by: nithinsubbiah <nithinsubbiah@gmail.com> Co-authored-by: Nithin Meganathan <18070964+nithinsubbiah@users.noreply.github.com> Co-authored-by: Ivan Butygin <ibutygin@amd.com>
This commit is contained in:
@@ -14,6 +14,7 @@ You can test them according to your needs.
|
||||
| **FlashMLA** | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| **TRTLLM MLA** | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| **Ascend** | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| **Wave** | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
|
||||
**Notes:**
|
||||
- TRTLLM MLA only implements decode operations. For prefill operations (including multimodal inputs), it falls back to FlashInfer MLA backend.
|
||||
@@ -70,6 +71,10 @@ python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attenti
|
||||
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend ascend
|
||||
```
|
||||
|
||||
- Wave
|
||||
```bash
|
||||
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend wave
|
||||
```
|
||||
|
||||
## Steps to add a new attention backend
|
||||
To add a new attention backend, you can learn from the existing backends
|
||||
|
||||
@@ -82,6 +82,7 @@ srt_hip = [
|
||||
"sglang[runtime_common]",
|
||||
"torch",
|
||||
"petit_kernel==0.0.2",
|
||||
"wave-lang==1.0.1",
|
||||
]
|
||||
|
||||
# CPU: torch wheel for CPU needs to be installed from https://download.pytorch.org/whl/cpu
|
||||
|
||||
627
python/sglang/srt/layers/attention/wave_backend.py
Normal file
627
python/sglang/srt/layers/attention/wave_backend.py
Normal file
@@ -0,0 +1,627 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.utils import get_bool_env_var, get_device_core_count
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def get_num_kv_splits_triton(
|
||||
num_kv_splits_ptr,
|
||||
seq_lens_ptr,
|
||||
num_seq,
|
||||
num_group,
|
||||
num_head,
|
||||
num_kv_head,
|
||||
max_kv_splits,
|
||||
device_core_count,
|
||||
MAX_NUM_SEQ: tl.constexpr,
|
||||
):
|
||||
# TODO: this method is tunable, we need more online serving data to tune it
|
||||
offs_seq = tl.arange(0, MAX_NUM_SEQ)
|
||||
mask_seq = offs_seq < num_seq
|
||||
|
||||
seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=0)
|
||||
max_seq_len = tl.max(seq_lens)
|
||||
seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=max_seq_len)
|
||||
min_seq_len = tl.min(seq_lens)
|
||||
if max_seq_len * 8 < min_seq_len * 10:
|
||||
min_seq_len = max_seq_len
|
||||
max_kv_splits_1 = tl.minimum(tl.cdiv(max_seq_len, min_seq_len), max_kv_splits)
|
||||
kv_chunk_size_1 = tl.cdiv(max_seq_len, max_kv_splits_1)
|
||||
|
||||
# NOTE: this is a hack to let num_kv_split grows up with seqlen gradually
|
||||
ext_seq_len = tl.cast(max_seq_len, tl.float32) / 64.0
|
||||
ext_device_core_count = tl.cast(
|
||||
device_core_count * tl.maximum(tl.log2(ext_seq_len), 1.0), tl.int32
|
||||
)
|
||||
block_h, num_kv_group = 16, num_head // num_kv_head
|
||||
if num_kv_group == 1:
|
||||
token_grid = num_seq * num_group * num_head
|
||||
else:
|
||||
# from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd
|
||||
block_h = tl.minimum(block_h, num_kv_group)
|
||||
token_grid = num_seq * num_group * tl.cdiv(num_head, block_h)
|
||||
max_kv_splits_2 = tl.minimum(
|
||||
tl.cdiv(ext_device_core_count, token_grid), max_kv_splits
|
||||
)
|
||||
kv_chunk_size_2 = tl.cdiv(max_seq_len, max_kv_splits_2)
|
||||
|
||||
num_kv_splits = tl.maximum(
|
||||
tl.cdiv(seq_lens, kv_chunk_size_1), tl.cdiv(seq_lens, kv_chunk_size_2)
|
||||
)
|
||||
|
||||
offs_token = offs_seq * num_group
|
||||
mask_token = offs_token < num_seq * num_group
|
||||
for i in range(0, num_group):
|
||||
tl.store(num_kv_splits_ptr + i + offs_token, num_kv_splits, mask=mask_token)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForwardMetadata:
|
||||
attn_logits: torch.Tensor
|
||||
attn_lse: torch.Tensor
|
||||
max_extend_len: int
|
||||
num_kv_splits: torch.Tensor
|
||||
kv_indptr: torch.Tensor
|
||||
kv_indices: torch.Tensor
|
||||
qo_indptr: torch.Tensor
|
||||
custom_mask: torch.Tensor
|
||||
mask_indptr: torch.Tensor
|
||||
|
||||
|
||||
class WaveAttnBackend(AttentionBackend):
|
||||
def __init__(
|
||||
self,
|
||||
model_runner: ModelRunner,
|
||||
skip_prefill: bool = False,
|
||||
kv_indptr_buf: Optional[torch.Tensor] = None,
|
||||
):
|
||||
# Lazy import to avoid the initialization of cuda context
|
||||
from sglang.srt.layers.attention.wave_ops.decode_attention import (
|
||||
decode_attention_fwd,
|
||||
)
|
||||
from sglang.srt.layers.attention.wave_ops.extend_attention import (
|
||||
extend_attention_wave,
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
|
||||
# Set unique cache dir for each process to avoid cache write races
|
||||
import wave_lang.kernel.wave.cache as cache
|
||||
|
||||
base_cache_dir = cache.CACHE_BASE_DIR
|
||||
new_dir = base_cache_dir / f"worker_{model_runner.tp_rank}"
|
||||
logger.info(f"Setting Wave cache dir: {new_dir}")
|
||||
cache.CACHE_BASE_DIR = new_dir
|
||||
|
||||
self.decode_attention_fwd = decode_attention_fwd
|
||||
self.extend_attention_fwd = extend_attention_wave
|
||||
|
||||
self.skip_prefill = skip_prefill
|
||||
|
||||
max_bs = model_runner.req_to_token_pool.size
|
||||
|
||||
if kv_indptr_buf is None:
|
||||
self.kv_indptr = torch.zeros(
|
||||
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
else:
|
||||
self.kv_indptr = kv_indptr_buf
|
||||
|
||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||
|
||||
if not self.skip_prefill:
|
||||
self.qo_indptr = torch.zeros(
|
||||
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
|
||||
self.mask_indptr = torch.zeros(
|
||||
(max_bs + 1,), dtype=torch.int64, device=model_runner.device
|
||||
)
|
||||
|
||||
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
|
||||
|
||||
self.num_head = (
|
||||
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||
)
|
||||
self.num_kv_head = model_runner.model_config.get_num_kv_heads(
|
||||
get_attention_tp_size()
|
||||
)
|
||||
|
||||
self.static_kv_splits = get_bool_env_var(
|
||||
"SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
|
||||
)
|
||||
self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
|
||||
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
|
||||
|
||||
self.forward_metadata: ForwardMetadata = None
|
||||
|
||||
self.max_context_len = model_runner.model_config.context_len
|
||||
|
||||
self.device = model_runner.device
|
||||
self.device_core_count = get_device_core_count(model_runner.gpu_id)
|
||||
|
||||
def get_num_kv_splits(
|
||||
self,
|
||||
num_kv_splits: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
):
|
||||
num_token, num_seq = num_kv_splits.shape[0], seq_lens.shape[0]
|
||||
num_group = num_token // num_seq
|
||||
|
||||
assert (
|
||||
num_group * num_seq == num_token
|
||||
), f"num_seq({num_seq}), num_token({num_token}), something goes wrong!"
|
||||
|
||||
if self.static_kv_splits or self.device_core_count <= 0:
|
||||
num_kv_splits.fill_(self.max_kv_splits)
|
||||
return
|
||||
|
||||
if num_seq < 256:
|
||||
SCHEDULE_SEQ = 256
|
||||
else:
|
||||
SCHEDULE_SEQ = triton.next_power_of_2(num_seq)
|
||||
|
||||
get_num_kv_splits_triton[(1,)](
|
||||
num_kv_splits,
|
||||
seq_lens,
|
||||
num_seq,
|
||||
num_group,
|
||||
self.num_head,
|
||||
self.num_kv_head,
|
||||
self.max_kv_splits,
|
||||
self.device_core_count,
|
||||
MAX_NUM_SEQ=SCHEDULE_SEQ,
|
||||
)
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
"""Init auxiliary variables for wave attention backend."""
|
||||
|
||||
bs = forward_batch.batch_size
|
||||
kv_indptr = self.kv_indptr
|
||||
spec_info = forward_batch.spec_info
|
||||
|
||||
if forward_batch.forward_mode.is_decode_or_idle():
|
||||
if spec_info is None:
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indices = torch.empty(
|
||||
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
|
||||
)
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
kv_indptr,
|
||||
None,
|
||||
kv_indices,
|
||||
self.req_to_token.stride(0),
|
||||
)
|
||||
else:
|
||||
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
||||
bs = kv_indptr.shape[0] - 1
|
||||
|
||||
from sglang.srt.layers.attention.wave_ops.decode_attention import (
|
||||
decode_attention_intermediate_arrays_shapes,
|
||||
)
|
||||
|
||||
attn_logits_shape, attn_logits_max_shape = (
|
||||
decode_attention_intermediate_arrays_shapes(
|
||||
bs, self.v_head_dim, self.num_head, self.max_kv_splits
|
||||
)
|
||||
)
|
||||
attn_logits = torch.empty(
|
||||
attn_logits_shape,
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
attn_lse = torch.empty(
|
||||
attn_logits_max_shape,
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
num_kv_splits = torch.empty((bs,), dtype=torch.int32, device=self.device)
|
||||
|
||||
self.get_num_kv_splits(num_kv_splits, forward_batch.seq_lens)
|
||||
|
||||
qo_indptr = None
|
||||
custom_mask = None
|
||||
mask_indptr = None
|
||||
max_extend_len = None
|
||||
elif forward_batch.forward_mode.is_target_verify():
|
||||
bs = len(forward_batch.req_pool_indices)
|
||||
qo_indptr = torch.arange(
|
||||
0,
|
||||
(1 + bs) * self.num_draft_tokens,
|
||||
step=self.num_draft_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
# Different with flashinfer kv_indptr and kv_indices construction
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indices = torch.empty(
|
||||
kv_indptr[-1], dtype=torch.int32, device=self.device
|
||||
)
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
kv_indptr,
|
||||
None,
|
||||
kv_indices,
|
||||
self.req_to_token.stride(0),
|
||||
)
|
||||
|
||||
custom_mask = spec_info.custom_mask
|
||||
seq_mask_len = self.num_draft_tokens * (
|
||||
forward_batch.seq_lens + self.num_draft_tokens
|
||||
)
|
||||
mask_indptr = self.mask_indptr
|
||||
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len[:bs], dim=0)
|
||||
mask_indptr = mask_indptr[: bs + 1]
|
||||
max_extend_len = self.num_draft_tokens
|
||||
num_kv_splits = None
|
||||
attn_logits = None
|
||||
attn_lse = None
|
||||
elif forward_batch.forward_mode.is_draft_extend():
|
||||
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
||||
spec_info.generate_attn_arg_prefill(
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
None,
|
||||
self.req_to_token,
|
||||
)
|
||||
)
|
||||
mask_indptr = None
|
||||
# TODO(FIXME): This will trigger an invalid Eagle tree when using
|
||||
# `max(spec_info.accept_length_cpu)`.
|
||||
# It might have been forgotten to update somewhere.
|
||||
max_extend_len = torch.max(spec_info.accept_length).item()
|
||||
num_kv_splits = None
|
||||
attn_logits = None
|
||||
attn_lse = None
|
||||
else:
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(
|
||||
forward_batch.extend_prefix_lens, dim=0
|
||||
)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indices = torch.empty(
|
||||
forward_batch.extend_prefix_lens.sum().item(),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.extend_prefix_lens,
|
||||
kv_indptr,
|
||||
None,
|
||||
kv_indices,
|
||||
self.req_to_token.stride(0),
|
||||
)
|
||||
|
||||
qo_indptr = self.qo_indptr
|
||||
qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0)
|
||||
qo_indptr = qo_indptr[: bs + 1]
|
||||
custom_mask = None
|
||||
mask_indptr = None
|
||||
attn_logits = None
|
||||
attn_lse = None
|
||||
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
|
||||
num_kv_splits = None
|
||||
|
||||
self.forward_metadata = ForwardMetadata(
|
||||
attn_logits,
|
||||
attn_lse,
|
||||
max_extend_len,
|
||||
num_kv_splits,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
qo_indptr,
|
||||
custom_mask,
|
||||
mask_indptr,
|
||||
)
|
||||
|
||||
def init_cuda_graph_state(
|
||||
self,
|
||||
max_bs: int,
|
||||
max_num_tokens: int,
|
||||
kv_indices_buf: Optional[torch.Tensor] = None,
|
||||
):
|
||||
from sglang.srt.layers.attention.wave_ops.decode_attention import (
|
||||
decode_attention_intermediate_arrays_shapes,
|
||||
)
|
||||
|
||||
attn_logits_shape, attn_logits_max_shape = (
|
||||
decode_attention_intermediate_arrays_shapes(
|
||||
max_bs, self.v_head_dim, self.num_head, self.max_kv_splits
|
||||
)
|
||||
)
|
||||
self.cuda_graph_attn_logits = torch.zeros(
|
||||
attn_logits_shape,
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
self.cuda_graph_attn_lse = torch.zeros(
|
||||
attn_logits_max_shape,
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
self.cuda_graph_num_kv_splits = torch.full(
|
||||
(max_bs,), self.max_kv_splits, dtype=torch.int32, device=self.device
|
||||
)
|
||||
if kv_indices_buf is None:
|
||||
self.cuda_graph_kv_indices = torch.zeros(
|
||||
(max_bs * self.max_context_len),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
self.cuda_graph_kv_indices = kv_indices_buf
|
||||
|
||||
if not self.skip_prefill:
|
||||
self.cuda_graph_custom_mask = torch.zeros(
|
||||
(max_bs * self.max_context_len),
|
||||
dtype=torch.uint8,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
num_tokens: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
assert encoder_lens is None, "Not supported"
|
||||
|
||||
if forward_mode.is_decode_or_idle():
|
||||
if spec_info is None:
|
||||
kv_indptr = self.kv_indptr
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indices = self.cuda_graph_kv_indices
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
kv_indptr,
|
||||
None,
|
||||
kv_indices,
|
||||
self.req_to_token.stride(0),
|
||||
)
|
||||
else:
|
||||
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
||||
|
||||
attn_logits = self.cuda_graph_attn_logits
|
||||
attn_lse = self.cuda_graph_attn_lse
|
||||
max_extend_len = None
|
||||
num_kv_splits = self.cuda_graph_num_kv_splits
|
||||
qo_indptr = None
|
||||
custom_mask = None
|
||||
mask_indptr = None
|
||||
elif forward_mode.is_target_verify():
|
||||
qo_indptr = self.qo_indptr[: bs + 1]
|
||||
qo_indptr[: bs + 1] = torch.arange(
|
||||
0,
|
||||
(1 + bs) * self.num_draft_tokens,
|
||||
step=self.num_draft_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
kv_indptr = self.kv_indptr[: bs + 1]
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
||||
kv_indices = self.cuda_graph_kv_indices
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
kv_indptr,
|
||||
None,
|
||||
kv_indices,
|
||||
self.req_to_token.stride(0),
|
||||
)
|
||||
|
||||
custom_mask = self.cuda_graph_custom_mask
|
||||
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
|
||||
mask_indptr = self.mask_indptr[: bs + 1]
|
||||
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
|
||||
max_extend_len = self.num_draft_tokens
|
||||
num_kv_splits = None
|
||||
attn_logits = None
|
||||
attn_lse = None
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid forward mode: {forward_mode=} for CUDA Graph capture."
|
||||
)
|
||||
|
||||
self.forward_metadata = ForwardMetadata(
|
||||
attn_logits,
|
||||
attn_lse,
|
||||
max_extend_len,
|
||||
num_kv_splits,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
qo_indptr,
|
||||
custom_mask,
|
||||
mask_indptr,
|
||||
)
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
# NOTE: encoder_lens expected to be zeros or None
|
||||
if forward_mode.is_decode_or_idle():
|
||||
# Update kv_indptr, kv_indices
|
||||
kv_indptr = self.kv_indptr
|
||||
kv_indices = self.cuda_graph_kv_indices
|
||||
num_kv_splits = self.cuda_graph_num_kv_splits
|
||||
if spec_info is None:
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices[:bs],
|
||||
seq_lens[:bs],
|
||||
kv_indptr,
|
||||
None,
|
||||
kv_indices,
|
||||
self.req_to_token.stride(0),
|
||||
)
|
||||
num_token = bs
|
||||
else:
|
||||
kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr
|
||||
kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
|
||||
num_token = spec_info.kv_indptr.shape[0] - 1
|
||||
self.get_num_kv_splits(num_kv_splits[:num_token], seq_lens[:bs])
|
||||
elif forward_mode.is_target_verify():
|
||||
# Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr
|
||||
bs = len(req_pool_indices)
|
||||
qo_indptr = self.qo_indptr[: bs + 1]
|
||||
qo_indptr[: bs + 1] = torch.arange(
|
||||
0,
|
||||
(1 + bs) * self.num_draft_tokens,
|
||||
step=self.num_draft_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
kv_indptr = self.kv_indptr[: bs + 1]
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
||||
kv_indices = self.cuda_graph_kv_indices
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
kv_indptr,
|
||||
None,
|
||||
kv_indices,
|
||||
self.req_to_token.stride(0),
|
||||
)
|
||||
custom_mask = self.cuda_graph_custom_mask
|
||||
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
|
||||
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
|
||||
mask_indptr = self.mask_indptr[: bs + 1]
|
||||
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid forward mode: {forward_mode=} for CUDA Graph replay."
|
||||
)
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return 1
|
||||
|
||||
def forward_extend(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
):
|
||||
# TODO: reuse the buffer across layers
|
||||
if layer.qk_head_dim != layer.v_head_dim:
|
||||
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
||||
else:
|
||||
o = torch.empty_like(q)
|
||||
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, forward_batch.out_cache_loc, k, v
|
||||
)
|
||||
|
||||
max_extend_len = self.forward_metadata.max_extend_len
|
||||
computed_max_ext_seq_len = torch.max(forward_batch.extend_seq_lens)
|
||||
if computed_max_ext_seq_len != max_extend_len:
|
||||
assert len(forward_batch.extend_seq_lens) == 1
|
||||
forward_batch.extend_seq_lens[0] = max_extend_len
|
||||
forward_batch.seq_lens = max_extend_len
|
||||
|
||||
self.extend_attention_fwd(
|
||||
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||
k.contiguous(),
|
||||
v.contiguous(),
|
||||
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
||||
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
||||
self.forward_metadata.qo_indptr,
|
||||
self.forward_metadata.kv_indptr,
|
||||
self.forward_metadata.kv_indices,
|
||||
self.forward_metadata.custom_mask,
|
||||
self.forward_metadata.mask_indptr,
|
||||
self.forward_metadata.max_extend_len,
|
||||
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
||||
is_causal=True,
|
||||
layer_scaling=layer.scaling,
|
||||
logit_cap=layer.logit_cap,
|
||||
)
|
||||
return o
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
):
|
||||
# During torch.compile, there is a bug in rotary_emb that causes the
|
||||
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
||||
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
||||
|
||||
# TODO: reuse the buffer across layers
|
||||
if layer.qk_head_dim != layer.v_head_dim:
|
||||
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
||||
else:
|
||||
o = torch.empty_like(q)
|
||||
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, forward_batch.out_cache_loc, k, v
|
||||
)
|
||||
|
||||
self.decode_attention_fwd(
|
||||
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
||||
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
||||
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
||||
self.forward_metadata.kv_indptr,
|
||||
self.forward_metadata.kv_indices,
|
||||
self.forward_metadata.attn_logits,
|
||||
self.forward_metadata.attn_lse,
|
||||
self.forward_metadata.num_kv_splits,
|
||||
self.max_kv_splits,
|
||||
layer.scaling,
|
||||
layer.logit_cap,
|
||||
)
|
||||
return o
|
||||
186
python/sglang/srt/layers/attention/wave_ops/decode_attention.py
Normal file
186
python/sglang/srt/layers/attention/wave_ops/decode_attention.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""
|
||||
Memory-efficient attention for decoding.
|
||||
It supports page size = 1.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import logging
|
||||
|
||||
from wave_lang.kernel.lang.global_symbols import *
|
||||
from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
|
||||
from wave_lang.kernel.wave.constraints import GenericDot, MMAOperand, MMAType
|
||||
from wave_lang.kernel.wave.templates.paged_decode_attention import (
|
||||
get_paged_decode_attention_kernels,
|
||||
get_paged_decode_intermediate_arrays_shapes,
|
||||
paged_decode_attention_shape,
|
||||
)
|
||||
from wave_lang.kernel.wave.utils.general_utils import get_default_scheduling_params
|
||||
from wave_lang.kernel.wave.utils.run_utils import set_default_run_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
import os
|
||||
|
||||
dump_generated_mlir = int(os.environ.get("WAVE_DUMP_MLIR", 0))
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
def get_wave_kernel(
|
||||
shape: paged_decode_attention_shape,
|
||||
max_kv_splits,
|
||||
input_dtype,
|
||||
output_dtype,
|
||||
logit_cap,
|
||||
):
|
||||
mha = (shape.num_query_heads // shape.num_kv_heads) == 1
|
||||
|
||||
# Get the kernels (either compile or load from cache).
|
||||
if mha:
|
||||
mfma_variant = (
|
||||
GenericDot(along_dim=MMAOperand.M, k_vec_size=4, k_mult=1),
|
||||
GenericDot(along_dim=MMAOperand.M, k_vec_size=1, k_mult=64),
|
||||
)
|
||||
else:
|
||||
mfma_variant = (MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16)
|
||||
|
||||
(
|
||||
phase_0,
|
||||
phase_1,
|
||||
hyperparams_0,
|
||||
hyperparams_1,
|
||||
dynamic_symbols_0,
|
||||
dynamic_symbols_1,
|
||||
) = get_paged_decode_attention_kernels(
|
||||
shape,
|
||||
mfma_variant,
|
||||
max_kv_splits,
|
||||
input_dtype=input_dtype,
|
||||
output_dtype=output_dtype,
|
||||
logit_cap=logit_cap,
|
||||
)
|
||||
hyperparams_0.update(get_default_scheduling_params())
|
||||
hyperparams_1.update(get_default_scheduling_params())
|
||||
|
||||
options = WaveCompileOptions(
|
||||
subs=hyperparams_0,
|
||||
canonicalize=True,
|
||||
run_bench=False,
|
||||
use_buffer_load_ops=True,
|
||||
use_buffer_store_ops=True,
|
||||
waves_per_eu=2,
|
||||
dynamic_symbols=dynamic_symbols_0,
|
||||
wave_runtime=True,
|
||||
)
|
||||
options = set_default_run_config(options)
|
||||
phase_0 = wave_compile(options, phase_0)
|
||||
|
||||
options = WaveCompileOptions(
|
||||
subs=hyperparams_1,
|
||||
canonicalize=True,
|
||||
run_bench=False,
|
||||
use_buffer_load_ops=False,
|
||||
use_buffer_store_ops=False,
|
||||
waves_per_eu=4,
|
||||
dynamic_symbols=dynamic_symbols_1,
|
||||
wave_runtime=True,
|
||||
)
|
||||
options = set_default_run_config(options)
|
||||
phase_1 = wave_compile(options, phase_1)
|
||||
|
||||
return phase_0, phase_1
|
||||
|
||||
|
||||
def decode_attention_intermediate_arrays_shapes(
|
||||
num_seqs, head_size_kv, num_query_heads, max_kv_splits
|
||||
):
|
||||
# Not all fields are used, but we need to pass them to the function
|
||||
shape = paged_decode_attention_shape(
|
||||
num_query_heads=num_query_heads,
|
||||
num_kv_heads=0,
|
||||
head_size=0,
|
||||
head_size_kv=head_size_kv,
|
||||
block_size=0,
|
||||
num_seqs=num_seqs,
|
||||
)
|
||||
return get_paged_decode_intermediate_arrays_shapes(shape, max_kv_splits)
|
||||
|
||||
|
||||
def decode_attention_wave(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
b_req_idx,
|
||||
req_to_token,
|
||||
attn_logits,
|
||||
attn_logits_max,
|
||||
num_kv_splits,
|
||||
max_kv_splits,
|
||||
sm_scale,
|
||||
logit_cap,
|
||||
):
|
||||
num_seqs, num_query_heads, head_size = q.shape
|
||||
_, num_kv_heads, _ = k_buffer.shape
|
||||
_, _, head_size_kv = v_buffer.shape
|
||||
block_size = 32
|
||||
shape = paged_decode_attention_shape(
|
||||
num_query_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
head_size_kv,
|
||||
block_size,
|
||||
num_seqs,
|
||||
)
|
||||
|
||||
phase_0, phase_1 = get_wave_kernel(
|
||||
shape, max_kv_splits, q.dtype, o.dtype, logit_cap
|
||||
)
|
||||
|
||||
mb_qk = phase_0(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
b_req_idx,
|
||||
req_to_token,
|
||||
attn_logits,
|
||||
attn_logits_max,
|
||||
)
|
||||
if dump_generated_mlir:
|
||||
filename = f"wave_decode_attention_phase0_{'x'.join(map(str, shape))}.mlir"
|
||||
with open(filename, "w") as f:
|
||||
f.write(mb_qk.module_op.get_asm())
|
||||
|
||||
mb_sv = phase_1(attn_logits, attn_logits_max, b_req_idx, o)
|
||||
if dump_generated_mlir:
|
||||
filename = f"wave_decode_attention_phase1_{'x'.join(map(str, shape))}.mlir"
|
||||
with open(filename, "w") as f:
|
||||
f.write(mb_sv.module_op.get_asm())
|
||||
|
||||
|
||||
def decode_attention_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
b_req_idx,
|
||||
req_to_token,
|
||||
attn_logits,
|
||||
attn_logits_max,
|
||||
num_kv_splits,
|
||||
max_kv_splits,
|
||||
sm_scale,
|
||||
logit_cap=0.0,
|
||||
):
|
||||
decode_attention_wave(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
b_req_idx,
|
||||
req_to_token,
|
||||
attn_logits,
|
||||
attn_logits_max,
|
||||
num_kv_splits,
|
||||
max_kv_splits,
|
||||
sm_scale,
|
||||
logit_cap,
|
||||
)
|
||||
149
python/sglang/srt/layers/attention/wave_ops/extend_attention.py
Normal file
149
python/sglang/srt/layers/attention/wave_ops/extend_attention.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""
|
||||
Memory-efficient attention for prefill.
|
||||
It support page size = 1.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import os
|
||||
|
||||
import torch
|
||||
from wave_lang.kernel.lang.global_symbols import *
|
||||
from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
|
||||
from wave_lang.kernel.wave.constraints import MMAType
|
||||
from wave_lang.kernel.wave.scheduling.schedule import SchedulingType
|
||||
from wave_lang.kernel.wave.templates.attention_common import AttentionShape
|
||||
from wave_lang.kernel.wave.templates.extend_attention import get_extend_attention_kernel
|
||||
from wave_lang.kernel.wave.utils.general_utils import get_default_scheduling_params
|
||||
from wave_lang.kernel.wave.utils.run_utils import set_default_run_config
|
||||
|
||||
dump_generated_mlir = int(os.environ.get("WAVE_DUMP_MLIR", 0))
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_wave_kernel(
|
||||
shape: AttentionShape,
|
||||
q_shape: tuple[int],
|
||||
k_shape: tuple[int],
|
||||
v_shape: tuple[int],
|
||||
k_cache_shape: tuple[int],
|
||||
v_cache_shape: tuple[int],
|
||||
o_shape: tuple[int],
|
||||
input_dtype: torch.dtype,
|
||||
output_dtype: torch.dtype,
|
||||
size_dtype: torch.dtype,
|
||||
is_causal: bool,
|
||||
logit_cap: float,
|
||||
layer_scaling: float,
|
||||
):
|
||||
assert shape.num_query_heads % shape.num_kv_heads == 0
|
||||
|
||||
mfma_variant = (MMAType.F32_16x16x32_K8_F16, MMAType.F32_16x16x16_F16)
|
||||
(
|
||||
extend_attention,
|
||||
hyperparams,
|
||||
dynamic_symbols,
|
||||
) = get_extend_attention_kernel(
|
||||
shape,
|
||||
mfma_variant,
|
||||
q_shape,
|
||||
k_shape,
|
||||
v_shape,
|
||||
k_cache_shape,
|
||||
v_cache_shape,
|
||||
o_shape,
|
||||
input_dtype=input_dtype,
|
||||
output_dtype=output_dtype,
|
||||
size_dtype=size_dtype,
|
||||
is_causal=is_causal,
|
||||
layer_scaling=layer_scaling,
|
||||
logit_cap=logit_cap,
|
||||
)
|
||||
|
||||
hyperparams.update(get_default_scheduling_params())
|
||||
options = WaveCompileOptions(
|
||||
subs=hyperparams,
|
||||
canonicalize=True,
|
||||
run_bench=False,
|
||||
schedule=SchedulingType.NONE,
|
||||
use_scheduling_barriers=False,
|
||||
dynamic_symbols=dynamic_symbols,
|
||||
use_buffer_load_ops=True,
|
||||
use_buffer_store_ops=True,
|
||||
waves_per_eu=2,
|
||||
denorm_fp_math_f32="preserve-sign",
|
||||
gpu_native_math_precision=True,
|
||||
wave_runtime=True,
|
||||
)
|
||||
options = set_default_run_config(options)
|
||||
extend_attention = wave_compile(options, extend_attention)
|
||||
|
||||
return extend_attention
|
||||
|
||||
|
||||
def extend_attention_wave(
|
||||
q_extend,
|
||||
k_extend,
|
||||
v_extend,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
custom_mask,
|
||||
mask_indptr,
|
||||
max_seq_len,
|
||||
output,
|
||||
is_causal=True,
|
||||
layer_scaling=None,
|
||||
logit_cap=0,
|
||||
):
|
||||
shape = AttentionShape(
|
||||
num_query_heads=q_extend.shape[1],
|
||||
num_kv_heads=k_extend.shape[1],
|
||||
head_size=q_extend.shape[2],
|
||||
head_size_kv=k_extend.shape[2],
|
||||
num_seqs=kv_indptr.shape[0] - 1,
|
||||
max_seq_len=max_seq_len,
|
||||
)
|
||||
|
||||
# Run the wave kernel.
|
||||
extend_attention = get_wave_kernel(
|
||||
shape,
|
||||
q_extend.shape,
|
||||
k_extend.shape,
|
||||
v_extend.shape,
|
||||
k_buffer.shape,
|
||||
v_buffer.shape,
|
||||
output.shape,
|
||||
input_dtype=q_extend.dtype,
|
||||
output_dtype=output.dtype,
|
||||
size_dtype=qo_indptr.dtype,
|
||||
is_causal=is_causal,
|
||||
layer_scaling=layer_scaling,
|
||||
logit_cap=logit_cap,
|
||||
)
|
||||
|
||||
mb = extend_attention(
|
||||
q_extend,
|
||||
k_extend,
|
||||
v_extend,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
max_seq_len,
|
||||
output,
|
||||
)
|
||||
|
||||
if dump_generated_mlir:
|
||||
shape_list = [
|
||||
q_extend.shape[0],
|
||||
q_extend.shape[1],
|
||||
k_extend.shape[1],
|
||||
q_extend.shape[2],
|
||||
k_extend.shape[2],
|
||||
]
|
||||
filename = f"wave_prefill_attention_{'x'.join(map(str, shape_list))}.mlir"
|
||||
with open(filename, "w") as f:
|
||||
f.write(mb.module_op.get_asm())
|
||||
@@ -0,0 +1,79 @@
|
||||
"""
|
||||
Memory-efficient attention for prefill.
|
||||
It support page size = 1.
|
||||
"""
|
||||
|
||||
import math
|
||||
import os
|
||||
|
||||
from wave_lang.kernel.lang.global_symbols import *
|
||||
from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
|
||||
from wave_lang.kernel.wave.constraints import MMAType
|
||||
from wave_lang.kernel.wave.templates.attention_common import AttentionShape
|
||||
from wave_lang.kernel.wave.templates.prefill_attention import (
|
||||
get_prefill_attention_kernel,
|
||||
)
|
||||
from wave_lang.kernel.wave.utils.general_utils import get_default_scheduling_params
|
||||
from wave_lang.kernel.wave.utils.run_utils import set_default_run_config
|
||||
|
||||
dump_generated_mlir = int(os.environ.get("WAVE_DUMP_MLIR", 0))
|
||||
|
||||
|
||||
def prefill_attention_wave(
|
||||
q, k, v, o, b_start_loc, b_seq_len, max_seq_len, is_causal=True
|
||||
):
|
||||
|
||||
shape = AttentionShape(
|
||||
num_query_heads=q.shape[1],
|
||||
num_kv_heads=k.shape[1],
|
||||
head_size=q.shape[2],
|
||||
head_size_kv=k.shape[2],
|
||||
num_seqs=b_seq_len.shape[0],
|
||||
max_seq_len=max_seq_len,
|
||||
total_seq_len=q.shape[0],
|
||||
)
|
||||
|
||||
assert shape.num_query_heads % shape.num_kv_heads == 0
|
||||
|
||||
output_shape = (shape.total_seq_len, shape.num_query_heads, shape.head_size_kv)
|
||||
# Run the wave kernel.
|
||||
mfma_variant = (MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16)
|
||||
(prefill, hyperparams) = get_prefill_attention_kernel(
|
||||
shape,
|
||||
mfma_variant,
|
||||
q.shape,
|
||||
k.shape,
|
||||
v.shape,
|
||||
output_shape,
|
||||
input_dtype=q.dtype,
|
||||
output_dtype=o.dtype,
|
||||
size_dtype=b_seq_len.dtype,
|
||||
)
|
||||
|
||||
hyperparams.update(get_default_scheduling_params())
|
||||
|
||||
log2e = 1.44269504089
|
||||
dk_sqrt = math.sqrt(1.0 / shape.head_size)
|
||||
|
||||
options = WaveCompileOptions(
|
||||
subs=hyperparams,
|
||||
canonicalize=True,
|
||||
run_bench=False,
|
||||
use_scheduling_barriers=False,
|
||||
)
|
||||
options = set_default_run_config(options)
|
||||
prefill = wave_compile(options, prefill)
|
||||
|
||||
mb = prefill(
|
||||
q * dk_sqrt * log2e,
|
||||
k,
|
||||
v,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
o,
|
||||
)
|
||||
if dump_generated_mlir:
|
||||
shape_list = [q.shape[0], q.shape[1], k.shape[1], q.shape[2], k.shape[2]]
|
||||
filename = f"wave_prefill_attention_{'x'.join(map(str, shape_list))}.mlir"
|
||||
with open(filename, "w") as f:
|
||||
f.write(mb.module_op.get_asm())
|
||||
@@ -1487,6 +1487,10 @@ class ModelRunner:
|
||||
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
|
||||
|
||||
return AiterAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "wave":
|
||||
from sglang.srt.layers.attention.wave_backend import WaveAttnBackend
|
||||
|
||||
return WaveAttnBackend(self)
|
||||
elif backend_str == "ascend":
|
||||
from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
|
||||
|
||||
|
||||
@@ -1323,6 +1323,7 @@ class ServerArgs:
|
||||
"trtllm_mla",
|
||||
"trtllm_mha",
|
||||
"dual_chunk_flash_attn",
|
||||
"wave",
|
||||
]
|
||||
parser.add_argument(
|
||||
"--attention-backend",
|
||||
|
||||
@@ -196,6 +196,8 @@ suite_amd = {
|
||||
TestFile("test_torch_native_attention_backend.py", 123),
|
||||
TestFile("test_triton_attention_backend.py", 150),
|
||||
# TestFile("test_vision_chunked_prefill.py", 175), # Disabled temporarily and track in #7701
|
||||
TestFile("test_wave_attention_kernels.py", 2),
|
||||
TestFile("test_wave_attention_backend.py", 150),
|
||||
],
|
||||
"per-commit-2-gpu-amd": [
|
||||
TestFile("lora/test_lora_tp.py", 116),
|
||||
|
||||
61
test/srt/test_wave_attention_backend.py
Normal file
61
test/srt/test_wave_attention_backend.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 -m unittest test_wave_attention_backend.TestWaveAttnBackend.test_mmlu
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.run_eval import run_eval
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
is_in_ci,
|
||||
popen_launch_server,
|
||||
run_bench_one_batch,
|
||||
)
|
||||
|
||||
|
||||
class TestWaveAttnBackend(unittest.TestCase):
|
||||
def test_latency(self):
|
||||
_, output_throughput, _ = run_bench_one_batch(
|
||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
[
|
||||
"--attention-backend",
|
||||
"wave",
|
||||
"--enable-torch-compile",
|
||||
],
|
||||
)
|
||||
|
||||
if is_in_ci():
|
||||
self.assertGreater(output_throughput, 153)
|
||||
|
||||
def _test_mmlu(self):
|
||||
model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||
base_url = DEFAULT_URL_FOR_TEST
|
||||
process = popen_launch_server(
|
||||
model,
|
||||
base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=["--attention-backend", "wave"],
|
||||
)
|
||||
|
||||
try:
|
||||
args = SimpleNamespace(
|
||||
base_url=base_url,
|
||||
model=model,
|
||||
eval_name="mmlu",
|
||||
num_examples=64,
|
||||
num_threads=32,
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
self.assertGreaterEqual(metrics["score"], 0.65)
|
||||
finally:
|
||||
kill_process_tree(process.pid)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
322
test/srt/test_wave_attention_kernels.py
Normal file
322
test/srt/test_wave_attention_kernels.py
Normal file
@@ -0,0 +1,322 @@
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.attention.triton_ops.decode_attention import (
|
||||
decode_attention_fwd_grouped as triton_decode_attention_fwd_grouped,
|
||||
)
|
||||
from sglang.srt.layers.attention.triton_ops.extend_attention import (
|
||||
extend_attention_fwd,
|
||||
redundant_attention,
|
||||
)
|
||||
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
|
||||
context_attention_fwd,
|
||||
)
|
||||
from sglang.srt.layers.attention.wave_ops.decode_attention import (
|
||||
decode_attention_intermediate_arrays_shapes,
|
||||
decode_attention_wave,
|
||||
)
|
||||
from sglang.srt.layers.attention.wave_ops.extend_attention import extend_attention_wave
|
||||
from sglang.srt.layers.attention.wave_ops.prefill_attention import (
|
||||
prefill_attention_wave,
|
||||
)
|
||||
|
||||
|
||||
class TestWaveAttention(unittest.TestCase):
|
||||
|
||||
def _set_all_seeds(self, seed):
|
||||
"""Set all random seeds for reproducibility."""
|
||||
random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
def setUp(self):
|
||||
# Set seeds before each test method
|
||||
self._set_all_seeds(42)
|
||||
|
||||
def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D):
|
||||
dtype = torch.float16
|
||||
extend_seq_len = 1024
|
||||
|
||||
b_seq_len_prefix = torch.full(
|
||||
(B,), N_CTX // B, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
b_seq_len_extend = torch.full(
|
||||
(B,), extend_seq_len, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
b_seq_len = b_seq_len_prefix + b_seq_len_extend
|
||||
max_len_in_batch = torch.max(b_seq_len, 0)[0].item()
|
||||
|
||||
b_req_idx = torch.arange(B, dtype=torch.int32, device="cuda")
|
||||
b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda")
|
||||
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
|
||||
b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda")
|
||||
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
|
||||
|
||||
kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
|
||||
kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0)
|
||||
kv_indices = torch.zeros(
|
||||
(b_seq_len_prefix.sum().item(),), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
for i in range(B):
|
||||
kv_indices[kv_indptr[i] : kv_indptr[i + 1]] = torch.arange(
|
||||
b_start_loc[i], b_start_loc[i] + b_seq_len_prefix[i]
|
||||
)
|
||||
|
||||
total_token_num = torch.sum(b_seq_len).item()
|
||||
extend_token_num = torch.sum(b_seq_len_extend).item()
|
||||
k_buffer = torch.empty(
|
||||
(total_token_num, H_KV, D), dtype=dtype, device="cuda"
|
||||
).normal_(mean=0.1, std=0.2)
|
||||
v_buffer = torch.empty(
|
||||
(total_token_num, H_KV, D), dtype=dtype, device="cuda"
|
||||
).normal_(mean=0.1, std=0.2)
|
||||
|
||||
k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
|
||||
v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
|
||||
q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
|
||||
for i in range(B):
|
||||
extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
|
||||
extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
|
||||
extend_start = b_start_loc_extend[i]
|
||||
extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
|
||||
k_extend[extend_start:extend_end] = k_buffer[
|
||||
extend_start_in_buffer:extend_end_in_buffer
|
||||
]
|
||||
v_extend[extend_start:extend_end] = v_buffer[
|
||||
extend_start_in_buffer:extend_end_in_buffer
|
||||
]
|
||||
q_extend[extend_start:extend_end] = torch.empty(
|
||||
(b_seq_len_extend[i], H_Q, D), dtype=dtype, device="cuda"
|
||||
).normal_(mean=0.1, std=0.2)
|
||||
|
||||
o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
|
||||
o_extend_mask = torch.empty(
|
||||
(extend_token_num, H_Q, D), dtype=dtype, device="cuda"
|
||||
)
|
||||
o_redundant = torch.empty(
|
||||
(extend_token_num, H_Q, D), dtype=dtype, device="cuda"
|
||||
)
|
||||
|
||||
b_seq_len_extend = b_seq_len - b_seq_len_prefix
|
||||
max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
|
||||
qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
|
||||
qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0)
|
||||
|
||||
custom_mask = None
|
||||
mask_indptr = None
|
||||
|
||||
redundant_attention(
|
||||
q_extend,
|
||||
o_redundant,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
b_seq_len_prefix,
|
||||
max_len_in_batch,
|
||||
)
|
||||
|
||||
is_causal = True
|
||||
|
||||
o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
|
||||
extend_attention_fwd(
|
||||
q_extend,
|
||||
k_extend,
|
||||
v_extend,
|
||||
o_extend,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
custom_mask,
|
||||
is_causal,
|
||||
mask_indptr,
|
||||
max_len_extend,
|
||||
)
|
||||
|
||||
o_wave = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
|
||||
extend_attention_wave(
|
||||
q_extend,
|
||||
k_extend,
|
||||
v_extend,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
custom_mask,
|
||||
mask_indptr,
|
||||
max_len_extend,
|
||||
o_wave,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(o_extend, o_redundant, rtol=1e-2))
|
||||
self.assertTrue(torch.allclose(o_wave, o_redundant, rtol=1e-2))
|
||||
|
||||
def test_extend_attention(self):
|
||||
|
||||
# Define the varying parameter values
|
||||
attention_values = [128]
|
||||
|
||||
# Loop through the values and call the method
|
||||
for value in attention_values:
|
||||
self._test_extend_attention_once(32, 16384, 6, 1, value)
|
||||
|
||||
def _test_grouped_decode_attention_once(self, B, S, H_Q, H_KV, D, D_V):
|
||||
dtype = torch.float16
|
||||
seq_len = S # This represents the number of tokens already in the sequence
|
||||
total_tokens = B * seq_len
|
||||
sm_scale = 1.0 / (D**0.5)
|
||||
max_kv_splits = 8
|
||||
num_kv_splits = torch.full((B,), 4, dtype=torch.int32, device="cuda")
|
||||
|
||||
# q represents the new token being generated, one per batch
|
||||
q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda")
|
||||
|
||||
# k_buffer and v_buffer represent all previous tokens
|
||||
k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")
|
||||
v_buffer = torch.randn(total_tokens, H_KV, D_V, dtype=dtype, device="cuda")
|
||||
|
||||
# o will have the same shape as q
|
||||
o_triton = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
|
||||
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
|
||||
|
||||
req_to_token = torch.arange(total_tokens, device="cuda", dtype=torch.int32)
|
||||
b_req_idx = torch.zeros(B + 1, device="cuda", dtype=torch.int32)
|
||||
b_seq_len = torch.full((B,), seq_len, device="cuda", dtype=torch.int32)
|
||||
b_req_idx[1 : B + 1] = torch.cumsum(b_seq_len, dim=0)
|
||||
|
||||
attn_logits = torch.empty(
|
||||
(B, H_Q, max_kv_splits, D_V + 1),
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
)
|
||||
attn_lse = torch.empty(
|
||||
(B, H_Q, max_kv_splits),
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
logit_cap = 0.0
|
||||
triton_decode_attention_fwd_grouped(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o_triton,
|
||||
b_req_idx,
|
||||
req_to_token,
|
||||
attn_logits,
|
||||
attn_lse,
|
||||
num_kv_splits,
|
||||
max_kv_splits,
|
||||
sm_scale,
|
||||
logit_cap,
|
||||
)
|
||||
|
||||
attn_logits_shape, attn_logits_max_shape = (
|
||||
decode_attention_intermediate_arrays_shapes(B, D_V, H_Q, max_kv_splits)
|
||||
)
|
||||
|
||||
attn_logits = torch.empty(
|
||||
attn_logits_shape,
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
attn_logits_max = torch.empty(
|
||||
attn_logits_max_shape,
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
decode_attention_wave(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
b_req_idx,
|
||||
req_to_token,
|
||||
attn_logits,
|
||||
attn_logits_max,
|
||||
num_kv_splits,
|
||||
max_kv_splits,
|
||||
sm_scale,
|
||||
logit_cap,
|
||||
)
|
||||
|
||||
cos_sim = torch.nn.functional.cosine_similarity(
|
||||
o.flatten(), o_triton.flatten(), dim=0
|
||||
)
|
||||
print(cos_sim.item())
|
||||
self.assertTrue(cos_sim.item() > 0.99)
|
||||
self.assertTrue(torch.allclose(o, o_triton, atol=3e-2))
|
||||
|
||||
def test_grouped_decode_attention(self):
|
||||
seq_lens = [5, 100, 128, 500]
|
||||
configs = [
|
||||
(2, 16, 16, 64, 64),
|
||||
(2, 16, 1, 64, 64),
|
||||
(2, 128, 1, 80, 80),
|
||||
(32, 128, 2, 512, 512),
|
||||
(2, 128, 2, 512, 512),
|
||||
(2, 128, 1, 576, 512),
|
||||
]
|
||||
|
||||
for S in seq_lens:
|
||||
for B, H_Q, H_KV, D, D_V in configs:
|
||||
self._test_grouped_decode_attention_once(B, S, H_Q, H_KV, D, D_V)
|
||||
|
||||
def _test_context_attention_once(self, head_dim, is_causal):
|
||||
# Set up a simple test case
|
||||
dtype = torch.float16
|
||||
num_heads = 4
|
||||
kv_heads = 1
|
||||
seq_lens = [128, 256]
|
||||
max_seq_len = max(seq_lens)
|
||||
|
||||
# Create random input tensors
|
||||
q = torch.randn(sum(seq_lens), num_heads, head_dim, dtype=dtype, device="cuda")
|
||||
k = torch.randn(sum(seq_lens), kv_heads, head_dim, dtype=dtype, device="cuda")
|
||||
v = torch.randn(sum(seq_lens), kv_heads, head_dim, dtype=dtype, device="cuda")
|
||||
o_triton = torch.zeros(
|
||||
sum(seq_lens), num_heads, head_dim, dtype=dtype, device="cuda"
|
||||
)
|
||||
o = torch.zeros(sum(seq_lens), num_heads, head_dim, dtype=dtype, device="cuda")
|
||||
|
||||
# Create b_start_loc and b_seq_len tensors
|
||||
b_start_loc = torch.tensor([0, seq_lens[0]], device="cuda")
|
||||
b_seq_len = torch.tensor(seq_lens, device="cuda")
|
||||
|
||||
context_attention_fwd(
|
||||
q, k, v, o_triton, b_start_loc, b_seq_len, max_seq_len, is_causal=is_causal
|
||||
)
|
||||
prefill_attention_wave(
|
||||
q, k, v, o, b_start_loc, b_seq_len, max_seq_len, is_causal=is_causal
|
||||
)
|
||||
cos_sim = torch.nn.functional.cosine_similarity(
|
||||
o.flatten(), o_triton.flatten(), dim=0
|
||||
)
|
||||
|
||||
print(cos_sim.item())
|
||||
self.assertTrue(torch.allclose(o, o_triton, atol=3e-2))
|
||||
self.assertTrue(cos_sim.item() > 1 - (1e-5))
|
||||
|
||||
def test_context_attention(self):
|
||||
head_dim = [128, 96]
|
||||
|
||||
for dim in head_dim:
|
||||
for is_causal in [False]:
|
||||
self._test_context_attention_once(dim, is_causal)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user