diff --git a/examples/runtime/engine/offline_batch_inference_qwen_1m.py b/examples/runtime/engine/offline_batch_inference_qwen_1m.py new file mode 100644 index 000000000..664efa6d7 --- /dev/null +++ b/examples/runtime/engine/offline_batch_inference_qwen_1m.py @@ -0,0 +1,74 @@ +""" +Usage: +python3 offline_batch_inference.py +""" + +from urllib.request import urlopen + +import sglang as sgl + + +def load_prompt() -> str: + # Test cases with various lengths can be found at: + # + # https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/64k.txt + # https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/200k.txt + # https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/600k.txt + # https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/1m.txt + + with urlopen( + "https://qianwen-res.oss-cn-beijing.aliyuncs.com" + "/Qwen2.5-1M/test-data/64k.txt", + timeout=5, + ) as response: + prompt = response.read().decode("utf-8") + return prompt + + +# Processing the prompt. +def process_requests(llm: sgl.Engine, prompts: list[str]) -> None: + # Create a sampling params object. + sampling_params = { + "temperature": 0.7, + "top_p": 0.8, + "top_k": 20, + "repetition_penalty": 1.05, + "max_new_tokens": 256, + } + # Generate texts from the prompts. + outputs = llm.generate(prompts, sampling_params) + # Print the outputs. + for output in outputs: + prompt_token_ids = output["meta_info"]["prompt_tokens"] + generated_text = output["text"] + print( + f"Prompt length: {prompt_token_ids}, " f"Generated text: {generated_text!r}" + ) + + +# Create an LLM. +def initialize_engine() -> sgl.Engine: + llm = sgl.Engine( + model_path="Qwen/Qwen2.5-7B-Instruct-1M", + context_length=1048576, + page_size=256, + attention_backend="dual_chunk_flash_attn", + tp_size=4, + disable_radix_cache=True, + enable_mixed_chunk=False, + enable_torch_compile=False, + chunked_prefill_size=131072, + mem_fraction_static=0.6, + log_level="DEBUG", + ) + return llm + + +def main(): + llm = initialize_engine() + prompt = load_prompt() + process_requests(llm, [prompt]) + + +if __name__ == "__main__": + main() diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 3091ed4fe..e03b32ca0 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -27,6 +27,7 @@ from sglang.srt.hf_transformers_utils import ( get_context_length, get_generation_config, get_hf_text_config, + get_sparse_attention_config, ) from sglang.srt.layers.quantization import QUANTIZATION_METHODS from sglang.srt.server_args import ServerArgs @@ -270,6 +271,9 @@ class ModelConfig: # Verify quantization self._verify_quantization() + # Verify dual-chunk attention config + self._verify_dual_chunk_attention_config() + # Cache attributes self.hf_eos_token_id = self.get_hf_eos_token_id() @@ -297,6 +301,13 @@ class ModelConfig: **kwargs, ) + def get_total_num_attention_heads(self) -> int: + return self.num_attention_heads + + def get_num_attention_heads(self, tensor_parallel_size) -> int: + total_num_attention_heads = self.num_attention_heads + return max(1, total_num_attention_heads // tensor_parallel_size) + # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289 def get_total_num_kv_heads(self) -> int: """Returns the total number of KV heads.""" @@ -484,6 +495,23 @@ class ModelConfig: self.quantization, ) + def _verify_dual_chunk_attention_config(self) -> None: + if hasattr(self.hf_config, "dual_chunk_attention_config"): + # Try loading the sparse attention config + sparse_attn_config = get_sparse_attention_config(self.model_path) + if not sparse_attn_config: + return + self.hf_config.dual_chunk_attention_config["sparse_attention_config"] = ( + sparse_attn_config + ) + if ( + "sparse_attention_enabled" + not in self.hf_config.dual_chunk_attention_config + ): + self.hf_config.dual_chunk_attention_config[ + "sparse_attention_enabled" + ] = True + def get_hf_eos_token_id(self) -> Optional[Set[int]]: eos_ids = getattr(self.hf_config, "eos_token_id", None) if eos_ids is not None: diff --git a/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py b/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py index 872d8a741..c1cb17c04 100644 --- a/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py +++ b/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py @@ -76,6 +76,9 @@ class ScheduleBatchDisaggregationDecodeMixin: req_pool_indices, dtype=torch.int64, device=self.device ) self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device) + self.orig_seq_lens = torch.tensor( + seq_lens, dtype=torch.int32, device=self.device + ) self.out_cache_loc = out_cache_loc self.seq_lens_sum = sum(seq_lens) diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index e4c87d573..1e9b32f01 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -14,10 +14,11 @@ """Utilities for Huggingface Transformers.""" import contextlib +import json import os import warnings from pathlib import Path -from typing import Dict, Optional, Type, Union +from typing import Any, Dict, Optional, Type, Union import torch from huggingface_hub import snapshot_download @@ -62,11 +63,17 @@ for name, cls in _CONFIG_REGISTRY.items(): AutoConfig.register(name, cls) -def download_from_hf(model_path: str): +def download_from_hf( + model_path: str, + allow_patterns: Optional[Union[str, list]] = None, +): if os.path.exists(model_path): return model_path - return snapshot_download(model_path, allow_patterns=["*.json", "*.bin", "*.model"]) + if not allow_patterns: + allow_patterns = ["*.json", "*.bin", "*.model"] + + return snapshot_download(model_path, allow_patterns=allow_patterns) def get_hf_text_config(config: PretrainedConfig): @@ -171,6 +178,26 @@ def get_generation_config( return None +# Qwen-1M related +def get_sparse_attention_config( + model: str, + sparse_attention_config_filename: str = "sparse_attention_config.json", +) -> Dict[str, Any]: + is_local = os.path.isdir(model) + if not is_local: + # Download the config files. + model = download_from_hf(model, allow_patterns=["*.json"]) + + config_file = os.path.join(model, sparse_attention_config_filename) + if not os.path.exists(config_file): + return {} + + # Load the sparse attention config. + with open(config_file) as f: + config = json.load(f) + return config + + # Models don't use the same configuration key for determining the maximum # context length. Store them here so we can sanely check them. # NOTE: The ordering here is important. Some models have two of these and we diff --git a/python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py b/python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py new file mode 100644 index 000000000..ea97ada22 --- /dev/null +++ b/python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py @@ -0,0 +1,1700 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Attention layer with Dual chunk flash attention and sparse attention. +""" +import functools +import logging +import math +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + +import torch +import torch.nn.functional as F +from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache +from sgl_kernel.sparse_flash_attn import ( + convert_vertical_slash_indexes, + convert_vertical_slash_indexes_mergehead, + sparse_attn_func, +) + +from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank +from sglang.srt.layers.attention.base_attn_backend import AttentionBackend +from sglang.srt.layers.attention.flashattention_backend import FlashAttentionMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode + +if TYPE_CHECKING: + from sglang.srt.layers.radix_attention import RadixAttention + from sglang.srt.model_executor.model_runner import ModelRunner + + +logger = logging.getLogger(__name__) + + +@dataclass +class DualChunkFlashAttentionMetadata: + """Metadata for FlashAttentionBackend. + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] = None + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] = None + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_seq_len: int = None + + # (batch_size,). The orig sequence length per sequence. + orig_seq_lens: Optional[List[int]] = None + + # orig_seq_lens stored as a tensor. + orig_seq_lens_tensor: Optional[torch.Tensor] = None + + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] = None + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] = None + + # Length scaling factor + scaling_factor: Optional[torch.Tensor] = None + + # (batch_size,). Sequence lengths for intra attention. + seq_lens_intra: Optional[torch.Tensor] = None + + # Max sequence length for intra attention. + max_seq_len_intra: Optional[int] = None + + # (batch_size, num_blocks). Block table for intra attention. + block_tables_intra: Optional[torch.Tensor] = None + + # (batch_size,). Sequence lengths for succ attention. + seq_lens_succ: Optional[torch.Tensor] = None + + # Max sequence length for succ attention. + max_seq_len_succ: Optional[int] = None + + # (batch_size, num_blocks). Block table for succ attention. + block_tables_succ: Optional[torch.Tensor] = None + + # (batch_size,). Sequence lengths for inter attention. + seq_lens_inter: Optional[torch.Tensor] = None + + # Max sequence length for inter attention. + max_seq_len_inter: Optional[int] = None + + +class DualChunkFlashAttentionBackend(AttentionBackend): + def __init__( + self, + model_runner: "ModelRunner", + ) -> None: + self.forward_metadata: FlashAttentionMetadata = None + self.device = model_runner.device + self.max_context_len = model_runner.model_config.context_len + self.num_heads = model_runner.model_config.get_num_attention_heads( + model_runner.server_args.tp_size + ) + self.num_kv_heads = model_runner.model_config.get_num_kv_heads( + model_runner.server_args.tp_size + ) + self.head_size = model_runner.model_config.head_dim + + self.req_to_token = model_runner.req_to_token_pool.req_to_token + self.kv_cache_dtype = model_runner.kv_cache_dtype + self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype + self.page_size = model_runner.page_size + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + dual_chunk_attention_config = getattr( + model_runner.model_config.hf_config, "dual_chunk_attention_config", None + ) + assert dual_chunk_attention_config is not None + self.chunk_size = dual_chunk_attention_config.get("chunk_size", 8192) + self.local_size = dual_chunk_attention_config.get("local_size", 1024) + self.original_max_position_embeddings = dual_chunk_attention_config.get( + "original_max_position_embeddings", 0 + ) + self.sparse_attention_config = dual_chunk_attention_config.get( + "sparse_attention_config", None + ) + if not self.sparse_attention_config: + logger.warning_once( + "Sparse attention will not be enabled as " + "sparse attention config is not provided." + ) + self.sparse_attention_enabled = dual_chunk_attention_config.get( + "sparse_attention_enabled", self.sparse_attention_config is not None + ) + self.sparse_attention_threshold = dual_chunk_attention_config.get( + "sparse_attention_threshold", 32768 + ) + self.sparse_attention_last_q = dual_chunk_attention_config.get( + "sparse_attention_last_q", 64 + ) + self.dual_chunk_attention_config = dual_chunk_attention_config + + if self.sparse_attention_enabled: + self.arange = torch.arange(self.sparse_attention_last_q, device="cuda") + self.last_q_mask = ( + self.arange[None, None, :, None] >= self.arange[None, None, None, :] + ) + + @functools.lru_cache() + def get_sparse_attention_config(self, layer_idx) -> List[Dict[str, Any]]: + layer_sparse_attention_config = { + int(i): j for i, j in self.sparse_attention_config[layer_idx].items() + } + start_head = self.num_heads * get_tensor_model_parallel_rank() + end_head = start_head + self.num_heads + return [layer_sparse_attention_config[i] for i in range(start_head, end_head)] + + def init_forward_metadata(self, forward_batch: ForwardBatch): + """Initialize forward metadata hence all layers in the forward pass can reuse it.""" + + forward_mode: ForwardMode = forward_batch.forward_mode + assert forward_mode.is_prefill() or forward_mode.is_decode() + batch_size = forward_batch.batch_size + + metadata = DualChunkFlashAttentionMetadata() + metadata.seq_lens_tensor = forward_batch.seq_lens.to(torch.int32) + metadata.seq_lens = forward_batch.seq_lens.tolist() + metadata.max_seq_len = forward_batch.seq_lens.max().item() + + metadata.orig_seq_lens_tensor = forward_batch.orig_seq_lens + metadata.orig_seq_lens = forward_batch.orig_seq_lens.tolist() + + metadata.block_tables = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : metadata.max_seq_len + ] + # Convert the block table to a strided format. + if self.page_size > 1: + strided_indices = torch.arange( + 0, metadata.block_tables.shape[1], self.page_size, device=self.device + ) + metadata.block_tables = ( + metadata.block_tables[:, strided_indices] // self.page_size + ) + + metadata.query_start_loc = torch.zeros( + batch_size + 1, dtype=torch.int32, device=metadata.seq_lens_tensor.device + ) + if forward_mode.is_prefill(): + metadata.query_start_loc[1:] = torch.cumsum( + forward_batch.extend_seq_lens.to(torch.int32), dim=0, dtype=torch.int32 + ) + else: + metadata.query_start_loc[1:] = torch.cumsum( + torch.arange( + batch_size, + dtype=metadata.query_start_loc.dtype, + device=metadata.query_start_loc.device, + ), + dim=0, + dtype=torch.int32, + ) + metadata.seq_start_loc = torch.zeros( + batch_size + 1, dtype=torch.int32, device=metadata.seq_lens_tensor.device + ) + metadata.seq_start_loc[1:] = torch.cumsum( + metadata.seq_lens_tensor, dim=0, dtype=torch.int32 + ) + + if self.original_max_position_embeddings > 0: + if forward_mode.is_prefill(): + metadata.scaling_factor = ( + 0.1 + * torch.log( + metadata.orig_seq_lens_tensor + / self.original_max_position_embeddings + ) + + 1.0 + ).clip(min=1) + else: + metadata.scaling_factor = ( + 0.1 + * torch.log( + metadata.orig_seq_lens_tensor + / self.original_max_position_embeddings + ) + + 1.0 + ).clip(min=1) + + if forward_mode.is_decode(): + cache_seq_lens = metadata.orig_seq_lens_tensor + + chunk_len = self.chunk_size - self.local_size + chunk_num_curr = (cache_seq_lens - 1) // chunk_len + + seq_lens_intra = cache_seq_lens - chunk_num_curr * chunk_len + max_seq_len_intra = seq_lens_intra.max().item() + metadata.seq_lens_intra = seq_lens_intra + metadata.max_seq_len_intra = max_seq_len_intra + + block_tables_intra = torch.zeros( + batch_size, + (max_seq_len_intra - 1) // self.page_size + 1, + dtype=metadata.block_tables.dtype, + device=metadata.block_tables.device, + ) + for i in range(batch_size): + st = chunk_num_curr[i] * chunk_len // self.page_size + ed = min( + st + (max_seq_len_intra - 1) // self.page_size + 1, + (cache_seq_lens[i] - 1) // self.page_size + 1, + ) + block_tables_intra[i, : ed - st] = metadata.block_tables[i, st:ed] + metadata.block_tables_intra = block_tables_intra + + metadata.seq_lens_succ = ( + chunk_num_curr - (chunk_num_curr - 1).clip(min=0) + ) * chunk_len + metadata.max_seq_len_succ = metadata.seq_lens_succ.max().item() + if metadata.max_seq_len_succ: + block_tables_succ = torch.zeros( + batch_size, + (metadata.max_seq_len_succ - 1) // self.page_size + 1, + dtype=metadata.block_tables.dtype, + device=metadata.block_tables.device, + ) + for i in range(batch_size): + start = ( + (chunk_num_curr[i] - 1).clip(min=0) + * chunk_len + // self.page_size + ) + end = min( + start + (metadata.max_seq_len_succ - 1) // self.page_size + 1, + (cache_seq_lens[i] - 1) // self.page_size + 1, + ) + block_tables_succ[i, : end - start] = metadata.block_tables[ + i, start:end + ] + metadata.block_tables_succ = block_tables_succ + + metadata.seq_lens_inter = (chunk_num_curr - 1).clip(min=0) * chunk_len + metadata.max_seq_len_inter = metadata.seq_lens_inter.max().item() + + self.forward_metadata = metadata + + def forward_extend( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: "RadixAttention", + forward_batch: ForwardBatch, + save_kv_cache=True, + ): + # Use precomputed metadata across all layers + metadata = self.forward_metadata + + ( + query, + query_succ, + query_inter, + query_succ_critical, + query_inter_critical, + ) = torch.split(q, q.shape[-1] // 5, dim=-1) + + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + query_succ = query_succ.view(-1, self.num_heads, self.head_size) + query_inter = query_inter.view(-1, self.num_heads, self.head_size) + query_succ_critical = query_succ_critical.view( + -1, self.num_heads, self.head_size + ) + query_inter_critical = query_inter_critical.view( + -1, self.num_heads, self.head_size + ) + key = k.view(-1, self.num_kv_heads, self.head_size) + value = v.view(-1, self.num_kv_heads, self.head_size) + + # apply DCA scaling + if self.original_max_position_embeddings > 0: + assert metadata.scaling_factor is not None + assert metadata.query_start_loc is not None + assert metadata.orig_seq_lens is not None + current_start = 0 + query_start_loc_cpu = metadata.query_start_loc.cpu() + for i in range(len(metadata.orig_seq_lens)): + current_end = ( + current_start + + (query_start_loc_cpu[i + 1] - query_start_loc_cpu[i]).item() + ) + key[current_start:current_end].mul_(metadata.scaling_factor[i]) + current_start = current_end + assert current_end <= self.max_context_len + + # Do multi-head attention + key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) + key_cache = key_cache.view( + -1, self.page_size, layer.tp_k_head_num, layer.head_dim + ) + value_cache = value_cache.view( + -1, self.page_size, layer.tp_v_head_num, layer.head_dim + ) + + if key is not None and value is not None: + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, + forward_batch.out_cache_loc, + key, + value, + layer.k_scale, + layer.v_scale, + ) + + if not save_kv_cache: + # profile run + o = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=metadata.seq_start_loc, + cu_seqlens_k=metadata.seq_start_loc, + max_seqlen_q=metadata.max_seq_len, + max_seqlen_k=metadata.max_seq_len, + softmax_scale=layer.scaling, + causal=True, + ) + else: + # prefill/chunked-prefill + # get per layer sparse attention config + if self.sparse_attention_enabled: + self.layer_sparse_attention_config = self.get_sparse_attention_config( + layer.layer_id + ) + assert metadata.orig_seq_lens is not None + o = self._dual_chunk_flash_attn_prefill( + q=query, + q_succ=query_succ, + q_inter=query_inter, + q_succ_critical=query_succ_critical, + q_inter_critical=query_inter_critical, + k=key_cache, + v=value_cache, + cu_seqlens_q=metadata.query_start_loc, + cu_seqlens_k=metadata.seq_start_loc, + orig_seq_lens=metadata.orig_seq_lens, + scaling_factor=metadata.scaling_factor, + softmax_scale=layer.scaling, + causal=True, + window_size=(-1, -1), + block_table=metadata.block_tables, + chunk_size=self.chunk_size, + local_size=self.local_size, + ) + return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) + + def forward_decode( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: "RadixAttention", + forward_batch: ForwardBatch, + save_kv_cache=True, + ) -> torch.Tensor: + # Use precomputed metadata across all layers + metadata = self.forward_metadata + + ( + query, + query_succ, + query_inter, + query_succ_critical, + query_inter_critical, + ) = torch.split(q, q.shape[-1] // 5, dim=-1) + + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + query_succ = query_succ.view(-1, self.num_heads, self.head_size) + query_inter = query_inter.view(-1, self.num_heads, self.head_size) + query_succ_critical = query_succ_critical.view( + -1, self.num_heads, self.head_size + ) + query_inter_critical = query_inter_critical.view( + -1, self.num_heads, self.head_size + ) + key = k.view(-1, self.num_kv_heads, self.head_size) + value = v.view(-1, self.num_kv_heads, self.head_size) + + key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) + key_cache = key_cache.view( + -1, self.page_size, layer.tp_k_head_num, layer.head_dim + ) + value_cache = value_cache.view( + -1, self.page_size, layer.tp_v_head_num, layer.head_dim + ) + + if key is not None and value is not None: + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, + forward_batch.out_cache_loc, + key, + value, + layer.k_scale, + layer.v_scale, + ) + + # apply DCA scaling + if self.original_max_position_embeddings > 0: + assert metadata.scaling_factor is not None + scaling_factor = metadata.scaling_factor + key.mul_(scaling_factor.unsqueeze(-1).unsqueeze(-1)) + + o = self._dual_chunk_flash_attn_decoding( + query.unsqueeze(1), + query_succ.unsqueeze(1), + query_inter.unsqueeze(1), + key_cache, + value_cache, + block_table=metadata.block_tables, + cache_seqlens=metadata.seq_lens_tensor, + softmax_scale=layer.scaling, + causal=True, + chunk_size=self.chunk_size, + local_size=self.local_size, + original_max_position_embeddings=self.original_max_position_embeddings, + decode_meta=metadata, + ).squeeze(1) + return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) + + def init_cuda_graph_state(self, max_bs: int): + """Initialize CUDA graph state for the attention backend. + + Args: + max_bs (int): Maximum batch size to support in CUDA graphs + + This creates fixed-size tensors that will be reused during CUDA graph replay + to avoid memory allocations. + """ + self.decode_metadata = { + "seq_lens_tensor": torch.zeros( + max_bs, dtype=torch.int32, device=self.device + ), + "orig_seq_lens_tensor": torch.zeros( + max_bs, dtype=torch.int32, device=self.device + ), + "scaling_factor": torch.zeros( + max_bs, dtype=torch.float32, device=self.device + ), + "block_tables": torch.zeros( + max_bs, + (self.max_context_len - 1) // self.page_size + 1, + dtype=torch.int32, + device=self.device, + ), + "block_tables_intra": torch.zeros( + max_bs, + (self.max_context_len - 1) // self.page_size + 1, + dtype=torch.int32, + device=self.device, + ), + "seq_lens_intra": torch.zeros( + max_bs, dtype=torch.int32, device=self.device + ), + "block_tables_succ": torch.zeros( + max_bs, + (self.max_context_len - 1) // self.page_size + 1, + dtype=torch.int32, + device=self.device, + ), + "seq_lens_succ": torch.zeros(max_bs, dtype=torch.int32, device=self.device), + "seq_lens_inter": torch.zeros( + max_bs, dtype=torch.int32, device=self.device + ), + } + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + num_tokens: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[None], + ): + metadata = DualChunkFlashAttentionMetadata() + + if forward_mode.is_decode_or_idle(): + if self.original_max_position_embeddings > 0: + metadata.scaling_factor = self.decode_metadata["scaling_factor"][:bs] + + metadata.seq_lens_tensor = self.decode_metadata["seq_lens_tensor"][:bs] + metadata.orig_seq_lens_tensor = self.decode_metadata[ + "orig_seq_lens_tensor" + ][:bs] + metadata.max_seq_len = self.max_context_len + metadata.block_tables = self.decode_metadata["block_tables"][ + req_pool_indices, : + ] + + # intra + metadata.max_seq_len_intra = self.max_context_len + metadata.seq_lens_intra = self.decode_metadata["seq_lens_intra"][:bs] + + metadata.block_tables_intra = self.decode_metadata["block_tables_intra"][ + :bs, : + ] + + # succ + metadata.seq_lens_succ = self.decode_metadata["seq_lens_succ"][:bs] + metadata.max_seq_len_succ = self.max_context_len + + metadata.block_tables_succ = self.decode_metadata["block_tables_succ"][ + :bs, : + ] + + metadata.seq_lens_inter = self.decode_metadata["seq_lens_inter"][:bs] + metadata.max_seq_len_inter = self.max_context_len + + self.decode_metadata[bs] = metadata + + self.forward_metadata = metadata + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[None], + seq_lens_cpu: Optional[torch.Tensor], + out_cache_loc: torch.Tensor = None, + ): + """Initialize forward metadata for replaying CUDA graph.""" + assert forward_mode.is_decode() + seq_lens = seq_lens[:bs] + req_pool_indices = req_pool_indices[:bs] + metadata = self.decode_metadata[bs] + + metadata.seq_lens_tensor.copy_(seq_lens.to(torch.int32)) + metadata.seq_lens = seq_lens.tolist() + metadata.max_seq_len = seq_lens.max().item() + + metadata.orig_seq_lens_tensor.copy_(seq_lens) + metadata.orig_seq_lens = seq_lens.tolist() + + block_tables = self.req_to_token[req_pool_indices, : metadata.max_seq_len] + # Convert the block table to a strided format. + if self.page_size > 1: + strided_indices = torch.arange( + 0, block_tables.shape[1], self.page_size, device=self.device + ) + block_tables = block_tables[:, strided_indices] // self.page_size + metadata.block_tables.fill_(0) + metadata.block_tables[: block_tables.shape[0], : block_tables.shape[1]].copy_( + block_tables + ) + + if self.original_max_position_embeddings > 0: + scaling_factor = ( + 0.1 + * torch.log( + metadata.orig_seq_lens_tensor + / self.original_max_position_embeddings + ) + + 1.0 + ).clip(min=1) + metadata.scaling_factor.copy_(scaling_factor) + + cache_seq_lens = metadata.orig_seq_lens_tensor + + chunk_len = self.chunk_size - self.local_size + chunk_num_curr = (cache_seq_lens - 1) // chunk_len + + seq_lens_intra = cache_seq_lens - chunk_num_curr * chunk_len + max_seq_len_intra = seq_lens_intra.max().item() + metadata.seq_lens_intra.copy_(seq_lens_intra) + metadata.max_seq_len_intra = max_seq_len_intra + + metadata.block_tables_intra.fill_(0) + for i in range(bs): + st = chunk_num_curr[i] * chunk_len // self.page_size + ed = min( + st + (max_seq_len_intra - 1) // self.page_size + 1, + (cache_seq_lens[i] - 1) // self.page_size + 1, + ) + metadata.block_tables_intra[i, : ed - st] = metadata.block_tables[i, st:ed] + + seq_lens_succ = (chunk_num_curr - (chunk_num_curr - 1).clip(min=0)) * chunk_len + metadata.seq_lens_succ.copy_(seq_lens_succ) + metadata.max_seq_len_succ = metadata.seq_lens_succ.max().item() + if metadata.max_seq_len_succ: + metadata.block_tables_succ.fill_(0) + for i in range(bs): + start = ( + (chunk_num_curr[i] - 1).clip(min=0) * chunk_len // self.page_size + ) + end = min( + start + (metadata.max_seq_len_succ - 1) // self.page_size + 1, + (cache_seq_lens[i] - 1) // self.page_size + 1, + ) + metadata.block_tables_succ[i, : end - start] = metadata.block_tables[ + i, start:end + ] + + seq_lens_inter = (chunk_num_curr - 1).clip(min=0) * chunk_len + metadata.seq_lens_inter.copy_(seq_lens_inter) + metadata.max_seq_len_inter = metadata.seq_lens_inter.max().item() + + self.forward_metadata = metadata + + def get_cuda_graph_seq_len_fill_value(self): + """Get the fill value for sequence length in CUDA graph.""" + return 1 + + def _dual_chunk_flash_attn_prefill( + self, + q, + q_succ, + q_inter, + q_succ_critical, + q_inter_critical, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + orig_seq_lens: List[int], + scaling_factor: torch.Tensor, + softmax_scale: float, + causal: Optional[bool] = True, + window_size: Tuple[int, int] = (-1, -1), + block_table: Optional[torch.Tensor] = None, + chunk_size: int = 8192, + local_size: int = 1024, + ): + if not causal: + raise ValueError("Dual Chunk Attention does not support causal=False") + if window_size != (-1, -1): + raise ValueError("Dual Chunk Attention does not support window_size") + + cu_seqlens_q_cpu = cu_seqlens_q.cpu().tolist() + cu_seqlens_k_cpu = cu_seqlens_k.cpu().tolist() + all_outputs = [] + + for i in range(0, len(cu_seqlens_q_cpu) - 1): + qs = cu_seqlens_q_cpu[i] + qe = cu_seqlens_q_cpu[i : i + 2][-1] + ks = cu_seqlens_k_cpu[i] + ke = cu_seqlens_k_cpu[i : i + 2][-1] + + current_q = q[qs:qe] + current_q_succ = q_succ[qs:qe] + current_q_inter = q_inter[qs:qe] + current_q_succ_critical = q_succ_critical[qs:qe] + current_q_inter_critical = q_inter_critical[qs:qe] + + if block_table is None: + current_k = k[ks:ke] + current_v = v[ks:ke] + current_block_table = None + current_orig_seq_len = orig_seq_lens[i] + else: + current_block_table = block_table[i] + current_orig_seq_len = orig_seq_lens[i] + current_k = k + current_v = v + sparse_attn_enabled = ( + self.sparse_attention_enabled + and current_orig_seq_len > self.sparse_attention_threshold + ) + + if current_q.shape[0] == 0: + continue + + if current_k.shape[0] == 0: + all_outputs.append( + torch.zeros( + (current_q.shape[0], current_q.shape[1], v.shape[2]), + device=q.device, + dtype=q.dtype, + ) + ) + continue + + current_output = torch.empty_like(current_q) + group_size = int(current_q.size(-2) / current_k.size(-2)) + + if sparse_attn_enabled: + num_device_q_heads = current_q.size(-2) + heads_vertical_size = torch.empty( + size=(num_device_q_heads,), dtype=torch.int32 + ) + heads_slash_size = torch.empty( + size=(num_device_q_heads,), dtype=torch.int32 + ) + for head_id in range(current_q.size(-2)): + ( + ty, + vertical_size, + slash_size, + _, + ) = self.layer_sparse_attention_config[head_id] + assert ty == "vertical_and_slash", "only support slash mode" + + if vertical_size == 30: + vertical_size += 100 + heads_vertical_size[head_id] = vertical_size + heads_slash_size[head_id] = slash_size + + current_output = self._dual_chunk_flash_attn_prefill_func( + current_q, # allheads + current_q_succ, + current_q_inter, + current_q_succ_critical, + current_q_inter_critical, + current_k, + current_v, + current_block_table, + softmax_scale, + chunk_size, + local_size, + scaling_factor[i].item(), + ke - ks, + sparse_attn_enabled=sparse_attn_enabled, + heads_vertical_size=heads_vertical_size, + heads_slash_size=heads_slash_size, + group_size=group_size, + ) + else: + for head_id in range(current_q.size(-2)): + # (seq_len, num_heads, head_size) + current_q_head = current_q[:, head_id, :].unsqueeze(1) + current_q_succ_head = current_q_succ[:, head_id, :].unsqueeze(1) + current_q_inter_head = current_q_inter[:, head_id, :].unsqueeze(1) + current_q_succ_head_critical = current_q_succ_critical[ + :, head_id, : + ].unsqueeze(1) + current_q_inter_head_critical = current_q_inter_critical[ + :, head_id, : + ].unsqueeze(1) + if block_table is not None: + current_k_head = current_k[ + ..., head_id // group_size, : + ].unsqueeze(2) + current_v_head = current_v[ + ..., head_id // group_size, : + ].unsqueeze(2) + + else: + current_k_head = current_k[:, head_id, :].unsqueeze(1) + current_v_head = current_v[:, head_id, :].unsqueeze(1) + + current_out = self._dual_chunk_flash_attn_prefill_func( + current_q_head, + current_q_succ_head, + current_q_inter_head, + current_q_succ_head_critical, + current_q_inter_head_critical, + current_k_head, + current_v_head, + current_block_table, + softmax_scale, + chunk_size, + local_size, + scaling_factor[i].item(), + ke - ks, + sparse_attn_enabled=sparse_attn_enabled, + ) + current_output[:, head_id : head_id + 1, :] = current_out + all_outputs.append(current_output) + return torch.cat(all_outputs, dim=0) + + def _dual_chunk_flash_attn_prefill_func( + self, + q, + q_succ, + q_inter, + q_succ_critical, + q_inter_critical, + k, + v, + block_table, + softmax_scale: float, + chunk_size: int, + local_size: int, + scaling_factor: float, + k_length: int, + sparse_attn_enabled: Optional[bool] = True, + heads_vertical_size=None, + heads_slash_size=None, + group_size=None, + ): + flash_results = [] + chunk_len = chunk_size - local_size + + if block_table is not None: + block_size = v.shape[1] + if chunk_len % block_size != 0: + raise ValueError("chunk_len must be divisible by block_size.") + else: + block_size = 1 + + if self.original_max_position_embeddings > 0: + softmax_scale = softmax_scale * scaling_factor + + begin = k_length - q.shape[0] + while begin < k_length: + flash_per_chunk = [] + + prev_chunk_end_pos = (begin // chunk_len) * chunk_len + next_chunk_end_pos = prev_chunk_end_pos + chunk_len + end = min(next_chunk_end_pos, k_length) + qbegin = begin - (k_length - q.shape[0]) + qend = end - (k_length - q.shape[0]) + + qk_chunks = [] + q_states_intra = q[qbegin:qend] + # choose critical token + if block_table is not None: + block_tables_intra = _get_block( + block_table, block_size, prev_chunk_end_pos, end + ) + k_states_intra = k[block_tables_intra].view(-1, *k.shape[-2:])[ + : (end - prev_chunk_end_pos) + ] + v_states_intra = v[block_tables_intra].view(-1, *v.shape[-2:])[ + : (end - prev_chunk_end_pos) + ] + else: + block_tables_intra = None + k_states_intra = k[prev_chunk_end_pos:end] + v_states_intra = v[prev_chunk_end_pos:end] + + if sparse_attn_enabled: + last_q_size = min(qend - qbegin, self.sparse_attention_last_q) + _, num_device_k_heads, head_dim = k_states_intra.shape + k_states_intra = ( + k_states_intra.unsqueeze(2) + .repeat(1, 1, group_size, 1) + .reshape(-1, num_device_k_heads * group_size, head_dim) + ) + v_states_intra = ( + v_states_intra.unsqueeze(2) + .repeat(1, 1, group_size, 1) + .reshape(-1, num_device_k_heads * group_size, head_dim) + ) + qk_chunks.append( + (q_states_intra.transpose(0, 1)[:, -last_q_size:] * softmax_scale) + @ k_states_intra.permute(1, 2, 0) + ) + + if prev_chunk_end_pos - chunk_len >= 0: + q_states_succ = q_succ[qbegin:qend] + q_states_succ_critical = q_succ_critical[qbegin:qend] + if block_table is not None: + block_tables_succ = _get_block( + block_table, + block_size, + prev_chunk_end_pos - chunk_len, + prev_chunk_end_pos, + ) + k_states_succ = k[block_tables_succ].view(-1, *k.shape[-2:])[ + :chunk_len + ] + v_states_succ = v[block_tables_succ].view(-1, *v.shape[-2:])[ + :chunk_len + ] + else: + k_states_succ = k[ + prev_chunk_end_pos - chunk_len : prev_chunk_end_pos + ] + v_states_succ = v[ + prev_chunk_end_pos - chunk_len : prev_chunk_end_pos + ] + + if sparse_attn_enabled: + k_states_succ = ( + k_states_succ.unsqueeze(2) + .repeat(1, 1, group_size, 1) + .reshape(-1, num_device_k_heads * group_size, head_dim) + ) + v_states_succ = ( + v_states_succ.unsqueeze(2) + .repeat(1, 1, group_size, 1) + .reshape(-1, num_device_k_heads * group_size, head_dim) + ) + qk_chunks.append( + ( + q_states_succ_critical.transpose(0, 1)[:, -last_q_size:] + * softmax_scale + ) + @ k_states_succ.permute(1, 2, 0) + ) + + if prev_chunk_end_pos - chunk_len * 2 >= 0: + q_states_inter = q_inter[qbegin:qend] + q_states_inter_critical = q_inter_critical[qbegin:qend] + if block_table is not None: + block_tables_inter = _get_block( + block_table, block_size, 0, prev_chunk_end_pos - chunk_len + ) + k_states_inter = k[block_tables_inter].view(-1, *k.shape[-2:])[ + : (prev_chunk_end_pos - chunk_len) + ] + v_states_inter = v[block_tables_inter].view(-1, *v.shape[-2:])[ + : (prev_chunk_end_pos - chunk_len) + ] + else: + k_states_inter = k[: prev_chunk_end_pos - chunk_len] + v_states_inter = v[: prev_chunk_end_pos - chunk_len] + + if sparse_attn_enabled: + k_states_inter = ( + k_states_inter.unsqueeze(2) + .repeat(1, 1, group_size, 1) + .reshape(-1, num_device_k_heads * group_size, head_dim) + ) + v_states_inter = ( + v_states_inter.unsqueeze(2) + .repeat(1, 1, group_size, 1) + .reshape(-1, num_device_k_heads * group_size, head_dim) + ) + qk_chunks.append( + ( + q_states_inter_critical.transpose(0, 1)[:, -last_q_size:] + * softmax_scale + ) + @ k_states_inter.permute(1, 2, 0) + ) + + if sparse_attn_enabled: + reversed_qk = qk_chunks[::-1] + qk = torch.cat(reversed_qk, dim=-1) + + qk[:, :, -last_q_size:] = torch.where( + self.last_q_mask[..., -last_q_size:, -last_q_size:].to(qk.device), + qk[:, :, -last_q_size:], + -torch.inf, + ) + qk = F.softmax(qk, dim=-1, dtype=torch.float32) + + vertical = qk.sum(-2, keepdim=True) + vertical[..., :30] = torch.inf + + # Avoid sorting by using the min/max ints to fill the indexer + # buffers. + int32_max = torch.iinfo(torch.int32).max + int32_min = torch.iinfo(torch.int32).min + n_heads = qk.size()[0] + max_slash_topk = torch.max(heads_slash_size).item() + max_vertical_topk = torch.max(heads_vertical_size).item() + # store each head's slash topk, vertical topk + vertical = vertical.reshape((n_heads, -1)) + # prevent out of range when prompt size < max_vertical_topk + max_vertical_topk = min(vertical.shape[-1], max_vertical_topk) + vertical_topk_buffer = torch.topk( + vertical, max_vertical_topk, -1 + ).indices + slash_topk_buffer = torch.empty( + size=(n_heads, max_slash_topk), dtype=torch.int64, device=qk.device + ) + for head_i in range(n_heads): + # (nqheads=1, lastq, k_len) + head_score = qk[head_i : head_i + 1, :, :] + slash_scores = _sum_all_diagonal_matrix(head_score) + if head_score.size(1) != 1: + # drop right up corner + slash_scores = slash_scores[..., : -last_q_size + 1] + slash_scores[..., -100:] = torch.inf + + head_slash_size = heads_slash_size[head_i] + head_slash_size = min(head_slash_size, vertical.size(-1)) + slash_topk = torch.topk(slash_scores, head_slash_size, -1).indices + # (nheads, max_topk) + slash_topk_buffer[head_i, :head_slash_size] = slash_topk + + # reset heads topk + heads_slash_size[head_i] = head_slash_size + heads_vertical_size[head_i] = min( + heads_vertical_size[head_i], max_vertical_topk + ) + + # store + vertical_buffer = torch.full( + (n_heads, max_vertical_topk), + int32_max, + dtype=torch.int64, + device=q.device, + ) + slash_buffer = torch.full( + (n_heads, max_slash_topk), + int32_min, + dtype=torch.int64, + device=q.device, + ) + succ_vertical_buffer = torch.full( + (n_heads, max_vertical_topk), + int32_max, + dtype=torch.int64, + device=q.device, + ) + succ_slash_buffer = torch.full( + (n_heads, max_slash_topk), + int32_min, + dtype=torch.int64, + device=q.device, + ) + inter_vertical_buffer = torch.full( + (n_heads, max_vertical_topk), + int32_max, + dtype=torch.int64, + device=q.device, + ) + inter_slash_buffer = torch.full( + (n_heads, max_slash_topk), + int32_min, + dtype=torch.int64, + device=q.device, + ) + + vertical_size_buffer = torch.empty( + size=(n_heads,), dtype=torch.int32, device=q.device + ) + slash_sizes_buffer = torch.empty( + size=(n_heads,), dtype=torch.int32, device=q.device + ) + succ_vertical_size_buffer = torch.empty( + size=(n_heads,), dtype=torch.int32, device=q.device + ) + succ_slash_sizes_buffer = torch.empty( + size=(n_heads,), dtype=torch.int32, device=q.device + ) + inter_vertical_size_buffer = torch.empty( + size=(n_heads,), dtype=torch.int32, device=q.device + ) + inter_slash_sizes_buffer = torch.empty( + size=(n_heads,), dtype=torch.int32, device=q.device + ) + + for head_i in range(n_heads): + vertical_topk = vertical_topk_buffer[ + head_i, : heads_vertical_size[head_i] + ] + # intra + intra_vertical_indices = ( + vertical_topk[vertical_topk >= prev_chunk_end_pos] + - prev_chunk_end_pos + ) + if intra_vertical_indices.nelement() == 0: + intra_vertical_indices = torch.cat( + [ + intra_vertical_indices, + torch.arange( + 0, + k_states_intra.size(0), + max(1, k_states_intra.size(0) / 5), + dtype=torch.int32, + device=intra_vertical_indices.device, + ), + ] + ) + slash_topk = slash_topk_buffer[head_i, : heads_slash_size[head_i]] + intra_slash_indices = (qk.size(-1) - 1) - slash_topk[ + slash_topk >= prev_chunk_end_pos + ] + # fill buffer + v_count = intra_vertical_indices.nelement() + s_count = intra_slash_indices.nelement() + vertical_size_buffer[head_i] = v_count + slash_sizes_buffer[head_i] = s_count + vertical_buffer[head_i, :v_count].copy_(intra_vertical_indices) + slash_buffer[head_i, :s_count].copy_(intra_slash_indices) + # succ + if prev_chunk_end_pos - chunk_len >= 0: + succ_vertical_indices = vertical_topk[ + (vertical_topk < prev_chunk_end_pos) + & (vertical_topk >= prev_chunk_end_pos - chunk_len) + ] - (prev_chunk_end_pos - chunk_len) + # TODO: support no vertical + if succ_vertical_indices.nelement() == 0: + succ_vertical_indices = torch.cat( + [ + succ_vertical_indices, + torch.arange( + 0, + k_states_succ.size(0), + max(1, k_states_succ.size(0) / 5), + dtype=torch.int32, + device=intra_vertical_indices.device, + ), + ] + ) + succ_slash_indices = ( + prev_chunk_end_pos + (qend - qbegin) - 1 + ) - slash_topk[ + ( + (slash_topk >= (prev_chunk_end_pos - chunk_len)) + & (slash_topk < (prev_chunk_end_pos + (qend - qbegin))) + ) + ] + if succ_slash_indices.nelement() == 0: + succ_slash_indices = torch.cat( + [ + succ_slash_indices, + torch.arange( + 0, + k_states_succ.size(0), + max(1, k_states_succ.size(0) / 5), + dtype=torch.int32, + device=intra_vertical_indices.device, + ), + ] + ) + # fill buffer + v_count = succ_vertical_indices.nelement() + s_count = succ_slash_indices.nelement() + succ_vertical_size_buffer[head_i] = v_count + succ_slash_sizes_buffer[head_i] = s_count + succ_vertical_buffer[head_i, :v_count].copy_( + succ_vertical_indices + ) + succ_slash_buffer[head_i, :s_count].copy_(succ_slash_indices) + + if prev_chunk_end_pos - 2 * chunk_len >= 0: + inter_vertical_indices = vertical_topk[ + vertical_topk < prev_chunk_end_pos - chunk_len + ] + + if inter_vertical_indices.nelement() == 0: + inter_vertical_indices = torch.cat( + [ + inter_vertical_indices, + torch.arange( + 0, + k_states_inter.size(0), + max(1, k_states_inter.size(0) / 5), + dtype=torch.int32, + device=intra_vertical_indices.device, + ), + ] + ) + inter_slash_indices = ( + prev_chunk_end_pos - chunk_len + (qend - qbegin) - 1 + ) - slash_topk[ + slash_topk + < (prev_chunk_end_pos - chunk_len + (qend - qbegin)) + ] + if inter_slash_indices.nelement() == 0: + inter_slash_indices = torch.cat( + [ + inter_slash_indices, + torch.arange( + 0, + k_states_inter.size(0), + max(1, k_states_inter.size(0) / 5), + dtype=torch.int32, + device=intra_vertical_indices.device, + ), + ] + ) + # fill buffer + v_count = inter_vertical_indices.nelement() + s_count = inter_slash_indices.nelement() + inter_vertical_size_buffer[head_i] = v_count + inter_slash_sizes_buffer[head_i] = s_count + inter_vertical_buffer[head_i, :v_count].copy_( + inter_vertical_indices + ) + inter_slash_buffer[head_i, :s_count].copy_(inter_slash_indices) + else: + intra_vertical_indices, intra_slash_indices = None, None + succ_vertical_indices, succ_slash_indices = None, None + inter_vertical_indices, inter_slash_indices = None, None + + if sparse_attn_enabled: + flash_result = self._do_flash_attn( + q_states_intra, + k_states_intra, + v_states_intra, + softmax_scale=softmax_scale, + causal=True, + stage="intra", + vertical_indices=vertical_buffer, + slash_indices=slash_buffer, + vertical_indices_count=vertical_size_buffer, + slash_indices_count=slash_sizes_buffer, + mergehead_softmax_scale=softmax_scale, + sparse_attn_enabled=sparse_attn_enabled, + ) + else: + flash_result = self._do_flash_attn( + q_states_intra, + k_states_intra, + v_states_intra, + softmax_scale=softmax_scale, + causal=True, + stage="intra", + vertical_indices=intra_vertical_indices, + slash_indices=intra_slash_indices, + sparse_attn_enabled=sparse_attn_enabled, + ) + flash_per_chunk.append(flash_result) + + if prev_chunk_end_pos - chunk_len >= 0: + if sparse_attn_enabled: + flash_result = self._do_flash_attn( + q_states_succ, + k_states_succ, + v_states_succ, + softmax_scale=softmax_scale, + causal=False, + stage="succ", + vertical_indices=succ_vertical_buffer, + slash_indices=succ_slash_buffer, + vertical_indices_count=succ_vertical_size_buffer, + slash_indices_count=succ_slash_sizes_buffer, + mergehead_softmax_scale=softmax_scale, + sparse_attn_enabled=sparse_attn_enabled, + ) + else: + flash_result = self._do_flash_attn( + q_states_succ, + k_states_succ, + v_states_succ, + softmax_scale=softmax_scale, + causal=False, + stage="succ", + vertical_indices=succ_vertical_indices, + slash_indices=succ_slash_indices, + sparse_attn_enabled=sparse_attn_enabled, + ) + flash_per_chunk.append(flash_result) + + if prev_chunk_end_pos - chunk_len * 2 >= 0: + if sparse_attn_enabled: + flash_result = self._do_flash_attn( + q_states_inter, + k_states_inter, + v_states_inter, + softmax_scale=softmax_scale, + causal=False, + stage="inter", + vertical_indices=inter_vertical_buffer, + slash_indices=inter_slash_buffer, + vertical_indices_count=inter_vertical_size_buffer, + slash_indices_count=inter_slash_sizes_buffer, + mergehead_softmax_scale=softmax_scale, + sparse_attn_enabled=sparse_attn_enabled, + ) + else: + flash_result = self._do_flash_attn( + q_states_inter, + k_states_inter, + v_states_inter, + softmax_scale=softmax_scale, + causal=False, + stage="inter", + vertical_indices=inter_vertical_indices, + slash_indices=inter_slash_indices, + sparse_attn_enabled=sparse_attn_enabled, + ) + flash_per_chunk.append(flash_result) + + flash_results.append(flash_per_chunk) + begin = end + + attn_output = self._merge_attn_outputs(flash_results) + del flash_results + return attn_output + + def _do_flash_attn( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + softmax_scale: float, + causal: bool = True, + max_seqlen_k: Optional[int] = None, + stage: str = "intra", + vertical_indices: Optional[torch.Tensor] = None, + slash_indices: Optional[torch.Tensor] = None, + vertical_indices_count: Optional[torch.Tensor] = None, + slash_indices_count: Optional[torch.Tensor] = None, + mergehead_softmax_scale: Optional[float] = None, + sparse_attn_enabled: Optional[bool] = False, + ): + if max_seqlen_k is None: + max_seqlen_k = key_states.shape[0] + + q_len = query_states.shape[0] + q_heads = query_states.shape[1] + h_dim = query_states.shape[-1] + + if sparse_attn_enabled: + assert slash_indices is not None + if stage == "intra": + assert causal + else: + assert not causal + + query_states = query_states.unsqueeze(0).transpose(1, 2) + key_states = key_states.unsqueeze(0).transpose(1, 2) + value_states = value_states.unsqueeze(0).transpose(1, 2) + + q = query_states + k = key_states + v = value_states + + if vertical_indices_count is not None and slash_indices_count is not None: + assert mergehead_softmax_scale is not None + + res, s_lse = _vertical_slash_sparse_attention( + q, + k, + v, + vertical_indices, + slash_indices, + mergehead_softmax_scale, + causal=causal, + stage=stage, + vertical_indices_count=vertical_indices_count, + slash_indices_count=slash_indices_count, + ) + res = res.view(q_heads, q_len, h_dim).transpose( + 0, 1 + ) # (qlen,nhead,h_dim) + s_lse = ( + s_lse.view(q_heads, q_len, 1).squeeze(-1).unsqueeze(0).float() + ) # (1, nhead,qlen) + else: + res, s_lse = _vertical_slash_sparse_attention( + q, + k, + v, + vertical_indices, + slash_indices, + softmax_scale, + causal=causal, + stage=stage, + ) + res = res.view(q_len, q_heads, h_dim) + s_lse = s_lse.view(q_len, q_heads, 1).transpose(0, 2).float() + return res, s_lse + + output, softmax_lse, *rest = flash_attn_varlen_func( + q=query_states, + k=key_states, + v=value_states, + softmax_scale=softmax_scale, + cu_seqlens_q=torch.tensor( + [0, query_states.shape[0]], + dtype=torch.int32, + device=query_states.device, + ), + max_seqlen_q=query_states.shape[0], + cu_seqlens_k=torch.tensor( + [0, max_seqlen_k], dtype=torch.int32, device=query_states.device + ), + max_seqlen_k=max_seqlen_k, + causal=causal, + return_softmax_lse=True, + ) + softmax_lse = softmax_lse.view(q_len, q_heads, 1).transpose(0, 2).float() + return output, softmax_lse + + def _merge_attn_outputs( + self, + flash_results: List[List[Tuple[torch.Tensor, torch.Tensor]]], + return_lse: Optional[bool] = False, + ) -> torch.Tensor: + attn_outputs_all = [] + logits_all = [] + + for flash_per_chunk in flash_results: + if len(flash_per_chunk) == 1: + attn_outputs_all.append(flash_per_chunk[0][0]) + if return_lse: + logits_all.append(flash_per_chunk[0][1]) + continue + + attn_outputs = torch.stack( + [flash_attn_output[0] for flash_attn_output in flash_per_chunk] + ) + logits = torch.stack( + [flash_attn_output[1] for flash_attn_output in flash_per_chunk] + ) + logits = logits.to(torch.float32) + + if return_lse: + max_val = torch.max(logits, dim=0).values + diff = torch.abs(logits[0] - logits[1]) + log_sum_exp = max_val + torch.log1p(torch.exp(-diff)) + logits_all.append(log_sum_exp) + + max_logits = torch.max(logits, dim=0).values + stable_logits = logits - max_logits.unsqueeze(0) + lse_s = torch.exp(stable_logits).detach() + lse_sum = torch.sum(lse_s, dim=0) + lse_s /= lse_sum + attn_outputs *= lse_s.unsqueeze(-1).transpose(2, 3).squeeze(1) + attn_outputs_all.append(attn_outputs.sum(dim=0)) + + if return_lse: + return (torch.cat(attn_outputs_all, dim=0), torch.cat(logits_all, dim=-1)) + else: + return torch.cat(attn_outputs_all, dim=0) + + def _dual_chunk_flash_attn_decoding( + self, + query: torch.Tensor, + query_succ: torch.Tensor, + query_inter: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_table: torch.Tensor, + cache_seqlens: torch.Tensor, + softmax_scale: float, + causal: bool, + chunk_size: int, + local_size: int, + original_max_position_embeddings: int, + decode_meta: DualChunkFlashAttentionMetadata, + ): + if not causal: + raise ValueError("Dual Chunk Attention does not support causal=False") + + block_size = value_cache.shape[1] + chunk_len = chunk_size - local_size + if chunk_len % block_size != 0: + raise ValueError("chunk_len must be divisible by block_size.") + if original_max_position_embeddings > 0: + assert decode_meta.scaling_factor is not None + scaling_factor = decode_meta.scaling_factor + query = (query * scaling_factor.view(-1, 1, 1, 1)).to( + query.dtype + ) # possible for numerical issue, need to fused in the kernel + query_succ = (query_succ * scaling_factor.view(-1, 1, 1, 1)).to(query.dtype) + query_inter = (query_inter * scaling_factor.view(-1, 1, 1, 1)).to( + query.dtype + ) + outputs_list = [] + softmax_lses_list = [] + + # intra-attention + intra_output, intra_softmax_lse = ( + self._dual_chunk_flash_attn_decoding_with_exp_sums( + query, + key_cache, + value_cache, + decode_meta.block_tables_intra, + decode_meta.seq_lens_intra, + softmax_scale, + causal=False, + ) + ) + outputs_list.append(intra_output) + softmax_lses_list.append(intra_softmax_lse) + + # succ-attention + if decode_meta.max_seq_len_succ: + succ_output, succ_softmax_lse = ( + self._dual_chunk_flash_attn_decoding_with_exp_sums( + query_succ, + key_cache, + value_cache, + decode_meta.block_tables_succ, + decode_meta.seq_lens_succ, + softmax_scale, + causal=False, + ) + ) + outputs_list.append(succ_output) + softmax_lses_list.append(succ_softmax_lse) + + # inter-attention + if decode_meta.max_seq_len_inter: + inter_output, inter_softmax_lse = ( + self._dual_chunk_flash_attn_decoding_with_exp_sums( + query_inter, + key_cache, + value_cache, + block_table[:, : decode_meta.max_seq_len_inter], + decode_meta.seq_lens_inter, + softmax_scale, + causal=False, + ) + ) + outputs_list.append(inter_output) + softmax_lses_list.append(inter_softmax_lse) + outputs = torch.stack(outputs_list, dim=0) + del outputs_list + softmax_lses = torch.stack(softmax_lses_list, dim=0).to(torch.float32) + del softmax_lses_list + max_logits = torch.max(softmax_lses, dim=0).values + stable_logits = softmax_lses - max_logits.unsqueeze(0) + lse_s = torch.exp(stable_logits).detach() + lse_sum = torch.sum(lse_s, dim=0) + lse_s /= lse_sum + outputs *= lse_s.unsqueeze(-1).transpose(2, 3) + return outputs.sum(0) + + def _dual_chunk_flash_attn_decoding_with_exp_sums( + self, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_table: torch.Tensor, + cache_seqlens: torch.Tensor, + softmax_scale: float, + causal: bool, + ): + out, softmax_lse, *rest_expand = flash_attn_with_kvcache( + q=query, + k_cache=key_cache, + v_cache=value_cache, + page_table=block_table, + cache_seqlens=cache_seqlens, + softmax_scale=softmax_scale, + causal=causal, + return_softmax_lse=True, + ) + mask = cache_seqlens == 0 + out[mask] = 0 + softmax_lse[mask] = -float("inf") + return out, softmax_lse + + +def _vertical_slash_sparse_attention( + query: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD] + key: torch.Tensor, # [BATCH, N_HEADS, N_KV_CTX, D_HEAD] + value: torch.Tensor, # [BATCH, N_HEADS, N_KV_CTX, D_HEAD] + v_idx: torch.Tensor, # [BATCH, N_HEADS, NNZ_V] + s_idx: torch.Tensor, # [BATCH, N_HEADS, NNZ_S] + softmax_scale: float, + causal: bool = True, + stage: str = "intra", + block_size_M: int = 64, + block_size_N: int = 64, + vertical_indices_count: torch.Tensor = None, # [N_HEADS,] + slash_indices_count: torch.Tensor = None, +): + if stage == "intra": + assert causal + else: + assert not causal + + batch_size, num_heads, context_size, head_dim = query.shape + _, _, kv_seq_len, _ = key.shape + + if head_dim not in [16, 32, 64, 128, 256, 512]: + target_dim = 2 ** math.ceil(math.log2(head_dim)) - head_dim + query = F.pad(query, [0, target_dim, 0, 0, 0, 0, 0, 0]) + key = F.pad(key, [0, target_dim, 0, 0, 0, 0, 0, 0]) + value = F.pad(value, [0, target_dim, 0, 0, 0, 0, 0, 0]) + + v_idx = ( + v_idx.to(torch.int32) + .reshape((batch_size, num_heads, -1)) + .sort(dim=-1, descending=False)[0] + ) + s_idx = ( + s_idx.to(torch.int32) + .reshape((batch_size, num_heads, -1)) + .sort(dim=-1, descending=True)[0] + ) + q_seqlens = torch.tensor([context_size], dtype=torch.int32, device=query.device) + kv_seqlens = torch.tensor([kv_seq_len], dtype=torch.int32, device=query.device) + + if vertical_indices_count is not None and slash_indices_count is not None: + ( + block_count, + block_offset, + column_count, + column_index, + ) = convert_vertical_slash_indexes_mergehead( + q_seqlens, + kv_seqlens, + v_idx, + s_idx, + vertical_indices_count, + slash_indices_count, + context_size, + block_size_M, + block_size_N, + causal, + ) + else: + ( + block_count, + block_offset, + column_count, + column_index, + ) = convert_vertical_slash_indexes( + q_seqlens, + kv_seqlens, + v_idx, + s_idx, + context_size, + block_size_M, + block_size_N, + causal, + ) + + q = query.transpose(1, 2).contiguous() + k = key.transpose(1, 2).contiguous() + v = value.transpose(1, 2).contiguous() + out, lse = sparse_attn_func( + q, + k, + v, + block_count, + block_offset, + column_count, + column_index, + causal=causal, + softmax_scale=softmax_scale, + return_softmax_lse=True, + ) + out = out.transpose(1, 2).contiguous() + softmax_lse = lse.reshape(*lse.shape, 1) + return (out[..., :context_size, :head_dim], softmax_lse[..., :context_size, :]) + + +def _sum_all_diagonal_matrix(mat: torch.tensor): + h, n, m = mat.shape + # Zero matrix used for padding + zero_mat = torch.zeros((h, n, n), device=mat.device) + # pads the matrix on left and right + mat_padded = torch.cat((zero_mat, mat, zero_mat), -1) + # Change the strides + mat_strided = mat_padded.as_strided( + (1, n, n + m), (n * (2 * n + m), 2 * n + m + 1, 1) + ) + # Sums the resulting matrix's columns + sum_diags = torch.sum(mat_strided, 1) + return sum_diags[:, 1:] # drop left bottom corner + + +def _get_block(block_table: torch.Tensor, block_size: int, begin: int, end: int): + begin_block = begin // block_size + end_block = (end - 1) // block_size + 1 + return block_table[begin_block:end_block] diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 0be507f84..252362201 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -1172,6 +1172,202 @@ class MRotaryEmbedding(RotaryEmbedding): ) +class DualChunkRotaryEmbedding(CustomOp): + """Rotary positional embedding for Dual Chunk Attention.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + chunk_size: int, + local_size: int, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.chunk_size = chunk_size + self.local_size = local_size + self.dtype = dtype + self.device = torch.device(f"cuda:{torch.cuda.current_device()}") + (q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache) = ( + self._compute_cos_sin_cache() + ) + + self.register_buffer("cos_sin_q_cache", q_cache, persistent=False) + self.register_buffer("cos_sin_qc_cache", qc_cache, persistent=False) + self.register_buffer("cos_sin_k_cache", k_cache, persistent=False) + self.register_buffer( + "cos_sin_qc_no_clamp_cache", qc_no_clamp_cache, persistent=False + ) + self.register_buffer("cos_sin_q_inter_cache", q_inter_cache, persistent=False) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + """Compute the inverse frequency.""" + # NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`. + # However, we use `torch.arange(..., dtype=torch.float)` instead to + # avoid numerical issues with large base values (e.g., 10000000). + # This may cause a slight numerical difference between the HF + # implementation and ours. + # NOTE(woosuk): To exactly match the HF implementation, we need to + # use CPU to compute the cache and then move it to GPU. However, we + # create the cache on GPU for faster initialization. This may cause + # a slight numerical difference between the HF implementation and ours. + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim + ) + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + chunk_len = self.chunk_size - self.local_size + q_t = torch.arange(chunk_len, dtype=torch.float) + qc_t = (torch.arange(chunk_len, dtype=torch.float) + chunk_len).clamp( + max=self.chunk_size + ) + k_t = torch.arange(self.max_position_embeddings, dtype=torch.float) % chunk_len + + # count from chunk_len, no clamp(self.chunk_size) restriction + qc_no_clamp_t = torch.arange(chunk_len, dtype=torch.float) + chunk_len + # count from self.chunk_size for q_inter's rope + q_inter_t = torch.arange(chunk_len, dtype=torch.float) + self.chunk_size + + q_freqs = torch.outer(q_t, inv_freq) + qc_freqs = torch.outer(qc_t, inv_freq) + k_freqs = torch.outer(k_t, inv_freq) + qc_no_clamp_freqs = torch.outer(qc_no_clamp_t, inv_freq) + q_inter_freqs = torch.outer(q_inter_t, inv_freq) + + q_cos = q_freqs.cos() + q_sin = q_freqs.sin() + qc_cos = qc_freqs.cos() + qc_sin = qc_freqs.sin() + k_cos = k_freqs.cos() + k_sin = k_freqs.sin() + + qc_no_clamp_cos = qc_no_clamp_freqs.cos() + qc_no_clamp_sin = qc_no_clamp_freqs.sin() + q_inter_cos = q_inter_freqs.cos() + q_inter_sin = q_inter_freqs.sin() + + q_cache = torch.cat((q_cos, q_sin), dim=-1).to( + dtype=self.dtype, device=self.device + ) + qc_cache = torch.cat((qc_cos, qc_sin), dim=-1).to( + dtype=self.dtype, device=self.device + ) + k_cache = torch.cat((k_cos, k_sin), dim=-1).to( + dtype=self.dtype, device=self.device + ) + qc_no_clamp_cache = torch.cat((qc_no_clamp_cos, qc_no_clamp_sin), dim=-1).to( + dtype=self.dtype, device=self.device + ) + q_inter_cache = torch.cat((q_inter_cos, q_inter_sin), dim=-1).to( + dtype=self.dtype, device=self.device + ) + return q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + query = query.view(*query.shape[:-1], -1, self.head_size) + key = key.view(*key.shape[:-1], -1, self.head_size) + query_rot = query[..., : self.rotary_dim] + key_rot = key[..., : self.rotary_dim] + if self.rotary_dim < self.head_size: + query_pass = query[..., self.rotary_dim :] + key_pass = key[..., self.rotary_dim :] + else: + query_pass = None + key_pass = None + + positions_with_offsets = ( + torch.add(positions, offsets) if offsets is not None else positions + ) + key = self._apply_rotary_embedding( + self.cos_sin_k_cache[positions_with_offsets], key_rot, key_pass + ) + chunk_len = self.chunk_size - self.local_size + query = self._apply_rotary_embedding( + self.cos_sin_q_cache[positions_with_offsets % chunk_len], + query_rot, + query_pass, + ) + query_succ = self._apply_rotary_embedding( + self.cos_sin_qc_cache[positions_with_offsets % chunk_len], + query_rot, + query_pass, + ) + query_inter = self._apply_rotary_embedding( + self.cos_sin_qc_cache[chunk_len - 1].repeat(positions.shape[0], 1), + query_rot, + query_pass, + ) + query_succ_critical = self._apply_rotary_embedding( + self.cos_sin_qc_no_clamp_cache[positions_with_offsets % chunk_len], + query_rot, + query_pass, + ) + query_inter_critical = self._apply_rotary_embedding( + self.cos_sin_q_inter_cache[positions_with_offsets % chunk_len], + query_rot, + query_pass, + ) + + # merge query into one tensor to simplify the interfaces + query = torch.cat( + ( + query, + query_succ, + query_inter, + query_succ_critical, + query_inter_critical, + ), + dim=-1, + ) + return query, key + + def _apply_rotary_embedding(self, cos_sin, hidden_rot, hidden_pass): + cos, sin = cos_sin.chunk(2, dim=-1) + if self.is_neox_style: + # NOTE(woosuk): Here we assume that the positions tensor has the + # shape [batch_size, seq_len]. + cos = cos.repeat(1, 1, 2).unsqueeze(-2) + sin = sin.repeat(1, 1, 2).unsqueeze(-2) + else: + cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) + sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) + rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj + hidden_rot = hidden_rot * cos + rotate_fn(hidden_rot) * sin + + if self.rotary_dim < self.head_size: + hidden = torch.cat((hidden_rot, hidden_pass), dim=-1) + else: + hidden = hidden_rot + return hidden.flatten(-2).squeeze(0) + + def extra_repr(self) -> str: + s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" + s += f", max_position_embeddings={self.max_position_embeddings}" + s += f", base={self.base}, is_neox_style={self.is_neox_style}" + s += f", chunk_size={self.chunk_size}, local_size={self.local_size}" + return s + + _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} @@ -1184,6 +1380,7 @@ def get_rope( rope_scaling: Optional[Dict[str, Any]] = None, dtype: Optional[torch.dtype] = None, partial_rotary_factor: float = 1.0, + dual_chunk_attention_config: Optional[Dict[str, Any]] = None, ) -> RotaryEmbedding: if dtype is None: dtype = torch.get_default_dtype() @@ -1195,6 +1392,17 @@ def get_rope( rope_scaling_args = tuple(rope_scaling_tuple.items()) else: rope_scaling_args = None + + if dual_chunk_attention_config is not None: + dual_chunk_attention_tuple = { + k: tuple(v) if isinstance(v, list) else v + for k, v in dual_chunk_attention_config.items() + if k != "sparse_attention_config" + } + dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items()) + else: + dual_chunk_attention_args = None + if partial_rotary_factor < 1.0: rotary_dim = int(rotary_dim * partial_rotary_factor) key = ( @@ -1204,12 +1412,28 @@ def get_rope( base, is_neox_style, rope_scaling_args, + dual_chunk_attention_args, dtype, ) if key in _ROPE_DICT: return _ROPE_DICT[key] - if rope_scaling is None: + if dual_chunk_attention_config is not None: + extra_kwargs = { + k: v + for k, v in dual_chunk_attention_config.items() + if k in ("chunk_size", "local_size") + } + rotary_emb = DualChunkRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + **extra_kwargs, + ) + elif rope_scaling is None: rotary_emb = RotaryEmbedding( head_size, rotary_dim, max_position, base, is_neox_style, dtype ) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 689ef94b3..da47667bd 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -846,6 +846,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): # The sum of all sequence lengths seq_lens_sum: int = None + # The original sequence lengths, Qwen-1M related + orig_seq_lens: torch.Tensor = None # shape: [b], int32 # For DP attention global_num_tokens: Optional[List[int]] = None @@ -1131,6 +1133,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] extend_num_tokens = sum(len(ids) for ids in input_ids) seq_lens = [len(r.fill_ids) for r in reqs] + orig_seq_lens = [max(len(r.fill_ids), len(r.origin_input_ids)) for r in reqs] prefix_lens = [len(r.prefix_indices) for r in reqs] extend_lens = [r.extend_input_len for r in reqs] @@ -1147,6 +1150,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to( self.device, non_blocking=True ) + orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to( + self.device, non_blocking=True + ) prefix_lens_tensor = torch.tensor( prefix_lens, dtype=torch.int64, device=self.device ) @@ -1260,6 +1266,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.input_ids = input_ids_tensor self.req_pool_indices = req_pool_indices_tensor self.seq_lens = seq_lens_tensor + self.orig_seq_lens = orig_seq_lens_tensor self.out_cache_loc = out_cache_loc self.input_embeds = ( torch.tensor(input_embeds).to(self.device, non_blocking=True) @@ -1507,6 +1514,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.forward_mode = ForwardMode.IDLE self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device) self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device) + self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device) self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device) self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device) self.seq_lens_sum = 0 @@ -1561,9 +1569,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): if self.enable_overlap: # Do not use in-place operations in the overlap mode self.seq_lens = self.seq_lens + 1 + self.orig_seq_lens = self.orig_seq_lens + 1 else: # A faster in-place version self.seq_lens.add_(1) + self.orig_seq_lens.add_(1) self.seq_lens_sum += bs # free memory @@ -1627,6 +1637,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices] self.req_pool_indices = self.req_pool_indices[keep_indices_device] self.seq_lens = self.seq_lens[keep_indices_device] + self.orig_seq_lens = self.orig_seq_lens[keep_indices_device] self.out_cache_loc = None self.seq_lens_sum = self.seq_lens.sum().item() self.output_ids = self.output_ids[keep_indices_device] @@ -1659,6 +1670,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): [self.req_pool_indices, other.req_pool_indices] ) self.seq_lens = torch.cat([self.seq_lens, other.seq_lens]) + self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens]) self.out_cache_loc = None self.seq_lens_sum += other.seq_lens_sum if self.output_ids is not None: @@ -1733,6 +1745,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): input_ids=self.input_ids, req_pool_indices=self.req_pool_indices, seq_lens=self.seq_lens, + orig_seq_lens=self.orig_seq_lens, out_cache_loc=self.out_cache_loc, seq_lens_cpu=seq_lens_cpu, seq_lens_sum=self.seq_lens_sum, @@ -1900,6 +1913,9 @@ class ModelWorkerBatch: # Sampling info sampling_info: SamplingBatchInfo + # The original sequence lengths, Qwen-1M related + orig_seq_lens: Optional[torch.Tensor] = None + # The input Embeds input_embeds: Optional[torch.Tensor] = None diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index c4031557b..05599c697 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -589,6 +589,7 @@ class CudaGraphRunner: req_pool_indices=req_pool_indices, seq_lens=seq_lens, next_token_logits_buffer=next_token_logits_buffer, + orig_seq_lens=seq_lens, req_to_token_pool=self.model_runner.req_to_token_pool, token_to_kv_pool=self.model_runner.token_to_kv_pool, attn_backend=self.model_runner.attn_backend, diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 4c47f319d..e5793a269 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -180,6 +180,9 @@ class ForwardBatch: # The sum of all sequence lengths seq_lens_sum: int + # The original sequence length without being chunked. Qwen-1M related. + orig_seq_lens: Optional[torch.Tensor] = None + # Optional seq_lens on cpu seq_lens_cpu: Optional[torch.Tensor] = None @@ -321,6 +324,7 @@ class ForwardBatch: encoder_out_cache_loc=batch.encoder_out_cache_loc, seq_lens_sum=batch.seq_lens_sum, seq_lens_cpu=batch.seq_lens_cpu, + orig_seq_lens=batch.orig_seq_lens, return_logprob=batch.return_logprob, top_logprobs_nums=batch.top_logprobs_nums, token_ids_logprobs=batch.token_ids_logprobs, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index fe5d2c478..923482d72 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1467,6 +1467,12 @@ class ModelRunner: logger.info(f"Intel AMX attention backend is enabled.") return IntelAMXAttnBackend(self) + elif self.server_args.attention_backend == "dual_chunk_flash_attn": + from sglang.srt.layers.attention.dual_chunk_flashattention_backend import ( + DualChunkFlashAttentionBackend, + ) + + return DualChunkFlashAttentionBackend(self) else: raise ValueError(f"Invalid attention backend: {backend_str}") diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 1696bdfa9..556a5bb8f 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -107,6 +107,7 @@ class Qwen2Attention(nn.Module): rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 32768, quant_config: Optional[QuantizationConfig] = None, + dual_chunk_attention_config: Optional[dict[str, Any]] = None, prefix: str = "", ) -> None: super().__init__() @@ -158,6 +159,7 @@ class Qwen2Attention(nn.Module): max_position=max_position_embeddings, base=rope_theta, rope_scaling=rope_scaling, + dual_chunk_attention_config=dual_chunk_attention_config, ) self.attn = RadixAttention( self.num_heads, @@ -198,6 +200,9 @@ class Qwen2DecoderLayer(nn.Module): rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", 32768) head_dim = getattr(config, "head_dim", None) + dual_chunk_attention_config = getattr( + config, "dual_chunk_attention_config", None + ) self.self_attn = Qwen2Attention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -208,6 +213,7 @@ class Qwen2DecoderLayer(nn.Module): rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, quant_config=quant_config, + dual_chunk_attention_config=dual_chunk_attention_config, prefix=add_prefix("self_attn", prefix), ) self.mlp = Qwen2MLP( diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 1463b6afa..2af1e919d 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -210,6 +210,7 @@ class Qwen2MoeAttention(nn.Module): max_position_embeddings: int = 8192, qkv_bias: int = True, quant_config: Optional[QuantizationConfig] = None, + dual_chunk_attention_config: Optional[dict[str, Any]] = None, prefix: str = "", ) -> None: super().__init__() @@ -267,6 +268,7 @@ class Qwen2MoeAttention(nn.Module): max_position=max_position_embeddings, base=rope_theta, rope_scaling=rope_scaling, + dual_chunk_attention_config=dual_chunk_attention_config, ) self.attn = RadixAttention( self.num_heads, @@ -308,6 +310,9 @@ class Qwen2MoeDecoderLayer(nn.Module): rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) qkv_bias = getattr(config, "qkv_bias", True) + dual_chunk_attention_config = getattr( + config, "dual_chunk_attention_config", None + ) self.self_attn = Qwen2MoeAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -317,6 +322,7 @@ class Qwen2MoeDecoderLayer(nn.Module): rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, quant_config=quant_config, + dual_chunk_attention_config=dual_chunk_attention_config, qkv_bias=qkv_bias, prefix=add_prefix("self_attn", prefix), ) diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index c7dc17444..d7c9290b2 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -295,6 +295,7 @@ class Qwen3MoeAttention(nn.Module): attention_bias: bool = False, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + dual_chunk_attention_config: Optional[dict[str, Any]] = None, alt_stream: Optional[torch.cuda.Stream] = None, ) -> None: super().__init__() @@ -353,6 +354,7 @@ class Qwen3MoeAttention(nn.Module): max_position=max_position_embeddings, base=rope_theta, rope_scaling=rope_scaling, + dual_chunk_attention_config=dual_chunk_attention_config, ) self.attn = RadixAttention( self.num_heads, @@ -458,6 +460,9 @@ class Qwen3MoeDecoderLayer(nn.Module): ) rms_norm_eps = config.rms_norm_eps attention_bias = config.attention_bias + dual_chunk_attention_config = getattr( + config, "dual_chunk_attention_config", None + ) self.self_attn = Qwen3MoeAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -471,6 +476,7 @@ class Qwen3MoeDecoderLayer(nn.Module): attention_bias=attention_bias, quant_config=quant_config, prefix=add_prefix("self_attn", prefix), + dual_chunk_attention_config=dual_chunk_attention_config, alt_stream=alt_stream, ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 8f8774f2a..442403307 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -502,6 +502,20 @@ class ServerArgs: # use bf16 for mxfp4 triton kernels self.dtype = "bfloat16" + if self.attention_backend == "dual_chunk_flash_attn": + logger.warning( + "Mixed chunk is disabled because of using dual chunk flash attention backend" + ) + logger.warning( + "Radix cache is disabled because of using dual chunk flash attention backend" + ) + logger.warning( + "Cuda graph is disabled because of using dual chunk flash attention backend" + ) + self.enable_mixed_chunk = False + self.disable_cuda_graph = True + self.disable_radix_cache = True + # Set page size if self.page_size is None: self.page_size = 1 @@ -1337,6 +1351,7 @@ class ServerArgs: "triton", "trtllm_mla", "trtllm_mha", + "dual_chunk_flash_attn", ], default=ServerArgs.attention_backend, help="Choose the kernels for attention layers.", diff --git a/python/sglang/srt/two_batch_overlap.py b/python/sglang/srt/two_batch_overlap.py index 7e0602a20..8e84b539b 100644 --- a/python/sglang/srt/two_batch_overlap.py +++ b/python/sglang/srt/two_batch_overlap.py @@ -661,6 +661,7 @@ class TboForwardBatchPreparer: "padded_static_len", "mrope_positions", # only used by qwen2-vl, thus not care "split_index", # for split prefill + "orig_seq_lens", # only used by qwen-1m, thus not care ]: output_dict[key] = getattr(batch, key) if not batch.forward_mode.is_target_verify():