[Feat] QWen-1M context support[2/2]: Update block sparse attention backend (#5949)
This commit is contained in:
@@ -27,6 +27,7 @@ from sglang.srt.hf_transformers_utils import (
|
||||
get_context_length,
|
||||
get_generation_config,
|
||||
get_hf_text_config,
|
||||
get_sparse_attention_config,
|
||||
)
|
||||
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
@@ -270,6 +271,9 @@ class ModelConfig:
|
||||
# Verify quantization
|
||||
self._verify_quantization()
|
||||
|
||||
# Verify dual-chunk attention config
|
||||
self._verify_dual_chunk_attention_config()
|
||||
|
||||
# Cache attributes
|
||||
self.hf_eos_token_id = self.get_hf_eos_token_id()
|
||||
|
||||
@@ -297,6 +301,13 @@ class ModelConfig:
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_total_num_attention_heads(self) -> int:
|
||||
return self.num_attention_heads
|
||||
|
||||
def get_num_attention_heads(self, tensor_parallel_size) -> int:
|
||||
total_num_attention_heads = self.num_attention_heads
|
||||
return max(1, total_num_attention_heads // tensor_parallel_size)
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
|
||||
def get_total_num_kv_heads(self) -> int:
|
||||
"""Returns the total number of KV heads."""
|
||||
@@ -484,6 +495,23 @@ class ModelConfig:
|
||||
self.quantization,
|
||||
)
|
||||
|
||||
def _verify_dual_chunk_attention_config(self) -> None:
|
||||
if hasattr(self.hf_config, "dual_chunk_attention_config"):
|
||||
# Try loading the sparse attention config
|
||||
sparse_attn_config = get_sparse_attention_config(self.model_path)
|
||||
if not sparse_attn_config:
|
||||
return
|
||||
self.hf_config.dual_chunk_attention_config["sparse_attention_config"] = (
|
||||
sparse_attn_config
|
||||
)
|
||||
if (
|
||||
"sparse_attention_enabled"
|
||||
not in self.hf_config.dual_chunk_attention_config
|
||||
):
|
||||
self.hf_config.dual_chunk_attention_config[
|
||||
"sparse_attention_enabled"
|
||||
] = True
|
||||
|
||||
def get_hf_eos_token_id(self) -> Optional[Set[int]]:
|
||||
eos_ids = getattr(self.hf_config, "eos_token_id", None)
|
||||
if eos_ids is not None:
|
||||
|
||||
@@ -76,6 +76,9 @@ class ScheduleBatchDisaggregationDecodeMixin:
|
||||
req_pool_indices, dtype=torch.int64, device=self.device
|
||||
)
|
||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
|
||||
self.orig_seq_lens = torch.tensor(
|
||||
seq_lens, dtype=torch.int32, device=self.device
|
||||
)
|
||||
self.out_cache_loc = out_cache_loc
|
||||
self.seq_lens_sum = sum(seq_lens)
|
||||
|
||||
|
||||
@@ -14,10 +14,11 @@
|
||||
"""Utilities for Huggingface Transformers."""
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Type, Union
|
||||
from typing import Any, Dict, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
@@ -62,11 +63,17 @@ for name, cls in _CONFIG_REGISTRY.items():
|
||||
AutoConfig.register(name, cls)
|
||||
|
||||
|
||||
def download_from_hf(model_path: str):
|
||||
def download_from_hf(
|
||||
model_path: str,
|
||||
allow_patterns: Optional[Union[str, list]] = None,
|
||||
):
|
||||
if os.path.exists(model_path):
|
||||
return model_path
|
||||
|
||||
return snapshot_download(model_path, allow_patterns=["*.json", "*.bin", "*.model"])
|
||||
if not allow_patterns:
|
||||
allow_patterns = ["*.json", "*.bin", "*.model"]
|
||||
|
||||
return snapshot_download(model_path, allow_patterns=allow_patterns)
|
||||
|
||||
|
||||
def get_hf_text_config(config: PretrainedConfig):
|
||||
@@ -171,6 +178,26 @@ def get_generation_config(
|
||||
return None
|
||||
|
||||
|
||||
# Qwen-1M related
|
||||
def get_sparse_attention_config(
|
||||
model: str,
|
||||
sparse_attention_config_filename: str = "sparse_attention_config.json",
|
||||
) -> Dict[str, Any]:
|
||||
is_local = os.path.isdir(model)
|
||||
if not is_local:
|
||||
# Download the config files.
|
||||
model = download_from_hf(model, allow_patterns=["*.json"])
|
||||
|
||||
config_file = os.path.join(model, sparse_attention_config_filename)
|
||||
if not os.path.exists(config_file):
|
||||
return {}
|
||||
|
||||
# Load the sparse attention config.
|
||||
with open(config_file) as f:
|
||||
config = json.load(f)
|
||||
return config
|
||||
|
||||
|
||||
# Models don't use the same configuration key for determining the maximum
|
||||
# context length. Store them here so we can sanely check them.
|
||||
# NOTE: The ordering here is important. Some models have two of these and we
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1172,6 +1172,202 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
)
|
||||
|
||||
|
||||
class DualChunkRotaryEmbedding(CustomOp):
|
||||
"""Rotary positional embedding for Dual Chunk Attention."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
chunk_size: int,
|
||||
local_size: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.head_size = head_size
|
||||
self.rotary_dim = rotary_dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
self.is_neox_style = is_neox_style
|
||||
self.chunk_size = chunk_size
|
||||
self.local_size = local_size
|
||||
self.dtype = dtype
|
||||
self.device = torch.device(f"cuda:{torch.cuda.current_device()}")
|
||||
(q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache) = (
|
||||
self._compute_cos_sin_cache()
|
||||
)
|
||||
|
||||
self.register_buffer("cos_sin_q_cache", q_cache, persistent=False)
|
||||
self.register_buffer("cos_sin_qc_cache", qc_cache, persistent=False)
|
||||
self.register_buffer("cos_sin_k_cache", k_cache, persistent=False)
|
||||
self.register_buffer(
|
||||
"cos_sin_qc_no_clamp_cache", qc_no_clamp_cache, persistent=False
|
||||
)
|
||||
self.register_buffer("cos_sin_q_inter_cache", q_inter_cache, persistent=False)
|
||||
|
||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||
"""Compute the inverse frequency."""
|
||||
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
|
||||
# However, we use `torch.arange(..., dtype=torch.float)` instead to
|
||||
# avoid numerical issues with large base values (e.g., 10000000).
|
||||
# This may cause a slight numerical difference between the HF
|
||||
# implementation and ours.
|
||||
# NOTE(woosuk): To exactly match the HF implementation, we need to
|
||||
# use CPU to compute the cache and then move it to GPU. However, we
|
||||
# create the cache on GPU for faster initialization. This may cause
|
||||
# a slight numerical difference between the HF implementation and ours.
|
||||
inv_freq = 1.0 / (
|
||||
base
|
||||
** (
|
||||
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim
|
||||
)
|
||||
)
|
||||
return inv_freq
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
"""Compute the cos and sin cache."""
|
||||
inv_freq = self._compute_inv_freq(self.base)
|
||||
chunk_len = self.chunk_size - self.local_size
|
||||
q_t = torch.arange(chunk_len, dtype=torch.float)
|
||||
qc_t = (torch.arange(chunk_len, dtype=torch.float) + chunk_len).clamp(
|
||||
max=self.chunk_size
|
||||
)
|
||||
k_t = torch.arange(self.max_position_embeddings, dtype=torch.float) % chunk_len
|
||||
|
||||
# count from chunk_len, no clamp(self.chunk_size) restriction
|
||||
qc_no_clamp_t = torch.arange(chunk_len, dtype=torch.float) + chunk_len
|
||||
# count from self.chunk_size for q_inter's rope
|
||||
q_inter_t = torch.arange(chunk_len, dtype=torch.float) + self.chunk_size
|
||||
|
||||
q_freqs = torch.outer(q_t, inv_freq)
|
||||
qc_freqs = torch.outer(qc_t, inv_freq)
|
||||
k_freqs = torch.outer(k_t, inv_freq)
|
||||
qc_no_clamp_freqs = torch.outer(qc_no_clamp_t, inv_freq)
|
||||
q_inter_freqs = torch.outer(q_inter_t, inv_freq)
|
||||
|
||||
q_cos = q_freqs.cos()
|
||||
q_sin = q_freqs.sin()
|
||||
qc_cos = qc_freqs.cos()
|
||||
qc_sin = qc_freqs.sin()
|
||||
k_cos = k_freqs.cos()
|
||||
k_sin = k_freqs.sin()
|
||||
|
||||
qc_no_clamp_cos = qc_no_clamp_freqs.cos()
|
||||
qc_no_clamp_sin = qc_no_clamp_freqs.sin()
|
||||
q_inter_cos = q_inter_freqs.cos()
|
||||
q_inter_sin = q_inter_freqs.sin()
|
||||
|
||||
q_cache = torch.cat((q_cos, q_sin), dim=-1).to(
|
||||
dtype=self.dtype, device=self.device
|
||||
)
|
||||
qc_cache = torch.cat((qc_cos, qc_sin), dim=-1).to(
|
||||
dtype=self.dtype, device=self.device
|
||||
)
|
||||
k_cache = torch.cat((k_cos, k_sin), dim=-1).to(
|
||||
dtype=self.dtype, device=self.device
|
||||
)
|
||||
qc_no_clamp_cache = torch.cat((qc_no_clamp_cos, qc_no_clamp_sin), dim=-1).to(
|
||||
dtype=self.dtype, device=self.device
|
||||
)
|
||||
q_inter_cache = torch.cat((q_inter_cos, q_inter_sin), dim=-1).to(
|
||||
dtype=self.dtype, device=self.device
|
||||
)
|
||||
return q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
query = query.view(*query.shape[:-1], -1, self.head_size)
|
||||
key = key.view(*key.shape[:-1], -1, self.head_size)
|
||||
query_rot = query[..., : self.rotary_dim]
|
||||
key_rot = key[..., : self.rotary_dim]
|
||||
if self.rotary_dim < self.head_size:
|
||||
query_pass = query[..., self.rotary_dim :]
|
||||
key_pass = key[..., self.rotary_dim :]
|
||||
else:
|
||||
query_pass = None
|
||||
key_pass = None
|
||||
|
||||
positions_with_offsets = (
|
||||
torch.add(positions, offsets) if offsets is not None else positions
|
||||
)
|
||||
key = self._apply_rotary_embedding(
|
||||
self.cos_sin_k_cache[positions_with_offsets], key_rot, key_pass
|
||||
)
|
||||
chunk_len = self.chunk_size - self.local_size
|
||||
query = self._apply_rotary_embedding(
|
||||
self.cos_sin_q_cache[positions_with_offsets % chunk_len],
|
||||
query_rot,
|
||||
query_pass,
|
||||
)
|
||||
query_succ = self._apply_rotary_embedding(
|
||||
self.cos_sin_qc_cache[positions_with_offsets % chunk_len],
|
||||
query_rot,
|
||||
query_pass,
|
||||
)
|
||||
query_inter = self._apply_rotary_embedding(
|
||||
self.cos_sin_qc_cache[chunk_len - 1].repeat(positions.shape[0], 1),
|
||||
query_rot,
|
||||
query_pass,
|
||||
)
|
||||
query_succ_critical = self._apply_rotary_embedding(
|
||||
self.cos_sin_qc_no_clamp_cache[positions_with_offsets % chunk_len],
|
||||
query_rot,
|
||||
query_pass,
|
||||
)
|
||||
query_inter_critical = self._apply_rotary_embedding(
|
||||
self.cos_sin_q_inter_cache[positions_with_offsets % chunk_len],
|
||||
query_rot,
|
||||
query_pass,
|
||||
)
|
||||
|
||||
# merge query into one tensor to simplify the interfaces
|
||||
query = torch.cat(
|
||||
(
|
||||
query,
|
||||
query_succ,
|
||||
query_inter,
|
||||
query_succ_critical,
|
||||
query_inter_critical,
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
return query, key
|
||||
|
||||
def _apply_rotary_embedding(self, cos_sin, hidden_rot, hidden_pass):
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
if self.is_neox_style:
|
||||
# NOTE(woosuk): Here we assume that the positions tensor has the
|
||||
# shape [batch_size, seq_len].
|
||||
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
|
||||
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
|
||||
else:
|
||||
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
||||
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
||||
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
|
||||
hidden_rot = hidden_rot * cos + rotate_fn(hidden_rot) * sin
|
||||
|
||||
if self.rotary_dim < self.head_size:
|
||||
hidden = torch.cat((hidden_rot, hidden_pass), dim=-1)
|
||||
else:
|
||||
hidden = hidden_rot
|
||||
return hidden.flatten(-2).squeeze(0)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
|
||||
s += f", max_position_embeddings={self.max_position_embeddings}"
|
||||
s += f", base={self.base}, is_neox_style={self.is_neox_style}"
|
||||
s += f", chunk_size={self.chunk_size}, local_size={self.local_size}"
|
||||
return s
|
||||
|
||||
|
||||
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
|
||||
|
||||
|
||||
@@ -1184,6 +1380,7 @@ def get_rope(
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
partial_rotary_factor: float = 1.0,
|
||||
dual_chunk_attention_config: Optional[Dict[str, Any]] = None,
|
||||
) -> RotaryEmbedding:
|
||||
if dtype is None:
|
||||
dtype = torch.get_default_dtype()
|
||||
@@ -1195,6 +1392,17 @@ def get_rope(
|
||||
rope_scaling_args = tuple(rope_scaling_tuple.items())
|
||||
else:
|
||||
rope_scaling_args = None
|
||||
|
||||
if dual_chunk_attention_config is not None:
|
||||
dual_chunk_attention_tuple = {
|
||||
k: tuple(v) if isinstance(v, list) else v
|
||||
for k, v in dual_chunk_attention_config.items()
|
||||
if k != "sparse_attention_config"
|
||||
}
|
||||
dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items())
|
||||
else:
|
||||
dual_chunk_attention_args = None
|
||||
|
||||
if partial_rotary_factor < 1.0:
|
||||
rotary_dim = int(rotary_dim * partial_rotary_factor)
|
||||
key = (
|
||||
@@ -1204,12 +1412,28 @@ def get_rope(
|
||||
base,
|
||||
is_neox_style,
|
||||
rope_scaling_args,
|
||||
dual_chunk_attention_args,
|
||||
dtype,
|
||||
)
|
||||
if key in _ROPE_DICT:
|
||||
return _ROPE_DICT[key]
|
||||
|
||||
if rope_scaling is None:
|
||||
if dual_chunk_attention_config is not None:
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in dual_chunk_attention_config.items()
|
||||
if k in ("chunk_size", "local_size")
|
||||
}
|
||||
rotary_emb = DualChunkRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
**extra_kwargs,
|
||||
)
|
||||
elif rope_scaling is None:
|
||||
rotary_emb = RotaryEmbedding(
|
||||
head_size, rotary_dim, max_position, base, is_neox_style, dtype
|
||||
)
|
||||
|
||||
@@ -846,6 +846,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
|
||||
# The sum of all sequence lengths
|
||||
seq_lens_sum: int = None
|
||||
# The original sequence lengths, Qwen-1M related
|
||||
orig_seq_lens: torch.Tensor = None # shape: [b], int32
|
||||
|
||||
# For DP attention
|
||||
global_num_tokens: Optional[List[int]] = None
|
||||
@@ -1131,6 +1133,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
||||
extend_num_tokens = sum(len(ids) for ids in input_ids)
|
||||
seq_lens = [len(r.fill_ids) for r in reqs]
|
||||
orig_seq_lens = [max(len(r.fill_ids), len(r.origin_input_ids)) for r in reqs]
|
||||
prefix_lens = [len(r.prefix_indices) for r in reqs]
|
||||
extend_lens = [r.extend_input_len for r in reqs]
|
||||
|
||||
@@ -1147,6 +1150,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
prefix_lens_tensor = torch.tensor(
|
||||
prefix_lens, dtype=torch.int64, device=self.device
|
||||
)
|
||||
@@ -1260,6 +1266,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.input_ids = input_ids_tensor
|
||||
self.req_pool_indices = req_pool_indices_tensor
|
||||
self.seq_lens = seq_lens_tensor
|
||||
self.orig_seq_lens = orig_seq_lens_tensor
|
||||
self.out_cache_loc = out_cache_loc
|
||||
self.input_embeds = (
|
||||
torch.tensor(input_embeds).to(self.device, non_blocking=True)
|
||||
@@ -1507,6 +1514,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.forward_mode = ForwardMode.IDLE
|
||||
self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
|
||||
self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
|
||||
self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
|
||||
self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
|
||||
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
|
||||
self.seq_lens_sum = 0
|
||||
@@ -1561,9 +1569,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
if self.enable_overlap:
|
||||
# Do not use in-place operations in the overlap mode
|
||||
self.seq_lens = self.seq_lens + 1
|
||||
self.orig_seq_lens = self.orig_seq_lens + 1
|
||||
else:
|
||||
# A faster in-place version
|
||||
self.seq_lens.add_(1)
|
||||
self.orig_seq_lens.add_(1)
|
||||
self.seq_lens_sum += bs
|
||||
|
||||
# free memory
|
||||
@@ -1627,6 +1637,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
|
||||
self.req_pool_indices = self.req_pool_indices[keep_indices_device]
|
||||
self.seq_lens = self.seq_lens[keep_indices_device]
|
||||
self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
|
||||
self.out_cache_loc = None
|
||||
self.seq_lens_sum = self.seq_lens.sum().item()
|
||||
self.output_ids = self.output_ids[keep_indices_device]
|
||||
@@ -1659,6 +1670,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
[self.req_pool_indices, other.req_pool_indices]
|
||||
)
|
||||
self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
|
||||
self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens])
|
||||
self.out_cache_loc = None
|
||||
self.seq_lens_sum += other.seq_lens_sum
|
||||
if self.output_ids is not None:
|
||||
@@ -1733,6 +1745,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
input_ids=self.input_ids,
|
||||
req_pool_indices=self.req_pool_indices,
|
||||
seq_lens=self.seq_lens,
|
||||
orig_seq_lens=self.orig_seq_lens,
|
||||
out_cache_loc=self.out_cache_loc,
|
||||
seq_lens_cpu=seq_lens_cpu,
|
||||
seq_lens_sum=self.seq_lens_sum,
|
||||
@@ -1900,6 +1913,9 @@ class ModelWorkerBatch:
|
||||
# Sampling info
|
||||
sampling_info: SamplingBatchInfo
|
||||
|
||||
# The original sequence lengths, Qwen-1M related
|
||||
orig_seq_lens: Optional[torch.Tensor] = None
|
||||
|
||||
# The input Embeds
|
||||
input_embeds: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
@@ -589,6 +589,7 @@ class CudaGraphRunner:
|
||||
req_pool_indices=req_pool_indices,
|
||||
seq_lens=seq_lens,
|
||||
next_token_logits_buffer=next_token_logits_buffer,
|
||||
orig_seq_lens=seq_lens,
|
||||
req_to_token_pool=self.model_runner.req_to_token_pool,
|
||||
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
||||
attn_backend=self.model_runner.attn_backend,
|
||||
|
||||
@@ -180,6 +180,9 @@ class ForwardBatch:
|
||||
# The sum of all sequence lengths
|
||||
seq_lens_sum: int
|
||||
|
||||
# The original sequence length without being chunked. Qwen-1M related.
|
||||
orig_seq_lens: Optional[torch.Tensor] = None
|
||||
|
||||
# Optional seq_lens on cpu
|
||||
seq_lens_cpu: Optional[torch.Tensor] = None
|
||||
|
||||
@@ -321,6 +324,7 @@ class ForwardBatch:
|
||||
encoder_out_cache_loc=batch.encoder_out_cache_loc,
|
||||
seq_lens_sum=batch.seq_lens_sum,
|
||||
seq_lens_cpu=batch.seq_lens_cpu,
|
||||
orig_seq_lens=batch.orig_seq_lens,
|
||||
return_logprob=batch.return_logprob,
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
token_ids_logprobs=batch.token_ids_logprobs,
|
||||
|
||||
@@ -1467,6 +1467,12 @@ class ModelRunner:
|
||||
|
||||
logger.info(f"Intel AMX attention backend is enabled.")
|
||||
return IntelAMXAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "dual_chunk_flash_attn":
|
||||
from sglang.srt.layers.attention.dual_chunk_flashattention_backend import (
|
||||
DualChunkFlashAttentionBackend,
|
||||
)
|
||||
|
||||
return DualChunkFlashAttentionBackend(self)
|
||||
else:
|
||||
raise ValueError(f"Invalid attention backend: {backend_str}")
|
||||
|
||||
|
||||
@@ -107,6 +107,7 @@ class Qwen2Attention(nn.Module):
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 32768,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -158,6 +159,7 @@ class Qwen2Attention(nn.Module):
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
dual_chunk_attention_config=dual_chunk_attention_config,
|
||||
)
|
||||
self.attn = RadixAttention(
|
||||
self.num_heads,
|
||||
@@ -198,6 +200,9 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 32768)
|
||||
head_dim = getattr(config, "head_dim", None)
|
||||
dual_chunk_attention_config = getattr(
|
||||
config, "dual_chunk_attention_config", None
|
||||
)
|
||||
self.self_attn = Qwen2Attention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
@@ -208,6 +213,7 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
dual_chunk_attention_config=dual_chunk_attention_config,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
self.mlp = Qwen2MLP(
|
||||
|
||||
@@ -210,6 +210,7 @@ class Qwen2MoeAttention(nn.Module):
|
||||
max_position_embeddings: int = 8192,
|
||||
qkv_bias: int = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -267,6 +268,7 @@ class Qwen2MoeAttention(nn.Module):
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
dual_chunk_attention_config=dual_chunk_attention_config,
|
||||
)
|
||||
self.attn = RadixAttention(
|
||||
self.num_heads,
|
||||
@@ -308,6 +310,9 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||
qkv_bias = getattr(config, "qkv_bias", True)
|
||||
dual_chunk_attention_config = getattr(
|
||||
config, "dual_chunk_attention_config", None
|
||||
)
|
||||
self.self_attn = Qwen2MoeAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
@@ -317,6 +322,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
dual_chunk_attention_config=dual_chunk_attention_config,
|
||||
qkv_bias=qkv_bias,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
|
||||
@@ -295,6 +295,7 @@ class Qwen3MoeAttention(nn.Module):
|
||||
attention_bias: bool = False,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
|
||||
alt_stream: Optional[torch.cuda.Stream] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -353,6 +354,7 @@ class Qwen3MoeAttention(nn.Module):
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
dual_chunk_attention_config=dual_chunk_attention_config,
|
||||
)
|
||||
self.attn = RadixAttention(
|
||||
self.num_heads,
|
||||
@@ -458,6 +460,9 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
||||
)
|
||||
rms_norm_eps = config.rms_norm_eps
|
||||
attention_bias = config.attention_bias
|
||||
dual_chunk_attention_config = getattr(
|
||||
config, "dual_chunk_attention_config", None
|
||||
)
|
||||
self.self_attn = Qwen3MoeAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
@@ -471,6 +476,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
||||
attention_bias=attention_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
dual_chunk_attention_config=dual_chunk_attention_config,
|
||||
alt_stream=alt_stream,
|
||||
)
|
||||
|
||||
|
||||
@@ -502,6 +502,20 @@ class ServerArgs:
|
||||
# use bf16 for mxfp4 triton kernels
|
||||
self.dtype = "bfloat16"
|
||||
|
||||
if self.attention_backend == "dual_chunk_flash_attn":
|
||||
logger.warning(
|
||||
"Mixed chunk is disabled because of using dual chunk flash attention backend"
|
||||
)
|
||||
logger.warning(
|
||||
"Radix cache is disabled because of using dual chunk flash attention backend"
|
||||
)
|
||||
logger.warning(
|
||||
"Cuda graph is disabled because of using dual chunk flash attention backend"
|
||||
)
|
||||
self.enable_mixed_chunk = False
|
||||
self.disable_cuda_graph = True
|
||||
self.disable_radix_cache = True
|
||||
|
||||
# Set page size
|
||||
if self.page_size is None:
|
||||
self.page_size = 1
|
||||
@@ -1337,6 +1351,7 @@ class ServerArgs:
|
||||
"triton",
|
||||
"trtllm_mla",
|
||||
"trtllm_mha",
|
||||
"dual_chunk_flash_attn",
|
||||
],
|
||||
default=ServerArgs.attention_backend,
|
||||
help="Choose the kernels for attention layers.",
|
||||
|
||||
@@ -661,6 +661,7 @@ class TboForwardBatchPreparer:
|
||||
"padded_static_len",
|
||||
"mrope_positions", # only used by qwen2-vl, thus not care
|
||||
"split_index", # for split prefill
|
||||
"orig_seq_lens", # only used by qwen-1m, thus not care
|
||||
]:
|
||||
output_dict[key] = getattr(batch, key)
|
||||
if not batch.forward_mode.is_target_verify():
|
||||
|
||||
Reference in New Issue
Block a user