diff --git a/docs/advanced_features/server_arguments.md b/docs/advanced_features/server_arguments.md index aad5bbc05..5960cfb2d 100644 --- a/docs/advanced_features/server_arguments.md +++ b/docs/advanced_features/server_arguments.md @@ -112,7 +112,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `--dtype` | Data type for model weights and activations. 'auto' will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models. 'half' for FP16. Recommended for AWQ quantization. 'float16' is the same as 'half'. 'bfloat16' for a balance between precision and range. 'float' is shorthand for FP32 precision. 'float32' for FP32 precision. | auto | | `--quantization` | The quantization method. | None | | `--quantization-param-path` | Path to the JSON file containing the KV cache scaling factors. This should generally be supplied, when KV cache dtype is FP8. Otherwise, KV cache scaling factors default to 1.0, which may cause accuracy issues. | None | -| `--kv-cache-dtype` | Data type for kv cache storage. 'auto' will use model data type. 'fp8_e5m2' and 'fp8_e4m3' is supported for CUDA 11.8+. | auto | +| `--kv-cache-dtype` | Data type for kv cache storage. 'auto' will use model data type. 'bf16' or 'bfloat16' for BF16 KV cache. 'fp8_e5m2' and 'fp8_e4m3' are supported for CUDA 11.8+. | auto | | `--enable-fp32-lm-head` | If set, the LM head outputs (logits) are in FP32. | False | ## Memory and scheduling diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 6f8bc40ee..6fce4cda4 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1567,6 +1567,8 @@ class ModelRunner: self.kv_cache_dtype = torch.float8_e4m3fnuz else: self.kv_cache_dtype = torch.float8_e4m3fn + elif self.server_args.kv_cache_dtype in ("bf16", "bfloat16"): + self.kv_cache_dtype = torch.bfloat16 else: raise ValueError( f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}." diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 71113c5e8..3ffb7935f 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1652,8 +1652,8 @@ class ServerArgs: "--kv-cache-dtype", type=str, default=ServerArgs.kv_cache_dtype, - choices=["auto", "fp8_e5m2", "fp8_e4m3"], - help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.', + choices=["auto", "fp8_e5m2", "fp8_e4m3", "bf16", "bfloat16"], + help='Data type for kv cache storage. "auto" will use model data type. "bf16" or "bfloat16" for BF16 KV cache. "fp8_e5m2" and "fp8_e4m3" are supported for CUDA 11.8+.', ) parser.add_argument( "--enable-fp32-lm-head",