[Fix] Fix eagle with disable cuda graph (#3411)
This commit is contained in:
@@ -924,38 +924,50 @@ class FlashInferMultiStepDraftBackend:
|
|||||||
self.max_context_len = self.attn_backends[0].max_context_len
|
self.max_context_len = self.attn_backends[0].max_context_len
|
||||||
# Cached variables for generate_draft_decode_kv_indices
|
# Cached variables for generate_draft_decode_kv_indices
|
||||||
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
||||||
self.kv_indptr_stride = self.kv_indptr.shape[1]
|
|
||||||
|
|
||||||
def common_template(self, forward_batch: ForwardBatch, call_fn: int):
|
def common_template(
|
||||||
|
self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
|
||||||
|
):
|
||||||
num_seqs = forward_batch.batch_size
|
num_seqs = forward_batch.batch_size
|
||||||
bs = self.topk * num_seqs
|
bs = self.topk * num_seqs
|
||||||
seq_lens_sum = forward_batch.seq_lens_sum
|
seq_lens_sum = forward_batch.seq_lens_sum
|
||||||
|
|
||||||
self.generate_draft_decode_kv_indices[
|
self.generate_draft_decode_kv_indices[
|
||||||
(self.speculative_num_steps, num_seqs, self.topk)
|
(self.speculative_num_steps, num_seqs, self.topk)
|
||||||
](
|
](
|
||||||
forward_batch.req_pool_indices,
|
forward_batch.req_pool_indices,
|
||||||
forward_batch.req_to_token_pool.req_to_token,
|
forward_batch.req_to_token_pool.req_to_token,
|
||||||
forward_batch.seq_lens,
|
forward_batch.seq_lens,
|
||||||
self.cuda_graph_kv_indices,
|
kv_indices_buffer,
|
||||||
self.kv_indptr,
|
self.kv_indptr,
|
||||||
forward_batch.positions,
|
forward_batch.positions,
|
||||||
num_seqs,
|
num_seqs,
|
||||||
self.topk,
|
self.topk,
|
||||||
self.pool_len,
|
self.pool_len,
|
||||||
self.kv_indptr_stride,
|
kv_indices_buffer.shape[1],
|
||||||
self.kv_indptr.shape[1],
|
self.kv_indptr.shape[1],
|
||||||
triton.next_power_of_2(num_seqs),
|
triton.next_power_of_2(num_seqs),
|
||||||
triton.next_power_of_2(self.speculative_num_steps),
|
triton.next_power_of_2(self.speculative_num_steps),
|
||||||
triton.next_power_of_2(bs),
|
triton.next_power_of_2(bs),
|
||||||
)
|
)
|
||||||
|
|
||||||
for i in range(self.speculative_num_steps):
|
for i in range(self.speculative_num_steps):
|
||||||
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
||||||
forward_batch.spec_info.kv_indices = self.cuda_graph_kv_indices[i][
|
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
|
||||||
: seq_lens_sum * self.topk + bs * (i + 1)
|
: seq_lens_sum * self.topk + bs * (i + 1)
|
||||||
]
|
]
|
||||||
call_fn(i, forward_batch)
|
call_fn(i, forward_batch)
|
||||||
|
|
||||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
|
kv_indices = torch.zeros(
|
||||||
|
(
|
||||||
|
self.speculative_num_steps,
|
||||||
|
forward_batch.batch_size * self.topk * self.max_context_len,
|
||||||
|
),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
|
||||||
def call_fn(i, forward_batch):
|
def call_fn(i, forward_batch):
|
||||||
forward_batch.spec_info.kv_indptr = (
|
forward_batch.spec_info.kv_indptr = (
|
||||||
forward_batch.spec_info.kv_indptr.clone()
|
forward_batch.spec_info.kv_indptr.clone()
|
||||||
@@ -965,7 +977,7 @@ class FlashInferMultiStepDraftBackend:
|
|||||||
)
|
)
|
||||||
self.attn_backends[i].init_forward_metadata(forward_batch)
|
self.attn_backends[i].init_forward_metadata(forward_batch)
|
||||||
|
|
||||||
self.common_template(forward_batch, call_fn)
|
self.common_template(forward_batch, kv_indices, call_fn)
|
||||||
|
|
||||||
def init_cuda_graph_state(self, max_bs: int):
|
def init_cuda_graph_state(self, max_bs: int):
|
||||||
self.cuda_graph_kv_indices = torch.zeros(
|
self.cuda_graph_kv_indices = torch.zeros(
|
||||||
@@ -973,7 +985,6 @@ class FlashInferMultiStepDraftBackend:
|
|||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
)
|
)
|
||||||
self.kv_indptr_stride = self.cuda_graph_kv_indices.shape[1]
|
|
||||||
for i in range(self.speculative_num_steps):
|
for i in range(self.speculative_num_steps):
|
||||||
self.attn_backends[i].init_cuda_graph_state(
|
self.attn_backends[i].init_cuda_graph_state(
|
||||||
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
||||||
@@ -995,7 +1006,7 @@ class FlashInferMultiStepDraftBackend:
|
|||||||
][0]
|
][0]
|
||||||
decode_wrapper.begin_forward = partial(fast_decode_plan, decode_wrapper)
|
decode_wrapper.begin_forward = partial(fast_decode_plan, decode_wrapper)
|
||||||
|
|
||||||
self.common_template(forward_batch, call_fn)
|
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
||||||
|
|
||||||
def init_forward_metadata_replay_cuda_graph(self, forward_batch):
|
def init_forward_metadata_replay_cuda_graph(self, forward_batch):
|
||||||
def call_fn(i, forward_batch):
|
def call_fn(i, forward_batch):
|
||||||
@@ -1009,7 +1020,7 @@ class FlashInferMultiStepDraftBackend:
|
|||||||
spec_info=forward_batch.spec_info,
|
spec_info=forward_batch.spec_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.common_template(forward_batch, call_fn)
|
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
|
|||||||
@@ -22,12 +22,12 @@ from sglang.test.test_utils import (
|
|||||||
class TestEAGLEEngine(unittest.TestCase):
|
class TestEAGLEEngine(unittest.TestCase):
|
||||||
|
|
||||||
def test_eagle_accuracy(self):
|
def test_eagle_accuracy(self):
|
||||||
prompt = "Today is a sunny day and I like"
|
prompt1 = "Today is a sunny day and I like"
|
||||||
sampling_params = {"temperature": 0, "max_new_tokens": 8}
|
sampling_params1 = {"temperature": 0, "max_new_tokens": 8}
|
||||||
|
|
||||||
# Get the reference output
|
# Get the reference output
|
||||||
ref_engine = sgl.Engine(model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
|
ref_engine = sgl.Engine(model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
|
||||||
ref_output = ref_engine.generate(prompt, sampling_params)["text"]
|
ref_output = ref_engine.generate(prompt1, sampling_params1)["text"]
|
||||||
ref_engine.shutdown()
|
ref_engine.shutdown()
|
||||||
|
|
||||||
# Test cases with different configurations
|
# Test cases with different configurations
|
||||||
@@ -60,20 +60,20 @@ class TestEAGLEEngine(unittest.TestCase):
|
|||||||
engine = sgl.Engine(**config)
|
engine = sgl.Engine(**config)
|
||||||
|
|
||||||
# Case 1: Test the output of EAGLE engine is the same as normal engine
|
# Case 1: Test the output of EAGLE engine is the same as normal engine
|
||||||
out1 = engine.generate(prompt, sampling_params)["text"]
|
out1 = engine.generate(prompt1, sampling_params1)["text"]
|
||||||
print(f"{out1=}, {ref_output=}")
|
print(f"{out1=}, {ref_output=}")
|
||||||
self.assertEqual(out1, ref_output)
|
self.assertEqual(out1, ref_output)
|
||||||
|
|
||||||
# Case 2: Test the output of EAGLE engine does not contain unexpected EOS
|
# Case 2: Test the output of EAGLE engine does not contain unexpected EOS
|
||||||
prompt = "[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nToday is a sunny day and I like [/INST]"
|
prompt2 = "[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nToday is a sunny day and I like [/INST]"
|
||||||
sampling_params = {
|
sampling_params2 = {
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
"max_new_tokens": 1024,
|
"max_new_tokens": 1024,
|
||||||
"skip_special_tokens": False,
|
"skip_special_tokens": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
|
tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
|
||||||
out2 = engine.generate(prompt, sampling_params)["text"]
|
out2 = engine.generate(prompt2, sampling_params2)["text"]
|
||||||
print(f"{out2=}")
|
print(f"{out2=}")
|
||||||
tokens = tokenizer.encode(out2, truncation=False)
|
tokens = tokenizer.encode(out2, truncation=False)
|
||||||
assert tokenizer.eos_token_id not in tokens
|
assert tokenizer.eos_token_id not in tokens
|
||||||
@@ -85,8 +85,8 @@ class TestEAGLEEngine(unittest.TestCase):
|
|||||||
"The capital of France is",
|
"The capital of France is",
|
||||||
"The future of AI is",
|
"The future of AI is",
|
||||||
]
|
]
|
||||||
sampling_params = {"temperature": 0, "max_new_tokens": 30}
|
sampling_params3 = {"temperature": 0, "max_new_tokens": 30}
|
||||||
outputs = engine.generate(prompts, sampling_params)
|
outputs = engine.generate(prompts, sampling_params3)
|
||||||
for prompt, output in zip(prompts, outputs):
|
for prompt, output in zip(prompts, outputs):
|
||||||
print("===============================")
|
print("===============================")
|
||||||
print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
|
print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
|
||||||
|
|||||||
Reference in New Issue
Block a user