[Feat] QWen-1M context support[2/2]: Update block sparse attention backend (#5949)

This commit is contained in:
PGFLMG
2025-08-07 14:49:36 +08:00
committed by GitHub
parent a69b637014
commit b7cd743038
15 changed files with 2121 additions and 4 deletions

View File

@@ -27,6 +27,7 @@ from sglang.srt.hf_transformers_utils import (
get_context_length,
get_generation_config,
get_hf_text_config,
get_sparse_attention_config,
)
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
from sglang.srt.server_args import ServerArgs
@@ -270,6 +271,9 @@ class ModelConfig:
# Verify quantization
self._verify_quantization()
# Verify dual-chunk attention config
self._verify_dual_chunk_attention_config()
# Cache attributes
self.hf_eos_token_id = self.get_hf_eos_token_id()
@@ -297,6 +301,13 @@ class ModelConfig:
**kwargs,
)
def get_total_num_attention_heads(self) -> int:
return self.num_attention_heads
def get_num_attention_heads(self, tensor_parallel_size) -> int:
total_num_attention_heads = self.num_attention_heads
return max(1, total_num_attention_heads // tensor_parallel_size)
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
def get_total_num_kv_heads(self) -> int:
"""Returns the total number of KV heads."""
@@ -484,6 +495,23 @@ class ModelConfig:
self.quantization,
)
def _verify_dual_chunk_attention_config(self) -> None:
if hasattr(self.hf_config, "dual_chunk_attention_config"):
# Try loading the sparse attention config
sparse_attn_config = get_sparse_attention_config(self.model_path)
if not sparse_attn_config:
return
self.hf_config.dual_chunk_attention_config["sparse_attention_config"] = (
sparse_attn_config
)
if (
"sparse_attention_enabled"
not in self.hf_config.dual_chunk_attention_config
):
self.hf_config.dual_chunk_attention_config[
"sparse_attention_enabled"
] = True
def get_hf_eos_token_id(self) -> Optional[Set[int]]:
eos_ids = getattr(self.hf_config, "eos_token_id", None)
if eos_ids is not None:

View File

@@ -76,6 +76,9 @@ class ScheduleBatchDisaggregationDecodeMixin:
req_pool_indices, dtype=torch.int64, device=self.device
)
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
self.orig_seq_lens = torch.tensor(
seq_lens, dtype=torch.int32, device=self.device
)
self.out_cache_loc = out_cache_loc
self.seq_lens_sum = sum(seq_lens)

View File

