Use native fp8 format on MI300X (#2094)
This commit is contained in:
@@ -60,6 +60,7 @@ from sglang.srt.utils import (
|
|||||||
crash_on_warnings,
|
crash_on_warnings,
|
||||||
enable_show_time_cost,
|
enable_show_time_cost,
|
||||||
get_available_gpu_memory,
|
get_available_gpu_memory,
|
||||||
|
is_hip,
|
||||||
monkey_patch_vllm_model_config,
|
monkey_patch_vllm_model_config,
|
||||||
monkey_patch_vllm_p2p_access_check,
|
monkey_patch_vllm_p2p_access_check,
|
||||||
)
|
)
|
||||||
@@ -440,6 +441,9 @@ class ModelRunner:
|
|||||||
if self.server_args.kv_cache_dtype == "auto":
|
if self.server_args.kv_cache_dtype == "auto":
|
||||||
self.kv_cache_dtype = self.dtype
|
self.kv_cache_dtype = self.dtype
|
||||||
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
|
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
|
||||||
|
if is_hip(): # Using natively supported format
|
||||||
|
self.kv_cache_dtype = torch.float8_e5m2fnuz
|
||||||
|
else:
|
||||||
self.kv_cache_dtype = torch.float8_e5m2
|
self.kv_cache_dtype = torch.float8_e5m2
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
Reference in New Issue
Block a user