[Feature] Support fp8 e5m2 kv cache with flashinfer (#1204)
Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
@@ -311,7 +311,7 @@ class ModelRunner:
|
||||
cell_size = (
|
||||
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
|
||||
* self.model_config.num_hidden_layers
|
||||
* torch._utils._element_size(self.dtype)
|
||||
* torch._utils._element_size(self.kv_cache_dtype)
|
||||
)
|
||||
else:
|
||||
cell_size = (
|
||||
@@ -319,7 +319,7 @@ class ModelRunner:
|
||||
* self.model_config.head_dim
|
||||
* self.model_config.num_hidden_layers
|
||||
* 2
|
||||
* torch._utils._element_size(self.dtype)
|
||||
* torch._utils._element_size(self.kv_cache_dtype)
|
||||
)
|
||||
rest_memory = available_gpu_memory - total_gpu_memory * (
|
||||
1 - self.mem_fraction_static
|
||||
@@ -333,6 +333,21 @@ class ModelRunner:
|
||||
max_num_reqs: int = None,
|
||||
max_total_tokens: int = None,
|
||||
):
|
||||
if self.server_args.kv_cache_dtype == "auto":
|
||||
self.kv_cache_dtype = self.dtype
|
||||
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
|
||||
if self.server_args.disable_flashinfer or self.server_args.enable_mla:
|
||||
logger.warning(
|
||||
"FP8 KV cache is not supported for Triton kernel now, using auto kv cache dtype"
|
||||
)
|
||||
self.kv_cache_dtype = self.dtype
|
||||
else:
|
||||
self.kv_cache_dtype = torch.float8_e5m2
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
|
||||
)
|
||||
|
||||
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
|
||||
if max_total_tokens is not None:
|
||||
if max_total_tokens > self.max_total_num_tokens:
|
||||
@@ -369,7 +384,7 @@ class ModelRunner:
|
||||
):
|
||||
self.token_to_kv_pool = MLATokenToKVPool(
|
||||
self.max_total_num_tokens,
|
||||
dtype=self.dtype,
|
||||
dtype=self.kv_cache_dtype,
|
||||
kv_lora_rank=self.model_config.kv_lora_rank,
|
||||
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
||||
layer_num=self.model_config.num_hidden_layers,
|
||||
@@ -380,7 +395,7 @@ class ModelRunner:
|
||||
else:
|
||||
self.token_to_kv_pool = MHATokenToKVPool(
|
||||
self.max_total_num_tokens,
|
||||
dtype=self.dtype,
|
||||
dtype=self.kv_cache_dtype,
|
||||
head_num=self.model_config.get_num_kv_heads(self.tp_size),
|
||||
head_dim=self.model_config.head_dim,
|
||||
layer_num=self.model_config.num_hidden_layers,
|
||||
|
||||
Reference in New Issue
Block a user