Upload folder using huggingface_hub
This commit is contained in:
1
.gitattributes
vendored
1
.gitattributes
vendored
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|||||||
*.zip filter=lfs diff=lfs merge=lfs -text
|
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||||
*.zst filter=lfs diff=lfs merge=lfs -text
|
*.zst filter=lfs diff=lfs merge=lfs -text
|
||||||
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
||||||
|
|||||||
541
README.md
Normal file
541
README.md
Normal file
@@ -0,0 +1,541 @@
|
|||||||
|
---
|
||||||
|
license: gemma
|
||||||
|
library_name: transformers
|
||||||
|
pipeline_tag: image-text-to-text
|
||||||
|
extra_gated_heading: Access Gemma on Hugging Face
|
||||||
|
extra_gated_prompt: To access Gemma on Hugging Face, you’re required to review and
|
||||||
|
agree to Google’s usage license. To do this, please ensure you’re logged in to Hugging
|
||||||
|
Face and click below. Requests are processed immediately.
|
||||||
|
extra_gated_button_content: Acknowledge license
|
||||||
|
base_model: google/gemma-3-27b-it
|
||||||
|
---
|
||||||
|
|
||||||
|
# Gemma 3 27B Instruction-tuned INT4
|
||||||
|
|
||||||
|
This is the QAT INT4 Flax checkpoint (from Kaggle) converted to HF+AWQ format for ease of use. AWQ was NOT used for quantization. You can find the conversion script `convert_flax.py` in this model repo.
|
||||||
|
|
||||||
|
Below is the original Model card from https://huggingface.co/google/gemma-3-27b-it
|
||||||
|
|
||||||
|
# Gemma 3 model card
|
||||||
|
|
||||||
|
**Model Page**: [Gemma](https://ai.google.dev/gemma/docs/core)
|
||||||
|
|
||||||
|
**Resources and Technical Documentation**:
|
||||||
|
|
||||||
|
* [Gemma 3 Technical Report][g3-tech-report]
|
||||||
|
* [Responsible Generative AI Toolkit][rai-toolkit]
|
||||||
|
* [Gemma on Kaggle][kaggle-gemma]
|
||||||
|
* [Gemma on Vertex Model Garden][vertex-mg-gemma3]
|
||||||
|
|
||||||
|
**Terms of Use**: [Terms][terms]
|
||||||
|
|
||||||
|
**Authors**: Google DeepMind
|
||||||
|
|
||||||
|
## Model Information
|
||||||
|
|
||||||
|
Summary description and brief definition of inputs and outputs.
|
||||||
|
|
||||||
|
### Description
|
||||||
|
|
||||||
|
Gemma is a family of lightweight, state-of-the-art open models from Google,
|
||||||
|
built from the same research and technology used to create the Gemini models.
|
||||||
|
Gemma 3 models are multimodal, handling text and image input and generating text
|
||||||
|
output, with open weights for both pre-trained variants and instruction-tuned
|
||||||
|
variants. Gemma 3 has a large, 128K context window, multilingual support in over
|
||||||
|
140 languages, and is available in more sizes than previous versions. Gemma 3
|
||||||
|
models are well-suited for a variety of text generation and image understanding
|
||||||
|
tasks, including question answering, summarization, and reasoning. Their
|
||||||
|
relatively small size makes it possible to deploy them in environments with
|
||||||
|
limited resources such as laptops, desktops or your own cloud infrastructure,
|
||||||
|
democratizing access to state of the art AI models and helping foster innovation
|
||||||
|
for everyone.
|
||||||
|
|
||||||
|
### Inputs and outputs
|
||||||
|
|
||||||
|
- **Input:**
|
||||||
|
- Text string, such as a question, a prompt, or a document to be summarized
|
||||||
|
- Images, normalized to 896 x 896 resolution and encoded to 256 tokens
|
||||||
|
each
|
||||||
|
- Total input context of 128K tokens for the 4B, 12B, and 27B sizes, and
|
||||||
|
32K tokens for the 1B size
|
||||||
|
|
||||||
|
- **Output:**
|
||||||
|
- Generated text in response to the input, such as an answer to a
|
||||||
|
question, analysis of image content, or a summary of a document
|
||||||
|
- Total output context of 8192 tokens
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
Below there are some code snippets on how to get quickly started with running the model. First, install the Transformers library with the version made for Gemma 3:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
$ pip install git+https://github.com/huggingface/transformers@v4.49.0-Gemma-3
|
||||||
|
```
|
||||||
|
|
||||||
|
Then, copy the snippet from the section that is relevant for your use case.
|
||||||
|
|
||||||
|
#### Running with the `pipeline` API
|
||||||
|
|
||||||
|
You can initialize the model and processor for inference with `pipeline` as follows.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from transformers import pipeline
|
||||||
|
import torch
|
||||||
|
|
||||||
|
pipe = pipeline(
|
||||||
|
"image-text-to-text",
|
||||||
|
model="google/gemma-3-27b-it",
|
||||||
|
device="cuda",
|
||||||
|
torch_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
With instruction-tuned models, you need to use chat templates to process our inputs first. Then, you can pass it to the pipeline.
|
||||||
|
|
||||||
|
```python
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [{"type": "text", "text": "You are a helpful assistant."}]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/p-blog/candy.JPG"},
|
||||||
|
{"type": "text", "text": "What animal is on the candy?"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
output = pipe(text=messages, max_new_tokens=200)
|
||||||
|
print(output[0][0]["generated_text"][-1]["content"])
|
||||||
|
# Okay, let's take a look!
|
||||||
|
# Based on the image, the animal on the candy is a **turtle**.
|
||||||
|
# You can see the shell shape and the head and legs.
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Running the model on a single/multi GPU
|
||||||
|
|
||||||
|
```python
|
||||||
|
# pip install accelerate
|
||||||
|
|
||||||
|
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
||||||
|
from PIL import Image
|
||||||
|
import requests
|
||||||
|
import torch
|
||||||
|
|
||||||
|
model_id = "google/gemma-3-27b-it"
|
||||||
|
|
||||||
|
model = Gemma3ForConditionalGeneration.from_pretrained(
|
||||||
|
model_id, device_map="auto"
|
||||||
|
).eval()
|
||||||
|
|
||||||
|
processor = AutoProcessor.from_pretrained(model_id)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [{"type": "text", "text": "You are a helpful assistant."}]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image", "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"},
|
||||||
|
{"type": "text", "text": "Describe this image in detail."}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
inputs = processor.apply_chat_template(
|
||||||
|
messages, add_generation_prompt=True, tokenize=True,
|
||||||
|
return_dict=True, return_tensors="pt"
|
||||||
|
).to(model.device, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
input_len = inputs["input_ids"].shape[-1]
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
|
||||||
|
generation = generation[0][input_len:]
|
||||||
|
|
||||||
|
decoded = processor.decode(generation, skip_special_tokens=True)
|
||||||
|
print(decoded)
|
||||||
|
|
||||||
|
# **Overall Impression:** The image is a close-up shot of a vibrant garden scene,
|
||||||
|
# focusing on a cluster of pink cosmos flowers and a busy bumblebee.
|
||||||
|
# It has a slightly soft, natural feel, likely captured in daylight.
|
||||||
|
```
|
||||||
|
|
||||||
|
### Citation
|
||||||
|
|
||||||
|
```none
|
||||||
|
@article{gemma_2025,
|
||||||
|
title={Gemma 3},
|
||||||
|
url={https://goo.gle/Gemma3Report},
|
||||||
|
publisher={Kaggle},
|
||||||
|
author={Gemma Team},
|
||||||
|
year={2025}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Model Data
|
||||||
|
|
||||||
|
Data used for model training and how the data was processed.
|
||||||
|
|
||||||
|
### Training Dataset
|
||||||
|
|
||||||
|
These models were trained on a dataset of text data that includes a wide variety
|
||||||
|
of sources. The 27B model was trained with 14 trillion tokens, the 12B model was
|
||||||
|
trained with 12 trillion tokens, 4B model was trained with 4 trillion tokens and
|
||||||
|
1B with 2 trillion tokens. Here are the key components:
|
||||||
|
|
||||||
|
- Web Documents: A diverse collection of web text ensures the model is
|
||||||
|
exposed to a broad range of linguistic styles, topics, and vocabulary. The
|
||||||
|
training dataset includes content in over 140 languages.
|
||||||
|
- Code: Exposing the model to code helps it to learn the syntax and
|
||||||
|
patterns of programming languages, which improves its ability to generate
|
||||||
|
code and understand code-related questions.
|
||||||
|
- Mathematics: Training on mathematical text helps the model learn logical
|
||||||
|
reasoning, symbolic representation, and to address mathematical queries.
|
||||||
|
- Images: A wide range of images enables the model to perform image
|
||||||
|
analysis and visual data extraction tasks.
|
||||||
|
|
||||||
|
The combination of these diverse data sources is crucial for training a powerful
|
||||||
|
multimodal model that can handle a wide variety of different tasks and data
|
||||||
|
formats.
|
||||||
|
|
||||||
|
### Data Preprocessing
|
||||||
|
|
||||||
|
Here are the key data cleaning and filtering methods applied to the training
|
||||||
|
data:
|
||||||
|
|
||||||
|
- CSAM Filtering: Rigorous CSAM (Child Sexual Abuse Material) filtering
|
||||||
|
was applied at multiple stages in the data preparation process to ensure
|
||||||
|
the exclusion of harmful and illegal content.
|
||||||
|
- Sensitive Data Filtering: As part of making Gemma pre-trained models
|
||||||
|
safe and reliable, automated techniques were used to filter out certain
|
||||||
|
personal information and other sensitive data from training sets.
|
||||||
|
- Additional methods: Filtering based on content quality and safety in
|
||||||
|
line with [our policies][safety-policies].
|
||||||
|
|
||||||
|
## Implementation Information
|
||||||
|
|
||||||
|
Details about the model internals.
|
||||||
|
|
||||||
|
### Hardware
|
||||||
|
|
||||||
|
Gemma was trained using [Tensor Processing Unit (TPU)][tpu] hardware (TPUv4p,
|
||||||
|
TPUv5p and TPUv5e). Training vision-language models (VLMS) requires significant
|
||||||
|
computational power. TPUs, designed specifically for matrix operations common in
|
||||||
|
machine learning, offer several advantages in this domain:
|
||||||
|
|
||||||
|
- Performance: TPUs are specifically designed to handle the massive
|
||||||
|
computations involved in training VLMs. They can speed up training
|
||||||
|
considerably compared to CPUs.
|
||||||
|
- Memory: TPUs often come with large amounts of high-bandwidth memory,
|
||||||
|
allowing for the handling of large models and batch sizes during training.
|
||||||
|
This can lead to better model quality.
|
||||||
|
- Scalability: TPU Pods (large clusters of TPUs) provide a scalable
|
||||||
|
solution for handling the growing complexity of large foundation models.
|
||||||
|
You can distribute training across multiple TPU devices for faster and more
|
||||||
|
efficient processing.
|
||||||
|
- Cost-effectiveness: In many scenarios, TPUs can provide a more
|
||||||
|
cost-effective solution for training large models compared to CPU-based
|
||||||
|
infrastructure, especially when considering the time and resources saved
|
||||||
|
due to faster training.
|
||||||
|
- These advantages are aligned with
|
||||||
|
[Google's commitments to operate sustainably][sustainability].
|
||||||
|
|
||||||
|
### Software
|
||||||
|
|
||||||
|
Training was done using [JAX][jax] and [ML Pathways][ml-pathways].
|
||||||
|
|
||||||
|
JAX allows researchers to take advantage of the latest generation of hardware,
|
||||||
|
including TPUs, for faster and more efficient training of large models. ML
|
||||||
|
Pathways is Google's latest effort to build artificially intelligent systems
|
||||||
|
capable of generalizing across multiple tasks. This is specially suitable for
|
||||||
|
foundation models, including large language models like these ones.
|
||||||
|
|
||||||
|
Together, JAX and ML Pathways are used as described in the
|
||||||
|
[paper about the Gemini family of models][gemini-2-paper]; *"the 'single
|
||||||
|
controller' programming model of Jax and Pathways allows a single Python
|
||||||
|
process to orchestrate the entire training run, dramatically simplifying the
|
||||||
|
development workflow."*
|
||||||
|
|
||||||
|
## Evaluation
|
||||||
|
|
||||||
|
Model evaluation metrics and results.
|
||||||
|
|
||||||
|
### Benchmark Results
|
||||||
|
|
||||||
|
These models were evaluated against a large collection of different datasets and
|
||||||
|
metrics to cover different aspects of text generation:
|
||||||
|
|
||||||
|
#### Reasoning and factuality
|
||||||
|
|
||||||
|
| Benchmark | Metric | Gemma 3 PT 1B | Gemma 3 PT 4B | Gemma 3 PT 12B | Gemma 3 PT 27B |
|
||||||
|
| ------------------------------ |----------------|:--------------:|:-------------:|:--------------:|:--------------:|
|
||||||
|
| [HellaSwag][hellaswag] | 10-shot | 62.3 | 77.2 | 84.2 | 85.6 |
|
||||||
|
| [BoolQ][boolq] | 0-shot | 63.2 | 72.3 | 78.8 | 82.4 |
|
||||||
|
| [PIQA][piqa] | 0-shot | 73.8 | 79.6 | 81.8 | 83.3 |
|
||||||
|
| [SocialIQA][socialiqa] | 0-shot | 48.9 | 51.9 | 53.4 | 54.9 |
|
||||||
|
| [TriviaQA][triviaqa] | 5-shot | 39.8 | 65.8 | 78.2 | 85.5 |
|
||||||
|
| [Natural Questions][naturalq] | 5-shot | 9.48 | 20.0 | 31.4 | 36.1 |
|
||||||
|
| [ARC-c][arc] | 25-shot | 38.4 | 56.2 | 68.9 | 70.6 |
|
||||||
|
| [ARC-e][arc] | 0-shot | 73.0 | 82.4 | 88.3 | 89.0 |
|
||||||
|
| [WinoGrande][winogrande] | 5-shot | 58.2 | 64.7 | 74.3 | 78.8 |
|
||||||
|
| [BIG-Bench Hard][bbh] | few-shot | 28.4 | 50.9 | 72.6 | 77.7 |
|
||||||
|
| [DROP][drop] | 1-shot | 42.4 | 60.1 | 72.2 | 77.2 |
|
||||||
|
|
||||||
|
[hellaswag]: https://arxiv.org/abs/1905.07830
|
||||||
|
[boolq]: https://arxiv.org/abs/1905.10044
|
||||||
|
[piqa]: https://arxiv.org/abs/1911.11641
|
||||||
|
[socialiqa]: https://arxiv.org/abs/1904.09728
|
||||||
|
[triviaqa]: https://arxiv.org/abs/1705.03551
|
||||||
|
[naturalq]: https://github.com/google-research-datasets/natural-questions
|
||||||
|
[arc]: https://arxiv.org/abs/1911.01547
|
||||||
|
[winogrande]: https://arxiv.org/abs/1907.10641
|
||||||
|
[bbh]: https://paperswithcode.com/dataset/bbh
|
||||||
|
[drop]: https://arxiv.org/abs/1903.00161
|
||||||
|
|
||||||
|
#### STEM and code
|
||||||
|
|
||||||
|
| Benchmark | Metric | Gemma 3 PT 4B | Gemma 3 PT 12B | Gemma 3 PT 27B |
|
||||||
|
| ------------------------------ |----------------|:-------------:|:--------------:|:--------------:|
|
||||||
|
| [MMLU][mmlu] | 5-shot | 59.6 | 74.5 | 78.6 |
|
||||||
|
| [MMLU][mmlu] (Pro COT) | 5-shot | 29.2 | 45.3 | 52.2 |
|
||||||
|
| [AGIEval][agieval] | 3-5-shot | 42.1 | 57.4 | 66.2 |
|
||||||
|
| [MATH][math] | 4-shot | 24.2 | 43.3 | 50.0 |
|
||||||
|
| [GSM8K][gsm8k] | 8-shot | 38.4 | 71.0 | 82.6 |
|
||||||
|
| [GPQA][gpqa] | 5-shot | 15.0 | 25.4 | 24.3 |
|
||||||
|
| [MBPP][mbpp] | 3-shot | 46.0 | 60.4 | 65.6 |
|
||||||
|
| [HumanEval][humaneval] | 0-shot | 36.0 | 45.7 | 48.8 |
|
||||||
|
|
||||||
|
[mmlu]: https://arxiv.org/abs/2009.03300
|
||||||
|
[agieval]: https://arxiv.org/abs/2304.06364
|
||||||
|
[math]: https://arxiv.org/abs/2103.03874
|
||||||
|
[gsm8k]: https://arxiv.org/abs/2110.14168
|
||||||
|
[gpqa]: https://arxiv.org/abs/2311.12022
|
||||||
|
[mbpp]: https://arxiv.org/abs/2108.07732
|
||||||
|
[humaneval]: https://arxiv.org/abs/2107.03374
|
||||||
|
|
||||||
|
#### Multilingual
|
||||||
|
|
||||||
|
| Benchmark | Gemma 3 PT 1B | Gemma 3 PT 4B | Gemma 3 PT 12B | Gemma 3 PT 27B |
|
||||||
|
| ------------------------------------ |:-------------:|:-------------:|:--------------:|:--------------:|
|
||||||
|
| [MGSM][mgsm] | 2.04 | 34.7 | 64.3 | 74.3 |
|
||||||
|
| [Global-MMLU-Lite][global-mmlu-lite] | 24.9 | 57.0 | 69.4 | 75.7 |
|
||||||
|
| [WMT24++][wmt24pp] (ChrF) | 36.7 | 48.4 | 53.9 | 55.7 |
|
||||||
|
| [FloRes][flores] | 29.5 | 39.2 | 46.0 | 48.8 |
|
||||||
|
| [XQuAD][xquad] (all) | 43.9 | 68.0 | 74.5 | 76.8 |
|
||||||
|
| [ECLeKTic][eclektic] | 4.69 | 11.0 | 17.2 | 24.4 |
|
||||||
|
| [IndicGenBench][indicgenbench] | 41.4 | 57.2 | 61.7 | 63.4 |
|
||||||
|
|
||||||
|
[mgsm]: https://arxiv.org/abs/2210.03057
|
||||||
|
[flores]: https://arxiv.org/abs/2106.03193
|
||||||
|
[xquad]: https://arxiv.org/abs/1910.11856v3
|
||||||
|
[global-mmlu-lite]: https://huggingface.co/datasets/CohereForAI/Global-MMLU-Lite
|
||||||
|
[wmt24pp]: https://arxiv.org/abs/2502.12404v1
|
||||||
|
[eclektic]: https://arxiv.org/abs/2502.21228
|
||||||
|
[indicgenbench]: https://arxiv.org/abs/2404.16816
|
||||||
|
|
||||||
|
#### Multimodal
|
||||||
|
|
||||||
|
| Benchmark | Gemma 3 PT 4B | Gemma 3 PT 12B | Gemma 3 PT 27B |
|
||||||
|
| ------------------------------ |:-------------:|:--------------:|:--------------:|
|
||||||
|
| [COCOcap][coco-cap] | 102 | 111 | 116 |
|
||||||
|
| [DocVQA][docvqa] (val) | 72.8 | 82.3 | 85.6 |
|
||||||
|
| [InfoVQA][info-vqa] (val) | 44.1 | 54.8 | 59.4 |
|
||||||
|
| [MMMU][mmmu] (pt) | 39.2 | 50.3 | 56.1 |
|
||||||
|
| [TextVQA][textvqa] (val) | 58.9 | 66.5 | 68.6 |
|
||||||
|
| [RealWorldQA][realworldqa] | 45.5 | 52.2 | 53.9 |
|
||||||
|
| [ReMI][remi] | 27.3 | 38.5 | 44.8 |
|
||||||
|
| [AI2D][ai2d] | 63.2 | 75.2 | 79.0 |
|
||||||
|
| [ChartQA][chartqa] | 63.6 | 74.7 | 76.3 |
|
||||||
|
| [VQAv2][vqav2] | 63.9 | 71.2 | 72.9 |
|
||||||
|
| [BLINK][blinkvqa] | 38.0 | 35.9 | 39.6 |
|
||||||
|
| [OKVQA][okvqa] | 51.0 | 58.7 | 60.2 |
|
||||||
|
| [TallyQA][tallyqa] | 42.5 | 51.8 | 54.3 |
|
||||||
|
| [SpatialSense VQA][ss-vqa] | 50.9 | 60.0 | 59.4 |
|
||||||
|
| [CountBenchQA][countbenchqa] | 26.1 | 17.8 | 68.0 |
|
||||||
|
|
||||||
|
[coco-cap]: https://cocodataset.org/#home
|
||||||
|
[docvqa]: https://www.docvqa.org/
|
||||||
|
[info-vqa]: https://arxiv.org/abs/2104.12756
|
||||||
|
[mmmu]: https://arxiv.org/abs/2311.16502
|
||||||
|
[textvqa]: https://textvqa.org/
|
||||||
|
[realworldqa]: https://paperswithcode.com/dataset/realworldqa
|
||||||
|
[remi]: https://arxiv.org/html/2406.09175v1
|
||||||
|
[ai2d]: https://allenai.org/data/diagrams
|
||||||
|
[chartqa]: https://arxiv.org/abs/2203.10244
|
||||||
|
[vqav2]: https://visualqa.org/index.html
|
||||||
|
[blinkvqa]: https://arxiv.org/abs/2404.12390
|
||||||
|
[okvqa]: https://okvqa.allenai.org/
|
||||||
|
[tallyqa]: https://arxiv.org/abs/1810.12440
|
||||||
|
[ss-vqa]: https://arxiv.org/abs/1908.02660
|
||||||
|
[countbenchqa]: https://github.com/google-research/big_vision/blob/main/big_vision/datasets/countbenchqa/
|
||||||
|
|
||||||
|
## Ethics and Safety
|
||||||
|
|
||||||
|
Ethics and safety evaluation approach and results.
|
||||||
|
|
||||||
|
### Evaluation Approach
|
||||||
|
|
||||||
|
Our evaluation methods include structured evaluations and internal red-teaming
|
||||||
|
testing of relevant content policies. Red-teaming was conducted by a number of
|
||||||
|
different teams, each with different goals and human evaluation metrics. These
|
||||||
|
models were evaluated against a number of different categories relevant to
|
||||||
|
ethics and safety, including:
|
||||||
|
|
||||||
|
- **Child Safety**: Evaluation of text-to-text and image to text prompts
|
||||||
|
covering child safety policies, including child sexual abuse and
|
||||||
|
exploitation.
|
||||||
|
- **Content Safety:** Evaluation of text-to-text and image to text prompts
|
||||||
|
covering safety policies including, harassment, violence and gore, and hate
|
||||||
|
speech.
|
||||||
|
- **Representational Harms**: Evaluation of text-to-text and image to text
|
||||||
|
prompts covering safety policies including bias, stereotyping, and harmful
|
||||||
|
associations or inaccuracies.
|
||||||
|
|
||||||
|
In addition to development level evaluations, we conduct "assurance
|
||||||
|
evaluations" which are our 'arms-length' internal evaluations for responsibility
|
||||||
|
governance decision making. They are conducted separately from the model
|
||||||
|
development team, to inform decision making about release. High level findings
|
||||||
|
are fed back to the model team, but prompt sets are held-out to prevent
|
||||||
|
overfitting and preserve the results' ability to inform decision making.
|
||||||
|
Assurance evaluation results are reported to our Responsibility & Safety Council
|
||||||
|
as part of release review.
|
||||||
|
|
||||||
|
### Evaluation Results
|
||||||
|
|
||||||
|
For all areas of safety testing, we saw major improvements in the categories of
|
||||||
|
child safety, content safety, and representational harms relative to previous
|
||||||
|
Gemma models. All testing was conducted without safety filters to evaluate the
|
||||||
|
model capabilities and behaviors. For both text-to-text and image-to-text, and
|
||||||
|
across all model sizes, the model produced minimal policy violations, and showed
|
||||||
|
significant improvements over previous Gemma models' performance with respect
|
||||||
|
to ungrounded inferences. A limitation of our evaluations was they included only
|
||||||
|
English language prompts.
|
||||||
|
|
||||||
|
## Usage and Limitations
|
||||||
|
|
||||||
|
These models have certain limitations that users should be aware of.
|
||||||
|
|
||||||
|
### Intended Usage
|
||||||
|
|
||||||
|
Open vision-language models (VLMs) models have a wide range of applications
|
||||||
|
across various industries and domains. The following list of potential uses is
|
||||||
|
not comprehensive. The purpose of this list is to provide contextual information
|
||||||
|
about the possible use-cases that the model creators considered as part of model
|
||||||
|
training and development.
|
||||||
|
|
||||||
|
- Content Creation and Communication
|
||||||
|
- Text Generation: These models can be used to generate creative text
|
||||||
|
formats such as poems, scripts, code, marketing copy, and email drafts.
|
||||||
|
- Chatbots and Conversational AI: Power conversational interfaces
|
||||||
|
for customer service, virtual assistants, or interactive applications.
|
||||||
|
- Text Summarization: Generate concise summaries of a text corpus,
|
||||||
|
research papers, or reports.
|
||||||
|
- Image Data Extraction: These models can be used to extract,
|
||||||
|
interpret, and summarize visual data for text communications.
|
||||||
|
- Research and Education
|
||||||
|
- Natural Language Processing (NLP) and VLM Research: These
|
||||||
|
models can serve as a foundation for researchers to experiment with VLM
|
||||||
|
and NLP techniques, develop algorithms, and contribute to the
|
||||||
|
advancement of the field.
|
||||||
|
- Language Learning Tools: Support interactive language learning
|
||||||
|
experiences, aiding in grammar correction or providing writing practice.
|
||||||
|
- Knowledge Exploration: Assist researchers in exploring large
|
||||||
|
bodies of text by generating summaries or answering questions about
|
||||||
|
specific topics.
|
||||||
|
|
||||||
|
### Limitations
|
||||||
|
|
||||||
|
- Training Data
|
||||||
|
- The quality and diversity of the training data significantly
|
||||||
|
influence the model's capabilities. Biases or gaps in the training data
|
||||||
|
can lead to limitations in the model's responses.
|
||||||
|
- The scope of the training dataset determines the subject areas
|
||||||
|
the model can handle effectively.
|
||||||
|
- Context and Task Complexity
|
||||||
|
- Models are better at tasks that can be framed with clear
|
||||||
|
prompts and instructions. Open-ended or highly complex tasks might be
|
||||||
|
challenging.
|
||||||
|
- A model's performance can be influenced by the amount of context
|
||||||
|
provided (longer context generally leads to better outputs, up to a
|
||||||
|
certain point).
|
||||||
|
- Language Ambiguity and Nuance
|
||||||
|
- Natural language is inherently complex. Models might struggle
|
||||||
|
to grasp subtle nuances, sarcasm, or figurative language.
|
||||||
|
- Factual Accuracy
|
||||||
|
- Models generate responses based on information they learned
|
||||||
|
from their training datasets, but they are not knowledge bases. They
|
||||||
|
may generate incorrect or outdated factual statements.
|
||||||
|
- Common Sense
|
||||||
|
- Models rely on statistical patterns in language. They might
|
||||||
|
lack the ability to apply common sense reasoning in certain situations.
|
||||||
|
|
||||||
|
### Ethical Considerations and Risks
|
||||||
|
|
||||||
|
The development of vision-language models (VLMs) raises several ethical
|
||||||
|
concerns. In creating an open model, we have carefully considered the following:
|
||||||
|
|
||||||
|
- Bias and Fairness
|
||||||
|
- VLMs trained on large-scale, real-world text and image data can
|
||||||
|
reflect socio-cultural biases embedded in the training material. These
|
||||||
|
models underwent careful scrutiny, input data pre-processing described
|
||||||
|
and posterior evaluations reported in this card.
|
||||||
|
- Misinformation and Misuse
|
||||||
|
- VLMs can be misused to generate text that is false, misleading,
|
||||||
|
or harmful.
|
||||||
|
- Guidelines are provided for responsible use with the model, see the
|
||||||
|
[Responsible Generative AI Toolkit][rai-toolkit].
|
||||||
|
- Transparency and Accountability:
|
||||||
|
- This model card summarizes details on the models' architecture,
|
||||||
|
capabilities, limitations, and evaluation processes.
|
||||||
|
- A responsibly developed open model offers the opportunity to
|
||||||
|
share innovation by making VLM technology accessible to developers and
|
||||||
|
researchers across the AI ecosystem.
|
||||||
|
|
||||||
|
Risks identified and mitigations:
|
||||||
|
|
||||||
|
- **Perpetuation of biases**: It's encouraged to perform continuous
|
||||||
|
monitoring (using evaluation metrics, human review) and the exploration of
|
||||||
|
de-biasing techniques during model training, fine-tuning, and other use
|
||||||
|
cases.
|
||||||
|
- **Generation of harmful content**: Mechanisms and guidelines for content
|
||||||
|
safety are essential. Developers are encouraged to exercise caution and
|
||||||
|
implement appropriate content safety safeguards based on their specific
|
||||||
|
product policies and application use cases.
|
||||||
|
- **Misuse for malicious purposes**: Technical limitations and developer
|
||||||
|
and end-user education can help mitigate against malicious applications of
|
||||||
|
VLMs. Educational resources and reporting mechanisms for users to flag
|
||||||
|
misuse are provided. Prohibited uses of Gemma models are outlined in the
|
||||||
|
[Gemma Prohibited Use Policy][prohibited-use].
|
||||||
|
- **Privacy violations**: Models were trained on data filtered for removal
|
||||||
|
of certain personal information and other sensitive data. Developers are
|
||||||
|
encouraged to adhere to privacy regulations with privacy-preserving
|
||||||
|
techniques.
|
||||||
|
|
||||||
|
### Benefits
|
||||||
|
|
||||||
|
At the time of release, this family of models provides high-performance open
|
||||||
|
vision-language model implementations designed from the ground up for
|
||||||
|
responsible AI development compared to similarly sized models.
|
||||||
|
|
||||||
|
Using the benchmark evaluation metrics described in this document, these models
|
||||||
|
have shown to provide superior performance to other, comparably-sized open model
|
||||||
|
alternatives.
|
||||||
|
|
||||||
|
[g3-tech-report]: https://goo.gle/Gemma3Report
|
||||||
|
[rai-toolkit]: https://ai.google.dev/responsible
|
||||||
|
[kaggle-gemma]: https://www.kaggle.com/models/google/gemma-3
|
||||||
|
[vertex-mg-gemma3]: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/gemma3
|
||||||
|
[terms]: https://ai.google.dev/gemma/terms
|
||||||
|
[safety-policies]: https://ai.google/static/documents/ai-responsibility-update-published-february-2025.pdf
|
||||||
|
[prohibited-use]: https://ai.google.dev/gemma/prohibited_use_policy
|
||||||
|
[tpu]: https://cloud.google.com/tpu/docs/intro-to-tpu
|
||||||
|
[sustainability]: https://sustainability.google/operating-sustainably/
|
||||||
|
[jax]: https://github.com/jax-ml/jax
|
||||||
|
[ml-pathways]: https://blog.google/technology/ai/introducing-pathways-next-generation-ai-architecture/
|
||||||
|
[sustainability]: https://sustainability.google/operating-sustainably/
|
||||||
|
[gemini-2-paper]: https://arxiv.org/abs/2312.11805
|
||||||
3
added_tokens.json
Normal file
3
added_tokens.json
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
{
|
||||||
|
"<image_soft_token>": 262144
|
||||||
|
}
|
||||||
3
chat_template.json
Normal file
3
chat_template.json
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
{
|
||||||
|
"chat_template": "{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n {%- if messages[0]['content'] is string -%}\n {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}\n {%- else -%}\n {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}\n {%- endif -%}\n {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n {%- set first_user_prefix = \"\" -%}\n {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif -%}\n {%- if (message['role'] == 'assistant') -%}\n {%- set role = \"model\" -%}\n {%- else -%}\n {%- set role = message['role'] -%}\n {%- endif -%}\n {{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else \"\") }}\n {%- if message['content'] is string -%}\n {{ message['content'] | trim }}\n {%- elif message['content'] is iterable -%}\n {%- for item in message['content'] -%}\n {%- if item['type'] == 'image' -%}\n {{ '<start_of_image>' }}\n {%- elif item['type'] == 'text' -%}\n {{ item['text'] | trim }}\n {%- endif -%}\n {%- endfor -%}\n {%- else -%}\n {{ raise_exception(\"Invalid content type\") }}\n {%- endif -%}\n {{ '<end_of_turn>\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{'<start_of_turn>model\n'}}\n{%- endif -%}\n"
|
||||||
|
}
|
||||||
53
config.json
Normal file
53
config.json
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
{
|
||||||
|
"architectures": [
|
||||||
|
"Gemma3ForConditionalGeneration"
|
||||||
|
],
|
||||||
|
"boi_token_index": 255999,
|
||||||
|
"eoi_token_index": 256000,
|
||||||
|
"eos_token_id": [
|
||||||
|
1,
|
||||||
|
106
|
||||||
|
],
|
||||||
|
"image_token_index": 262144,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"mm_tokens_per_image": 256,
|
||||||
|
"model_type": "gemma3",
|
||||||
|
"text_config": {
|
||||||
|
"head_dim": 128,
|
||||||
|
"hidden_size": 5376,
|
||||||
|
"intermediate_size": 21504,
|
||||||
|
"model_type": "gemma3_text",
|
||||||
|
"num_attention_heads": 32,
|
||||||
|
"num_hidden_layers": 62,
|
||||||
|
"num_key_value_heads": 16,
|
||||||
|
"query_pre_attn_scalar": 168,
|
||||||
|
"rope_scaling": {
|
||||||
|
"factor": 8.0,
|
||||||
|
"rope_type": "linear"
|
||||||
|
},
|
||||||
|
"sliding_window": 1024
|
||||||
|
},
|
||||||
|
"torch_dtype": "bfloat16",
|
||||||
|
"transformers_version": "4.50.0.dev0",
|
||||||
|
"vision_config": {
|
||||||
|
"hidden_size": 1152,
|
||||||
|
"image_size": 896,
|
||||||
|
"intermediate_size": 4304,
|
||||||
|
"model_type": "siglip_vision_model",
|
||||||
|
"num_attention_heads": 16,
|
||||||
|
"num_hidden_layers": 27,
|
||||||
|
"patch_size": 14,
|
||||||
|
"vision_use_head": false
|
||||||
|
},
|
||||||
|
"quantization_config": {
|
||||||
|
"bits": 4,
|
||||||
|
"group_size": 32,
|
||||||
|
"quant_method": "awq",
|
||||||
|
"version": "gemm",
|
||||||
|
"zero_point": true,
|
||||||
|
"modules_to_not_convert": [
|
||||||
|
"lm_head",
|
||||||
|
"vision_tower"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
281
convert_flax.py
Normal file
281
convert_flax.py
Normal file
@@ -0,0 +1,281 @@
|
|||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
import orbax.checkpoint as ocp
|
||||||
|
from safetensors.flax import save_file
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
SIGLIP_PREFIX = "SigLiPFromPatches_0/siglip_encoder"
|
||||||
|
|
||||||
|
|
||||||
|
def flatten(x: np.ndarray, start: int = 0, end: int = -1):
|
||||||
|
if start < 0:
|
||||||
|
start += x.ndim
|
||||||
|
if end < 0:
|
||||||
|
end += x.ndim
|
||||||
|
new_shape = x.shape[:start] + (-1,) + x.shape[end + 1 :]
|
||||||
|
return x.reshape(new_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def unflatten(x: np.ndarray, dim: int, sizes: tuple[int, ...]):
|
||||||
|
new_shape = x.shape[:dim] + tuple(sizes) + x.shape[dim + 1 :]
|
||||||
|
return x.reshape(new_shape)
|
||||||
|
|
||||||
|
|
||||||
|
# correct quantization parameters mean quantization error = 0 (or close to 0)
|
||||||
|
def check_groups(groups: np.ndarray, scales: np.ndarray, dim: int):
|
||||||
|
# groups: (a, b, c, 32, d, e, f)
|
||||||
|
# scales: (a, b, c, 1, d, e, f)
|
||||||
|
inv_scale = 1.0 / scales.clip(1e-12)
|
||||||
|
q_group = np.round(groups * inv_scale)
|
||||||
|
max_diff = np.abs(q_group * scales - groups).max(dim, keepdims=True)
|
||||||
|
return max_diff < 1e-6, max_diff
|
||||||
|
|
||||||
|
|
||||||
|
def find_scales(w: np.ndarray, dim: int):
|
||||||
|
w = unflatten(w, dim, (-1, 32))
|
||||||
|
group_range = w.max(dim + 1, keepdims=True) - w.min(dim + 1, keepdims=True)
|
||||||
|
|
||||||
|
scales = np.zeros_like(group_range)
|
||||||
|
for q in range(15, 0, -1):
|
||||||
|
try_scale = group_range / q
|
||||||
|
ok, _ = check_groups(w, try_scale, dim + 1)
|
||||||
|
scales[ok] = try_scale[ok]
|
||||||
|
|
||||||
|
ok, _ = check_groups(w, scales, dim + 1)
|
||||||
|
assert ok.all()
|
||||||
|
|
||||||
|
return scales.squeeze(dim + 1)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_siglip(params, num_layers: int):
|
||||||
|
state_dict = dict()
|
||||||
|
|
||||||
|
def convert_layer(prefix: str, layer: dict[str, np.ndarray]):
|
||||||
|
bias = layer["bias"]
|
||||||
|
|
||||||
|
if "kernel" in layer:
|
||||||
|
w = layer["kernel"]
|
||||||
|
if w.ndim == 2: # linear layer
|
||||||
|
w = w.T
|
||||||
|
|
||||||
|
elif w.ndim == 3: # attn projection
|
||||||
|
# qkv projection - (dim, num_heads, head_dim)
|
||||||
|
if bias.ndim == 2:
|
||||||
|
w = flatten(w, 1, 2).T
|
||||||
|
bias = bias.reshape(-1)
|
||||||
|
|
||||||
|
# o projection - (num_heads, head_dim, dim)
|
||||||
|
elif bias.ndim == 1:
|
||||||
|
w = flatten(w, 0, 1).T
|
||||||
|
|
||||||
|
elif w.ndim == 4: # conv2d layer
|
||||||
|
w = w.transpose(3, 2, 0, 1)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unsupported {w.shape=}")
|
||||||
|
|
||||||
|
elif "scale" in layer: # layer norm
|
||||||
|
w = layer["scale"]
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise RuntimeError
|
||||||
|
|
||||||
|
state_dict[f"{prefix}weight"] = w
|
||||||
|
state_dict[f"{prefix}bias"] = bias
|
||||||
|
|
||||||
|
convert_layer("embeddings.patch_embedding.", params[f"{SIGLIP_PREFIX}/embedding"])
|
||||||
|
state_dict["embeddings.position_embedding.weight"] = params[SIGLIP_PREFIX]["pos_embedding"].squeeze(0)
|
||||||
|
convert_layer("post_layernorm.", params[f"{SIGLIP_PREFIX}/Transformer/encoder_norm"])
|
||||||
|
|
||||||
|
for layer_idx in range(num_layers):
|
||||||
|
prefix = f"encoder.layers.{layer_idx}."
|
||||||
|
layer_prefix = f"{SIGLIP_PREFIX}/Transformer/encoderblock_{layer_idx}/"
|
||||||
|
|
||||||
|
convert_layer(f"{prefix}layer_norm1.", params[f"{layer_prefix}LayerNorm_0"])
|
||||||
|
convert_layer(f"{prefix}layer_norm2.", params[f"{layer_prefix}LayerNorm_1"])
|
||||||
|
|
||||||
|
attn_prefix = f"{layer_prefix}MultiHeadDotProductAttention_0/"
|
||||||
|
convert_layer(f"{prefix}self_attn.q_proj.", params[f"{attn_prefix}query"])
|
||||||
|
convert_layer(f"{prefix}self_attn.k_proj.", params[f"{attn_prefix}key"])
|
||||||
|
convert_layer(f"{prefix}self_attn.v_proj.", params[f"{attn_prefix}value"])
|
||||||
|
convert_layer(f"{prefix}self_attn.out_proj.", params[f"{attn_prefix}out"])
|
||||||
|
|
||||||
|
mlp_prefix = f"{layer_prefix}MlpBlock_0/"
|
||||||
|
convert_layer(f"{prefix}mlp.fc1.", params[f"{mlp_prefix}Dense_0"])
|
||||||
|
convert_layer(f"{prefix}mlp.fc2.", params[f"{mlp_prefix}Dense_1"])
|
||||||
|
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
# convert to HF format first, then apply quantization
|
||||||
|
def convert_to_hf(path: Path):
|
||||||
|
path = path.absolute() # orbax only works with absolute path
|
||||||
|
ckpt = ocp.StandardCheckpointer()
|
||||||
|
metadata = dict(ckpt.metadata(path))
|
||||||
|
metadata = jax.tree.map(ocp.utils.to_shape_dtype_struct, metadata)
|
||||||
|
|
||||||
|
num_layers = num_siglip_layers = 0
|
||||||
|
while f"transformer/layer_{num_layers}/attn/_key_norm" in metadata:
|
||||||
|
num_layers += 1
|
||||||
|
while f"{SIGLIP_PREFIX}/Transformer/encoderblock_{num_siglip_layers}/LayerNorm_0" in metadata:
|
||||||
|
num_siglip_layers += 1
|
||||||
|
print(f"{num_layers=}")
|
||||||
|
print(f"{num_siglip_layers=}")
|
||||||
|
|
||||||
|
# NOTE: all gemma3 models use tied embeddings, even for the 27B version.
|
||||||
|
params = ckpt.restore(path)
|
||||||
|
state_dict = dict()
|
||||||
|
|
||||||
|
if num_siglip_layers > 0:
|
||||||
|
# HF append unused tokens for no reason???
|
||||||
|
embed = params["transformer/embedder"]["input_embedding"]
|
||||||
|
params["transformer/embedder"]["input_embedding"] = np.pad(embed, ((0, 64), (0, 0)))
|
||||||
|
gemma_prefix = "language_model."
|
||||||
|
|
||||||
|
prefix = "multi_modal_projector.mm_"
|
||||||
|
jax_prefix = "transformer/embedder/"
|
||||||
|
state_dict[f"{prefix}input_projection_weight"] = params[f"{jax_prefix}mm_input_projection"]["w"]
|
||||||
|
state_dict[f"{prefix}soft_emb_norm.weight"] = params[f"{jax_prefix}mm_soft_embedding_norm"]["scale"]
|
||||||
|
|
||||||
|
else:
|
||||||
|
gemma_prefix = ""
|
||||||
|
|
||||||
|
state_dict[f"{gemma_prefix}model.embed_tokens.weight"] = params["transformer/embedder"]["input_embedding"]
|
||||||
|
state_dict[f"{gemma_prefix}model.norm.weight"] = params["transformer/final_norm"]["scale"]
|
||||||
|
|
||||||
|
yield state_dict
|
||||||
|
|
||||||
|
for layer_idx in range(num_layers):
|
||||||
|
jax_prefix = f"transformer/layer_{layer_idx}/"
|
||||||
|
|
||||||
|
state_dict = dict()
|
||||||
|
prefix = f"{gemma_prefix}model.layers.{layer_idx}."
|
||||||
|
state_dict[f"{prefix}input_layernorm.weight"] = params[f"{jax_prefix}pre_attention_norm"]["scale"]
|
||||||
|
state_dict[f"{prefix}post_attention_layernorm.weight"] = params[f"{jax_prefix}post_attention_norm"]["scale"]
|
||||||
|
state_dict[f"{prefix}pre_feedforward_layernorm.weight"] = params[f"{jax_prefix}pre_ffw_norm"]["scale"]
|
||||||
|
state_dict[f"{prefix}post_feedforward_layernorm.weight"] = params[f"{jax_prefix}post_ffw_norm"]["scale"]
|
||||||
|
|
||||||
|
prefix = f"{gemma_prefix}model.layers.{layer_idx}.self_attn."
|
||||||
|
jax_prefix = f"transformer/layer_{layer_idx}/attn/"
|
||||||
|
state_dict[f"{prefix}q_norm.weight"] = params[f"{jax_prefix}_query_norm"]["scale"]
|
||||||
|
state_dict[f"{prefix}k_norm.weight"] = params[f"{jax_prefix}_key_norm"]["scale"]
|
||||||
|
|
||||||
|
# (num_heads, hidden_size, head_dim) -> (num_heads * head_dim, hidden_size)
|
||||||
|
state_dict[f"{prefix}q_proj.weight"] = flatten(params[f"{jax_prefix}q_einsum"]["w"].transpose(0, 2, 1), end=1)
|
||||||
|
state_dict[f"{prefix}k_proj.weight"] = flatten(
|
||||||
|
params[f"{jax_prefix}kv_einsum"]["w"][0].transpose(0, 2, 1), end=1
|
||||||
|
)
|
||||||
|
state_dict[f"{prefix}v_proj.weight"] = flatten(
|
||||||
|
params[f"{jax_prefix}kv_einsum"]["w"][1].transpose(0, 2, 1), end=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# (num_heads, head_dim, hidden_size) -> (hidden_size, num_heads * head_dim)
|
||||||
|
state_dict[f"{prefix}o_proj.weight"] = flatten(params[f"{jax_prefix}attn_vec_einsum"]["w"], end=1).T
|
||||||
|
|
||||||
|
prefix = f"{gemma_prefix}model.layers.{layer_idx}.mlp."
|
||||||
|
jax_prefix = f"transformer/layer_{layer_idx}/mlp/"
|
||||||
|
state_dict[f"{prefix}gate_proj.weight"] = params[f"{jax_prefix}gating_einsum"]["w"][0]
|
||||||
|
state_dict[f"{prefix}up_proj.weight"] = params[f"{jax_prefix}gating_einsum"]["w"][1]
|
||||||
|
state_dict[f"{prefix}down_proj.weight"] = params[f"{jax_prefix}linear"]["w"].T
|
||||||
|
|
||||||
|
yield state_dict
|
||||||
|
|
||||||
|
# vision tower
|
||||||
|
if num_siglip_layers > 0:
|
||||||
|
siglip_state_dict = convert_siglip(params, num_siglip_layers)
|
||||||
|
for k, v in siglip_state_dict.items():
|
||||||
|
state_dict[f"vision_tower.vision_model.{k}"] = v
|
||||||
|
yield state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def convert_awq(state_dict: dict[str, np.ndarray]):
|
||||||
|
awq_state_dict = dict()
|
||||||
|
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
if (
|
||||||
|
k.endswith("model.embed_tokens.weight") # AWQ doesn't support INT4 embeddings
|
||||||
|
or k.startswith(("vision_tower", "multi_modal_projector")) # vision tower is not quantized
|
||||||
|
or v.ndim == 1
|
||||||
|
):
|
||||||
|
awq_state_dict[k] = v.astype(jnp.bfloat16)
|
||||||
|
continue
|
||||||
|
|
||||||
|
assert v.ndim == 2
|
||||||
|
v = v.T # AWQ transpose the weight
|
||||||
|
|
||||||
|
K, N = v.shape
|
||||||
|
scales = find_scales(v, dim=0) # (K/32, N)
|
||||||
|
inv_scale = 1 / scales.clip(1e-12)
|
||||||
|
qweight = np.round(v.reshape(K // 32, 32, N) * inv_scale[:, None])
|
||||||
|
|
||||||
|
# AWQ is actually UINT4 (instead of INT4)
|
||||||
|
# hence, we will shift qweight up by 8 (even though Google AQT only uses [-7,7])
|
||||||
|
# and set zero_point = 8
|
||||||
|
qweight = (qweight + 8).astype(np.uint32)
|
||||||
|
|
||||||
|
# AWQ pack 8 int4 into UINT32 in the following layout (from high bits to low bits)
|
||||||
|
# [7 5 3 1 6 4 2 0] along the 2nd dim
|
||||||
|
qweight = qweight.reshape(K, N // 8, 8)
|
||||||
|
qweight_packed = (
|
||||||
|
(qweight[..., 7] << (7 * 4))
|
||||||
|
| (qweight[..., 5] << (6 * 4))
|
||||||
|
| (qweight[..., 3] << (5 * 4))
|
||||||
|
| (qweight[..., 1] << (4 * 4))
|
||||||
|
| (qweight[..., 6] << (3 * 4))
|
||||||
|
| (qweight[..., 4] << (2 * 4))
|
||||||
|
| (qweight[..., 2] << (1 * 4))
|
||||||
|
| (qweight[..., 0] << (0 * 4))
|
||||||
|
)
|
||||||
|
qweight_packed = qweight_packed.view(np.int32).reshape(K, N // 8)
|
||||||
|
|
||||||
|
prefix = k.removesuffix(".weight")
|
||||||
|
awq_state_dict[f"{prefix}.qweight"] = qweight_packed
|
||||||
|
awq_state_dict[f"{prefix}.qzeros"] = np.full((K // 32, N // 8), 0x8888_8888, dtype=np.uint32).view(np.int32)
|
||||||
|
awq_state_dict[f"{prefix}.scales"] = scales.astype(jnp.bfloat16)
|
||||||
|
|
||||||
|
return awq_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--ckpt_dir", required=True, type=Path)
|
||||||
|
parser.add_argument("--save_dir", required=True, type=Path)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
args.save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
total_size = 0
|
||||||
|
weight_map = dict()
|
||||||
|
|
||||||
|
state_dict = dict()
|
||||||
|
size = 0
|
||||||
|
shard_idx = 0
|
||||||
|
filename = f"model-{shard_idx + 1:05d}.safetensors"
|
||||||
|
for sub_state_dict in tqdm(convert_to_hf(args.ckpt_dir)):
|
||||||
|
sub_state_dict = convert_awq(sub_state_dict)
|
||||||
|
new_size = sum(v.nbytes for v in sub_state_dict.values())
|
||||||
|
|
||||||
|
if size + new_size > 5e9:
|
||||||
|
save_file(state_dict, args.save_dir / filename)
|
||||||
|
state_dict = dict()
|
||||||
|
size = 0
|
||||||
|
shard_idx += 1
|
||||||
|
filename = f"model-{shard_idx + 1:05d}.safetensors"
|
||||||
|
|
||||||
|
# assume that new_size < 5e9
|
||||||
|
size += new_size
|
||||||
|
total_size += new_size
|
||||||
|
for k, v in sub_state_dict.items():
|
||||||
|
state_dict[k] = v
|
||||||
|
weight_map[k] = filename
|
||||||
|
|
||||||
|
save_file(state_dict, args.save_dir / filename)
|
||||||
|
json.dump(
|
||||||
|
dict(metadata=dict(total_size=total_size), weight_map=weight_map),
|
||||||
|
open(args.save_dir / "model.safetensors.index.json", "w"),
|
||||||
|
)
|
||||||
13
generation_config.json
Normal file
13
generation_config.json
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"bos_token_id": 2,
|
||||||
|
"cache_implementation": "hybrid",
|
||||||
|
"do_sample": true,
|
||||||
|
"eos_token_id": [
|
||||||
|
1,
|
||||||
|
106
|
||||||
|
],
|
||||||
|
"pad_token_id": 0,
|
||||||
|
"top_k": 64,
|
||||||
|
"top_p": 0.95,
|
||||||
|
"transformers_version": "4.50.0.dev0"
|
||||||
|
}
|
||||||
3
model-00001-of-00004.safetensors
Normal file
3
model-00001-of-00004.safetensors
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:dfe46693f149bea38e2c089161ab180ea08f03f2a647d9ade2d63c9f290c0f77
|
||||||
|
size 4980332880
|
||||||
3
model-00002-of-00004.safetensors
Normal file
3
model-00002-of-00004.safetensors
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:b7e265b1583962613677420aa8de95afd96d7cbfd47fee19d2f4959fadd978a9
|
||||||
|
size 4774828736
|
||||||
3
model-00003-of-00004.safetensors
Normal file
3
model-00003-of-00004.safetensors
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:fc7283de8174bc0d3242ce8530eb05d636208d6164ecdeeefbf3fb672c118d12
|
||||||
|
size 4774828760
|
||||||
3
model-00004-of-00004.safetensors
Normal file
3
model-00004-of-00004.safetensors
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:0cff33cb628f34e3b11b07952bcc696fd0f2528141b251f12a15024334e7b577
|
||||||
|
size 3937431096
|
||||||
2122
model.safetensors.index.json
Normal file
2122
model.safetensors.index.json
Normal file
File diff suppressed because it is too large
Load Diff
29
preprocessor_config.json
Normal file
29
preprocessor_config.json
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
{
|
||||||
|
"do_convert_rgb": null,
|
||||||
|
"do_normalize": true,
|
||||||
|
"do_pan_and_scan": null,
|
||||||
|
"do_rescale": true,
|
||||||
|
"do_resize": true,
|
||||||
|
"image_mean": [
|
||||||
|
0.5,
|
||||||
|
0.5,
|
||||||
|
0.5
|
||||||
|
],
|
||||||
|
"image_processor_type": "Gemma3ImageProcessor",
|
||||||
|
"image_seq_length": 256,
|
||||||
|
"image_std": [
|
||||||
|
0.5,
|
||||||
|
0.5,
|
||||||
|
0.5
|
||||||
|
],
|
||||||
|
"pan_and_scan_max_num_crops": null,
|
||||||
|
"pan_and_scan_min_crop_size": null,
|
||||||
|
"pan_and_scan_min_ratio_to_activate": null,
|
||||||
|
"processor_class": "Gemma3Processor",
|
||||||
|
"resample": 2,
|
||||||
|
"rescale_factor": 0.00392156862745098,
|
||||||
|
"size": {
|
||||||
|
"height": 896,
|
||||||
|
"width": 896
|
||||||
|
}
|
||||||
|
}
|
||||||
4
processor_config.json
Normal file
4
processor_config.json
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
{
|
||||||
|
"image_seq_length": 256,
|
||||||
|
"processor_class": "Gemma3Processor"
|
||||||
|
}
|
||||||
33
special_tokens_map.json
Normal file
33
special_tokens_map.json
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
{
|
||||||
|
"boi_token": "<start_of_image>",
|
||||||
|
"bos_token": {
|
||||||
|
"content": "<bos>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false
|
||||||
|
},
|
||||||
|
"eoi_token": "<end_of_image>",
|
||||||
|
"eos_token": {
|
||||||
|
"content": "<eos>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false
|
||||||
|
},
|
||||||
|
"image_token": "<image_soft_token>",
|
||||||
|
"pad_token": {
|
||||||
|
"content": "<pad>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false
|
||||||
|
},
|
||||||
|
"unk_token": {
|
||||||
|
"content": "<unk>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false
|
||||||
|
}
|
||||||
|
}
|
||||||
3
tokenizer.json
Normal file
3
tokenizer.json
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:4667f2089529e8e7657cfb6d1c19910ae71ff5f28aa7ab2ff2763330affad795
|
||||||
|
size 33384568
|
||||||
3
tokenizer.model
Normal file
3
tokenizer.model
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c
|
||||||
|
size 4689074
|
||||||
51346
tokenizer_config.json
Normal file
51346
tokenizer_config.json
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user