enable auto-round quantization model (#6226)
Signed-off-by: Zhang, Weiwei1 <weiwei1.zhang@intel.com>
This commit is contained in:
@@ -40,6 +40,93 @@ python3 -m sglang.launch_server \
|
||||
|
||||
### Examples of Offline Model Quantization
|
||||
|
||||
|
||||
#### Using [auto-round](https://github.com/intel/auto-round)
|
||||
|
||||
```bash
|
||||
# Install
|
||||
pip install auto-round
|
||||
```
|
||||
|
||||
- LLM quantization
|
||||
|
||||
```py
|
||||
# for LLM
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from auto_round import AutoRound
|
||||
model_id = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
quant_path = "Llama-3.2-1B-Instruct-autoround-4bit"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
bits, group_size, sym = 4, 128, True # set quantize args
|
||||
autoround = AutoRound(model, tokenizer, bits=bits, group_size=group_size, sym=sym)
|
||||
format='auto_round'
|
||||
autoround.quantize_and_save(quant_path, format=format) # quantize and save
|
||||
|
||||
```
|
||||
|
||||
- VLM quantization
|
||||
```py
|
||||
# for VLMs
|
||||
from auto_round import AutoRoundMLLM
|
||||
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, AutoTokenizer
|
||||
model_name = "Qwen/Qwen2-VL-2B-Instruct"
|
||||
quant_path = "Qwen2-VL-2B-Instruct-autoround-4bit"
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
model_name, trust_remote_code=True, torch_dtype="auto")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||||
bits, group_size, sym = 4, 128, True
|
||||
autoround = AutoRoundMLLM(model, tokenizer, processor,
|
||||
bits=bits, group_size=group_size, sym=sym)
|
||||
format='auto_round'
|
||||
autoround.quantize_and_save(quant_path, format=format) # quantize and save
|
||||
|
||||
```
|
||||
|
||||
- Command Line Usage (Gaudi/CPU/Intel GPU/CUDA)
|
||||
|
||||
```bash
|
||||
auto-round \
|
||||
--model meta-llama/Llama-3.2-1B-Instruct \
|
||||
--bits 4 \
|
||||
--group_size 128 \
|
||||
--format "auto_gptq,auto_awq,auto_round" \
|
||||
--output_dir ./tmp_autoround
|
||||
```
|
||||
|
||||
- known issues
|
||||
|
||||
Several limitations currently affect offline quantized model loading in sglang, These issues might be resolved in future updates of sglang. If you experience any problems, consider using Hugging Face Transformers as an alternative.
|
||||
|
||||
1. Mixed-bit Quantization Limitations
|
||||
|
||||
Mixed-bit quantization is not fully supported. Due to vLLM's layer fusion (e.g., QKV fusion), applying different bit-widths to components within the same fused layer can lead to compatibility issues.
|
||||
|
||||
|
||||
2. Limited Support for Quantized MoE Models
|
||||
|
||||
Quantized MoE models may encounter inference issues due to kernel limitations (e.g., lack of support for mlp.gate layer quantization). To avoid such errors, please skip quantizing gate layers when processing quantization to MoE modules.
|
||||
|
||||
|
||||
3. Limited Support for Quantized VLMs
|
||||
<details>
|
||||
<summary>VLM failure cases</summary>
|
||||
|
||||
Qwen2.5-VL-7B
|
||||
|
||||
auto_round:auto_gptq format: Accuracy is close to zero.
|
||||
|
||||
GPTQ format: Fails with:
|
||||
```
|
||||
The output size is not aligned with the quantized weight shape
|
||||
```
|
||||
|
||||
auto_round:auto_awq and AWQ format: These work as expected.
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
#### Using [GPTQModel](https://github.com/ModelCloud/GPTQModel)
|
||||
|
||||
```bash
|
||||
@@ -150,3 +237,4 @@ python3 -m sglang.launch_server \
|
||||
- [LLM Compressor](https://github.com/vllm-project/llm-compressor/)
|
||||
- [Torchao: PyTorch Architecture Optimization](https://github.com/pytorch/ao)
|
||||
- [vLLM Quantization](https://docs.vllm.ai/en/latest/quantization/)
|
||||
- [auto-round](https://github.com/intel/auto-round)
|
||||
|
||||
Reference in New Issue
Block a user