AITER backend extension and workload optimizations (#6838)

Co-authored-by: wunhuang <wunhuang@amd.com>
Co-authored-by: Hubert Lu <Hubert.Lu@amd.com>
This commit is contained in:
HAI
2025-06-05 23:00:18 -07:00
committed by GitHub
parent 562f279a2d
commit b819381fec
12 changed files with 583 additions and 164 deletions

View File

@@ -27,12 +27,19 @@ if TYPE_CHECKING:
from sglang.srt.speculative.spec_info import SpecInfo
try:
from aiter import mha_batch_prefill_func, paged_attention_ragged
from aiter import (
flash_attn_varlen_func,
mha_batch_prefill_func,
paged_attention_ragged,
)
from aiter.mla import mla_decode_fwd
except ImportError:
print(
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
)
from sglang.srt.configs.model_config import AttentionArch
class WrapperDispatch(Enum):
SLIDING_WINDOW = auto()
@@ -43,6 +50,10 @@ class WrapperDispatch(Enum):
class ForwardMetadata:
kv_indptr: torch.Tensor
kv_indices: torch.Tensor
qo_indptr: torch.Tensor
kv_last_page_len: torch.Tensor
max_extend_len: int
max_prefix_extend_len: int
max_q_len: int
max_kv_len: int
@@ -63,6 +74,7 @@ class AiterAttnBackend(AttentionBackend):
self.device = model_runner.device
self.is_multimodal = model_runner.model_config.is_multimodal
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
self.num_head = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
@@ -75,6 +87,8 @@ class AiterAttnBackend(AttentionBackend):
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
# Parse constants
self.max_context_len = model_runner.model_config.context_len
self.skip_prefill = skip_prefill
@@ -100,6 +114,10 @@ class AiterAttnBackend(AttentionBackend):
self.indices_updater_prefill = AiterIndicesUpdaterPrefill(
model_runner, self
)
if self.use_mla:
self.mla_indices_updater_prefill = AiterMlaIndicesUpdaterPrefill(
model_runner, self
)
# aiter kernel related initialization
self.max_num_partitions = (
@@ -108,33 +126,40 @@ class AiterAttnBackend(AttentionBackend):
nbyes_per_qo_elem = torch.finfo(torch.float32).bits // 8
self.workspace_buffer = torch.empty(
(max_bs * self.num_head * self.max_num_partitions * self.head_dim)
* nbyes_per_qo_elem
+ 2 * (max_bs * self.num_head * self.max_num_partitions) * 4,
dtype=torch.uint8,
device=self.device,
)
if not self.use_mla:
self.workspace_buffer = torch.empty(
(max_bs * self.num_head * self.max_num_partitions * self.head_dim)
* nbyes_per_qo_elem
+ 2 * (max_bs * self.num_head * self.max_num_partitions) * 4,
dtype=torch.uint8,
device=self.device,
)
self.scale = float(1.0 / (self.head_dim**0.5))
self.k_scale = self.v_scale = torch.tensor([1.0], dtype=torch.float32).to(
self.device
)
self.kv_last_page_lens = torch.ones((max_bs,), dtype=torch.int32).to(
self.device
)
self.logits_soft_cap = 0.0
self.forward_metadata: ForwardMetadata = None
if self.use_mla:
self.qo_indptr_ = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init auxiliary variables for triton attention backend."""
bs = forward_batch.batch_size
kv_indptr = self.kv_indptr
spec_info = forward_batch.spec_info
qo_indptr = None
kv_last_page_len = None
max_extend_len = None
if forward_batch.forward_mode.is_decode_or_idle():
# update for aiter
# create kv_indices and kv_inptr
bs = forward_batch.batch_size
kv_indptr = self.kv_indptr
spec_info = forward_batch.spec_info
if spec_info is None:
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
@@ -154,38 +179,103 @@ class AiterAttnBackend(AttentionBackend):
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
bs = kv_indptr.shape[0] - 1
self.forward_metadata = ForwardMetadata(kv_indptr, kv_indices, None, None)
if self.use_mla:
qo_indptr = self.qo_indptr_[: bs + 1]
qo_indptr[1 : bs + 1] = torch.cumsum(self.kv_last_page_len[:bs], dim=0)
kv_last_page_len = self.kv_last_page_len[:bs]
max_extend_len = 1
self.forward_metadata = ForwardMetadata(
kv_indptr,
kv_indices,
qo_indptr,
kv_last_page_len,
max_extend_len,
None,
None,
None,
)
elif forward_batch.forward_mode.is_draft_extend():
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens=None,
encoder_lens=forward_batch.encoder_lens,
spec_info=forward_batch.spec_info,
)
self.forward_metadata = ForwardMetadata(
self.indices_updater_prefill.kv_indptr,
self.indices_updater_prefill.kv_indices,
self.indices_updater_prefill.max_q_len,
self.indices_updater_prefill.max_kv_len,
)
if self.use_mla:
prefix_lens = forward_batch.extend_prefix_lens
self.mla_indices_updater_prefill.update(
forward_batch.req_pool_indices,
prefix_lens,
prefix_lens.sum().item(),
forward_batch.extend_seq_lens,
encoder_lens=forward_batch.encoder_lens,
spec_info=None,
)
self.forward_metadata = ForwardMetadata(
self.mla_indices_updater_prefill.kv_indptr,
self.mla_indices_updater_prefill.kv_indices,
self.mla_indices_updater_prefill.qo_indptr,
self.mla_indices_updater_prefill.kv_last_page_len,
self.mla_indices_updater_prefill.max_extend_len,
self.mla_indices_updater_prefill.max_prefix_extend_len,
None,
None,
)
else:
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens=None,
encoder_lens=forward_batch.encoder_lens,
spec_info=forward_batch.spec_info,
)
self.forward_metadata = ForwardMetadata(
self.indices_updater_prefill.kv_indptr,
self.indices_updater_prefill.kv_indices,
None,
None,
None,
None,
self.indices_updater_prefill.max_q_len,
self.indices_updater_prefill.max_kv_len,
)
elif forward_batch.forward_mode.is_target_verify():
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens=None,
encoder_lens=forward_batch.encoder_lens,
spec_info=forward_batch.spec_info,
)
self.forward_metadata = ForwardMetadata(
self.indices_updater_prefill.kv_indptr,
self.indices_updater_prefill.kv_indices,
self.indices_updater_prefill.max_q_len,
self.indices_updater_prefill.max_kv_len,
)
if self.use_mla:
prefix_lens = forward_batch.extend_prefix_lens
self.mla_indices_updater_prefill.update(
forward_batch.req_pool_indices,
prefix_lens,
prefix_lens.sum().item(),
forward_batch.extend_seq_lens,
encoder_lens=forward_batch.encoder_lens,
spec_info=None,
)
self.forward_metadata = ForwardMetadata(
self.mla_indices_updater_prefill.kv_indptr,
self.mla_indices_updater_prefill.kv_indices,
self.mla_indices_updater_prefill.qo_indptr,
self.mla_indices_updater_prefill.kv_last_page_len,
self.mla_indices_updater_prefill.max_extend_len,
self.mla_indices_updater_prefill.max_prefix_extend_len,
None,
None,
)
else:
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens=None,
encoder_lens=forward_batch.encoder_lens,
spec_info=forward_batch.spec_info,
)
self.forward_metadata = ForwardMetadata(
self.indices_updater_prefill.kv_indptr,
self.indices_updater_prefill.kv_indices,
None,
None,
None,
None,
self.indices_updater_prefill.max_q_len,
self.indices_updater_prefill.max_kv_len,
)
else:
prefix_lens = forward_batch.extend_prefix_lens
@@ -194,24 +284,49 @@ class AiterAttnBackend(AttentionBackend):
else:
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens,
encoder_lens=forward_batch.encoder_lens,
spec_info=None,
)
self.forward_metadata = ForwardMetadata(
self.indices_updater_prefill.kv_indptr,
self.indices_updater_prefill.kv_indices,
self.indices_updater_prefill.max_q_len,
self.indices_updater_prefill.max_kv_len,
)
if self.use_mla:
self.mla_indices_updater_prefill.update(
forward_batch.req_pool_indices,
prefix_lens,
prefix_lens.sum().item(),
forward_batch.extend_seq_lens,
encoder_lens=forward_batch.encoder_lens,
spec_info=None,
)
self.forward_metadata = ForwardMetadata(
self.mla_indices_updater_prefill.kv_indptr,
self.mla_indices_updater_prefill.kv_indices,
self.mla_indices_updater_prefill.qo_indptr,
self.mla_indices_updater_prefill.kv_last_page_len,
self.mla_indices_updater_prefill.max_extend_len,
self.mla_indices_updater_prefill.max_prefix_extend_len,
None,
None,
)
else:
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens,
encoder_lens=forward_batch.encoder_lens,
spec_info=None,
)
self.forward_metadata = ForwardMetadata(
self.indices_updater_prefill.kv_indptr,
self.indices_updater_prefill.kv_indices,
None,
None,
None,
None,
self.indices_updater_prefill.max_q_len,
self.indices_updater_prefill.max_kv_len,
)
def init_cuda_graph_state(
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
):
self.cuda_graph_kv_last_page_len = torch.ones(max_bs, dtype=torch.int)
if kv_indices_buf is None:
self.cuda_graph_kv_indices = torch.zeros(
(max_bs * self.max_context_len),
@@ -239,6 +354,10 @@ class AiterAttnBackend(AttentionBackend):
spec_info: Optional[SpecInfo],
):
if forward_mode.is_decode_or_idle():
qo_indptr = None
kv_last_page_len = None
max_extend_len = None
if spec_info is None:
kv_indptr = self.kv_indptr
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
@@ -255,24 +374,82 @@ class AiterAttnBackend(AttentionBackend):
)
else:
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
self.forward_metadata = ForwardMetadata(kv_indptr, kv_indices, None, None)
if self.use_mla:
qo_indptr = self.qo_indptr_[: bs + 1]
qo_indptr[1 : bs + 1] = torch.cumsum(
self.cuda_graph_kv_last_page_len[:bs], dim=0
)
max_extend_len = 1
kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
self.forward_metadata = ForwardMetadata(
kv_indptr,
kv_indices,
qo_indptr,
kv_last_page_len,
max_extend_len,
None,
None,
None,
)
elif forward_mode.is_target_verify():
seq_lens_sum = seq_lens.sum().item()
self.indices_updater_prefill.update(
req_pool_indices,
seq_lens,
seq_lens_sum,
prefix_lens=None,
encoder_lens=encoder_lens,
spec_info=spec_info,
)
self.forward_metadata = ForwardMetadata(
self.indices_updater_prefill.kv_indptr,
self.indices_updater_prefill.kv_indices,
self.indices_updater_prefill.max_q_len,
self.indices_updater_prefill.max_kv_len,
)
if self.use_mla:
qo_indptr = self.qo_indptr[: bs + 1]
qo_indptr[: bs + 1] = torch.arange(
0,
(1 + bs) * self.num_draft_tokens,
step=self.num_draft_tokens,
dtype=torch.int32,
device=self.device,
)
kv_indptr = self.kv_indptr[: bs + 1]
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
kv_indices = self.cuda_graph_kv_indices
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
seq_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)
max_extend_len = self.num_draft_tokens
kv_last_page_len = None
self.forward_metadata = ForwardMetadata(
kv_indptr,
kv_indices,
qo_indptr,
kv_last_page_len,
max_extend_len,
None,
None,
None,
)
else:
seq_lens_sum = seq_lens.sum().item()
self.indices_updater_prefill.update(
req_pool_indices,
seq_lens,
seq_lens_sum,
prefix_lens=None,
encoder_lens=encoder_lens,
spec_info=spec_info,
)
self.forward_metadata = ForwardMetadata(
self.indices_updater_prefill.kv_indptr,
self.indices_updater_prefill.kv_indices,
None,
None,
None,
None,
self.indices_updater_prefill.max_q_len,
self.indices_updater_prefill.max_kv_len,
)
else:
raise ValueError(f"Invalid mode: {forward_mode=}")
@@ -342,31 +519,113 @@ class AiterAttnBackend(AttentionBackend):
if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
if self.use_mla:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
else:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
if self.use_mla:
max_extend_len = self.forward_metadata.max_extend_len
max_prefix_extend_len = self.forward_metadata.max_prefix_extend_len
kv_indptr = self.forward_metadata.kv_indptr
kv_indices = self.forward_metadata.kv_indices
kv_last_page_lens = self.forward_metadata.kv_last_page_len
qo_indptr = self.forward_metadata.qo_indptr
K_Buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
V_Buffer = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
kv_lora_rank = V_Buffer.shape[-1]
qk_rope_head_dim = K_Buffer.shape[-1] - kv_lora_rank
qk_nope_head_dim = k.shape[-1] - qk_rope_head_dim
assert len(q.shape) == 3
assert len(k.shape) == 3
assert len(v.shape) == 3
if kv_indices.shape[0] == 0:
o = flash_attn_varlen_func(
q,
k,
v,
qo_indptr,
qo_indptr,
max_extend_len,
max_extend_len,
softmax_scale=layer.scaling,
causal=True,
)
return o
elif layer.qk_head_dim != (kv_lora_rank + qk_rope_head_dim):
K_Buffer = torch.index_select(K_Buffer, 0, kv_indices)
kvc, k_pe = torch.split(
K_Buffer, [kv_lora_rank, qk_rope_head_dim], dim=-1
)
kvprefix = layer.kv_b_proj(kvc.contiguous())[0]
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
kvprefix = kvprefix.view(
-1, layer.tp_k_head_num, qk_nope_head_dim + layer.v_head_dim
)
k_prefix, v_prefix = torch.split(
kvprefix, [qk_nope_head_dim, layer.v_head_dim], dim=-1
)
k_prefix = torch.cat(
[
k_prefix,
torch.broadcast_to(
k_pe,
(k_pe.shape[0], layer.tp_k_head_num, k_pe.shape[2]),
),
],
dim=-1,
)
assert (
forward_batch.extend_prefix_lens.shape
== forward_batch.extend_seq_lens.shape
)
k_prefix = torch.split(k_prefix, forward_batch.extend_prefix_lens_cpu)
k_extend = torch.split(k, forward_batch.extend_seq_lens_cpu)
assert len(k_prefix) == len(forward_batch.extend_prefix_lens_cpu)
k = torch.cat([x for el in zip(k_prefix, k_extend) for x in el])
v_prefix = torch.split(v_prefix, forward_batch.extend_prefix_lens_cpu)
v_extend = torch.split(v, forward_batch.extend_seq_lens_cpu)
v = torch.cat([x for el in zip(v_prefix, v_extend) for x in el])
bs0 = forward_batch.batch_size + 1
o = flash_attn_varlen_func(
q,
k,
v,
qo_indptr,
kv_indptr,
max_extend_len,
max_prefix_extend_len,
softmax_scale=layer.scaling,
causal=True,
)
return o
else:
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
layer.layer_id
)
o = mha_batch_prefill_func(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache,
v_cache,
self.qo_indptr[:bs0],
self.forward_metadata.kv_indptr[:bs0],
self.forward_metadata.kv_indices,
self.forward_metadata.max_q_len,
self.forward_metadata.max_kv_len,
causal=True,
logits_soft_cap=self.logits_soft_cap,
alibi_slopes=None,
return_lse=False,
return_attn_probs=False,
)
bs0 = forward_batch.batch_size + 1
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
o = mha_batch_prefill_func(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache,
v_cache,
self.qo_indptr[:bs0],
self.forward_metadata.kv_indptr[:bs0],
self.forward_metadata.kv_indices,
self.forward_metadata.max_q_len,
self.forward_metadata.max_kv_len,
causal=True,
logits_soft_cap=self.logits_soft_cap,
alibi_slopes=None,
return_lse=False,
return_attn_probs=False,
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
def forward_decode(
self,
@@ -377,6 +636,7 @@ class AiterAttnBackend(AttentionBackend):
forward_batch: ForwardBatch,
save_kv_cache=True,
):
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
if layer.qk_head_dim != layer.v_head_dim:
@@ -389,32 +649,48 @@ class AiterAttnBackend(AttentionBackend):
layer, forward_batch.out_cache_loc, k, v
)
self.logits_soft_cap = layer.logit_cap
paged_attention_ragged(
o.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
self.workspace_buffer,
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).view(
-1, 1, layer.tp_k_head_num, layer.qk_head_dim
),
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id).view(
-1, 1, layer.tp_v_head_num, layer.v_head_dim
),
self.scale,
self.forward_metadata.kv_indptr,
self.forward_metadata.kv_indices,
self.kv_last_page_lens,
1,
self.max_num_partitions,
None,
"auto",
"NHD",
self.logits_soft_cap,
self.k_scale,
self.v_scale,
None,
_AITER_PARTITION_SIZE_ROCM,
)
if self.use_mla:
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
mla_decode_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
k_buffer.view(-1, 1, 1, layer.qk_head_dim),
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
self.forward_metadata.qo_indptr,
self.forward_metadata.kv_indptr,
self.forward_metadata.kv_indices,
self.forward_metadata.kv_last_page_len,
self.forward_metadata.max_extend_len,
layer.scaling,
layer.logit_cap,
)
k_buffer = k_buffer.view(-1, 1, layer.qk_head_dim)
else:
self.logits_soft_cap = layer.logit_cap
paged_attention_ragged(
o.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
self.workspace_buffer,
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).view(
-1, 1, layer.tp_k_head_num, layer.qk_head_dim
),
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id).view(
-1, 1, layer.tp_v_head_num, layer.v_head_dim
),
self.scale,
self.forward_metadata.kv_indptr,
self.forward_metadata.kv_indices,
self.kv_last_page_len,
1,
self.max_num_partitions,
None,
"auto",
"NHD",
self.logits_soft_cap,
self.k_scale,
self.v_scale,
None,
_AITER_PARTITION_SIZE_ROCM,
)
return o
@@ -506,9 +782,97 @@ class AiterIndicesUpdaterPrefill:
spec_info.generate_attn_arg_prefill(
req_pool_indices,
paged_kernel_lens,
None,
paged_kernel_lens_sum,
self.req_to_token,
)
)
self.kv_indices = kv_indices
class AiterMlaIndicesUpdaterPrefill:
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
# Parse Constants
self.attn_backend = attn_backend
# Buffers and wrappers
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.update = self.update_single_wrapper
self.kv_indptr = None
self.kv_indices = None
self.qo_indptr = None
self.kv_last_page_len = None
self.max_extend_len = 0
self.max_prefix_extend_len = 0
def update(
self,
req_pool_indices: torch.Tensor,
prefix_lens: torch.Tensor,
prefix_lens_sum: int,
extend_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
):
# Keep the signature for type checking. It will be assigned during runtime.
raise NotImplementedError()
def update_single_wrapper(
self,
req_pool_indices: torch.Tensor,
prefix_lens: torch.Tensor,
prefix_lens_sum: int,
extend_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
):
paged_kernel_lens = prefix_lens
paged_kernel_lens_sum = prefix_lens_sum
bs = len(req_pool_indices)
kv_indptr = self.attn_backend.kv_indptr
if spec_info is None:
# Normal extend
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(
paged_kernel_lens_sum,
dtype=torch.int32,
device=req_pool_indices.device,
)
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
paged_kernel_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)
qo_indptr = self.attn_backend.qo_indptr
qo_indptr[1 : bs + 1] = torch.cumsum(extend_lens, dim=0)
qo_indptr = qo_indptr[: bs + 1]
max_extend_len = torch.max(extend_lens).item()
max_prefix_extend_len = torch.max(extend_lens + paged_kernel_lens).item()
kv_indptr += qo_indptr
else:
kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill(
req_pool_indices,
paged_kernel_lens,
paged_kernel_lens_sum,
self.req_to_token,
)
)
self.kv_indptr = kv_indptr
self.kv_indices = kv_indices
self.qo_indptr = qo_indptr
self.max_extend_len = max_extend_len
self.max_prefix_extend_len = max_prefix_extend_len