@@ -14,10 +14,11 @@
"""Utilities for Huggingface Transformers."""
import contextlib
import json
import os
import warnings
from pathlib import Path
from typing import Dict, Optional, Type, Union
from typing import Any, Dict, Optional, Type, Union
import torch
from huggingface_hub import snapshot_download
@@ -62,11 +63,17 @@ for name, cls in _CONFIG_REGISTRY.items():
AutoConfig.register(name, cls)
def download_from_hf(model_path: str):
def download_from_hf(
model_path: str,
allow_patterns: Optional[Union[str, list]] = None,
):
if os.path.exists(model_path):
return model_path
return snapshot_download(model_path, allow_patterns=["*.json", "*.bin", "*.model"])
if not allow_patterns:
allow_patterns = ["*.json", "*.bin", "*.model"]
return snapshot_download(model_path, allow_patterns=allow_patterns)
def get_hf_text_config(config: PretrainedConfig):
@@ -171,6 +178,26 @@ def get_generation_config(
return None
# Qwen-1M related
def get_sparse_attention_config(
model: str,
sparse_attention_config_filename: str = "sparse_attention_config.json",
) -> Dict[str, Any]:
is_local = os.path.isdir(model)
if not is_local:
# Download the config files.
model = download_from_hf(model, allow_patterns=["*.json"])
config_file = os.path.join(model, sparse_attention_config_filename)
if not os.path.exists(config_file):
return {}
# Load the sparse attention config.
with open(config_file) as f:
config = json.load(f)
return config
# Models don't use the same configuration key for determining the maximum
# context length. Store them here so we can sanely check them.
# NOTE: The ordering here is important. Some models have two of these and we

File diff suppressed because it is too large Load Diff

View File

@@ -1172,6 +1172,202 @@ class MRotaryEmbedding(RotaryEmbedding):
)
class DualChunkRotaryEmbedding(CustomOp):
"""Rotary positional embedding for Dual Chunk Attention."""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
chunk_size: int,
local_size: int,
) -> None:
super().__init__()
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style
self.chunk_size = chunk_size
self.local_size = local_size
self.dtype = dtype
self.device = torch.device(f"cuda:{torch.cuda.current_device()}")
(q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache) = (
self._compute_cos_sin_cache()
)
self.register_buffer("cos_sin_q_cache", q_cache, persistent=False)
self.register_buffer("cos_sin_qc_cache", qc_cache, persistent=False)
self.register_buffer("cos_sin_k_cache", k_cache, persistent=False)
self.register_buffer(
"cos_sin_qc_no_clamp_cache", qc_no_clamp_cache, persistent=False
)
self.register_buffer("cos_sin_q_inter_cache", q_inter_cache, persistent=False)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
"""Compute the inverse frequency."""
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
# However, we use `torch.arange(..., dtype=torch.float)` instead to
# avoid numerical issues with large base values (e.g., 10000000).
# This may cause a slight numerical difference between the HF
# implementation and ours.
# NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we
# create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours.
inv_freq = 1.0 / (
base
** (
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim
)
)
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base)
chunk_len = self.chunk_size - self.local_size
q_t = torch.arange(chunk_len, dtype=torch.float)
qc_t = (torch.arange(chunk_len, dtype=torch.float) + chunk_len).clamp(
max=self.chunk_size
)
k_t = torch.arange(self.max_position_embeddings, dtype=torch.float) % chunk_len
# count from chunk_len, no clamp(self.chunk_size) restriction
qc_no_clamp_t = torch.arange(chunk_len, dtype=torch.float) + chunk_len
# count from self.chunk_size for q_inter's rope
q_inter_t = torch.arange(chunk_len, dtype=torch.float) + self.chunk_size
q_freqs = torch.outer(q_t, inv_freq)
qc_freqs = torch.outer(qc_t, inv_freq)
k_freqs = torch.outer(k_t, inv_freq)
qc_no_clamp_freqs = torch.outer(qc_no_clamp_t, inv_freq)
q_inter_freqs = torch.outer(q_inter_t, inv_freq)
q_cos = q_freqs.cos()
q_sin = q_freqs.sin()
qc_cos = qc_freqs.cos()
qc_sin = qc_freqs.sin()
k_cos = k_freqs.cos()
k_sin = k_freqs.sin()
qc_no_clamp_cos = qc_no_clamp_freqs.cos()
qc_no_clamp_sin = qc_no_clamp_freqs.sin()
q_inter_cos = q_inter_freqs.cos()
q_inter_sin = q_inter_freqs.sin()
q_cache = torch.cat((q_cos, q_sin), dim=-1).to(
dtype=self.dtype, device=self.device
)
qc_cache = torch.cat((qc_cos, qc_sin), dim=-1).to(
dtype=self.dtype, device=self.device
)
k_cache = torch.cat((k_cos, k_sin), dim=-1).to(
dtype=self.dtype, device=self.device
)
qc_no_clamp_cache = torch.cat((qc_no_clamp_cos, qc_no_clamp_sin), dim=-1).to(
dtype=self.dtype, device=self.device
)
q_inter_cache = torch.cat((q_inter_cos, q_inter_sin), dim=-1).to(
dtype=self.dtype, device=self.device
)
return q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
query = query.view(*query.shape[:-1], -1, self.head_size)
key = key.view(*key.shape[:-1], -1, self.head_size)
query_rot = query[..., : self.rotary_dim]
key_rot = key[..., : self.rotary_dim]
if self.rotary_dim < self.head_size:
query_pass = query[..., self.rotary_dim :]
key_pass = key[..., self.rotary_dim :]
else:
query_pass = None
key_pass = None
positions_with_offsets = (
torch.add(positions, offsets) if offsets is not None else positions
)
key = self._apply_rotary_embedding(
self.cos_sin_k_cache[positions_with_offsets], key_rot, key_pass
)
chunk_len = self.chunk_size - self.local_size
query = self._apply_rotary_embedding(
self.cos_sin_q_cache[positions_with_offsets % chunk_len],
query_rot,
query_pass,
)
query_succ = self._apply_rotary_embedding(
self.cos_sin_qc_cache[positions_with_offsets % chunk_len],
query_rot,
query_pass,
)
query_inter = self._apply_rotary_embedding(
self.cos_sin_qc_cache[chunk_len - 1].repeat(positions.shape[0], 1),
query_rot,
query_pass,
)
query_succ_critical = self._apply_rotary_embedding(
self.cos_sin_qc_no_clamp_cache[positions_with_offsets % chunk_len],
query_rot,
query_pass,
)
query_inter_critical = self._apply_rotary_embedding(
self.cos_sin_q_inter_cache[positions_with_offsets % chunk_len],
query_rot,
query_pass,
)
# merge query into one tensor to simplify the interfaces
query = torch.cat(
(
query,
query_succ,
query_inter,
query_succ_critical,
query_inter_critical,
),
dim=-1,
)
return query, key
def _apply_rotary_embedding(self, cos_sin, hidden_rot, hidden_pass):
cos, sin = cos_sin.chunk(2, dim=-1)
if self.is_neox_style:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
else:
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
hidden_rot = hidden_rot * cos + rotate_fn(hidden_rot) * sin
if self.rotary_dim < self.head_size:
hidden = torch.cat((hidden_rot, hidden_pass), dim=-1)
else:
hidden = hidden_rot
return hidden.flatten(-2).squeeze(0)
def extra_repr(self) -> str:
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
s += f", max_position_embeddings={self.max_position_embeddings}"
s += f", base={self.base}, is_neox_style={self.is_neox_style}"
s += f", chunk_size={self.chunk_size}, local_size={self.local_size}"
return s
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
@@ -1184,6 +1380,7 @@ def get_rope(
rope_scaling: Optional[Dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None,
partial_rotary_factor: float = 1.0,
dual_chunk_attention_config: Optional[Dict[str, Any]] = None,
) -> RotaryEmbedding:
if dtype is None:
dtype = torch.get_default_dtype()
@@ -1195,6 +1392,17 @@ def get_rope(
rope_scaling_args = tuple(rope_scaling_tuple.items())
else:
rope_scaling_args = None
if dual_chunk_attention_config is not None:
dual_chunk_attention_tuple = {
k: tuple(v) if isinstance(v, list) else v
for k, v in dual_chunk_attention_config.items()
if k != "sparse_attention_config"
}
dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items())
else:
dual_chunk_attention_args = None
if partial_rotary_factor < 1.0:
rotary_dim = int(rotary_dim * partial_rotary_factor)
key = (
@@ -1204,12 +1412,28 @@ def get_rope(
base,
is_neox_style,
rope_scaling_args,
dual_chunk_attention_args,
dtype,
)
if key in _ROPE_DICT:
return _ROPE_DICT[key]
if rope_scaling is None:
if dual_chunk_attention_config is not None:
extra_kwargs = {
k: v
for k, v in dual_chunk_attention_config.items()
if k in ("chunk_size", "local_size")
}
rotary_emb = DualChunkRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
**extra_kwargs,
)
elif rope_scaling is None:
rotary_emb = RotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style, dtype
)

