From e57c3e12b89ad5b06a5166f300991ccfe9867560 Mon Sep 17 00:00:00 2001 From: HAI Date: Tue, 19 Nov 2024 14:06:29 -0800 Subject: [PATCH] Use native fp8 format on MI300X (#2094) --- python/sglang/srt/model_executor/model_runner.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 036be8675..c3e14c1ec 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -60,6 +60,7 @@ from sglang.srt.utils import ( crash_on_warnings, enable_show_time_cost, get_available_gpu_memory, + is_hip, monkey_patch_vllm_model_config, monkey_patch_vllm_p2p_access_check, ) @@ -440,7 +441,10 @@ class ModelRunner: if self.server_args.kv_cache_dtype == "auto": self.kv_cache_dtype = self.dtype elif self.server_args.kv_cache_dtype == "fp8_e5m2": - self.kv_cache_dtype = torch.float8_e5m2 + if is_hip(): # Using natively supported format + self.kv_cache_dtype = torch.float8_e5m2fnuz + else: + self.kv_cache_dtype = torch.float8_e5m2 else: raise ValueError( f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."