diff --git a/python/sglang/srt/layers/attention/torch_flex_backend.py b/python/sglang/srt/layers/attention/torch_flex_backend.py new file mode 100644 index 000000000..69f097efd --- /dev/null +++ b/python/sglang/srt/layers/attention/torch_flex_backend.py @@ -0,0 +1,325 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +from torch.nn.attention.flex_attention import create_block_mask, flex_attention + +from sglang.srt.layers.attention.base_attn_backend import AttentionBackend +from sglang.srt.layers.radix_attention import AttentionType +from sglang.srt.model_executor.forward_batch_info import ForwardBatch + +if TYPE_CHECKING: + from sglang.srt.layers.radix_attention import RadixAttention + from sglang.srt.model_executor.model_runner import ModelRunner + + +class TorchFlexAttnBackend(AttentionBackend): + def __init__(self, model_runner: ModelRunner): + super().__init__() + self.forward_metadata = None + self.device = model_runner.device + self.flex_attention = torch.compile(flex_attention, dynamic=True) + torch._dynamo.config.cache_size_limit = 1024 + torch._dynamo.config.accumulated_cache_size_limit = 1024 + + def init_forward_metadata(self, forward_batch: ForwardBatch): + """Init the metadata for a forward pass.""" + # TODO: find a more elegant way to save memory + # Currently maintain the same memory as torch_native_backend + torch.cuda.empty_cache() + + # Provide two block_mask Lists per seq_idx for lower latency, later will support per layer level mask generation + self.extend_block_masks = [] + self.decode_block_masks = [] + + if forward_batch.forward_mode.is_extend(): + for seq_idx in range(forward_batch.seq_lens.shape[0]): + seq_len_kv = forward_batch.seq_lens[seq_idx] + seq_len_q = seq_len_kv + self.extend_block_masks.append( + create_block_mask( + self._causal_mask, + None, + None, + seq_len_q, + seq_len_kv, + device=self.device, + _compile=False, + ) + ) + + elif forward_batch.forward_mode.is_decode(): + for seq_idx in range(forward_batch.seq_lens.shape[0]): + seq_len_q = 1 + seq_len_kv = forward_batch.seq_lens[seq_idx] + + self.decode_block_masks.append( + create_block_mask( + self._decode_mask, + None, + None, + seq_len_q, + seq_len_kv, + device=self.device, + _compile=False, + ) + ) + + def _causal_mask(self, b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + def _decode_mask(self, b, h, q_idx, kv_idx): + return q_idx <= kv_idx + + def _run_flex_forward_extend( + self, + query: torch.Tensor, + output: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + req_to_token: torch.Tensor, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + extend_prefix_lens: torch.Tensor, + extend_seq_lens: torch.Tensor, + scaling=None, + enable_gqa=False, + causal=False, + ): + """Run the extend forward by using torch flex attention op. + + Args: + query: [num_tokens, num_heads, head_size] + output: [num_tokens, num_heads, head_size] + k_cache: [max_total_num_tokens, num_heads, head_size] + v_cache: [max_total_num_tokens, num_heads, head_size] + req_to_token: [max_num_reqs, max_context_len] + req_pool_indices: [num_seqs] + seq_lens: [num_seqs] + extend_prefix_lens: [num_seqs] + extend_seq_lens: [num_seqs] + scaling: float or None + enable_gqa: bool + causal: bool + + Returns: + output: [num_tokens, num_heads, head_size] + """ + + assert seq_lens.shape[0] == extend_prefix_lens.shape[0] + assert seq_lens.shape[0] == extend_seq_lens.shape[0] + + # [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size] + query = query.movedim(0, query.dim() - 2) + + start_q, start_kv = 0, 0 + + for seq_idx in range(seq_lens.shape[0]): + # TODO: this loop process a sequence per iter, this is inefficient. + # Need optimize the performance later. + extend_seq_len_q = extend_seq_lens[seq_idx] + prefill_seq_len_q = extend_prefix_lens[seq_idx] + + seq_len_kv = seq_lens[seq_idx] + end_q = start_q + extend_seq_len_q + end_kv = start_kv + seq_len_kv + + per_req_query = query[:, start_q:end_q, :] + per_req_query_redundant = torch.empty( + (per_req_query.shape[0], seq_len_kv, per_req_query.shape[2]), + dtype=per_req_query.dtype, + device=per_req_query.device, + ) + + per_req_query_redundant[:, prefill_seq_len_q:, :] = per_req_query + + # get key and value from cache. per_req_tokens contains the kv cache + # index for each token in the sequence. + req_pool_idx = req_pool_indices[seq_idx] + per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv] + per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2) + per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2) + + if not causal: + raise NotImplementedError("Non-causal mode is not yet implemented.") + + per_req_out_redundant = ( + self.flex_attention( + per_req_query_redundant.unsqueeze(0), + per_req_key.unsqueeze(0), + per_req_value.unsqueeze(0), + block_mask=self.extend_block_masks[seq_idx], + scale=scaling, + enable_gqa=enable_gqa, + ) + .squeeze(0) + .movedim(query.dim() - 2, 0) + ) + output[start_q:end_q, :, :] = per_req_out_redundant[ + prefill_seq_len_q:, :, : + ] + start_q, start_kv = end_q, end_kv + return output + + def _run_flex_forward_decode( + self, + query: torch.Tensor, + output: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + req_to_token: torch.Tensor, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + scaling=None, + enable_gqa=False, + causal=False, + ): + """Run the decode forward by using torch flex attention op. + + Args: + query: [num_tokens, num_heads, head_size] + output: [num_tokens, num_heads, head_size] + k_cache: [max_total_num_tokens, num_heads, head_size] + v_cache: [max_total_num_tokens, num_heads, head_size] + req_to_token: [max_num_reqs, max_context_len] + req_pool_indices: [num_seqs] + seq_lens: [num_seqs] + scaling: float or None + enable_gqa: bool + causal: bool + + Returns: + output: [num_tokens, num_heads, head_size] + """ + + # [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size] + query = query.movedim(0, query.dim() - 2) + + start_q, start_kv = 0, 0 + for seq_idx in range(seq_lens.shape[0]): + # TODO: this loop process a sequence per iter, this is inefficient. + # Need optimize the performance later. + + seq_len_q = 1 + seq_len_kv = seq_lens[seq_idx] + end_q = start_q + seq_len_q + end_kv = start_kv + seq_len_kv + + per_req_query = query[:, start_q:end_q, :] + + # get key and value from cache. per_req_tokens contains the kv cache + # index for each token in the sequence. + req_pool_idx = req_pool_indices[seq_idx] + per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv] + per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2) + per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2) + + per_req_out = ( + self.flex_attention( + per_req_query.unsqueeze(0), + per_req_key.unsqueeze(0), + per_req_value.unsqueeze(0), + block_mask=self.decode_block_masks[seq_idx], + scale=scaling, + enable_gqa=enable_gqa, + ) + .squeeze(0) + .movedim(query.dim() - 2, 0) + ) + + output[start_q:end_q, :, :] = per_req_out + start_q, start_kv = end_q, end_kv + + return output + + def forward_extend( + self, + q, + k, + v, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + ): + if layer.qk_head_dim != layer.v_head_dim: + o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) + else: + o = torch.empty_like(q) + + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v + ) + + use_gqa = layer.tp_q_head_num != layer.tp_k_head_num + + q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim) + o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim) + + causal = True + if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY: + raise NotImplementedError( + "TorchFlexAttnBackend does not support non-causal attention for now." + ) + + self._run_flex_forward_extend( + q_, + o_, + forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id), + forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id), + forward_batch.req_to_token_pool.req_to_token, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.extend_prefix_lens, + forward_batch.extend_seq_lens, + scaling=layer.scaling, + enable_gqa=use_gqa, + causal=causal, + ) + return o + + def forward_decode( + self, + q, + k, + v, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + ): + # During torch.compile, there is a bug in rotary_emb that causes the + # output value to have a 3D tensor shape. This reshapes the output correctly. + q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) + + if layer.qk_head_dim != layer.v_head_dim: + o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) + else: + o = torch.empty_like(q) + + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v + ) + + use_gqa = layer.tp_q_head_num != layer.tp_k_head_num + q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim) + o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim) + + self._run_flex_forward_decode( + q_, + o_, + forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id), + forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id), + forward_batch.req_to_token_pool.req_to_token, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + scaling=layer.scaling, + enable_gqa=use_gqa, + causal=False, + ) + + return o + + def support_triton(self): + return False diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 8bfea5613..f2df9b134 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1786,6 +1786,12 @@ class ModelRunner: ) return TorchNativeAttnBackend(self) + elif backend_str == "flex_attention": + from sglang.srt.layers.attention.torch_flex_backend import ( + TorchFlexAttnBackend, + ) + + return TorchFlexAttnBackend(self) elif backend_str == "flashmla": from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 3459d67ee..d55b56794 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -93,6 +93,7 @@ ATTENTION_BACKEND_CHOICES = [ # Common "triton", "torch_native", + "flex_attention", # NVIDIA specific "cutlass_mla", "fa3", @@ -592,6 +593,15 @@ class ServerArgs: ) self.disable_cuda_graph = True + if self.attention_backend == "flex_attention": + logger.warning( + "Cuda graph is disabled because of using torch Flex Attention backend" + ) + self.disable_cuda_graph = True + assert ( + self.speculative_algorithm is None + ), "Speculative decoding is currently not supported with Flex Attention backend" + if is_npu() and self.attention_backend in ["ascend", "hybrid_linear_attn"]: logger.warning( "At this moment Ascend attention backend only supports a page_size of 128, change page_size to 128." diff --git a/test/srt/test_torch_flex_attention_backend.py b/test/srt/test_torch_flex_attention_backend.py new file mode 100644 index 000000000..832ac14c4 --- /dev/null +++ b/test/srt/test_torch_flex_attention_backend.py @@ -0,0 +1,49 @@ +""" +Usage: +python3 -m unittest test_torch_flex_attention_backend.TestTorchFlexAttnBackend.test_gsm8k +""" + +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + + +class TestTorchFlexAttnBackend(CustomTestCase): + def test_gsm8k(self): + model = DEFAULT_MODEL_NAME_FOR_TEST + base_url = DEFAULT_URL_FOR_TEST + process = popen_launch_server( + model, + base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--attention-backend", "flex_attention"], + ) + + try: + args = SimpleNamespace( + num_shots=8, + data_path=None, + num_questions=100, + parallel=10, + max_new_tokens=512, + host="http://127.0.0.1", + port=int(base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(f"{metrics=}") + self.assertGreater(metrics["accuracy"], 0.62) + finally: + kill_process_tree(process.pid) + + +if __name__ == "__main__": + unittest.main()