feat: Add FlexAttention Backend for Efficient Sparse Attention (#9947)
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
This commit is contained in:
325
python/sglang/srt/layers/attention/torch_flex_backend.py
Normal file
325
python/sglang/srt/layers/attention/torch_flex_backend.py
Normal file
@@ -0,0 +1,325 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
|
||||||
|
|
||||||
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||||
|
from sglang.srt.layers.radix_attention import AttentionType
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
|
|
||||||
|
|
||||||
|
class TorchFlexAttnBackend(AttentionBackend):
|
||||||
|
def __init__(self, model_runner: ModelRunner):
|
||||||
|
super().__init__()
|
||||||
|
self.forward_metadata = None
|
||||||
|
self.device = model_runner.device
|
||||||
|
self.flex_attention = torch.compile(flex_attention, dynamic=True)
|
||||||
|
torch._dynamo.config.cache_size_limit = 1024
|
||||||
|
torch._dynamo.config.accumulated_cache_size_limit = 1024
|
||||||
|
|
||||||
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
|
"""Init the metadata for a forward pass."""
|
||||||
|
# TODO: find a more elegant way to save memory
|
||||||
|
# Currently maintain the same memory as torch_native_backend
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# Provide two block_mask Lists per seq_idx for lower latency, later will support per layer level mask generation
|
||||||
|
self.extend_block_masks = []
|
||||||
|
self.decode_block_masks = []
|
||||||
|
|
||||||
|
if forward_batch.forward_mode.is_extend():
|
||||||
|
for seq_idx in range(forward_batch.seq_lens.shape[0]):
|
||||||
|
seq_len_kv = forward_batch.seq_lens[seq_idx]
|
||||||
|
seq_len_q = seq_len_kv
|
||||||
|
self.extend_block_masks.append(
|
||||||
|
create_block_mask(
|
||||||
|
self._causal_mask,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
seq_len_q,
|
||||||
|
seq_len_kv,
|
||||||
|
device=self.device,
|
||||||
|
_compile=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif forward_batch.forward_mode.is_decode():
|
||||||
|
for seq_idx in range(forward_batch.seq_lens.shape[0]):
|
||||||
|
seq_len_q = 1
|
||||||
|
seq_len_kv = forward_batch.seq_lens[seq_idx]
|
||||||
|
|
||||||
|
self.decode_block_masks.append(
|
||||||
|
create_block_mask(
|
||||||
|
self._decode_mask,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
seq_len_q,
|
||||||
|
seq_len_kv,
|
||||||
|
device=self.device,
|
||||||
|
_compile=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _causal_mask(self, b, h, q_idx, kv_idx):
|
||||||
|
return q_idx >= kv_idx
|
||||||
|
|
||||||
|
def _decode_mask(self, b, h, q_idx, kv_idx):
|
||||||
|
return q_idx <= kv_idx
|
||||||
|
|
||||||
|
def _run_flex_forward_extend(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
k_cache: torch.Tensor,
|
||||||
|
v_cache: torch.Tensor,
|
||||||
|
req_to_token: torch.Tensor,
|
||||||
|
req_pool_indices: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
extend_prefix_lens: torch.Tensor,
|
||||||
|
extend_seq_lens: torch.Tensor,
|
||||||
|
scaling=None,
|
||||||
|
enable_gqa=False,
|
||||||
|
causal=False,
|
||||||
|
):
|
||||||
|
"""Run the extend forward by using torch flex attention op.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: [num_tokens, num_heads, head_size]
|
||||||
|
output: [num_tokens, num_heads, head_size]
|
||||||
|
k_cache: [max_total_num_tokens, num_heads, head_size]
|
||||||
|
v_cache: [max_total_num_tokens, num_heads, head_size]
|
||||||
|
req_to_token: [max_num_reqs, max_context_len]
|
||||||
|
req_pool_indices: [num_seqs]
|
||||||
|
seq_lens: [num_seqs]
|
||||||
|
extend_prefix_lens: [num_seqs]
|
||||||
|
extend_seq_lens: [num_seqs]
|
||||||
|
scaling: float or None
|
||||||
|
enable_gqa: bool
|
||||||
|
causal: bool
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
output: [num_tokens, num_heads, head_size]
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert seq_lens.shape[0] == extend_prefix_lens.shape[0]
|
||||||
|
assert seq_lens.shape[0] == extend_seq_lens.shape[0]
|
||||||
|
|
||||||
|
# [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
|
||||||
|
query = query.movedim(0, query.dim() - 2)
|
||||||
|
|
||||||
|
start_q, start_kv = 0, 0
|
||||||
|
|
||||||
|
for seq_idx in range(seq_lens.shape[0]):
|
||||||
|
# TODO: this loop process a sequence per iter, this is inefficient.
|
||||||
|
# Need optimize the performance later.
|
||||||
|
extend_seq_len_q = extend_seq_lens[seq_idx]
|
||||||
|
prefill_seq_len_q = extend_prefix_lens[seq_idx]
|
||||||
|
|
||||||
|
seq_len_kv = seq_lens[seq_idx]
|
||||||
|
end_q = start_q + extend_seq_len_q
|
||||||
|
end_kv = start_kv + seq_len_kv
|
||||||
|
|
||||||
|
per_req_query = query[:, start_q:end_q, :]
|
||||||
|
per_req_query_redundant = torch.empty(
|
||||||
|
(per_req_query.shape[0], seq_len_kv, per_req_query.shape[2]),
|
||||||
|
dtype=per_req_query.dtype,
|
||||||
|
device=per_req_query.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
per_req_query_redundant[:, prefill_seq_len_q:, :] = per_req_query
|
||||||
|
|
||||||
|
# get key and value from cache. per_req_tokens contains the kv cache
|
||||||
|
# index for each token in the sequence.
|
||||||
|
req_pool_idx = req_pool_indices[seq_idx]
|
||||||
|
per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
|
||||||
|
per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
||||||
|
per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
||||||
|
|
||||||
|
if not causal:
|
||||||
|
raise NotImplementedError("Non-causal mode is not yet implemented.")
|
||||||
|
|
||||||
|
per_req_out_redundant = (
|
||||||
|
self.flex_attention(
|
||||||
|
per_req_query_redundant.unsqueeze(0),
|
||||||
|
per_req_key.unsqueeze(0),
|
||||||
|
per_req_value.unsqueeze(0),
|
||||||
|
block_mask=self.extend_block_masks[seq_idx],
|
||||||
|
scale=scaling,
|
||||||
|
enable_gqa=enable_gqa,
|
||||||
|
)
|
||||||
|
.squeeze(0)
|
||||||
|
.movedim(query.dim() - 2, 0)
|
||||||
|
)
|
||||||
|
output[start_q:end_q, :, :] = per_req_out_redundant[
|
||||||
|
prefill_seq_len_q:, :, :
|
||||||
|
]
|
||||||
|
start_q, start_kv = end_q, end_kv
|
||||||
|
return output
|
||||||
|
|
||||||
|
def _run_flex_forward_decode(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
k_cache: torch.Tensor,
|
||||||
|
v_cache: torch.Tensor,
|
||||||
|
req_to_token: torch.Tensor,
|
||||||
|
req_pool_indices: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
scaling=None,
|
||||||
|
enable_gqa=False,
|
||||||
|
causal=False,
|
||||||
|
):
|
||||||
|
"""Run the decode forward by using torch flex attention op.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: [num_tokens, num_heads, head_size]
|
||||||
|
output: [num_tokens, num_heads, head_size]
|
||||||
|
k_cache: [max_total_num_tokens, num_heads, head_size]
|
||||||
|
v_cache: [max_total_num_tokens, num_heads, head_size]
|
||||||
|
req_to_token: [max_num_reqs, max_context_len]
|
||||||
|
req_pool_indices: [num_seqs]
|
||||||
|
seq_lens: [num_seqs]
|
||||||
|
scaling: float or None
|
||||||
|
enable_gqa: bool
|
||||||
|
causal: bool
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
output: [num_tokens, num_heads, head_size]
|
||||||
|
"""
|
||||||
|
|
||||||
|
# [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
|
||||||
|
query = query.movedim(0, query.dim() - 2)
|
||||||
|
|
||||||
|
start_q, start_kv = 0, 0
|
||||||
|
for seq_idx in range(seq_lens.shape[0]):
|
||||||
|
# TODO: this loop process a sequence per iter, this is inefficient.
|
||||||
|
# Need optimize the performance later.
|
||||||
|
|
||||||
|
seq_len_q = 1
|
||||||
|
seq_len_kv = seq_lens[seq_idx]
|
||||||
|
end_q = start_q + seq_len_q
|
||||||
|
end_kv = start_kv + seq_len_kv
|
||||||
|
|
||||||
|
per_req_query = query[:, start_q:end_q, :]
|
||||||
|
|
||||||
|
# get key and value from cache. per_req_tokens contains the kv cache
|
||||||
|
# index for each token in the sequence.
|
||||||
|
req_pool_idx = req_pool_indices[seq_idx]
|
||||||
|
per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
|
||||||
|
per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
||||||
|
per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
||||||
|
|
||||||
|
per_req_out = (
|
||||||
|
self.flex_attention(
|
||||||
|
per_req_query.unsqueeze(0),
|
||||||
|
per_req_key.unsqueeze(0),
|
||||||
|
per_req_value.unsqueeze(0),
|
||||||
|
block_mask=self.decode_block_masks[seq_idx],
|
||||||
|
scale=scaling,
|
||||||
|
enable_gqa=enable_gqa,
|
||||||
|
)
|
||||||
|
.squeeze(0)
|
||||||
|
.movedim(query.dim() - 2, 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
output[start_q:end_q, :, :] = per_req_out
|
||||||
|
start_q, start_kv = end_q, end_kv
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward_extend(
|
||||||
|
self,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
layer: RadixAttention,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
save_kv_cache=True,
|
||||||
|
):
|
||||||
|
if layer.qk_head_dim != layer.v_head_dim:
|
||||||
|
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
||||||
|
else:
|
||||||
|
o = torch.empty_like(q)
|
||||||
|
|
||||||
|
if save_kv_cache:
|
||||||
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
|
layer, forward_batch.out_cache_loc, k, v
|
||||||
|
)
|
||||||
|
|
||||||
|
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
|
||||||
|
|
||||||
|
q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
||||||
|
o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
||||||
|
|
||||||
|
causal = True
|
||||||
|
if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"TorchFlexAttnBackend does not support non-causal attention for now."
|
||||||
|
)
|
||||||
|
|
||||||
|
self._run_flex_forward_extend(
|
||||||
|
q_,
|
||||||
|
o_,
|
||||||
|
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
||||||
|
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
||||||
|
forward_batch.req_to_token_pool.req_to_token,
|
||||||
|
forward_batch.req_pool_indices,
|
||||||
|
forward_batch.seq_lens,
|
||||||
|
forward_batch.extend_prefix_lens,
|
||||||
|
forward_batch.extend_seq_lens,
|
||||||
|
scaling=layer.scaling,
|
||||||
|
enable_gqa=use_gqa,
|
||||||
|
causal=causal,
|
||||||
|
)
|
||||||
|
return o
|
||||||
|
|
||||||
|
def forward_decode(
|
||||||
|
self,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
layer: RadixAttention,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
save_kv_cache=True,
|
||||||
|
):
|
||||||
|
# During torch.compile, there is a bug in rotary_emb that causes the
|
||||||
|
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
||||||
|
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
||||||
|
|
||||||
|
if layer.qk_head_dim != layer.v_head_dim:
|
||||||
|
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
||||||
|
else:
|
||||||
|
o = torch.empty_like(q)
|
||||||
|
|
||||||
|
if save_kv_cache:
|
||||||
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
|
layer, forward_batch.out_cache_loc, k, v
|
||||||
|
)
|
||||||
|
|
||||||
|
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
|
||||||
|
q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
||||||
|
o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
||||||
|
|
||||||
|
self._run_flex_forward_decode(
|
||||||
|
q_,
|
||||||
|
o_,
|
||||||
|
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
||||||
|
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
||||||
|
forward_batch.req_to_token_pool.req_to_token,
|
||||||
|
forward_batch.req_pool_indices,
|
||||||
|
forward_batch.seq_lens,
|
||||||
|
scaling=layer.scaling,
|
||||||
|
enable_gqa=use_gqa,
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return o
|
||||||
|
|
||||||
|
def support_triton(self):
|
||||||
|
return False
|
||||||
@@ -1786,6 +1786,12 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return TorchNativeAttnBackend(self)
|
return TorchNativeAttnBackend(self)
|
||||||
|
elif backend_str == "flex_attention":
|
||||||
|
from sglang.srt.layers.attention.torch_flex_backend import (
|
||||||
|
TorchFlexAttnBackend,
|
||||||
|
)
|
||||||
|
|
||||||
|
return TorchFlexAttnBackend(self)
|
||||||
elif backend_str == "flashmla":
|
elif backend_str == "flashmla":
|
||||||
from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
|
from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
|
||||||
|
|
||||||
|
|||||||
@@ -93,6 +93,7 @@ ATTENTION_BACKEND_CHOICES = [
|
|||||||
# Common
|
# Common
|
||||||
"triton",
|
"triton",
|
||||||
"torch_native",
|
"torch_native",
|
||||||
|
"flex_attention",
|
||||||
# NVIDIA specific
|
# NVIDIA specific
|
||||||
"cutlass_mla",
|
"cutlass_mla",
|
||||||
"fa3",
|
"fa3",
|
||||||
@@ -592,6 +593,15 @@ class ServerArgs:
|
|||||||
)
|
)
|
||||||
self.disable_cuda_graph = True
|
self.disable_cuda_graph = True
|
||||||
|
|
||||||
|
if self.attention_backend == "flex_attention":
|
||||||
|
logger.warning(
|
||||||
|
"Cuda graph is disabled because of using torch Flex Attention backend"
|
||||||
|
)
|
||||||
|
self.disable_cuda_graph = True
|
||||||
|
assert (
|
||||||
|
self.speculative_algorithm is None
|
||||||
|
), "Speculative decoding is currently not supported with Flex Attention backend"
|
||||||
|
|
||||||
if is_npu() and self.attention_backend in ["ascend", "hybrid_linear_attn"]:
|
if is_npu() and self.attention_backend in ["ascend", "hybrid_linear_attn"]:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"At this moment Ascend attention backend only supports a page_size of 128, change page_size to 128."
|
"At this moment Ascend attention backend only supports a page_size of 128, change page_size to 128."
|
||||||
|
|||||||
49
test/srt/test_torch_flex_attention_backend.py
Normal file
49
test/srt/test_torch_flex_attention_backend.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
python3 -m unittest test_torch_flex_attention_backend.TestTorchFlexAttnBackend.test_gsm8k
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
CustomTestCase,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTorchFlexAttnBackend(CustomTestCase):
|
||||||
|
def test_gsm8k(self):
|
||||||
|
model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||||
|
base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
process = popen_launch_server(
|
||||||
|
model,
|
||||||
|
base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=["--attention-backend", "flex_attention"],
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
args = SimpleNamespace(
|
||||||
|
num_shots=8,
|
||||||
|
data_path=None,
|
||||||
|
num_questions=100,
|
||||||
|
parallel=10,
|
||||||
|
max_new_tokens=512,
|
||||||
|
host="http://127.0.0.1",
|
||||||
|
port=int(base_url.split(":")[-1]),
|
||||||
|
)
|
||||||
|
metrics = run_eval_few_shot_gsm8k(args)
|
||||||
|
print(f"{metrics=}")
|
||||||
|
self.assertGreater(metrics["accuracy"], 0.62)
|
||||||
|
finally:
|
||||||
|
kill_process_tree(process.pid)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user