Support FP8 E4M3 KV Cache (#2786)

Co-authored-by: root <bjmsong@126.com>
This commit is contained in:
bjmsong
2025-01-13 13:17:11 +08:00
committed by GitHub
parent 85b2e05770
commit 0bb0f76311
9 changed files with 205 additions and 10 deletions

View File

@@ -109,8 +109,8 @@ class BaseTokenToKVPool:
):
self.size = size
self.dtype = dtype
if dtype == torch.float8_e5m2:
# NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
self.store_dtype = torch.uint8
else:
self.store_dtype = dtype
@@ -256,11 +256,13 @@ class MHATokenToKVPool(BaseTokenToKVPool):
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
k_scale: float = 1.0,
v_scale: float = 1.0,
):
layer_id = layer.layer_id
if cache_k.dtype != self.dtype:
cache_k = cache_k.to(self.dtype)
cache_v = cache_v.to(self.dtype)
cache_k = (cache_k / k_scale).to(self.dtype)
cache_v = (cache_v / v_scale).to(self.dtype)
if self.store_dtype != self.dtype:
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)