commit d07e3fc172001e2254c6ab7c4dcb9d7e55c5f9c2 Author: ModelHub XC Date: Mon Jun 8 01:30:13 2026 +0800 初始化项目,由ModelHub XC社区提供模型 Model: Efficient-Large-Model/Fast_dLLM_1.5B Source: Original Platform diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..54b1208 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,53 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bin.* filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zstandard filter=lfs diff=lfs merge=lfs -text +*.tfevents* filter=lfs diff=lfs merge=lfs -text +*.db* filter=lfs diff=lfs merge=lfs -text +*.ark* filter=lfs diff=lfs merge=lfs -text +**/*ckpt*data* filter=lfs diff=lfs merge=lfs -text +**/*ckpt*.meta filter=lfs diff=lfs merge=lfs -text +**/*ckpt*.index filter=lfs diff=lfs merge=lfs -text + +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.gguf* filter=lfs diff=lfs merge=lfs -text +*.ggml filter=lfs diff=lfs merge=lfs -text +*.llamafile* filter=lfs diff=lfs merge=lfs -text +*.pt2 filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text + +tokenizer.json filter=lfs diff=lfs merge=lfs -text +merges.txt filter=lfs diff=lfs merge=lfs -text +vocab.json filter=lfs diff=lfs merge=lfs -text +model.safetensors filter=lfs diff=lfs merge=lfs -text +assets/visualization_animation.gif filter=lfs diff=lfs merge=lfs -text \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..a9a9052 --- /dev/null +++ b/README.md @@ -0,0 +1,150 @@ +--- +license: apache-2.0 +language: +- en +base_model: +- Qwen/Qwen2.5-1.5B-Instruct +--- + +# Fast-dLLM v2 (1.5B) — Efficient Block-Diffusion LLM + +## 📖 Introduction + +Autoregressive (AR) large language models (LLMs) have achieved remarkable performance across a wide range of natural language tasks, yet their **inherent sequential decoding limits inference efficiency**. + +We present **Fast-dLLM v2** — a carefully designed **block diffusion language model (dLLM)** that efficiently adapts a pretrained AR model (**Qwen2.5-1.5B-Instruct**) into a diffusion-style decoder for **parallel text generation**. + +Our approach introduces a novel decoding recipe incorporating a complementary attention mask and block diffusion mechanism, which together enable blockwise bidirectional context modeling while preserving the original AR training objectives and performance. To further enhance inference speed, we design a hierarchical caching mechanism: a block-level cache that stores historical context representations and a sub-block level cache that supports efficient parallel decoding within partially generated blocks. + +### ✨ Key Innovations +- **Block Diffusion Mechanism + Complementary Attention Mask** + Enables **blockwise bidirectional context modeling** without sacrificing AR objectives. +- **Hierarchical Caching** + - **Block-level cache**: Stores historical context representations across blocks. + - **Sub-block cache**: Parallel decoding within partially generated blocks. +- **Token Shift Mechanism** + Retains autoregressive characteristics while supporting bidirectional context within blocks. +- **Parallel Decoding Pipeline** + Achieves up to **2.5× speedup** over standard AR decoding **without compromising quality**. + +> 🚀 Fast-dLLM v2 uses **only ~1B tokens** for fine-tuning — a **500× reduction** vs. full-attention diffusion LLMs (Dream: 580B tokens) — while **matching or surpassing AR baselines** in accuracy. + + +![Generation Process](assets/visualization_animation.gif) + +--- + +## 🛠 Model Overview +- **Type**: Block Diffusion Language Model (dLLM) +- **Base Model**: `Qwen/Qwen2.5-1.5B-Instruct` +- **Architecture**: Transformer w/ RoPE, SwiGLU, RMSNorm, Attention QKV bias, tied embeddings +- **Params**: 1.54B (non-embedding: 1.31B) +- **Layers**: 28 +- **Attention Heads**: 12 (Q), 2 (KV, GQA) +- **Key Feature**: Parallel **block-wise decoding** + **hierarchical caching** + +--- + +## 📦 Installation +You will need `transformers`, `torch`, and our **custom generation function**: + +```bash +pip install transformers torch numpy +``` + +--- + +## 🚀 Quickstart + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer + +model_name = "Efficient-Large-Model/Fast_dLLM_1.5B" + +model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype="auto", + device_map="auto", + trust_remote_code=True +) + +tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + +prompt = "Give me a short introduction to large language model." +messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt} +] + +text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True +) +inputs = tokenizer([text], return_tensors="pt").to(model.device) + +# Fast-dLLM v2 parallel decoding +gen_ids = model.generate( + inputs["input_ids"], + tokenizer=tokenizer, + max_new_tokens=512, + small_block_size=8, + threshold=0.9, +) + +response = tokenizer.decode( + gen_ids[0][inputs["input_ids"].shape[1]:], + skip_special_tokens=True +) +print(response) +``` + +--- + +## 📊 Performance & Benchmarks + +### ▶ Real-time Throughput +Fast-dLLM v2 offers **up to 2.54× higher throughput** than Qwen2.5-7B-Instruct, **without loss in quality**. + +![Throughput Comparison](assets/throughput.png) + +--- + +### 🏆 Benchmark Results +We compare Fast-dLLM v2 against AR baselines and previous diffusion LLMs on diverse tasks: +HumanEval, MBPP (code), GSM8K, Math (reasoning), IFEval (instruction), MMLU, GPQA (knowledge QA). + +- **1B group**: Fast-dLLM v2 (1.5B) achieves **best average score: 45.0**. +- **7B group**: Fast-dLLM v2 (7B) achieves **best average score: 60.3**, surpassing LLaDA and Dream models. + +![Benchmark Results](assets/benchmark_results.png) + +--- + +## 📜 Citation + +If you use Fast-dLLM v2 in your research or products, please cite: + +```bibtex +@misc{wu2025fastdllmv2efficientblockdiffusion, + title={Fast-dLLM v2: Efficient Block-Diffusion LLM}, + author={Chengyue Wu and Hao Zhang and Shuchen Xue and Shizhe Diao and Yonggan Fu and Zhijian Liu and Pavlo Molchanov and Ping Luo and Song Han and Enze Xie}, + year={2025}, + eprint={2509.26328}, + archivePrefix={arXiv}, + primaryClass={cs.CL}, + url={https://arxiv.org/abs/2509.26328}, +} +``` + +--- + +## 📄 License +Released under **Apache 2.0**, following the base Qwen2.5 license. + +--- + +## 🔗 Resources +- 📄 [Paper](https://arxiv.org/abs/2509.26328) +- 💻 [Code](https://github.com/NVlabs/Fast-dLLM) +- 🤗 [HuggingFace Model](https://huggingface.co/Efficient-Large-Model/Fast_dLLM_1.5B) \ No newline at end of file diff --git a/added_tokens.json b/added_tokens.json new file mode 100644 index 0000000..c24cd79 --- /dev/null +++ b/added_tokens.json @@ -0,0 +1,25 @@ +{ + "": 151658, + "": 151657, + "<|box_end|>": 151649, + "<|box_start|>": 151648, + "<|endoftext|>": 151643, + "<|file_sep|>": 151664, + "<|fim_middle|>": 151660, + "<|fim_pad|>": 151662, + "<|fim_prefix|>": 151659, + "<|fim_suffix|>": 151661, + "<|im_end|>": 151645, + "<|im_start|>": 151644, + "<|image_pad|>": 151655, + "<|object_ref_end|>": 151647, + "<|object_ref_start|>": 151646, + "<|quad_end|>": 151651, + "<|quad_start|>": 151650, + "<|repo_name|>": 151663, + "<|video_pad|>": 151656, + "<|vision_end|>": 151653, + "<|vision_pad|>": 151654, + "<|vision_start|>": 151652, + "||": 151665 +} diff --git a/assets/benchmark_results.png b/assets/benchmark_results.png new file mode 100644 index 0000000..c9dff34 Binary files /dev/null and b/assets/benchmark_results.png differ diff --git a/assets/throughput.png b/assets/throughput.png new file mode 100644 index 0000000..948f34d Binary files /dev/null and b/assets/throughput.png differ diff --git a/assets/training_recipe.png b/assets/training_recipe.png new file mode 100644 index 0000000..54516dc Binary files /dev/null and b/assets/training_recipe.png differ diff --git a/assets/visualization_animation.gif b/assets/visualization_animation.gif new file mode 100644 index 0000000..ca3d01f --- /dev/null +++ b/assets/visualization_animation.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2c4c7fb54af204ea8cc03a8dadc9dde8dc8fb5ac514ead026a5d2833ee3aad37 +size 1362150 diff --git a/chat_template.jinja b/chat_template.jinja new file mode 100644 index 0000000..28028c0 --- /dev/null +++ b/chat_template.jinja @@ -0,0 +1,54 @@ +{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages[0]['role'] == 'system' %} + {{- messages[0]['content'] }} + {%- else %} + {{- 'You are a helpful assistant.' }} + {%- endif %} + {{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} +{%- else %} + {%- if messages[0]['role'] == 'system' %} + {{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }} + {%- else %} + {{- '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- for message in messages %} + {%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {{- '<|im_start|>' + message.role }} + {%- if message.content %} + {{- '\n' + message.content }} + {%- endif %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {{- tool_call.arguments | tojson }} + {{- '}\n' }} + {%- endfor %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- message.content }} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %} diff --git a/config.json b/config.json new file mode 100644 index 0000000..6049219 --- /dev/null +++ b/config.json @@ -0,0 +1,65 @@ +{ + "architectures": [ + "Fast_dLLM_QwenForCausalLM" + ], + "attention_dropout": 0.0, + "auto_map": { + "AutoConfig": "configuration.Fast_dLLM_QwenConfig", + "AutoModel": "modeling.Fast_dLLM_QwenModel", + "AutoModelForCausalLM": "modeling.Fast_dLLM_QwenForCausalLM" + }, + "bd_size": 32, + "bos_token_id": 151643, + "eos_token_id": 151645, + "hidden_act": "silu", + "hidden_size": 1536, + "initializer_range": 0.02, + "intermediate_size": 8960, + "layer_types": [ + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention" + ], + "max_position_embeddings": 32768, + "max_window_layers": 21, + "model_type": "Fast_dLLM_Qwen", + "num_attention_heads": 12, + "num_hidden_layers": 28, + "num_key_value_heads": 2, + "pad_token_id": 151645, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000.0, + "sliding_window": null, + "tie_word_embeddings": true, + "torch_dtype": "bfloat16", + "transformers_version": "4.53.1", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 151936 +} diff --git a/configuration.json b/configuration.json new file mode 100644 index 0000000..159097f --- /dev/null +++ b/configuration.json @@ -0,0 +1 @@ +{"framework": "pytorch", "task": "others", "allow_remote": true} \ No newline at end of file diff --git a/configuration.py b/configuration.py new file mode 100644 index 0000000..00b1ee0 --- /dev/null +++ b/configuration.py @@ -0,0 +1,98 @@ +"""Fast_dLLM_Qwen model configuration""" + +from transformers.configuration_utils import PretrainedConfig, layer_type_validation +from transformers.modeling_rope_utils import rope_config_validation +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class Fast_dLLM_QwenConfig(PretrainedConfig): + + model_type = "Fast_dLLM_Qwen" + keys_to_ignore_at_inference = ["past_key_values"] + + # Default tensor parallel plan for base model `Fast_dLLM_Qwen` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=151936, + hidden_size=4096, + intermediate_size=22016, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + layer_types=None, + attention_dropout=0.0, + bd_size=32, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window if self.use_sliding_window else None + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_dropout = attention_dropout + self.bd_size = bd_size + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + self.layer_types = layer_types + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" + if self.sliding_window is not None and i >= self.max_window_layers + else "full_attention" + for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types) + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/generation_config.json b/generation_config.json new file mode 100644 index 0000000..95c56f1 --- /dev/null +++ b/generation_config.json @@ -0,0 +1,14 @@ +{ + "bos_token_id": 151643, + "do_sample": true, + "eos_token_id": [ + 151645, + 151643 + ], + "pad_token_id": 151643, + "repetition_penalty": 1.1, + "temperature": 0.7, + "top_k": 20, + "top_p": 0.8, + "transformers_version": "4.53.1" +} diff --git a/merges.txt b/merges.txt new file mode 100644 index 0000000..80c1a19 --- /dev/null +++ b/merges.txt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8831e4f1a044471340f7c0a83d7bd71306a5b867e95fd870f74d0c5308a904d5 +size 1671853 diff --git a/model.safetensors b/model.safetensors new file mode 100644 index 0000000..7bbbd3c --- /dev/null +++ b/model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8d267bb8b935f2e15148ba1175b67dba70261a696ec931dd0a3b0f27f9f3c434 +size 3087467144 diff --git a/modeling.py b/modeling.py new file mode 100644 index 0000000..4ee7f43 --- /dev/null +++ b/modeling.py @@ -0,0 +1,866 @@ +from typing import Callable, Optional, Union +from dataclasses import dataclass + +import torch +from torch import nn +import torch.nn.functional as F +from functools import partial + +from transformers.generation.utils import GenerateDecoderOnlyOutput +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.generation import GenerationMixin +from transformers.integrations import use_kernel_forward_from_hub +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_layers import GradientCheckpointingLayer +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.processing_utils import Unpack +from transformers.utils import auto_docstring, can_return_tuple, logging +from .configuration import Fast_dLLM_QwenConfig +from torch.nn.attention.flex_attention import flex_attention, create_block_mask +from einops import rearrange, repeat + +logger = logging.get_logger(__name__) + + +@dataclass +class CausalLMOutputWithPastAndBlockCache(CausalLMOutputWithPast): + block_past_key_values: Optional[Cache] = None + +@dataclass +class BaseModelOutputWithPastAndBlockCache(BaseModelOutputWithPast): + block_past_key_values: Optional[Cache] = None + + +@torch.compile(fullgraph=True, mode="max-autotune-no-cudagraphs") +def fused_flex_attention(q, k, v, mask=None): + return flex_attention(q, k, v, block_mask=mask, enable_gqa=True) + +def block_diff_mask(b, h, q_idx, kv_idx, block_size=None, n=None): + """ + Constructs the specialized block diffusion attention mask for training + composed of three masks: + - **Block Diagonal Mask (M_BD)**: Self-attention within noised blocks + - **Offset Block Causal Mask (M_OBC)**: Cross-attention for conditional context + - **Block Causal Mask (M_BC)**: Attention to update x0 + + Args: + b, h: Batch and head indices (ignored for mask logic). + q_idx, kv_idx: Query and Key indices. + seq_len: Total sequence length. + block_size: Defines the block structure. + + Returns: + A boolean attention mask. + """ + # Indicate whether token belongs to xt or x0 + x0_flag_q = (q_idx >= n) + x0_flag_kv = (kv_idx >= n) + + # Compute block indices + block_q = torch.where(x0_flag_q == 1, + (q_idx - n) // block_size, + q_idx // block_size) + block_kv = torch.where(x0_flag_kv == 1, + (kv_idx - n) // block_size, + kv_idx // block_size) + + # **1. Block Diagonal Mask (M_BD) ** + block_diagonal = (block_q == block_kv) & (x0_flag_q == x0_flag_kv) + + # **2. Offset Block-Causal Mask (M_OBC) ** + offset_block_causal = ( + (block_q > block_kv) + & (x0_flag_kv == 1) + & (x0_flag_q == 0) + ) + + # **3. Block-Causal Mask (M_BC) ** + block_causal = (block_q >= block_kv) & (x0_flag_kv == 1) & (x0_flag_q == 1) + + # **4. Combine Masks ** + return block_diagonal | offset_block_causal | block_causal + +def eval_block_diff_mask(q_idx, kv_idx, block_size=None): + # Compute block indices + block_q = q_idx // block_size + block_kv = kv_idx // block_size + + return block_q >= block_kv + +class Fast_dLLM_QwenMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Fast_dLLM_QwenAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Fast_dLLM_QwenConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + update_past_key_values: Optional[bool] = False, + block_past_key_values: Optional[Cache] = None, + replace_position: Optional[int] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + if self.training: + #split q into two parts + q_1 = query_states[:,:,:query_states.shape[2]//2] + q_2 = query_states[:,:,query_states.shape[2]//2:] + #split k into two parts + k_1 = key_states[:,:,:key_states.shape[2]//2] + k_2 = key_states[:,:,key_states.shape[2]//2:] + q_1, k_1 = apply_rotary_pos_emb(q_1, k_1, cos, sin) + q_2, k_2 = apply_rotary_pos_emb(q_2, k_2, cos, sin) + query_states = torch.cat((q_1, q_2), dim=-2) + key_states = torch.cat((k_1, k_2), dim=-2) + else: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if block_past_key_values is not None: + if len(block_past_key_values) <= self.layer_idx: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = block_past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + else: + block_cache_key_states = block_past_key_values[self.layer_idx][0] + block_cache_value_states = block_past_key_values[self.layer_idx][1] + + block_cache_key_states[:, :, replace_position:replace_position+key_states.shape[2]] = key_states + block_cache_value_states[:, :, replace_position:replace_position+value_states.shape[2]] = value_states + key_states = block_cache_key_states + value_states = block_cache_value_states + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + if update_past_key_values: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + elif len(past_key_value) > self.layer_idx: + key_states = torch.cat((past_key_value[self.layer_idx][0], key_states), dim=-2) + value_states = torch.cat((past_key_value[self.layer_idx][1], value_states), dim=-2) + + if self.training: + attn_output = fused_flex_attention(query_states, key_states, value_states, mask=attention_mask) + attn_output = attn_output.transpose(1, 2).contiguous() + else: + attention_interface = ALL_ATTENTION_FUNCTIONS["sdpa"] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + is_causal=False, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, # main diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output + +@use_kernel_forward_from_hub("RMSNorm") +class Fast_dLLM_QwenRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Fast_dLLM_QwenRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Fast_dLLM_QwenDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Fast_dLLM_QwenConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Fast_dLLM_QwenAttention(config=config, layer_idx=layer_idx) + + self.mlp = Fast_dLLM_QwenMLP(config) + self.input_layernorm = Fast_dLLM_QwenRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Fast_dLLM_QwenRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.attention_type = config.layer_types[layer_idx] + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + update_past_key_values: Optional[bool] = False, + use_block_cache: Optional[bool] = False, + block_past_key_values: Optional[Cache] = None, + replace_position: Optional[int] = None, + **kwargs + ) -> tuple[torch.Tensor]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + update_past_key_values=update_past_key_values, + use_block_cache=use_block_cache, + block_past_key_values=block_past_key_values, + replace_position=replace_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + + +class Fast_dLLM_QwenPreTrainedModel(PreTrainedModel): + config_class = Fast_dLLM_QwenConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Fast_dLLM_QwenDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": Fast_dLLM_QwenDecoderLayer, + "attentions": Fast_dLLM_QwenAttention, + } + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Fast_dLLM_QwenRMSNorm): + module.weight.data.fill_(1.0) + + +class Fast_dLLM_QwenRotaryEmbedding(nn.Module): + def __init__(self, config: Fast_dLLM_QwenConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + + +class Fast_dLLM_QwenModel(Fast_dLLM_QwenPreTrainedModel): + def __init__(self, config: Fast_dLLM_QwenConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.bd_size = config.bd_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Fast_dLLM_QwenDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Fast_dLLM_QwenRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Fast_dLLM_QwenRotaryEmbedding(config=config) + self.gradient_checkpointing = True + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + + def eval_mask(self, seqlen, block_size, cache_seq_len): + q_indices = torch.arange(seqlen) + cache_seq_len + k_indices = torch.arange(seqlen + cache_seq_len) + mask = eval_block_diff_mask( + q_idx=q_indices[:, None], + kv_idx=k_indices[None, :], + block_size=block_size + ) + return mask + + def gen_mask(self, seqlen, block_size, B, H): + mask = create_block_mask( + partial(block_diff_mask, block_size=block_size, n=seqlen), + B=B, H=H, Q_LEN=seqlen*2, KV_LEN=seqlen*2) + + return mask + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + update_past_key_values: Optional[bool] = False, + block_size: Optional[int] = 32, + use_block_cache: Optional[bool] = False, + block_past_key_values: Optional[Cache] = None, + replace_position: Optional[int] = None, + **kwargs + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if use_block_cache and block_past_key_values is None: + block_past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + if self.training: + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1]//2, device=inputs_embeds.device + ) + else: + if use_block_cache: + block_start_position = past_seen_tokens+replace_position if replace_position is not None else past_seen_tokens + cache_position = torch.arange( + block_start_position, block_start_position + inputs_embeds.shape[1], device=inputs_embeds.device + ) + else: + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1] if not self.training else inputs_embeds.shape[1]//2, device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + if self.training: + attention_mask = self.gen_mask(labels.shape[1], self.bd_size, labels.shape[0], self.config.num_attention_heads).to(device=inputs_embeds.device) + else: + if use_block_cache and block_past_key_values.get_seq_length() != 0: + attention_mask = None + else: + attention_mask = self.eval_mask(input_ids.shape[1], block_size, past_key_values.get_seq_length() if past_key_values is not None else 0).to(device=inputs_embeds.device) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + update_past_key_values=update_past_key_values, + use_block_cache=use_block_cache, + block_past_key_values=block_past_key_values, + replace_position=replace_position, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPastAndBlockCache( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + block_past_key_values=block_past_key_values if use_block_cache else None, + ) + + +class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = Fast_dLLM_QwenModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @can_return_tuple + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + update_past_key_values: Optional[bool] = False, + block_size: Optional[int] = 32, + use_block_cache: Optional[bool] = False, + block_past_key_values: Optional[Cache] = None, + replace_position: Optional[int] = None, + mask_id: Optional[int] = 151665, + **kwargs + ) -> CausalLMOutputWithPastAndBlockCache: + + if self.training: + original_labels = labels.clone() + original_input_ids = input_ids.clone() + + noisy_input_ids = input_ids.clone() + + input_ids = input_ids.reshape(input_ids.shape[0] * input_ids.shape[1] // self.model.bd_size, self.model.bd_size) + b, l = input_ids.shape + t = torch.rand((b,), device=input_ids.device) + eps=1e-3 + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + mask_indices = torch.rand((b, l), device=input_ids.device) < p_mask + x_t = torch.where(mask_indices, mask_id, input_ids).reshape(labels.shape) + noisy_input_ids[labels != -100] = x_t[labels != -100] + mask = (noisy_input_ids != mask_id) + labels[mask] = -100 + input_ids = torch.cat([noisy_input_ids, input_ids.reshape(labels.shape)], dim=1) + + complementary_noisy_input_ids = original_input_ids.clone() + complementary_labels = original_labels.clone() + + complementary_input_ids = original_input_ids.reshape(original_input_ids.shape[0] * original_input_ids.shape[1] // self.model.bd_size, self.model.bd_size) + + complementary_mask_indices = ~mask_indices + complementary_x_t = torch.where(complementary_mask_indices, mask_id, complementary_input_ids).reshape(labels.shape) + complementary_noisy_input_ids[complementary_labels != -100] = complementary_x_t[complementary_labels != -100] + complementary_mask = (complementary_noisy_input_ids != mask_id) + complementary_labels[complementary_mask] = -100 + complementary_input_ids = torch.cat([complementary_noisy_input_ids, complementary_input_ids.reshape(complementary_labels.shape)], dim=1) + + input_ids = torch.cat([input_ids, complementary_input_ids], dim=0) + labels = torch.cat([labels, complementary_labels], dim=0) + + outputs: BaseModelOutputWithPastAndBlockCache = self.model( + input_ids=input_ids, + labels=labels, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + update_past_key_values=update_past_key_values, + block_size=block_size, + use_block_cache=use_block_cache, + block_past_key_values=block_past_key_values, + replace_position=replace_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + if self.training: + hidden_states = hidden_states[:, :hidden_states.shape[1]//2, :] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPastAndBlockCache( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=hidden_states, + attentions=outputs.attentions, + block_past_key_values=outputs.block_past_key_values, + ) + + @torch.no_grad() + def generate( + self, + input_ids, + max_new_tokens=None, + max_length=None, + tokenizer=None, + mask_id=151665, + threshold=1, + small_block_size=8, + block_size=32, + stop_token=151645, + stopping_criteria=None, + top_p=0.95, + temperature=0, + use_block_cache=False, + return_dict_in_generate=False, + output_scores=False, + output_hidden_states=False, + **kwargs + ): + if max_new_tokens is None and max_length is None: + raise ValueError("Either max_new_tokens or max_length must be specified") + if max_new_tokens is None: + max_new_tokens = max_length - input_ids.shape[1] + + scores_list = [] if output_scores else None + decoder_hidden_states = [] if output_hidden_states else None + + num_blocks = max_new_tokens // block_size + original_input_length = input_ids.shape[1] + + if input_ids.shape[1] > block_size: + output = self.forward( + input_ids=input_ids[:, :(input_ids.shape[1] // block_size * block_size)], + use_cache=True, + update_past_key_values=True, + block_size=block_size + ) + logits, past_key_values = output.logits, output.past_key_values + + if output_scores: + scores_list.append(logits) + if output_hidden_states and hasattr(output, 'hidden_states'): + decoder_hidden_states.append(output.hidden_states) + + if input_ids.shape[1] % block_size == 0: + next_token = logits[:, -1:, :].argmax(dim=-1) + input_ids = torch.cat([input_ids, next_token], dim=1) + else: + past_key_values = None + + num_small_blocks = block_size // small_block_size + + for block_idx in range(num_blocks): + if stop_token in input_ids[:, original_input_length:]: + break + prompt_length = input_ids.shape[1] + # Initialize x_init with mask_id + x_init = mask_id * torch.ones( + (input_ids.shape[0], block_size-prompt_length%block_size), + device=self.device, + dtype=torch.long + ) + x_init = torch.cat([input_ids, x_init], dim=1) + + x_t = x_init.clone() + block_past_key_values = None + + while True: + if stop_token in x_t[:, prompt_length:]: + stop_token_idx = (x_t[:, prompt_length:] == stop_token).nonzero()[0][1] + if (x_t[:, prompt_length:prompt_length+stop_token_idx] == mask_id).sum() == 0: + break + mask_idx = (x_t[:, -block_size:] == mask_id) + + # Decode a complete block, update cache, and generate the next token + if mask_idx.sum() == 0: + output = self.forward( + input_ids=x_t[:, -block_size:], + use_cache=True, + past_key_values=past_key_values, + update_past_key_values=True, + block_size=block_size + ) + logits, past_key_values = output.logits, output.past_key_values + + # 收集输出信息 + if output_scores: + scores_list.append(logits) + if output_hidden_states and hasattr(output, 'hidden_states'): + decoder_hidden_states.append(output.hidden_states) + + next_token = logits[:, -1:, :].argmax(dim=-1) + x_t = torch.cat([x_t, next_token], dim=1) + break + + for small_block_idx in range(num_small_blocks): + small_block_start_idx = small_block_idx * small_block_size + small_block_end_idx = small_block_start_idx + small_block_size + + start = -block_size + small_block_start_idx + end = None if block_size == small_block_end_idx else -block_size + small_block_end_idx + + while True: + mask_idx = (x_t[:, -block_size:] == mask_id) + if mask_idx[:, start:end].sum() == 0: + break + if stop_token in x_t[:, prompt_length:]: + stop_token_idx = (x_t[:, prompt_length:] == stop_token).nonzero()[0][1] + if (x_t[:, prompt_length:prompt_length+stop_token_idx] == mask_id).sum() == 0: + break + + if use_block_cache: + if block_past_key_values is None or (x_t[:, -block_size+small_block_start_idx] == mask_id).any(): + output = self.forward( + input_ids=x_t[:, -block_size:], + use_cache=True, + past_key_values=past_key_values, + update_past_key_values=False, + use_block_cache=True, + ) + logits, block_past_key_values = output.logits, output.block_past_key_values + logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1) + logits = logits[:, start:end] + else: + output = self.forward( + input_ids=x_t[:,start:end], + use_cache=True, + past_key_values=past_key_values, + update_past_key_values=False, + use_block_cache=True, + block_past_key_values=block_past_key_values, + replace_position=small_block_start_idx + ) + logits = output.logits + logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1) + else: + output = self.forward( + input_ids=x_t[:, -block_size:], + use_cache=True, + past_key_values=past_key_values, + update_past_key_values=False + ) + logits = output.logits + logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1) + logits = logits[:, start:end] + + if output_scores: + scores_list.append(logits) + if output_hidden_states and hasattr(output, 'hidden_states'): + decoder_hidden_states.append(output.hidden_states) + + x_1, p_1t = self.sample_with_top_p(logits, top_p=top_p, temperature=temperature) + # Select tokens with probability greater than threshold from p_1t + x1_p = torch.squeeze(torch.gather(p_1t, dim=-1, index=torch.unsqueeze(x_1, -1)), -1) + x1_p = torch.where(mask_idx[:, start:end], x1_p, -torch.inf) + + unmask_idx = (x1_p > threshold) + max_prob_idx = x1_p.argmax(dim=-1) + unmask_idx[torch.arange(x_1.shape[0]), max_prob_idx] = True + unmask_idx = unmask_idx & mask_idx[:, start:end] + + x_t[:, start:end][unmask_idx] = x_1[unmask_idx] + + input_ids = x_t + + # Truncate stop_token + if stop_token in input_ids[:, original_input_length:]: + stop_token_idx = (input_ids[:, original_input_length:] == stop_token).nonzero()[0][1] + input_ids = input_ids[:, :stop_token_idx+original_input_length+1] + + if return_dict_in_generate: + return GenerateDecoderOnlyOutput( + sequences=input_ids, + scores=tuple(scores_list) if output_scores and scores_list else None, + hidden_states=tuple(decoder_hidden_states) if output_hidden_states and decoder_hidden_states else None, + ) + else: + return input_ids + + + def sample_with_top_p(self, logits, top_p=0.95, temperature=1.0): + # Calculate probabilities + if temperature > 0: + scaled_logits = logits / temperature + else: + p_1t = torch.softmax(logits, dim=-1) + x_1 = p_1t.argmax(dim=-1) + return x_1, p_1t + + probs = F.softmax(scaled_logits, dim=-1) + + sorted_probs, sorted_indices = torch.sort(probs, descending=True) + cumulative_probs = torch.cumsum(sorted_probs, dim=-1) + + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + indices_to_remove = torch.zeros_like(probs, dtype=torch.bool).scatter_( + dim=-1, index=sorted_indices, src=sorted_indices_to_remove + ) + + probs[indices_to_remove] = 0 + + # Renormalize so that the probabilities of remaining tokens sum to 1 + # Add a small epsilon value to prevent division by zero + probs_sum = torch.sum(probs, dim=-1, keepdim=True) + normalized_probs = probs / probs_sum + + p_1t = normalized_probs + x_1 = torch.multinomial(p_1t[0], num_samples=1).unsqueeze(0).squeeze(-1) + + return x_1, p_1t diff --git a/special_tokens_map.json b/special_tokens_map.json new file mode 100644 index 0000000..5a5751e --- /dev/null +++ b/special_tokens_map.json @@ -0,0 +1,25 @@ +{ + "additional_special_tokens": [ + { + "content": "||", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + } + ], + "eos_token": { + "content": "<|im_end|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "pad_token": { + "content": "<|endoftext|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + } +} diff --git a/tokenizer.json b/tokenizer.json new file mode 100644 index 0000000..013fc14 --- /dev/null +++ b/tokenizer.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cb2105b66192c5a532e2a098dc899df86eca233b4faa48461211e4312c8b3568 +size 11422081 diff --git a/tokenizer_config.json b/tokenizer_config.json new file mode 100644 index 0000000..d200a00 --- /dev/null +++ b/tokenizer_config.json @@ -0,0 +1,204 @@ +{ + "add_bos_token": false, + "add_prefix_space": false, + "added_tokens_decoder": { + "151643": { + "content": "<|endoftext|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151644": { + "content": "<|im_start|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151645": { + "content": "<|im_end|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151646": { + "content": "<|object_ref_start|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151647": { + "content": "<|object_ref_end|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151648": { + "content": "<|box_start|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151649": { + "content": "<|box_end|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151650": { + "content": "<|quad_start|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151651": { + "content": "<|quad_end|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151652": { + "content": "<|vision_start|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151653": { + "content": "<|vision_end|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151654": { + "content": "<|vision_pad|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151655": { + "content": "<|image_pad|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151656": { + "content": "<|video_pad|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151657": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "151658": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "151659": { + "content": "<|fim_prefix|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "151660": { + "content": "<|fim_middle|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "151661": { + "content": "<|fim_suffix|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "151662": { + "content": "<|fim_pad|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "151663": { + "content": "<|repo_name|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "151664": { + "content": "<|file_sep|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "151665": { + "content": "||", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + } + }, + "additional_special_tokens": [ + "||" + ], + "bos_token": null, + "clean_up_tokenization_spaces": false, + "eos_token": "<|im_end|>", + "errors": "replace", + "extra_special_tokens": {}, + "model_max_length": 131072, + "pad_token": "<|endoftext|>", + "padding_side": "right", + "split_special_tokens": false, + "tokenizer_class": "Qwen2Tokenizer", + "unk_token": null +} diff --git a/vocab.json b/vocab.json new file mode 100644 index 0000000..6c49fc6 --- /dev/null +++ b/vocab.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ca10d7e9fb3ed18575dd1e277a2579c16d108e32f27439684afa0e10b1440910 +size 2776833