diff --git a/python/sglang/srt/layers/extend_attention.py b/python/sglang/srt/layers/extend_attention.py index 097adca3c..31a002f43 100644 --- a/python/sglang/srt/layers/extend_attention.py +++ b/python/sglang/srt/layers/extend_attention.py @@ -128,7 +128,7 @@ def _fwd_kernel( k = tl.load(K_Buffer + offs_buf_k, mask=mask_n[None, :], other=0.0) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) + qk += tl.dot(q.to(k.dtype), k) if BLOCK_DPE > 0: offs_kpe = ( offs_kv_loc[None, :] * stride_buf_kbs @@ -140,7 +140,7 @@ def _fwd_kernel( mask=mask_n[None, :], other=0.0, ) - qk += tl.dot(qpe, kpe) + qk += tl.dot(qpe.to(kpe.dtype), kpe) qk *= sm_scale if logit_cap > 0: @@ -276,9 +276,17 @@ def extend_attention_fwd( BLOCK_DV = Lv if CUDA_CAPABILITY[0] >= 9: - BLOCK_M, BLOCK_N = (128, 64) + if Lq <= 256: + BLOCK_M, BLOCK_N = (128, 64) + else: + BLOCK_M, BLOCK_N = (32, 64) elif CUDA_CAPABILITY[0] >= 8: - BLOCK_M, BLOCK_N = (128, 128) if Lq <= 128 else (64, 64) + if Lq <= 128: + BLOCK_M, BLOCK_N = (128, 128) + elif Lq <= 256: + BLOCK_M, BLOCK_N = (64, 64) + else: + BLOCK_M, BLOCK_N = (32, 64) else: BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index e6f5e7431..cee269dc7 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -348,13 +348,7 @@ class ModelRunner: if self.server_args.kv_cache_dtype == "auto": self.kv_cache_dtype = self.dtype elif self.server_args.kv_cache_dtype == "fp8_e5m2": - if self.server_args.disable_flashinfer or self.server_args.enable_mla: - logger.warning( - "FP8 KV cache is not supported for Triton kernel now, using auto kv cache dtype" - ) - self.kv_cache_dtype = self.dtype - else: - self.kv_cache_dtype = torch.float8_e5m2 + self.kv_cache_dtype = torch.float8_e5m2 else: raise ValueError( f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."