Support FP8 E4M3 KV Cache (#2786)
Co-authored-by: root <bjmsong@126.com>
This commit is contained in:
@@ -109,8 +109,8 @@ class BaseTokenToKVPool:
|
||||
):
|
||||
self.size = size
|
||||
self.dtype = dtype
|
||||
if dtype == torch.float8_e5m2:
|
||||
# NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
|
||||
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
||||
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
|
||||
self.store_dtype = torch.uint8
|
||||
else:
|
||||
self.store_dtype = dtype
|
||||
@@ -256,11 +256,13 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
||||
loc: torch.Tensor,
|
||||
cache_k: torch.Tensor,
|
||||
cache_v: torch.Tensor,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
):
|
||||
layer_id = layer.layer_id
|
||||
if cache_k.dtype != self.dtype:
|
||||
cache_k = cache_k.to(self.dtype)
|
||||
cache_v = cache_v.to(self.dtype)
|
||||
cache_k = (cache_k / k_scale).to(self.dtype)
|
||||
cache_v = (cache_v / v_scale).to(self.dtype)
|
||||
if self.store_dtype != self.dtype:
|
||||
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
|
||||
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
|
||||
|
||||
Reference in New Issue
Block a user