From 25caa7a8a98e1bf73335973b29adbadb108f0ac7 Mon Sep 17 00:00:00 2001 From: "jacky.cheng" Date: Wed, 13 Aug 2025 04:49:11 +0800 Subject: [PATCH] [AMD] Support Wave attention backend with AMD GPU optimizations (#8660) Signed-off-by: Stanley Winata Signed-off-by: Harsh Menon Signed-off-by: nithinsubbiah Signed-off-by: Ivan Butygin Signed-off-by: xintin Co-authored-by: Harsh Menon Co-authored-by: Stanley Winata Co-authored-by: Stanley Winata <68087699+raikonenfnu@users.noreply.github.com> Co-authored-by: Stanley Winata Co-authored-by: Ivan Butygin Co-authored-by: nithinsubbiah Co-authored-by: Nithin Meganathan <18070964+nithinsubbiah@users.noreply.github.com> Co-authored-by: Ivan Butygin --- docs/advanced_features/attention_backend.md | 5 + python/pyproject.toml | 1 + .../srt/layers/attention/wave_backend.py | 627 ++++++++++++++++++ .../attention/wave_ops/decode_attention.py | 186 ++++++ .../attention/wave_ops/extend_attention.py | 149 +++++ .../attention/wave_ops/prefill_attention.py | 79 +++ .../sglang/srt/model_executor/model_runner.py | 4 + python/sglang/srt/server_args.py | 1 + test/srt/run_suite.py | 2 + test/srt/test_wave_attention_backend.py | 61 ++ test/srt/test_wave_attention_kernels.py | 322 +++++++++ 11 files changed, 1437 insertions(+) create mode 100644 python/sglang/srt/layers/attention/wave_backend.py create mode 100644 python/sglang/srt/layers/attention/wave_ops/decode_attention.py create mode 100644 python/sglang/srt/layers/attention/wave_ops/extend_attention.py create mode 100644 python/sglang/srt/layers/attention/wave_ops/prefill_attention.py create mode 100644 test/srt/test_wave_attention_backend.py create mode 100644 test/srt/test_wave_attention_kernels.py diff --git a/docs/advanced_features/attention_backend.md b/docs/advanced_features/attention_backend.md index e4c56ea53..68e4318d8 100644 --- a/docs/advanced_features/attention_backend.md +++ b/docs/advanced_features/attention_backend.md @@ -14,6 +14,7 @@ You can test them according to your needs. | **FlashMLA** | ✅ | ✅ | ✅ | ❌ | ❌ | | **TRTLLM MLA** | ✅ | ❌ | ✅ | ✅ | ❌ | | **Ascend** | ✅ | ❌ | ✅ | ❌ | ❌ | +| **Wave** | ✅ | ❌ | ❌ | ❌ | ❌ | **Notes:** - TRTLLM MLA only implements decode operations. For prefill operations (including multimodal inputs), it falls back to FlashInfer MLA backend. @@ -70,6 +71,10 @@ python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attenti python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend ascend ``` +- Wave +```bash +python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend wave +``` ## Steps to add a new attention backend To add a new attention backend, you can learn from the existing backends diff --git a/python/pyproject.toml b/python/pyproject.toml index 42368c9c6..08b38d629 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -82,6 +82,7 @@ srt_hip = [ "sglang[runtime_common]", "torch", "petit_kernel==0.0.2", + "wave-lang==1.0.1", ] # CPU: torch wheel for CPU needs to be installed from https://download.pytorch.org/whl/cpu diff --git a/python/sglang/srt/layers/attention/wave_backend.py b/python/sglang/srt/layers/attention/wave_backend.py new file mode 100644 index 000000000..eb6e061ac --- /dev/null +++ b/python/sglang/srt/layers/attention/wave_backend.py @@ -0,0 +1,627 @@ +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional, Union + +import torch +import triton +import triton.language as tl + +from sglang.srt.layers.attention.base_attn_backend import AttentionBackend +from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton +from sglang.srt.layers.dp_attention import get_attention_tp_size +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.utils import get_bool_env_var, get_device_core_count + +if TYPE_CHECKING: + from sglang.srt.layers.radix_attention import RadixAttention + from sglang.srt.model_executor.model_runner import ModelRunner + from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput + +logger = logging.getLogger(__name__) + + +@triton.jit +def get_num_kv_splits_triton( + num_kv_splits_ptr, + seq_lens_ptr, + num_seq, + num_group, + num_head, + num_kv_head, + max_kv_splits, + device_core_count, + MAX_NUM_SEQ: tl.constexpr, +): + # TODO: this method is tunable, we need more online serving data to tune it + offs_seq = tl.arange(0, MAX_NUM_SEQ) + mask_seq = offs_seq < num_seq + + seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=0) + max_seq_len = tl.max(seq_lens) + seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=max_seq_len) + min_seq_len = tl.min(seq_lens) + if max_seq_len * 8 < min_seq_len * 10: + min_seq_len = max_seq_len + max_kv_splits_1 = tl.minimum(tl.cdiv(max_seq_len, min_seq_len), max_kv_splits) + kv_chunk_size_1 = tl.cdiv(max_seq_len, max_kv_splits_1) + + # NOTE: this is a hack to let num_kv_split grows up with seqlen gradually + ext_seq_len = tl.cast(max_seq_len, tl.float32) / 64.0 + ext_device_core_count = tl.cast( + device_core_count * tl.maximum(tl.log2(ext_seq_len), 1.0), tl.int32 + ) + block_h, num_kv_group = 16, num_head // num_kv_head + if num_kv_group == 1: + token_grid = num_seq * num_group * num_head + else: + # from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd + block_h = tl.minimum(block_h, num_kv_group) + token_grid = num_seq * num_group * tl.cdiv(num_head, block_h) + max_kv_splits_2 = tl.minimum( + tl.cdiv(ext_device_core_count, token_grid), max_kv_splits + ) + kv_chunk_size_2 = tl.cdiv(max_seq_len, max_kv_splits_2) + + num_kv_splits = tl.maximum( + tl.cdiv(seq_lens, kv_chunk_size_1), tl.cdiv(seq_lens, kv_chunk_size_2) + ) + + offs_token = offs_seq * num_group + mask_token = offs_token < num_seq * num_group + for i in range(0, num_group): + tl.store(num_kv_splits_ptr + i + offs_token, num_kv_splits, mask=mask_token) + + +@dataclass +class ForwardMetadata: + attn_logits: torch.Tensor + attn_lse: torch.Tensor + max_extend_len: int + num_kv_splits: torch.Tensor + kv_indptr: torch.Tensor + kv_indices: torch.Tensor + qo_indptr: torch.Tensor + custom_mask: torch.Tensor + mask_indptr: torch.Tensor + + +class WaveAttnBackend(AttentionBackend): + def __init__( + self, + model_runner: ModelRunner, + skip_prefill: bool = False, + kv_indptr_buf: Optional[torch.Tensor] = None, + ): + # Lazy import to avoid the initialization of cuda context + from sglang.srt.layers.attention.wave_ops.decode_attention import ( + decode_attention_fwd, + ) + from sglang.srt.layers.attention.wave_ops.extend_attention import ( + extend_attention_wave, + ) + + super().__init__() + + # Set unique cache dir for each process to avoid cache write races + import wave_lang.kernel.wave.cache as cache + + base_cache_dir = cache.CACHE_BASE_DIR + new_dir = base_cache_dir / f"worker_{model_runner.tp_rank}" + logger.info(f"Setting Wave cache dir: {new_dir}") + cache.CACHE_BASE_DIR = new_dir + + self.decode_attention_fwd = decode_attention_fwd + self.extend_attention_fwd = extend_attention_wave + + self.skip_prefill = skip_prefill + + max_bs = model_runner.req_to_token_pool.size + + if kv_indptr_buf is None: + self.kv_indptr = torch.zeros( + (max_bs + 1,), dtype=torch.int32, device=model_runner.device + ) + else: + self.kv_indptr = kv_indptr_buf + + self.req_to_token = model_runner.req_to_token_pool.req_to_token + + if not self.skip_prefill: + self.qo_indptr = torch.zeros( + (max_bs + 1,), dtype=torch.int32, device=model_runner.device + ) + + self.mask_indptr = torch.zeros( + (max_bs + 1,), dtype=torch.int64, device=model_runner.device + ) + + self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens + + self.num_head = ( + model_runner.model_config.num_attention_heads // get_attention_tp_size() + ) + self.num_kv_head = model_runner.model_config.get_num_kv_heads( + get_attention_tp_size() + ) + + self.static_kv_splits = get_bool_env_var( + "SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false" + ) + self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits + self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1] + + self.forward_metadata: ForwardMetadata = None + + self.max_context_len = model_runner.model_config.context_len + + self.device = model_runner.device + self.device_core_count = get_device_core_count(model_runner.gpu_id) + + def get_num_kv_splits( + self, + num_kv_splits: torch.Tensor, + seq_lens: torch.Tensor, + ): + num_token, num_seq = num_kv_splits.shape[0], seq_lens.shape[0] + num_group = num_token // num_seq + + assert ( + num_group * num_seq == num_token + ), f"num_seq({num_seq}), num_token({num_token}), something goes wrong!" + + if self.static_kv_splits or self.device_core_count <= 0: + num_kv_splits.fill_(self.max_kv_splits) + return + + if num_seq < 256: + SCHEDULE_SEQ = 256 + else: + SCHEDULE_SEQ = triton.next_power_of_2(num_seq) + + get_num_kv_splits_triton[(1,)]( + num_kv_splits, + seq_lens, + num_seq, + num_group, + self.num_head, + self.num_kv_head, + self.max_kv_splits, + self.device_core_count, + MAX_NUM_SEQ=SCHEDULE_SEQ, + ) + + def init_forward_metadata(self, forward_batch: ForwardBatch): + """Init auxiliary variables for wave attention backend.""" + + bs = forward_batch.batch_size + kv_indptr = self.kv_indptr + spec_info = forward_batch.spec_info + + if forward_batch.forward_mode.is_decode_or_idle(): + if spec_info is None: + kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = torch.empty( + forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device + ) + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + else: + kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices + bs = kv_indptr.shape[0] - 1 + + from sglang.srt.layers.attention.wave_ops.decode_attention import ( + decode_attention_intermediate_arrays_shapes, + ) + + attn_logits_shape, attn_logits_max_shape = ( + decode_attention_intermediate_arrays_shapes( + bs, self.v_head_dim, self.num_head, self.max_kv_splits + ) + ) + attn_logits = torch.empty( + attn_logits_shape, + dtype=torch.float32, + device=self.device, + ) + attn_lse = torch.empty( + attn_logits_max_shape, + dtype=torch.float32, + device=self.device, + ) + num_kv_splits = torch.empty((bs,), dtype=torch.int32, device=self.device) + + self.get_num_kv_splits(num_kv_splits, forward_batch.seq_lens) + + qo_indptr = None + custom_mask = None + mask_indptr = None + max_extend_len = None + elif forward_batch.forward_mode.is_target_verify(): + bs = len(forward_batch.req_pool_indices) + qo_indptr = torch.arange( + 0, + (1 + bs) * self.num_draft_tokens, + step=self.num_draft_tokens, + dtype=torch.int32, + device=self.device, + ) + # Different with flashinfer kv_indptr and kv_indices construction + kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = torch.empty( + kv_indptr[-1], dtype=torch.int32, device=self.device + ) + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + + custom_mask = spec_info.custom_mask + seq_mask_len = self.num_draft_tokens * ( + forward_batch.seq_lens + self.num_draft_tokens + ) + mask_indptr = self.mask_indptr + mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len[:bs], dim=0) + mask_indptr = mask_indptr[: bs + 1] + max_extend_len = self.num_draft_tokens + num_kv_splits = None + attn_logits = None + attn_lse = None + elif forward_batch.forward_mode.is_draft_extend(): + kv_indices, kv_indptr, qo_indptr, custom_mask = ( + spec_info.generate_attn_arg_prefill( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + None, + self.req_to_token, + ) + ) + mask_indptr = None + # TODO(FIXME): This will trigger an invalid Eagle tree when using + # `max(spec_info.accept_length_cpu)`. + # It might have been forgotten to update somewhere. + max_extend_len = torch.max(spec_info.accept_length).item() + num_kv_splits = None + attn_logits = None + attn_lse = None + else: + kv_indptr[1 : bs + 1] = torch.cumsum( + forward_batch.extend_prefix_lens, dim=0 + ) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = torch.empty( + forward_batch.extend_prefix_lens.sum().item(), + dtype=torch.int32, + device=self.device, + ) + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + forward_batch.req_pool_indices, + forward_batch.extend_prefix_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + + qo_indptr = self.qo_indptr + qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0) + qo_indptr = qo_indptr[: bs + 1] + custom_mask = None + mask_indptr = None + attn_logits = None + attn_lse = None + max_extend_len = torch.max(forward_batch.extend_seq_lens).item() + num_kv_splits = None + + self.forward_metadata = ForwardMetadata( + attn_logits, + attn_lse, + max_extend_len, + num_kv_splits, + kv_indptr, + kv_indices, + qo_indptr, + custom_mask, + mask_indptr, + ) + + def init_cuda_graph_state( + self, + max_bs: int, + max_num_tokens: int, + kv_indices_buf: Optional[torch.Tensor] = None, + ): + from sglang.srt.layers.attention.wave_ops.decode_attention import ( + decode_attention_intermediate_arrays_shapes, + ) + + attn_logits_shape, attn_logits_max_shape = ( + decode_attention_intermediate_arrays_shapes( + max_bs, self.v_head_dim, self.num_head, self.max_kv_splits + ) + ) + self.cuda_graph_attn_logits = torch.zeros( + attn_logits_shape, + dtype=torch.float32, + device=self.device, + ) + self.cuda_graph_attn_lse = torch.zeros( + attn_logits_max_shape, + dtype=torch.float32, + device=self.device, + ) + self.cuda_graph_num_kv_splits = torch.full( + (max_bs,), self.max_kv_splits, dtype=torch.int32, device=self.device + ) + if kv_indices_buf is None: + self.cuda_graph_kv_indices = torch.zeros( + (max_bs * self.max_context_len), + dtype=torch.int32, + device=self.device, + ) + else: + self.cuda_graph_kv_indices = kv_indices_buf + + if not self.skip_prefill: + self.cuda_graph_custom_mask = torch.zeros( + (max_bs * self.max_context_len), + dtype=torch.uint8, + device=self.device, + ) + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + num_tokens: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + ): + assert encoder_lens is None, "Not supported" + + if forward_mode.is_decode_or_idle(): + if spec_info is None: + kv_indptr = self.kv_indptr + kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = self.cuda_graph_kv_indices + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + else: + kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices + + attn_logits = self.cuda_graph_attn_logits + attn_lse = self.cuda_graph_attn_lse + max_extend_len = None + num_kv_splits = self.cuda_graph_num_kv_splits + qo_indptr = None + custom_mask = None + mask_indptr = None + elif forward_mode.is_target_verify(): + qo_indptr = self.qo_indptr[: bs + 1] + qo_indptr[: bs + 1] = torch.arange( + 0, + (1 + bs) * self.num_draft_tokens, + step=self.num_draft_tokens, + dtype=torch.int32, + device=self.device, + ) + kv_indptr = self.kv_indptr[: bs + 1] + kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) + kv_indices = self.cuda_graph_kv_indices + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + + custom_mask = self.cuda_graph_custom_mask + seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens) + mask_indptr = self.mask_indptr[: bs + 1] + mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0) + max_extend_len = self.num_draft_tokens + num_kv_splits = None + attn_logits = None + attn_lse = None + else: + raise ValueError( + f"Invalid forward mode: {forward_mode=} for CUDA Graph capture." + ) + + self.forward_metadata = ForwardMetadata( + attn_logits, + attn_lse, + max_extend_len, + num_kv_splits, + kv_indptr, + kv_indices, + qo_indptr, + custom_mask, + mask_indptr, + ) + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + seq_lens_cpu: Optional[torch.Tensor], + ): + # NOTE: encoder_lens expected to be zeros or None + if forward_mode.is_decode_or_idle(): + # Update kv_indptr, kv_indices + kv_indptr = self.kv_indptr + kv_indices = self.cuda_graph_kv_indices + num_kv_splits = self.cuda_graph_num_kv_splits + if spec_info is None: + kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0) + kv_indptr = kv_indptr[: bs + 1] + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices[:bs], + seq_lens[:bs], + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + num_token = bs + else: + kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr + kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices + num_token = spec_info.kv_indptr.shape[0] - 1 + self.get_num_kv_splits(num_kv_splits[:num_token], seq_lens[:bs]) + elif forward_mode.is_target_verify(): + # Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr + bs = len(req_pool_indices) + qo_indptr = self.qo_indptr[: bs + 1] + qo_indptr[: bs + 1] = torch.arange( + 0, + (1 + bs) * self.num_draft_tokens, + step=self.num_draft_tokens, + dtype=torch.int32, + device=self.device, + ) + kv_indptr = self.kv_indptr[: bs + 1] + kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) + kv_indices = self.cuda_graph_kv_indices + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + custom_mask = self.cuda_graph_custom_mask + custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask + seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens) + mask_indptr = self.mask_indptr[: bs + 1] + mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0) + else: + raise ValueError( + f"Invalid forward mode: {forward_mode=} for CUDA Graph replay." + ) + + def get_cuda_graph_seq_len_fill_value(self): + return 1 + + def forward_extend( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + ): + # TODO: reuse the buffer across layers + if layer.qk_head_dim != layer.v_head_dim: + o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) + else: + o = torch.empty_like(q) + + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v + ) + + max_extend_len = self.forward_metadata.max_extend_len + computed_max_ext_seq_len = torch.max(forward_batch.extend_seq_lens) + if computed_max_ext_seq_len != max_extend_len: + assert len(forward_batch.extend_seq_lens) == 1 + forward_batch.extend_seq_lens[0] = max_extend_len + forward_batch.seq_lens = max_extend_len + + self.extend_attention_fwd( + q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + k.contiguous(), + v.contiguous(), + forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id), + forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id), + self.forward_metadata.qo_indptr, + self.forward_metadata.kv_indptr, + self.forward_metadata.kv_indices, + self.forward_metadata.custom_mask, + self.forward_metadata.mask_indptr, + self.forward_metadata.max_extend_len, + o.view(-1, layer.tp_q_head_num, layer.v_head_dim), + is_causal=True, + layer_scaling=layer.scaling, + logit_cap=layer.logit_cap, + ) + return o + + def forward_decode( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + ): + # During torch.compile, there is a bug in rotary_emb that causes the + # output value to have a 3D tensor shape. This reshapes the output correctly. + q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) + + # TODO: reuse the buffer across layers + if layer.qk_head_dim != layer.v_head_dim: + o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) + else: + o = torch.empty_like(q) + + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v + ) + + self.decode_attention_fwd( + q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id), + forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id), + o.view(-1, layer.tp_q_head_num, layer.v_head_dim), + self.forward_metadata.kv_indptr, + self.forward_metadata.kv_indices, + self.forward_metadata.attn_logits, + self.forward_metadata.attn_lse, + self.forward_metadata.num_kv_splits, + self.max_kv_splits, + layer.scaling, + layer.logit_cap, + ) + return o diff --git a/python/sglang/srt/layers/attention/wave_ops/decode_attention.py b/python/sglang/srt/layers/attention/wave_ops/decode_attention.py new file mode 100644 index 000000000..cb89697bd --- /dev/null +++ b/python/sglang/srt/layers/attention/wave_ops/decode_attention.py @@ -0,0 +1,186 @@ +""" +Memory-efficient attention for decoding. +It supports page size = 1. +""" + +import functools +import logging + +from wave_lang.kernel.lang.global_symbols import * +from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile +from wave_lang.kernel.wave.constraints import GenericDot, MMAOperand, MMAType +from wave_lang.kernel.wave.templates.paged_decode_attention import ( + get_paged_decode_attention_kernels, + get_paged_decode_intermediate_arrays_shapes, + paged_decode_attention_shape, +) +from wave_lang.kernel.wave.utils.general_utils import get_default_scheduling_params +from wave_lang.kernel.wave.utils.run_utils import set_default_run_config + +logger = logging.getLogger(__name__) +import os + +dump_generated_mlir = int(os.environ.get("WAVE_DUMP_MLIR", 0)) + + +@functools.lru_cache(maxsize=4096) +def get_wave_kernel( + shape: paged_decode_attention_shape, + max_kv_splits, + input_dtype, + output_dtype, + logit_cap, +): + mha = (shape.num_query_heads // shape.num_kv_heads) == 1 + + # Get the kernels (either compile or load from cache). + if mha: + mfma_variant = ( + GenericDot(along_dim=MMAOperand.M, k_vec_size=4, k_mult=1), + GenericDot(along_dim=MMAOperand.M, k_vec_size=1, k_mult=64), + ) + else: + mfma_variant = (MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16) + + ( + phase_0, + phase_1, + hyperparams_0, + hyperparams_1, + dynamic_symbols_0, + dynamic_symbols_1, + ) = get_paged_decode_attention_kernels( + shape, + mfma_variant, + max_kv_splits, + input_dtype=input_dtype, + output_dtype=output_dtype, + logit_cap=logit_cap, + ) + hyperparams_0.update(get_default_scheduling_params()) + hyperparams_1.update(get_default_scheduling_params()) + + options = WaveCompileOptions( + subs=hyperparams_0, + canonicalize=True, + run_bench=False, + use_buffer_load_ops=True, + use_buffer_store_ops=True, + waves_per_eu=2, + dynamic_symbols=dynamic_symbols_0, + wave_runtime=True, + ) + options = set_default_run_config(options) + phase_0 = wave_compile(options, phase_0) + + options = WaveCompileOptions( + subs=hyperparams_1, + canonicalize=True, + run_bench=False, + use_buffer_load_ops=False, + use_buffer_store_ops=False, + waves_per_eu=4, + dynamic_symbols=dynamic_symbols_1, + wave_runtime=True, + ) + options = set_default_run_config(options) + phase_1 = wave_compile(options, phase_1) + + return phase_0, phase_1 + + +def decode_attention_intermediate_arrays_shapes( + num_seqs, head_size_kv, num_query_heads, max_kv_splits +): + # Not all fields are used, but we need to pass them to the function + shape = paged_decode_attention_shape( + num_query_heads=num_query_heads, + num_kv_heads=0, + head_size=0, + head_size_kv=head_size_kv, + block_size=0, + num_seqs=num_seqs, + ) + return get_paged_decode_intermediate_arrays_shapes(shape, max_kv_splits) + + +def decode_attention_wave( + q, + k_buffer, + v_buffer, + o, + b_req_idx, + req_to_token, + attn_logits, + attn_logits_max, + num_kv_splits, + max_kv_splits, + sm_scale, + logit_cap, +): + num_seqs, num_query_heads, head_size = q.shape + _, num_kv_heads, _ = k_buffer.shape + _, _, head_size_kv = v_buffer.shape + block_size = 32 + shape = paged_decode_attention_shape( + num_query_heads, + num_kv_heads, + head_size, + head_size_kv, + block_size, + num_seqs, + ) + + phase_0, phase_1 = get_wave_kernel( + shape, max_kv_splits, q.dtype, o.dtype, logit_cap + ) + + mb_qk = phase_0( + q, + k_buffer, + v_buffer, + b_req_idx, + req_to_token, + attn_logits, + attn_logits_max, + ) + if dump_generated_mlir: + filename = f"wave_decode_attention_phase0_{'x'.join(map(str, shape))}.mlir" + with open(filename, "w") as f: + f.write(mb_qk.module_op.get_asm()) + + mb_sv = phase_1(attn_logits, attn_logits_max, b_req_idx, o) + if dump_generated_mlir: + filename = f"wave_decode_attention_phase1_{'x'.join(map(str, shape))}.mlir" + with open(filename, "w") as f: + f.write(mb_sv.module_op.get_asm()) + + +def decode_attention_fwd( + q, + k_buffer, + v_buffer, + o, + b_req_idx, + req_to_token, + attn_logits, + attn_logits_max, + num_kv_splits, + max_kv_splits, + sm_scale, + logit_cap=0.0, +): + decode_attention_wave( + q, + k_buffer, + v_buffer, + o, + b_req_idx, + req_to_token, + attn_logits, + attn_logits_max, + num_kv_splits, + max_kv_splits, + sm_scale, + logit_cap, + ) diff --git a/python/sglang/srt/layers/attention/wave_ops/extend_attention.py b/python/sglang/srt/layers/attention/wave_ops/extend_attention.py new file mode 100644 index 000000000..35a53d3e2 --- /dev/null +++ b/python/sglang/srt/layers/attention/wave_ops/extend_attention.py @@ -0,0 +1,149 @@ +""" +Memory-efficient attention for prefill. +It support page size = 1. +""" + +import functools +import os + +import torch +from wave_lang.kernel.lang.global_symbols import * +from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile +from wave_lang.kernel.wave.constraints import MMAType +from wave_lang.kernel.wave.scheduling.schedule import SchedulingType +from wave_lang.kernel.wave.templates.attention_common import AttentionShape +from wave_lang.kernel.wave.templates.extend_attention import get_extend_attention_kernel +from wave_lang.kernel.wave.utils.general_utils import get_default_scheduling_params +from wave_lang.kernel.wave.utils.run_utils import set_default_run_config + +dump_generated_mlir = int(os.environ.get("WAVE_DUMP_MLIR", 0)) + + +@functools.lru_cache +def get_wave_kernel( + shape: AttentionShape, + q_shape: tuple[int], + k_shape: tuple[int], + v_shape: tuple[int], + k_cache_shape: tuple[int], + v_cache_shape: tuple[int], + o_shape: tuple[int], + input_dtype: torch.dtype, + output_dtype: torch.dtype, + size_dtype: torch.dtype, + is_causal: bool, + logit_cap: float, + layer_scaling: float, +): + assert shape.num_query_heads % shape.num_kv_heads == 0 + + mfma_variant = (MMAType.F32_16x16x32_K8_F16, MMAType.F32_16x16x16_F16) + ( + extend_attention, + hyperparams, + dynamic_symbols, + ) = get_extend_attention_kernel( + shape, + mfma_variant, + q_shape, + k_shape, + v_shape, + k_cache_shape, + v_cache_shape, + o_shape, + input_dtype=input_dtype, + output_dtype=output_dtype, + size_dtype=size_dtype, + is_causal=is_causal, + layer_scaling=layer_scaling, + logit_cap=logit_cap, + ) + + hyperparams.update(get_default_scheduling_params()) + options = WaveCompileOptions( + subs=hyperparams, + canonicalize=True, + run_bench=False, + schedule=SchedulingType.NONE, + use_scheduling_barriers=False, + dynamic_symbols=dynamic_symbols, + use_buffer_load_ops=True, + use_buffer_store_ops=True, + waves_per_eu=2, + denorm_fp_math_f32="preserve-sign", + gpu_native_math_precision=True, + wave_runtime=True, + ) + options = set_default_run_config(options) + extend_attention = wave_compile(options, extend_attention) + + return extend_attention + + +def extend_attention_wave( + q_extend, + k_extend, + v_extend, + k_buffer, + v_buffer, + qo_indptr, + kv_indptr, + kv_indices, + custom_mask, + mask_indptr, + max_seq_len, + output, + is_causal=True, + layer_scaling=None, + logit_cap=0, +): + shape = AttentionShape( + num_query_heads=q_extend.shape[1], + num_kv_heads=k_extend.shape[1], + head_size=q_extend.shape[2], + head_size_kv=k_extend.shape[2], + num_seqs=kv_indptr.shape[0] - 1, + max_seq_len=max_seq_len, + ) + + # Run the wave kernel. + extend_attention = get_wave_kernel( + shape, + q_extend.shape, + k_extend.shape, + v_extend.shape, + k_buffer.shape, + v_buffer.shape, + output.shape, + input_dtype=q_extend.dtype, + output_dtype=output.dtype, + size_dtype=qo_indptr.dtype, + is_causal=is_causal, + layer_scaling=layer_scaling, + logit_cap=logit_cap, + ) + + mb = extend_attention( + q_extend, + k_extend, + v_extend, + k_buffer, + v_buffer, + qo_indptr, + kv_indptr, + kv_indices, + max_seq_len, + output, + ) + + if dump_generated_mlir: + shape_list = [ + q_extend.shape[0], + q_extend.shape[1], + k_extend.shape[1], + q_extend.shape[2], + k_extend.shape[2], + ] + filename = f"wave_prefill_attention_{'x'.join(map(str, shape_list))}.mlir" + with open(filename, "w") as f: + f.write(mb.module_op.get_asm()) diff --git a/python/sglang/srt/layers/attention/wave_ops/prefill_attention.py b/python/sglang/srt/layers/attention/wave_ops/prefill_attention.py new file mode 100644 index 000000000..2d8aa4678 --- /dev/null +++ b/python/sglang/srt/layers/attention/wave_ops/prefill_attention.py @@ -0,0 +1,79 @@ +""" +Memory-efficient attention for prefill. +It support page size = 1. +""" + +import math +import os + +from wave_lang.kernel.lang.global_symbols import * +from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile +from wave_lang.kernel.wave.constraints import MMAType +from wave_lang.kernel.wave.templates.attention_common import AttentionShape +from wave_lang.kernel.wave.templates.prefill_attention import ( + get_prefill_attention_kernel, +) +from wave_lang.kernel.wave.utils.general_utils import get_default_scheduling_params +from wave_lang.kernel.wave.utils.run_utils import set_default_run_config + +dump_generated_mlir = int(os.environ.get("WAVE_DUMP_MLIR", 0)) + + +def prefill_attention_wave( + q, k, v, o, b_start_loc, b_seq_len, max_seq_len, is_causal=True +): + + shape = AttentionShape( + num_query_heads=q.shape[1], + num_kv_heads=k.shape[1], + head_size=q.shape[2], + head_size_kv=k.shape[2], + num_seqs=b_seq_len.shape[0], + max_seq_len=max_seq_len, + total_seq_len=q.shape[0], + ) + + assert shape.num_query_heads % shape.num_kv_heads == 0 + + output_shape = (shape.total_seq_len, shape.num_query_heads, shape.head_size_kv) + # Run the wave kernel. + mfma_variant = (MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16) + (prefill, hyperparams) = get_prefill_attention_kernel( + shape, + mfma_variant, + q.shape, + k.shape, + v.shape, + output_shape, + input_dtype=q.dtype, + output_dtype=o.dtype, + size_dtype=b_seq_len.dtype, + ) + + hyperparams.update(get_default_scheduling_params()) + + log2e = 1.44269504089 + dk_sqrt = math.sqrt(1.0 / shape.head_size) + + options = WaveCompileOptions( + subs=hyperparams, + canonicalize=True, + run_bench=False, + use_scheduling_barriers=False, + ) + options = set_default_run_config(options) + prefill = wave_compile(options, prefill) + + mb = prefill( + q * dk_sqrt * log2e, + k, + v, + b_start_loc, + b_seq_len, + o, + ) + if dump_generated_mlir: + shape_list = [q.shape[0], q.shape[1], k.shape[1], q.shape[2], k.shape[2]] + filename = f"wave_prefill_attention_{'x'.join(map(str, shape_list))}.mlir" + with open(filename, "w") as f: + f.write(mb.module_op.get_asm()) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index c865daf6b..fe56a208b 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1487,6 +1487,10 @@ class ModelRunner: from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend return AiterAttnBackend(self) + elif self.server_args.attention_backend == "wave": + from sglang.srt.layers.attention.wave_backend import WaveAttnBackend + + return WaveAttnBackend(self) elif backend_str == "ascend": from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 63c7f1145..457ed17f7 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1323,6 +1323,7 @@ class ServerArgs: "trtllm_mla", "trtllm_mha", "dual_chunk_flash_attn", + "wave", ] parser.add_argument( "--attention-backend", diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index dcf2c0efb..1bbdb7f65 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -196,6 +196,8 @@ suite_amd = { TestFile("test_torch_native_attention_backend.py", 123), TestFile("test_triton_attention_backend.py", 150), # TestFile("test_vision_chunked_prefill.py", 175), # Disabled temporarily and track in #7701 + TestFile("test_wave_attention_kernels.py", 2), + TestFile("test_wave_attention_backend.py", 150), ], "per-commit-2-gpu-amd": [ TestFile("lora/test_lora_tp.py", 116), diff --git a/test/srt/test_wave_attention_backend.py b/test/srt/test_wave_attention_backend.py new file mode 100644 index 000000000..5feab4595 --- /dev/null +++ b/test/srt/test_wave_attention_backend.py @@ -0,0 +1,61 @@ +""" +Usage: +python3 -m unittest test_wave_attention_backend.TestWaveAttnBackend.test_mmlu +""" + +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + is_in_ci, + popen_launch_server, + run_bench_one_batch, +) + + +class TestWaveAttnBackend(unittest.TestCase): + def test_latency(self): + _, output_throughput, _ = run_bench_one_batch( + DEFAULT_MODEL_NAME_FOR_TEST, + [ + "--attention-backend", + "wave", + "--enable-torch-compile", + ], + ) + + if is_in_ci(): + self.assertGreater(output_throughput, 153) + + def _test_mmlu(self): + model = DEFAULT_MODEL_NAME_FOR_TEST + base_url = DEFAULT_URL_FOR_TEST + process = popen_launch_server( + model, + base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--attention-backend", "wave"], + ) + + try: + args = SimpleNamespace( + base_url=base_url, + model=model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + self.assertGreaterEqual(metrics["score"], 0.65) + finally: + kill_process_tree(process.pid) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_wave_attention_kernels.py b/test/srt/test_wave_attention_kernels.py new file mode 100644 index 000000000..d4c2ff8e5 --- /dev/null +++ b/test/srt/test_wave_attention_kernels.py @@ -0,0 +1,322 @@ +import random +import unittest + +import torch + +from sglang.srt.layers.attention.triton_ops.decode_attention import ( + decode_attention_fwd_grouped as triton_decode_attention_fwd_grouped, +) +from sglang.srt.layers.attention.triton_ops.extend_attention import ( + extend_attention_fwd, + redundant_attention, +) +from sglang.srt.layers.attention.triton_ops.prefill_attention import ( + context_attention_fwd, +) +from sglang.srt.layers.attention.wave_ops.decode_attention import ( + decode_attention_intermediate_arrays_shapes, + decode_attention_wave, +) +from sglang.srt.layers.attention.wave_ops.extend_attention import extend_attention_wave +from sglang.srt.layers.attention.wave_ops.prefill_attention import ( + prefill_attention_wave, +) + + +class TestWaveAttention(unittest.TestCase): + + def _set_all_seeds(self, seed): + """Set all random seeds for reproducibility.""" + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + def setUp(self): + # Set seeds before each test method + self._set_all_seeds(42) + + def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D): + dtype = torch.float16 + extend_seq_len = 1024 + + b_seq_len_prefix = torch.full( + (B,), N_CTX // B, dtype=torch.int32, device="cuda" + ) + b_seq_len_extend = torch.full( + (B,), extend_seq_len, dtype=torch.int32, device="cuda" + ) + b_seq_len = b_seq_len_prefix + b_seq_len_extend + max_len_in_batch = torch.max(b_seq_len, 0)[0].item() + + b_req_idx = torch.arange(B, dtype=torch.int32, device="cuda") + b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda") + b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0) + b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda") + b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) + + kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda") + kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0) + kv_indices = torch.zeros( + (b_seq_len_prefix.sum().item(),), dtype=torch.int32, device="cuda" + ) + + for i in range(B): + kv_indices[kv_indptr[i] : kv_indptr[i + 1]] = torch.arange( + b_start_loc[i], b_start_loc[i] + b_seq_len_prefix[i] + ) + + total_token_num = torch.sum(b_seq_len).item() + extend_token_num = torch.sum(b_seq_len_extend).item() + k_buffer = torch.empty( + (total_token_num, H_KV, D), dtype=dtype, device="cuda" + ).normal_(mean=0.1, std=0.2) + v_buffer = torch.empty( + (total_token_num, H_KV, D), dtype=dtype, device="cuda" + ).normal_(mean=0.1, std=0.2) + + k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda") + v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda") + q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda") + for i in range(B): + extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i] + extend_end_in_buffer = b_start_loc[i] + b_seq_len[i] + extend_start = b_start_loc_extend[i] + extend_end = b_start_loc_extend[i] + b_seq_len_extend[i] + k_extend[extend_start:extend_end] = k_buffer[ + extend_start_in_buffer:extend_end_in_buffer + ] + v_extend[extend_start:extend_end] = v_buffer[ + extend_start_in_buffer:extend_end_in_buffer + ] + q_extend[extend_start:extend_end] = torch.empty( + (b_seq_len_extend[i], H_Q, D), dtype=dtype, device="cuda" + ).normal_(mean=0.1, std=0.2) + + o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda") + o_extend_mask = torch.empty( + (extend_token_num, H_Q, D), dtype=dtype, device="cuda" + ) + o_redundant = torch.empty( + (extend_token_num, H_Q, D), dtype=dtype, device="cuda" + ) + + b_seq_len_extend = b_seq_len - b_seq_len_prefix + max_len_extend = torch.max(b_seq_len_extend, 0)[0].item() + qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda") + qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0) + + custom_mask = None + mask_indptr = None + + redundant_attention( + q_extend, + o_redundant, + k_buffer, + v_buffer, + b_req_idx, + b_start_loc, + b_seq_len, + b_seq_len_prefix, + max_len_in_batch, + ) + + is_causal = True + + o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda") + extend_attention_fwd( + q_extend, + k_extend, + v_extend, + o_extend, + k_buffer, + v_buffer, + qo_indptr, + kv_indptr, + kv_indices, + custom_mask, + is_causal, + mask_indptr, + max_len_extend, + ) + + o_wave = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda") + extend_attention_wave( + q_extend, + k_extend, + v_extend, + k_buffer, + v_buffer, + qo_indptr, + kv_indptr, + kv_indices, + custom_mask, + mask_indptr, + max_len_extend, + o_wave, + is_causal=is_causal, + ) + + self.assertTrue(torch.allclose(o_extend, o_redundant, rtol=1e-2)) + self.assertTrue(torch.allclose(o_wave, o_redundant, rtol=1e-2)) + + def test_extend_attention(self): + + # Define the varying parameter values + attention_values = [128] + + # Loop through the values and call the method + for value in attention_values: + self._test_extend_attention_once(32, 16384, 6, 1, value) + + def _test_grouped_decode_attention_once(self, B, S, H_Q, H_KV, D, D_V): + dtype = torch.float16 + seq_len = S # This represents the number of tokens already in the sequence + total_tokens = B * seq_len + sm_scale = 1.0 / (D**0.5) + max_kv_splits = 8 + num_kv_splits = torch.full((B,), 4, dtype=torch.int32, device="cuda") + + # q represents the new token being generated, one per batch + q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda") + + # k_buffer and v_buffer represent all previous tokens + k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda") + v_buffer = torch.randn(total_tokens, H_KV, D_V, dtype=dtype, device="cuda") + + # o will have the same shape as q + o_triton = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") + o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") + + req_to_token = torch.arange(total_tokens, device="cuda", dtype=torch.int32) + b_req_idx = torch.zeros(B + 1, device="cuda", dtype=torch.int32) + b_seq_len = torch.full((B,), seq_len, device="cuda", dtype=torch.int32) + b_req_idx[1 : B + 1] = torch.cumsum(b_seq_len, dim=0) + + attn_logits = torch.empty( + (B, H_Q, max_kv_splits, D_V + 1), + dtype=torch.float32, + device="cuda", + ) + attn_lse = torch.empty( + (B, H_Q, max_kv_splits), + dtype=torch.float32, + device="cuda", + ) + + logit_cap = 0.0 + triton_decode_attention_fwd_grouped( + q, + k_buffer, + v_buffer, + o_triton, + b_req_idx, + req_to_token, + attn_logits, + attn_lse, + num_kv_splits, + max_kv_splits, + sm_scale, + logit_cap, + ) + + attn_logits_shape, attn_logits_max_shape = ( + decode_attention_intermediate_arrays_shapes(B, D_V, H_Q, max_kv_splits) + ) + + attn_logits = torch.empty( + attn_logits_shape, + dtype=torch.float32, + device="cuda", + ) + + attn_logits_max = torch.empty( + attn_logits_max_shape, + dtype=torch.float32, + device="cuda", + ) + + decode_attention_wave( + q, + k_buffer, + v_buffer, + o, + b_req_idx, + req_to_token, + attn_logits, + attn_logits_max, + num_kv_splits, + max_kv_splits, + sm_scale, + logit_cap, + ) + + cos_sim = torch.nn.functional.cosine_similarity( + o.flatten(), o_triton.flatten(), dim=0 + ) + print(cos_sim.item()) + self.assertTrue(cos_sim.item() > 0.99) + self.assertTrue(torch.allclose(o, o_triton, atol=3e-2)) + + def test_grouped_decode_attention(self): + seq_lens = [5, 100, 128, 500] + configs = [ + (2, 16, 16, 64, 64), + (2, 16, 1, 64, 64), + (2, 128, 1, 80, 80), + (32, 128, 2, 512, 512), + (2, 128, 2, 512, 512), + (2, 128, 1, 576, 512), + ] + + for S in seq_lens: + for B, H_Q, H_KV, D, D_V in configs: + self._test_grouped_decode_attention_once(B, S, H_Q, H_KV, D, D_V) + + def _test_context_attention_once(self, head_dim, is_causal): + # Set up a simple test case + dtype = torch.float16 + num_heads = 4 + kv_heads = 1 + seq_lens = [128, 256] + max_seq_len = max(seq_lens) + + # Create random input tensors + q = torch.randn(sum(seq_lens), num_heads, head_dim, dtype=dtype, device="cuda") + k = torch.randn(sum(seq_lens), kv_heads, head_dim, dtype=dtype, device="cuda") + v = torch.randn(sum(seq_lens), kv_heads, head_dim, dtype=dtype, device="cuda") + o_triton = torch.zeros( + sum(seq_lens), num_heads, head_dim, dtype=dtype, device="cuda" + ) + o = torch.zeros(sum(seq_lens), num_heads, head_dim, dtype=dtype, device="cuda") + + # Create b_start_loc and b_seq_len tensors + b_start_loc = torch.tensor([0, seq_lens[0]], device="cuda") + b_seq_len = torch.tensor(seq_lens, device="cuda") + + context_attention_fwd( + q, k, v, o_triton, b_start_loc, b_seq_len, max_seq_len, is_causal=is_causal + ) + prefill_attention_wave( + q, k, v, o, b_start_loc, b_seq_len, max_seq_len, is_causal=is_causal + ) + cos_sim = torch.nn.functional.cosine_similarity( + o.flatten(), o_triton.flatten(), dim=0 + ) + + print(cos_sim.item()) + self.assertTrue(torch.allclose(o, o_triton, atol=3e-2)) + self.assertTrue(cos_sim.item() > 1 - (1e-5)) + + def test_context_attention(self): + head_dim = [128, 96] + + for dim in head_dim: + for is_causal in [False]: + self._test_context_attention_once(dim, is_causal) + + +if __name__ == "__main__": + unittest.main()