View File

@@ -846,6 +846,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# The sum of all sequence lengths
seq_lens_sum: int = None
# The original sequence lengths, Qwen-1M related
orig_seq_lens: torch.Tensor = None # shape: [b], int32
# For DP attention
global_num_tokens: Optional[List[int]] = None
@@ -1131,6 +1133,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
extend_num_tokens = sum(len(ids) for ids in input_ids)
seq_lens = [len(r.fill_ids) for r in reqs]
orig_seq_lens = [max(len(r.fill_ids), len(r.origin_input_ids)) for r in reqs]
prefix_lens = [len(r.prefix_indices) for r in reqs]
extend_lens = [r.extend_input_len for r in reqs]
@@ -1147,6 +1150,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
self.device, non_blocking=True
)
orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
self.device, non_blocking=True
)
prefix_lens_tensor = torch.tensor(
prefix_lens, dtype=torch.int64, device=self.device
)
@@ -1260,6 +1266,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.input_ids = input_ids_tensor
self.req_pool_indices = req_pool_indices_tensor
self.seq_lens = seq_lens_tensor
self.orig_seq_lens = orig_seq_lens_tensor
self.out_cache_loc = out_cache_loc
self.input_embeds = (
torch.tensor(input_embeds).to(self.device, non_blocking=True)
@@ -1507,6 +1514,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.forward_mode = ForwardMode.IDLE
self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
self.seq_lens_sum = 0
@@ -1561,9 +1569,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
if self.enable_overlap:
# Do not use in-place operations in the overlap mode
self.seq_lens = self.seq_lens + 1
self.orig_seq_lens = self.orig_seq_lens + 1
else:
# A faster in-place version
self.seq_lens.add_(1)
self.orig_seq_lens.add_(1)
self.seq_lens_sum += bs
# free memory
@@ -1627,6 +1637,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
self.req_pool_indices = self.req_pool_indices[keep_indices_device]
self.seq_lens = self.seq_lens[keep_indices_device]
self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
self.out_cache_loc = None
self.seq_lens_sum = self.seq_lens.sum().item()
self.output_ids = self.output_ids[keep_indices_device]
@@ -1659,6 +1670,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
[self.req_pool_indices, other.req_pool_indices]
)
self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens])
self.out_cache_loc = None
self.seq_lens_sum += other.seq_lens_sum
if self.output_ids is not None:
@@ -1733,6 +1745,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
input_ids=self.input_ids,
req_pool_indices=self.req_pool_indices,
seq_lens=self.seq_lens,
orig_seq_lens=self.orig_seq_lens,
out_cache_loc=self.out_cache_loc,
seq_lens_cpu=seq_lens_cpu,
seq_lens_sum=self.seq_lens_sum,
@@ -1900,6 +1913,9 @@ class ModelWorkerBatch:
# Sampling info
sampling_info: SamplingBatchInfo
# The original sequence lengths, Qwen-1M related
orig_seq_lens: Optional[torch.Tensor] = None
# The input Embeds
input_embeds: Optional[torch.Tensor] = None

