diff --git a/Dockerfile b/Dockerfile index 51d3df7..da24151 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,15 +1,6 @@ FROM git.modelhub.org.cn:9443/enginex-iluvatar/bi100-3.2.3-x86-ubuntu20.04-py3.10-poc-llm-infer:v1.2.3 -RUN pip install --no-cache-dir triton==2.1.0 - -COPY pkgs/triton /usr/local/corex/lib64/python3/dist-packages/triton -COPY pkgs/triton-2.1.0+corex.4.1.2.dist-info /usr/local/corex/lib64/python3/dist-packages/triton-2.1.0+corex.4.1.2.dist-info - -COPY paged_attn.py /usr/local/corex/lib64/python3/dist-packages/vllm/attention/ops/paged_attn.py -COPY __init__.py /usr/local/corex/lib64/python3/dist-packages/vllm/triton_utils/__init__.py - RUN mkdir /workspace WORKDIR /workspace/ - -COPY ./launch_service /workspace/launch_service - +COPY ./qwen3_6_scripts /workspace/qwen3_6_scripts +RUN cd ./qwen3_6_scripts && ./patch_ops.sh \ No newline at end of file diff --git a/README_qwen3_6.md b/README_qwen3_6.md new file mode 100644 index 0000000..edd388f --- /dev/null +++ b/README_qwen3_6.md @@ -0,0 +1,42 @@ +# 天数智芯 天垓100 文本生成引擎(基于 vLLM 优化适配Qwen3.6-27B) + +``` +# 本地构建 +docker build -t enginex-iluvatar-vllm:bi100-qwen3.6 -f Dockerfile . +``` + + +启动容器镜像 + +下载Qwen3.6-27B模型,并且需要将模型的config.json文件中architectures字段改成 +```json +"architectures": [ + "Qwen3_5ForCausalLM" + ] +``` + +```bash +docker run -dit --network=host --ipc=host \ + -v /usr/src:/usr/src -v /lib/modules:/lib/modules -v /dev:/dev --privileged \ + --name vllm-iluvatar \ + -v /mnt/models/Qwen3.6-27B:/model:ro --entrypoint=python3 \ + enginex-iluvatar-vllm:bi100 \ + -m vllm.entrypoints.openai.api_server \ + --model /model --port 1111 --served-model-name llm \ + --max-model-len 10000 --enforce-eager --trust-remote-code -tp 4 --gpu-memory-utilization 0.95 +``` + +请求 +```bash +curl http://localhost:1111/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "llm", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Can you tell me the story of Snow White?"} + ], + "max_tokens": 200, + "temperature": 0.7 + }' +``` \ No newline at end of file diff --git a/qwen3_6_scripts/mamba_cache.py b/qwen3_6_scripts/mamba_cache.py new file mode 100644 index 0000000..7537b9e --- /dev/null +++ b/qwen3_6_scripts/mamba_cache.py @@ -0,0 +1,224 @@ +from typing import Dict, List, Optional + +import torch + +from vllm.attention.backends.abstract import AttentionMetadata + + +class MambaCacheManager: + + def __init__(self, dtype, num_mamba_layers, max_batch_size, + conv_state_shape, temporal_state_shape): + + conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) + + conv_state_shape, + dtype=dtype, + device="cuda") + temporal_state = torch.zeros(size=(num_mamba_layers, max_batch_size) + + temporal_state_shape, + dtype=dtype, + device="cuda") + + self.mamba_cache = (conv_state, temporal_state) + + # Maps between the request id and a dict that maps between the seq_id + # and its index inside the self.mamba_cache + self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {} + + def current_run_tensors(self, input_ids: torch.Tensor, + attn_metadata: AttentionMetadata, **kwargs): + """ + Return the tensors for the current run's conv and ssm state. + """ + if "seqlen_agnostic_capture_inputs" not in kwargs: + # We get here only on Prefill/Eager mode runs + request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] + finished_requests_ids = kwargs["finished_requests_ids"] + + self._release_finished_requests(finished_requests_ids) + mamba_cache_tensors = self._prepare_current_run_mamba_cache( + request_ids_to_seq_ids, finished_requests_ids) + + else: + # CUDA graph capturing runs + mamba_cache_tensors = kwargs["seqlen_agnostic_capture_inputs"] + + return mamba_cache_tensors + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + """ + Copy the relevant Mamba cache into the CUDA graph input buffer + that was provided during the capture runs + (JambaForCausalLM.mamba_gc_cache_buffer). + """ + assert all( + key in kwargs + for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) + finished_requests_ids = kwargs["finished_requests_ids"] + request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] + + self._release_finished_requests(finished_requests_ids) + self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, + finished_requests_ids) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + """ + Provide the CUDA graph capture runs with a buffer in adjusted size. + The buffer is used to maintain the Mamba Cache during the CUDA graph + replay runs. + """ + return tuple(buffer[:, :batch_size] for buffer in self.mamba_cache) + + def _swap_mamba_cache(self, from_index: int, to_index: int): + assert len(self.mamba_cache) > 0 + for cache_t in self.mamba_cache: + cache_t[:, [to_index,from_index]] = \ + cache_t[:, [from_index,to_index]] + + def _copy_mamba_cache(self, from_index: int, to_index: int): + assert len(self.mamba_cache) > 0 + for cache_t in self.mamba_cache: + cache_t[:, to_index].copy_(cache_t[:, from_index], + non_blocking=True) + + def _move_out_if_already_occupied(self, index: int, + all_occupied_indices: List[int]): + if index in all_occupied_indices: + first_free_index = self._first_free_index_in_mamba_cache() + # In case occupied, move the occupied to a new empty block + self._move_cache_index_and_mappings(from_index=index, + to_index=first_free_index) + + def _assign_seq_id_to_mamba_cache_in_specific_dest(self, cur_rid: str, + seq_id: int, + destination_index: int): + """ + Assign (req_id,seq_id) pair to a `destination_index` index, if + already occupied, move the occupying index to a free index. + """ + all_occupied_indices = self._get_all_occupied_indices() + if cur_rid not in self.mamba_cache_indices_mapping: + self._move_out_if_already_occupied( + index=destination_index, + all_occupied_indices=all_occupied_indices) + for cache_t in self.mamba_cache: + cache_t[:, destination_index].zero_() + self.mamba_cache_indices_mapping[cur_rid] = { + seq_id: destination_index + } + elif seq_id not in (seq_ids2indices := + self.mamba_cache_indices_mapping[cur_rid]): + # parallel sampling , where n > 1, assume prefill have + # already happened now we only need to copy the already + # existing cache into the siblings seq_ids caches + self._move_out_if_already_occupied( + index=destination_index, + all_occupied_indices=all_occupied_indices) + index_exists = list(seq_ids2indices.values())[0] + # case of decoding n>1, copy prefill cache to decoding indices + self._copy_mamba_cache(from_index=index_exists, + to_index=destination_index) + self.mamba_cache_indices_mapping[cur_rid][ + seq_id] = destination_index + else: + # already exists + cache_index_already_exists = self.mamba_cache_indices_mapping[ + cur_rid][seq_id] + if cache_index_already_exists != destination_index: + # In case the seq id already exists but not in + # the right destination, swap it with what's occupying it + self._swap_pair_indices_and_mappings( + from_index=cache_index_already_exists, + to_index=destination_index) + + def _prepare_current_run_mamba_cache( + self, request_ids_to_seq_ids: Dict[str, list[int]], + finished_requests_ids: List[str]): + running_indices = [] + request_ids_to_seq_ids_flatten = [ + (req_id, seq_id) + for req_id, seq_ids in request_ids_to_seq_ids.items() + for seq_id in seq_ids + ] + batch_size = len(request_ids_to_seq_ids_flatten) + for dest_index, (request_id, + seq_id) in enumerate(request_ids_to_seq_ids_flatten): + if request_id in finished_requests_ids: + # Do not allocate cache index for requests that run + # and finish right after + continue + self._assign_seq_id_to_mamba_cache_in_specific_dest( + request_id, seq_id, dest_index) + running_indices.append(dest_index) + + self._clean_up_first_bs_blocks(batch_size, running_indices) + conv_state = self.mamba_cache[0][:, :batch_size] + temporal_state = self.mamba_cache[1][:, :batch_size] + + return (conv_state, temporal_state) + + def _get_all_occupied_indices(self): + return [ + cache_idx + for seq_ids2indices in self.mamba_cache_indices_mapping.values() + for cache_idx in seq_ids2indices.values() + ] + + def _clean_up_first_bs_blocks(self, batch_size: int, + indices_for_current_run: List[int]): + # move out all of the occupied but currently not running blocks + # outside of the first n blocks + destination_indices = range(batch_size) + max_possible_batch_size = self.mamba_cache[0].shape[1] + for destination_index in destination_indices: + if destination_index in self._get_all_occupied_indices() and \ + destination_index not in indices_for_current_run: + # move not running indices outside of the batch + all_other_indices = list( + range(batch_size, max_possible_batch_size)) + first_avail_index = self._first_free_index_in_mamba_cache( + all_other_indices) + self._swap_indices(from_index=destination_index, + to_index=first_avail_index) + + def _move_cache_index_and_mappings(self, from_index: int, to_index: int): + self._copy_mamba_cache(from_index=from_index, to_index=to_index) + self._update_mapping_index(from_index=from_index, to_index=to_index) + + def _swap_pair_indices_and_mappings(self, from_index: int, to_index: int): + self._swap_mamba_cache(from_index=from_index, to_index=to_index) + self._swap_mapping_index(from_index=from_index, to_index=to_index) + + def _swap_mapping_index(self, from_index: int, to_index: int): + for seq_ids2index in self.mamba_cache_indices_mapping.values(): + for seq_id, index in seq_ids2index.items(): + if from_index == index: + seq_ids2index.update({seq_id: to_index}) + elif to_index == index: + seq_ids2index.update({seq_id: from_index}) + + def _update_mapping_index(self, from_index: int, to_index: int): + for seq_ids2index in self.mamba_cache_indices_mapping.values(): + for seq_id, index in seq_ids2index.items(): + if from_index == index: + seq_ids2index.update({seq_id: to_index}) + return + + def _release_finished_requests(self, + finished_seq_groups_req_ids: List[str]): + for req_id in finished_seq_groups_req_ids: + if req_id in self.mamba_cache_indices_mapping: + self.mamba_cache_indices_mapping.pop(req_id) + + def _first_free_index_in_mamba_cache( + self, indices_range: Optional[List[int]] = None) -> int: + assert self.mamba_cache is not None + if indices_range is None: + max_possible_batch_size = self.mamba_cache[0].shape[1] + indices_range = list(range(max_possible_batch_size)) + all_occupied_indices = self._get_all_occupied_indices() + for i in indices_range: + if i not in all_occupied_indices: + return i + raise Exception("Couldn't find a free spot in the mamba cache! This" + "should never happen") diff --git a/qwen3_6_scripts/patch_ops.sh b/qwen3_6_scripts/patch_ops.sh new file mode 100755 index 0000000..4048cd5 --- /dev/null +++ b/qwen3_6_scripts/patch_ops.sh @@ -0,0 +1,10 @@ +pip install transformers==4.55.3 -i https://pypi.tuna.tsinghua.edu.cn/simple +cp -r ./qwen3_5 /usr/local/lib/python3.10/site-packages/transformers/models/ +python3 ./patch_transformers_qwen3_5.py + +cp ./mamba_cache.py /usr/local/corex/lib/python3/dist-packages/vllm/model_executor/models/ +cp ./qwen3_5.py /usr/local/corex/lib/python3/dist-packages/vllm/model_executor/models/ +python3 ./patch_vllm_qwen3_5.py + +# 此步骤脚本四选一(默认 matmul+seq策略) +python3 ./patch_xformers_sdpa_seq.py diff --git a/qwen3_6_scripts/patch_transformers_qwen3_5.py b/qwen3_6_scripts/patch_transformers_qwen3_5.py new file mode 100644 index 0000000..c065b03 --- /dev/null +++ b/qwen3_6_scripts/patch_transformers_qwen3_5.py @@ -0,0 +1,97 @@ +""" +Patches transformers 4.55.3 to register the qwen3_5 model type. + +Deploy steps on the remote machine: + 1. cp -r modified_scripts/qwen3_5 /usr/local/lib/python3.10/site-packages/transformers/models/qwen3_5 + 2. python3 modified_scripts/patch_transformers_qwen3_5.py + +Target: pip-installed transformers at /usr/local/lib/python3.10/site-packages/transformers/ +(Not the corex pre-installed path at /usr/local/corex/lib64/python3/dist-packages/) +""" + +import sys + +TRANSFORMERS_ROOT = "/usr/local/lib/python3.10/site-packages/transformers" +AUTO_CONFIG = f"{TRANSFORMERS_ROOT}/models/auto/configuration_auto.py" +MODELS_INIT = f"{TRANSFORMERS_ROOT}/models/__init__.py" + + +def patch_file(path, replacements): + with open(path, "r") as f: + content = f.read() + + patched = False + for old, new in replacements: + if new in content: + print(f" [skip] already patched: {repr(new[:60])}") + continue + if old not in content: + print(f" [warn] anchor not found: {repr(old[:60])}") + continue + content = content.replace(old, new, 1) + patched = True + print(f" [ok] inserted after: {repr(old[:60])}") + + if patched: + with open(path, "w") as f: + f.write(content) + + +def main(): + print(f"=== Patching {AUTO_CONFIG} ===") + patch_file(AUTO_CONFIG, [ + # CONFIG_MAPPING_NAMES: insert qwen3_5 right after qwen3 + ( + '("qwen3", "Qwen3Config"),', + '("qwen3", "Qwen3Config"),\n ("qwen3_5", "Qwen3_5Config"),', + ), + # Some versions don't have trailing comma — handle that too + ( + '("qwen3", "Qwen3Config")\n', + '("qwen3", "Qwen3Config"),\n ("qwen3_5", "Qwen3_5Config"),\n', + ), + # MODEL_NAMES_MAPPING (model_type -> human readable name, used by docstring generator) + ( + '("qwen3", "Qwen3"),', + '("qwen3", "Qwen3"),\n ("qwen3_5", "Qwen3_5"),', + ), + ( + '("qwen3", "Qwen3")\n', + '("qwen3", "Qwen3"),\n ("qwen3_5", "Qwen3_5"),\n', + ), + ]) + + print(f"\n=== Patching {MODELS_INIT} ===") + patch_file(MODELS_INIT, [ + ( + "from .qwen3 import *\n", + "from .qwen3 import *\n from .qwen3_5 import *\n", + ), + ]) + + # Verification + print("\n=== Verification ===") + try: + import importlib.util, types + + # Quick smoke-test: import the config class directly + spec = importlib.util.spec_from_file_location( + "configuration_qwen3_5", + f"{TRANSFORMERS_ROOT}/models/qwen3_5/configuration_qwen3_5.py", + ) + mod = importlib.util.module_from_spec(spec) + # Provide minimal parent package stubs so relative imports resolve + pkg = types.ModuleType("transformers") + pkg.__path__ = [TRANSFORMERS_ROOT] + sys.modules.setdefault("transformers", pkg) + spec.loader.exec_module(mod) + cfg = mod.Qwen3_5Config() + print(f" Qwen3_5Config() smoke-test OK (model_type={cfg.model_type})") + except Exception as e: + print(f" [warn] smoke-test failed (may be fine at runtime): {e}") + + print("\nDone.") + + +if __name__ == "__main__": + main() diff --git a/qwen3_6_scripts/patch_vllm_qwen3_5.py b/qwen3_6_scripts/patch_vllm_qwen3_5.py new file mode 100644 index 0000000..d788ad9 --- /dev/null +++ b/qwen3_6_scripts/patch_vllm_qwen3_5.py @@ -0,0 +1,73 @@ +""" +Patches the vLLM model registry and deploys the Qwen3_5 model file. + +Deploy steps on the remote machine: + 1. cp modified_scripts/qwen3_5.py \ + /usr/local/corex/lib64/python3/dist-packages/vllm/model_executor/models/qwen3_5.py + 2. python3 modified_scripts/patch_vllm_qwen3_5.py + +Also edit your model config.json to set: + "architectures": ["Qwen3_5ForCausalLM"] + +Target: vLLM at /usr/local/corex/lib64/python3/dist-packages/vllm/ +""" + +VLLM_ROOT = "/usr/local/corex/lib64/python3/dist-packages/vllm" +REGISTRY = f"{VLLM_ROOT}/model_executor/models/registry.py" + + +def patch_file(path, replacements): + with open(path, "r") as f: + content = f.read() + + patched = False + for old, new in replacements: + if new in content: + print(f" [skip] already patched: {repr(new[:70])}") + continue + if old not in content: + print(f" [warn] anchor not found: {repr(old[:70])}") + continue + content = content.replace(old, new, 1) + patched = True + print(f" [ok] patched after: {repr(old[:70])}") + + if patched: + with open(path, "w") as f: + f.write(content) + + +def main(): + print(f"=== Patching {REGISTRY} ===") + patch_file(REGISTRY, [ + ( + ' "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),\n' + ' "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),', + ' "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),\n' + ' "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),\n' + ' "Qwen3_5ForCausalLM": ("qwen3_5", "Qwen3_5ForCausalLM"),', + ), + ]) + + print("\n=== Verification ===") + try: + import importlib.util + spec = importlib.util.spec_from_file_location( + "qwen3_5", + f"{VLLM_ROOT}/model_executor/models/qwen3_5.py", + ) + mod = importlib.util.module_from_spec(spec) + # Quick check: does the class exist? + spec.loader.exec_module(mod) + cls = mod.Qwen3_5ForCausalLM + print(f" Qwen3_5ForCausalLM found: {cls}") + except Exception as e: + print(f" [warn] verification failed (may be OK at runtime): {e}") + + print("\nDone. Remember to:") + print(" 1. Set config.json 'architectures': ['Qwen3_5ForCausalLM']") + print(" 2. Run patch_transformers_qwen3_5.py if not already done") + + +if __name__ == "__main__": + main() diff --git a/qwen3_6_scripts/patch_xformers_sdpa_batch.py b/qwen3_6_scripts/patch_xformers_sdpa_batch.py new file mode 100644 index 0000000..a585b4d --- /dev/null +++ b/qwen3_6_scripts/patch_xformers_sdpa_batch.py @@ -0,0 +1,192 @@ +""" +策略:批量(block-diagonal)fallback — 纯 PyTorch 数学实现 +============================================================= +构建块对角 causal mask,对整批序列一次 matmul + softmax, +完全绕开所有硬件 flash attention kernel。 + +背景: + ixformer flshattF: head_dim > 128 报错拒绝 + cudnnFlashAttnForward: 接受 head_dim=256,但数值结果错误(输出全"!") + 两者大概率是同一硬件单元,ixformer 提前拦截了硬件不支持的配置。 + 纯 matmul 路径完全绕开硬件 flash attention,数值正确。 + +优点: + 数值正确。 + 并发请求 prefill attention 在 GPU 上真正并行(一次大 matmul)。 + +缺点: + 峰值显存 = total_tokens² × H × dtype_size + total_tokens 受 --max-num-batched-tokens 控制,max-model-len 控制不住。 + +内存参考(fp16,H_local=6,--max-num-batched-tokens=T): + T=2048 → 峰值 ~50 MB + T=4096 → 峰值 ~200 MB + T=8192 → 峰值 ~800 MB + T=16384 → 峰值 ~3.2 GB + +Deploy: + python3 modified_scripts/patch_xformers_sdpa_batch.py +""" + +XFORMERS_PATH = ( + "/usr/local/corex/lib64/python3/dist-packages/" + "vllm/attention/backends/xformers.py" +) + +FALLBACK_METHOD = ''' + def _run_sdpa_fallback( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: "XFormersMetadata", + ) -> torch.Tensor: + """批量纯数学 attention fallback。 + + 构建块对角 causal mask(等价于 ixformer BlockDiagonalCausalMask), + 对整批序列一次 matmul + softmax,GPU 并行处理所有序列。 + + 块对角 mask 结构(seq1 len=3,seq2 len=2): + s1,0 s1,1 s1,2 s2,0 s2,1 + s1,0 [ 0 -inf -inf -inf -inf ] + s1,1 [ 0 0 -inf -inf -inf ] + s1,2 [ 0 0 0 -inf -inf ] + s2,0 [-inf -inf -inf 0 -inf ] + s2,1 [-inf -inf -inf 0 0 ] + + softmax 在 float32 下计算防止 float16 溢出,结果转回原始 dtype。 + + Args: + query : [1, total_prefill_tokens, num_heads, head_dim] + key : [1, total_prefill_tokens, num_kv_heads, head_dim] + value : [1, total_prefill_tokens, num_kv_heads, head_dim] + Returns: + [1, total_prefill_tokens, num_heads, head_dim] + """ + assert attn_metadata.seq_lens is not None + orig_dtype = query.dtype + total_tokens = query.shape[1] + + # ── 构建块对角 causal mask [T, T] ──────────────────────────────── + # 全部初始化为 -inf,再对每条序列的对角块填入下三角 0 + mask = torch.full( + (total_tokens, total_tokens), + float("-inf"), + dtype=torch.float32, + device=query.device, + ) + start = 0 + for seq_len in attn_metadata.seq_lens: + end = start + seq_len + mask[start:end, start:end] = torch.tril( + torch.zeros(seq_len, seq_len, + dtype=torch.float32, device=query.device) + ) + start = end + + # ── [1, H, T, D],.contiguous() ────────────────────────────────── + q_all = query.squeeze(0).permute(1, 0, 2).contiguous().unsqueeze(0) + k_all = key.squeeze(0).permute(1, 0, 2).contiguous().unsqueeze(0) + v_all = value.squeeze(0).permute(1, 0, 2).contiguous().unsqueeze(0) + + # ── GQA:展开 KV heads ──────────────────────────────────────────── + if k_all.shape[1] != q_all.shape[1]: + n = q_all.shape[1] // k_all.shape[1] + k_all = k_all.repeat_interleave(n, dim=1).contiguous() + v_all = v_all.repeat_interleave(n, dim=1).contiguous() + + # ── 纯数学 attention(float32 防溢出)──────────────────────────── + # [1, H, T, T] + attn_w = torch.matmul(q_all.float(), k_all.float().transpose(-2, -1)) + attn_w = attn_w * self.scale + attn_w = attn_w + mask # 加法广播:mask [T,T] → [1, H, T, T] + attn_w = torch.softmax(attn_w, dim=-1) + + out = torch.matmul(attn_w, v_all.float()).to(orig_dtype) + # [1, H, T, D] → [1, T, H, D] + return out.squeeze(0).permute(1, 0, 2).contiguous().unsqueeze(0) + +''' + +OLD_XFORMER_BLOCK = """\ + self.attn_op = xops.fmha.flash.FwOp() + if self.alibi_slopes is None: + # Add the batch dimension. + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + out = xops.memory_efficient_attention_forward( + query, + key, + value, + attn_bias=attn_bias[0], + p=0.0, + scale=self.scale, + op = self.attn_op + ) + return out.view_as(original_query)\ +""" + +NEW_XFORMER_BLOCK = """\ + self.attn_op = xops.fmha.flash.FwOp() + if self.alibi_slopes is None: + # Add the batch dimension. + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + if self.head_size > 128: + out = self._run_sdpa_fallback(query, key, value, attn_metadata) + else: + out = xops.memory_efficient_attention_forward( + query, + key, + value, + attn_bias=attn_bias[0], + p=0.0, + scale=self.scale, + op=self.attn_op, + ) + return out.view_as(original_query)\ +""" + +INJECT_ANCHOR = " def _run_memory_efficient_xformers_forward(" + + +def patch_file(path): + with open(path, "r") as f: + content = f.read() + changed = False + + if "_run_sdpa_fallback" in content: + print(" [skip] _run_sdpa_fallback already present") + elif INJECT_ANCHOR not in content: + print(" [warn] inject anchor not found") + else: + content = content.replace(INJECT_ANCHOR, FALLBACK_METHOD + INJECT_ANCHOR, 1) + print(" [ok] injected _run_sdpa_fallback (batch, pure-math)") + changed = True + + if NEW_XFORMER_BLOCK in content: + print(" [skip] dispatch block already patched") + elif OLD_XFORMER_BLOCK in content: + content = content.replace(OLD_XFORMER_BLOCK, NEW_XFORMER_BLOCK, 1) + print(" [ok] patched dispatch block") + changed = True + else: + print(" [warn] dispatch block anchor not found") + + if changed: + with open(path, "w") as f: + f.write(content) + print(f" Written: {path}") + + +def main(): + print("=== patch_xformers_sdpa_batch (batch, pure-math) ===") + print(f"Target: {XFORMERS_PATH}") + patch_file(XFORMERS_PATH) + print("\nDone.") + + +if __name__ == "__main__": + main() diff --git a/qwen3_6_scripts/patch_xformers_sdpa_batch_kernel.py b/qwen3_6_scripts/patch_xformers_sdpa_batch_kernel.py new file mode 100644 index 0000000..e7f647f --- /dev/null +++ b/qwen3_6_scripts/patch_xformers_sdpa_batch_kernel.py @@ -0,0 +1,191 @@ +""" +策略:批量(block-diagonal)— F.scaled_dot_product_attention,可走硬件 kernel +============================================================================= +构建块对角 causal mask,对整批序列一次 F.scaled_dot_product_attention。 +与 patch_xformers_sdpa_batch.py(纯 matmul)的区别: + SDPA 会根据 PyTorch/驱动能力分发到最优 kernel(Flash Attention / + mem-efficient attention / math fallback),而不是固定走 cublas matmul。 + +历史说明: + 该方案最早因输出全"!"而被弃用,后续排查确认"!"由 mamba_cache.py bug + 引起,与 attention 实现无关。当前恢复此方案用于性能对比测试。 + +已知硬件限制(BI-V100): + cudnnFlashAttnForward 不支持 is_causal=True(报错)。 + 本实现使用 is_causal=False + 显式块对角 additive mask 规避此限制。 + 若 SDPA 仍分发到有问题的 kernel,回退到 patch_xformers_sdpa_batch.py。 + +优点(vs 纯 matmul): + SDPA 可分发到 Flash Attention kernel → O(L) 显存、更快的 CUDA kernel。 + +缺点: + 依赖硬件 kernel 行为,若 kernel 有 bug 则数值错误(需与 matmul 版对比验证)。 + +Deploy: + python3 modified_scripts/patch_xformers_sdpa_batch_kernel.py +""" + +XFORMERS_PATH = ( + "/usr/local/corex/lib64/python3/dist-packages/" + "vllm/attention/backends/xformers.py" +) + +FALLBACK_METHOD = ''' + def _run_sdpa_fallback( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: "XFormersMetadata", + ) -> torch.Tensor: + """批量 F.scaled_dot_product_attention fallback(可走硬件 kernel)。 + + 构建块对角 causal mask,对整批序列一次 SDPA 调用。 + SDPA 可分发到 Flash Attention / mem-efficient attention kernel。 + is_causal=False + 显式 additive mask,规避 cudnnFlashAttnForward + 不支持 is_causal=True 的限制。 + + 块对角 mask(seq1 len=3,seq2 len=2): + s1,0 s1,1 s1,2 s2,0 s2,1 + s1,0 [ 0 -inf -inf -inf -inf ] + s1,1 [ 0 0 -inf -inf -inf ] + s1,2 [ 0 0 0 -inf -inf ] + s2,0 [-inf -inf -inf 0 -inf ] + s2,1 [-inf -inf -inf 0 0 ] + + Args: + query : [1, total_prefill_tokens, num_heads, head_dim] + key : [1, total_prefill_tokens, num_kv_heads, head_dim] + value : [1, total_prefill_tokens, num_kv_heads, head_dim] + Returns: + [1, total_prefill_tokens, num_heads, head_dim] + """ + import torch.nn.functional as F + + assert attn_metadata.seq_lens is not None + orig_dtype = query.dtype + total_tokens = query.shape[1] + + # ── 块对角 causal mask [T, T] ───────────────────────────────────── + mask = torch.full( + (total_tokens, total_tokens), + float("-inf"), + dtype=orig_dtype, + device=query.device, + ) + start = 0 + for seq_len in attn_metadata.seq_lens: + end = start + seq_len + mask[start:end, start:end] = torch.tril( + torch.zeros(seq_len, seq_len, dtype=orig_dtype, device=query.device) + ) + start = end + + # ── [1, H, T, D] ────────────────────────────────────────────────── + q_all = query.squeeze(0).permute(1, 0, 2).contiguous().unsqueeze(0) + k_all = key.squeeze(0).permute(1, 0, 2).contiguous().unsqueeze(0) + v_all = value.squeeze(0).permute(1, 0, 2).contiguous().unsqueeze(0) + + # ── GQA:展开 KV heads ──────────────────────────────────────────── + if k_all.shape[1] != q_all.shape[1]: + n = q_all.shape[1] // k_all.shape[1] + k_all = k_all.repeat_interleave(n, dim=1).contiguous() + v_all = v_all.repeat_interleave(n, dim=1).contiguous() + + # ── F.scaled_dot_product_attention(可走硬件 kernel)───────────── + # is_causal=False:避免 cudnnFlashAttnForward "not support causal mode" + # attn_mask 传 additive float mask(非 bool),SDPA 选择 math/kernel 路径 + out = F.scaled_dot_product_attention( + q_all, k_all, v_all, + attn_mask=mask, + dropout_p=0.0, + is_causal=False, + scale=self.scale, + ) + # [1, H, T, D] → [1, T, H, D] + return out.squeeze(0).permute(1, 0, 2).contiguous().unsqueeze(0) + +''' + +OLD_XFORMER_BLOCK = """\ + self.attn_op = xops.fmha.flash.FwOp() + if self.alibi_slopes is None: + # Add the batch dimension. + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + out = xops.memory_efficient_attention_forward( + query, + key, + value, + attn_bias=attn_bias[0], + p=0.0, + scale=self.scale, + op = self.attn_op + ) + return out.view_as(original_query)\ +""" + +NEW_XFORMER_BLOCK = """\ + self.attn_op = xops.fmha.flash.FwOp() + if self.alibi_slopes is None: + # Add the batch dimension. + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + if self.head_size > 128: + out = self._run_sdpa_fallback(query, key, value, attn_metadata) + else: + out = xops.memory_efficient_attention_forward( + query, + key, + value, + attn_bias=attn_bias[0], + p=0.0, + scale=self.scale, + op=self.attn_op, + ) + return out.view_as(original_query)\ +""" + +INJECT_ANCHOR = " def _run_memory_efficient_xformers_forward(" + + +def patch_file(path): + with open(path, "r") as f: + content = f.read() + changed = False + + if "_run_sdpa_fallback" in content: + print(" [skip] _run_sdpa_fallback already present") + elif INJECT_ANCHOR not in content: + print(" [warn] inject anchor not found") + else: + content = content.replace(INJECT_ANCHOR, FALLBACK_METHOD + INJECT_ANCHOR, 1) + print(" [ok] injected _run_sdpa_fallback (batch, F.sdpa kernel)") + changed = True + + if NEW_XFORMER_BLOCK in content: + print(" [skip] dispatch block already patched") + elif OLD_XFORMER_BLOCK in content: + content = content.replace(OLD_XFORMER_BLOCK, NEW_XFORMER_BLOCK, 1) + print(" [ok] patched dispatch block") + changed = True + else: + print(" [warn] dispatch block anchor not found") + + if changed: + with open(path, "w") as f: + f.write(content) + print(f" Written: {path}") + + +def main(): + print("=== patch_xformers_sdpa_batch_kernel (batch, F.sdpa + kernel dispatch) ===") + print(f"Target: {XFORMERS_PATH}") + patch_file(XFORMERS_PATH) + print("\nDone.") + + +if __name__ == "__main__": + main() diff --git a/qwen3_6_scripts/patch_xformers_sdpa_seq.py b/qwen3_6_scripts/patch_xformers_sdpa_seq.py new file mode 100644 index 0000000..8f36a0c --- /dev/null +++ b/qwen3_6_scripts/patch_xformers_sdpa_seq.py @@ -0,0 +1,186 @@ +""" +策略:顺序(per-sequence)fallback — 纯 PyTorch 数学实现 +========================================================== +逐条序列用 matmul + softmax 手写 attention,完全绕开所有硬件 +flash attention kernel(ixformer / cudnnFlashAttnForward)。 + +背景: + Iluvatar cudnnFlashAttnForward 存在两个已知问题: + 1. 不支持 is_causal=True(报错) + 2. 使用 attn_mask 路径时数值结果不正确(静默错误,输出全为"!") + 与华为昇腾 910B4 上 llama.cpp --flash-attn off 修复同类问题的原理相同。 + 纯数学路径(matmul + softmax)在任何 PyTorch 后端上结果都正确。 + +优点: + 数值正确,不依赖任何硬件特定 attention kernel。 + 峰值显存 = max(seq_len)² × H × dtype_size,由 --max-model-len 控制。 + +缺点: + 并发请求的 prefill attention 串行执行。 + O(L²) 显存(无 flash attention 的 O(L) 优化)。 + +内存参考(fp16,H_local=6): + max-model-len=4096 → 峰值 ~200 MB + max-model-len=8192 → 峰值 ~800 MB + max-model-len=16384 → 峰值 ~3.2 GB + +Deploy: + python3 modified_scripts/patch_xformers_sdpa_seq.py +""" + +XFORMERS_PATH = ( + "/usr/local/corex/lib64/python3/dist-packages/" + "vllm/attention/backends/xformers.py" +) + +FALLBACK_METHOD = ''' + def _run_sdpa_fallback( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: "XFormersMetadata", + ) -> torch.Tensor: + """顺序纯数学 attention fallback。 + + 完全绕开 ixformer / cudnnFlashAttnForward,用 matmul + softmax + 手写 attention。Iluvatar cudnnFlashAttnForward 的 attn_mask 路径 + 存在静默数值错误(输出全为"!"),纯数学路径结果正确。 + + softmax 在 float32 下计算以防止 float16 溢出,结果转回原始 dtype。 + + Args: + query : [1, total_prefill_tokens, num_heads, head_dim] + key : [1, total_prefill_tokens, num_kv_heads, head_dim] + value : [1, total_prefill_tokens, num_kv_heads, head_dim] + Returns: + [1, total_prefill_tokens, num_heads, head_dim] + """ + assert attn_metadata.seq_lens is not None + orig_dtype = query.dtype + + q_flat = query.squeeze(0) # [T, H, D] + k_flat = key.squeeze(0) # [T, Hkv, D] + v_flat = value.squeeze(0) + + output = torch.empty_like(q_flat) + start = 0 + for seq_len in attn_metadata.seq_lens: + end = start + seq_len + # [1, H, L, D] + q_s = q_flat[start:end].permute(1, 0, 2).contiguous().unsqueeze(0) + k_s = k_flat[start:end].permute(1, 0, 2).contiguous().unsqueeze(0) + v_s = v_flat[start:end].permute(1, 0, 2).contiguous().unsqueeze(0) + + # GQA:展开 KV heads 至与 query heads 一致 + if k_s.shape[1] != q_s.shape[1]: + n = q_s.shape[1] // k_s.shape[1] + k_s = k_s.repeat_interleave(n, dim=1).contiguous() + v_s = v_s.repeat_interleave(n, dim=1).contiguous() + + # 纯数学 attention:完全绕开硬件 flash attention kernel + # [1, H, L, L] + attn_w = torch.matmul(q_s.float(), k_s.float().transpose(-2, -1)) + attn_w = attn_w * self.scale + + # 上三角填 -inf(future tokens) + causal_mask = torch.triu( + torch.ones(seq_len, seq_len, dtype=torch.bool, device=attn_w.device), + diagonal=1, + ) + attn_w = attn_w.masked_fill(causal_mask, float("-inf")) + + # float32 softmax 防止 float16 溢出 + attn_w = torch.softmax(attn_w, dim=-1) + + out_s = torch.matmul(attn_w, v_s.float()).to(orig_dtype) + # [1, H, L, D] → [L, H, D] + output[start:end] = out_s.squeeze(0).permute(1, 0, 2) + start = end + + return output.unsqueeze(0) # [1, T, H, D] + +''' + +OLD_XFORMER_BLOCK = """\ + self.attn_op = xops.fmha.flash.FwOp() + if self.alibi_slopes is None: + # Add the batch dimension. + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + out = xops.memory_efficient_attention_forward( + query, + key, + value, + attn_bias=attn_bias[0], + p=0.0, + scale=self.scale, + op = self.attn_op + ) + return out.view_as(original_query)\ +""" + +NEW_XFORMER_BLOCK = """\ + self.attn_op = xops.fmha.flash.FwOp() + if self.alibi_slopes is None: + # Add the batch dimension. + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + if self.head_size > 128: + out = self._run_sdpa_fallback(query, key, value, attn_metadata) + else: + out = xops.memory_efficient_attention_forward( + query, + key, + value, + attn_bias=attn_bias[0], + p=0.0, + scale=self.scale, + op=self.attn_op, + ) + return out.view_as(original_query)\ +""" + +INJECT_ANCHOR = " def _run_memory_efficient_xformers_forward(" + + +def patch_file(path): + with open(path, "r") as f: + content = f.read() + changed = False + + if "_run_sdpa_fallback" in content: + print(" [skip] _run_sdpa_fallback already present") + elif INJECT_ANCHOR not in content: + print(" [warn] inject anchor not found") + else: + content = content.replace(INJECT_ANCHOR, FALLBACK_METHOD + INJECT_ANCHOR, 1) + print(" [ok] injected _run_sdpa_fallback (sequential, pure-math)") + changed = True + + if NEW_XFORMER_BLOCK in content: + print(" [skip] dispatch block already patched") + elif OLD_XFORMER_BLOCK in content: + content = content.replace(OLD_XFORMER_BLOCK, NEW_XFORMER_BLOCK, 1) + print(" [ok] patched dispatch block") + changed = True + else: + print(" [warn] dispatch block anchor not found") + + if changed: + with open(path, "w") as f: + f.write(content) + print(f" Written: {path}") + + +def main(): + print("=== patch_xformers_sdpa_seq (sequential, pure-math) ===") + print(f"Target: {XFORMERS_PATH}") + patch_file(XFORMERS_PATH) + print("\nDone.") + + +if __name__ == "__main__": + main() diff --git a/qwen3_6_scripts/patch_xformers_sdpa_seq_kernel.py b/qwen3_6_scripts/patch_xformers_sdpa_seq_kernel.py new file mode 100644 index 0000000..82df8d0 --- /dev/null +++ b/qwen3_6_scripts/patch_xformers_sdpa_seq_kernel.py @@ -0,0 +1,181 @@ +""" +策略:顺序(per-sequence)— F.scaled_dot_product_attention,可走硬件 kernel +============================================================================= +逐条序列调用 F.scaled_dot_product_attention,is_causal=False + 显式因果 mask。 +与 patch_xformers_sdpa_seq.py(纯 matmul)的区别: + SDPA 可分发到 Flash Attention / mem-efficient attention kernel, + 而纯 matmul 固定走 cublas。 + +硬件限制(BI-V100): + cudnnFlashAttnForward 不支持 is_causal=True(直接报错)。 + 必须使用 is_causal=False + 显式 additive causal mask。 + 每条序列单独构造上三角 -inf mask,peak 显存 = max(seq_len)² × dtype, + 比 batch 版的 total_tokens² 小得多。 + +与 batch_kernel 的对比: + seq_kernel: 显存小,peak = max_single_seq²;并发 prefill 串行排队 + batch_kernel: 显存大,peak = total_tokens²;并发 prefill 一次并行处理, + 通过 --max-num-batched-tokens 控制 total_tokens 上限 + +Deploy: + python3 modified_scripts/patch_xformers_sdpa_seq_kernel.py +""" + +XFORMERS_PATH = ( + "/usr/local/corex/lib64/python3/dist-packages/" + "vllm/attention/backends/xformers.py" +) + +FALLBACK_METHOD = ''' + def _run_sdpa_fallback( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: "XFormersMetadata", + ) -> torch.Tensor: + """顺序 F.scaled_dot_product_attention fallback(可走硬件 kernel)。 + + 逐条序列调用 SDPA,is_causal=False + 显式上三角 additive mask。 + cudnnFlashAttnForward 不支持 is_causal=True,必须用显式 mask。 + 逐序列构造 mask,peak 显存 = max(seq_len)² × dtype(远小于 batch 版)。 + + Args: + query : [1, total_prefill_tokens, num_heads, head_dim] + key : [1, total_prefill_tokens, num_kv_heads, head_dim] + value : [1, total_prefill_tokens, num_kv_heads, head_dim] + Returns: + [1, total_prefill_tokens, num_heads, head_dim] + """ + import torch.nn.functional as F + + assert attn_metadata.seq_lens is not None + orig_dtype = query.dtype + + q_flat = query.squeeze(0) # [T, H, D] + k_flat = key.squeeze(0) # [T, Hkv, D] + v_flat = value.squeeze(0) + + output = torch.empty_like(q_flat) + start = 0 + for seq_len in attn_metadata.seq_lens: + end = start + seq_len + # [1, H, L, D] + q_s = q_flat[start:end].permute(1, 0, 2).contiguous().unsqueeze(0) + k_s = k_flat[start:end].permute(1, 0, 2).contiguous().unsqueeze(0) + v_s = v_flat[start:end].permute(1, 0, 2).contiguous().unsqueeze(0) + + # GQA:展开 KV heads + if k_s.shape[1] != q_s.shape[1]: + n = q_s.shape[1] // k_s.shape[1] + k_s = k_s.repeat_interleave(n, dim=1).contiguous() + v_s = v_s.repeat_interleave(n, dim=1).contiguous() + + # 逐序列因果 mask [L, L],上三角 -inf + causal_mask = torch.tril( + torch.zeros(seq_len, seq_len, dtype=orig_dtype, device=q_s.device) + ) + causal_mask = causal_mask.masked_fill( + torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool, + device=q_s.device), diagonal=1), + float("-inf"), + ) + + # is_causal=False + 显式 mask,规避 cudnnFlashAttnForward 不支持 is_causal=True + out_s = F.scaled_dot_product_attention( + q_s, k_s, v_s, + attn_mask=causal_mask, + dropout_p=0.0, + is_causal=False, + scale=self.scale, + ) + # [1, H, L, D] → [L, H, D] + output[start:end] = out_s.squeeze(0).permute(1, 0, 2).to(orig_dtype) + start = end + + return output.unsqueeze(0) # [1, T, H, D] + +''' + +OLD_XFORMER_BLOCK = """\ + self.attn_op = xops.fmha.flash.FwOp() + if self.alibi_slopes is None: + # Add the batch dimension. + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + out = xops.memory_efficient_attention_forward( + query, + key, + value, + attn_bias=attn_bias[0], + p=0.0, + scale=self.scale, + op = self.attn_op + ) + return out.view_as(original_query)\ +""" + +NEW_XFORMER_BLOCK = """\ + self.attn_op = xops.fmha.flash.FwOp() + if self.alibi_slopes is None: + # Add the batch dimension. + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + if self.head_size > 128: + out = self._run_sdpa_fallback(query, key, value, attn_metadata) + else: + out = xops.memory_efficient_attention_forward( + query, + key, + value, + attn_bias=attn_bias[0], + p=0.0, + scale=self.scale, + op=self.attn_op, + ) + return out.view_as(original_query)\ +""" + +INJECT_ANCHOR = " def _run_memory_efficient_xformers_forward(" + + +def patch_file(path): + with open(path, "r") as f: + content = f.read() + changed = False + + if "_run_sdpa_fallback" in content: + print(" [skip] _run_sdpa_fallback already present") + elif INJECT_ANCHOR not in content: + print(" [warn] inject anchor not found") + else: + content = content.replace(INJECT_ANCHOR, FALLBACK_METHOD + INJECT_ANCHOR, 1) + print(" [ok] injected _run_sdpa_fallback (seq, F.sdpa kernel)") + changed = True + + if NEW_XFORMER_BLOCK in content: + print(" [skip] dispatch block already patched") + elif OLD_XFORMER_BLOCK in content: + content = content.replace(OLD_XFORMER_BLOCK, NEW_XFORMER_BLOCK, 1) + print(" [ok] patched dispatch block") + changed = True + else: + print(" [warn] dispatch block anchor not found") + + if changed: + with open(path, "w") as f: + f.write(content) + print(f" Written: {path}") + + +def main(): + print("=== patch_xformers_sdpa_seq_kernel (seq, F.sdpa + kernel dispatch) ===") + print(f"Target: {XFORMERS_PATH}") + patch_file(XFORMERS_PATH) + print("\nDone.") + + +if __name__ == "__main__": + main() diff --git a/qwen3_6_scripts/qwen3_5.py b/qwen3_6_scripts/qwen3_5.py new file mode 100644 index 0000000..839e3db --- /dev/null +++ b/qwen3_6_scripts/qwen3_5.py @@ -0,0 +1,894 @@ +# Inference-only Qwen3.6-27B (Qwen3_5 architecture) for Iluvatar BI-V100. +# Pure-PyTorch DeltaNet (no fla / causal_conv1d dependency). +# Text-only (no VL, no MTP). + +from typing import Iterable, List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import GemmaRMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, sharded_weight_loader) +from vllm.model_executor.models.mamba_cache import MambaCacheManager +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.sequence import IntermediateTensors +from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, + _get_graph_batch_size) + +from vllm.model_executor.models.interfaces import HasInnerState, SupportsLoRA + + +# --------------------------------------------------------------------------- +# Pure-PyTorch DeltaNet kernels (fallbacks from transformers 5.2.0) +# --------------------------------------------------------------------------- + +def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor: + return x * torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) + + +def _torch_causal_conv1d_update( + hidden_states: torch.Tensor, # (batch, channels, seq=1) + conv_state: torch.Tensor, # (batch, channels, state_len) modified in-place + weight: torch.Tensor, # (channels, kernel_size) + bias: Optional[torch.Tensor] = None, + activation: Optional[str] = None, +) -> torch.Tensor: + _, channels, seq_len = hidden_states.shape + state_len = conv_state.shape[-1] + cat = torch.cat([conv_state, hidden_states], dim=-1).to(weight.dtype) + conv_state.copy_(cat[:, :, -state_len:]) + out = F.conv1d(cat, weight.unsqueeze(1), bias, padding=0, groups=channels) + out = out[:, :, -seq_len:] + if activation is not None: + out = F.silu(out) + return out.to(hidden_states.dtype) + + +def _torch_chunk_gated_delta_rule( + query: torch.Tensor, # (batch, seq, num_heads, head_k_dim) + key: torch.Tensor, + value: torch.Tensor, # (batch, seq, num_heads, head_v_dim) + g: torch.Tensor, # (batch, seq, num_heads) + beta: torch.Tensor, # (batch, seq, num_heads) + chunk_size: int = 64, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + use_qk_l2norm_in_kernel: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + query = _l2norm(query) + key = _l2norm(key) + # Transpose to (batch, num_heads, seq, dim) + query, key, value, beta, g = [ + x.transpose(1, 2).contiguous().to(torch.float32) + for x in (query, key, value, beta, g) + ] + batch, num_heads, seq_len, k_dim = key.shape + v_dim = value.shape[-1] + pad = (chunk_size - seq_len % chunk_size) % chunk_size + query = F.pad(query, (0, 0, 0, pad)) + key = F.pad(key, (0, 0, 0, pad)) + value = F.pad(value, (0, 0, 0, pad)) + beta = F.pad(beta, (0, pad)) + g = F.pad(g, (0, pad)) + total_len = seq_len + pad + scale = 1.0 / (query.shape[-1] ** 0.5) + query = query * scale + + v_beta = value * beta.unsqueeze(-1) + k_beta = key * beta.unsqueeze(-1) + query, key, value, k_beta, v_beta = [ + x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) + for x in (query, key, value, k_beta, v_beta) + ] + g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) + mask_upper = torch.triu( + torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), + diagonal=0) + + g = g.cumsum(dim=-1) + decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril() + attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask_upper, 0) + for i in range(1, chunk_size): + row = attn[..., i, :i].clone() + sub = attn[..., :i, :i].clone() + attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + value = attn @ v_beta + k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) + + last_state = ( + torch.zeros(batch, num_heads, k_dim, v_dim, dtype=value.dtype, device=value.device) + if initial_state is None + else initial_state.to(value) + ) + core_out = torch.zeros_like(value) + mask_upper2 = torch.triu( + torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), + diagonal=1) + + for i in range(total_len // chunk_size): + q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] + attn_i = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask_upper2, 0) + v_prime = k_cumdecay[:, :, i] @ last_state + v_new = v_i - v_prime + attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_state + core_out[:, :, i] = attn_inter + attn_i @ v_new + last_state = ( + last_state * g[:, :, i, -1, None, None].exp() + + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]) + .transpose(-1, -2) @ v_new + ) + + if not output_final_state: + last_state = None + core_out = core_out.reshape(batch, num_heads, -1, v_dim)[:, :, :seq_len] + core_out = core_out.transpose(1, 2).contiguous().to(initial_dtype) + return core_out, last_state + +def _torch_recurrent_gated_delta_rule( + query: torch.Tensor, # (batch, 1, num_heads, head_k_dim) + key: torch.Tensor, + value: torch.Tensor, + g: torch.Tensor, # (batch, 1, num_heads) + beta: torch.Tensor, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + use_qk_l2norm_in_kernel: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + query = _l2norm(query) + key = _l2norm(key) + query, key, value, beta, g = [ + x.transpose(1, 2).contiguous().to(torch.float32) + for x in (query, key, value, beta, g) + ] + batch, num_heads, seq_len, k_dim = key.shape + v_dim = value.shape[-1] + scale = 1.0 / (query.shape[-1] ** 0.5) + query = query * scale + + core_out = torch.zeros(batch, num_heads, seq_len, v_dim, + dtype=value.dtype, device=value.device) + last_state = ( + torch.zeros(batch, num_heads, k_dim, v_dim, + dtype=value.dtype, device=value.device) + if initial_state is None + else initial_state.to(value) + ) + for t in range(seq_len): + q_t = query[:, :, t] + k_t = key[:, :, t] + v_t = value[:, :, t] + g_t = g[:, :, t].exp().unsqueeze(-1).unsqueeze(-1) + beta_t = beta[:, :, t].unsqueeze(-1) + last_state = last_state * g_t + kv_mem = (last_state * k_t.unsqueeze(-1)).sum(dim=-2) + delta = (v_t - kv_mem) * beta_t + last_state = last_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) + core_out[:, :, t] = (last_state * q_t.unsqueeze(-1)).sum(dim=-2) + + if not output_final_state: + last_state = None + core_out = core_out.transpose(1, 2).contiguous().to(initial_dtype) + return core_out, last_state + + +# --------------------------------------------------------------------------- +# Gated RMSNorm (for DeltaNet output normalisation) +# --------------------------------------------------------------------------- + +class Qwen3_5RMSNormGated(nn.Module): + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor, + gate: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hs = hidden_states.to(torch.float32) + variance = hs.pow(2).mean(-1, keepdim=True) + hs = hs * torch.rsqrt(variance + self.variance_epsilon) + hs = self.weight * hs.to(input_dtype) + return (hs * F.silu(gate.to(torch.float32))).to(input_dtype) + + +# --------------------------------------------------------------------------- +# Gated DeltaNet (linear_attention layers) +# --------------------------------------------------------------------------- + +class GatedDeltaNet(nn.Module): + def __init__( + self, + text_cfg, + layer_idx: int, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.hidden_size = text_cfg.hidden_size + self.num_v_heads = text_cfg.linear_num_value_heads # 48 + self.num_k_heads = text_cfg.linear_num_key_heads # 16 + self.head_k_dim = text_cfg.linear_key_head_dim # 128 + self.head_v_dim = text_cfg.linear_value_head_dim # 128 + self.key_dim = self.num_k_heads * self.head_k_dim # 2048 + self.value_dim = self.num_v_heads * self.head_v_dim # 6144 + self.conv_dim = self.key_dim * 2 + self.value_dim # 10240 + self.conv_kernel_size = text_cfg.linear_conv_kernel_dim # 4 + self.head_expand_ratio = self.num_v_heads // self.num_k_heads # 3 + + tp_size = get_tensor_model_parallel_world_size() + + # Sharded projections — MergedColumnParallelLinear shards each of q/k/v + # independently so each TP rank gets [q_shard, k_shard, v_shard]. + # Plain ColumnParallelLinear would shard contiguously, giving rank 0 + # [q_all, k_partial] — completely wrong Q/K/V after the split below. + self.in_proj_qkv = MergedColumnParallelLinear( + self.hidden_size, [self.key_dim, self.key_dim, self.value_dim], + bias=False, quant_config=quant_config) + self.in_proj_z = ColumnParallelLinear( + self.hidden_size, self.value_dim, + bias=False, quant_config=quant_config) + self.in_proj_b = ColumnParallelLinear( + self.hidden_size, self.num_v_heads, + bias=False, quant_config=quant_config) + self.in_proj_a = ColumnParallelLinear( + self.hidden_size, self.num_v_heads, + bias=False, quant_config=quant_config) + self.out_proj = RowParallelLinear( + self.value_dim, self.hidden_size, + bias=False, quant_config=quant_config) + + # Depthwise conv weight — sharded along channel dim (dim 0) + local_conv_dim = self.conv_dim // tp_size + self.conv1d_weight = nn.Parameter( + torch.empty(local_conv_dim, 1, self.conv_kernel_size)) + set_weight_attrs(self.conv1d_weight, { + "weight_loader": self._conv1d_weight_loader}) + + # Per-head scalar parameters — sharded along dim 0 + local_num_v = self.num_v_heads // tp_size + self.A_log = nn.Parameter(torch.zeros(local_num_v)) + self.dt_bias = nn.Parameter(torch.zeros(local_num_v)) + set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)}) + set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)}) + + # Gated RMSNorm on head_v_dim — replicated (head_v_dim=128 is small) + self.norm = Qwen3_5RMSNormGated(self.head_v_dim, + eps=text_cfg.rms_norm_eps) + + def _conv1d_weight_loader(self, param: torch.Tensor, + loaded_weight: torch.Tensor) -> None: + # loaded_weight: (conv_dim=10240, 1, kernel) ordered as [q, k, v] channels + # Must gather channels in the same non-contiguous pattern that + # MergedColumnParallelLinear uses for in_proj_qkv, so that each rank's + # conv1d_weight[i] applies to the correct in_proj_qkv output channel. + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + key_local = self.key_dim // tp_size # 512 with TP=4 + val_local = self.value_dim // tp_size # 1536 with TP=4 + q_s = loaded_weight[tp_rank * key_local : (tp_rank + 1) * key_local] + k_s = loaded_weight[self.key_dim + tp_rank * key_local : + self.key_dim + (tp_rank + 1) * key_local] + v_s = loaded_weight[2 * self.key_dim + tp_rank * val_local : + 2 * self.key_dim + (tp_rank + 1) * val_local] + param.data.copy_(torch.cat([q_s, k_s, v_s], dim=0)) + + def forward( + self, + hidden_states: torch.Tensor, # (total_tokens, hidden_size) + attn_metadata: AttentionMetadata, + conv_state: torch.Tensor, # (batch, local_conv_dim, kernel-1) in-place + temporal_state: torch.Tensor, # (batch, local_v_heads, k_dim, v_dim) in-place + ) -> torch.Tensor: + tp_size = get_tensor_model_parallel_world_size() + local_key_dim = self.key_dim // tp_size + local_val_dim = self.value_dim // tp_size + local_num_v = self.num_v_heads // tp_size + local_num_k = self.num_k_heads // tp_size + local_conv_dim = self.conv_dim // tp_size + + is_prefill = attn_metadata.num_prefill_tokens > 0 + + # Compute all projections for every token at once (batched, efficient) + mixed_qkv_all, _ = self.in_proj_qkv(hidden_states) # (total, local_conv_dim) + z_all, _ = self.in_proj_z(hidden_states) # (total, local_val_dim) + b_all, _ = self.in_proj_b(hidden_states) # (total, local_num_v) + a_all, _ = self.in_proj_a(hidden_states) # (total, local_num_v) + + if is_prefill: + seq_starts = attn_metadata.query_start_loc.tolist() + outputs = [] + state_len = self.conv_kernel_size - 1 + weight_2d = self.conv1d_weight.squeeze(1) # (local_conv_dim, kernel) + + for si in range(len(seq_starts) - 1): + s, e = int(seq_starts[si]), int(seq_starts[si + 1]) + seq_len = e - s + + # Shape: (1, local_conv_dim, seq_len) + mixed_qkv = (mixed_qkv_all[s:e] + .transpose(0, 1).unsqueeze(0) + .to(weight_2d.dtype)) + + # Save conv state (last state_len positions) + if seq_len >= state_len: + conv_state[si].copy_(mixed_qkv[0, :, -state_len:]) + else: + conv_state[si, :, state_len - seq_len:].copy_( + mixed_qkv[0]) + conv_state[si, :, :state_len - seq_len] = 0 + + # Causal conv (left-pad with zeros, then convolve) + padded = F.pad(mixed_qkv, (state_len, 0)) + mixed_qkv_conv = F.conv1d( + padded, self.conv1d_weight, + bias=None, padding=0, groups=local_conv_dim) + mixed_qkv_conv = F.silu(mixed_qkv_conv) + # (1, seq_len, local_conv_dim) + mixed_qkv_conv = mixed_qkv_conv.squeeze(0).transpose(0, 1).unsqueeze(0) + + q, k, v = torch.split( + mixed_qkv_conv, + [local_key_dim, local_key_dim, local_val_dim], dim=-1) + q = q.reshape(1, seq_len, local_num_k, self.head_k_dim) + k = k.reshape(1, seq_len, local_num_k, self.head_k_dim) + v = v.reshape(1, seq_len, local_num_v, self.head_v_dim) + + beta = b_all[s:e].sigmoid().unsqueeze(0) # (1, seq_len, local_num_v) + g = (-self.A_log.float().exp() + * F.softplus(a_all[s:e].float() + self.dt_bias) + ).unsqueeze(0) # (1, seq_len, local_num_v) + + # Expand k/q to match num_v_heads + q = q.repeat_interleave(self.head_expand_ratio, dim=2) + k = k.repeat_interleave(self.head_expand_ratio, dim=2) + + core_out, last_state = _torch_chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=temporal_state[si:si + 1], + output_final_state=True, + use_qk_l2norm_in_kernel=True, + ) + if last_state is not None: + temporal_state[si].copy_(last_state[0]) + + # Gate + norm + output proj + z = z_all[s:e].reshape(seq_len, local_num_v, self.head_v_dim) + core_out = core_out.reshape(seq_len, local_num_v, self.head_v_dim) + normed = self.norm( + core_out.reshape(-1, self.head_v_dim), + z.reshape(-1, self.head_v_dim)) + normed = normed.reshape(seq_len, -1) + out, _ = self.out_proj(normed) + outputs.append(out) + + result = torch.cat(outputs, dim=0) + assert not torch.isnan(result).any(), f"NaN in prefill layer {self.layer_idx}" + return result + + else: + # Decode: one token per sequence + num_seqs = hidden_states.shape[0] + weight_2d = self.conv1d_weight.squeeze(1) + + # (num_seqs, local_conv_dim, 1) + mixed_qkv = (mixed_qkv_all + .to(weight_2d.dtype) + .unsqueeze(-1)) + + mixed_qkv_conv = _torch_causal_conv1d_update( + mixed_qkv, conv_state, weight_2d, + bias=None, activation='silu') + # (num_seqs, local_conv_dim, 1) → (num_seqs, 1, local_conv_dim) + mixed_qkv_conv = mixed_qkv_conv.squeeze(-1).unsqueeze(1) + + q, k, v = torch.split( + mixed_qkv_conv, + [local_key_dim, local_key_dim, local_val_dim], dim=-1) + q = q.reshape(num_seqs, 1, local_num_k, self.head_k_dim) + k = k.reshape(num_seqs, 1, local_num_k, self.head_k_dim) + v = v.reshape(num_seqs, 1, local_num_v, self.head_v_dim) + + beta = b_all.sigmoid().unsqueeze(1) # (num_seqs, 1, local_num_v) + g = (-self.A_log.float().exp() + * F.softplus(a_all.float() + self.dt_bias) + ).unsqueeze(1) # (num_seqs, 1, local_num_v) + + q = q.repeat_interleave(self.head_expand_ratio, dim=2) + k = k.repeat_interleave(self.head_expand_ratio, dim=2) + + core_out, last_state = _torch_recurrent_gated_delta_rule( + q, k, v, g, beta, + initial_state=temporal_state, + output_final_state=True, + use_qk_l2norm_in_kernel=True, + ) + if last_state is not None: + temporal_state.copy_(last_state) + + z = z_all.reshape(num_seqs, local_num_v, self.head_v_dim) + core_out = core_out.reshape(num_seqs, local_num_v, self.head_v_dim) + normed = self.norm( + core_out.reshape(-1, self.head_v_dim), + z.reshape(-1, self.head_v_dim)) + normed = normed.reshape(num_seqs, -1) + out, _ = self.out_proj(normed) + assert not torch.isnan(out).any(), f"NaN in layer {self.layer_idx}" + return out + + +# --------------------------------------------------------------------------- +# Full Attention (with gated q — unique to Qwen3.5) +# --------------------------------------------------------------------------- + +class Qwen3_5FullAttention(nn.Module): + def __init__( + self, + text_cfg, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.hidden_size = text_cfg.hidden_size # 5120 + self.num_heads = text_cfg.num_attention_heads # 24 + self.num_kv_heads = text_cfg.num_key_value_heads # 4 + self.head_dim = text_cfg.head_dim # 256 + self.rms_norm_eps = text_cfg.rms_norm_eps + + tp_size = get_tensor_model_parallel_world_size() + self.local_num_heads = self.num_heads // tp_size + self.local_num_kv_heads = max(1, self.num_kv_heads // tp_size) + self.local_q_dim = self.local_num_heads * self.head_dim + self.local_kv_dim = self.local_num_kv_heads * self.head_dim + self.scaling = self.head_dim ** -0.5 + + # q_proj includes gate: output = num_heads * head_dim * 2 + self.q_proj = ColumnParallelLinear( + self.hidden_size, self.num_heads * self.head_dim * 2, + bias=False, quant_config=quant_config, + prefix=f"{prefix}.q_proj") + self.k_proj = ColumnParallelLinear( + self.hidden_size, self.num_kv_heads * self.head_dim, + bias=False, quant_config=quant_config, + prefix=f"{prefix}.k_proj") + self.v_proj = ColumnParallelLinear( + self.hidden_size, self.num_kv_heads * self.head_dim, + bias=False, quant_config=quant_config, + prefix=f"{prefix}.v_proj") + self.o_proj = RowParallelLinear( + self.num_heads * self.head_dim, self.hidden_size, + bias=False, quant_config=quant_config, + prefix=f"{prefix}.o_proj") + + self.q_norm = GemmaRMSNorm(self.head_dim, eps=self.rms_norm_eps) + self.k_norm = GemmaRMSNorm(self.head_dim, eps=self.rms_norm_eps) + + # Partial RoPE: rotary_dim = head_dim * partial_rotary_factor = 256 * 0.25 = 64 + rope_params = getattr(text_cfg, "rope_parameters", {}) or {} + rope_theta = rope_params.get("rope_theta", 10_000_000) + partial_factor = rope_params.get("partial_rotary_factor", 0.25) + rotary_dim = int(self.head_dim * partial_factor) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=rotary_dim, + max_position=text_cfg.max_position_embeddings, + base=rope_theta, + ) + + self.attn = Attention( + self.local_num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.local_num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + total_tokens = hidden_states.shape[0] + + # q_proj output includes gate (dim doubled) + qg, _ = self.q_proj(hidden_states) # (total, local_num_heads * head_dim * 2) + qg = qg.view(total_tokens, self.local_num_heads, self.head_dim * 2) + q = qg[:, :, :self.head_dim].reshape(total_tokens, -1) + gate = qg[:, :, self.head_dim:].reshape(total_tokens, -1) + + k, _ = self.k_proj(hidden_states) # (total, local_kv_dim) + v, _ = self.v_proj(hidden_states) + + # Per-head RMSNorm + q = self.q_norm.forward_cuda( + q.view(total_tokens, self.local_num_heads, self.head_dim) + .contiguous()).view(total_tokens, -1) + k = self.k_norm.forward_cuda( + k.view(total_tokens, self.local_num_kv_heads, self.head_dim) + .contiguous()).view(total_tokens, -1) + + q, k = self.rotary_emb(positions, q, k) + attn_out = self.attn(q, k, v, kv_cache, attn_metadata) + + # Multiply by sigmoid gate before output projection + attn_out = attn_out * torch.sigmoid(gate.float()).to(attn_out.dtype) + output, _ = self.o_proj(attn_out) + return output + + +# --------------------------------------------------------------------------- +# MLP (SwiGLU, same as Qwen2/Qwen3) +# --------------------------------------------------------------------------- + +class Qwen3_5MLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, quant_config=quant_config) + self.down_proj = RowParallelLinear( + intermediate_size, hidden_size, + bias=False, quant_config=quant_config) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}") + self.act_fn = SiluAndMul() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +# --------------------------------------------------------------------------- +# Decoder layer (dispatches to GatedDeltaNet or Qwen3_5FullAttention) +# --------------------------------------------------------------------------- + +class Qwen3_5DecoderLayer(nn.Module): + def __init__( + self, + text_cfg, + layer_idx: int, + layer_type: str, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.layer_type = layer_type + self.input_layernorm = GemmaRMSNorm(text_cfg.hidden_size, + eps=text_cfg.rms_norm_eps) + self.post_attention_layernorm = GemmaRMSNorm(text_cfg.hidden_size, + eps=text_cfg.rms_norm_eps) + + if layer_type == "linear_attention": + self.linear_attn = GatedDeltaNet(text_cfg, layer_idx, + quant_config=quant_config) + else: + self.self_attn = Qwen3_5FullAttention( + text_cfg, layer_idx, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"layers.{layer_idx}.self_attn", + ) + + self.mlp = Qwen3_5MLP( + hidden_size=text_cfg.hidden_size, + intermediate_size=text_cfg.intermediate_size, + hidden_act=text_cfg.hidden_act, + quant_config=quant_config, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + # Only for linear_attention layers: + conv_state: Optional[torch.Tensor] = None, + temporal_state: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + if self.layer_type == "linear_attention": + hidden_states = self.linear_attn( + hidden_states, attn_metadata, conv_state, temporal_state) + else: + hidden_states = self.self_attn( + positions, hidden_states, kv_cache, attn_metadata) + + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +# --------------------------------------------------------------------------- +# Full transformer model +# --------------------------------------------------------------------------- + +class Qwen3_5Model(nn.Module): + def __init__( + self, + text_cfg, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.text_cfg = text_cfg + self.embed_tokens = VocabParallelEmbedding( + text_cfg.vocab_size, text_cfg.hidden_size) + self.layers = nn.ModuleList([ + Qwen3_5DecoderLayer( + text_cfg, i, text_cfg.layer_types[i], + cache_config=cache_config, quant_config=quant_config) + for i in range(text_cfg.num_hidden_layers) + ]) + self.norm = GemmaRMSNorm(text_cfg.hidden_size, eps=text_cfg.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + conv_states: torch.Tensor, # (num_linear_layers, batch, ...) + temporal_states: torch.Tensor, # (num_linear_layers, batch, ...) + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + residual = None + + attn_idx = 0 + linear_idx = 0 + for layer in self.layers: + if layer.layer_type == "linear_attention": + hidden_states, residual = layer( + positions, hidden_states, + kv_cache=None, + attn_metadata=attn_metadata, + residual=residual, + conv_state=conv_states[linear_idx], + temporal_state=temporal_states[linear_idx], + ) + linear_idx += 1 + else: + kv_cache = kv_caches[attn_idx] + hidden_states, residual = layer( + positions, hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + residual=residual, + ) + attn_idx += 1 + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +# --------------------------------------------------------------------------- +# Top-level CausalLM wrapper with MambaCacheManager +# --------------------------------------------------------------------------- + +class Qwen3_5ForCausalLM(nn.Module, HasInnerState, SupportsLoRA): + + has_inner_state = True + supports_lora = True + + packed_modules_mapping = { + "gate_up_proj": ["gate_proj", "up_proj"], + } + + supported_lora_modules = [ + "gate_up_proj", + "down_proj", + "o_proj", + ] + embedding_modules = {} + embedding_padding_modules = [] + + def __init__( + self, + config, # Qwen3_5Config (top-level) + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + scheduler_config: Optional[SchedulerConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.scheduler_config = scheduler_config + + # The text config holds all architecture parameters + text_cfg = config.text_config + self.text_cfg = text_cfg + + # Pre-compute counts + self.num_linear_layers = sum( + 1 for lt in text_cfg.layer_types if lt == "linear_attention") + self.num_attn_layers = sum( + 1 for lt in text_cfg.layer_types if lt == "full_attention") + + # DeltaNet state dimensions (per layer, per sequence, TP-sharded) + tp_size = get_tensor_model_parallel_world_size() + self.conv_dim = (text_cfg.linear_num_key_heads * text_cfg.linear_key_head_dim * 2 + + text_cfg.linear_num_value_heads * text_cfg.linear_value_head_dim) + self.num_v_heads = text_cfg.linear_num_value_heads + self.head_k_dim = text_cfg.linear_key_head_dim + self.head_v_dim = text_cfg.linear_value_head_dim + self.conv_kernel_size = text_cfg.linear_conv_kernel_dim + + self.model = Qwen3_5Model( + text_cfg, + cache_config=cache_config, + quant_config=quant_config, + ) + + self.lm_head = ParallelLMHead( + text_cfg.vocab_size, text_cfg.hidden_size, + quant_config=quant_config, + ) + + self.logits_processor = LogitsProcessor(text_cfg.vocab_size) + self.sampler = Sampler() + + # Lazy initialised in first forward call + self.mamba_cache: Optional[MambaCacheManager] = None + + def _get_mamba_cache_shape(self): + tp_size = get_tensor_model_parallel_world_size() + # Each sequence's state is stored in float32 + conv_state_shape = (self.conv_dim // tp_size, self.conv_kernel_size - 1) + temporal_state_shape = ( + self.num_v_heads // tp_size, self.head_k_dim, self.head_v_dim) + return conv_state_shape, temporal_state_shape + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs, + ) -> torch.Tensor: + if self.mamba_cache is None: + if self.scheduler_config is not None: + max_batch_size = _get_graph_batch_size( + self.scheduler_config.max_num_seqs) + else: + max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) + 2 + self.mamba_cache = MambaCacheManager( + torch.float32, + self.num_linear_layers, + max_batch_size, + *self._get_mamba_cache_shape(), + ) + + mamba_tensors = self.mamba_cache.current_run_tensors( + input_ids, attn_metadata, **kwargs) + # conv_states: (num_linear_layers, batch, local_conv_dim, kernel-1) + # temporal_states: (num_linear_layers, batch, local_num_v, k_dim, v_dim) + conv_states, temporal_states = mamba_tensors + + hidden_states = self.model( + input_ids, positions, kv_caches, attn_metadata, + conv_states, temporal_states) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.sampler(logits, sampling_metadata) + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.mamba_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, weight_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + + for name, loaded_weight in weights: + # Skip vision and MTP branches + if (name.startswith("model.visual") + or name.startswith("mtp.") + or name.startswith("model.mtp")): + continue + + # Remap checkpoint prefix → module path + # Checkpoint: "model.language_model.{rest}" → our module: "model.{rest}" + # Checkpoint: "lm_head.weight" → our module: "lm_head.weight" + if name.startswith("model.language_model."): + name = "model." + name[len("model.language_model."):] + # lm_head is already at top level — no change needed + + # Skip positional embedding caches + if "rotary_emb.inv_freq" in name: + continue + + # Remap conv1d.weight → conv1d_weight + # The conv has depth (1) dim in the checkpoint that we handle separately + if ".linear_attn.conv1d.weight" in name: + name = name.replace(".linear_attn.conv1d.weight", + ".linear_attn.conv1d_weight") + + # Stacked param loading (gate_up_proj) + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if name.endswith(".bias") and name not in params_dict: + break + if name not in params_dict: + break + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/qwen3_6_scripts/qwen3_5/__init__.py b/qwen3_6_scripts/qwen3_5/__init__.py new file mode 100644 index 0000000..168e15f --- /dev/null +++ b/qwen3_6_scripts/qwen3_5/__init__.py @@ -0,0 +1,3 @@ +from .configuration_qwen3_5 import Qwen3_5Config, Qwen3_5TextConfig, Qwen3_5VisionConfig + +__all__ = ["Qwen3_5Config", "Qwen3_5TextConfig", "Qwen3_5VisionConfig"] diff --git a/qwen3_6_scripts/qwen3_5/configuration_qwen3_5.py b/qwen3_6_scripts/qwen3_5/configuration_qwen3_5.py new file mode 100644 index 0000000..afe21f7 --- /dev/null +++ b/qwen3_6_scripts/qwen3_5/configuration_qwen3_5.py @@ -0,0 +1,188 @@ +# Adapted from transformers 5.2.0 for compatibility with transformers 4.55.3 + torch 2.1.0 +# Stubs layer_type_validation and RopeParameters which do not exist in 4.55.3 + +from typing import Optional, List + +from ...configuration_utils import PretrainedConfig as PreTrainedConfig + +# --- Local stubs for APIs not present in transformers 4.55.3 --- +# Always use these definitions; do NOT import from the older transformers +# as same-named functions there have incompatible signatures. + +def layer_type_validation(layer_types, num_hidden_layers=None, attention=True): + allowed = {"full_attention", "linear_attention"} + if not all(lt in allowed for lt in layer_types): + raise ValueError(f"layer_types entries must be in {allowed}, got {layer_types}") + if num_hidden_layers is not None and num_hidden_layers != len(layer_types): + raise ValueError( + f"num_hidden_layers ({num_hidden_layers}) != len(layer_types) ({len(layer_types)})" + ) + +try: + from typing import TypedDict + class RopeParameters(TypedDict, total=False): + rope_theta: float + rope_type: str + partial_rotary_factor: float + factor: float +except Exception: + RopeParameters = dict + +# --- End stubs --- + + +class Qwen3_5TextConfig(PreTrainedConfig): + r""" + Configuration for the text backbone of Qwen3.5 / Qwen3.6-27B models. + model_type is "qwen3_5_text" (used internally by the nested config). + """ + + model_type = "qwen3_5_text" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=248320, + hidden_size=4096, + intermediate_size=12288, + num_hidden_layers=32, + num_attention_heads=16, + num_key_value_heads=4, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_parameters=None, + attention_bias=False, + attention_dropout=0.0, + head_dim=256, + linear_conv_kernel_dim=4, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_num_key_heads=16, + linear_num_value_heads=32, + layer_types=None, + pad_token_id=None, + bos_token_id=None, + eos_token_id=None, + **kwargs, + ): + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.tie_word_embeddings = tie_word_embeddings + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.head_dim = head_dim + self.rope_parameters = rope_parameters + kwargs.setdefault("partial_rotary_factor", 0.25) + + self.layer_types = layer_types + if self.layer_types is None: + interval_pattern = kwargs.get("full_attention_interval", 4) + self.layer_types = [ + "linear_attention" if bool((i + 1) % interval_pattern) else "full_attention" + for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types, self.num_hidden_layers) + + self.linear_conv_kernel_dim = linear_conv_kernel_dim + self.linear_key_head_dim = linear_key_head_dim + self.linear_value_head_dim = linear_value_head_dim + self.linear_num_key_heads = linear_num_key_heads + self.linear_num_value_heads = linear_num_value_heads + super().__init__(**kwargs) + + +class Qwen3_5VisionConfig(PreTrainedConfig): + model_type = "qwen3_5_vision" + + def __init__( + self, + depth=27, + hidden_size=1152, + hidden_act="gelu_pytorch_tanh", + intermediate_size=4304, + num_heads=16, + in_channels=3, + patch_size=16, + spatial_merge_size=2, + temporal_patch_size=2, + out_hidden_size=3584, + num_position_embeddings=2304, + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + self.depth = depth + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.out_hidden_size = out_hidden_size + self.num_position_embeddings = num_position_embeddings + self.initializer_range = initializer_range + + +class Qwen3_5Config(PreTrainedConfig): + r""" + Top-level configuration for Qwen3.5 / Qwen3.6-27B. + model_type = "qwen3_5" matches the model card / config.json. + Wraps Qwen3_5TextConfig (and optionally Qwen3_5VisionConfig for multimodal use). + For vLLM text-only inference only text_config is consumed. + """ + + model_type = "qwen3_5" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + text_config=None, + vision_config=None, + image_token_id=248056, + video_token_id=248057, + vision_start_token_id=248053, + vision_end_token_id=248054, + tie_word_embeddings=False, + **kwargs, + ): + if isinstance(text_config, dict): + self.text_config = Qwen3_5TextConfig(**text_config) + elif text_config is None: + self.text_config = Qwen3_5TextConfig() + else: + self.text_config = text_config + + if isinstance(vision_config, dict): + self.vision_config = Qwen3_5VisionConfig(**vision_config) + elif vision_config is None: + self.vision_config = Qwen3_5VisionConfig() + else: + self.vision_config = vision_config + + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.vision_start_token_id = vision_start_token_id + self.vision_end_token_id = vision_end_token_id + self.tie_word_embeddings = tie_word_embeddings + super().__init__(**kwargs) + + +__all__ = ["Qwen3_5Config", "Qwen3_5TextConfig", "Qwen3_5VisionConfig"]