From 6cb32ef92c99ee7c1192ff90023692adc106049c Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Sun, 1 Sep 2024 17:46:40 +0800 Subject: [PATCH] Support Triton fp8 e5m2 kv cache (#1286) Co-authored-by: Yineng Zhang --- python/sglang/srt/layers/extend_attention.py | 16 ++++++++++++---- python/sglang/srt/model_executor/model_runner.py | 8 +------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/layers/extend_attention.py b/python/sglang/srt/layers/extend_attention.py index 097adca3c..31a002f43 100644 --- a/python/sglang/srt/layers/extend_attention.py +++ b/python/sglang/srt/layers/extend_attention.py @@ -128,7 +128,7 @@ def _fwd_kernel( k = tl.load(K_Buffer + offs_buf_k, mask=mask_n[None, :], other=0.0) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) + qk += tl.dot(q.to(k.dtype), k) if BLOCK_DPE > 0: offs_kpe = ( offs_kv_loc[None, :] * stride_buf_kbs @@ -140,7 +140,7 @@ def _fwd_kernel( mask=mask_n[None, :], other=0.0, ) - qk += tl.dot(qpe, kpe) + qk += tl.dot(qpe.to(kpe.dtype), kpe) qk *= sm_scale if logit_cap > 0: @@ -276,9 +276,17 @@ def extend_attention_fwd( BLOCK_DV = Lv if CUDA_CAPABILITY[0] >= 9: - BLOCK_M, BLOCK_N = (128, 64) + if Lq <= 256: + BLOCK_M, BLOCK_N = (128, 64) + else: + BLOCK_M, BLOCK_N = (32, 64) elif CUDA_CAPABILITY[0] >= 8: - BLOCK_M, BLOCK_N = (128, 128) if Lq <= 128 else (64, 64) + if Lq <= 128: + BLOCK_M, BLOCK_N = (128, 128) + elif Lq <= 256: + BLOCK_M, BLOCK_N = (64, 64) + else: + BLOCK_M, BLOCK_N = (32, 64) else: BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index e6f5e7431..cee269dc7 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -348,13 +348,7 @@ class ModelRunner: if self.server_args.kv_cache_dtype == "auto": self.kv_cache_dtype = self.dtype elif self.server_args.kv_cache_dtype == "fp8_e5m2": - if self.server_args.disable_flashinfer or self.server_args.enable_mla: - logger.warning( - "FP8 KV cache is not supported for Triton kernel now, using auto kv cache dtype" - ) - self.kv_cache_dtype = self.dtype - else: - self.kv_cache_dtype = torch.float8_e5m2 + self.kv_cache_dtype = torch.float8_e5m2 else: raise ValueError( f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."