Fuse MLA set kv cache kernel (#5748)
This commit is contained in:
@@ -625,6 +625,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
save_kv_cache=True,
|
save_kv_cache=True,
|
||||||
# For multi-head latent attention
|
# For multi-head latent attention
|
||||||
q_rope: Optional[torch.Tensor] = None,
|
q_rope: Optional[torch.Tensor] = None,
|
||||||
|
k_rope: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
if k is not None:
|
if k is not None:
|
||||||
assert v is not None
|
assert v is not None
|
||||||
@@ -639,11 +640,11 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
||||||
layer,
|
layer,
|
||||||
cache_loc,
|
cache_loc,
|
||||||
k,
|
k,
|
||||||
v,
|
k_rope,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use precomputed metadata across all layers
|
# Use precomputed metadata across all layers
|
||||||
@@ -887,6 +888,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
save_kv_cache=True,
|
save_kv_cache=True,
|
||||||
# For multi-head latent attention
|
# For multi-head latent attention
|
||||||
q_rope: Optional[torch.Tensor] = None,
|
q_rope: Optional[torch.Tensor] = None,
|
||||||
|
k_rope: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if k is not None:
|
if k is not None:
|
||||||
assert v is not None
|
assert v is not None
|
||||||
@@ -901,11 +903,11 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
||||||
layer,
|
layer,
|
||||||
cache_loc,
|
cache_loc,
|
||||||
k,
|
k,
|
||||||
v,
|
k_rope,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use precomputed metadata across all layers
|
# Use precomputed metadata across all layers
|
||||||
|
|||||||
@@ -92,8 +92,11 @@ class RadixAttention(nn.Module):
|
|||||||
if k is not None:
|
if k is not None:
|
||||||
# For cross-layer sharing, kv can be None
|
# For cross-layer sharing, kv can be None
|
||||||
assert v is not None
|
assert v is not None
|
||||||
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
|
if "k_rope" not in kwargs:
|
||||||
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
|
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
|
||||||
|
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
|
||||||
|
else:
|
||||||
|
k = k.view(-1, self.tp_k_head_num, self.v_head_dim)
|
||||||
|
|
||||||
return forward_batch.attn_backend.forward(
|
return forward_batch.attn_backend.forward(
|
||||||
q,
|
q,
|
||||||
|
|||||||
@@ -34,6 +34,8 @@ from typing import List, Optional, Tuple, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.utils import debug_timing, get_compiler_backend
|
from sglang.srt.utils import debug_timing, get_compiler_backend
|
||||||
@@ -405,6 +407,72 @@ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
|
|||||||
dst_2[loc] = src_2.to(dtype).view(store_dtype)
|
dst_2[loc] = src_2.to(dtype).view(store_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def set_mla_kv_buffer_kernel(
|
||||||
|
kv_buffer_ptr,
|
||||||
|
cache_k_nope_ptr,
|
||||||
|
cache_k_rope_ptr,
|
||||||
|
loc_ptr,
|
||||||
|
buffer_stride: tl.constexpr,
|
||||||
|
nope_stride: tl.constexpr,
|
||||||
|
rope_stride: tl.constexpr,
|
||||||
|
nope_dim: tl.constexpr,
|
||||||
|
rope_dim: tl.constexpr,
|
||||||
|
BLOCK: tl.constexpr,
|
||||||
|
):
|
||||||
|
pid_loc = tl.program_id(0)
|
||||||
|
pid_blk = tl.program_id(1)
|
||||||
|
|
||||||
|
base = pid_blk * BLOCK
|
||||||
|
offs = base + tl.arange(0, BLOCK)
|
||||||
|
total_dim = nope_dim + rope_dim
|
||||||
|
mask = offs < total_dim
|
||||||
|
|
||||||
|
loc = tl.load(loc_ptr + pid_loc)
|
||||||
|
dst_ptr = kv_buffer_ptr + loc * buffer_stride + offs
|
||||||
|
|
||||||
|
if base + BLOCK <= nope_dim:
|
||||||
|
src = tl.load(
|
||||||
|
cache_k_nope_ptr + pid_loc * nope_stride + offs,
|
||||||
|
mask=mask,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
offs_rope = offs - nope_dim
|
||||||
|
src = tl.load(
|
||||||
|
cache_k_rope_ptr + pid_loc * rope_stride + offs_rope,
|
||||||
|
mask=mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
tl.store(dst_ptr, src, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def set_mla_kv_buffer_triton(
|
||||||
|
kv_buffer: torch.Tensor,
|
||||||
|
loc: torch.Tensor,
|
||||||
|
cache_k_nope: torch.Tensor,
|
||||||
|
cache_k_rope: torch.Tensor,
|
||||||
|
):
|
||||||
|
nope_dim = cache_k_nope.shape[-1]
|
||||||
|
rope_dim = cache_k_rope.shape[-1]
|
||||||
|
total_dim = nope_dim + rope_dim
|
||||||
|
BLOCK = 128
|
||||||
|
n_loc = loc.numel()
|
||||||
|
grid = (n_loc, triton.cdiv(total_dim, BLOCK))
|
||||||
|
|
||||||
|
set_mla_kv_buffer_kernel[grid](
|
||||||
|
kv_buffer,
|
||||||
|
cache_k_nope,
|
||||||
|
cache_k_rope,
|
||||||
|
loc,
|
||||||
|
kv_buffer.stride(0),
|
||||||
|
cache_k_nope.stride(0),
|
||||||
|
cache_k_rope.stride(0),
|
||||||
|
nope_dim,
|
||||||
|
rope_dim,
|
||||||
|
BLOCK=BLOCK,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MLATokenToKVPool(KVCache):
|
class MLATokenToKVPool(KVCache):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -504,6 +572,25 @@ class MLATokenToKVPool(KVCache):
|
|||||||
else:
|
else:
|
||||||
self.kv_buffer[layer_id][loc] = cache_k
|
self.kv_buffer[layer_id][loc] = cache_k
|
||||||
|
|
||||||
|
def set_mla_kv_buffer(
|
||||||
|
self,
|
||||||
|
layer: RadixAttention,
|
||||||
|
loc: torch.Tensor,
|
||||||
|
cache_k_nope: torch.Tensor,
|
||||||
|
cache_k_rope: torch.Tensor,
|
||||||
|
):
|
||||||
|
layer_id = layer.layer_id
|
||||||
|
if cache_k_nope.dtype != self.dtype:
|
||||||
|
cache_k_nope = cache_k_nope.to(self.dtype)
|
||||||
|
cache_k_rope = cache_k_rope.to(self.dtype)
|
||||||
|
if self.store_dtype != self.dtype:
|
||||||
|
cache_k_nope = cache_k_nope.view(self.store_dtype)
|
||||||
|
cache_k_rope = cache_k_rope.view(self.store_dtype)
|
||||||
|
|
||||||
|
set_mla_kv_buffer_triton(
|
||||||
|
self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope
|
||||||
|
)
|
||||||
|
|
||||||
def get_flat_data(self, indices):
|
def get_flat_data(self, indices):
|
||||||
# prepare a large chunk of contiguous data for efficient transfer
|
# prepare a large chunk of contiguous data for efficient transfer
|
||||||
return torch.stack([self.kv_buffer[i][indices] for i in range(self.layer_num)])
|
return torch.stack([self.kv_buffer[i][indices] for i in range(self.layer_num)])
|
||||||
|
|||||||
@@ -757,14 +757,13 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
|
|
||||||
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
||||||
|
|
||||||
k = torch.cat([k_nope, k_pe], dim=-1)
|
|
||||||
|
|
||||||
if self.attention_backend == "fa3":
|
if self.attention_backend == "fa3":
|
||||||
attn_output = self.attn_mqa(
|
attn_output = self.attn_mqa(
|
||||||
q_nope_out, k, k_nope, forward_batch, q_rope=q_pe
|
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
q = torch.cat([q_nope_out, q_pe], dim=-1)
|
q = torch.cat([q_nope_out, q_pe], dim=-1)
|
||||||
|
k = torch.cat([k_nope, k_pe], dim=-1)
|
||||||
attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
|
attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
|
||||||
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user