[Feat] QWen-1M context support[2/2]: Update block sparse attention backend (#5949)
This commit is contained in:
@@ -27,6 +27,7 @@ from sglang.srt.hf_transformers_utils import (
|
||||
get_context_length,
|
||||
get_generation_config,
|
||||
get_hf_text_config,
|
||||
get_sparse_attention_config,
|
||||
)
|
||||
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
@@ -270,6 +271,9 @@ class ModelConfig:
|
||||
# Verify quantization
|
||||
self._verify_quantization()
|
||||
|
||||
# Verify dual-chunk attention config
|
||||
self._verify_dual_chunk_attention_config()
|
||||
|
||||
# Cache attributes
|
||||
self.hf_eos_token_id = self.get_hf_eos_token_id()
|
||||
|
||||
@@ -297,6 +301,13 @@ class ModelConfig:
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_total_num_attention_heads(self) -> int:
|
||||
return self.num_attention_heads
|
||||
|
||||
def get_num_attention_heads(self, tensor_parallel_size) -> int:
|
||||
total_num_attention_heads = self.num_attention_heads
|
||||
return max(1, total_num_attention_heads // tensor_parallel_size)
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
|
||||
def get_total_num_kv_heads(self) -> int:
|
||||
"""Returns the total number of KV heads."""
|
||||
@@ -484,6 +495,23 @@ class ModelConfig:
|
||||
self.quantization,
|
||||
)
|
||||
|
||||
def _verify_dual_chunk_attention_config(self) -> None:
|
||||
if hasattr(self.hf_config, "dual_chunk_attention_config"):
|
||||
# Try loading the sparse attention config
|
||||
sparse_attn_config = get_sparse_attention_config(self.model_path)
|
||||
if not sparse_attn_config:
|
||||
return
|
||||
self.hf_config.dual_chunk_attention_config["sparse_attention_config"] = (
|
||||
sparse_attn_config
|
||||
)
|
||||
if (
|
||||
"sparse_attention_enabled"
|
||||
not in self.hf_config.dual_chunk_attention_config
|
||||
):
|
||||
self.hf_config.dual_chunk_attention_config[
|
||||
"sparse_attention_enabled"
|
||||
] = True
|
||||
|
||||
def get_hf_eos_token_id(self) -> Optional[Set[int]]:
|
||||
eos_ids = getattr(self.hf_config, "eos_token_id", None)
|
||||
if eos_ids is not None:
|
||||
|
||||
Reference in New Issue
Block a user