diff --git a/python/sglang/srt/layers/extend_attention.py b/python/sglang/srt/layers/extend_attention.py index 0a03f6562..097adca3c 100644 --- a/python/sglang/srt/layers/extend_attention.py +++ b/python/sglang/srt/layers/extend_attention.py @@ -275,7 +275,9 @@ def extend_attention_fwd( BLOCK_DPE = 0 BLOCK_DV = Lv - if CUDA_CAPABILITY[0] >= 8: + if CUDA_CAPABILITY[0] >= 9: + BLOCK_M, BLOCK_N = (128, 64) + elif CUDA_CAPABILITY[0] >= 8: BLOCK_M, BLOCK_N = (128, 128) if Lq <= 128 else (64, 64) else: BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)