support server arg override KV cache to bf16 to avoid slow cases (#11749)
This commit is contained in:
@@ -112,7 +112,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
|
|||||||
| `--dtype` | Data type for model weights and activations. 'auto' will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models. 'half' for FP16. Recommended for AWQ quantization. 'float16' is the same as 'half'. 'bfloat16' for a balance between precision and range. 'float' is shorthand for FP32 precision. 'float32' for FP32 precision. | auto |
|
| `--dtype` | Data type for model weights and activations. 'auto' will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models. 'half' for FP16. Recommended for AWQ quantization. 'float16' is the same as 'half'. 'bfloat16' for a balance between precision and range. 'float' is shorthand for FP32 precision. 'float32' for FP32 precision. | auto |
|
||||||
| `--quantization` | The quantization method. | None |
|
| `--quantization` | The quantization method. | None |
|
||||||
| `--quantization-param-path` | Path to the JSON file containing the KV cache scaling factors. This should generally be supplied, when KV cache dtype is FP8. Otherwise, KV cache scaling factors default to 1.0, which may cause accuracy issues. | None |
|
| `--quantization-param-path` | Path to the JSON file containing the KV cache scaling factors. This should generally be supplied, when KV cache dtype is FP8. Otherwise, KV cache scaling factors default to 1.0, which may cause accuracy issues. | None |
|
||||||
| `--kv-cache-dtype` | Data type for kv cache storage. 'auto' will use model data type. 'fp8_e5m2' and 'fp8_e4m3' is supported for CUDA 11.8+. | auto |
|
| `--kv-cache-dtype` | Data type for kv cache storage. 'auto' will use model data type. 'bf16' or 'bfloat16' for BF16 KV cache. 'fp8_e5m2' and 'fp8_e4m3' are supported for CUDA 11.8+. | auto |
|
||||||
| `--enable-fp32-lm-head` | If set, the LM head outputs (logits) are in FP32. | False |
|
| `--enable-fp32-lm-head` | If set, the LM head outputs (logits) are in FP32. | False |
|
||||||
|
|
||||||
## Memory and scheduling
|
## Memory and scheduling
|
||||||
|
|||||||
@@ -1567,6 +1567,8 @@ class ModelRunner:
|
|||||||
self.kv_cache_dtype = torch.float8_e4m3fnuz
|
self.kv_cache_dtype = torch.float8_e4m3fnuz
|
||||||
else:
|
else:
|
||||||
self.kv_cache_dtype = torch.float8_e4m3fn
|
self.kv_cache_dtype = torch.float8_e4m3fn
|
||||||
|
elif self.server_args.kv_cache_dtype in ("bf16", "bfloat16"):
|
||||||
|
self.kv_cache_dtype = torch.bfloat16
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
|
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
|
||||||
|
|||||||
@@ -1652,8 +1652,8 @@ class ServerArgs:
|
|||||||
"--kv-cache-dtype",
|
"--kv-cache-dtype",
|
||||||
type=str,
|
type=str,
|
||||||
default=ServerArgs.kv_cache_dtype,
|
default=ServerArgs.kv_cache_dtype,
|
||||||
choices=["auto", "fp8_e5m2", "fp8_e4m3"],
|
choices=["auto", "fp8_e5m2", "fp8_e4m3", "bf16", "bfloat16"],
|
||||||
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
|
help='Data type for kv cache storage. "auto" will use model data type. "bf16" or "bfloat16" for BF16 KV cache. "fp8_e5m2" and "fp8_e4m3" are supported for CUDA 11.8+.',
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable-fp32-lm-head",
|
"--enable-fp32-lm-head",
|
||||||
|
|||||||
Reference in New Issue
Block a user