Add split tile size for Triton attention (#10425)
This commit is contained in:
@@ -94,6 +94,11 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
"SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
|
"SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
|
||||||
)
|
)
|
||||||
self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
|
self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
|
||||||
|
self.split_tile_size = model_runner.server_args.triton_attention_split_tile_size
|
||||||
|
if self.split_tile_size is not None:
|
||||||
|
self.max_kv_splits = (
|
||||||
|
self.max_context_len + self.split_tile_size - 1
|
||||||
|
) // self.split_tile_size
|
||||||
|
|
||||||
# Check arguments
|
# Check arguments
|
||||||
assert not (
|
assert not (
|
||||||
@@ -153,6 +158,12 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
num_kv_splits.fill_(self.max_kv_splits)
|
num_kv_splits.fill_(self.max_kv_splits)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if self.split_tile_size is not None:
|
||||||
|
num_kv_splits[:] = (
|
||||||
|
seq_lens + self.split_tile_size - 1
|
||||||
|
) // self.split_tile_size
|
||||||
|
return
|
||||||
|
|
||||||
if num_seq < 256:
|
if num_seq < 256:
|
||||||
SCHEDULE_SEQ = 256
|
SCHEDULE_SEQ = 256
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -362,6 +362,7 @@ class ServerArgs:
|
|||||||
enable_p2p_check: bool = False
|
enable_p2p_check: bool = False
|
||||||
triton_attention_reduce_in_fp32: bool = False
|
triton_attention_reduce_in_fp32: bool = False
|
||||||
triton_attention_num_kv_splits: int = 8
|
triton_attention_num_kv_splits: int = 8
|
||||||
|
triton_attention_split_tile_size: Optional[int] = None
|
||||||
num_continuous_decode_steps: int = 1
|
num_continuous_decode_steps: int = 1
|
||||||
delete_ckpt_after_loading: bool = False
|
delete_ckpt_after_loading: bool = False
|
||||||
enable_memory_saver: bool = False
|
enable_memory_saver: bool = False
|
||||||
@@ -2100,6 +2101,12 @@ class ServerArgs:
|
|||||||
default=ServerArgs.triton_attention_num_kv_splits,
|
default=ServerArgs.triton_attention_num_kv_splits,
|
||||||
help="The number of KV splits in flash decoding Triton kernel. Larger value is better in longer context scenarios. The default value is 8.",
|
help="The number of KV splits in flash decoding Triton kernel. Larger value is better in longer context scenarios. The default value is 8.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--triton-attention-split-tile-size",
|
||||||
|
type=int,
|
||||||
|
default=ServerArgs.triton_attention_split_tile_size,
|
||||||
|
help="The size of split KV tile in flash decoding Triton kernel. Used for deterministic inference.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-continuous-decode-steps",
|
"--num-continuous-decode-steps",
|
||||||
type=int,
|
type=int,
|
||||||
|
|||||||
Reference in New Issue
Block a user