102 lines
3.4 KiB
Markdown
102 lines
3.4 KiB
Markdown
|
|
---
|
|||
|
|
license: apache-2.0
|
|||
|
|
tags:
|
|||
|
|
- prismml
|
|||
|
|
- bonsai
|
|||
|
|
- awq
|
|||
|
|
- 4-bit
|
|||
|
|
pipeline_tag: text-generation
|
|||
|
|
base_model: prism-ml/Bonsai-8B-unpacked
|
|||
|
|
library_name: transformers
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
# Bonsai-8B — AWQ 4-bit
|
|||
|
|
|
|||
|
|
## Summary
|
|||
|
|
|
|||
|
|
This repo provides an AWQ 4-bit checkpoint so you can run Bonsai-8B on [sglang](https://github.com/sgl-project/sglang) (or vLLM) until native 1-bit support lands in those engines. The repack from 1-bit to AWQ 4-bit is **lossless**. Both formats use group size 128, and Bonsai's binary weights (±d) fit exactly inside AWQ INT4 by an exact conversion formula.
|
|||
|
|
|
|||
|
|
> **For the best Bonsai experience on edge or consumer-grade hardware, use the native 1-bit releases.** The 1-bit format is where Bonsai's memory and energy wins come from.
|
|||
|
|
>
|
|||
|
|
> - **[Bonsai-8B MLX 1-bit](https://huggingface.co/prism-ml/Bonsai-8B-mlx-1bit)** — 1-bit MLX for Apple Silicon.
|
|||
|
|
> - **[Bonsai-8B-gguf](https://huggingface.co/prism-ml/Bonsai-8B-gguf)** - 1-bit gguf supported by llama.cpp across many backends (GPU, Metal, CPU, Vulkan, etc)
|
|||
|
|
> - **[Bonsai-8B FP16](https://huggingface.co/prism-ml/Bonsai-8B-unpacked)** — FP16 safetensors for stock HuggingFace tooling.
|
|||
|
|
|
|||
|
|
## How It Works
|
|||
|
|
|
|||
|
|
Bonsai weights are ±d (binary) with a shared scale across a group size of 128. INT4 can represent these values exactly. Embedding and `lm_head` stay FP16 due to sglang limitations.
|
|||
|
|
|
|||
|
|
AWQ INT4 dequantization: `weight = scale × (int4 − zero)`.
|
|||
|
|
|
|||
|
|
```
|
|||
|
|
+d → scale=d, int4=9, zero=8 → d × (9-8) = +d
|
|||
|
|
-d → scale=d, int4=7, zero=8 → d × (7-8) = -d
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
## Serve
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
pip install sglang
|
|||
|
|
|
|||
|
|
python -m sglang.launch_server \
|
|||
|
|
--model /path/to/Bonsai-8B-awq/ \
|
|||
|
|
--port 8000 \
|
|||
|
|
--dtype bfloat16
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
## Use
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
# Completion API
|
|||
|
|
curl http://localhost:8000/v1/completions \
|
|||
|
|
-H "Content-Type: application/json" \
|
|||
|
|
-d '{"model":"Bonsai-8B","prompt":"The capital of France is","max_tokens":20}'
|
|||
|
|
|
|||
|
|
# Chat API
|
|||
|
|
curl http://localhost:8000/v1/chat/completions \
|
|||
|
|
-H "Content-Type: application/json" \
|
|||
|
|
-d '{"model":"Bonsai-8B","messages":[{"role":"user","content":"Who are you?"}],"max_tokens":100}'
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
## Multi-GPU (8× H100)
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
# Option A: DP=8 — 8 independent replicas, no inter-GPU comms
|
|||
|
|
python -m sglang.launch_server \
|
|||
|
|
--model /path/to/Bonsai-8B-awq/ \
|
|||
|
|
--dp-size 8 \
|
|||
|
|
--load-balance-method total_tokens \
|
|||
|
|
--port 8000 --dtype bfloat16
|
|||
|
|
|
|||
|
|
# Option B: TP=2 DP=4 — 4 replicas, each split across 2 GPUs
|
|||
|
|
python -m sglang.launch_server \
|
|||
|
|
--model /path/to/Bonsai-8B-awq/ \
|
|||
|
|
--tp-size 2 --dp-size 4 \
|
|||
|
|
--load-balance-method total_tokens \
|
|||
|
|
--port 8000 --dtype bfloat16
|
|||
|
|
|
|||
|
|
# Option C: TP=4 DP=2 — 2 replicas across 4 GPUs each
|
|||
|
|
python -m sglang.launch_server \
|
|||
|
|
--model /path/to/Bonsai-8B-awq/ \
|
|||
|
|
--tp-size 4 --dp-size 2 \
|
|||
|
|
--load-balance-method total_tokens \
|
|||
|
|
--port 8000 --dtype bfloat16
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
## Appendix
|
|||
|
|
|
|||
|
|
### Launch time
|
|||
|
|
|
|||
|
|
With pre-built sgl-kernel wheels: 15 s. First launch on an arch without pre-built wheels (e.g. L40S / sm_89) takes 3–5 min while sglang JIT-compiles Marlin GEMM + FlashInfer kernels; artifacts cache to `~/.cache/tvm-ffi/` and `~/.cache/flashinfer/`, so subsequent launches drop back to 15 s.
|
|||
|
|
|
|||
|
|
### Known-good environment
|
|||
|
|
|
|||
|
|
Example of successful end-to-end serving environment:
|
|||
|
|
|
|||
|
|
- `sglang[all] == 0.5.9`
|
|||
|
|
- `torch == 2.9.1`, `transformers == 4.57.1`, `triton == 3.5.1`
|
|||
|
|
- `ninja == 1.13` on `PATH`
|
|||
|
|
- `nvcc` from **CUDA 12.8** first on `PATH` (sglang's JIT Marlin needs `-std=c++20`; CUDA 11.x will fail)
|
|||
|
|
- Python 3.12
|
|||
|
|
|