Fuse MLA set kv cache kernel (#5748)
This commit is contained in:
@@ -625,6 +625,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
save_kv_cache=True,
|
||||
# For multi-head latent attention
|
||||
q_rope: Optional[torch.Tensor] = None,
|
||||
k_rope: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
@@ -639,11 +640,11 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||
)
|
||||
else:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
||||
layer,
|
||||
cache_loc,
|
||||
k,
|
||||
v,
|
||||
k_rope,
|
||||
)
|
||||
|
||||
# Use precomputed metadata across all layers
|
||||
@@ -887,6 +888,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
save_kv_cache=True,
|
||||
# For multi-head latent attention
|
||||
q_rope: Optional[torch.Tensor] = None,
|
||||
k_rope: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
@@ -901,11 +903,11 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||
)
|
||||
else:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
||||
layer,
|
||||
cache_loc,
|
||||
k,
|
||||
v,
|
||||
k_rope,
|
||||
)
|
||||
|
||||
# Use precomputed metadata across all layers
|
||||
|
||||
@@ -92,8 +92,11 @@ class RadixAttention(nn.Module):
|
||||
if k is not None:
|
||||
# For cross-layer sharing, kv can be None
|
||||
assert v is not None
|
||||
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
|
||||
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
|
||||
if "k_rope" not in kwargs:
|
||||
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
|
||||
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
|
||||
else:
|
||||
k = k.view(-1, self.tp_k_head_num, self.v_head_dim)
|
||||
|
||||
return forward_batch.attn_backend.forward(
|
||||
q,
|
||||
|
||||
@@ -34,6 +34,8 @@ from typing import List, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
import psutil
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.utils import debug_timing, get_compiler_backend
|
||||
@@ -405,6 +407,72 @@ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
|
||||
dst_2[loc] = src_2.to(dtype).view(store_dtype)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def set_mla_kv_buffer_kernel(
|
||||
kv_buffer_ptr,
|
||||
cache_k_nope_ptr,
|
||||
cache_k_rope_ptr,
|
||||
loc_ptr,
|
||||
buffer_stride: tl.constexpr,
|
||||
nope_stride: tl.constexpr,
|
||||
rope_stride: tl.constexpr,
|
||||
nope_dim: tl.constexpr,
|
||||
rope_dim: tl.constexpr,
|
||||
BLOCK: tl.constexpr,
|
||||
):
|
||||
pid_loc = tl.program_id(0)
|
||||
pid_blk = tl.program_id(1)
|
||||
|
||||
base = pid_blk * BLOCK
|
||||
offs = base + tl.arange(0, BLOCK)
|
||||
total_dim = nope_dim + rope_dim
|
||||
mask = offs < total_dim
|
||||
|
||||
loc = tl.load(loc_ptr + pid_loc)
|
||||
dst_ptr = kv_buffer_ptr + loc * buffer_stride + offs
|
||||
|
||||
if base + BLOCK <= nope_dim:
|
||||
src = tl.load(
|
||||
cache_k_nope_ptr + pid_loc * nope_stride + offs,
|
||||
mask=mask,
|
||||
)
|
||||
else:
|
||||
offs_rope = offs - nope_dim
|
||||
src = tl.load(
|
||||
cache_k_rope_ptr + pid_loc * rope_stride + offs_rope,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
tl.store(dst_ptr, src, mask=mask)
|
||||
|
||||
|
||||
def set_mla_kv_buffer_triton(
|
||||
kv_buffer: torch.Tensor,
|
||||
loc: torch.Tensor,
|
||||
cache_k_nope: torch.Tensor,
|
||||
cache_k_rope: torch.Tensor,
|
||||
):
|
||||
nope_dim = cache_k_nope.shape[-1]
|
||||
rope_dim = cache_k_rope.shape[-1]
|
||||
total_dim = nope_dim + rope_dim
|
||||
BLOCK = 128
|
||||
n_loc = loc.numel()
|
||||
grid = (n_loc, triton.cdiv(total_dim, BLOCK))
|
||||
|
||||
set_mla_kv_buffer_kernel[grid](
|
||||
kv_buffer,
|
||||
cache_k_nope,
|
||||
cache_k_rope,
|
||||
loc,
|
||||
kv_buffer.stride(0),
|
||||
cache_k_nope.stride(0),
|
||||
cache_k_rope.stride(0),
|
||||
nope_dim,
|
||||
rope_dim,
|
||||
BLOCK=BLOCK,
|
||||
)
|
||||
|
||||
|
||||
class MLATokenToKVPool(KVCache):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -504,6 +572,25 @@ class MLATokenToKVPool(KVCache):
|
||||
else:
|
||||
self.kv_buffer[layer_id][loc] = cache_k
|
||||
|
||||
def set_mla_kv_buffer(
|
||||
self,
|
||||
layer: RadixAttention,
|
||||
loc: torch.Tensor,
|
||||
cache_k_nope: torch.Tensor,
|
||||
cache_k_rope: torch.Tensor,
|
||||
):
|
||||
layer_id = layer.layer_id
|
||||
if cache_k_nope.dtype != self.dtype:
|
||||
cache_k_nope = cache_k_nope.to(self.dtype)
|
||||
cache_k_rope = cache_k_rope.to(self.dtype)
|
||||
if self.store_dtype != self.dtype:
|
||||
cache_k_nope = cache_k_nope.view(self.store_dtype)
|
||||
cache_k_rope = cache_k_rope.view(self.store_dtype)
|
||||
|
||||
set_mla_kv_buffer_triton(
|
||||
self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope
|
||||
)
|
||||
|
||||
def get_flat_data(self, indices):
|
||||
# prepare a large chunk of contiguous data for efficient transfer
|
||||
return torch.stack([self.kv_buffer[i][indices] for i in range(self.layer_num)])
|
||||
|
||||
@@ -757,14 +757,13 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
|
||||
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
||||
|
||||
k = torch.cat([k_nope, k_pe], dim=-1)
|
||||
|
||||
if self.attention_backend == "fa3":
|
||||
attn_output = self.attn_mqa(
|
||||
q_nope_out, k, k_nope, forward_batch, q_rope=q_pe
|
||||
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
|
||||
)
|
||||
else:
|
||||
q = torch.cat([q_nope_out, q_pe], dim=-1)
|
||||
k = torch.cat([k_nope, k_pe], dim=-1)
|
||||
attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
|
||||
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user