Qwen3.6-27B iluvatar bi-v100 adaptation
This commit is contained in:
13
Dockerfile
13
Dockerfile
@@ -1,15 +1,6 @@
|
||||
FROM git.modelhub.org.cn:9443/enginex-iluvatar/bi100-3.2.3-x86-ubuntu20.04-py3.10-poc-llm-infer:v1.2.3
|
||||
|
||||
RUN pip install --no-cache-dir triton==2.1.0
|
||||
|
||||
COPY pkgs/triton /usr/local/corex/lib64/python3/dist-packages/triton
|
||||
COPY pkgs/triton-2.1.0+corex.4.1.2.dist-info /usr/local/corex/lib64/python3/dist-packages/triton-2.1.0+corex.4.1.2.dist-info
|
||||
|
||||
COPY paged_attn.py /usr/local/corex/lib64/python3/dist-packages/vllm/attention/ops/paged_attn.py
|
||||
COPY __init__.py /usr/local/corex/lib64/python3/dist-packages/vllm/triton_utils/__init__.py
|
||||
|
||||
RUN mkdir /workspace
|
||||
WORKDIR /workspace/
|
||||
|
||||
COPY ./launch_service /workspace/launch_service
|
||||
|
||||
COPY ./qwen3_6_scripts /workspace/qwen3_6_scripts
|
||||
RUN cd ./qwen3_6_scripts && ./patch_ops.sh
|
||||
42
README_qwen3_6.md
Normal file
42
README_qwen3_6.md
Normal file
@@ -0,0 +1,42 @@
|
||||
# 天数智芯 天垓100 文本生成引擎(基于 vLLM 优化适配Qwen3.6-27B)
|
||||
|
||||
```
|
||||
# 本地构建
|
||||
docker build -t enginex-iluvatar-vllm:bi100-qwen3.6 -f Dockerfile .
|
||||
```
|
||||
|
||||
|
||||
启动容器镜像
|
||||
|
||||
下载Qwen3.6-27B模型,并且需要将模型的config.json文件中architectures字段改成
|
||||
```json
|
||||
"architectures": [
|
||||
"Qwen3_5ForCausalLM"
|
||||
]
|
||||
```
|
||||
|
||||
```bash
|
||||
docker run -dit --network=host --ipc=host \
|
||||
-v /usr/src:/usr/src -v /lib/modules:/lib/modules -v /dev:/dev --privileged \
|
||||
--name vllm-iluvatar \
|
||||
-v /mnt/models/Qwen3.6-27B:/model:ro --entrypoint=python3 \
|
||||
enginex-iluvatar-vllm:bi100 \
|
||||
-m vllm.entrypoints.openai.api_server \
|
||||
--model /model --port 1111 --served-model-name llm \
|
||||
--max-model-len 10000 --enforce-eager --trust-remote-code -tp 4 --gpu-memory-utilization 0.95
|
||||
```
|
||||
|
||||
请求
|
||||
```bash
|
||||
curl http://localhost:1111/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "llm",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Can you tell me the story of Snow White?"}
|
||||
],
|
||||
"max_tokens": 200,
|
||||
"temperature": 0.7
|
||||
}'
|
||||
```
|
||||
224
qwen3_6_scripts/mamba_cache.py
Normal file
224
qwen3_6_scripts/mamba_cache.py
Normal file
@@ -0,0 +1,224 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
|
||||
|
||||
class MambaCacheManager:
|
||||
|
||||
def __init__(self, dtype, num_mamba_layers, max_batch_size,
|
||||
conv_state_shape, temporal_state_shape):
|
||||
|
||||
conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
|
||||
conv_state_shape,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
temporal_state = torch.zeros(size=(num_mamba_layers, max_batch_size) +
|
||||
temporal_state_shape,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
|
||||
self.mamba_cache = (conv_state, temporal_state)
|
||||
|
||||
# Maps between the request id and a dict that maps between the seq_id
|
||||
# and its index inside the self.mamba_cache
|
||||
self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {}
|
||||
|
||||
def current_run_tensors(self, input_ids: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata, **kwargs):
|
||||
"""
|
||||
Return the tensors for the current run's conv and ssm state.
|
||||
"""
|
||||
if "seqlen_agnostic_capture_inputs" not in kwargs:
|
||||
# We get here only on Prefill/Eager mode runs
|
||||
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
|
||||
finished_requests_ids = kwargs["finished_requests_ids"]
|
||||
|
||||
self._release_finished_requests(finished_requests_ids)
|
||||
mamba_cache_tensors = self._prepare_current_run_mamba_cache(
|
||||
request_ids_to_seq_ids, finished_requests_ids)
|
||||
|
||||
else:
|
||||
# CUDA graph capturing runs
|
||||
mamba_cache_tensors = kwargs["seqlen_agnostic_capture_inputs"]
|
||||
|
||||
return mamba_cache_tensors
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||
"""
|
||||
Copy the relevant Mamba cache into the CUDA graph input buffer
|
||||
that was provided during the capture runs
|
||||
(JambaForCausalLM.mamba_gc_cache_buffer).
|
||||
"""
|
||||
assert all(
|
||||
key in kwargs
|
||||
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
|
||||
finished_requests_ids = kwargs["finished_requests_ids"]
|
||||
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
|
||||
|
||||
self._release_finished_requests(finished_requests_ids)
|
||||
self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
|
||||
finished_requests_ids)
|
||||
|
||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||
"""
|
||||
Provide the CUDA graph capture runs with a buffer in adjusted size.
|
||||
The buffer is used to maintain the Mamba Cache during the CUDA graph
|
||||
replay runs.
|
||||
"""
|
||||
return tuple(buffer[:, :batch_size] for buffer in self.mamba_cache)
|
||||
|
||||
def _swap_mamba_cache(self, from_index: int, to_index: int):
|
||||
assert len(self.mamba_cache) > 0
|
||||
for cache_t in self.mamba_cache:
|
||||
cache_t[:, [to_index,from_index]] = \
|
||||
cache_t[:, [from_index,to_index]]
|
||||
|
||||
def _copy_mamba_cache(self, from_index: int, to_index: int):
|
||||
assert len(self.mamba_cache) > 0
|
||||
for cache_t in self.mamba_cache:
|
||||
cache_t[:, to_index].copy_(cache_t[:, from_index],
|
||||
non_blocking=True)
|
||||
|
||||
def _move_out_if_already_occupied(self, index: int,
|
||||
all_occupied_indices: List[int]):
|
||||
if index in all_occupied_indices:
|
||||
first_free_index = self._first_free_index_in_mamba_cache()
|
||||
# In case occupied, move the occupied to a new empty block
|
||||
self._move_cache_index_and_mappings(from_index=index,
|
||||
to_index=first_free_index)
|
||||
|
||||
def _assign_seq_id_to_mamba_cache_in_specific_dest(self, cur_rid: str,
|
||||
seq_id: int,
|
||||
destination_index: int):
|
||||
"""
|
||||
Assign (req_id,seq_id) pair to a `destination_index` index, if
|
||||
already occupied, move the occupying index to a free index.
|
||||
"""
|
||||
all_occupied_indices = self._get_all_occupied_indices()
|
||||
if cur_rid not in self.mamba_cache_indices_mapping:
|
||||
self._move_out_if_already_occupied(
|
||||
index=destination_index,
|
||||
all_occupied_indices=all_occupied_indices)
|
||||
for cache_t in self.mamba_cache:
|
||||
cache_t[:, destination_index].zero_()
|
||||
self.mamba_cache_indices_mapping[cur_rid] = {
|
||||
seq_id: destination_index
|
||||
}
|
||||
elif seq_id not in (seq_ids2indices :=
|
||||
self.mamba_cache_indices_mapping[cur_rid]):
|
||||
# parallel sampling , where n > 1, assume prefill have
|
||||
# already happened now we only need to copy the already
|
||||
# existing cache into the siblings seq_ids caches
|
||||
self._move_out_if_already_occupied(
|
||||
index=destination_index,
|
||||
all_occupied_indices=all_occupied_indices)
|
||||
index_exists = list(seq_ids2indices.values())[0]
|
||||
# case of decoding n>1, copy prefill cache to decoding indices
|
||||
self._copy_mamba_cache(from_index=index_exists,
|
||||
to_index=destination_index)
|
||||
self.mamba_cache_indices_mapping[cur_rid][
|
||||
seq_id] = destination_index
|
||||
else:
|
||||
# already exists
|
||||
cache_index_already_exists = self.mamba_cache_indices_mapping[
|
||||
cur_rid][seq_id]
|
||||
if cache_index_already_exists != destination_index:
|
||||
# In case the seq id already exists but not in
|
||||
# the right destination, swap it with what's occupying it
|
||||
self._swap_pair_indices_and_mappings(
|
||||
from_index=cache_index_already_exists,
|
||||
to_index=destination_index)
|
||||
|
||||
def _prepare_current_run_mamba_cache(
|
||||
self, request_ids_to_seq_ids: Dict[str, list[int]],
|
||||
finished_requests_ids: List[str]):
|
||||
running_indices = []
|
||||
request_ids_to_seq_ids_flatten = [
|
||||
(req_id, seq_id)
|
||||
for req_id, seq_ids in request_ids_to_seq_ids.items()
|
||||
for seq_id in seq_ids
|
||||
]
|
||||
batch_size = len(request_ids_to_seq_ids_flatten)
|
||||
for dest_index, (request_id,
|
||||
seq_id) in enumerate(request_ids_to_seq_ids_flatten):
|
||||
if request_id in finished_requests_ids:
|
||||
# Do not allocate cache index for requests that run
|
||||
# and finish right after
|
||||
continue
|
||||
self._assign_seq_id_to_mamba_cache_in_specific_dest(
|
||||
request_id, seq_id, dest_index)
|
||||
running_indices.append(dest_index)
|
||||
|
||||
self._clean_up_first_bs_blocks(batch_size, running_indices)
|
||||
conv_state = self.mamba_cache[0][:, :batch_size]
|
||||
temporal_state = self.mamba_cache[1][:, :batch_size]
|
||||
|
||||
return (conv_state, temporal_state)
|
||||
|
||||
def _get_all_occupied_indices(self):
|
||||
return [
|
||||
cache_idx
|
||||
for seq_ids2indices in self.mamba_cache_indices_mapping.values()
|
||||
for cache_idx in seq_ids2indices.values()
|
||||
]
|
||||
|
||||
def _clean_up_first_bs_blocks(self, batch_size: int,
|
||||
indices_for_current_run: List[int]):
|
||||
# move out all of the occupied but currently not running blocks
|
||||
# outside of the first n blocks
|
||||
destination_indices = range(batch_size)
|
||||
max_possible_batch_size = self.mamba_cache[0].shape[1]
|
||||
for destination_index in destination_indices:
|
||||
if destination_index in self._get_all_occupied_indices() and \
|
||||
destination_index not in indices_for_current_run:
|
||||
# move not running indices outside of the batch
|
||||
all_other_indices = list(
|
||||
range(batch_size, max_possible_batch_size))
|
||||
first_avail_index = self._first_free_index_in_mamba_cache(
|
||||
all_other_indices)
|
||||
self._swap_indices(from_index=destination_index,
|
||||
to_index=first_avail_index)
|
||||
|
||||
def _move_cache_index_and_mappings(self, from_index: int, to_index: int):
|
||||
self._copy_mamba_cache(from_index=from_index, to_index=to_index)
|
||||
self._update_mapping_index(from_index=from_index, to_index=to_index)
|
||||
|
||||
def _swap_pair_indices_and_mappings(self, from_index: int, to_index: int):
|
||||
self._swap_mamba_cache(from_index=from_index, to_index=to_index)
|
||||
self._swap_mapping_index(from_index=from_index, to_index=to_index)
|
||||
|
||||
def _swap_mapping_index(self, from_index: int, to_index: int):
|
||||
for seq_ids2index in self.mamba_cache_indices_mapping.values():
|
||||
for seq_id, index in seq_ids2index.items():
|
||||
if from_index == index:
|
||||
seq_ids2index.update({seq_id: to_index})
|
||||
elif to_index == index:
|
||||
seq_ids2index.update({seq_id: from_index})
|
||||
|
||||
def _update_mapping_index(self, from_index: int, to_index: int):
|
||||
for seq_ids2index in self.mamba_cache_indices_mapping.values():
|
||||
for seq_id, index in seq_ids2index.items():
|
||||
if from_index == index:
|
||||
seq_ids2index.update({seq_id: to_index})
|
||||
return
|
||||
|
||||
def _release_finished_requests(self,
|
||||
finished_seq_groups_req_ids: List[str]):
|
||||
for req_id in finished_seq_groups_req_ids:
|
||||
if req_id in self.mamba_cache_indices_mapping:
|
||||
self.mamba_cache_indices_mapping.pop(req_id)
|
||||
|
||||
def _first_free_index_in_mamba_cache(
|
||||
self, indices_range: Optional[List[int]] = None) -> int:
|
||||
assert self.mamba_cache is not None
|
||||
if indices_range is None:
|
||||
max_possible_batch_size = self.mamba_cache[0].shape[1]
|
||||
indices_range = list(range(max_possible_batch_size))
|
||||
all_occupied_indices = self._get_all_occupied_indices()
|
||||
for i in indices_range:
|
||||
if i not in all_occupied_indices:
|
||||
return i
|
||||
raise Exception("Couldn't find a free spot in the mamba cache! This"
|
||||
"should never happen")
|
||||
10
qwen3_6_scripts/patch_ops.sh
Executable file
10
qwen3_6_scripts/patch_ops.sh
Executable file
@@ -0,0 +1,10 @@
|
||||
pip install transformers==4.55.3 -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
cp -r ./qwen3_5 /usr/local/lib/python3.10/site-packages/transformers/models/
|
||||
python3 ./patch_transformers_qwen3_5.py
|
||||
|
||||
cp ./mamba_cache.py /usr/local/corex/lib/python3/dist-packages/vllm/model_executor/models/
|
||||
cp ./qwen3_5.py /usr/local/corex/lib/python3/dist-packages/vllm/model_executor/models/
|
||||
python3 ./patch_vllm_qwen3_5.py
|
||||
|
||||
# 此步骤脚本四选一(默认 matmul+seq策略)
|
||||
python3 ./patch_xformers_sdpa_seq.py
|
||||
97
qwen3_6_scripts/patch_transformers_qwen3_5.py
Normal file
97
qwen3_6_scripts/patch_transformers_qwen3_5.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
Patches transformers 4.55.3 to register the qwen3_5 model type.
|
||||
|
||||
Deploy steps on the remote machine:
|
||||
1. cp -r modified_scripts/qwen3_5 /usr/local/lib/python3.10/site-packages/transformers/models/qwen3_5
|
||||
2. python3 modified_scripts/patch_transformers_qwen3_5.py
|
||||
|
||||
Target: pip-installed transformers at /usr/local/lib/python3.10/site-packages/transformers/
|
||||
(Not the corex pre-installed path at /usr/local/corex/lib64/python3/dist-packages/)
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
TRANSFORMERS_ROOT = "/usr/local/lib/python3.10/site-packages/transformers"
|
||||
AUTO_CONFIG = f"{TRANSFORMERS_ROOT}/models/auto/configuration_auto.py"
|
||||
MODELS_INIT = f"{TRANSFORMERS_ROOT}/models/__init__.py"
|
||||
|
||||
|
||||
def patch_file(path, replacements):
|
||||
with open(path, "r") as f:
|
||||
content = f.read()
|
||||
|
||||
patched = False
|
||||
for old, new in replacements:
|
||||
if new in content:
|
||||
print(f" [skip] already patched: {repr(new[:60])}")
|
||||
continue
|
||||
if old not in content:
|
||||
print(f" [warn] anchor not found: {repr(old[:60])}")
|
||||
continue
|
||||
content = content.replace(old, new, 1)
|
||||
patched = True
|
||||
print(f" [ok] inserted after: {repr(old[:60])}")
|
||||
|
||||
if patched:
|
||||
with open(path, "w") as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
def main():
|
||||
print(f"=== Patching {AUTO_CONFIG} ===")
|
||||
patch_file(AUTO_CONFIG, [
|
||||
# CONFIG_MAPPING_NAMES: insert qwen3_5 right after qwen3
|
||||
(
|
||||
'("qwen3", "Qwen3Config"),',
|
||||
'("qwen3", "Qwen3Config"),\n ("qwen3_5", "Qwen3_5Config"),',
|
||||
),
|
||||
# Some versions don't have trailing comma — handle that too
|
||||
(
|
||||
'("qwen3", "Qwen3Config")\n',
|
||||
'("qwen3", "Qwen3Config"),\n ("qwen3_5", "Qwen3_5Config"),\n',
|
||||
),
|
||||
# MODEL_NAMES_MAPPING (model_type -> human readable name, used by docstring generator)
|
||||
(
|
||||
'("qwen3", "Qwen3"),',
|
||||
'("qwen3", "Qwen3"),\n ("qwen3_5", "Qwen3_5"),',
|
||||
),
|
||||
(
|
||||
'("qwen3", "Qwen3")\n',
|
||||
'("qwen3", "Qwen3"),\n ("qwen3_5", "Qwen3_5"),\n',
|
||||
),
|
||||
])
|
||||
|
||||
print(f"\n=== Patching {MODELS_INIT} ===")
|
||||
patch_file(MODELS_INIT, [
|
||||
(
|
||||
"from .qwen3 import *\n",
|
||||
"from .qwen3 import *\n from .qwen3_5 import *\n",
|
||||
),
|
||||
])
|
||||
|
||||
# Verification
|
||||
print("\n=== Verification ===")
|
||||
try:
|
||||
import importlib.util, types
|
||||
|
||||
# Quick smoke-test: import the config class directly
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"configuration_qwen3_5",
|
||||
f"{TRANSFORMERS_ROOT}/models/qwen3_5/configuration_qwen3_5.py",
|
||||
)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
# Provide minimal parent package stubs so relative imports resolve
|
||||
pkg = types.ModuleType("transformers")
|
||||
pkg.__path__ = [TRANSFORMERS_ROOT]
|
||||
sys.modules.setdefault("transformers", pkg)
|
||||
spec.loader.exec_module(mod)
|
||||
cfg = mod.Qwen3_5Config()
|
||||
print(f" Qwen3_5Config() smoke-test OK (model_type={cfg.model_type})")
|
||||
except Exception as e:
|
||||
print(f" [warn] smoke-test failed (may be fine at runtime): {e}")
|
||||
|
||||
print("\nDone.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
73
qwen3_6_scripts/patch_vllm_qwen3_5.py
Normal file
73
qwen3_6_scripts/patch_vllm_qwen3_5.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""
|
||||
Patches the vLLM model registry and deploys the Qwen3_5 model file.
|
||||
|
||||
Deploy steps on the remote machine:
|
||||
1. cp modified_scripts/qwen3_5.py \
|
||||
/usr/local/corex/lib64/python3/dist-packages/vllm/model_executor/models/qwen3_5.py
|
||||
2. python3 modified_scripts/patch_vllm_qwen3_5.py
|
||||
|
||||
Also edit your model config.json to set:
|
||||
"architectures": ["Qwen3_5ForCausalLM"]
|
||||
|
||||
Target: vLLM at /usr/local/corex/lib64/python3/dist-packages/vllm/
|
||||
"""
|
||||
|
||||
VLLM_ROOT = "/usr/local/corex/lib64/python3/dist-packages/vllm"
|
||||
REGISTRY = f"{VLLM_ROOT}/model_executor/models/registry.py"
|
||||
|
||||
|
||||
def patch_file(path, replacements):
|
||||
with open(path, "r") as f:
|
||||
content = f.read()
|
||||
|
||||
patched = False
|
||||
for old, new in replacements:
|
||||
if new in content:
|
||||
print(f" [skip] already patched: {repr(new[:70])}")
|
||||
continue
|
||||
if old not in content:
|
||||
print(f" [warn] anchor not found: {repr(old[:70])}")
|
||||
continue
|
||||
content = content.replace(old, new, 1)
|
||||
patched = True
|
||||
print(f" [ok] patched after: {repr(old[:70])}")
|
||||
|
||||
if patched:
|
||||
with open(path, "w") as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
def main():
|
||||
print(f"=== Patching {REGISTRY} ===")
|
||||
patch_file(REGISTRY, [
|
||||
(
|
||||
' "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),\n'
|
||||
' "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),',
|
||||
' "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),\n'
|
||||
' "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),\n'
|
||||
' "Qwen3_5ForCausalLM": ("qwen3_5", "Qwen3_5ForCausalLM"),',
|
||||
),
|
||||
])
|
||||
|
||||
print("\n=== Verification ===")
|
||||
try:
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"qwen3_5",
|
||||
f"{VLLM_ROOT}/model_executor/models/qwen3_5.py",
|
||||
)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
# Quick check: does the class exist?
|
||||
spec.loader.exec_module(mod)
|
||||
cls = mod.Qwen3_5ForCausalLM
|
||||
print(f" Qwen3_5ForCausalLM found: {cls}")
|
||||
except Exception as e:
|
||||
print(f" [warn] verification failed (may be OK at runtime): {e}")
|
||||
|
||||
print("\nDone. Remember to:")
|
||||
print(" 1. Set config.json 'architectures': ['Qwen3_5ForCausalLM']")
|
||||
print(" 2. Run patch_transformers_qwen3_5.py if not already done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
192
qwen3_6_scripts/patch_xformers_sdpa_batch.py
Normal file
192
qwen3_6_scripts/patch_xformers_sdpa_batch.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""
|
||||
策略:批量(block-diagonal)fallback — 纯 PyTorch 数学实现
|
||||
=============================================================
|
||||
构建块对角 causal mask,对整批序列一次 matmul + softmax,
|
||||
完全绕开所有硬件 flash attention kernel。
|
||||
|
||||
背景:
|
||||
ixformer flshattF: head_dim > 128 报错拒绝
|
||||
cudnnFlashAttnForward: 接受 head_dim=256,但数值结果错误(输出全"!")
|
||||
两者大概率是同一硬件单元,ixformer 提前拦截了硬件不支持的配置。
|
||||
纯 matmul 路径完全绕开硬件 flash attention,数值正确。
|
||||
|
||||
优点:
|
||||
数值正确。
|
||||
并发请求 prefill attention 在 GPU 上真正并行(一次大 matmul)。
|
||||
|
||||
缺点:
|
||||
峰值显存 = total_tokens² × H × dtype_size
|
||||
total_tokens 受 --max-num-batched-tokens 控制,max-model-len 控制不住。
|
||||
|
||||
内存参考(fp16,H_local=6,--max-num-batched-tokens=T):
|
||||
T=2048 → 峰值 ~50 MB
|
||||
T=4096 → 峰值 ~200 MB
|
||||
T=8192 → 峰值 ~800 MB
|
||||
T=16384 → 峰值 ~3.2 GB
|
||||
|
||||
Deploy:
|
||||
python3 modified_scripts/patch_xformers_sdpa_batch.py
|
||||
"""
|
||||
|
||||
XFORMERS_PATH = (
|
||||
"/usr/local/corex/lib64/python3/dist-packages/"
|
||||
"vllm/attention/backends/xformers.py"
|
||||
)
|
||||
|
||||
FALLBACK_METHOD = '''
|
||||
def _run_sdpa_fallback(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_metadata: "XFormersMetadata",
|
||||
) -> torch.Tensor:
|
||||
"""批量纯数学 attention fallback。
|
||||
|
||||
构建块对角 causal mask(等价于 ixformer BlockDiagonalCausalMask),
|
||||
对整批序列一次 matmul + softmax,GPU 并行处理所有序列。
|
||||
|
||||
块对角 mask 结构(seq1 len=3,seq2 len=2):
|
||||
s1,0 s1,1 s1,2 s2,0 s2,1
|
||||
s1,0 [ 0 -inf -inf -inf -inf ]
|
||||
s1,1 [ 0 0 -inf -inf -inf ]
|
||||
s1,2 [ 0 0 0 -inf -inf ]
|
||||
s2,0 [-inf -inf -inf 0 -inf ]
|
||||
s2,1 [-inf -inf -inf 0 0 ]
|
||||
|
||||
softmax 在 float32 下计算防止 float16 溢出,结果转回原始 dtype。
|
||||
|
||||
Args:
|
||||
query : [1, total_prefill_tokens, num_heads, head_dim]
|
||||
key : [1, total_prefill_tokens, num_kv_heads, head_dim]
|
||||
value : [1, total_prefill_tokens, num_kv_heads, head_dim]
|
||||
Returns:
|
||||
[1, total_prefill_tokens, num_heads, head_dim]
|
||||
"""
|
||||
assert attn_metadata.seq_lens is not None
|
||||
orig_dtype = query.dtype
|
||||
total_tokens = query.shape[1]
|
||||
|
||||
# ── 构建块对角 causal mask [T, T] ────────────────────────────────
|
||||
# 全部初始化为 -inf,再对每条序列的对角块填入下三角 0
|
||||
mask = torch.full(
|
||||
(total_tokens, total_tokens),
|
||||
float("-inf"),
|
||||
dtype=torch.float32,
|
||||
device=query.device,
|
||||
)
|
||||
start = 0
|
||||
for seq_len in attn_metadata.seq_lens:
|
||||
end = start + seq_len
|
||||
mask[start:end, start:end] = torch.tril(
|
||||
torch.zeros(seq_len, seq_len,
|
||||
dtype=torch.float32, device=query.device)
|
||||
)
|
||||
start = end
|
||||
|
||||
# ── [1, H, T, D],.contiguous() ──────────────────────────────────
|
||||
q_all = query.squeeze(0).permute(1, 0, 2).contiguous().unsqueeze(0)
|
||||
k_all = key.squeeze(0).permute(1, 0, 2).contiguous().unsqueeze(0)
|
||||
v_all = value.squeeze(0).permute(1, 0, 2).contiguous().unsqueeze(0)
|
||||
|
||||
# ── GQA:展开 KV heads ────────────────────────────────────────────
|
||||
if k_all.shape[1] != q_all.shape[1]:
|
||||
n = q_all.shape[1] // k_all.shape[1]
|
||||
k_all = k_all.repeat_interleave(n, dim=1).contiguous()
|
||||
v_all = v_all.repeat_interleave(n, dim=1).contiguous()
|
||||
|
||||
# ── 纯数学 attention(float32 防溢出)────────────────────────────
|
||||
# [1, H, T, T]
|
||||
attn_w = torch.matmul(q_all.float(), k_all.float().transpose(-2, -1))
|
||||
attn_w = attn_w * self.scale
|
||||
attn_w = attn_w + mask # 加法广播:mask [T,T] → [1, H, T, T]
|
||||
attn_w = torch.softmax(attn_w, dim=-1)
|
||||
|
||||
out = torch.matmul(attn_w, v_all.float()).to(orig_dtype)
|
||||
# [1, H, T, D] → [1, T, H, D]
|
||||
return out.squeeze(0).permute(1, 0, 2).contiguous().unsqueeze(0)
|
||||
|
||||
'''
|
||||
|
||||
OLD_XFORMER_BLOCK = """\
|
||||
self.attn_op = xops.fmha.flash.FwOp()
|
||||
if self.alibi_slopes is None:
|
||||
# Add the batch dimension.
|
||||
query = query.unsqueeze(0)
|
||||
key = key.unsqueeze(0)
|
||||
value = value.unsqueeze(0)
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_bias=attn_bias[0],
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
op = self.attn_op
|
||||
)
|
||||
return out.view_as(original_query)\
|
||||
"""
|
||||
|
||||
NEW_XFORMER_BLOCK = """\
|
||||
self.attn_op = xops.fmha.flash.FwOp()
|
||||
if self.alibi_slopes is None:
|
||||
# Add the batch dimension.
|
||||
query = query.unsqueeze(0)
|
||||
key = key.unsqueeze(0)
|
||||
value = value.unsqueeze(0)
|
||||
if self.head_size > 128:
|
||||
out = self._run_sdpa_fallback(query, key, value, attn_metadata)
|
||||
else:
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_bias=attn_bias[0],
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
op=self.attn_op,
|
||||
)
|
||||
return out.view_as(original_query)\
|
||||
"""
|
||||
|
||||
INJECT_ANCHOR = " def _run_memory_efficient_xformers_forward("
|
||||
|
||||
|
||||
def patch_file(path):
|
||||
with open(path, "r") as f:
|
||||
content = f.read()
|
||||
changed = False
|
||||
|
||||
if "_run_sdpa_fallback" in content:
|
||||
print(" [skip] _run_sdpa_fallback already present")
|
||||
elif INJECT_ANCHOR not in content:
|
||||
print(" [warn] inject anchor not found")
|
||||
else:
|
||||
content = content.replace(INJECT_ANCHOR, FALLBACK_METHOD + INJECT_ANCHOR, 1)
|
||||
print(" [ok] injected _run_sdpa_fallback (batch, pure-math)")
|
||||
changed = True
|
||||
|
||||
if NEW_XFORMER_BLOCK in content:
|
||||
print(" [skip] dispatch block already patched")
|
||||
elif OLD_XFORMER_BLOCK in content:
|
||||
content = content.replace(OLD_XFORMER_BLOCK, NEW_XFORMER_BLOCK, 1)
|
||||
print(" [ok] patched dispatch block")
|
||||
changed = True
|
||||
else:
|
||||
print(" [warn] dispatch block anchor not found")
|
||||
|
||||
if changed:
|
||||
with open(path, "w") as f:
|
||||
f.write(content)
|
||||
print(f" Written: {path}")
|
||||
|
||||
|
||||
def main():
|
||||
print("=== patch_xformers_sdpa_batch (batch, pure-math) ===")
|
||||
print(f"Target: {XFORMERS_PATH}")
|
||||
patch_file(XFORMERS_PATH)
|
||||
print("\nDone.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
191
qwen3_6_scripts/patch_xformers_sdpa_batch_kernel.py
Normal file
191
qwen3_6_scripts/patch_xformers_sdpa_batch_kernel.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""
|
||||
策略:批量(block-diagonal)— F.scaled_dot_product_attention,可走硬件 kernel
|
||||
=============================================================================
|
||||
构建块对角 causal mask,对整批序列一次 F.scaled_dot_product_attention。
|
||||
与 patch_xformers_sdpa_batch.py(纯 matmul)的区别:
|
||||
SDPA 会根据 PyTorch/驱动能力分发到最优 kernel(Flash Attention /
|
||||
mem-efficient attention / math fallback),而不是固定走 cublas matmul。
|
||||
|
||||
历史说明:
|
||||
该方案最早因输出全"!"而被弃用,后续排查确认"!"由 mamba_cache.py bug
|
||||
引起,与 attention 实现无关。当前恢复此方案用于性能对比测试。
|
||||
|
||||
已知硬件限制(BI-V100):
|
||||
cudnnFlashAttnForward 不支持 is_causal=True(报错)。
|
||||
本实现使用 is_causal=False + 显式块对角 additive mask 规避此限制。
|
||||
若 SDPA 仍分发到有问题的 kernel,回退到 patch_xformers_sdpa_batch.py。
|
||||
|
||||
优点(vs 纯 matmul):
|
||||
SDPA 可分发到 Flash Attention kernel → O(L) 显存、更快的 CUDA kernel。
|
||||
|
||||
缺点:
|
||||
依赖硬件 kernel 行为,若 kernel 有 bug 则数值错误(需与 matmul 版对比验证)。
|
||||
|
||||
Deploy:
|
||||
python3 modified_scripts/patch_xformers_sdpa_batch_kernel.py
|
||||
"""
|
||||
|
||||
XFORMERS_PATH = (
|
||||
"/usr/local/corex/lib64/python3/dist-packages/"
|
||||
"vllm/attention/backends/xformers.py"
|
||||
)
|
||||
|
||||
FALLBACK_METHOD = '''
|
||||
def _run_sdpa_fallback(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_metadata: "XFormersMetadata",
|
||||
) -> torch.Tensor:
|
||||
"""批量 F.scaled_dot_product_attention fallback(可走硬件 kernel)。
|
||||
|
||||
构建块对角 causal mask,对整批序列一次 SDPA 调用。
|
||||
SDPA 可分发到 Flash Attention / mem-efficient attention kernel。
|
||||
is_causal=False + 显式 additive mask,规避 cudnnFlashAttnForward
|
||||
不支持 is_causal=True 的限制。
|
||||
|
||||
块对角 mask(seq1 len=3,seq2 len=2):
|
||||
s1,0 s1,1 s1,2 s2,0 s2,1
|
||||
s1,0 [ 0 -inf -inf -inf -inf ]
|
||||
s1,1 [ 0 0 -inf -inf -inf ]
|
||||
s1,2 [ 0 0 0 -inf -inf ]
|
||||
s2,0 [-inf -inf -inf 0 -inf ]
|
||||
s2,1 [-inf -inf -inf 0 0 ]
|
||||
|
||||
Args:
|
||||
query : [1, total_prefill_tokens, num_heads, head_dim]
|
||||
key : [1, total_prefill_tokens, num_kv_heads, head_dim]
|
||||
value : [1, total_prefill_tokens, num_kv_heads, head_dim]
|
||||
Returns:
|
||||
[1, total_prefill_tokens, num_heads, head_dim]
|
||||
"""
|
||||
import torch.nn.functional as F
|
||||
|
||||
assert attn_metadata.seq_lens is not None
|
||||
orig_dtype = query.dtype
|
||||
total_tokens = query.shape[1]
|
||||
|
||||
# ── 块对角 causal mask [T, T] ─────────────────────────────────────
|
||||
mask = torch.full(
|
||||
(total_tokens, total_tokens),
|
||||
float("-inf"),
|
||||
dtype=orig_dtype,
|
||||
device=query.device,
|
||||
)
|
||||
start = 0
|
||||
for seq_len in attn_metadata.seq_lens:
|
||||
end = start + seq_len
|
||||
mask[start:end, start:end] = torch.tril(
|
||||
torch.zeros(seq_len, seq_len, dtype=orig_dtype, device=query.device)
|
||||
)
|
||||
start = end
|
||||
|
||||
# ── [1, H, T, D] ──────────────────────────────────────────────────
|
||||
q_all = query.squeeze(0).permute(1, 0, 2).contiguous().unsqueeze(0)
|
||||
k_all = key.squeeze(0).permute(1, 0, 2).contiguous().unsqueeze(0)
|
||||
v_all = value.squeeze(0).permute(1, 0, 2).contiguous().unsqueeze(0)
|
||||
|
||||
# ── GQA:展开 KV heads ────────────────────────────────────────────
|
||||
if k_all.shape[1] != q_all.shape[1]:
|
||||
n = q_all.shape[1] // k_all.shape[1]
|
||||
k_all = k_all.repeat_interleave(n, dim=1).contiguous()
|
||||
v_all = v_all.repeat_interleave(n, dim=1).contiguous()
|
||||
|
||||
# ── F.scaled_dot_product_attention(可走硬件 kernel)─────────────
|
||||
# is_causal=False:避免 cudnnFlashAttnForward "not support causal mode"
|
||||
# attn_mask 传 additive float mask(非 bool),SDPA 选择 math/kernel 路径
|
||||
out = F.scaled_dot_product_attention(
|
||||
q_all, k_all, v_all,
|
||||
attn_mask=mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
scale=self.scale,
|
||||
)
|
||||
# [1, H, T, D] → [1, T, H, D]
|
||||
return out.squeeze(0).permute(1, 0, 2).contiguous().unsqueeze(0)
|
||||
|
||||
'''
|
||||
|
||||
OLD_XFORMER_BLOCK = """\
|
||||
self.attn_op = xops.fmha.flash.FwOp()
|
||||
if self.alibi_slopes is None:
|
||||
# Add the batch dimension.
|
||||
query = query.unsqueeze(0)
|
||||
key = key.unsqueeze(0)
|
||||
value = value.unsqueeze(0)
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_bias=attn_bias[0],
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
op = self.attn_op
|
||||
)
|
||||
return out.view_as(original_query)\
|
||||
"""
|
||||
|
||||
NEW_XFORMER_BLOCK = """\
|
||||
self.attn_op = xops.fmha.flash.FwOp()
|
||||
if self.alibi_slopes is None:
|
||||
# Add the batch dimension.
|
||||
query = query.unsqueeze(0)
|
||||
key = key.unsqueeze(0)
|
||||
value = value.unsqueeze(0)
|
||||
if self.head_size > 128:
|
||||
out = self._run_sdpa_fallback(query, key, value, attn_metadata)
|
||||
else:
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_bias=attn_bias[0],
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
op=self.attn_op,
|
||||
)
|
||||
return out.view_as(original_query)\
|
||||
"""
|
||||
|
||||
INJECT_ANCHOR = " def _run_memory_efficient_xformers_forward("
|
||||
|
||||
|
||||
def patch_file(path):
|
||||
with open(path, "r") as f:
|
||||
content = f.read()
|
||||
changed = False
|
||||
|
||||
if "_run_sdpa_fallback" in content:
|
||||
print(" [skip] _run_sdpa_fallback already present")
|
||||
elif INJECT_ANCHOR not in content:
|
||||
print(" [warn] inject anchor not found")
|
||||
else:
|
||||
content = content.replace(INJECT_ANCHOR, FALLBACK_METHOD + INJECT_ANCHOR, 1)
|
||||
print(" [ok] injected _run_sdpa_fallback (batch, F.sdpa kernel)")
|
||||
changed = True
|
||||
|
||||
if NEW_XFORMER_BLOCK in content:
|
||||
print(" [skip] dispatch block already patched")
|
||||
elif OLD_XFORMER_BLOCK in content:
|
||||
content = content.replace(OLD_XFORMER_BLOCK, NEW_XFORMER_BLOCK, 1)
|
||||
print(" [ok] patched dispatch block")
|
||||
changed = True
|
||||
else:
|
||||
print(" [warn] dispatch block anchor not found")
|
||||
|
||||
if changed:
|
||||
with open(path, "w") as f:
|
||||
f.write(content)
|
||||
print(f" Written: {path}")
|
||||
|
||||
|
||||
def main():
|
||||
print("=== patch_xformers_sdpa_batch_kernel (batch, F.sdpa + kernel dispatch) ===")
|
||||
print(f"Target: {XFORMERS_PATH}")
|
||||
patch_file(XFORMERS_PATH)
|
||||
print("\nDone.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
186
qwen3_6_scripts/patch_xformers_sdpa_seq.py
Normal file
186
qwen3_6_scripts/patch_xformers_sdpa_seq.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""
|
||||
策略:顺序(per-sequence)fallback — 纯 PyTorch 数学实现
|
||||
==========================================================
|
||||
逐条序列用 matmul + softmax 手写 attention,完全绕开所有硬件
|
||||
flash attention kernel(ixformer / cudnnFlashAttnForward)。
|
||||
|
||||
背景:
|
||||
Iluvatar cudnnFlashAttnForward 存在两个已知问题:
|
||||
1. 不支持 is_causal=True(报错)
|
||||
2. 使用 attn_mask 路径时数值结果不正确(静默错误,输出全为"!")
|
||||
与华为昇腾 910B4 上 llama.cpp --flash-attn off 修复同类问题的原理相同。
|
||||
纯数学路径(matmul + softmax)在任何 PyTorch 后端上结果都正确。
|
||||
|
||||
优点:
|
||||
数值正确,不依赖任何硬件特定 attention kernel。
|
||||
峰值显存 = max(seq_len)² × H × dtype_size,由 --max-model-len 控制。
|
||||
|
||||
缺点:
|
||||
并发请求的 prefill attention 串行执行。
|
||||
O(L²) 显存(无 flash attention 的 O(L) 优化)。
|
||||
|
||||
内存参考(fp16,H_local=6):
|
||||
max-model-len=4096 → 峰值 ~200 MB
|
||||
max-model-len=8192 → 峰值 ~800 MB
|
||||
max-model-len=16384 → 峰值 ~3.2 GB
|
||||
|
||||
Deploy:
|
||||
python3 modified_scripts/patch_xformers_sdpa_seq.py
|
||||
"""
|
||||
|
||||
XFORMERS_PATH = (
|
||||
"/usr/local/corex/lib64/python3/dist-packages/"
|
||||
"vllm/attention/backends/xformers.py"
|
||||
)
|
||||
|
||||
FALLBACK_METHOD = '''
|
||||
def _run_sdpa_fallback(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_metadata: "XFormersMetadata",
|
||||
) -> torch.Tensor:
|
||||
"""顺序纯数学 attention fallback。
|
||||
|
||||
完全绕开 ixformer / cudnnFlashAttnForward,用 matmul + softmax
|
||||
手写 attention。Iluvatar cudnnFlashAttnForward 的 attn_mask 路径
|
||||
存在静默数值错误(输出全为"!"),纯数学路径结果正确。
|
||||
|
||||
softmax 在 float32 下计算以防止 float16 溢出,结果转回原始 dtype。
|
||||
|
||||
Args:
|
||||
query : [1, total_prefill_tokens, num_heads, head_dim]
|
||||
key : [1, total_prefill_tokens, num_kv_heads, head_dim]
|
||||
value : [1, total_prefill_tokens, num_kv_heads, head_dim]
|
||||
Returns:
|
||||
[1, total_prefill_tokens, num_heads, head_dim]
|
||||
"""
|
||||
assert attn_metadata.seq_lens is not None
|
||||
orig_dtype = query.dtype
|
||||
|
||||
q_flat = query.squeeze(0) # [T, H, D]
|
||||
k_flat = key.squeeze(0) # [T, Hkv, D]
|
||||
v_flat = value.squeeze(0)
|
||||
|
||||
output = torch.empty_like(q_flat)
|
||||
start = 0
|
||||
for seq_len in attn_metadata.seq_lens:
|
||||
end = start + seq_len
|
||||
# [1, H, L, D]
|
||||
q_s = q_flat[start:end].permute(1, 0, 2).contiguous().unsqueeze(0)
|
||||
k_s = k_flat[start:end].permute(1, 0, 2).contiguous().unsqueeze(0)
|
||||
v_s = v_flat[start:end].permute(1, 0, 2).contiguous().unsqueeze(0)
|
||||
|
||||
# GQA:展开 KV heads 至与 query heads 一致
|
||||
if k_s.shape[1] != q_s.shape[1]:
|
||||
n = q_s.shape[1] // k_s.shape[1]
|
||||
k_s = k_s.repeat_interleave(n, dim=1).contiguous()
|
||||
v_s = v_s.repeat_interleave(n, dim=1).contiguous()
|
||||
|
||||
# 纯数学 attention:完全绕开硬件 flash attention kernel
|
||||
# [1, H, L, L]
|
||||
attn_w = torch.matmul(q_s.float(), k_s.float().transpose(-2, -1))
|
||||
attn_w = attn_w * self.scale
|
||||
|
||||
# 上三角填 -inf(future tokens)
|
||||
causal_mask = torch.triu(
|
||||
torch.ones(seq_len, seq_len, dtype=torch.bool, device=attn_w.device),
|
||||
diagonal=1,
|
||||
)
|
||||
attn_w = attn_w.masked_fill(causal_mask, float("-inf"))
|
||||
|
||||
# float32 softmax 防止 float16 溢出
|
||||
attn_w = torch.softmax(attn_w, dim=-1)
|
||||
|
||||
out_s = torch.matmul(attn_w, v_s.float()).to(orig_dtype)
|
||||
# [1, H, L, D] → [L, H, D]
|
||||
output[start:end] = out_s.squeeze(0).permute(1, 0, 2)
|
||||
start = end
|
||||
|
||||
return output.unsqueeze(0) # [1, T, H, D]
|
||||
|
||||
'''
|
||||
|
||||
OLD_XFORMER_BLOCK = """\
|
||||
self.attn_op = xops.fmha.flash.FwOp()
|
||||
if self.alibi_slopes is None:
|
||||
# Add the batch dimension.
|
||||
query = query.unsqueeze(0)
|
||||
key = key.unsqueeze(0)
|
||||
value = value.unsqueeze(0)
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_bias=attn_bias[0],
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
op = self.attn_op
|
||||
)
|
||||
return out.view_as(original_query)\
|
||||
"""
|
||||
|
||||
NEW_XFORMER_BLOCK = """\
|
||||
self.attn_op = xops.fmha.flash.FwOp()
|
||||
if self.alibi_slopes is None:
|
||||
# Add the batch dimension.
|
||||
query = query.unsqueeze(0)
|
||||
key = key.unsqueeze(0)
|
||||
value = value.unsqueeze(0)
|
||||
if self.head_size > 128:
|
||||
out = self._run_sdpa_fallback(query, key, value, attn_metadata)
|
||||
else:
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_bias=attn_bias[0],
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
op=self.attn_op,
|
||||
)
|
||||
return out.view_as(original_query)\
|
||||
"""
|
||||
|
||||
INJECT_ANCHOR = " def _run_memory_efficient_xformers_forward("
|
||||
|
||||
|
||||
def patch_file(path):
|
||||
with open(path, "r") as f:
|
||||
content = f.read()
|
||||
changed = False
|
||||
|
||||
if "_run_sdpa_fallback" in content:
|
||||
print(" [skip] _run_sdpa_fallback already present")
|
||||
elif INJECT_ANCHOR not in content:
|
||||
print(" [warn] inject anchor not found")
|
||||
else:
|
||||
content = content.replace(INJECT_ANCHOR, FALLBACK_METHOD + INJECT_ANCHOR, 1)
|
||||
print(" [ok] injected _run_sdpa_fallback (sequential, pure-math)")
|
||||
changed = True
|
||||
|
||||
if NEW_XFORMER_BLOCK in content:
|
||||
print(" [skip] dispatch block already patched")
|
||||
elif OLD_XFORMER_BLOCK in content:
|
||||
content = content.replace(OLD_XFORMER_BLOCK, NEW_XFORMER_BLOCK, 1)
|
||||
print(" [ok] patched dispatch block")
|
||||
changed = True
|
||||
else:
|
||||
print(" [warn] dispatch block anchor not found")
|
||||
|
||||
if changed:
|
||||
with open(path, "w") as f:
|
||||
f.write(content)
|
||||
print(f" Written: {path}")
|
||||
|
||||
|
||||
def main():
|
||||
print("=== patch_xformers_sdpa_seq (sequential, pure-math) ===")
|
||||
print(f"Target: {XFORMERS_PATH}")
|
||||
patch_file(XFORMERS_PATH)
|
||||
print("\nDone.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
181
qwen3_6_scripts/patch_xformers_sdpa_seq_kernel.py
Normal file
181
qwen3_6_scripts/patch_xformers_sdpa_seq_kernel.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""
|
||||
策略:顺序(per-sequence)— F.scaled_dot_product_attention,可走硬件 kernel
|
||||
=============================================================================
|
||||
逐条序列调用 F.scaled_dot_product_attention,is_causal=False + 显式因果 mask。
|
||||
与 patch_xformers_sdpa_seq.py(纯 matmul)的区别:
|
||||
SDPA 可分发到 Flash Attention / mem-efficient attention kernel,
|
||||
而纯 matmul 固定走 cublas。
|
||||
|
||||
硬件限制(BI-V100):
|
||||
cudnnFlashAttnForward 不支持 is_causal=True(直接报错)。
|
||||
必须使用 is_causal=False + 显式 additive causal mask。
|
||||
每条序列单独构造上三角 -inf mask,peak 显存 = max(seq_len)² × dtype,
|
||||
比 batch 版的 total_tokens² 小得多。
|
||||
|
||||
与 batch_kernel 的对比:
|
||||
seq_kernel: 显存小,peak = max_single_seq²;并发 prefill 串行排队
|
||||
batch_kernel: 显存大,peak = total_tokens²;并发 prefill 一次并行处理,
|
||||
通过 --max-num-batched-tokens 控制 total_tokens 上限
|
||||
|
||||
Deploy:
|
||||
python3 modified_scripts/patch_xformers_sdpa_seq_kernel.py
|
||||
"""
|
||||
|
||||
XFORMERS_PATH = (
|
||||
"/usr/local/corex/lib64/python3/dist-packages/"
|
||||
"vllm/attention/backends/xformers.py"
|
||||
)
|
||||
|
||||
FALLBACK_METHOD = '''
|
||||
def _run_sdpa_fallback(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_metadata: "XFormersMetadata",
|
||||
) -> torch.Tensor:
|
||||
"""顺序 F.scaled_dot_product_attention fallback(可走硬件 kernel)。
|
||||
|
||||
逐条序列调用 SDPA,is_causal=False + 显式上三角 additive mask。
|
||||
cudnnFlashAttnForward 不支持 is_causal=True,必须用显式 mask。
|
||||
逐序列构造 mask,peak 显存 = max(seq_len)² × dtype(远小于 batch 版)。
|
||||
|
||||
Args:
|
||||
query : [1, total_prefill_tokens, num_heads, head_dim]
|
||||
key : [1, total_prefill_tokens, num_kv_heads, head_dim]
|
||||
value : [1, total_prefill_tokens, num_kv_heads, head_dim]
|
||||
Returns:
|
||||
[1, total_prefill_tokens, num_heads, head_dim]
|
||||
"""
|
||||
import torch.nn.functional as F
|
||||
|
||||
assert attn_metadata.seq_lens is not None
|
||||
orig_dtype = query.dtype
|
||||
|
||||
q_flat = query.squeeze(0) # [T, H, D]
|
||||
k_flat = key.squeeze(0) # [T, Hkv, D]
|
||||
v_flat = value.squeeze(0)
|
||||
|
||||
output = torch.empty_like(q_flat)
|
||||
start = 0
|
||||
for seq_len in attn_metadata.seq_lens:
|
||||
end = start + seq_len
|
||||
# [1, H, L, D]
|
||||
q_s = q_flat[start:end].permute(1, 0, 2).contiguous().unsqueeze(0)
|
||||
k_s = k_flat[start:end].permute(1, 0, 2).contiguous().unsqueeze(0)
|
||||
v_s = v_flat[start:end].permute(1, 0, 2).contiguous().unsqueeze(0)
|
||||
|
||||
# GQA:展开 KV heads
|
||||
if k_s.shape[1] != q_s.shape[1]:
|
||||
n = q_s.shape[1] // k_s.shape[1]
|
||||
k_s = k_s.repeat_interleave(n, dim=1).contiguous()
|
||||
v_s = v_s.repeat_interleave(n, dim=1).contiguous()
|
||||
|
||||
# 逐序列因果 mask [L, L],上三角 -inf
|
||||
causal_mask = torch.tril(
|
||||
torch.zeros(seq_len, seq_len, dtype=orig_dtype, device=q_s.device)
|
||||
)
|
||||
causal_mask = causal_mask.masked_fill(
|
||||
torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool,
|
||||
device=q_s.device), diagonal=1),
|
||||
float("-inf"),
|
||||
)
|
||||
|
||||
# is_causal=False + 显式 mask,规避 cudnnFlashAttnForward 不支持 is_causal=True
|
||||
out_s = F.scaled_dot_product_attention(
|
||||
q_s, k_s, v_s,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
scale=self.scale,
|
||||
)
|
||||
# [1, H, L, D] → [L, H, D]
|
||||
output[start:end] = out_s.squeeze(0).permute(1, 0, 2).to(orig_dtype)
|
||||
start = end
|
||||
|
||||
return output.unsqueeze(0) # [1, T, H, D]
|
||||
|
||||
'''
|
||||
|
||||
OLD_XFORMER_BLOCK = """\
|
||||
self.attn_op = xops.fmha.flash.FwOp()
|
||||
if self.alibi_slopes is None:
|
||||
# Add the batch dimension.
|
||||
query = query.unsqueeze(0)
|
||||
key = key.unsqueeze(0)
|
||||
value = value.unsqueeze(0)
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_bias=attn_bias[0],
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
op = self.attn_op
|
||||
)
|
||||
return out.view_as(original_query)\
|
||||
"""
|
||||
|
||||
NEW_XFORMER_BLOCK = """\
|
||||
self.attn_op = xops.fmha.flash.FwOp()
|
||||
if self.alibi_slopes is None:
|
||||
# Add the batch dimension.
|
||||
query = query.unsqueeze(0)
|
||||
key = key.unsqueeze(0)
|
||||
value = value.unsqueeze(0)
|
||||
if self.head_size > 128:
|
||||
out = self._run_sdpa_fallback(query, key, value, attn_metadata)
|
||||
else:
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_bias=attn_bias[0],
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
op=self.attn_op,
|
||||
)
|
||||
return out.view_as(original_query)\
|
||||
"""
|
||||
|
||||
INJECT_ANCHOR = " def _run_memory_efficient_xformers_forward("
|
||||
|
||||
|
||||
def patch_file(path):
|
||||
with open(path, "r") as f:
|
||||
content = f.read()
|
||||
changed = False
|
||||
|
||||
if "_run_sdpa_fallback" in content:
|
||||
print(" [skip] _run_sdpa_fallback already present")
|
||||
elif INJECT_ANCHOR not in content:
|
||||
print(" [warn] inject anchor not found")
|
||||
else:
|
||||
content = content.replace(INJECT_ANCHOR, FALLBACK_METHOD + INJECT_ANCHOR, 1)
|
||||
print(" [ok] injected _run_sdpa_fallback (seq, F.sdpa kernel)")
|
||||
changed = True
|
||||
|
||||
if NEW_XFORMER_BLOCK in content:
|
||||
print(" [skip] dispatch block already patched")
|
||||
elif OLD_XFORMER_BLOCK in content:
|
||||
content = content.replace(OLD_XFORMER_BLOCK, NEW_XFORMER_BLOCK, 1)
|
||||
print(" [ok] patched dispatch block")
|
||||
changed = True
|
||||
else:
|
||||
print(" [warn] dispatch block anchor not found")
|
||||
|
||||
if changed:
|
||||
with open(path, "w") as f:
|
||||
f.write(content)
|
||||
print(f" Written: {path}")
|
||||
|
||||
|
||||
def main():
|
||||
print("=== patch_xformers_sdpa_seq_kernel (seq, F.sdpa + kernel dispatch) ===")
|
||||
print(f"Target: {XFORMERS_PATH}")
|
||||
patch_file(XFORMERS_PATH)
|
||||
print("\nDone.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
894
qwen3_6_scripts/qwen3_5.py
Normal file
894
qwen3_6_scripts/qwen3_5.py
Normal file
@@ -0,0 +1,894 @@
|
||||
# Inference-only Qwen3.6-27B (Qwen3_5 architecture) for Iluvatar BI-V100.
|
||||
# Pure-PyTorch DeltaNet (no fla / causal_conv1d dependency).
|
||||
# Text-only (no VL, no MTP).
|
||||
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, sharded_weight_loader)
|
||||
from vllm.model_executor.models.mamba_cache import MambaCacheManager
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
|
||||
_get_graph_batch_size)
|
||||
|
||||
from vllm.model_executor.models.interfaces import HasInnerState, SupportsLoRA
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pure-PyTorch DeltaNet kernels (fallbacks from transformers 5.2.0)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor:
|
||||
return x * torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
|
||||
|
||||
|
||||
def _torch_causal_conv1d_update(
|
||||
hidden_states: torch.Tensor, # (batch, channels, seq=1)
|
||||
conv_state: torch.Tensor, # (batch, channels, state_len) modified in-place
|
||||
weight: torch.Tensor, # (channels, kernel_size)
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = None,
|
||||
) -> torch.Tensor:
|
||||
_, channels, seq_len = hidden_states.shape
|
||||
state_len = conv_state.shape[-1]
|
||||
cat = torch.cat([conv_state, hidden_states], dim=-1).to(weight.dtype)
|
||||
conv_state.copy_(cat[:, :, -state_len:])
|
||||
out = F.conv1d(cat, weight.unsqueeze(1), bias, padding=0, groups=channels)
|
||||
out = out[:, :, -seq_len:]
|
||||
if activation is not None:
|
||||
out = F.silu(out)
|
||||
return out.to(hidden_states.dtype)
|
||||
|
||||
|
||||
def _torch_chunk_gated_delta_rule(
|
||||
query: torch.Tensor, # (batch, seq, num_heads, head_k_dim)
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor, # (batch, seq, num_heads, head_v_dim)
|
||||
g: torch.Tensor, # (batch, seq, num_heads)
|
||||
beta: torch.Tensor, # (batch, seq, num_heads)
|
||||
chunk_size: int = 64,
|
||||
initial_state: Optional[torch.Tensor] = None,
|
||||
output_final_state: bool = False,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
initial_dtype = query.dtype
|
||||
if use_qk_l2norm_in_kernel:
|
||||
query = _l2norm(query)
|
||||
key = _l2norm(key)
|
||||
# Transpose to (batch, num_heads, seq, dim)
|
||||
query, key, value, beta, g = [
|
||||
x.transpose(1, 2).contiguous().to(torch.float32)
|
||||
for x in (query, key, value, beta, g)
|
||||
]
|
||||
batch, num_heads, seq_len, k_dim = key.shape
|
||||
v_dim = value.shape[-1]
|
||||
pad = (chunk_size - seq_len % chunk_size) % chunk_size
|
||||
query = F.pad(query, (0, 0, 0, pad))
|
||||
key = F.pad(key, (0, 0, 0, pad))
|
||||
value = F.pad(value, (0, 0, 0, pad))
|
||||
beta = F.pad(beta, (0, pad))
|
||||
g = F.pad(g, (0, pad))
|
||||
total_len = seq_len + pad
|
||||
scale = 1.0 / (query.shape[-1] ** 0.5)
|
||||
query = query * scale
|
||||
|
||||
v_beta = value * beta.unsqueeze(-1)
|
||||
k_beta = key * beta.unsqueeze(-1)
|
||||
query, key, value, k_beta, v_beta = [
|
||||
x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1])
|
||||
for x in (query, key, value, k_beta, v_beta)
|
||||
]
|
||||
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
|
||||
mask_upper = torch.triu(
|
||||
torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device),
|
||||
diagonal=0)
|
||||
|
||||
g = g.cumsum(dim=-1)
|
||||
decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril()
|
||||
attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask_upper, 0)
|
||||
for i in range(1, chunk_size):
|
||||
row = attn[..., i, :i].clone()
|
||||
sub = attn[..., :i, :i].clone()
|
||||
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
|
||||
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
|
||||
value = attn @ v_beta
|
||||
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
|
||||
|
||||
last_state = (
|
||||
torch.zeros(batch, num_heads, k_dim, v_dim, dtype=value.dtype, device=value.device)
|
||||
if initial_state is None
|
||||
else initial_state.to(value)
|
||||
)
|
||||
core_out = torch.zeros_like(value)
|
||||
mask_upper2 = torch.triu(
|
||||
torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device),
|
||||
diagonal=1)
|
||||
|
||||
for i in range(total_len // chunk_size):
|
||||
q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
|
||||
attn_i = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask_upper2, 0)
|
||||
v_prime = k_cumdecay[:, :, i] @ last_state
|
||||
v_new = v_i - v_prime
|
||||
attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_state
|
||||
core_out[:, :, i] = attn_inter + attn_i @ v_new
|
||||
last_state = (
|
||||
last_state * g[:, :, i, -1, None, None].exp()
|
||||
+ (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None])
|
||||
.transpose(-1, -2) @ v_new
|
||||
)
|
||||
|
||||
if not output_final_state:
|
||||
last_state = None
|
||||
core_out = core_out.reshape(batch, num_heads, -1, v_dim)[:, :, :seq_len]
|
||||
core_out = core_out.transpose(1, 2).contiguous().to(initial_dtype)
|
||||
return core_out, last_state
|
||||
|
||||
def _torch_recurrent_gated_delta_rule(
|
||||
query: torch.Tensor, # (batch, 1, num_heads, head_k_dim)
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
g: torch.Tensor, # (batch, 1, num_heads)
|
||||
beta: torch.Tensor,
|
||||
initial_state: Optional[torch.Tensor] = None,
|
||||
output_final_state: bool = False,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
initial_dtype = query.dtype
|
||||
if use_qk_l2norm_in_kernel:
|
||||
query = _l2norm(query)
|
||||
key = _l2norm(key)
|
||||
query, key, value, beta, g = [
|
||||
x.transpose(1, 2).contiguous().to(torch.float32)
|
||||
for x in (query, key, value, beta, g)
|
||||
]
|
||||
batch, num_heads, seq_len, k_dim = key.shape
|
||||
v_dim = value.shape[-1]
|
||||
scale = 1.0 / (query.shape[-1] ** 0.5)
|
||||
query = query * scale
|
||||
|
||||
core_out = torch.zeros(batch, num_heads, seq_len, v_dim,
|
||||
dtype=value.dtype, device=value.device)
|
||||
last_state = (
|
||||
torch.zeros(batch, num_heads, k_dim, v_dim,
|
||||
dtype=value.dtype, device=value.device)
|
||||
if initial_state is None
|
||||
else initial_state.to(value)
|
||||
)
|
||||
for t in range(seq_len):
|
||||
q_t = query[:, :, t]
|
||||
k_t = key[:, :, t]
|
||||
v_t = value[:, :, t]
|
||||
g_t = g[:, :, t].exp().unsqueeze(-1).unsqueeze(-1)
|
||||
beta_t = beta[:, :, t].unsqueeze(-1)
|
||||
last_state = last_state * g_t
|
||||
kv_mem = (last_state * k_t.unsqueeze(-1)).sum(dim=-2)
|
||||
delta = (v_t - kv_mem) * beta_t
|
||||
last_state = last_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
|
||||
core_out[:, :, t] = (last_state * q_t.unsqueeze(-1)).sum(dim=-2)
|
||||
|
||||
if not output_final_state:
|
||||
last_state = None
|
||||
core_out = core_out.transpose(1, 2).contiguous().to(initial_dtype)
|
||||
return core_out, last_state
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Gated RMSNorm (for DeltaNet output normalisation)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class Qwen3_5RMSNormGated(nn.Module):
|
||||
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor,
|
||||
gate: torch.Tensor) -> torch.Tensor:
|
||||
input_dtype = hidden_states.dtype
|
||||
hs = hidden_states.to(torch.float32)
|
||||
variance = hs.pow(2).mean(-1, keepdim=True)
|
||||
hs = hs * torch.rsqrt(variance + self.variance_epsilon)
|
||||
hs = self.weight * hs.to(input_dtype)
|
||||
return (hs * F.silu(gate.to(torch.float32))).to(input_dtype)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Gated DeltaNet (linear_attention layers)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class GatedDeltaNet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
text_cfg,
|
||||
layer_idx: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
self.hidden_size = text_cfg.hidden_size
|
||||
self.num_v_heads = text_cfg.linear_num_value_heads # 48
|
||||
self.num_k_heads = text_cfg.linear_num_key_heads # 16
|
||||
self.head_k_dim = text_cfg.linear_key_head_dim # 128
|
||||
self.head_v_dim = text_cfg.linear_value_head_dim # 128
|
||||
self.key_dim = self.num_k_heads * self.head_k_dim # 2048
|
||||
self.value_dim = self.num_v_heads * self.head_v_dim # 6144
|
||||
self.conv_dim = self.key_dim * 2 + self.value_dim # 10240
|
||||
self.conv_kernel_size = text_cfg.linear_conv_kernel_dim # 4
|
||||
self.head_expand_ratio = self.num_v_heads // self.num_k_heads # 3
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
# Sharded projections — MergedColumnParallelLinear shards each of q/k/v
|
||||
# independently so each TP rank gets [q_shard, k_shard, v_shard].
|
||||
# Plain ColumnParallelLinear would shard contiguously, giving rank 0
|
||||
# [q_all, k_partial] — completely wrong Q/K/V after the split below.
|
||||
self.in_proj_qkv = MergedColumnParallelLinear(
|
||||
self.hidden_size, [self.key_dim, self.key_dim, self.value_dim],
|
||||
bias=False, quant_config=quant_config)
|
||||
self.in_proj_z = ColumnParallelLinear(
|
||||
self.hidden_size, self.value_dim,
|
||||
bias=False, quant_config=quant_config)
|
||||
self.in_proj_b = ColumnParallelLinear(
|
||||
self.hidden_size, self.num_v_heads,
|
||||
bias=False, quant_config=quant_config)
|
||||
self.in_proj_a = ColumnParallelLinear(
|
||||
self.hidden_size, self.num_v_heads,
|
||||
bias=False, quant_config=quant_config)
|
||||
self.out_proj = RowParallelLinear(
|
||||
self.value_dim, self.hidden_size,
|
||||
bias=False, quant_config=quant_config)
|
||||
|
||||
# Depthwise conv weight — sharded along channel dim (dim 0)
|
||||
local_conv_dim = self.conv_dim // tp_size
|
||||
self.conv1d_weight = nn.Parameter(
|
||||
torch.empty(local_conv_dim, 1, self.conv_kernel_size))
|
||||
set_weight_attrs(self.conv1d_weight, {
|
||||
"weight_loader": self._conv1d_weight_loader})
|
||||
|
||||
# Per-head scalar parameters — sharded along dim 0
|
||||
local_num_v = self.num_v_heads // tp_size
|
||||
self.A_log = nn.Parameter(torch.zeros(local_num_v))
|
||||
self.dt_bias = nn.Parameter(torch.zeros(local_num_v))
|
||||
set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)})
|
||||
set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
|
||||
|
||||
# Gated RMSNorm on head_v_dim — replicated (head_v_dim=128 is small)
|
||||
self.norm = Qwen3_5RMSNormGated(self.head_v_dim,
|
||||
eps=text_cfg.rms_norm_eps)
|
||||
|
||||
def _conv1d_weight_loader(self, param: torch.Tensor,
|
||||
loaded_weight: torch.Tensor) -> None:
|
||||
# loaded_weight: (conv_dim=10240, 1, kernel) ordered as [q, k, v] channels
|
||||
# Must gather channels in the same non-contiguous pattern that
|
||||
# MergedColumnParallelLinear uses for in_proj_qkv, so that each rank's
|
||||
# conv1d_weight[i] applies to the correct in_proj_qkv output channel.
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
key_local = self.key_dim // tp_size # 512 with TP=4
|
||||
val_local = self.value_dim // tp_size # 1536 with TP=4
|
||||
q_s = loaded_weight[tp_rank * key_local : (tp_rank + 1) * key_local]
|
||||
k_s = loaded_weight[self.key_dim + tp_rank * key_local :
|
||||
self.key_dim + (tp_rank + 1) * key_local]
|
||||
v_s = loaded_weight[2 * self.key_dim + tp_rank * val_local :
|
||||
2 * self.key_dim + (tp_rank + 1) * val_local]
|
||||
param.data.copy_(torch.cat([q_s, k_s, v_s], dim=0))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor, # (total_tokens, hidden_size)
|
||||
attn_metadata: AttentionMetadata,
|
||||
conv_state: torch.Tensor, # (batch, local_conv_dim, kernel-1) in-place
|
||||
temporal_state: torch.Tensor, # (batch, local_v_heads, k_dim, v_dim) in-place
|
||||
) -> torch.Tensor:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
local_key_dim = self.key_dim // tp_size
|
||||
local_val_dim = self.value_dim // tp_size
|
||||
local_num_v = self.num_v_heads // tp_size
|
||||
local_num_k = self.num_k_heads // tp_size
|
||||
local_conv_dim = self.conv_dim // tp_size
|
||||
|
||||
is_prefill = attn_metadata.num_prefill_tokens > 0
|
||||
|
||||
# Compute all projections for every token at once (batched, efficient)
|
||||
mixed_qkv_all, _ = self.in_proj_qkv(hidden_states) # (total, local_conv_dim)
|
||||
z_all, _ = self.in_proj_z(hidden_states) # (total, local_val_dim)
|
||||
b_all, _ = self.in_proj_b(hidden_states) # (total, local_num_v)
|
||||
a_all, _ = self.in_proj_a(hidden_states) # (total, local_num_v)
|
||||
|
||||
if is_prefill:
|
||||
seq_starts = attn_metadata.query_start_loc.tolist()
|
||||
outputs = []
|
||||
state_len = self.conv_kernel_size - 1
|
||||
weight_2d = self.conv1d_weight.squeeze(1) # (local_conv_dim, kernel)
|
||||
|
||||
for si in range(len(seq_starts) - 1):
|
||||
s, e = int(seq_starts[si]), int(seq_starts[si + 1])
|
||||
seq_len = e - s
|
||||
|
||||
# Shape: (1, local_conv_dim, seq_len)
|
||||
mixed_qkv = (mixed_qkv_all[s:e]
|
||||
.transpose(0, 1).unsqueeze(0)
|
||||
.to(weight_2d.dtype))
|
||||
|
||||
# Save conv state (last state_len positions)
|
||||
if seq_len >= state_len:
|
||||
conv_state[si].copy_(mixed_qkv[0, :, -state_len:])
|
||||
else:
|
||||
conv_state[si, :, state_len - seq_len:].copy_(
|
||||
mixed_qkv[0])
|
||||
conv_state[si, :, :state_len - seq_len] = 0
|
||||
|
||||
# Causal conv (left-pad with zeros, then convolve)
|
||||
padded = F.pad(mixed_qkv, (state_len, 0))
|
||||
mixed_qkv_conv = F.conv1d(
|
||||
padded, self.conv1d_weight,
|
||||
bias=None, padding=0, groups=local_conv_dim)
|
||||
mixed_qkv_conv = F.silu(mixed_qkv_conv)
|
||||
# (1, seq_len, local_conv_dim)
|
||||
mixed_qkv_conv = mixed_qkv_conv.squeeze(0).transpose(0, 1).unsqueeze(0)
|
||||
|
||||
q, k, v = torch.split(
|
||||
mixed_qkv_conv,
|
||||
[local_key_dim, local_key_dim, local_val_dim], dim=-1)
|
||||
q = q.reshape(1, seq_len, local_num_k, self.head_k_dim)
|
||||
k = k.reshape(1, seq_len, local_num_k, self.head_k_dim)
|
||||
v = v.reshape(1, seq_len, local_num_v, self.head_v_dim)
|
||||
|
||||
beta = b_all[s:e].sigmoid().unsqueeze(0) # (1, seq_len, local_num_v)
|
||||
g = (-self.A_log.float().exp()
|
||||
* F.softplus(a_all[s:e].float() + self.dt_bias)
|
||||
).unsqueeze(0) # (1, seq_len, local_num_v)
|
||||
|
||||
# Expand k/q to match num_v_heads
|
||||
q = q.repeat_interleave(self.head_expand_ratio, dim=2)
|
||||
k = k.repeat_interleave(self.head_expand_ratio, dim=2)
|
||||
|
||||
core_out, last_state = _torch_chunk_gated_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=temporal_state[si:si + 1],
|
||||
output_final_state=True,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
if last_state is not None:
|
||||
temporal_state[si].copy_(last_state[0])
|
||||
|
||||
# Gate + norm + output proj
|
||||
z = z_all[s:e].reshape(seq_len, local_num_v, self.head_v_dim)
|
||||
core_out = core_out.reshape(seq_len, local_num_v, self.head_v_dim)
|
||||
normed = self.norm(
|
||||
core_out.reshape(-1, self.head_v_dim),
|
||||
z.reshape(-1, self.head_v_dim))
|
||||
normed = normed.reshape(seq_len, -1)
|
||||
out, _ = self.out_proj(normed)
|
||||
outputs.append(out)
|
||||
|
||||
result = torch.cat(outputs, dim=0)
|
||||
assert not torch.isnan(result).any(), f"NaN in prefill layer {self.layer_idx}"
|
||||
return result
|
||||
|
||||
else:
|
||||
# Decode: one token per sequence
|
||||
num_seqs = hidden_states.shape[0]
|
||||
weight_2d = self.conv1d_weight.squeeze(1)
|
||||
|
||||
# (num_seqs, local_conv_dim, 1)
|
||||
mixed_qkv = (mixed_qkv_all
|
||||
.to(weight_2d.dtype)
|
||||
.unsqueeze(-1))
|
||||
|
||||
mixed_qkv_conv = _torch_causal_conv1d_update(
|
||||
mixed_qkv, conv_state, weight_2d,
|
||||
bias=None, activation='silu')
|
||||
# (num_seqs, local_conv_dim, 1) → (num_seqs, 1, local_conv_dim)
|
||||
mixed_qkv_conv = mixed_qkv_conv.squeeze(-1).unsqueeze(1)
|
||||
|
||||
q, k, v = torch.split(
|
||||
mixed_qkv_conv,
|
||||
[local_key_dim, local_key_dim, local_val_dim], dim=-1)
|
||||
q = q.reshape(num_seqs, 1, local_num_k, self.head_k_dim)
|
||||
k = k.reshape(num_seqs, 1, local_num_k, self.head_k_dim)
|
||||
v = v.reshape(num_seqs, 1, local_num_v, self.head_v_dim)
|
||||
|
||||
beta = b_all.sigmoid().unsqueeze(1) # (num_seqs, 1, local_num_v)
|
||||
g = (-self.A_log.float().exp()
|
||||
* F.softplus(a_all.float() + self.dt_bias)
|
||||
).unsqueeze(1) # (num_seqs, 1, local_num_v)
|
||||
|
||||
q = q.repeat_interleave(self.head_expand_ratio, dim=2)
|
||||
k = k.repeat_interleave(self.head_expand_ratio, dim=2)
|
||||
|
||||
core_out, last_state = _torch_recurrent_gated_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=temporal_state,
|
||||
output_final_state=True,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
if last_state is not None:
|
||||
temporal_state.copy_(last_state)
|
||||
|
||||
z = z_all.reshape(num_seqs, local_num_v, self.head_v_dim)
|
||||
core_out = core_out.reshape(num_seqs, local_num_v, self.head_v_dim)
|
||||
normed = self.norm(
|
||||
core_out.reshape(-1, self.head_v_dim),
|
||||
z.reshape(-1, self.head_v_dim))
|
||||
normed = normed.reshape(num_seqs, -1)
|
||||
out, _ = self.out_proj(normed)
|
||||
assert not torch.isnan(out).any(), f"NaN in layer {self.layer_idx}"
|
||||
return out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Full Attention (with gated q — unique to Qwen3.5)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class Qwen3_5FullAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
text_cfg,
|
||||
layer_idx: int,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
self.hidden_size = text_cfg.hidden_size # 5120
|
||||
self.num_heads = text_cfg.num_attention_heads # 24
|
||||
self.num_kv_heads = text_cfg.num_key_value_heads # 4
|
||||
self.head_dim = text_cfg.head_dim # 256
|
||||
self.rms_norm_eps = text_cfg.rms_norm_eps
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.local_num_heads = self.num_heads // tp_size
|
||||
self.local_num_kv_heads = max(1, self.num_kv_heads // tp_size)
|
||||
self.local_q_dim = self.local_num_heads * self.head_dim
|
||||
self.local_kv_dim = self.local_num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
|
||||
# q_proj includes gate: output = num_heads * head_dim * 2
|
||||
self.q_proj = ColumnParallelLinear(
|
||||
self.hidden_size, self.num_heads * self.head_dim * 2,
|
||||
bias=False, quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_proj")
|
||||
self.k_proj = ColumnParallelLinear(
|
||||
self.hidden_size, self.num_kv_heads * self.head_dim,
|
||||
bias=False, quant_config=quant_config,
|
||||
prefix=f"{prefix}.k_proj")
|
||||
self.v_proj = ColumnParallelLinear(
|
||||
self.hidden_size, self.num_kv_heads * self.head_dim,
|
||||
bias=False, quant_config=quant_config,
|
||||
prefix=f"{prefix}.v_proj")
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.num_heads * self.head_dim, self.hidden_size,
|
||||
bias=False, quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj")
|
||||
|
||||
self.q_norm = GemmaRMSNorm(self.head_dim, eps=self.rms_norm_eps)
|
||||
self.k_norm = GemmaRMSNorm(self.head_dim, eps=self.rms_norm_eps)
|
||||
|
||||
# Partial RoPE: rotary_dim = head_dim * partial_rotary_factor = 256 * 0.25 = 64
|
||||
rope_params = getattr(text_cfg, "rope_parameters", {}) or {}
|
||||
rope_theta = rope_params.get("rope_theta", 10_000_000)
|
||||
partial_factor = rope_params.get("partial_rotary_factor", 0.25)
|
||||
rotary_dim = int(self.head_dim * partial_factor)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=rotary_dim,
|
||||
max_position=text_cfg.max_position_embeddings,
|
||||
base=rope_theta,
|
||||
)
|
||||
|
||||
self.attn = Attention(
|
||||
self.local_num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.local_num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
total_tokens = hidden_states.shape[0]
|
||||
|
||||
# q_proj output includes gate (dim doubled)
|
||||
qg, _ = self.q_proj(hidden_states) # (total, local_num_heads * head_dim * 2)
|
||||
qg = qg.view(total_tokens, self.local_num_heads, self.head_dim * 2)
|
||||
q = qg[:, :, :self.head_dim].reshape(total_tokens, -1)
|
||||
gate = qg[:, :, self.head_dim:].reshape(total_tokens, -1)
|
||||
|
||||
k, _ = self.k_proj(hidden_states) # (total, local_kv_dim)
|
||||
v, _ = self.v_proj(hidden_states)
|
||||
|
||||
# Per-head RMSNorm
|
||||
q = self.q_norm.forward_cuda(
|
||||
q.view(total_tokens, self.local_num_heads, self.head_dim)
|
||||
.contiguous()).view(total_tokens, -1)
|
||||
k = self.k_norm.forward_cuda(
|
||||
k.view(total_tokens, self.local_num_kv_heads, self.head_dim)
|
||||
.contiguous()).view(total_tokens, -1)
|
||||
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_out = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
|
||||
# Multiply by sigmoid gate before output projection
|
||||
attn_out = attn_out * torch.sigmoid(gate.float()).to(attn_out.dtype)
|
||||
output, _ = self.o_proj(attn_out)
|
||||
return output
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MLP (SwiGLU, same as Qwen2/Qwen3)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class Qwen3_5MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False, quant_config=quant_config)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size, hidden_size,
|
||||
bias=False, quant_config=quant_config)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}")
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Decoder layer (dispatches to GatedDeltaNet or Qwen3_5FullAttention)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class Qwen3_5DecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
text_cfg,
|
||||
layer_idx: int,
|
||||
layer_type: str,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.layer_type = layer_type
|
||||
self.input_layernorm = GemmaRMSNorm(text_cfg.hidden_size,
|
||||
eps=text_cfg.rms_norm_eps)
|
||||
self.post_attention_layernorm = GemmaRMSNorm(text_cfg.hidden_size,
|
||||
eps=text_cfg.rms_norm_eps)
|
||||
|
||||
if layer_type == "linear_attention":
|
||||
self.linear_attn = GatedDeltaNet(text_cfg, layer_idx,
|
||||
quant_config=quant_config)
|
||||
else:
|
||||
self.self_attn = Qwen3_5FullAttention(
|
||||
text_cfg, layer_idx,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"layers.{layer_idx}.self_attn",
|
||||
)
|
||||
|
||||
self.mlp = Qwen3_5MLP(
|
||||
hidden_size=text_cfg.hidden_size,
|
||||
intermediate_size=text_cfg.intermediate_size,
|
||||
hidden_act=text_cfg.hidden_act,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
# Only for linear_attention layers:
|
||||
conv_state: Optional[torch.Tensor] = None,
|
||||
temporal_state: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
if self.layer_type == "linear_attention":
|
||||
hidden_states = self.linear_attn(
|
||||
hidden_states, attn_metadata, conv_state, temporal_state)
|
||||
else:
|
||||
hidden_states = self.self_attn(
|
||||
positions, hidden_states, kv_cache, attn_metadata)
|
||||
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Full transformer model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class Qwen3_5Model(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
text_cfg,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.text_cfg = text_cfg
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
text_cfg.vocab_size, text_cfg.hidden_size)
|
||||
self.layers = nn.ModuleList([
|
||||
Qwen3_5DecoderLayer(
|
||||
text_cfg, i, text_cfg.layer_types[i],
|
||||
cache_config=cache_config, quant_config=quant_config)
|
||||
for i in range(text_cfg.num_hidden_layers)
|
||||
])
|
||||
self.norm = GemmaRMSNorm(text_cfg.hidden_size, eps=text_cfg.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
conv_states: torch.Tensor, # (num_linear_layers, batch, ...)
|
||||
temporal_states: torch.Tensor, # (num_linear_layers, batch, ...)
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
|
||||
attn_idx = 0
|
||||
linear_idx = 0
|
||||
for layer in self.layers:
|
||||
if layer.layer_type == "linear_attention":
|
||||
hidden_states, residual = layer(
|
||||
positions, hidden_states,
|
||||
kv_cache=None,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=residual,
|
||||
conv_state=conv_states[linear_idx],
|
||||
temporal_state=temporal_states[linear_idx],
|
||||
)
|
||||
linear_idx += 1
|
||||
else:
|
||||
kv_cache = kv_caches[attn_idx]
|
||||
hidden_states, residual = layer(
|
||||
positions, hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=residual,
|
||||
)
|
||||
attn_idx += 1
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Top-level CausalLM wrapper with MambaCacheManager
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class Qwen3_5ForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
|
||||
has_inner_state = True
|
||||
supports_lora = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
}
|
||||
|
||||
supported_lora_modules = [
|
||||
"gate_up_proj",
|
||||
"down_proj",
|
||||
"o_proj",
|
||||
]
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config, # Qwen3_5Config (top-level)
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
scheduler_config: Optional[SchedulerConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.scheduler_config = scheduler_config
|
||||
|
||||
# The text config holds all architecture parameters
|
||||
text_cfg = config.text_config
|
||||
self.text_cfg = text_cfg
|
||||
|
||||
# Pre-compute counts
|
||||
self.num_linear_layers = sum(
|
||||
1 for lt in text_cfg.layer_types if lt == "linear_attention")
|
||||
self.num_attn_layers = sum(
|
||||
1 for lt in text_cfg.layer_types if lt == "full_attention")
|
||||
|
||||
# DeltaNet state dimensions (per layer, per sequence, TP-sharded)
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.conv_dim = (text_cfg.linear_num_key_heads * text_cfg.linear_key_head_dim * 2
|
||||
+ text_cfg.linear_num_value_heads * text_cfg.linear_value_head_dim)
|
||||
self.num_v_heads = text_cfg.linear_num_value_heads
|
||||
self.head_k_dim = text_cfg.linear_key_head_dim
|
||||
self.head_v_dim = text_cfg.linear_value_head_dim
|
||||
self.conv_kernel_size = text_cfg.linear_conv_kernel_dim
|
||||
|
||||
self.model = Qwen3_5Model(
|
||||
text_cfg,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.lm_head = ParallelLMHead(
|
||||
text_cfg.vocab_size, text_cfg.hidden_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.logits_processor = LogitsProcessor(text_cfg.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
# Lazy initialised in first forward call
|
||||
self.mamba_cache: Optional[MambaCacheManager] = None
|
||||
|
||||
def _get_mamba_cache_shape(self):
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
# Each sequence's state is stored in float32
|
||||
conv_state_shape = (self.conv_dim // tp_size, self.conv_kernel_size - 1)
|
||||
temporal_state_shape = (
|
||||
self.num_v_heads // tp_size, self.head_k_dim, self.head_v_dim)
|
||||
return conv_state_shape, temporal_state_shape
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if self.mamba_cache is None:
|
||||
if self.scheduler_config is not None:
|
||||
max_batch_size = _get_graph_batch_size(
|
||||
self.scheduler_config.max_num_seqs)
|
||||
else:
|
||||
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) + 2
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
torch.float32,
|
||||
self.num_linear_layers,
|
||||
max_batch_size,
|
||||
*self._get_mamba_cache_shape(),
|
||||
)
|
||||
|
||||
mamba_tensors = self.mamba_cache.current_run_tensors(
|
||||
input_ids, attn_metadata, **kwargs)
|
||||
# conv_states: (num_linear_layers, batch, local_conv_dim, kernel-1)
|
||||
# temporal_states: (num_linear_layers, batch, local_num_v, k_dim, v_dim)
|
||||
conv_states, temporal_states = mamba_tensors
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids, positions, kv_caches, attn_metadata,
|
||||
conv_states, temporal_states)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
return self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
return self.sampler(logits, sampling_metadata)
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||
return self.mamba_cache.copy_inputs_before_cuda_graphs(
|
||||
input_buffers, **kwargs)
|
||||
|
||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, weight_name, shard_id)
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
# Skip vision and MTP branches
|
||||
if (name.startswith("model.visual")
|
||||
or name.startswith("mtp.")
|
||||
or name.startswith("model.mtp")):
|
||||
continue
|
||||
|
||||
# Remap checkpoint prefix → module path
|
||||
# Checkpoint: "model.language_model.{rest}" → our module: "model.{rest}"
|
||||
# Checkpoint: "lm_head.weight" → our module: "lm_head.weight"
|
||||
if name.startswith("model.language_model."):
|
||||
name = "model." + name[len("model.language_model."):]
|
||||
# lm_head is already at top level — no change needed
|
||||
|
||||
# Skip positional embedding caches
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
# Remap conv1d.weight → conv1d_weight
|
||||
# The conv has depth (1) dim in the checkpoint that we handle separately
|
||||
if ".linear_attn.conv1d.weight" in name:
|
||||
name = name.replace(".linear_attn.conv1d.weight",
|
||||
".linear_attn.conv1d_weight")
|
||||
|
||||
# Stacked param loading (gate_up_proj)
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
break
|
||||
if name not in params_dict:
|
||||
break
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
3
qwen3_6_scripts/qwen3_5/__init__.py
Normal file
3
qwen3_6_scripts/qwen3_5/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .configuration_qwen3_5 import Qwen3_5Config, Qwen3_5TextConfig, Qwen3_5VisionConfig
|
||||
|
||||
__all__ = ["Qwen3_5Config", "Qwen3_5TextConfig", "Qwen3_5VisionConfig"]
|
||||
188
qwen3_6_scripts/qwen3_5/configuration_qwen3_5.py
Normal file
188
qwen3_6_scripts/qwen3_5/configuration_qwen3_5.py
Normal file
@@ -0,0 +1,188 @@
|
||||
# Adapted from transformers 5.2.0 for compatibility with transformers 4.55.3 + torch 2.1.0
|
||||
# Stubs layer_type_validation and RopeParameters which do not exist in 4.55.3
|
||||
|
||||
from typing import Optional, List
|
||||
|
||||
from ...configuration_utils import PretrainedConfig as PreTrainedConfig
|
||||
|
||||
# --- Local stubs for APIs not present in transformers 4.55.3 ---
|
||||
# Always use these definitions; do NOT import from the older transformers
|
||||
# as same-named functions there have incompatible signatures.
|
||||
|
||||
def layer_type_validation(layer_types, num_hidden_layers=None, attention=True):
|
||||
allowed = {"full_attention", "linear_attention"}
|
||||
if not all(lt in allowed for lt in layer_types):
|
||||
raise ValueError(f"layer_types entries must be in {allowed}, got {layer_types}")
|
||||
if num_hidden_layers is not None and num_hidden_layers != len(layer_types):
|
||||
raise ValueError(
|
||||
f"num_hidden_layers ({num_hidden_layers}) != len(layer_types) ({len(layer_types)})"
|
||||
)
|
||||
|
||||
try:
|
||||
from typing import TypedDict
|
||||
class RopeParameters(TypedDict, total=False):
|
||||
rope_theta: float
|
||||
rope_type: str
|
||||
partial_rotary_factor: float
|
||||
factor: float
|
||||
except Exception:
|
||||
RopeParameters = dict
|
||||
|
||||
# --- End stubs ---
|
||||
|
||||
|
||||
class Qwen3_5TextConfig(PreTrainedConfig):
|
||||
r"""
|
||||
Configuration for the text backbone of Qwen3.5 / Qwen3.6-27B models.
|
||||
model_type is "qwen3_5_text" (used internally by the nested config).
|
||||
"""
|
||||
|
||||
model_type = "qwen3_5_text"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=248320,
|
||||
hidden_size=4096,
|
||||
intermediate_size=12288,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=4,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=32768,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
tie_word_embeddings=False,
|
||||
rope_parameters=None,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
head_dim=256,
|
||||
linear_conv_kernel_dim=4,
|
||||
linear_key_head_dim=128,
|
||||
linear_value_head_dim=128,
|
||||
linear_num_key_heads=16,
|
||||
linear_num_value_heads=32,
|
||||
layer_types=None,
|
||||
pad_token_id=None,
|
||||
bos_token_id=None,
|
||||
eos_token_id=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.pad_token_id = pad_token_id
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.head_dim = head_dim
|
||||
self.rope_parameters = rope_parameters
|
||||
kwargs.setdefault("partial_rotary_factor", 0.25)
|
||||
|
||||
self.layer_types = layer_types
|
||||
if self.layer_types is None:
|
||||
interval_pattern = kwargs.get("full_attention_interval", 4)
|
||||
self.layer_types = [
|
||||
"linear_attention" if bool((i + 1) % interval_pattern) else "full_attention"
|
||||
for i in range(self.num_hidden_layers)
|
||||
]
|
||||
layer_type_validation(self.layer_types, self.num_hidden_layers)
|
||||
|
||||
self.linear_conv_kernel_dim = linear_conv_kernel_dim
|
||||
self.linear_key_head_dim = linear_key_head_dim
|
||||
self.linear_value_head_dim = linear_value_head_dim
|
||||
self.linear_num_key_heads = linear_num_key_heads
|
||||
self.linear_num_value_heads = linear_num_value_heads
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class Qwen3_5VisionConfig(PreTrainedConfig):
|
||||
model_type = "qwen3_5_vision"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
depth=27,
|
||||
hidden_size=1152,
|
||||
hidden_act="gelu_pytorch_tanh",
|
||||
intermediate_size=4304,
|
||||
num_heads=16,
|
||||
in_channels=3,
|
||||
patch_size=16,
|
||||
spatial_merge_size=2,
|
||||
temporal_patch_size=2,
|
||||
out_hidden_size=3584,
|
||||
num_position_embeddings=2304,
|
||||
initializer_range=0.02,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.depth = depth
|
||||
self.hidden_size = hidden_size
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_heads = num_heads
|
||||
self.in_channels = in_channels
|
||||
self.patch_size = patch_size
|
||||
self.spatial_merge_size = spatial_merge_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
self.out_hidden_size = out_hidden_size
|
||||
self.num_position_embeddings = num_position_embeddings
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
|
||||
class Qwen3_5Config(PreTrainedConfig):
|
||||
r"""
|
||||
Top-level configuration for Qwen3.5 / Qwen3.6-27B.
|
||||
model_type = "qwen3_5" matches the model card / config.json.
|
||||
Wraps Qwen3_5TextConfig (and optionally Qwen3_5VisionConfig for multimodal use).
|
||||
For vLLM text-only inference only text_config is consumed.
|
||||
"""
|
||||
|
||||
model_type = "qwen3_5"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_config=None,
|
||||
vision_config=None,
|
||||
image_token_id=248056,
|
||||
video_token_id=248057,
|
||||
vision_start_token_id=248053,
|
||||
vision_end_token_id=248054,
|
||||
tie_word_embeddings=False,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(text_config, dict):
|
||||
self.text_config = Qwen3_5TextConfig(**text_config)
|
||||
elif text_config is None:
|
||||
self.text_config = Qwen3_5TextConfig()
|
||||
else:
|
||||
self.text_config = text_config
|
||||
|
||||
if isinstance(vision_config, dict):
|
||||
self.vision_config = Qwen3_5VisionConfig(**vision_config)
|
||||
elif vision_config is None:
|
||||
self.vision_config = Qwen3_5VisionConfig()
|
||||
else:
|
||||
self.vision_config = vision_config
|
||||
|
||||
self.image_token_id = image_token_id
|
||||
self.video_token_id = video_token_id
|
||||
self.vision_start_token_id = vision_start_token_id
|
||||
self.vision_end_token_id = vision_end_token_id
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
__all__ = ["Qwen3_5Config", "Qwen3_5TextConfig", "Qwen3_5VisionConfig"]
|
||||
Reference in New Issue
Block a user