初始化项目,由ModelHub XC社区提供模型
Model: facebook/KernelLLM Source: Original Platform
This commit is contained in:
54
.gitattributes
vendored
Normal file
54
.gitattributes
vendored
Normal file
@@ -0,0 +1,54 @@
|
||||
*.7z filter=lfs diff=lfs merge=lfs -text
|
||||
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||
*.bin filter=lfs diff=lfs merge=lfs -text
|
||||
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
||||
*.ftz filter=lfs diff=lfs merge=lfs -text
|
||||
*.gz filter=lfs diff=lfs merge=lfs -text
|
||||
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||
*.joblib filter=lfs diff=lfs merge=lfs -text
|
||||
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.model filter=lfs diff=lfs merge=lfs -text
|
||||
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||
*.ot filter=lfs diff=lfs merge=lfs -text
|
||||
*.parquet filter=lfs diff=lfs merge=lfs -text
|
||||
*.pb filter=lfs diff=lfs merge=lfs -text
|
||||
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||
*.rar filter=lfs diff=lfs merge=lfs -text
|
||||
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tflite filter=lfs diff=lfs merge=lfs -text
|
||||
*.tgz filter=lfs diff=lfs merge=lfs -text
|
||||
*.xz filter=lfs diff=lfs merge=lfs -text
|
||||
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
||||
*.tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||
*.db* filter=lfs diff=lfs merge=lfs -text
|
||||
*.ark* filter=lfs diff=lfs merge=lfs -text
|
||||
**/*ckpt*data* filter=lfs diff=lfs merge=lfs -text
|
||||
**/*ckpt*.meta filter=lfs diff=lfs merge=lfs -text
|
||||
**/*ckpt*.index filter=lfs diff=lfs merge=lfs -text
|
||||
|
||||
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
||||
*.gguf* filter=lfs diff=lfs merge=lfs -text
|
||||
*.ggml filter=lfs diff=lfs merge=lfs -text
|
||||
*.llamafile* filter=lfs diff=lfs merge=lfs -text
|
||||
*.pt2 filter=lfs diff=lfs merge=lfs -text
|
||||
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
||||
*.npy filter=lfs diff=lfs merge=lfs -text
|
||||
*.npz filter=lfs diff=lfs merge=lfs -text
|
||||
*.pickle filter=lfs diff=lfs merge=lfs -text
|
||||
*.pkl filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar filter=lfs diff=lfs merge=lfs -text
|
||||
*.wasm filter=lfs diff=lfs merge=lfs -text
|
||||
*.zst filter=lfs diff=lfs merge=lfs -text
|
||||
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||
|
||||
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
||||
|
||||
model-00003-of-00004.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||
model-00002-of-00004.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||
model-00004-of-00004.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||
model-00001-of-00004.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||
BIN
LICENSE.pdf
Normal file
BIN
LICENSE.pdf
Normal file
Binary file not shown.
196
README.md
Normal file
196
README.md
Normal file
@@ -0,0 +1,196 @@
|
||||
---
|
||||
license: other
|
||||
base_model:
|
||||
- meta-llama/Llama-3.1-8B-Instruct
|
||||
datasets:
|
||||
- ScalingIntelligence/KernelBench
|
||||
- GPUMODE/KernelBook
|
||||
library_name: transformers
|
||||
---
|
||||
|
||||
# KernelLLM
|
||||

|
||||
*On KernelBench-Triton Level 1, our 8B parameter model exceeds models such as GPT-4o and DeepSeek V3 in single-shot performance. With multiple inferences, KernelLLM's performance outperforms DeepSeek R1. This is all from a model with two orders of magnitude fewer parameters than its competitors.*
|
||||
|
||||
## _Updates_:
|
||||
* 2025/06/25: We added an [end-to-end example walkthrough](https://huggingface.co/facebook/KernelLLM/discussions/5#685b0903b3d048882566b17b), where we format a community-provided prompt for KernelLLM to function well.
|
||||
We have received many questions about how to format the prompts such that KernelLLM performs best. We hope this can help!
|
||||
* 2025/06/15 We would like to thank the community for the creation of [multiple](https://huggingface.co/bartowski/facebook_KernelLLM-GGUF) [different](https://huggingface.co/unsloth/KernelLLM-GGUF) [quantizations](https://huggingface.co/unsloth/KernelLLM) and for a total of more than 20k downloads!
|
||||
* 2025/06/03 The startup mako.dev has integrated KernelLLM into their [GPU performance engineering plattform](https://generate.mako.dev/)!
|
||||
|
||||
## Making Kernel Development more accessible with KernelLLM
|
||||
|
||||
We introduce KernelLLM, a large language model based on Llama 3.1 Instruct, which has been trained specifically for the task of authoring GPU kernels using Triton. KernelLLM translates PyTorch modules into Triton kernels and was evaluated on KernelBench-Triton (see [here](https://github.com/ScalingIntelligence/KernelBench/pull/35)).
|
||||
KernelLLM aims to democratize GPU programming by making kernel development more accessible and efficient.
|
||||
|
||||
KernelLLM's vision is to meet the growing demand for high-performance GPU kernels by automating the generation of efficient Triton implementations. As workloads grow larger and more diverse accelerator architectures emerge, the need for tailored kernel solutions has increased significantly. Although a number of [works](https://metr.org/blog/2025-02-14-measuring-automated-kernel-engineering/) [exist](https://cognition.ai/blog/kevin-32b), most of them are limited to [test-time](https://sakana.ai/ai-cuda-engineer/) [optimization](https://developer.nvidia.com/blog/automating-gpu-kernel-generation-with-deepseek-r1-and-inference-time-scaling/), while others tune on solutions traced of KernelBench problems itself, thereby limiting the informativeness of the results towards out-of-distribution generalization. To the best of our knowledge KernelLLM is the first LLM finetuned on external (torch, triton) pairs, and we hope that making our model available can accelerate progress towards intelligent kernel authoring systems.
|
||||
|
||||
|
||||
|
||||

|
||||
|
||||
*KernelLLM Workflow for Triton Kernel Generation: Our approach uses KernelLLM to translate PyTorch code (green) into Triton kernel candidates. Input and output components are marked in bold. The generations are validated against unit tests, which run kernels with random inputs of known shapes. This workflow allows us to evaluate multiple generations (pass@k) by increasing the number of kernel candidate generations. The best kernel implementation is selected and returned (green output).*
|
||||
|
||||
|
||||
The model was trained on approximately 25,000 paired examples of PyTorch modules and their equivalent Triton kernel implementations, and additional synthetically generated samples. Our approach combines filtered code from TheStack [Kocetkov et al. 2022] and synthetic examples generated through `torch.compile()` and additional prompting techniques. The filtered and compiled dataset is [KernelBook]](https://huggingface.co/datasets/GPUMODE/KernelBook).
|
||||
|
||||
We finetuned Llama3.1-8B-Instruct on the created dataset using supervised instruction tuning and measured its ability to generate correct Triton kernels and corresponding calling code on KernelBench-Triton, our newly created variant of KernelBench [Ouyang et al. 2025] targeting Triton kernel generation. The torch code was used with a prompt template containing a format example as instruction during both training and evaluation. The model was trained for 10 epochs with a batch size of 32 and a standard SFT recipe with hyperparameters selected by perplexity on a held-out subset of the training data. Training took circa 12 hours wall clock time on 16 GPUs (192 GPU hours), and we report the best checkpoint's validation results.
|
||||
|
||||
### Model Performance
|
||||
|
||||
|
||||

|
||||
|
||||
| Model | Parameters (B) | Score | Pass@k |
|
||||
|-------|---------------|-------|--------|
|
||||
| KernelLLM | 8 | 20.2 | 1 |
|
||||
| KernelLLM | 8 | 51.8 | 10 |
|
||||
| KernelLLM | 8 | 57.1 | 20 |
|
||||
| DeepSeek V3 | 671 | 16 | 1 |
|
||||
| GPT-4o | ~200 | 15 | 1 |
|
||||
| Qwen2.5 | 32 | 15 | 1 |
|
||||
| Llama 3.3 | 70 | 13 | 1 |
|
||||
| Llama 3.1 | 8 | 14 | 20 |
|
||||
| Llama 3.1 | 8 | 6 | 1 |
|
||||
| Llama R1 Distill | 70 | 11 | reasoning |
|
||||
| DeepSeek R1 | 671 | 30 | 1 |
|
||||
|
||||
*Our 8B parameter model achieves competitive or superior performance compared to much larger models on kernel generation tasks, demonstrating the effectiveness of our specialized training approach on KernelBench Level 1 versus various baselines. KernelLLM inference was run with temperature=1.0 and top_p=0.97.*
|
||||
|
||||
The resulting model is competitive with state of the art LLMs despite its small size. We evaluate our model on KernelBench which is an open-source benchmark to evaluate the ability of LLMs to write efficient GPU kernels. It contains 250 selected PyTorch modules organized into difficulty levels, from single torch operators such as Conv2D or Swish (level 1), to full model architectures (level 3). The benchmark measures both correctness (by comparing against reference PyTorch outputs) and performance (by measuring speedup over baseline implementations). We implemented a new KernelBench-Triton variant that evaluates an LLMs ability to generate Triton kernels, making it an ideal benchmark for evaluating KernelLLM's capabilities. All our measurements were done on Nvidia H100 GPUs.
|
||||
|
||||

|
||||
*KernelLLM shows quasi log-linear scaling behavior during pass@k analysis.*
|
||||
|
||||
|
||||
For more information, please see [Project Popcorn](https://gpu-mode.github.io/popcorn/).
|
||||
|
||||
## Installation
|
||||
|
||||
To use KernelLLM, install the required dependencies:
|
||||
|
||||
```bash
|
||||
pip install transformers accelerate torch triton
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
KernelLLM provides a simple interface for generating Triton kernels from PyTorch code. The included `kernelllm.py` script offers multiple methods for interacting with the model.
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```python
|
||||
from kernelllm import KernelLLM
|
||||
|
||||
# Initialize the model
|
||||
model = KernelLLM()
|
||||
|
||||
# Define your PyTorch module
|
||||
pytorch_code = '''
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class Model(nn.Module):
|
||||
"""
|
||||
A model that computes Hinge Loss for binary classification tasks.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
|
||||
def forward(self, predictions, targets):
|
||||
return torch.mean(torch.clamp(1 - predictions * targets, min=0))
|
||||
|
||||
batch_size = 128
|
||||
input_shape = (1,)
|
||||
|
||||
def get_inputs():
|
||||
return [torch.randn(batch_size, *input_shape), torch.randint(0, 2, (batch_size, 1)).float() * 2 - 1]
|
||||
|
||||
def get_init_inputs():
|
||||
return []
|
||||
'''
|
||||
|
||||
# Generate optimized Triton code
|
||||
optimized_code = model.generate_triton(pytorch_code, max_new_tokens=512)
|
||||
print(optimized_code)
|
||||
```
|
||||
|
||||
### Interactive REPL
|
||||
|
||||
You can also use the built-in REPL interface:
|
||||
|
||||
```bash
|
||||
python kernelllm.py
|
||||
```
|
||||
|
||||
This will start an interactive session where you can input your PyTorch code and receive Triton-optimized implementations.
|
||||
|
||||
### Advanced Options
|
||||
|
||||
KernelLLM provides several methods for customizing the generation process:
|
||||
|
||||
```python
|
||||
from kernelllm import KernelLLM
|
||||
|
||||
model = KernelLLM()
|
||||
|
||||
# Stream output in real-time
|
||||
model.stream_raw("Your prompt here", max_new_tokens=2048)
|
||||
|
||||
# Generate raw text without the Triton-specific prompt template
|
||||
raw_output = model.generate_raw("Your prompt here", temperature=1.0, max_new_tokens=2048)
|
||||
```
|
||||
|
||||
## Current Limitations and Future Work
|
||||
|
||||
Despite showing promising results, KernelLLM has several limitations:
|
||||
|
||||
- The model may still produce incorrect API references and syntax errors, and is limited in its instruction following ability.
|
||||
- Generated code structurally resembles compiler-generated output, and the model often fails to implement a meaningful kernel.
|
||||
- Error analysis shows common issues related to instruction following with respect to variable naming, tensor shapes, type handling, and numerical precision.
|
||||
|
||||
## Model Details
|
||||
|
||||
**Model Developers:** Meta.
|
||||
|
||||
**Input:** Models input text only.
|
||||
|
||||
**Output:** Models generate text only.
|
||||
|
||||
**Model Architecture:** KernelLLM is an auto-regressive language model that uses an optimized transformer architecture.
|
||||
|
||||
**Model Dates:** KernelLLM was trained in March 2025.
|
||||
|
||||
**Status:** This is a static model trained on an offline dataset.
|
||||
|
||||
**License:** See LICENSE.pdf for details.
|
||||
|
||||
## Intended Use
|
||||
|
||||
**Intended Use Cases:** KernelLLM is intended for commercial and research use in English, relevant programming languages, Python, and Triton.
|
||||
|
||||
**Out-of-Scope Uses:** Use in any manner that violates applicable laws or regulations (including trade compliance laws). Use in languages other than English. Use in any other way that is prohibited by the [Acceptable Use Policy](https://llama.meta.com/llama3/use-policy) and Licensing Agreement for KernelLLM and its variants.
|
||||
|
||||
## Hardware and Software
|
||||
|
||||
**Training Factors:** We used custom training libraries.
|
||||
|
||||
**Carbon Footprint:** In aggregate, training KernelLLM required 250 hours of computation on hardware of type H100-80GB, not including the training of the base model. 100% of the estimated tCO2eq emissions were offset by Meta's sustainability program.
|
||||
|
||||
## Ethical Considerations and Limitations
|
||||
|
||||
KernelLLM and its variants are a new technology that carries risks with use. Testing conducted to date has been in English, and has not covered, nor could it cover all scenarios. For these reasons, as with all LLMs, KernelLLM's potential outputs cannot be predicted in advance, and the model may in some instances produce inaccurate or objectionable responses to user prompts. Therefore, before deploying any applications of KernelLLM, developers should perform safety testing and tuning tailored to their specific applications of the model.
|
||||
|
||||
Please see the Responsible Use Guide available at [https://ai.meta.com/llama/responsible-use-guide](https://ai.meta.com/llama/responsible-use-guide).
|
||||
|
||||
## Citation
|
||||
|
||||
```
|
||||
@software{kernelllm2025,
|
||||
title={KernelLLM: Making Kernel Development More Accessible},
|
||||
author={Fisches, Zacharias V. and Paliskara, Sahan and Guo, Simon and Zhang, Alex and Spisak, Joe and Cummins, Chris and Leather, Hugh and Synnaeve, Gabriel and Isaacson, Joe and Markosyan, Aram and Saroufim, Mark},
|
||||
year={2025},
|
||||
month={6},
|
||||
url={https://huggingface.co/facebook/KernelLLM},
|
||||
}
|
||||
```
|
||||
39
config.json
Normal file
39
config.json
Normal file
@@ -0,0 +1,39 @@
|
||||
{
|
||||
"architectures": [
|
||||
"LlamaForCausalLM"
|
||||
],
|
||||
"attention_bias": false,
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 128000,
|
||||
"eos_token_id": [
|
||||
128001,
|
||||
128008,
|
||||
128009
|
||||
],
|
||||
"head_dim": 128,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 4096,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 14336,
|
||||
"max_position_embeddings": 131072,
|
||||
"mlp_bias": false,
|
||||
"model_type": "llama",
|
||||
"num_attention_heads": 32,
|
||||
"num_hidden_layers": 32,
|
||||
"num_key_value_heads": 8,
|
||||
"pretraining_tp": 1,
|
||||
"rms_norm_eps": 1e-05,
|
||||
"rope_scaling": {
|
||||
"factor": 8.0,
|
||||
"high_freq_factor": 4.0,
|
||||
"low_freq_factor": 1.0,
|
||||
"original_max_position_embeddings": 8192,
|
||||
"rope_type": "llama3"
|
||||
},
|
||||
"rope_theta": 500000.0,
|
||||
"tie_word_embeddings": false,
|
||||
"torch_dtype": "bfloat16",
|
||||
"transformers_version": "4.48.2",
|
||||
"use_cache": true,
|
||||
"vocab_size": 128256
|
||||
}
|
||||
1
configuration.json
Normal file
1
configuration.json
Normal file
@@ -0,0 +1 @@
|
||||
{"framework": "pytorch", "task": "text-generation", "allow_remote": true}
|
||||
10
generation_config.json
Normal file
10
generation_config.json
Normal file
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"_from_model_config": true,
|
||||
"bos_token_id": 128000,
|
||||
"eos_token_id": [
|
||||
128001,
|
||||
128008,
|
||||
128009
|
||||
],
|
||||
"transformers_version": "4.48.2"
|
||||
}
|
||||
331
kernelllm.py
Normal file
331
kernelllm.py
Normal file
@@ -0,0 +1,331 @@
|
||||
"""
|
||||
KernelLLM
|
||||
|
||||
This script provides a simple interface for the KernelLLM model.
|
||||
It allows users to input PyTorch models and let KernelLLM attempt to implement the corresponding Triton kernels.
|
||||
|
||||
The KernelLLM class provides two types of methods:
|
||||
1. Methods that instruct the model with a suitable prompt to generate Triton kernels.
|
||||
2. "raw" methods that allow the user to interact with the model directly, without any additional prompt wrapping.
|
||||
|
||||
For best results, use the `generate_triton` method to instruct the model the way it was trained.
|
||||
|
||||
Example usage:
|
||||
To run the script from the command line:
|
||||
CUDA_VISIBLE_DEVICES=0 python kernelllm.py
|
||||
|
||||
To use the class in an interactive Python session:
|
||||
$ ipython
|
||||
from kernelllm import KernelLLM
|
||||
model = KernelLLM()
|
||||
model.generate_triton("<your torch module here>", max_new_tokens=128)
|
||||
|
||||
# or stream output directly
|
||||
model.stream_raw("<your text prompt>", max_new_tokens=128)
|
||||
|
||||
|
||||
Full example:
|
||||
```
|
||||
#Generate Triton-optimized code for a PyTorch model:
|
||||
from kernelllm import KernelLLM
|
||||
|
||||
model = KernelLLM()
|
||||
pytorch_code = '''
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x * 2
|
||||
|
||||
def get_inputs():
|
||||
return [torch.randn(1, 128).cuda()]
|
||||
|
||||
def get_init_inputs():
|
||||
return []
|
||||
'''
|
||||
optimized_code = model.generate_triton(pytorch_code, max_new_tokens=512)
|
||||
print(optimized_code)
|
||||
```
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
|
||||
|
||||
HF_MODEL = "facebook/KernelLLM"
|
||||
|
||||
REPL_INSTRUCTIONS = """
|
||||
You can paste or write your nn.Module code below (and finish with Ctrl+D).
|
||||
The model will try to optimize it with Triton kernels.
|
||||
|
||||
Make sure that you provide a `get_inputs()` and `get_init_inputs()` function such that your model can be run like this
|
||||
args, kwargs = get_inputs()
|
||||
model = Model(*args, **kwargs)
|
||||
out = model(get_inputs())
|
||||
|
||||
>>>"""
|
||||
|
||||
DEFAULT_MODEL_CODE = """
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class Model(nn.Module):
|
||||
\"\"\"
|
||||
A model that computes Hinge Loss for binary classification tasks.
|
||||
|
||||
Parameters:
|
||||
None
|
||||
\"\"\"
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
|
||||
def forward(self, predictions, targets):
|
||||
return torch.mean(torch.clamp(1 - predictions * targets, min=0))
|
||||
|
||||
batch_size = 128
|
||||
input_shape = (1,)
|
||||
dim = 1
|
||||
|
||||
def get_inputs():
|
||||
return [torch.randn(batch_size, *input_shape), torch.randint(0, 2, (batch_size, 1)).float() * 2 - 1]
|
||||
|
||||
def get_init_inputs():
|
||||
return []
|
||||
"""
|
||||
|
||||
PROMPT_TEMPLATE = """
|
||||
<|begin_of_text|>You write custom Triton kernels to replace the pytorch operators in the given architecture to get speedups.
|
||||
|
||||
You have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom Triton kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.
|
||||
|
||||
|
||||
Here's an example to show you the syntax of inline embedding custom operators from the Triton DSL in torch: The example given architecture is:
|
||||
```
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, a, b):
|
||||
return a + b
|
||||
|
||||
|
||||
def get_inputs():
|
||||
# randomly generate input tensors based on the model architecture
|
||||
a = torch.randn(1, 128).cuda()
|
||||
b = torch.randn(1, 128).cuda()
|
||||
return [a, b]
|
||||
|
||||
|
||||
def get_init_inputs():
|
||||
# randomly generate tensors required for initialization based on the model architecture
|
||||
return []
|
||||
|
||||
```
|
||||
The example new arch with custom Triton kernels looks like this:
|
||||
```
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def add_kernel(
|
||||
x_ptr, # Pointer to first input
|
||||
y_ptr, # Pointer to second input
|
||||
out_ptr, # Pointer to output
|
||||
n_elements, # Total number of elements in input/output
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# Each program handles a contiguous block of data of size BLOCK_SIZE
|
||||
block_start = tl.program_id(0) * BLOCK_SIZE
|
||||
# Create a range of offsets [0..BLOCK_SIZE-1]
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
# Mask to ensure we don't go out of bounds
|
||||
mask = offsets < n_elements
|
||||
# Load input values
|
||||
x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
|
||||
y = tl.load(y_ptr + offsets, mask=mask, other=0.0)
|
||||
# Perform the elementwise addition
|
||||
out = x + y
|
||||
# Store the result
|
||||
tl.store(out_ptr + offsets, out, mask=mask)
|
||||
|
||||
|
||||
def triton_add(x: torch.Tensor, y: torch.Tensor):
|
||||
\"\"\"
|
||||
This function wraps the Triton kernel call. It:
|
||||
1. Ensures the inputs are contiguous on GPU.
|
||||
2. Calculates the grid (blocks) needed.
|
||||
3. Launches the Triton kernel.
|
||||
\"\"\"
|
||||
assert x.is_cuda and y.is_cuda, "Tensors must be on CUDA."
|
||||
x = x.contiguous()
|
||||
y = y.contiguous()
|
||||
|
||||
# Prepare output tensor
|
||||
out = torch.empty_like(x)
|
||||
|
||||
# Number of elements in the tensor
|
||||
n_elements = x.numel()
|
||||
BLOCK_SIZE = 128 # Tunable parameter for block size
|
||||
|
||||
# Determine the number of blocks needed
|
||||
grid = lambda meta: ((n_elements + meta["BLOCK_SIZE"] - 1) // meta["BLOCK_SIZE"],)
|
||||
|
||||
# Launch the Triton kernel
|
||||
add_kernel[grid](x, y, out, n_elements, BLOCK_SIZE=BLOCK_SIZE)
|
||||
return out
|
||||
|
||||
|
||||
class ModelNew(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, a, b):
|
||||
# Instead of "return a + b", call our Triton-based addition
|
||||
return triton_add(a, b)
|
||||
|
||||
```
|
||||
|
||||
You are given the following architecture:
|
||||
```
|
||||
{}
|
||||
```
|
||||
|
||||
Optimize the architecture named Model with custom Triton kernels! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code!
|
||||
"""
|
||||
|
||||
|
||||
class KernelLLM:
|
||||
"""
|
||||
A simple wrapper around the KernelLLM model for generating Triton kernels that allows easy
|
||||
instruction of the model and a streamed repl interface to interact with the model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = HF_MODEL,
|
||||
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.device = device
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name, torch_dtype=torch.float16
|
||||
)
|
||||
self.model.to(self.device)
|
||||
|
||||
def generate_raw(
|
||||
self, prompt: str, temperature: float = 0.6, max_new_tokens: int = 2048
|
||||
) -> str:
|
||||
"""
|
||||
Generate text from the model using the given prompt verbatim.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt to generate text from.
|
||||
temperature (float): The temperature to use for sampling.
|
||||
max_new_tokens (int): The maximum length of the generated text.
|
||||
Returns:
|
||||
str: The generated text.
|
||||
"""
|
||||
inputs = self.tokenizer([prompt], return_tensors="pt")
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
outputs = self.model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
top_k=0,
|
||||
top_p=0.95,
|
||||
do_sample=True,
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
)
|
||||
text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
return text[len(prompt) :].strip()
|
||||
|
||||
def stream_raw(self, prompt: str, max_new_tokens: int = 2048):
|
||||
"""
|
||||
Stream and print text from the model using the given prompt verbatim.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt to generate text from.
|
||||
max_new_tokens (int): The maximum length of the generated text.
|
||||
"""
|
||||
inputs = self.tokenizer([prompt], return_tensors="pt")
|
||||
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||
streamer = TextStreamer(
|
||||
self.tokenizer, skip_prompt=True, skip_special_tokens=True
|
||||
)
|
||||
self.model.generate(
|
||||
**inputs,
|
||||
streamer=streamer,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=True,
|
||||
top_k=0,
|
||||
top_p=0.95,
|
||||
temperature=0.6,
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
)
|
||||
|
||||
def generate_triton(
|
||||
self, code: str, temperature: float = 0.6, max_new_tokens: int = 2048
|
||||
) -> str:
|
||||
"""
|
||||
Generate Triton for the given torch module.
|
||||
|
||||
The input code should be a python module that contains a torch Model(nn.Module) class and
|
||||
`get_inputs()` and `get_init_inputs()` functions such that your model can be run like this
|
||||
```
|
||||
args, kwargs = get_inputs()
|
||||
model = Model(*args, **kwargs)
|
||||
out = model(get_inputs())
|
||||
```
|
||||
|
||||
Args:
|
||||
code (str): The torch code to generate Triton for.
|
||||
temperature (float): The temperature to use for sampling.
|
||||
max_new_tokens (int): The maximum length of the generated text.
|
||||
Returns:
|
||||
str: The generated Triton module.
|
||||
"""
|
||||
prompt = PROMPT_TEMPLATE.format(code)
|
||||
return self.generate_raw(prompt, temperature, max_new_tokens)
|
||||
|
||||
def run_repl(self):
|
||||
"""
|
||||
Run a REPL for the model. The user can input code and the model will try to optimize it with Triton kernels.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
print(REPL_INSTRUCTIONS)
|
||||
code = sys.stdin.read().strip()
|
||||
if code.lower() == "exit":
|
||||
return
|
||||
except EOFError:
|
||||
pass
|
||||
|
||||
if not code:
|
||||
print(f"Using default prompt:\n{DEFAULT_MODEL_CODE}\n")
|
||||
code = DEFAULT_MODEL_CODE
|
||||
prompt = PROMPT_TEMPLATE.format(DEFAULT_MODEL_CODE)
|
||||
|
||||
try:
|
||||
self.stream_raw(prompt)
|
||||
except KeyboardInterrupt:
|
||||
print("Aborting...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
kernel_llm = KernelLLM()
|
||||
kernel_llm.run_repl()
|
||||
BIN
media/blog_post_model_performance.png
Normal file
BIN
media/blog_post_model_performance.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 238 KiB |
BIN
media/kernelllm_pass_at_k_scaling.png
Normal file
BIN
media/kernelllm_pass_at_k_scaling.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 495 KiB |
BIN
media/llm_performance_comparison.png
Normal file
BIN
media/llm_performance_comparison.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 360 KiB |
BIN
media/triton-kernel-workflow.png
Normal file
BIN
media/triton-kernel-workflow.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 320 KiB |
3
model-00001-of-00004.safetensors
Normal file
3
model-00001-of-00004.safetensors
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:cc5b9790382a99cbaf8dd0312d3a01c5d3e6f58de917ca24be0a110c9dce7515
|
||||
size 4976698672
|
||||
3
model-00002-of-00004.safetensors
Normal file
3
model-00002-of-00004.safetensors
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:14d4f200f024182d03eebd915d2855014ea7d9937e60a16aa2f849c2488102c3
|
||||
size 4999802720
|
||||
3
model-00003-of-00004.safetensors
Normal file
3
model-00003-of-00004.safetensors
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8805dad2f4dd0009156271c0a352815197d38b3992969770745727f8d5e21fb0
|
||||
size 4915916176
|
||||
3
model-00004-of-00004.safetensors
Normal file
3
model-00004-of-00004.safetensors
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:327f53054b3b5ae860bf06e2d5dd8f4f07f627d6886155c1a449069147a742f9
|
||||
size 1168138808
|
||||
298
model.safetensors.index.json
Normal file
298
model.safetensors.index.json
Normal file
@@ -0,0 +1,298 @@
|
||||
{
|
||||
"metadata": {
|
||||
"total_size": 16060522496
|
||||
},
|
||||
"weight_map": {
|
||||
"lm_head.weight": "model-00004-of-00004.safetensors",
|
||||
"model.embed_tokens.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.0.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.0.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.0.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.0.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.0.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.0.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.1.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.1.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.1.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.1.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.1.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.1.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.10.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.10.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.10.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.10.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.10.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.10.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.11.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.11.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.11.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.11.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.11.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.11.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.12.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.12.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.12.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.12.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.12.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.12.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.13.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.13.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.13.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.13.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.13.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.13.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.13.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.14.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.14.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.14.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.14.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.14.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.14.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.14.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.14.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.14.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.15.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.15.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.15.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.15.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.15.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.15.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.15.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.15.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.15.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.16.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.16.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.16.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.16.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.16.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.16.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.16.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.16.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.16.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.17.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.17.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.17.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.17.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.17.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.17.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.17.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.17.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.17.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.18.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.18.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.18.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.18.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.18.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.18.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.18.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.18.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.18.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.19.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.19.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.19.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.19.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.19.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.19.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.19.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.19.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.19.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.2.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.2.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.2.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.2.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.2.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.2.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.2.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.20.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.20.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.20.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.20.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.20.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.20.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.20.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.20.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.20.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.21.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.21.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.21.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.21.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.21.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.21.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.21.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.21.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.21.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.22.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.22.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.22.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.22.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.22.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.22.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.22.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.22.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.23.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.23.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.23.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.23.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.23.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.23.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.23.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.24.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.24.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.24.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.24.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.24.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.24.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.24.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.24.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.25.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.25.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.25.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.25.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.25.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.25.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.25.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.25.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.25.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.26.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.26.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.26.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.26.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.26.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.26.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.26.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.26.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.26.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.27.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.27.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.27.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.27.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.27.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.27.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.27.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.27.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.27.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.28.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.28.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.28.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.28.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.28.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.28.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.28.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.28.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.28.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.29.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.29.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.29.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.29.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.29.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.29.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.29.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.29.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.29.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.3.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.3.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.3.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.3.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.3.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.3.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.3.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.3.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.3.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.30.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.30.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.30.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.30.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.30.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.30.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.30.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.30.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.30.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.31.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
||||
"model.layers.31.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
||||
"model.layers.31.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.31.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.31.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
|
||||
"model.layers.31.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.31.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.31.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.31.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
||||
"model.layers.4.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.4.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.4.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.4.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.4.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.4.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.4.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.4.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.4.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.5.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.5.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.5.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.5.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.5.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.5.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.5.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.5.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.5.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.6.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.6.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.6.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.6.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.6.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.6.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.6.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.6.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.6.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.7.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.7.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.7.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.7.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.7.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.7.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.7.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.7.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.7.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.8.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.8.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.8.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.8.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.8.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.8.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.8.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.8.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.8.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
||||
"model.layers.9.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.9.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.9.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.9.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.9.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.9.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.9.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.9.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.layers.9.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
||||
"model.norm.weight": "model-00004-of-00004.safetensors"
|
||||
}
|
||||
}
|
||||
4
special_tokens_map.json
Normal file
4
special_tokens_map.json
Normal file
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"bos_token": "<|begin_of_text|>",
|
||||
"eos_token": "<|end_of_text|>"
|
||||
}
|
||||
3
tokenizer.json
Normal file
3
tokenizer.json
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:01e3be37353fbc0be479c7509d53c76860b7915a6b1852d5e75ec0c92707138b
|
||||
size 17208753
|
||||
2063
tokenizer_config.json
Normal file
2063
tokenizer_config.json
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user