From 888cb175a6a8a24b4ffe07ee0e1ace1bda8ea850 Mon Sep 17 00:00:00 2001 From: YanbingJiang Date: Sat, 31 May 2025 12:37:42 +0800 Subject: [PATCH] Add intel_amx backend for Radix Attention for CPU (#6408) Co-authored-by: Chunyuan WU Co-authored-by: Thien Tran --- .../srt/layers/attention/base_attn_backend.py | 4 + .../srt/layers/attention/intel_amx_backend.py | 128 ++++++++++++++++++ .../layers/attention/torch_native_backend.py | 3 + python/sglang/srt/managers/schedule_batch.py | 4 +- .../srt/model_executor/forward_batch_info.py | 4 +- .../sglang/srt/model_executor/model_runner.py | 23 +++- python/sglang/srt/server_args.py | 6 + python/sglang/srt/utils.py | 18 +++ 8 files changed, 185 insertions(+), 5 deletions(-) create mode 100644 python/sglang/srt/layers/attention/intel_amx_backend.py diff --git a/python/sglang/srt/layers/attention/base_attn_backend.py b/python/sglang/srt/layers/attention/base_attn_backend.py index 52bcd5fba..a38c319c7 100644 --- a/python/sglang/srt/layers/attention/base_attn_backend.py +++ b/python/sglang/srt/layers/attention/base_attn_backend.py @@ -109,3 +109,7 @@ class AttentionBackend(ABC): ): """Run a forward for extend.""" raise NotImplementedError() + + def support_triton(self): + """Check if the current backend supports triton.""" + return True diff --git a/python/sglang/srt/layers/attention/intel_amx_backend.py b/python/sglang/srt/layers/attention/intel_amx_backend.py new file mode 100644 index 000000000..9f2f7ece4 --- /dev/null +++ b/python/sglang/srt/layers/attention/intel_amx_backend.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch + +from sglang.srt.layers.attention.base_attn_backend import AttentionBackend +from sglang.srt.model_executor.forward_batch_info import ForwardBatch + +if TYPE_CHECKING: + from sglang.srt.layers.radix_attention import RadixAttention + from sglang.srt.model_executor.model_runner import ModelRunner + + +class IntelAMXAttnBackend(AttentionBackend): + def __init__(self, model_runner: ModelRunner): + import sgl_kernel + + super().__init__() + self.forward_metadata = None + self.device = model_runner.device + + self.num_head = ( + model_runner.model_config.num_attention_heads // model_runner.tp_size + ) + + self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1] + + self.decode_attention_fwd = torch.ops.sgl_kernel.decode_attention_cpu + self.extend_attention_fwd = torch.ops.sgl_kernel.extend_attention_cpu + + def init_forward_metadata(self, forward_batch: ForwardBatch): + """Init the metadata for a forward pass.""" + + bs = forward_batch.batch_size + attn_logits = torch.zeros( + ( + bs, + self.num_head, + 8, # self.num_kv_splits, + self.v_head_dim + 1, + ), + dtype=torch.float32, + device=self.device, + ) + if forward_batch.forward_mode.is_decode_or_idle(): + max_extend_len = None + else: + max_extend_len = torch.max(forward_batch.extend_seq_lens).item() + self.forward_metadata = (attn_logits, max_extend_len) + + def forward_extend( + self, + q, + k, + v, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + ): + if layer.qk_head_dim != layer.v_head_dim: + o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) + else: + o = torch.empty_like(q) + + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v + ) + + _, max_extend_len = self.forward_metadata + + self.extend_attention_fwd( + q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + k, + v, + o.view(-1, layer.tp_q_head_num, layer.v_head_dim), + forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id), + forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id), + forward_batch.req_to_token_pool.req_to_token, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.extend_seq_lens, + forward_batch.extend_start_loc, + max_extend_len, + layer.scaling, + layer.logit_cap, + ) + return o + + def forward_decode( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + ): + attn_logits, _ = self.forward_metadata + + q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) + + if layer.qk_head_dim != layer.v_head_dim: + o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) + else: + o = torch.empty_like(q) + + self.decode_attention_fwd( + q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id), + forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id), + o.view(-1, layer.tp_q_head_num, layer.v_head_dim), + k, + v, + forward_batch.out_cache_loc, + attn_logits, + forward_batch.req_to_token_pool.req_to_token, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + layer.scaling, + layer.logit_cap, + ) + + return o + + def support_triton(self): + return False diff --git a/python/sglang/srt/layers/attention/torch_native_backend.py b/python/sglang/srt/layers/attention/torch_native_backend.py index 78ed042de..bb06076c1 100644 --- a/python/sglang/srt/layers/attention/torch_native_backend.py +++ b/python/sglang/srt/layers/attention/torch_native_backend.py @@ -265,3 +265,6 @@ class TorchNativeAttnBackend(AttentionBackend): ) return o + + def support_triton(self): + return False diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 59b5471b6..637662713 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -60,7 +60,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import flatten_nested_list, get_compiler_backend +from sglang.srt.utils import flatten_nested_list, get_compiler_backend, support_triton if TYPE_CHECKING: from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput @@ -1257,7 +1257,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.extend_input_logprob_token_ids = extend_input_logprob_token_ids # Write to req_to_token_pool - if global_server_args_dict["attention_backend"] != "torch_native": + if support_triton(global_server_args_dict.get("attention_backend")): # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start) write_req_to_token_pool_triton[(bs,)]( diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index de462e45d..d2104e41a 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -39,7 +39,7 @@ import triton import triton.language as tl from sglang.srt.layers.rotary_embedding import MRotaryEmbedding -from sglang.srt.utils import flatten_nested_list, get_compiler_backend +from sglang.srt.utils import flatten_nested_list, get_compiler_backend, support_triton if TYPE_CHECKING: from sglang.srt.layers.attention.base_attn_backend import AttentionBackend @@ -351,7 +351,7 @@ class ForwardBatch: ret.extend_prefix_lens = torch.tensor( batch.extend_prefix_lens, dtype=torch.int32 ).to(device, non_blocking=True) - if model_runner.server_args.attention_backend != "torch_native": + if support_triton(model_runner.server_args.attention_backend): ret.extend_num_tokens = batch.extend_num_tokens positions, ret.extend_start_loc = compute_position_triton( ret.extend_prefix_lens, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 68f570b8e..6d5248376 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -91,6 +91,7 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils import ( MultiprocessingSerializer, + cpu_has_amx_support, enable_show_time_cost, get_available_gpu_memory, get_bool_env_var, @@ -317,6 +318,16 @@ class ModelRunner: def model_specific_adjustment(self): server_args = self.server_args + if ( + server_args.attention_backend == "intel_amx" + and server_args.device == "cpu" + and not cpu_has_amx_support() + ): + logger.info( + "The current platform does not support Intel AMX, will fallback to torch_native backend." + ) + server_args.attention_backend = "torch_native" + if server_args.attention_backend is None: """ Auto select the fastest attention backend. @@ -369,7 +380,10 @@ class ModelRunner: f"Invalid attention backend for MLA: {server_args.attention_backend}" ) else: - raise ValueError("MLA optimization not supported on CPU.") + if server_args.attention_backend != "intel_amx": + raise ValueError( + "MLA optimization not supported on CPU except for intel_amx backend." + ) if ( server_args.attention_backend == "fa3" @@ -1067,6 +1081,13 @@ class ModelRunner: ) return CutlassMLABackend(self) + elif self.server_args.attention_backend == "intel_amx": + from sglang.srt.layers.attention.intel_amx_backend import ( + IntelAMXAttnBackend, + ) + + logger.info(f"Intel AMX attention backend is enabled.") + return IntelAMXAttnBackend(self) else: raise ValueError( f"Invalid attention backend: {self.server_args.attention_backend}" diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index d7feb50d2..1c2958b9c 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -323,6 +323,11 @@ class ServerArgs: self.sampling_backend = "pytorch" # Set kernel backends + if self.device == "cpu": + if self.attention_backend is None: + self.attention_backend = "intel_amx" + self.sampling_backend = "pytorch" + if self.sampling_backend is None: self.sampling_backend = ( "flashinfer" if is_flashinfer_available() else "pytorch" @@ -993,6 +998,7 @@ class ServerArgs: "fa3", "flashmla", "cutlass_mla", + "intel_amx", ], default=ServerArgs.attention_backend, help="Choose the kernels for attention layers.", diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 69ade4a4f..6b82ca47a 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -2225,3 +2225,21 @@ def bind_or_assign(target, source): return target else: return source + + +def support_triton(backend: str) -> bool: + return backend not in ["torch_native", "intel_amx"] + + +try: + import sgl_kernel + + is_intel_amx_backend_available = hasattr( + torch.ops.sgl_kernel, "convert_weight_packed" + ) +except: + is_intel_amx_backend_available = False + + +def cpu_has_amx_support(): + return torch._C._cpu._is_amx_tile_supported() and is_intel_amx_backend_available