Support Triton fp8 e5m2 kv cache (#1286)
Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
@@ -128,7 +128,7 @@ def _fwd_kernel(
|
|||||||
k = tl.load(K_Buffer + offs_buf_k, mask=mask_n[None, :], other=0.0)
|
k = tl.load(K_Buffer + offs_buf_k, mask=mask_n[None, :], other=0.0)
|
||||||
|
|
||||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||||
qk += tl.dot(q, k)
|
qk += tl.dot(q.to(k.dtype), k)
|
||||||
if BLOCK_DPE > 0:
|
if BLOCK_DPE > 0:
|
||||||
offs_kpe = (
|
offs_kpe = (
|
||||||
offs_kv_loc[None, :] * stride_buf_kbs
|
offs_kv_loc[None, :] * stride_buf_kbs
|
||||||
@@ -140,7 +140,7 @@ def _fwd_kernel(
|
|||||||
mask=mask_n[None, :],
|
mask=mask_n[None, :],
|
||||||
other=0.0,
|
other=0.0,
|
||||||
)
|
)
|
||||||
qk += tl.dot(qpe, kpe)
|
qk += tl.dot(qpe.to(kpe.dtype), kpe)
|
||||||
qk *= sm_scale
|
qk *= sm_scale
|
||||||
|
|
||||||
if logit_cap > 0:
|
if logit_cap > 0:
|
||||||
@@ -276,9 +276,17 @@ def extend_attention_fwd(
|
|||||||
BLOCK_DV = Lv
|
BLOCK_DV = Lv
|
||||||
|
|
||||||
if CUDA_CAPABILITY[0] >= 9:
|
if CUDA_CAPABILITY[0] >= 9:
|
||||||
BLOCK_M, BLOCK_N = (128, 64)
|
if Lq <= 256:
|
||||||
|
BLOCK_M, BLOCK_N = (128, 64)
|
||||||
|
else:
|
||||||
|
BLOCK_M, BLOCK_N = (32, 64)
|
||||||
elif CUDA_CAPABILITY[0] >= 8:
|
elif CUDA_CAPABILITY[0] >= 8:
|
||||||
BLOCK_M, BLOCK_N = (128, 128) if Lq <= 128 else (64, 64)
|
if Lq <= 128:
|
||||||
|
BLOCK_M, BLOCK_N = (128, 128)
|
||||||
|
elif Lq <= 256:
|
||||||
|
BLOCK_M, BLOCK_N = (64, 64)
|
||||||
|
else:
|
||||||
|
BLOCK_M, BLOCK_N = (32, 64)
|
||||||
else:
|
else:
|
||||||
BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
|
BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
|
||||||
|
|
||||||
|
|||||||
@@ -348,13 +348,7 @@ class ModelRunner:
|
|||||||
if self.server_args.kv_cache_dtype == "auto":
|
if self.server_args.kv_cache_dtype == "auto":
|
||||||
self.kv_cache_dtype = self.dtype
|
self.kv_cache_dtype = self.dtype
|
||||||
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
|
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
|
||||||
if self.server_args.disable_flashinfer or self.server_args.enable_mla:
|
self.kv_cache_dtype = torch.float8_e5m2
|
||||||
logger.warning(
|
|
||||||
"FP8 KV cache is not supported for Triton kernel now, using auto kv cache dtype"
|
|
||||||
)
|
|
||||||
self.kv_cache_dtype = self.dtype
|
|
||||||
else:
|
|
||||||
self.kv_cache_dtype = torch.float8_e5m2
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
|
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
|
||||||
|
|||||||
Reference in New Issue
Block a user