View File

@@ -589,6 +589,7 @@ class CudaGraphRunner:
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
next_token_logits_buffer=next_token_logits_buffer,
orig_seq_lens=seq_lens,
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
attn_backend=self.model_runner.attn_backend,

View File

@@ -180,6 +180,9 @@ class ForwardBatch:
# The sum of all sequence lengths
seq_lens_sum: int
# The original sequence length without being chunked. Qwen-1M related.
orig_seq_lens: Optional[torch.Tensor] = None
# Optional seq_lens on cpu
seq_lens_cpu: Optional[torch.Tensor] = None
@@ -321,6 +324,7 @@ class ForwardBatch:
encoder_out_cache_loc=batch.encoder_out_cache_loc,
seq_lens_sum=batch.seq_lens_sum,
seq_lens_cpu=batch.seq_lens_cpu,
orig_seq_lens=batch.orig_seq_lens,
return_logprob=batch.return_logprob,
top_logprobs_nums=batch.top_logprobs_nums,
token_ids_logprobs=batch.token_ids_logprobs,

View File

@@ -1467,6 +1467,12 @@ class ModelRunner:
logger.info(f"Intel AMX attention backend is enabled.")
return IntelAMXAttnBackend(self)
elif self.server_args.attention_backend == "dual_chunk_flash_attn":
from sglang.srt.layers.attention.dual_chunk_flashattention_backend import (
DualChunkFlashAttentionBackend,
)
return DualChunkFlashAttentionBackend(self)
else:
raise ValueError(f"Invalid attention backend: {backend_str}")

View File

