Add intel_amx backend for Radix Attention for CPU (#6408)

Co-authored-by: Chunyuan WU <chunyuan.wu@intel.com>
Co-authored-by: Thien Tran <gau.nernst@yahoo.com.sg>
This commit is contained in:
YanbingJiang
2025-05-31 12:37:42 +08:00
committed by GitHub
parent e39bca0756
commit 888cb175a6
8 changed files with 185 additions and 5 deletions

View File

@@ -109,3 +109,7 @@ class AttentionBackend(ABC):
):
"""Run a forward for extend."""
raise NotImplementedError()
def support_triton(self):
"""Check if the current backend supports triton."""
return True

View File

@@ -0,0 +1,128 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
class IntelAMXAttnBackend(AttentionBackend):
def __init__(self, model_runner: ModelRunner):
import sgl_kernel
super().__init__()
self.forward_metadata = None
self.device = model_runner.device
self.num_head = (
model_runner.model_config.num_attention_heads // model_runner.tp_size
)
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
self.decode_attention_fwd = torch.ops.sgl_kernel.decode_attention_cpu
self.extend_attention_fwd = torch.ops.sgl_kernel.extend_attention_cpu
def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init the metadata for a forward pass."""
bs = forward_batch.batch_size
attn_logits = torch.zeros(
(
bs,
self.num_head,
8, # self.num_kv_splits,
self.v_head_dim + 1,
),
dtype=torch.float32,
device=self.device,
)
if forward_batch.forward_mode.is_decode_or_idle():
max_extend_len = None
else:
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
self.forward_metadata = (attn_logits, max_extend_len)
def forward_extend(
self,
q,
k,
v,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
else:
o = torch.empty_like(q)
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)
_, max_extend_len = self.forward_metadata
self.extend_attention_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
k,
v,
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
forward_batch.req_to_token_pool.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.extend_seq_lens,
forward_batch.extend_start_loc,
max_extend_len,
layer.scaling,
layer.logit_cap,
)
return o
def forward_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
attn_logits, _ = self.forward_metadata
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
else:
o = torch.empty_like(q)
self.decode_attention_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
k,
v,
forward_batch.out_cache_loc,
attn_logits,
forward_batch.req_to_token_pool.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
layer.scaling,
layer.logit_cap,
)
return o
def support_triton(self):
return False

View File

@@ -265,3 +265,6 @@ class TorchNativeAttnBackend(AttentionBackend):
)
return o
def support_triton(self):
return False