[Feature] Support fp8 e5m2 kv cache with flashinfer (#1204)
Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
@@ -203,7 +203,6 @@ class RadixAttention(nn.Module):
|
|||||||
return self.decode_forward(q, k, v, input_metadata)
|
return self.decode_forward(q, k, v, input_metadata)
|
||||||
|
|
||||||
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
|
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
|
||||||
k_cache = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id)
|
input_metadata.token_to_kv_pool.set_kv_buffer(
|
||||||
v_cache = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id)
|
self.layer_id, input_metadata.out_cache_loc, cache_k, cache_v
|
||||||
k_cache[input_metadata.out_cache_loc] = cache_k
|
)
|
||||||
v_cache[input_metadata.out_cache_loc] = cache_v
|
|
||||||
|
|||||||
@@ -16,7 +16,8 @@ limitations under the License.
|
|||||||
"""Memory pool."""
|
"""Memory pool."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Union
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -52,14 +53,21 @@ class ReqToTokenPool:
|
|||||||
self.free_slots = list(range(self.size))
|
self.free_slots = list(range(self.size))
|
||||||
|
|
||||||
|
|
||||||
class BaseTokenToKVPool:
|
class BaseTokenToKVPool(ABC):
|
||||||
"""A memory pool that maps a token to its kv cache locations"""
|
"""A memory pool that maps a token to its kv cache locations"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
size: int,
|
size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
):
|
):
|
||||||
self.size = size
|
self.size = size
|
||||||
|
self.dtype = dtype
|
||||||
|
if dtype == torch.float8_e5m2:
|
||||||
|
# NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
|
||||||
|
self.store_dtype = torch.uint8
|
||||||
|
else:
|
||||||
|
self.store_dtype = dtype
|
||||||
|
|
||||||
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
||||||
self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
|
self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
|
||||||
@@ -112,6 +120,28 @@ class BaseTokenToKVPool:
|
|||||||
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
||||||
self.mem_state[0] = False
|
self.mem_state[0] = False
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def set_kv_buffer(
|
||||||
|
self,
|
||||||
|
layer_id: int,
|
||||||
|
loc: torch.Tensor,
|
||||||
|
cache_k: torch.Tensor,
|
||||||
|
cache_v: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class MHATokenToKVPool(BaseTokenToKVPool):
|
class MHATokenToKVPool(BaseTokenToKVPool):
|
||||||
|
|
||||||
@@ -123,26 +153,52 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
|||||||
head_dim: int,
|
head_dim: int,
|
||||||
layer_num: int,
|
layer_num: int,
|
||||||
):
|
):
|
||||||
super().__init__(size)
|
super().__init__(size, dtype)
|
||||||
|
|
||||||
# [size, head_num, head_dim] for each layer
|
# [size, head_num, head_dim] for each layer
|
||||||
self.k_buffer = [
|
self.k_buffer = [
|
||||||
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
|
torch.empty(
|
||||||
|
(size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda"
|
||||||
|
)
|
||||||
for _ in range(layer_num)
|
for _ in range(layer_num)
|
||||||
]
|
]
|
||||||
self.v_buffer = [
|
self.v_buffer = [
|
||||||
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
|
torch.empty(
|
||||||
|
(size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda"
|
||||||
|
)
|
||||||
for _ in range(layer_num)
|
for _ in range(layer_num)
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_key_buffer(self, layer_id: int):
|
def get_key_buffer(self, layer_id: int):
|
||||||
|
if self.store_dtype != self.dtype:
|
||||||
|
return self.k_buffer[layer_id].view(self.dtype)
|
||||||
return self.k_buffer[layer_id]
|
return self.k_buffer[layer_id]
|
||||||
|
|
||||||
def get_value_buffer(self, layer_id: int):
|
def get_value_buffer(self, layer_id: int):
|
||||||
|
if self.store_dtype != self.dtype:
|
||||||
|
return self.v_buffer[layer_id].view(self.dtype)
|
||||||
return self.v_buffer[layer_id]
|
return self.v_buffer[layer_id]
|
||||||
|
|
||||||
def get_kv_buffer(self, layer_id: int):
|
def get_kv_buffer(self, layer_id: int):
|
||||||
return self.k_buffer[layer_id], self.v_buffer[layer_id]
|
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
|
||||||
|
|
||||||
|
def set_kv_buffer(
|
||||||
|
self,
|
||||||
|
layer_id: int,
|
||||||
|
loc: torch.Tensor,
|
||||||
|
cache_k: torch.Tensor,
|
||||||
|
cache_v: torch.Tensor,
|
||||||
|
):
|
||||||
|
if cache_k.dtype != self.dtype:
|
||||||
|
cache_k = cache_k.to(self.dtype)
|
||||||
|
if cache_v.dtype != self.dtype:
|
||||||
|
cache_v = cache_v.to(self.dtype)
|
||||||
|
if self.store_dtype != self.dtype:
|
||||||
|
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
|
||||||
|
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
|
||||||
|
else:
|
||||||
|
self.k_buffer[layer_id][loc] = cache_k
|
||||||
|
self.v_buffer[layer_id][loc] = cache_v
|
||||||
|
|
||||||
|
|
||||||
class MLATokenToKVPool(BaseTokenToKVPool):
|
class MLATokenToKVPool(BaseTokenToKVPool):
|
||||||
@@ -155,23 +211,41 @@ class MLATokenToKVPool(BaseTokenToKVPool):
|
|||||||
qk_rope_head_dim: int,
|
qk_rope_head_dim: int,
|
||||||
layer_num: int,
|
layer_num: int,
|
||||||
):
|
):
|
||||||
super().__init__(size)
|
super().__init__(size, dtype)
|
||||||
|
|
||||||
self.kv_lora_rank = kv_lora_rank
|
self.kv_lora_rank = kv_lora_rank
|
||||||
self.kv_buffer = [
|
self.kv_buffer = [
|
||||||
torch.empty(
|
torch.empty(
|
||||||
(size + 1, 1, kv_lora_rank + qk_rope_head_dim),
|
(size + 1, 1, kv_lora_rank + qk_rope_head_dim),
|
||||||
dtype=dtype,
|
dtype=self.store_dtype,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
)
|
)
|
||||||
for _ in range(layer_num)
|
for _ in range(layer_num)
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_key_buffer(self, layer_id: int):
|
def get_key_buffer(self, layer_id: int):
|
||||||
|
if self.store_dtype != self.dtype:
|
||||||
|
return self.kv_buffer[layer_id].view(self.dtype)
|
||||||
return self.kv_buffer[layer_id]
|
return self.kv_buffer[layer_id]
|
||||||
|
|
||||||
def get_value_buffer(self, layer_id: int):
|
def get_value_buffer(self, layer_id: int):
|
||||||
|
if self.store_dtype != self.dtype:
|
||||||
|
return self.kv_buffer[layer_id][..., : self.kv_lora_rank].view(self.dtype)
|
||||||
return self.kv_buffer[layer_id][..., : self.kv_lora_rank]
|
return self.kv_buffer[layer_id][..., : self.kv_lora_rank]
|
||||||
|
|
||||||
def get_kv_buffer(self, layer_id: int):
|
def get_kv_buffer(self, layer_id: int):
|
||||||
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
|
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
|
||||||
|
|
||||||
|
def set_kv_buffer(
|
||||||
|
self,
|
||||||
|
layer_id: int,
|
||||||
|
loc: torch.Tensor,
|
||||||
|
cache_k: torch.Tensor,
|
||||||
|
cache_v: torch.Tensor,
|
||||||
|
):
|
||||||
|
if cache_k.dtype != self.dtype:
|
||||||
|
cache_k = cache_k.to(self.dtype)
|
||||||
|
if self.store_dtype != self.dtype:
|
||||||
|
self.kv_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
|
||||||
|
else:
|
||||||
|
self.kv_buffer[layer_id][loc] = cache_k
|
||||||
|
|||||||
@@ -315,6 +315,8 @@ def update_flashinfer_indices(
|
|||||||
num_kv_heads,
|
num_kv_heads,
|
||||||
head_dim,
|
head_dim,
|
||||||
1,
|
1,
|
||||||
|
data_type=model_runner.kv_cache_dtype,
|
||||||
|
q_data_type=model_runner.dtype,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# extend part
|
# extend part
|
||||||
@@ -393,6 +395,8 @@ def update_flashinfer_indices(
|
|||||||
num_kv_heads,
|
num_kv_heads,
|
||||||
head_dim,
|
head_dim,
|
||||||
1,
|
1,
|
||||||
|
data_type=model_runner.kv_cache_dtype,
|
||||||
|
q_data_type=model_runner.dtype,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# extend part
|
# extend part
|
||||||
|
|||||||
@@ -311,7 +311,7 @@ class ModelRunner:
|
|||||||
cell_size = (
|
cell_size = (
|
||||||
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
|
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
|
||||||
* self.model_config.num_hidden_layers
|
* self.model_config.num_hidden_layers
|
||||||
* torch._utils._element_size(self.dtype)
|
* torch._utils._element_size(self.kv_cache_dtype)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
cell_size = (
|
cell_size = (
|
||||||
@@ -319,7 +319,7 @@ class ModelRunner:
|
|||||||
* self.model_config.head_dim
|
* self.model_config.head_dim
|
||||||
* self.model_config.num_hidden_layers
|
* self.model_config.num_hidden_layers
|
||||||
* 2
|
* 2
|
||||||
* torch._utils._element_size(self.dtype)
|
* torch._utils._element_size(self.kv_cache_dtype)
|
||||||
)
|
)
|
||||||
rest_memory = available_gpu_memory - total_gpu_memory * (
|
rest_memory = available_gpu_memory - total_gpu_memory * (
|
||||||
1 - self.mem_fraction_static
|
1 - self.mem_fraction_static
|
||||||
@@ -333,6 +333,21 @@ class ModelRunner:
|
|||||||
max_num_reqs: int = None,
|
max_num_reqs: int = None,
|
||||||
max_total_tokens: int = None,
|
max_total_tokens: int = None,
|
||||||
):
|
):
|
||||||
|
if self.server_args.kv_cache_dtype == "auto":
|
||||||
|
self.kv_cache_dtype = self.dtype
|
||||||
|
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
|
||||||
|
if self.server_args.disable_flashinfer or self.server_args.enable_mla:
|
||||||
|
logger.warning(
|
||||||
|
"FP8 KV cache is not supported for Triton kernel now, using auto kv cache dtype"
|
||||||
|
)
|
||||||
|
self.kv_cache_dtype = self.dtype
|
||||||
|
else:
|
||||||
|
self.kv_cache_dtype = torch.float8_e5m2
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
|
||||||
|
)
|
||||||
|
|
||||||
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
|
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
|
||||||
if max_total_tokens is not None:
|
if max_total_tokens is not None:
|
||||||
if max_total_tokens > self.max_total_num_tokens:
|
if max_total_tokens > self.max_total_num_tokens:
|
||||||
@@ -369,7 +384,7 @@ class ModelRunner:
|
|||||||
):
|
):
|
||||||
self.token_to_kv_pool = MLATokenToKVPool(
|
self.token_to_kv_pool = MLATokenToKVPool(
|
||||||
self.max_total_num_tokens,
|
self.max_total_num_tokens,
|
||||||
dtype=self.dtype,
|
dtype=self.kv_cache_dtype,
|
||||||
kv_lora_rank=self.model_config.kv_lora_rank,
|
kv_lora_rank=self.model_config.kv_lora_rank,
|
||||||
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
||||||
layer_num=self.model_config.num_hidden_layers,
|
layer_num=self.model_config.num_hidden_layers,
|
||||||
@@ -380,7 +395,7 @@ class ModelRunner:
|
|||||||
else:
|
else:
|
||||||
self.token_to_kv_pool = MHATokenToKVPool(
|
self.token_to_kv_pool = MHATokenToKVPool(
|
||||||
self.max_total_num_tokens,
|
self.max_total_num_tokens,
|
||||||
dtype=self.dtype,
|
dtype=self.kv_cache_dtype,
|
||||||
head_num=self.model_config.get_num_kv_heads(self.tp_size),
|
head_num=self.model_config.get_num_kv_heads(self.tp_size),
|
||||||
head_dim=self.model_config.head_dim,
|
head_dim=self.model_config.head_dim,
|
||||||
layer_num=self.model_config.num_hidden_layers,
|
layer_num=self.model_config.num_hidden_layers,
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ class ServerArgs:
|
|||||||
skip_tokenizer_init: bool = False
|
skip_tokenizer_init: bool = False
|
||||||
load_format: str = "auto"
|
load_format: str = "auto"
|
||||||
dtype: str = "auto"
|
dtype: str = "auto"
|
||||||
|
kv_cache_dtype: str = "auto"
|
||||||
trust_remote_code: bool = True
|
trust_remote_code: bool = True
|
||||||
context_length: Optional[int] = None
|
context_length: Optional[int] = None
|
||||||
quantization: Optional[str] = None
|
quantization: Optional[str] = None
|
||||||
@@ -196,6 +197,13 @@ class ServerArgs:
|
|||||||
'* "float" is shorthand for FP32 precision.\n'
|
'* "float" is shorthand for FP32 precision.\n'
|
||||||
'* "float32" for FP32 precision.',
|
'* "float32" for FP32 precision.',
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--kv-cache-dtype",
|
||||||
|
type=str,
|
||||||
|
default=ServerArgs.kv_cache_dtype,
|
||||||
|
choices=["auto", "fp8_e5m2"],
|
||||||
|
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.',
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--trust-remote-code",
|
"--trust-remote-code",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
Reference in New Issue
Block a user