Add intel_amx backend for Radix Attention for CPU (#6408)
Co-authored-by: Chunyuan WU <chunyuan.wu@intel.com> Co-authored-by: Thien Tran <gau.nernst@yahoo.com.sg>
This commit is contained in:
@@ -109,3 +109,7 @@ class AttentionBackend(ABC):
|
|||||||
):
|
):
|
||||||
"""Run a forward for extend."""
|
"""Run a forward for extend."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def support_triton(self):
|
||||||
|
"""Check if the current backend supports triton."""
|
||||||
|
return True
|
||||||
|
|||||||
128
python/sglang/srt/layers/attention/intel_amx_backend.py
Normal file
128
python/sglang/srt/layers/attention/intel_amx_backend.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
|
|
||||||
|
|
||||||
|
class IntelAMXAttnBackend(AttentionBackend):
|
||||||
|
def __init__(self, model_runner: ModelRunner):
|
||||||
|
import sgl_kernel
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.forward_metadata = None
|
||||||
|
self.device = model_runner.device
|
||||||
|
|
||||||
|
self.num_head = (
|
||||||
|
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
||||||
|
)
|
||||||
|
|
||||||
|
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
|
||||||
|
|
||||||
|
self.decode_attention_fwd = torch.ops.sgl_kernel.decode_attention_cpu
|
||||||
|
self.extend_attention_fwd = torch.ops.sgl_kernel.extend_attention_cpu
|
||||||
|
|
||||||
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
|
"""Init the metadata for a forward pass."""
|
||||||
|
|
||||||
|
bs = forward_batch.batch_size
|
||||||
|
attn_logits = torch.zeros(
|
||||||
|
(
|
||||||
|
bs,
|
||||||
|
self.num_head,
|
||||||
|
8, # self.num_kv_splits,
|
||||||
|
self.v_head_dim + 1,
|
||||||
|
),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
if forward_batch.forward_mode.is_decode_or_idle():
|
||||||
|
max_extend_len = None
|
||||||
|
else:
|
||||||
|
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
|
||||||
|
self.forward_metadata = (attn_logits, max_extend_len)
|
||||||
|
|
||||||
|
def forward_extend(
|
||||||
|
self,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
layer: RadixAttention,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
save_kv_cache=True,
|
||||||
|
):
|
||||||
|
if layer.qk_head_dim != layer.v_head_dim:
|
||||||
|
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
||||||
|
else:
|
||||||
|
o = torch.empty_like(q)
|
||||||
|
|
||||||
|
if save_kv_cache:
|
||||||
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
|
layer, forward_batch.out_cache_loc, k, v
|
||||||
|
)
|
||||||
|
|
||||||
|
_, max_extend_len = self.forward_metadata
|
||||||
|
|
||||||
|
self.extend_attention_fwd(
|
||||||
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
||||||
|
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
||||||
|
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
||||||
|
forward_batch.req_to_token_pool.req_to_token,
|
||||||
|
forward_batch.req_pool_indices,
|
||||||
|
forward_batch.seq_lens,
|
||||||
|
forward_batch.extend_seq_lens,
|
||||||
|
forward_batch.extend_start_loc,
|
||||||
|
max_extend_len,
|
||||||
|
layer.scaling,
|
||||||
|
layer.logit_cap,
|
||||||
|
)
|
||||||
|
return o
|
||||||
|
|
||||||
|
def forward_decode(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
layer: RadixAttention,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
save_kv_cache=True,
|
||||||
|
):
|
||||||
|
attn_logits, _ = self.forward_metadata
|
||||||
|
|
||||||
|
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
||||||
|
|
||||||
|
if layer.qk_head_dim != layer.v_head_dim:
|
||||||
|
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
||||||
|
else:
|
||||||
|
o = torch.empty_like(q)
|
||||||
|
|
||||||
|
self.decode_attention_fwd(
|
||||||
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||||
|
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
||||||
|
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
||||||
|
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
forward_batch.out_cache_loc,
|
||||||
|
attn_logits,
|
||||||
|
forward_batch.req_to_token_pool.req_to_token,
|
||||||
|
forward_batch.req_pool_indices,
|
||||||
|
forward_batch.seq_lens,
|
||||||
|
layer.scaling,
|
||||||
|
layer.logit_cap,
|
||||||
|
)
|
||||||
|
|
||||||
|
return o
|
||||||
|
|
||||||
|
def support_triton(self):
|
||||||
|
return False
|
||||||
@@ -265,3 +265,6 @@ class TorchNativeAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return o
|
return o
|
||||||
|
|
||||||
|
def support_triton(self):
|
||||||
|
return False
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw
|
|||||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import flatten_nested_list, get_compiler_backend
|
from sglang.srt.utils import flatten_nested_list, get_compiler_backend, support_triton
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||||
@@ -1257,7 +1257,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
|
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
|
||||||
|
|
||||||
# Write to req_to_token_pool
|
# Write to req_to_token_pool
|
||||||
if global_server_args_dict["attention_backend"] != "torch_native":
|
if support_triton(global_server_args_dict.get("attention_backend")):
|
||||||
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
|
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
|
||||||
|
|
||||||
write_req_to_token_pool_triton[(bs,)](
|
write_req_to_token_pool_triton[(bs,)](
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ import triton
|
|||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
||||||
from sglang.srt.utils import flatten_nested_list, get_compiler_backend
|
from sglang.srt.utils import flatten_nested_list, get_compiler_backend, support_triton
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||||
@@ -351,7 +351,7 @@ class ForwardBatch:
|
|||||||
ret.extend_prefix_lens = torch.tensor(
|
ret.extend_prefix_lens = torch.tensor(
|
||||||
batch.extend_prefix_lens, dtype=torch.int32
|
batch.extend_prefix_lens, dtype=torch.int32
|
||||||
).to(device, non_blocking=True)
|
).to(device, non_blocking=True)
|
||||||
if model_runner.server_args.attention_backend != "torch_native":
|
if support_triton(model_runner.server_args.attention_backend):
|
||||||
ret.extend_num_tokens = batch.extend_num_tokens
|
ret.extend_num_tokens = batch.extend_num_tokens
|
||||||
positions, ret.extend_start_loc = compute_position_triton(
|
positions, ret.extend_start_loc = compute_position_triton(
|
||||||
ret.extend_prefix_lens,
|
ret.extend_prefix_lens,
|
||||||
|
|||||||
@@ -91,6 +91,7 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
|||||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
MultiprocessingSerializer,
|
MultiprocessingSerializer,
|
||||||
|
cpu_has_amx_support,
|
||||||
enable_show_time_cost,
|
enable_show_time_cost,
|
||||||
get_available_gpu_memory,
|
get_available_gpu_memory,
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
@@ -317,6 +318,16 @@ class ModelRunner:
|
|||||||
def model_specific_adjustment(self):
|
def model_specific_adjustment(self):
|
||||||
server_args = self.server_args
|
server_args = self.server_args
|
||||||
|
|
||||||
|
if (
|
||||||
|
server_args.attention_backend == "intel_amx"
|
||||||
|
and server_args.device == "cpu"
|
||||||
|
and not cpu_has_amx_support()
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
"The current platform does not support Intel AMX, will fallback to torch_native backend."
|
||||||
|
)
|
||||||
|
server_args.attention_backend = "torch_native"
|
||||||
|
|
||||||
if server_args.attention_backend is None:
|
if server_args.attention_backend is None:
|
||||||
"""
|
"""
|
||||||
Auto select the fastest attention backend.
|
Auto select the fastest attention backend.
|
||||||
@@ -369,7 +380,10 @@ class ModelRunner:
|
|||||||
f"Invalid attention backend for MLA: {server_args.attention_backend}"
|
f"Invalid attention backend for MLA: {server_args.attention_backend}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("MLA optimization not supported on CPU.")
|
if server_args.attention_backend != "intel_amx":
|
||||||
|
raise ValueError(
|
||||||
|
"MLA optimization not supported on CPU except for intel_amx backend."
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
server_args.attention_backend == "fa3"
|
server_args.attention_backend == "fa3"
|
||||||
@@ -1067,6 +1081,13 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return CutlassMLABackend(self)
|
return CutlassMLABackend(self)
|
||||||
|
elif self.server_args.attention_backend == "intel_amx":
|
||||||
|
from sglang.srt.layers.attention.intel_amx_backend import (
|
||||||
|
IntelAMXAttnBackend,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Intel AMX attention backend is enabled.")
|
||||||
|
return IntelAMXAttnBackend(self)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid attention backend: {self.server_args.attention_backend}"
|
f"Invalid attention backend: {self.server_args.attention_backend}"
|
||||||
|
|||||||
@@ -323,6 +323,11 @@ class ServerArgs:
|
|||||||
self.sampling_backend = "pytorch"
|
self.sampling_backend = "pytorch"
|
||||||
|
|
||||||
# Set kernel backends
|
# Set kernel backends
|
||||||
|
if self.device == "cpu":
|
||||||
|
if self.attention_backend is None:
|
||||||
|
self.attention_backend = "intel_amx"
|
||||||
|
self.sampling_backend = "pytorch"
|
||||||
|
|
||||||
if self.sampling_backend is None:
|
if self.sampling_backend is None:
|
||||||
self.sampling_backend = (
|
self.sampling_backend = (
|
||||||
"flashinfer" if is_flashinfer_available() else "pytorch"
|
"flashinfer" if is_flashinfer_available() else "pytorch"
|
||||||
@@ -993,6 +998,7 @@ class ServerArgs:
|
|||||||
"fa3",
|
"fa3",
|
||||||
"flashmla",
|
"flashmla",
|
||||||
"cutlass_mla",
|
"cutlass_mla",
|
||||||
|
"intel_amx",
|
||||||
],
|
],
|
||||||
default=ServerArgs.attention_backend,
|
default=ServerArgs.attention_backend,
|
||||||
help="Choose the kernels for attention layers.",
|
help="Choose the kernels for attention layers.",
|
||||||
|
|||||||
@@ -2225,3 +2225,21 @@ def bind_or_assign(target, source):
|
|||||||
return target
|
return target
|
||||||
else:
|
else:
|
||||||
return source
|
return source
|
||||||
|
|
||||||
|
|
||||||
|
def support_triton(backend: str) -> bool:
|
||||||
|
return backend not in ["torch_native", "intel_amx"]
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
import sgl_kernel
|
||||||
|
|
||||||
|
is_intel_amx_backend_available = hasattr(
|
||||||
|
torch.ops.sgl_kernel, "convert_weight_packed"
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
is_intel_amx_backend_available = False
|
||||||
|
|
||||||
|
|
||||||
|
def cpu_has_amx_support():
|
||||||
|
return torch._C._cpu._is_amx_tile_supported() and is_intel_amx_backend_available
|
||||||
|
|||||||
Reference in New Issue
Block a user