Support Triton fp8 e5m2 kv cache (#1286)
Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
@@ -128,7 +128,7 @@ def _fwd_kernel(
|
||||
k = tl.load(K_Buffer + offs_buf_k, mask=mask_n[None, :], other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk += tl.dot(q.to(k.dtype), k)
|
||||
if BLOCK_DPE > 0:
|
||||
offs_kpe = (
|
||||
offs_kv_loc[None, :] * stride_buf_kbs
|
||||
@@ -140,7 +140,7 @@ def _fwd_kernel(
|
||||
mask=mask_n[None, :],
|
||||
other=0.0,
|
||||
)
|
||||
qk += tl.dot(qpe, kpe)
|
||||
qk += tl.dot(qpe.to(kpe.dtype), kpe)
|
||||
qk *= sm_scale
|
||||
|
||||
if logit_cap > 0:
|
||||
@@ -276,9 +276,17 @@ def extend_attention_fwd(
|
||||
BLOCK_DV = Lv
|
||||
|
||||
if CUDA_CAPABILITY[0] >= 9:
|
||||
BLOCK_M, BLOCK_N = (128, 64)
|
||||
if Lq <= 256:
|
||||
BLOCK_M, BLOCK_N = (128, 64)
|
||||
else:
|
||||
BLOCK_M, BLOCK_N = (32, 64)
|
||||
elif CUDA_CAPABILITY[0] >= 8:
|
||||
BLOCK_M, BLOCK_N = (128, 128) if Lq <= 128 else (64, 64)
|
||||
if Lq <= 128:
|
||||
BLOCK_M, BLOCK_N = (128, 128)
|
||||
elif Lq <= 256:
|
||||
BLOCK_M, BLOCK_N = (64, 64)
|
||||
else:
|
||||
BLOCK_M, BLOCK_N = (32, 64)
|
||||
else:
|
||||
BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
|
||||
|
||||
|
||||
@@ -348,13 +348,7 @@ class ModelRunner:
|
||||
if self.server_args.kv_cache_dtype == "auto":
|
||||
self.kv_cache_dtype = self.dtype
|
||||
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
|
||||
if self.server_args.disable_flashinfer or self.server_args.enable_mla:
|
||||
logger.warning(
|
||||
"FP8 KV cache is not supported for Triton kernel now, using auto kv cache dtype"
|
||||
)
|
||||
self.kv_cache_dtype = self.dtype
|
||||
else:
|
||||
self.kv_cache_dtype = torch.float8_e5m2
|
||||
self.kv_cache_dtype = torch.float8_e5m2
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
|
||||
|
||||
Reference in New Issue
Block a user