Add intel_amx backend for Radix Attention for CPU (#6408)

Co-authored-by: Chunyuan WU <chunyuan.wu@intel.com>
Co-authored-by: Thien Tran <gau.nernst@yahoo.com.sg>
This commit is contained in:
YanbingJiang
2025-05-31 12:37:42 +08:00
committed by GitHub
parent e39bca0756
commit 888cb175a6
8 changed files with 185 additions and 5 deletions

View File

@@ -39,7 +39,7 @@ import triton
import triton.language as tl
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.utils import flatten_nested_list, get_compiler_backend
from sglang.srt.utils import flatten_nested_list, get_compiler_backend, support_triton
if TYPE_CHECKING:
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
@@ -351,7 +351,7 @@ class ForwardBatch:
ret.extend_prefix_lens = torch.tensor(
batch.extend_prefix_lens, dtype=torch.int32
).to(device, non_blocking=True)
if model_runner.server_args.attention_backend != "torch_native":
if support_triton(model_runner.server_args.attention_backend):
ret.extend_num_tokens = batch.extend_num_tokens
positions, ret.extend_start_loc = compute_position_triton(
ret.extend_prefix_lens,

View File

@@ -91,6 +91,7 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import (
MultiprocessingSerializer,
cpu_has_amx_support,
enable_show_time_cost,
get_available_gpu_memory,
get_bool_env_var,
@@ -317,6 +318,16 @@ class ModelRunner:
def model_specific_adjustment(self):
server_args = self.server_args
if (
server_args.attention_backend == "intel_amx"
and server_args.device == "cpu"
and not cpu_has_amx_support()
):
logger.info(
"The current platform does not support Intel AMX, will fallback to torch_native backend."
)
server_args.attention_backend = "torch_native"
if server_args.attention_backend is None:
"""
Auto select the fastest attention backend.
@@ -369,7 +380,10 @@ class ModelRunner:
f"Invalid attention backend for MLA: {server_args.attention_backend}"
)
else:
raise ValueError("MLA optimization not supported on CPU.")
if server_args.attention_backend != "intel_amx":
raise ValueError(
"MLA optimization not supported on CPU except for intel_amx backend."
)
if (
server_args.attention_backend == "fa3"
@@ -1067,6 +1081,13 @@ class ModelRunner:
)
return CutlassMLABackend(self)
elif self.server_args.attention_backend == "intel_amx":
from sglang.srt.layers.attention.intel_amx_backend import (
IntelAMXAttnBackend,
)
logger.info(f"Intel AMX attention backend is enabled.")
return IntelAMXAttnBackend(self)
else:
raise ValueError(
f"Invalid attention backend: {self.server_args.attention_backend}"