[Feat] QWen-1M context support[2/2]: Update block sparse attention backend (#5949)
This commit is contained in:
@@ -589,6 +589,7 @@ class CudaGraphRunner:
|
||||
req_pool_indices=req_pool_indices,
|
||||
seq_lens=seq_lens,
|
||||
next_token_logits_buffer=next_token_logits_buffer,
|
||||
orig_seq_lens=seq_lens,
|
||||
req_to_token_pool=self.model_runner.req_to_token_pool,
|
||||
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
||||
attn_backend=self.model_runner.attn_backend,
|
||||
|
||||
@@ -180,6 +180,9 @@ class ForwardBatch:
|
||||
# The sum of all sequence lengths
|
||||
seq_lens_sum: int
|
||||
|
||||
# The original sequence length without being chunked. Qwen-1M related.
|
||||
orig_seq_lens: Optional[torch.Tensor] = None
|
||||
|
||||
# Optional seq_lens on cpu
|
||||
seq_lens_cpu: Optional[torch.Tensor] = None
|
||||
|
||||
@@ -321,6 +324,7 @@ class ForwardBatch:
|
||||
encoder_out_cache_loc=batch.encoder_out_cache_loc,
|
||||
seq_lens_sum=batch.seq_lens_sum,
|
||||
seq_lens_cpu=batch.seq_lens_cpu,
|
||||
orig_seq_lens=batch.orig_seq_lens,
|
||||
return_logprob=batch.return_logprob,
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
token_ids_logprobs=batch.token_ids_logprobs,
|
||||
|
||||
@@ -1467,6 +1467,12 @@ class ModelRunner:
|
||||
|
||||
logger.info(f"Intel AMX attention backend is enabled.")
|
||||
return IntelAMXAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "dual_chunk_flash_attn":
|
||||
from sglang.srt.layers.attention.dual_chunk_flashattention_backend import (
|
||||
DualChunkFlashAttentionBackend,
|
||||
)
|
||||
|
||||
return DualChunkFlashAttentionBackend(self)
|
||||
else:
|
||||
raise ValueError(f"Invalid attention backend: {backend_str}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user