[Feat] QWen-1M context support[2/2]: Update block sparse attention backend (#5949)

This commit is contained in:
PGFLMG
2025-08-07 14:49:36 +08:00
committed by GitHub
parent a69b637014
commit b7cd743038
15 changed files with 2121 additions and 4 deletions

View File

@@ -14,10 +14,11 @@
"""Utilities for Huggingface Transformers."""
import contextlib
import json
import os
import warnings
from pathlib import Path
from typing import Dict, Optional, Type, Union
from typing import Any, Dict, Optional, Type, Union
import torch
from huggingface_hub import snapshot_download
@@ -62,11 +63,17 @@ for name, cls in _CONFIG_REGISTRY.items():
AutoConfig.register(name, cls)
def download_from_hf(model_path: str):
def download_from_hf(
model_path: str,
allow_patterns: Optional[Union[str, list]] = None,
):
if os.path.exists(model_path):
return model_path
return snapshot_download(model_path, allow_patterns=["*.json", "*.bin", "*.model"])
if not allow_patterns:
allow_patterns = ["*.json", "*.bin", "*.model"]
return snapshot_download(model_path, allow_patterns=allow_patterns)
def get_hf_text_config(config: PretrainedConfig):
@@ -171,6 +178,26 @@ def get_generation_config(
return None
# Qwen-1M related
def get_sparse_attention_config(
model: str,
sparse_attention_config_filename: str = "sparse_attention_config.json",
) -> Dict[str, Any]:
is_local = os.path.isdir(model)
if not is_local:
# Download the config files.
model = download_from_hf(model, allow_patterns=["*.json"])
config_file = os.path.join(model, sparse_attention_config_filename)
if not os.path.exists(config_file):
return {}
# Load the sparse attention config.
with open(config_file) as f:
config = json.load(f)
return config
# Models don't use the same configuration key for determining the maximum
# context length. Store them here so we can sanely check them.
# NOTE: The ordering here is important. Some models have two of these and we