From 7b4e61fff36a4dd5b485397a972d3202358df9c7 Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Sat, 8 Feb 2025 16:40:00 -0800 Subject: [PATCH] [Fix] Fix eagle with disable cuda graph (#3411) --- .../layers/attention/flashinfer_backend.py | 29 +++++++++++++------ test/srt/test_eagle_infer.py | 18 ++++++------ 2 files changed, 29 insertions(+), 18 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 1f701f946..5e8879b1d 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -924,38 +924,50 @@ class FlashInferMultiStepDraftBackend: self.max_context_len = self.attn_backends[0].max_context_len # Cached variables for generate_draft_decode_kv_indices self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1] - self.kv_indptr_stride = self.kv_indptr.shape[1] - def common_template(self, forward_batch: ForwardBatch, call_fn: int): + def common_template( + self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int + ): num_seqs = forward_batch.batch_size bs = self.topk * num_seqs seq_lens_sum = forward_batch.seq_lens_sum + self.generate_draft_decode_kv_indices[ (self.speculative_num_steps, num_seqs, self.topk) ]( forward_batch.req_pool_indices, forward_batch.req_to_token_pool.req_to_token, forward_batch.seq_lens, - self.cuda_graph_kv_indices, + kv_indices_buffer, self.kv_indptr, forward_batch.positions, num_seqs, self.topk, self.pool_len, - self.kv_indptr_stride, + kv_indices_buffer.shape[1], self.kv_indptr.shape[1], triton.next_power_of_2(num_seqs), triton.next_power_of_2(self.speculative_num_steps), triton.next_power_of_2(bs), ) + for i in range(self.speculative_num_steps): forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1] - forward_batch.spec_info.kv_indices = self.cuda_graph_kv_indices[i][ + forward_batch.spec_info.kv_indices = kv_indices_buffer[i][ : seq_lens_sum * self.topk + bs * (i + 1) ] call_fn(i, forward_batch) def init_forward_metadata(self, forward_batch: ForwardBatch): + kv_indices = torch.zeros( + ( + self.speculative_num_steps, + forward_batch.batch_size * self.topk * self.max_context_len, + ), + dtype=torch.int32, + device="cuda", + ) + def call_fn(i, forward_batch): forward_batch.spec_info.kv_indptr = ( forward_batch.spec_info.kv_indptr.clone() @@ -965,7 +977,7 @@ class FlashInferMultiStepDraftBackend: ) self.attn_backends[i].init_forward_metadata(forward_batch) - self.common_template(forward_batch, call_fn) + self.common_template(forward_batch, kv_indices, call_fn) def init_cuda_graph_state(self, max_bs: int): self.cuda_graph_kv_indices = torch.zeros( @@ -973,7 +985,6 @@ class FlashInferMultiStepDraftBackend: dtype=torch.int32, device="cuda", ) - self.kv_indptr_stride = self.cuda_graph_kv_indices.shape[1] for i in range(self.speculative_num_steps): self.attn_backends[i].init_cuda_graph_state( max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i] @@ -995,7 +1006,7 @@ class FlashInferMultiStepDraftBackend: ][0] decode_wrapper.begin_forward = partial(fast_decode_plan, decode_wrapper) - self.common_template(forward_batch, call_fn) + self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn) def init_forward_metadata_replay_cuda_graph(self, forward_batch): def call_fn(i, forward_batch): @@ -1009,7 +1020,7 @@ class FlashInferMultiStepDraftBackend: spec_info=forward_batch.spec_info, ) - self.common_template(forward_batch, call_fn) + self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn) @triton.jit diff --git a/test/srt/test_eagle_infer.py b/test/srt/test_eagle_infer.py index b5b17dad1..b04b13211 100644 --- a/test/srt/test_eagle_infer.py +++ b/test/srt/test_eagle_infer.py @@ -22,12 +22,12 @@ from sglang.test.test_utils import ( class TestEAGLEEngine(unittest.TestCase): def test_eagle_accuracy(self): - prompt = "Today is a sunny day and I like" - sampling_params = {"temperature": 0, "max_new_tokens": 8} + prompt1 = "Today is a sunny day and I like" + sampling_params1 = {"temperature": 0, "max_new_tokens": 8} # Get the reference output ref_engine = sgl.Engine(model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST) - ref_output = ref_engine.generate(prompt, sampling_params)["text"] + ref_output = ref_engine.generate(prompt1, sampling_params1)["text"] ref_engine.shutdown() # Test cases with different configurations @@ -60,20 +60,20 @@ class TestEAGLEEngine(unittest.TestCase): engine = sgl.Engine(**config) # Case 1: Test the output of EAGLE engine is the same as normal engine - out1 = engine.generate(prompt, sampling_params)["text"] + out1 = engine.generate(prompt1, sampling_params1)["text"] print(f"{out1=}, {ref_output=}") self.assertEqual(out1, ref_output) # Case 2: Test the output of EAGLE engine does not contain unexpected EOS - prompt = "[INST] <>\\nYou are a helpful assistant.\\n<>\\nToday is a sunny day and I like [/INST]" - sampling_params = { + prompt2 = "[INST] <>\\nYou are a helpful assistant.\\n<>\\nToday is a sunny day and I like [/INST]" + sampling_params2 = { "temperature": 0, "max_new_tokens": 1024, "skip_special_tokens": False, } tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST) - out2 = engine.generate(prompt, sampling_params)["text"] + out2 = engine.generate(prompt2, sampling_params2)["text"] print(f"{out2=}") tokens = tokenizer.encode(out2, truncation=False) assert tokenizer.eos_token_id not in tokens @@ -85,8 +85,8 @@ class TestEAGLEEngine(unittest.TestCase): "The capital of France is", "The future of AI is", ] - sampling_params = {"temperature": 0, "max_new_tokens": 30} - outputs = engine.generate(prompts, sampling_params) + sampling_params3 = {"temperature": 0, "max_new_tokens": 30} + outputs = engine.generate(prompts, sampling_params3) for prompt, output in zip(prompts, outputs): print("===============================") print(f"Prompt: {prompt}\nGenerated text: {output['text']}")