@@ -107,6 +107,7 @@ class Qwen2Attention(nn.Module):
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 32768,
quant_config: Optional[QuantizationConfig] = None,
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
prefix: str = "",
) -> None:
super().__init__()
@@ -158,6 +159,7 @@ class Qwen2Attention(nn.Module):
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
dual_chunk_attention_config=dual_chunk_attention_config,
)
self.attn = RadixAttention(
self.num_heads,
@@ -198,6 +200,9 @@ class Qwen2DecoderLayer(nn.Module):
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 32768)
head_dim = getattr(config, "head_dim", None)
dual_chunk_attention_config = getattr(
config, "dual_chunk_attention_config", None
)
self.self_attn = Qwen2Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
@@ -208,6 +213,7 @@ class Qwen2DecoderLayer(nn.Module):
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
dual_chunk_attention_config=dual_chunk_attention_config,
prefix=add_prefix("self_attn", prefix),
)
self.mlp = Qwen2MLP(

View File

@@ -210,6 +210,7 @@ class Qwen2MoeAttention(nn.Module):
max_position_embeddings: int = 8192,
qkv_bias: int = True,
quant_config: Optional[QuantizationConfig] = None,
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
prefix: str = "",
) -> None:
super().__init__()
@@ -267,6 +268,7 @@ class Qwen2MoeAttention(nn.Module):
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
dual_chunk_attention_config=dual_chunk_attention_config,
)
self.attn = RadixAttention(
self.num_heads,
@@ -308,6 +310,9 @@ class Qwen2MoeDecoderLayer(nn.Module):
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
qkv_bias = getattr(config, "qkv_bias", True)
dual_chunk_attention_config = getattr(
config, "dual_chunk_attention_config", None
)
self.self_attn = Qwen2MoeAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
@@ -317,6 +322,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
dual_chunk_attention_config=dual_chunk_attention_config,
qkv_bias=qkv_bias,
prefix=add_prefix("self_attn", prefix),
)

View File

@@ -295,6 +295,7 @@ class Qwen3MoeAttention(nn.Module):
attention_bias: bool = False,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None:
super().__init__()
@@ -353,6 +354,7 @@ class Qwen3MoeAttention(nn.Module):
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
dual_chunk_attention_config=dual_chunk_attention_config,
)
self.attn = RadixAttention(
self.num_heads,
@@ -458,6 +460,9 @@ class Qwen3MoeDecoderLayer(nn.Module):
)
rms_norm_eps = config.rms_norm_eps
attention_bias = config.attention_bias
dual_chunk_attention_config = getattr(
config, "dual_chunk_attention_config", None
)
self.self_attn = Qwen3MoeAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
@@ -471,6 +476,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
attention_bias=attention_bias,
quant_config=quant_config,
prefix=add_prefix("self_attn", prefix),
dual_chunk_attention_config=dual_chunk_attention_config,
alt_stream=alt_stream,
)

View File

@@ -502,6 +502,20 @@ class ServerArgs:
# use bf16 for mxfp4 triton kernels
self.dtype = "bfloat16"
if self.attention_backend == "dual_chunk_flash_attn":
logger.warning(
"Mixed chunk is disabled because of using dual chunk flash attention backend"
)
logger.warning(
"Radix cache is disabled because of using dual chunk flash attention backend"
)
logger.warning(
"Cuda graph is disabled because of using dual chunk flash attention backend"
)
self.enable_mixed_chunk = False
self.disable_cuda_graph = True
self.disable_radix_cache = True
# Set page size
if self.page_size is None:
self.page_size = 1
@@ -1337,6 +1351,7 @@ class ServerArgs:
"triton",
"trtllm_mla",
"trtllm_mha",
"dual_chunk_flash_attn",
],
default=ServerArgs.attention_backend,
help="Choose the kernels for attention layers.",

View File

@@ -661,6 +661,7 @@ class TboForwardBatchPreparer:
"padded_static_len",
"mrope_positions", # only used by qwen2-vl, thus not care
"split_index", # for split prefill
"orig_seq_lens", # only used by qwen-1m, thus not care
]:
output_dict[key] = getattr(batch, key)
if not batch.forward_mode.is_target_verify():