diff --git a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py index 90576a17a..fb476a762 100644 --- a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py @@ -24,9 +24,7 @@ if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1": from sglang.global_config import global_config from sglang.srt.layers.attention.base_attn_backend import AttentionBackend -from sglang.srt.layers.attention.flashinfer_backend import ( - create_flashinfer_kv_indices_triton, -) +from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.managers.schedule_batch import global_server_args_dict @@ -72,11 +70,11 @@ class FlashInferMLAAttnBackend(AttentionBackend): q_indptr_decode_buf: Optional[torch.Tensor] = None, ): super().__init__() - # Parse constants self.max_context_len = model_runner.model_config.context_len self.device = model_runner.device self.skip_prefill = skip_prefill + self.page_size = model_runner.page_size # Allocate buffers global global_workspace_buffer @@ -97,15 +95,25 @@ class FlashInferMLAAttnBackend(AttentionBackend): else: self.kv_indptr = kv_indptr_buf + self.kv_indices = torch.empty( + (max_bs * (self.max_context_len + self.page_size - 1) // self.page_size,), + dtype=torch.int32, + device=model_runner.device, + ) + if not self.skip_prefill: self.qo_indptr = torch.zeros( (max_bs + 1,), dtype=torch.int32, device=model_runner.device ) if q_indptr_decode_buf is None: + # A hack to pre-initialize large batch size for dp attention + if model_runner.server_args.enable_dp_attention: + max_bs = model_runner.server_args.dp_size * max_bs self.q_indptr_decode = torch.arange( 0, max_bs + 1, dtype=torch.int32, device=model_runner.device ) + else: self.q_indptr_decode = q_indptr_decode_buf @@ -148,6 +156,7 @@ class FlashInferMLAAttnBackend(AttentionBackend): self.prefill_cuda_graph_metadata = {} # For verify def init_forward_metadata(self, forward_batch: ForwardBatch): + if forward_batch.forward_mode.is_decode_or_idle(): self.indices_updater_decode.update( forward_batch.req_pool_indices, @@ -205,16 +214,9 @@ class FlashInferMLAAttnBackend(AttentionBackend): max_num_tokens: int, kv_indices_buf: Optional[torch.Tensor] = None, ): - if kv_indices_buf is None: - cuda_graph_kv_indices = torch.zeros( - (max_bs * self.max_context_len,), - dtype=torch.int32, - device="cuda", - ) - else: - cuda_graph_kv_indices = kv_indices_buf - - self.cuda_graph_kv_indices = cuda_graph_kv_indices + self.cuda_graph_kv_indices = ( + self.kv_indices.clone() if kv_indices_buf is None else kv_indices_buf + ) self.cuda_graph_qo_indptr = self.q_indptr_decode.clone() self.cuda_graph_kv_indptr = self.kv_indptr.clone() self.cuda_graph_kv_lens = torch.ones( @@ -240,6 +242,7 @@ class FlashInferMLAAttnBackend(AttentionBackend): forward_mode: ForwardMode, spec_info: Optional[SpecInfo], ): + if forward_mode.is_decode_or_idle(): decode_wrapper = BatchMLAPagedAttentionWrapper( self.workspace_buffer, @@ -250,7 +253,6 @@ class FlashInferMLAAttnBackend(AttentionBackend): kv_len_arr=self.cuda_graph_kv_lens[:num_tokens], backend="auto", ) - seq_lens_sum = seq_lens.sum().item() self.indices_updater_decode.update( req_pool_indices, @@ -321,11 +323,13 @@ class FlashInferMLAAttnBackend(AttentionBackend): spec_info: Optional[SpecInfo], seq_lens_cpu: Optional[torch.Tensor], ): + if forward_mode.is_decode_or_idle(): assert seq_lens_cpu is not None kv_len_arr_cpu = seq_lens_cpu[:bs] + num_pages_per_req = (seq_lens_cpu + self.page_size - 1) // self.page_size self.cuda_graph_kv_indptr_cpu[1 : bs + 1] = torch.cumsum( - kv_len_arr_cpu, dim=0 + num_pages_per_req, dim=0 ) self.fast_decode_kwargs.update( { @@ -334,7 +338,6 @@ class FlashInferMLAAttnBackend(AttentionBackend): "kv_len_arr_cpu": kv_len_arr_cpu, } ) - self.indices_updater_decode.update( req_pool_indices[:bs], seq_lens[:bs], @@ -381,7 +384,6 @@ class FlashInferMLAAttnBackend(AttentionBackend): q_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None, ): - cache_loc = forward_batch.out_cache_loc logits_soft_cap = layer.logit_cap prefill_wrapper_paged = self.forward_metadata.prefill_wrapper @@ -401,7 +403,6 @@ class FlashInferMLAAttnBackend(AttentionBackend): q_rope = q_rope.view( -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim ) - if self.forward_metadata.use_ragged: # ragged prefill if q_rope is not None: @@ -422,6 +423,8 @@ class FlashInferMLAAttnBackend(AttentionBackend): k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to( q.dtype ) + k_buf = k_buf.view(-1, self.page_size, k_buf.shape[-1]) + if q_rope is None: qall = q.view(-1, layer.tp_q_head_num, layer.head_dim) q, q_rope = ( @@ -483,17 +486,17 @@ class FlashInferMLAAttnBackend(AttentionBackend): q_nope = reshaped_q[:, :, : layer.v_head_dim] q_rope = reshaped_q[:, :, layer.v_head_dim :] - k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to( + k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to( q.dtype ) + k_buf = k_buf.view(-1, self.page_size, k_buf.shape[-1]) o = q_nope.new_empty(q_nope.shape) - # Direct call to run without the wrapper o = decode_wrapper.run( q_nope, q_rope, - k_buffer[:, :, : layer.v_head_dim], - k_buffer[:, :, layer.v_head_dim :], + k_buf[:, :, : layer.v_head_dim], + k_buf[:, :, layer.v_head_dim :], out=o, ) @@ -512,9 +515,10 @@ class FlashInferMLAIndicesUpdaterDecode: self.scaling = model_runner.model_config.scaling self.data_type = model_runner.dtype self.attn_backend = attn_backend - + self.page_size = model_runner.page_size # Buffers and wrappers self.kv_indptr = attn_backend.kv_indptr + self.kv_indices = attn_backend.kv_indices self.req_to_token = model_runner.req_to_token_pool.req_to_token self.q_indptr = attn_backend.q_indptr_decode @@ -558,13 +562,17 @@ class FlashInferMLAIndicesUpdaterDecode: kv_lens = paged_kernel_lens.to(torch.int32) sm_scale = self.scaling if spec_info is None: - kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) + num_pages_per_req = ( + paged_kernel_lens + self.page_size - 1 + ) // self.page_size + kv_indptr[1 : bs + 1] = torch.cumsum(num_pages_per_req, dim=0) kv_indptr = kv_indptr[: bs + 1] kv_indices = ( - torch.empty(paged_kernel_lens_sum, dtype=torch.int32, device="cuda") + self.kv_indices[: kv_indptr[-1]] if not init_metadata_replay else fast_decode_kwargs["kv_indices"] ) + create_flashinfer_kv_indices_triton[(bs,)]( self.req_to_token, req_pool_indices, @@ -573,39 +581,40 @@ class FlashInferMLAIndicesUpdaterDecode: None, kv_indices, self.req_to_token.shape[1], + self.page_size, ) else: kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices if not init_metadata_replay: wrapper.plan( - q_indptr, - kv_indptr, - kv_indices, - kv_lens, - self.num_local_heads, - self.kv_lora_rank, - self.qk_rope_head_dim, - 1, - False, - sm_scale, - self.data_type, - self.data_type, + qo_indptr=q_indptr, + kv_indptr=kv_indptr, + kv_indices=kv_indices, + kv_len_arr=kv_lens, + num_heads=self.num_local_heads, + head_dim_ckv=self.kv_lora_rank, + head_dim_kpe=self.qk_rope_head_dim, + page_size=self.page_size, + causal=False, + sm_scale=sm_scale, + q_data_type=self.data_type, + kv_data_type=self.data_type, ) else: wrapper.plan( - fast_decode_kwargs["qo_indptr_cpu"], - fast_decode_kwargs["kv_indptr_cpu"], - kv_indices, - fast_decode_kwargs["kv_len_arr_cpu"], - self.num_local_heads, - self.kv_lora_rank, - self.qk_rope_head_dim, - 1, - False, - sm_scale, - self.data_type, - self.data_type, + qo_indptr_cpu=fast_decode_kwargs["qo_indptr_cpu"], + kv_indptr_cpu=fast_decode_kwargs["kv_indptr_cpu"], + kv_indices=kv_indices, + kv_len_arr_cpu=fast_decode_kwargs["kv_len_arr_cpu"], + num_heads=self.num_local_heads, + head_dim_ckv=self.kv_lora_rank, + head_dim_kpe=self.qk_rope_head_dim, + page_size=self.page_size, + causal=False, + sm_scale=sm_scale, + q_data_type=self.data_type, + kv_data_type=self.data_type, ) @@ -627,12 +636,14 @@ class FlashInferMLAIndicesUpdaterPrefill: # Buffers and wrappers self.kv_indptr = attn_backend.kv_indptr self.qo_indptr = attn_backend.qo_indptr + self.kv_indices = attn_backend.kv_indices self.req_to_token = model_runner.req_to_token_pool.req_to_token self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged + self.page_size = model_runner.page_size def update( self, - req_pool_indices: torch.Tnesor, + req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, prefix_lens: torch.Tensor, @@ -646,7 +657,6 @@ class FlashInferMLAIndicesUpdaterPrefill: else: paged_kernel_lens = seq_lens paged_kernel_lens_sum = seq_lens_sum - self.call_begin_forward( self.prefill_wrapper_ragged, prefill_wrapper_paged, @@ -680,13 +690,12 @@ class FlashInferMLAIndicesUpdaterPrefill: if spec_info is None: assert len(seq_lens) == len(req_pool_indices) - kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) + num_pages_per_req = ( + paged_kernel_lens + self.page_size - 1 + ) // self.page_size + kv_indptr[1 : bs + 1] = torch.cumsum(num_pages_per_req, dim=0) kv_indptr = kv_indptr[: bs + 1] - kv_indices = torch.empty( - paged_kernel_lens_sum, - dtype=torch.int32, - device=req_pool_indices.device, - ) + kv_indices = self.kv_indices[: kv_indptr[-1]] create_flashinfer_kv_indices_triton[(bs,)]( self.req_to_token, req_pool_indices, @@ -695,6 +704,7 @@ class FlashInferMLAIndicesUpdaterPrefill: None, kv_indices, self.req_to_token.shape[1], + self.page_size, ) qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0) qo_indptr = qo_indptr[: bs + 1] @@ -712,7 +722,6 @@ class FlashInferMLAIndicesUpdaterPrefill: self.req_to_token, ) ) - if use_ragged: # ragged prefill wrapper_ragged.begin_forward( @@ -726,20 +735,26 @@ class FlashInferMLAIndicesUpdaterPrefill: ) else: # mla paged prefill - kv_len_arr = kv_indptr[1:] - kv_indptr[:-1] + if spec_info is not None: + assert ( + self.page_size == 1 + ), "Only page_size=1 is supported for flashinfer backend with speculative decoding" + kv_lens = kv_indptr[1:] - kv_indptr[:-1] + else: + kv_lens = paged_kernel_lens.to(torch.int32) wrapper_paged.plan( - qo_indptr, - kv_indptr, - kv_indices, - kv_len_arr, - self.num_local_heads, - self.kv_lora_rank, - self.qk_rope_head_dim, - 1, - True, - sm_scale, - self.q_data_type, - self.data_type, + qo_indptr=qo_indptr, + kv_indptr=kv_indptr, + kv_indices=kv_indices, + kv_len_arr=kv_lens, + num_heads=self.num_local_heads, + head_dim_ckv=self.kv_lora_rank, + head_dim_kpe=self.qk_rope_head_dim, + page_size=self.page_size, + causal=True, + sm_scale=sm_scale, + q_data_type=self.q_data_type, + kv_data_type=self.data_type, ) @@ -834,6 +849,7 @@ class FlashInferMLAMultiStepDraftBackend: call_fn(i, forward_batch) def init_forward_metadata(self, forward_batch: ForwardBatch): + kv_indices = torch.zeros( ( self.speculative_num_steps, @@ -869,6 +885,7 @@ class FlashInferMLAMultiStepDraftBackend: ) def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch): + def call_fn(i, forward_batch): self.attn_backends[i].init_forward_metadata_capture_cuda_graph( forward_batch.batch_size, @@ -885,6 +902,7 @@ class FlashInferMLAMultiStepDraftBackend: def init_forward_metadata_replay_cuda_graph( self, forward_batch: ForwardBatch, bs: int ): + def call_fn(i, forward_batch): self.attn_backends[i].init_forward_metadata_replay_cuda_graph( bs, diff --git a/python/sglang/srt/layers/attention/utils.py b/python/sglang/srt/layers/attention/utils.py index e8cd2e158..5c9ab87ef 100644 --- a/python/sglang/srt/layers/attention/utils.py +++ b/python/sglang/srt/layers/attention/utils.py @@ -9,18 +9,89 @@ TRITON_PAD_NUM_PAGE_PER_BLOCK = 64 @triton.jit def create_flashinfer_kv_indices_triton( - req_to_token_ptr, # [max_batch, max_context_len] + req_to_token_ptr, req_pool_indices_ptr, page_kernel_lens_ptr, kv_indptr, kv_start_idx, kv_indices_ptr, req_to_token_ptr_stride: tl.constexpr, + PAGE_SIZE: tl.constexpr = 1, ): - BLOCK_SIZE: tl.constexpr = 512 - pid = tl.program_id(axis=0) + """ + Create KV indices for FlashInfer attention backend. - # find the req pool idx, this is for batch to token + This Triton kernel builds a lookup table that maps from logical request/token + coordinates to physical token locations in the global KV cache pool. It's used + by FlashInfer attention backends to efficiently access scattered KV cache data. + + The kernel processes each request in parallel and converts the req_to_token + lookup table into a flat list of token indices that can be used by attention kernels. + + general idea: + blocktables/kv_indices_ptr = [batch_size * max_pages(for graph mode with + fixed number of pages)] + max_pages = max_context_len / PAGED_SIZE + kv_indices_ptr will store the flat list of the pages used by each request + Args: + Inputs Arguments (non mutable): + + req_to_token_ptr: Request to token location look up table + Shape: [max_batch, max_context_len] + req_pool_indices_ptr: Request to pool index look up table. Each request uses + one pool. + Shape: [batch_size] + page_kernel_lens_ptr: sequence lengths per request + Shape: [batch_size] + kv_indptr: Should be computed based on number of pages used by each request. + It is used by flashinfer attention kernels to index into the kv_indices_ptr. + per request. + Shape: [batch_size + 1] + kv_indptr[i] = start index in kv_indices for request i + kv_start_idx: Pointer to array containing start offsets for each request in SGL. + Can be None. If provided, adds offset to token positions. + + req_to_token_ptr_stride: Stride for the second dimension of req_to_token. + Equal to max_context_len. + + PAGED_SIZE: Number of tokens per page. Default is 1 for FlashInfer. + + Outputs: + kv_indices_ptr: Pointer to output array where KV indices will be stored. + Shape:[total-num-pages], + where total_num_pages = sum(seq_lens // PAGED_SIZE) + + Example: + If we have: + - req_pool_indices = [0, 1] (request 0 uses pool 0, request 1 uses pool 1) + - page_kernel_lens = [3, 2] (request 0 has 3 tokens, request 1 has 2 tokens) + - req_to_token = [[10, 11, 12, -1], [20, 21, -1, -1]] (tokens are the elements + in radix tree, use them as a pointer to the token location in the kv_indices_ptr) + + The kernel will output: + If PAGE_SIZE = 1: + packed + - kv_indptr (passed in as input arg): [0,3,5] + - kv_indices = [10, 11, 12, 20, 21] + padded - max_pages is 10 tokens per req + - kv_indptr (passed in as input arg): [0,10, 20] + - kv_indices = [10, 11, 12, -1, -1, -1, -1, -1, -1, -1, + 20, 21, -1, -1, -1, -1, -1, -1, -1, -1] + + If PAGE_SIZE = 2 + packed: + - kv_indptr (passed in as input arg): [0,3,4] + - kv_indices = [5,6,10] + padded: max_pages is 4 + - kv_indptr (passed in as input arg): [0,4,8,..] (note that 4 is the max_pages) + - kv_indices = [5, 6, -1, -1, + 10, -1, -1, -1] + This allows attention kernels to directly access the correct KV cache + entries for each request's tokens. + """ + BLOCK_SIZE: tl.constexpr = 512 + NUM_PAGES_PER_BLOCK: tl.constexpr = BLOCK_SIZE // PAGE_SIZE + pid = tl.program_id(axis=0) req_pool_index = tl.load(req_pool_indices_ptr + pid) kv_indices_offset = tl.load(kv_indptr + pid) @@ -31,19 +102,27 @@ def create_flashinfer_kv_indices_triton( kv_end = kv_start kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32) - num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE) - for i in range(num_loop): - # index into req_to_token_ptr needs to be int64 - offset = tl.arange(0, BLOCK_SIZE).to(tl.int64) + i * BLOCK_SIZE - mask = offset < kv_end - kv_start - data = tl.load( - req_to_token_ptr - + req_pool_index * req_to_token_ptr_stride - + kv_start - + offset, - mask=mask, + kv_range = kv_end - kv_start + num_pages = tl.cdiv(kv_range, PAGE_SIZE) + num_loops = tl.cdiv(kv_range, BLOCK_SIZE) + req_to_token_block_start = ( + req_to_token_ptr + req_pool_index * req_to_token_ptr_stride + kv_start + ) + for i in range(num_loops): + token_offsets_in_block = ( + tl.arange(0, NUM_PAGES_PER_BLOCK).to(tl.int64) + i * NUM_PAGES_PER_BLOCK + ) * PAGE_SIZE + page_offsets_in_block = token_offsets_in_block // PAGE_SIZE + valid_tokens = token_offsets_in_block < kv_range + valid_pages = page_offsets_in_block < num_pages + token_numbers = tl.load( + req_to_token_block_start + token_offsets_in_block, mask=valid_tokens + ) + tl.store( + kv_indices_ptr + kv_indices_offset + page_offsets_in_block, + token_numbers // PAGE_SIZE, # write the page numbers to kv_indices_ptr + mask=valid_pages, ) - tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask) @triton.jit diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 861134ca3..f220770ba 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -636,6 +636,10 @@ class ServerArgs: logger.warning( "DeepSeek MTP does not require setting speculative_draft_model_path." ) + if self.page_size != 1 and self.attention_backend == "flashinfer": + raise ValueError( + "Speculative decoding with page_size != 1 is not supported. Please set page_size to 1." + ) # Auto choose parameters if self.speculative_num_steps is None: diff --git a/test/srt/test_create_kvindices.py b/test/srt/test_create_kvindices.py index 4196eb290..87ebbee3c 100644 --- a/test/srt/test_create_kvindices.py +++ b/test/srt/test_create_kvindices.py @@ -4,7 +4,10 @@ import unittest import numpy as np import torch -from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton +from sglang.srt.layers.attention.utils import ( + create_flashinfer_kv_indices_triton, + create_flashmla_kv_indices_triton, +) from sglang.test.test_utils import CustomTestCase @@ -15,10 +18,14 @@ class TestCreateKvIndices(CustomTestCase): raise unittest.SkipTest("CUDA is not available") torch.set_default_device("cuda") - def _run_test(self, batch, max_batch, max_context_len): + def _run_test(self, batch, max_batch, max_context_len, page_size): + np.random.seed(9) + PAGE_SIZE = page_size req_to_token = torch.arange( max_batch * max_context_len, dtype=torch.int32, device="cuda" ).reshape((max_batch, max_context_len)) + + # the block table req_pool_indices = torch.tensor( torch.from_numpy( np.random.choice(range(max_batch), size=batch, replace=False) @@ -26,49 +33,84 @@ class TestCreateKvIndices(CustomTestCase): dtype=torch.int32, device="cuda", ) - paged_kernel_lens = torch.tensor( + seq_lens = torch.tensor( torch.from_numpy( np.random.choice(range(max_context_len), size=batch, replace=False) ), dtype=torch.int32, device="cuda", ) - + num_pages_per_req = (seq_lens + PAGE_SIZE - 1) // PAGE_SIZE kv_indptr = torch.zeros((batch + 1,), dtype=torch.int32, device="cuda") - kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) + kv_indptr[1:] = torch.cumsum(num_pages_per_req, dim=0) # ref + kv_indices_ref = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda") req_pool_indices_cpu = req_pool_indices.cpu().numpy() - paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy() - kv_indices_ref = torch.cat( - [ - req_to_token[req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]] - for i in range(batch) - ], - dim=0, - ).contiguous() + seq_lens_cpu = seq_lens.cpu().numpy() + for i in range(batch): + kv_indptr_req = kv_indptr[i] + num_toks_seq = seq_lens_cpu[i] + curr_req_pool = req_pool_indices_cpu[i] + curr_num_pages = num_pages_per_req[i] + curr_token_ids = req_to_token[curr_req_pool] + curr_pages = (curr_token_ids[:num_toks_seq] // PAGE_SIZE).unique() + assert ( + len(curr_pages) == curr_num_pages + ), f"req {i} has #{curr_num_pages} pages, but got {len(curr_pages)} pages" + kv_indices_ref[kv_indptr_req : kv_indptr_req + curr_num_pages] = curr_pages # triton kv_indices_triton = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda") create_flashinfer_kv_indices_triton[(batch,)]( req_to_token, req_pool_indices, - paged_kernel_lens, + seq_lens, kv_indptr, None, kv_indices_triton, req_to_token.size(1), + PAGE_SIZE, + ) + max_pages = max_context_len // PAGE_SIZE + kv_indices_flashmla = torch.empty( + batch, max_pages, dtype=torch.int32, device="cuda" ) + create_flashmla_kv_indices_triton[(batch,)]( + req_to_token, + req_pool_indices, + seq_lens, + None, + kv_indices_flashmla, + req_to_token.size(1), + max_pages, + PAGE_SIZE, + ) # Check self.assertTrue(torch.equal(kv_indices_ref, kv_indices_triton)) def test_create_kvindices(self): - BATCH = [1, 37, 1786] + BATCH = [4, 37, 512, 1786] MAX_BATCH = 4096 MAX_CONTEXT_LEN = 4096 - for batch in BATCH: - self._run_test(batch, MAX_BATCH, MAX_CONTEXT_LEN) + PAGE_SIZE = [1, 2, 16, 64] + # for debug + # BATCH = [4] + # MAX_BATCH = 4 + # MAX_CONTEXT_LEN = 10 + # Test for small batch size + for page_size in PAGE_SIZE[:1]: + print(f"Running test for page size: {page_size} and batch size: {BATCH[0]}") + self._run_test(BATCH[0], MAX_BATCH, MAX_CONTEXT_LEN, page_size) + + # Test for larger batch size + for batch in BATCH[1:]: + for page_size in PAGE_SIZE: + print( + f"Running test for batch size: {batch} and page size: {page_size}" + ) + self._run_test(batch, MAX_BATCH, MAX_CONTEXT_LEN, page_size) if __name__ == "__main__": diff --git a/test/srt/test_mla_flashinfer.py b/test/srt/test_mla_flashinfer.py index f72aef5a5..b98e65620 100644 --- a/test/srt/test_mla_flashinfer.py +++ b/test/srt/test_mla_flashinfer.py @@ -120,5 +120,49 @@ class TestFlashinferMLAMTP(CustomTestCase): self.assertGreater(avg_spec_accept_length, 2.5) +class TestFlashinferMLAPageSize16(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct" + cls.base_url = DEFAULT_URL_FOR_TEST + other_args = ["--trust-remote-code"] + if torch.cuda.is_available() and torch.version.cuda: + other_args.extend( + [ + "--cuda-graph-max-bs", + "4", + "--attention-backend", + "flashinfer", + "--page-size", + "16", + ] + ) + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(metrics) + + self.assertGreater(metrics["accuracy"], 0.615) + + if __name__ == "__main__": unittest.main()