diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index b6329cb28..5d55abe0a 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1229,7 +1229,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to( self.device, non_blocking=True ) - seq_lens_cpu_tensor = torch.tensor(seq_lens, dtype=torch.int64) + seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64) orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to( self.device, non_blocking=True ) @@ -1366,7 +1366,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): prefix_lens_tensor, prefix_lens_cpu_tensor, seq_lens_tensor, - seq_lens_cpu_tensor, + seq_lens_cpu, last_loc, extend_num_tokens, ) @@ -1375,7 +1375,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.input_ids = input_ids_tensor self.req_pool_indices = req_pool_indices_tensor self.seq_lens = seq_lens_tensor - self.seq_lens_cpu = seq_lens_cpu_tensor + self.seq_lens_cpu = seq_lens_cpu self.orig_seq_lens = orig_seq_lens_tensor self.out_cache_loc = out_cache_loc self.input_embeds = ( diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index d616e7f1a..dc1fbd2db 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1087,7 +1087,10 @@ class ServerArgs: and self.attention_backend != "flashinfer" ): raise ValueError( - "speculative_eagle_topk > 1 with page_size > 1 is unstable and produces incorrect results for paged attention backends. This combination is only supported for the 'flashinfer' backend." + f"speculative_eagle_topk({self.speculative_eagle_topk}) > 1 " + f"with page_size({self.page_size}) > 1 is unstable " + "and produces incorrect results for paged attention backends. " + "This combination is only supported for the 'flashinfer' backend." ) if self.enable_dp_attention: # TODO: support dp attention for ngram speculative decoding diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py index 6ab1499f9..5d8c920c4 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -388,6 +388,8 @@ class EagleVerifyInput(SpecInput): evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool) evict_mask[accept_index] = False accept_length_cpu = accept_length.cpu() + # FIXME: this `tolist()` fixes the numerical calculation consistency + # try to unify the tensor representation and list representation accept_length_list = accept_length_cpu.tolist() if page_size == 1: diff --git a/python/sglang/srt/speculative/ngram_utils.py b/python/sglang/srt/speculative/ngram_utils.py index 79d66a047..345fcbd66 100644 --- a/python/sglang/srt/speculative/ngram_utils.py +++ b/python/sglang/srt/speculative/ngram_utils.py @@ -79,14 +79,21 @@ class NgramVerifyInput(SpecInput): else: # TODO(lsyin): add prefix lens cpu here to support page size > 1 prefix_lens = batch.seq_lens + prefix_lens_cpu = batch.seq_lens_cpu end_offset = prefix_lens + self.draft_token_num + end_offset_cpu = prefix_lens_cpu + self.draft_token_num last_loc = get_last_loc( batch.req_to_token_pool.req_to_token, batch.req_pool_indices, prefix_lens, ) batch.out_cache_loc = batch.alloc_paged_token_slots_extend( - prefix_lens, end_offset, last_loc, len(batch.input_ids) + prefix_lens, + prefix_lens_cpu, + end_offset, + end_offset_cpu, + last_loc, + len(batch.input_ids), ) self.last_loc = last_loc diff --git a/test/srt/test_ngram_speculative_decoding.py b/test/srt/test_ngram_speculative_decoding.py index c791915a8..4495f9121 100644 --- a/test/srt/test_ngram_speculative_decoding.py +++ b/test/srt/test_ngram_speculative_decoding.py @@ -31,7 +31,7 @@ DEFAULT_SERVER_ARGS = [ ] -class TestStandaloneSpeculativeDecodingBase(CustomTestCase): +class TestNgramSpeculativeDecodingBase(CustomTestCase): model = DEFAULT_NGRAM_SPECULATIVE_TARGET_MODEL_FOR_TEST base_url = DEFAULT_URL_FOR_TEST @@ -88,20 +88,30 @@ class TestStandaloneSpeculativeDecodingBase(CustomTestCase): self.assertGreater(avg_spec_accept_length, self.spec_decode_threshold) -class TestStandaloneSpeculativeDecodingTriton(TestStandaloneSpeculativeDecodingBase): +class TestNgramSpeculativeDecodingTriton(TestNgramSpeculativeDecodingBase): @classmethod def get_server_args(cls): return DEFAULT_SERVER_ARGS + ["--attention-backend", "triton"] -class TestStandaloneSpeculativeDecodingFlashinfer( - TestStandaloneSpeculativeDecodingBase -): +class TestNgramSpeculativeDecodingFlashinfer(TestNgramSpeculativeDecodingBase): @classmethod def get_server_args(cls): return DEFAULT_SERVER_ARGS + ["--attention-backend", "flashinfer"] +class TestNgramSpeculativeDecodingPaged(TestNgramSpeculativeDecodingBase): + + @classmethod + def get_server_args(cls): + return DEFAULT_SERVER_ARGS + [ + "--attention-backend", + "flashinfer", + "--page-size", + "64", + ] + + if __name__ == "__main__": unittest.main()