[Feature] Support fp8 e5m2 kv cache with flashinfer (#1204)
Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
@@ -16,7 +16,8 @@ limitations under the License.
|
||||
"""Memory pool."""
|
||||
|
||||
import logging
|
||||
from typing import List, Union
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -52,14 +53,21 @@ class ReqToTokenPool:
|
||||
self.free_slots = list(range(self.size))
|
||||
|
||||
|
||||
class BaseTokenToKVPool:
|
||||
class BaseTokenToKVPool(ABC):
|
||||
"""A memory pool that maps a token to its kv cache locations"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
self.size = size
|
||||
self.dtype = dtype
|
||||
if dtype == torch.float8_e5m2:
|
||||
# NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
|
||||
self.store_dtype = torch.uint8
|
||||
else:
|
||||
self.store_dtype = dtype
|
||||
|
||||
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
||||
self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
|
||||
@@ -112,6 +120,28 @@ class BaseTokenToKVPool:
|
||||
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
||||
self.mem_state[0] = False
|
||||
|
||||
@abstractmethod
|
||||
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def set_kv_buffer(
|
||||
self,
|
||||
layer_id: int,
|
||||
loc: torch.Tensor,
|
||||
cache_k: torch.Tensor,
|
||||
cache_v: torch.Tensor,
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class MHATokenToKVPool(BaseTokenToKVPool):
|
||||
|
||||
@@ -123,26 +153,52 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
||||
head_dim: int,
|
||||
layer_num: int,
|
||||
):
|
||||
super().__init__(size)
|
||||
super().__init__(size, dtype)
|
||||
|
||||
# [size, head_num, head_dim] for each layer
|
||||
self.k_buffer = [
|
||||
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
|
||||
torch.empty(
|
||||
(size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda"
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
self.v_buffer = [
|
||||
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
|
||||
torch.empty(
|
||||
(size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda"
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
|
||||
def get_key_buffer(self, layer_id: int):
|
||||
if self.store_dtype != self.dtype:
|
||||
return self.k_buffer[layer_id].view(self.dtype)
|
||||
return self.k_buffer[layer_id]
|
||||
|
||||
def get_value_buffer(self, layer_id: int):
|
||||
if self.store_dtype != self.dtype:
|
||||
return self.v_buffer[layer_id].view(self.dtype)
|
||||
return self.v_buffer[layer_id]
|
||||
|
||||
def get_kv_buffer(self, layer_id: int):
|
||||
return self.k_buffer[layer_id], self.v_buffer[layer_id]
|
||||
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
|
||||
|
||||
def set_kv_buffer(
|
||||
self,
|
||||
layer_id: int,
|
||||
loc: torch.Tensor,
|
||||
cache_k: torch.Tensor,
|
||||
cache_v: torch.Tensor,
|
||||
):
|
||||
if cache_k.dtype != self.dtype:
|
||||
cache_k = cache_k.to(self.dtype)
|
||||
if cache_v.dtype != self.dtype:
|
||||
cache_v = cache_v.to(self.dtype)
|
||||
if self.store_dtype != self.dtype:
|
||||
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
|
||||
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
|
||||
else:
|
||||
self.k_buffer[layer_id][loc] = cache_k
|
||||
self.v_buffer[layer_id][loc] = cache_v
|
||||
|
||||
|
||||
class MLATokenToKVPool(BaseTokenToKVPool):
|
||||
@@ -155,23 +211,41 @@ class MLATokenToKVPool(BaseTokenToKVPool):
|
||||
qk_rope_head_dim: int,
|
||||
layer_num: int,
|
||||
):
|
||||
super().__init__(size)
|
||||
super().__init__(size, dtype)
|
||||
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.kv_buffer = [
|
||||
torch.empty(
|
||||
(size + 1, 1, kv_lora_rank + qk_rope_head_dim),
|
||||
dtype=dtype,
|
||||
dtype=self.store_dtype,
|
||||
device="cuda",
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
|
||||
def get_key_buffer(self, layer_id: int):
|
||||
if self.store_dtype != self.dtype:
|
||||
return self.kv_buffer[layer_id].view(self.dtype)
|
||||
return self.kv_buffer[layer_id]
|
||||
|
||||
def get_value_buffer(self, layer_id: int):
|
||||
if self.store_dtype != self.dtype:
|
||||
return self.kv_buffer[layer_id][..., : self.kv_lora_rank].view(self.dtype)
|
||||
return self.kv_buffer[layer_id][..., : self.kv_lora_rank]
|
||||
|
||||
def get_kv_buffer(self, layer_id: int):
|
||||
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
|
||||
|
||||
def set_kv_buffer(
|
||||
self,
|
||||
layer_id: int,
|
||||
loc: torch.Tensor,
|
||||
cache_k: torch.Tensor,
|
||||
cache_v: torch.Tensor,
|
||||
):
|
||||
if cache_k.dtype != self.dtype:
|
||||
cache_k = cache_k.to(self.dtype)
|
||||
if self.store_dtype != self.dtype:
|
||||
self.kv_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
|
||||
else:
|
||||
self.kv_buffer[layer_id][loc] = cache_k
|
||||
|
||||
Reference in New Issue
Block a user