Add intel_amx backend for Radix Attention for CPU (#6408)
Co-authored-by: Chunyuan WU <chunyuan.wu@intel.com> Co-authored-by: Thien Tran <gau.nernst@yahoo.com.sg>
This commit is contained in:
@@ -39,7 +39,7 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
||||
from sglang.srt.utils import flatten_nested_list, get_compiler_backend
|
||||
from sglang.srt.utils import flatten_nested_list, get_compiler_backend, support_triton
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
@@ -351,7 +351,7 @@ class ForwardBatch:
|
||||
ret.extend_prefix_lens = torch.tensor(
|
||||
batch.extend_prefix_lens, dtype=torch.int32
|
||||
).to(device, non_blocking=True)
|
||||
if model_runner.server_args.attention_backend != "torch_native":
|
||||
if support_triton(model_runner.server_args.attention_backend):
|
||||
ret.extend_num_tokens = batch.extend_num_tokens
|
||||
positions, ret.extend_start_loc = compute_position_triton(
|
||||
ret.extend_prefix_lens,
|
||||
|
||||
@@ -91,6 +91,7 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
from sglang.srt.utils import (
|
||||
MultiprocessingSerializer,
|
||||
cpu_has_amx_support,
|
||||
enable_show_time_cost,
|
||||
get_available_gpu_memory,
|
||||
get_bool_env_var,
|
||||
@@ -317,6 +318,16 @@ class ModelRunner:
|
||||
def model_specific_adjustment(self):
|
||||
server_args = self.server_args
|
||||
|
||||
if (
|
||||
server_args.attention_backend == "intel_amx"
|
||||
and server_args.device == "cpu"
|
||||
and not cpu_has_amx_support()
|
||||
):
|
||||
logger.info(
|
||||
"The current platform does not support Intel AMX, will fallback to torch_native backend."
|
||||
)
|
||||
server_args.attention_backend = "torch_native"
|
||||
|
||||
if server_args.attention_backend is None:
|
||||
"""
|
||||
Auto select the fastest attention backend.
|
||||
@@ -369,7 +380,10 @@ class ModelRunner:
|
||||
f"Invalid attention backend for MLA: {server_args.attention_backend}"
|
||||
)
|
||||
else:
|
||||
raise ValueError("MLA optimization not supported on CPU.")
|
||||
if server_args.attention_backend != "intel_amx":
|
||||
raise ValueError(
|
||||
"MLA optimization not supported on CPU except for intel_amx backend."
|
||||
)
|
||||
|
||||
if (
|
||||
server_args.attention_backend == "fa3"
|
||||
@@ -1067,6 +1081,13 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
return CutlassMLABackend(self)
|
||||
elif self.server_args.attention_backend == "intel_amx":
|
||||
from sglang.srt.layers.attention.intel_amx_backend import (
|
||||
IntelAMXAttnBackend,
|
||||
)
|
||||
|
||||
logger.info(f"Intel AMX attention backend is enabled.")
|
||||
return IntelAMXAttnBackend(self)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend: {self.server_args.attention_backend}"
|
||||
|
||||
Reference in New Issue
Block a user