[Feature] Support fp8 e5m2 kv cache with flashinfer (#1204)

Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
Ke Bao
2024-08-26 08:38:11 +08:00
committed by GitHub
parent 61bb223e0f
commit 2c615d120f
5 changed files with 116 additions and 16 deletions

View File

@@ -16,7 +16,8 @@ limitations under the License.
"""Memory pool."""
import logging
from typing import List, Union
from abc import ABC, abstractmethod
from typing import List, Tuple, Union
import torch
@@ -52,14 +53,21 @@ class ReqToTokenPool:
self.free_slots = list(range(self.size))
class BaseTokenToKVPool:
class BaseTokenToKVPool(ABC):
"""A memory pool that maps a token to its kv cache locations"""
def __init__(
self,
size: int,
dtype: torch.dtype,
):
self.size = size
self.dtype = dtype
if dtype == torch.float8_e5m2:
# NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
self.store_dtype = torch.uint8
else:
self.store_dtype = dtype
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
@@ -112,6 +120,28 @@ class BaseTokenToKVPool:
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self.mem_state[0] = False
@abstractmethod
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
raise NotImplementedError()
@abstractmethod
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
raise NotImplementedError()
@abstractmethod
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError()
@abstractmethod
def set_kv_buffer(
self,
layer_id: int,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
) -> None:
raise NotImplementedError()
class MHATokenToKVPool(BaseTokenToKVPool):
@@ -123,26 +153,52 @@ class MHATokenToKVPool(BaseTokenToKVPool):
head_dim: int,
layer_num: int,
):
super().__init__(size)
super().__init__(size, dtype)
# [size, head_num, head_dim] for each layer
self.k_buffer = [
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
torch.empty(
(size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda"
)
for _ in range(layer_num)
]
self.v_buffer = [
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
torch.empty(
(size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda"
)
for _ in range(layer_num)
]
def get_key_buffer(self, layer_id: int):
if self.store_dtype != self.dtype:
return self.k_buffer[layer_id].view(self.dtype)
return self.k_buffer[layer_id]
def get_value_buffer(self, layer_id: int):
if self.store_dtype != self.dtype:
return self.v_buffer[layer_id].view(self.dtype)
return self.v_buffer[layer_id]
def get_kv_buffer(self, layer_id: int):
return self.k_buffer[layer_id], self.v_buffer[layer_id]
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
def set_kv_buffer(
self,
layer_id: int,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
):
if cache_k.dtype != self.dtype:
cache_k = cache_k.to(self.dtype)
if cache_v.dtype != self.dtype:
cache_v = cache_v.to(self.dtype)
if self.store_dtype != self.dtype:
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
else:
self.k_buffer[layer_id][loc] = cache_k
self.v_buffer[layer_id][loc] = cache_v
class MLATokenToKVPool(BaseTokenToKVPool):
@@ -155,23 +211,41 @@ class MLATokenToKVPool(BaseTokenToKVPool):
qk_rope_head_dim: int,
layer_num: int,
):
super().__init__(size)
super().__init__(size, dtype)
self.kv_lora_rank = kv_lora_rank
self.kv_buffer = [
torch.empty(
(size + 1, 1, kv_lora_rank + qk_rope_head_dim),
dtype=dtype,
dtype=self.store_dtype,
device="cuda",
)
for _ in range(layer_num)
]
def get_key_buffer(self, layer_id: int):
if self.store_dtype != self.dtype:
return self.kv_buffer[layer_id].view(self.dtype)
return self.kv_buffer[layer_id]
def get_value_buffer(self, layer_id: int):
if self.store_dtype != self.dtype:
return self.kv_buffer[layer_id][..., : self.kv_lora_rank].view(self.dtype)
return self.kv_buffer[layer_id][..., : self.kv_lora_rank]
def get_kv_buffer(self, layer_id: int):
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
def set_kv_buffer(
self,
layer_id: int,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
):
if cache_k.dtype != self.dtype:
cache_k = cache_k.to(self.dtype)
if self.store_dtype != self.dtype:
self.kv_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
else:
self.kv_buffer[layer_id][loc] = cache_k