Add intel_amx backend for Radix Attention for CPU (#6408)
Co-authored-by: Chunyuan WU <chunyuan.wu@intel.com> Co-authored-by: Thien Tran <gau.nernst@yahoo.com.sg>
This commit is contained in:
@@ -109,3 +109,7 @@ class AttentionBackend(ABC):
|
||||
):
|
||||
"""Run a forward for extend."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def support_triton(self):
|
||||
"""Check if the current backend supports triton."""
|
||||
return True
|
||||
|
||||
128
python/sglang/srt/layers/attention/intel_amx_backend.py
Normal file
128
python/sglang/srt/layers/attention/intel_amx_backend.py
Normal file
@@ -0,0 +1,128 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
|
||||
|
||||
class IntelAMXAttnBackend(AttentionBackend):
|
||||
def __init__(self, model_runner: ModelRunner):
|
||||
import sgl_kernel
|
||||
|
||||
super().__init__()
|
||||
self.forward_metadata = None
|
||||
self.device = model_runner.device
|
||||
|
||||
self.num_head = (
|
||||
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
||||
)
|
||||
|
||||
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
|
||||
|
||||
self.decode_attention_fwd = torch.ops.sgl_kernel.decode_attention_cpu
|
||||
self.extend_attention_fwd = torch.ops.sgl_kernel.extend_attention_cpu
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
"""Init the metadata for a forward pass."""
|
||||
|
||||
bs = forward_batch.batch_size
|
||||
attn_logits = torch.zeros(
|
||||
(
|
||||
bs,
|
||||
self.num_head,
|
||||
8, # self.num_kv_splits,
|
||||
self.v_head_dim + 1,
|
||||
),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
if forward_batch.forward_mode.is_decode_or_idle():
|
||||
max_extend_len = None
|
||||
else:
|
||||
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
|
||||
self.forward_metadata = (attn_logits, max_extend_len)
|
||||
|
||||
def forward_extend(
|
||||
self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
):
|
||||
if layer.qk_head_dim != layer.v_head_dim:
|
||||
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
||||
else:
|
||||
o = torch.empty_like(q)
|
||||
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, forward_batch.out_cache_loc, k, v
|
||||
)
|
||||
|
||||
_, max_extend_len = self.forward_metadata
|
||||
|
||||
self.extend_attention_fwd(
|
||||
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||
k,
|
||||
v,
|
||||
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
||||
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
||||
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
||||
forward_batch.req_to_token_pool.req_to_token,
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.extend_seq_lens,
|
||||
forward_batch.extend_start_loc,
|
||||
max_extend_len,
|
||||
layer.scaling,
|
||||
layer.logit_cap,
|
||||
)
|
||||
return o
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
):
|
||||
attn_logits, _ = self.forward_metadata
|
||||
|
||||
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
||||
|
||||
if layer.qk_head_dim != layer.v_head_dim:
|
||||
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
||||
else:
|
||||
o = torch.empty_like(q)
|
||||
|
||||
self.decode_attention_fwd(
|
||||
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
||||
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
||||
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
||||
k,
|
||||
v,
|
||||
forward_batch.out_cache_loc,
|
||||
attn_logits,
|
||||
forward_batch.req_to_token_pool.req_to_token,
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
layer.scaling,
|
||||
layer.logit_cap,
|
||||
)
|
||||
|
||||
return o
|
||||
|
||||
def support_triton(self):
|
||||
return False
|
||||
@@ -265,3 +265,6 @@ class TorchNativeAttnBackend(AttentionBackend):
|
||||
)
|
||||
|
||||
return o
|
||||
|
||||
def support_triton(self):
|
||||
return False
|
||||
|
||||
@@ -60,7 +60,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import flatten_nested_list, get_compiler_backend
|
||||
from sglang.srt.utils import flatten_nested_list, get_compiler_backend, support_triton
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
@@ -1257,7 +1257,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
|
||||
|
||||
# Write to req_to_token_pool
|
||||
if global_server_args_dict["attention_backend"] != "torch_native":
|
||||
if support_triton(global_server_args_dict.get("attention_backend")):
|
||||
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
|
||||
|
||||
write_req_to_token_pool_triton[(bs,)](
|
||||
|
||||
@@ -39,7 +39,7 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
||||
from sglang.srt.utils import flatten_nested_list, get_compiler_backend
|
||||
from sglang.srt.utils import flatten_nested_list, get_compiler_backend, support_triton
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
@@ -351,7 +351,7 @@ class ForwardBatch:
|
||||
ret.extend_prefix_lens = torch.tensor(
|
||||
batch.extend_prefix_lens, dtype=torch.int32
|
||||
).to(device, non_blocking=True)
|
||||
if model_runner.server_args.attention_backend != "torch_native":
|
||||
if support_triton(model_runner.server_args.attention_backend):
|
||||
ret.extend_num_tokens = batch.extend_num_tokens
|
||||
positions, ret.extend_start_loc = compute_position_triton(
|
||||
ret.extend_prefix_lens,
|
||||
|
||||
@@ -91,6 +91,7 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
from sglang.srt.utils import (
|
||||
MultiprocessingSerializer,
|
||||
cpu_has_amx_support,
|
||||
enable_show_time_cost,
|
||||
get_available_gpu_memory,
|
||||
get_bool_env_var,
|
||||
@@ -317,6 +318,16 @@ class ModelRunner:
|
||||
def model_specific_adjustment(self):
|
||||
server_args = self.server_args
|
||||
|
||||
if (
|
||||
server_args.attention_backend == "intel_amx"
|
||||
and server_args.device == "cpu"
|
||||
and not cpu_has_amx_support()
|
||||
):
|
||||
logger.info(
|
||||
"The current platform does not support Intel AMX, will fallback to torch_native backend."
|
||||
)
|
||||
server_args.attention_backend = "torch_native"
|
||||
|
||||
if server_args.attention_backend is None:
|
||||
"""
|
||||
Auto select the fastest attention backend.
|
||||
@@ -369,7 +380,10 @@ class ModelRunner:
|
||||
f"Invalid attention backend for MLA: {server_args.attention_backend}"
|
||||
)
|
||||
else:
|
||||
raise ValueError("MLA optimization not supported on CPU.")
|
||||
if server_args.attention_backend != "intel_amx":
|
||||
raise ValueError(
|
||||
"MLA optimization not supported on CPU except for intel_amx backend."
|
||||
)
|
||||
|
||||
if (
|
||||
server_args.attention_backend == "fa3"
|
||||
@@ -1067,6 +1081,13 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
return CutlassMLABackend(self)
|
||||
elif self.server_args.attention_backend == "intel_amx":
|
||||
from sglang.srt.layers.attention.intel_amx_backend import (
|
||||
IntelAMXAttnBackend,
|
||||
)
|
||||
|
||||
logger.info(f"Intel AMX attention backend is enabled.")
|
||||
return IntelAMXAttnBackend(self)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend: {self.server_args.attention_backend}"
|
||||
|
||||
@@ -323,6 +323,11 @@ class ServerArgs:
|
||||
self.sampling_backend = "pytorch"
|
||||
|
||||
# Set kernel backends
|
||||
if self.device == "cpu":
|
||||
if self.attention_backend is None:
|
||||
self.attention_backend = "intel_amx"
|
||||
self.sampling_backend = "pytorch"
|
||||
|
||||
if self.sampling_backend is None:
|
||||
self.sampling_backend = (
|
||||
"flashinfer" if is_flashinfer_available() else "pytorch"
|
||||
@@ -993,6 +998,7 @@ class ServerArgs:
|
||||
"fa3",
|
||||
"flashmla",
|
||||
"cutlass_mla",
|
||||
"intel_amx",
|
||||
],
|
||||
default=ServerArgs.attention_backend,
|
||||
help="Choose the kernels for attention layers.",
|
||||
|
||||
@@ -2225,3 +2225,21 @@ def bind_or_assign(target, source):
|
||||
return target
|
||||
else:
|
||||
return source
|
||||
|
||||
|
||||
def support_triton(backend: str) -> bool:
|
||||
return backend not in ["torch_native", "intel_amx"]
|
||||
|
||||
|
||||
try:
|
||||
import sgl_kernel
|
||||
|
||||
is_intel_amx_backend_available = hasattr(
|
||||
torch.ops.sgl_kernel, "convert_weight_packed"
|
||||
)
|
||||
except:
|
||||
is_intel_amx_backend_available = False
|
||||
|
||||
|
||||
def cpu_has_amx_support():
|
||||
return torch._C._cpu._is_amx_tile_supported() and is_intel_amx_backend_available
|
||||
|
||||
Reference in New Issue
Block a user