Clean up eagle code (#2756)
This commit is contained in:
@@ -74,11 +74,6 @@ class LogitsMetadata:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_forward_batch(cls, forward_batch: ForwardBatch):
|
def from_forward_batch(cls, forward_batch: ForwardBatch):
|
||||||
if forward_batch.spec_info:
|
|
||||||
capture_hidden_mode = forward_batch.spec_info.capture_hidden_mode
|
|
||||||
else:
|
|
||||||
capture_hidden_mode = CaptureHiddenMode.NULL
|
|
||||||
|
|
||||||
if forward_batch.forward_mode.is_extend() and forward_batch.return_logprob:
|
if forward_batch.forward_mode.is_extend() and forward_batch.return_logprob:
|
||||||
extend_return_logprob = True
|
extend_return_logprob = True
|
||||||
extend_return_top_logprob = any(
|
extend_return_top_logprob = any(
|
||||||
@@ -98,7 +93,7 @@ class LogitsMetadata:
|
|||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
forward_mode=forward_batch.forward_mode,
|
forward_mode=forward_batch.forward_mode,
|
||||||
capture_hidden_mode=capture_hidden_mode,
|
capture_hidden_mode=forward_batch.capture_hidden_mode,
|
||||||
extend_return_logprob=extend_return_logprob,
|
extend_return_logprob=extend_return_logprob,
|
||||||
extend_return_top_logprob=extend_return_top_logprob,
|
extend_return_top_logprob=extend_return_top_logprob,
|
||||||
extend_seq_lens=forward_batch.extend_seq_lens,
|
extend_seq_lens=forward_batch.extend_seq_lens,
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
|||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
||||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
@@ -1163,6 +1163,11 @@ class ScheduleBatch:
|
|||||||
input_embeds=self.input_embeds,
|
input_embeds=self.input_embeds,
|
||||||
spec_algorithm=self.spec_algorithm,
|
spec_algorithm=self.spec_algorithm,
|
||||||
spec_info=self.spec_info,
|
spec_info=self.spec_info,
|
||||||
|
capture_hidden_mode=(
|
||||||
|
getattr(self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL)
|
||||||
|
if self.spec_info
|
||||||
|
else CaptureHiddenMode.NULL
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
@@ -1237,6 +1242,7 @@ class ModelWorkerBatch:
|
|||||||
# Speculative decoding
|
# Speculative decoding
|
||||||
spec_algorithm: SpeculativeAlgorithm = None
|
spec_algorithm: SpeculativeAlgorithm = None
|
||||||
spec_info: Optional[SpecInfo] = None
|
spec_info: Optional[SpecInfo] = None
|
||||||
|
capture_hidden_mode: CaptureHiddenMode = None
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
|
|||||||
@@ -962,10 +962,13 @@ class Scheduler:
|
|||||||
self.tp_worker.forward_batch_generation(model_worker_batch)
|
self.tp_worker.forward_batch_generation(model_worker_batch)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logits_output, next_token_ids, model_worker_batch, spec_info = (
|
(
|
||||||
self.draft_worker.forward_batch_speculative_generation(batch)
|
logits_output,
|
||||||
)
|
next_token_ids,
|
||||||
batch.spec_info = spec_info
|
model_worker_batch,
|
||||||
|
num_accepted_tokens,
|
||||||
|
) = self.draft_worker.forward_batch_speculative_generation(batch)
|
||||||
|
self.num_generated_tokens += num_accepted_tokens
|
||||||
elif batch.forward_mode.is_idle():
|
elif batch.forward_mode.is_idle():
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
self.tp_worker.forward_batch_idle(model_worker_batch)
|
self.tp_worker.forward_batch_idle(model_worker_batch)
|
||||||
|
|||||||
@@ -322,6 +322,8 @@ class CudaGraphRunner:
|
|||||||
global_num_tokens = None
|
global_num_tokens = None
|
||||||
gathered_buffer = None
|
gathered_buffer = None
|
||||||
|
|
||||||
|
spec_info = self.get_spec_info(num_tokens, positions)
|
||||||
|
|
||||||
forward_batch = ForwardBatch(
|
forward_batch = ForwardBatch(
|
||||||
forward_mode=self.capture_forward_mode,
|
forward_mode=self.capture_forward_mode,
|
||||||
batch_size=bs,
|
batch_size=bs,
|
||||||
@@ -341,7 +343,10 @@ class CudaGraphRunner:
|
|||||||
mrope_positions=mrope_positions,
|
mrope_positions=mrope_positions,
|
||||||
gathered_buffer=gathered_buffer,
|
gathered_buffer=gathered_buffer,
|
||||||
spec_algorithm=self.model_runner.spec_algorithm,
|
spec_algorithm=self.model_runner.spec_algorithm,
|
||||||
spec_info=self.get_spec_info(num_tokens, positions),
|
spec_info=spec_info,
|
||||||
|
capture_hidden_mode=(
|
||||||
|
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Attention backend
|
# Attention backend
|
||||||
@@ -446,10 +451,10 @@ class CudaGraphRunner:
|
|||||||
|
|
||||||
if self.model_runner.is_draft_worker:
|
if self.model_runner.is_draft_worker:
|
||||||
spec_info = EAGLEDraftInput()
|
spec_info = EAGLEDraftInput()
|
||||||
|
spec_info.load_server_args(self.model_runner.server_args)
|
||||||
spec_info.hidden_states = self.hidden_states[:num_tokens]
|
spec_info.hidden_states = self.hidden_states[:num_tokens]
|
||||||
spec_info.positions = positions
|
spec_info.positions = positions
|
||||||
spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
|
spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||||
spec_info.init(self.model_runner.server_args)
|
|
||||||
else:
|
else:
|
||||||
spec_info = EagleVerifyInput(
|
spec_info = EagleVerifyInput(
|
||||||
None,
|
None,
|
||||||
|
|||||||
@@ -107,6 +107,21 @@ class ForwardMode(IntEnum):
|
|||||||
return self == ForwardMode.DUMMY_FIRST
|
return self == ForwardMode.DUMMY_FIRST
|
||||||
|
|
||||||
|
|
||||||
|
class CaptureHiddenMode(IntEnum):
|
||||||
|
NULL = auto()
|
||||||
|
FULL = auto()
|
||||||
|
LAST = auto()
|
||||||
|
|
||||||
|
def need_capture(self):
|
||||||
|
return self != CaptureHiddenMode.NULL
|
||||||
|
|
||||||
|
def is_full(self):
|
||||||
|
return self == CaptureHiddenMode.FULL
|
||||||
|
|
||||||
|
def is_last(self):
|
||||||
|
return self == CaptureHiddenMode.LAST
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ForwardBatch:
|
class ForwardBatch:
|
||||||
"""Store all inputs of a forward pass."""
|
"""Store all inputs of a forward pass."""
|
||||||
@@ -174,6 +189,7 @@ class ForwardBatch:
|
|||||||
# Speculative decoding
|
# Speculative decoding
|
||||||
spec_info: SpecInfo = None
|
spec_info: SpecInfo = None
|
||||||
spec_algorithm: SpeculativeAlgorithm = None
|
spec_algorithm: SpeculativeAlgorithm = None
|
||||||
|
capture_hidden_mode: CaptureHiddenMode = None
|
||||||
|
|
||||||
# For Qwen2-VL
|
# For Qwen2-VL
|
||||||
mrope_positions: torch.Tensor = None
|
mrope_positions: torch.Tensor = None
|
||||||
@@ -265,6 +281,7 @@ class ForwardBatch:
|
|||||||
sampling_info=batch.sampling_info,
|
sampling_info=batch.sampling_info,
|
||||||
spec_algorithm=batch.spec_algorithm,
|
spec_algorithm=batch.spec_algorithm,
|
||||||
spec_info=batch.spec_info,
|
spec_info=batch.spec_info,
|
||||||
|
capture_hidden_mode=batch.capture_hidden_mode,
|
||||||
input_embeds=batch.input_embeds,
|
input_embeds=batch.input_embeds,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -400,18 +417,3 @@ def compute_position_torch(
|
|||||||
@maybe_torch_compile(dynamic=True)
|
@maybe_torch_compile(dynamic=True)
|
||||||
def clamp_position(seq_lens):
|
def clamp_position(seq_lens):
|
||||||
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
|
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
|
||||||
|
|
||||||
|
|
||||||
class CaptureHiddenMode(IntEnum):
|
|
||||||
NULL = auto()
|
|
||||||
FULL = auto()
|
|
||||||
LAST = auto()
|
|
||||||
|
|
||||||
def need_capture(self):
|
|
||||||
return self != CaptureHiddenMode.NULL
|
|
||||||
|
|
||||||
def is_full(self):
|
|
||||||
return self == CaptureHiddenMode.FULL
|
|
||||||
|
|
||||||
def is_last(self):
|
|
||||||
return self == CaptureHiddenMode.LAST
|
|
||||||
|
|||||||
@@ -9,12 +9,11 @@ import triton.language as tl
|
|||||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||||
create_flashinfer_kv_indices_triton,
|
create_flashinfer_kv_indices_triton,
|
||||||
)
|
)
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||||
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel
|
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel
|
||||||
from sglang.srt.speculative.spec_info import SpecInfo
|
from sglang.srt.speculative.spec_info import SpecInfo
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from python.sglang.srt.layers.sampler import SampleOutput
|
|
||||||
from python.sglang.srt.managers.schedule_batch import ScheduleBatch
|
from python.sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
|
||||||
@@ -179,19 +178,9 @@ def generate_draft_decode_kv_indices(
|
|||||||
|
|
||||||
|
|
||||||
class EAGLEDraftInput(SpecInfo):
|
class EAGLEDraftInput(SpecInfo):
|
||||||
hidden_states: torch.Tensor = None
|
def __init__(self):
|
||||||
verified_id: torch.Tensor = None
|
|
||||||
positions: torch.Tensor = None
|
|
||||||
accept_length: torch.Tensor = None
|
|
||||||
has_finished: bool = False
|
|
||||||
unfinished_index: List[int] = None
|
|
||||||
|
|
||||||
def init(self, server_args: ServerArgs):
|
|
||||||
self.prev_mode = ForwardMode.DECODE
|
self.prev_mode = ForwardMode.DECODE
|
||||||
self.sample_output = None
|
self.sample_output = None
|
||||||
self.topk: int = server_args.speculative_eagle_topk
|
|
||||||
self.num_verify_token: int = server_args.speculative_num_draft_tokens
|
|
||||||
self.spec_steps = server_args.speculative_num_steps
|
|
||||||
|
|
||||||
self.scores: torch.Tensor = None
|
self.scores: torch.Tensor = None
|
||||||
self.score_list: List[torch.Tensor] = []
|
self.score_list: List[torch.Tensor] = []
|
||||||
@@ -200,11 +189,20 @@ class EAGLEDraftInput(SpecInfo):
|
|||||||
self.parents_list: List[torch.Tensor] = []
|
self.parents_list: List[torch.Tensor] = []
|
||||||
self.cache_list: List[torch.Tenor] = []
|
self.cache_list: List[torch.Tenor] = []
|
||||||
self.iter = 0
|
self.iter = 0
|
||||||
self.root_token: int = None
|
|
||||||
|
|
||||||
assert self.topk <= 10, "topk should <= 10"
|
self.hidden_states: torch.Tensor = None
|
||||||
|
self.verified_id: torch.Tensor = None
|
||||||
|
self.positions: torch.Tensor = None
|
||||||
|
self.accept_length: torch.Tensor = None
|
||||||
|
self.has_finished: bool = False
|
||||||
|
self.unfinished_index: List[int] = None
|
||||||
|
|
||||||
def prepare_for_extend(self, batch: ForwardBatch):
|
def load_server_args(self, server_args: ServerArgs):
|
||||||
|
self.topk: int = server_args.speculative_eagle_topk
|
||||||
|
self.num_verify_token: int = server_args.speculative_num_draft_tokens
|
||||||
|
self.spec_steps = server_args.speculative_num_steps
|
||||||
|
|
||||||
|
def prepare_for_extend(self, batch: ScheduleBatch):
|
||||||
req_pool_indices = batch.alloc_req_slots(len(batch.reqs))
|
req_pool_indices = batch.alloc_req_slots(len(batch.reqs))
|
||||||
out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
|
out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
|
||||||
batch.out_cache_loc = out_cache_loc
|
batch.out_cache_loc = out_cache_loc
|
||||||
@@ -226,81 +224,72 @@ class EAGLEDraftInput(SpecInfo):
|
|||||||
|
|
||||||
pt += req.extend_input_len
|
pt += req.extend_input_len
|
||||||
|
|
||||||
seq_lens = [0] + batch.extend_lens
|
# TODO: support batching inputs
|
||||||
input_ids = batch.input_ids.tolist()
|
assert len(batch.extend_lens) == 1
|
||||||
verified_id = batch.spec_info.verified_id.tolist()
|
batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id))
|
||||||
model_input_ids = []
|
|
||||||
for i in range(len(seq_lens) - 1):
|
|
||||||
model_input_ids.extend(
|
|
||||||
input_ids[seq_lens[i] + 1 : seq_lens[i + 1]] + [verified_id[i]]
|
|
||||||
)
|
|
||||||
batch.input_ids = torch.tensor(
|
|
||||||
model_input_ids, dtype=torch.int32, device="cuda"
|
|
||||||
)
|
|
||||||
|
|
||||||
def capture_for_decode(
|
|
||||||
self,
|
|
||||||
sample_output: SampleOutput,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
prev_mode: ForwardMode,
|
|
||||||
):
|
|
||||||
self.sample_output = sample_output
|
|
||||||
self.prev_mode = prev_mode
|
|
||||||
self.hidden_states = hidden_states
|
|
||||||
|
|
||||||
def prepare_for_decode(self, batch: ScheduleBatch):
|
def prepare_for_decode(self, batch: ScheduleBatch):
|
||||||
prob = self.sample_output # b * (1/topk), vocab
|
prob = self.sample_output # shape: (b * top_k, vocab) or (b, vocab)
|
||||||
top = torch.topk(prob, self.topk, dim=-1)
|
top = torch.topk(prob, self.topk, dim=-1)
|
||||||
topk_index, topk_p = top.indices, top.values # b * (1/topk), topk
|
topk_index, topk_p = (
|
||||||
if self.prev_mode == ForwardMode.DECODE:
|
top.indices,
|
||||||
|
top.values,
|
||||||
|
) # shape: (b * top_k, top_k) or (b, top_k)
|
||||||
|
|
||||||
|
if self.prev_mode.is_decode():
|
||||||
scores = torch.mul(
|
scores = torch.mul(
|
||||||
self.scores.unsqueeze(2), topk_p.reshape(-1, self.topk, self.topk)
|
self.scores.unsqueeze(2), topk_p.reshape(-1, self.topk, self.topk)
|
||||||
) # (b, topk) mul (b * topk ,topk) -> b, topk, topk
|
) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
|
||||||
topk_cs = torch.topk(
|
topk_cs = torch.topk(
|
||||||
scores.flatten(start_dim=1), self.topk, dim=-1
|
scores.flatten(start_dim=1), self.topk, dim=-1
|
||||||
) # (b, topk)
|
) # (b, topk)
|
||||||
topk_cs_index, topk_cs_p = topk_cs.indices, topk_cs.values
|
topk_cs_index, topk_cs_p = topk_cs.indices, topk_cs.values
|
||||||
self.scores = topk_cs_p
|
|
||||||
|
|
||||||
selected_input_index = topk_cs_index.flatten() // self.topk # b* topk
|
|
||||||
|
|
||||||
|
selected_input_index = (
|
||||||
|
topk_cs_index.flatten() // self.topk
|
||||||
|
) # shape: (b * topk)
|
||||||
batch.spec_info.hidden_states = batch.spec_info.hidden_states[
|
batch.spec_info.hidden_states = batch.spec_info.hidden_states[
|
||||||
selected_input_index, :
|
selected_input_index, :
|
||||||
]
|
]
|
||||||
|
|
||||||
topk_index = topk_index.reshape(-1, self.topk**2)
|
topk_index = topk_index.reshape(-1, self.topk**2)
|
||||||
batch.input_ids = torch.gather(
|
batch.input_ids = torch.gather(
|
||||||
topk_index, index=topk_cs_index, dim=1
|
topk_index, index=topk_cs_index, dim=1
|
||||||
).flatten()
|
).flatten()
|
||||||
batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
|
batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids))
|
||||||
self.score_list.append(scores) # b, topk, topk
|
|
||||||
self.token_list.append(topk_index) # b, topk*topk
|
self.scores = topk_cs_p
|
||||||
|
self.score_list.append(scores) # (b, topk, topk)
|
||||||
|
self.token_list.append(topk_index) # (b, topk * topk)
|
||||||
self.origin_score_list.append(topk_p.reshape(topk_index.shape))
|
self.origin_score_list.append(topk_p.reshape(topk_index.shape))
|
||||||
self.parents_list.append(
|
self.parents_list.append(
|
||||||
topk_cs_index + (self.topk**2 * (self.iter - 1) + self.topk)
|
topk_cs_index + (self.topk**2 * (self.iter - 1) + self.topk)
|
||||||
) # b, topk
|
) # shape: (b, topk)
|
||||||
|
else:
|
||||||
elif self.prev_mode in (ForwardMode.EXTEND, ForwardMode.DRAFT_EXTEND):
|
# ForwardMode.EXTEND or ForwardMode.DRAFT_EXTEND
|
||||||
self.scores = topk_p # b, top_k
|
|
||||||
self.score_list.append(topk_p.unsqueeze(1))
|
|
||||||
self.token_list.append(topk_index)
|
|
||||||
self.origin_score_list.append(topk_p)
|
|
||||||
batch.spec_info.hidden_states = (
|
batch.spec_info.hidden_states = (
|
||||||
batch.spec_info.hidden_states.repeat_interleave(self.topk, 0)
|
batch.spec_info.hidden_states.repeat_interleave(self.topk, dim=0)
|
||||||
)
|
)
|
||||||
|
|
||||||
batch.input_ids = topk_index.flatten()
|
batch.input_ids = topk_index.flatten()
|
||||||
batch.out_cache_loc = batch.alloc_token_slots(topk_index.numel())
|
batch.out_cache_loc = batch.alloc_token_slots(topk_index.numel())
|
||||||
|
|
||||||
|
self.scores = topk_p # shape: (b, topk)
|
||||||
|
self.score_list.append(topk_p.unsqueeze(1)) # shape: (b, 1, topk)
|
||||||
|
self.token_list.append(topk_index) # shape: (b, topk)
|
||||||
|
self.origin_score_list.append(topk_p)
|
||||||
self.parents_list.append(
|
self.parents_list.append(
|
||||||
torch.arange(-1, self.topk, dtype=torch.long, device="cuda")
|
torch.arange(-1, self.topk, dtype=torch.long, device="cuda")
|
||||||
.unsqueeze(0)
|
.unsqueeze(0)
|
||||||
.repeat(self.scores.shape[0], 1)
|
.repeat(self.scores.shape[0], 1)
|
||||||
) # b, topk+1
|
) # shape: (b, topk + 1)
|
||||||
self.cache_list.append(batch.out_cache_loc)
|
self.cache_list.append(batch.out_cache_loc)
|
||||||
self.positions = (
|
self.positions = (
|
||||||
batch.seq_lens[:, None]
|
batch.seq_lens[:, None]
|
||||||
+ torch.ones([1, self.topk], device="cuda", dtype=torch.long) * self.iter
|
+ torch.ones([1, self.topk], device="cuda", dtype=torch.long) * self.iter
|
||||||
).flatten()
|
).flatten()
|
||||||
|
|
||||||
bs = batch.seq_lens.numel()
|
bs = len(batch.seq_lens)
|
||||||
assign_req_to_token_pool[(bs,)](
|
assign_req_to_token_pool[(bs,)](
|
||||||
batch.req_pool_indices,
|
batch.req_pool_indices,
|
||||||
batch.req_to_token_pool.req_to_token,
|
batch.req_to_token_pool.req_to_token,
|
||||||
@@ -419,11 +408,6 @@ class EAGLEDraftInput(SpecInfo):
|
|||||||
)
|
)
|
||||||
return bs, kv_indices, cum_kv_seq_len
|
return bs, kv_indices, cum_kv_seq_len
|
||||||
|
|
||||||
def clear(self):
|
|
||||||
self.iter = 0
|
|
||||||
self.score_list.clear()
|
|
||||||
self.positions = None
|
|
||||||
|
|
||||||
def clear_draft_cache(self, batch):
|
def clear_draft_cache(self, batch):
|
||||||
draft_cache = torch.cat(self.cache_list, dim=0)
|
draft_cache = torch.cat(self.cache_list, dim=0)
|
||||||
batch.token_to_kv_pool.free(draft_cache)
|
batch.token_to_kv_pool.free(draft_cache)
|
||||||
@@ -460,7 +444,6 @@ class EAGLEDraftInput(SpecInfo):
|
|||||||
[self.hidden_states, spec_info.hidden_states], axis=0
|
[self.hidden_states, spec_info.hidden_states], axis=0
|
||||||
)
|
)
|
||||||
self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0)
|
self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0)
|
||||||
# self.positions = torch.cat([self.positions, spec_info.positions], axis=0)
|
|
||||||
self.sample_output = torch.cat([self.sample_output, spec_info.sample_output])
|
self.sample_output = torch.cat([self.sample_output, spec_info.sample_output])
|
||||||
|
|
||||||
|
|
||||||
@@ -568,9 +551,6 @@ class EagleVerifyInput(SpecInfo):
|
|||||||
)
|
)
|
||||||
|
|
||||||
accept_index = accept_index[accept_index != -1]
|
accept_index = accept_index[accept_index != -1]
|
||||||
# extract_index = extract_index[extract_index != 0]
|
|
||||||
|
|
||||||
draft_input = EAGLEDraftInput()
|
|
||||||
|
|
||||||
accept_length_cpu = accept_length.tolist()
|
accept_length_cpu = accept_length.tolist()
|
||||||
verified_id = predict[accept_index]
|
verified_id = predict[accept_index]
|
||||||
@@ -596,6 +576,7 @@ class EagleVerifyInput(SpecInfo):
|
|||||||
# retracted_reqs, new_token_ratio = batch.retract_decode()
|
# retracted_reqs, new_token_ratio = batch.retract_decode()
|
||||||
|
|
||||||
low = 0
|
low = 0
|
||||||
|
draft_input = EAGLEDraftInput()
|
||||||
for i, (req, verified_len) in enumerate(zip(batch.reqs, accept_length_cpu)):
|
for i, (req, verified_len) in enumerate(zip(batch.reqs, accept_length_cpu)):
|
||||||
req.output_ids.extend(verified_id_cpu[low : low + verified_len + 1])
|
req.output_ids.extend(verified_id_cpu[low : low + verified_len + 1])
|
||||||
req.check_finished()
|
req.check_finished()
|
||||||
@@ -615,4 +596,10 @@ class EagleVerifyInput(SpecInfo):
|
|||||||
draft_input.unfinished_index = unfinished_index
|
draft_input.unfinished_index = unfinished_index
|
||||||
|
|
||||||
logits_output.next_token_logits = logits_output.next_token_logits[accept_index]
|
logits_output.next_token_logits = logits_output.next_token_logits[accept_index]
|
||||||
return draft_input, logits_output, verified_id, finished_extend_len
|
return (
|
||||||
|
draft_input,
|
||||||
|
logits_output,
|
||||||
|
verified_id,
|
||||||
|
finished_extend_len,
|
||||||
|
accept_length_cpu,
|
||||||
|
)
|
||||||
|
|||||||
@@ -51,63 +51,72 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
batch.spec_info.prepare_for_decode(batch)
|
batch.spec_info.prepare_for_decode(batch)
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||||
forward_batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||||
logits_output = self.model_runner.forward(forward_batch)
|
logits_output = self.model_runner.forward(forward_batch)
|
||||||
self.capture_for_decode(logits_output, forward_batch)
|
self.capture_for_decode(logits_output, forward_batch)
|
||||||
|
|
||||||
def forward_draft_extend(self, batch: ScheduleBatch):
|
def forward_draft_extend(self, batch: ScheduleBatch):
|
||||||
self._swap_mem_pool(batch, self.model_runner)
|
self._set_mem_pool(batch, self.model_runner)
|
||||||
batch.spec_info.prepare_for_extend(batch)
|
batch.spec_info.prepare_for_extend(batch)
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||||
forward_batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||||
logits_output = self.model_runner.forward(forward_batch)
|
logits_output = self.model_runner.forward(forward_batch)
|
||||||
self.capture_for_decode(logits_output, forward_batch)
|
self.capture_for_decode(logits_output, forward_batch)
|
||||||
self._swap_mem_pool(batch, self.target_worker.model_runner)
|
self._set_mem_pool(batch, self.target_worker.model_runner)
|
||||||
|
|
||||||
def forward_batch_speculative_generation(self, batch: ScheduleBatch):
|
def forward_batch_speculative_generation(self, batch: ScheduleBatch):
|
||||||
if batch.forward_mode.is_decode():
|
if batch.forward_mode.is_decode():
|
||||||
prev_spec_info = batch.spec_info
|
# Draft
|
||||||
self._swap_mem_pool(batch, self.model_runner)
|
self._set_mem_pool(batch, self.model_runner)
|
||||||
for i in range(self.server_args.speculative_num_steps):
|
for i in range(self.server_args.speculative_num_steps):
|
||||||
self.forward_draft_decode(batch)
|
self.forward_draft_decode(batch)
|
||||||
batch.spec_info.clear_draft_cache(batch)
|
batch.spec_info.clear_draft_cache(batch)
|
||||||
self._swap_mem_pool(batch, self.target_worker.model_runner)
|
self._set_mem_pool(batch, self.target_worker.model_runner)
|
||||||
|
|
||||||
|
# Verify
|
||||||
(
|
(
|
||||||
next_draft_input,
|
next_draft_input,
|
||||||
logits_output,
|
logits_output,
|
||||||
verified_id,
|
verified_id,
|
||||||
self.finish_extend_len,
|
self.finish_extend_len,
|
||||||
|
accept_length_cpu,
|
||||||
model_worker_batch,
|
model_worker_batch,
|
||||||
) = self.verify(batch)
|
) = self.verify(batch)
|
||||||
next_draft_input.init(self.server_args)
|
next_draft_input.load_server_args(self.server_args)
|
||||||
batch.spec_info = next_draft_input
|
batch.spec_info = next_draft_input
|
||||||
# if it is None, means all requsets are finished
|
# if it is None, means all requsets are finished
|
||||||
if batch.spec_info.verified_id is not None:
|
if batch.spec_info.verified_id is not None:
|
||||||
self.forward_extend_after_decode(batch)
|
self.forward_draft_extend_after_decode(batch)
|
||||||
batch.spec_info = prev_spec_info
|
return (
|
||||||
return logits_output, verified_id, model_worker_batch, next_draft_input
|
logits_output,
|
||||||
|
verified_id,
|
||||||
|
model_worker_batch,
|
||||||
|
sum(accept_length_cpu),
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
spec_info = EAGLEDraftInput()
|
# Forward with the target model and get hidden states.
|
||||||
spec_info.init(self.server_args)
|
# We need the full hidden states to prefill the KV cache of the draft model.
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
model_worker_batch.spec_info = spec_info
|
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||||
spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
|
|
||||||
logits_output, next_token_ids = self.target_worker.forward_batch_generation(
|
logits_output, next_token_ids = self.target_worker.forward_batch_generation(
|
||||||
model_worker_batch
|
model_worker_batch
|
||||||
)
|
)
|
||||||
model_worker_batch.spec_info.verified_id = next_token_ids
|
|
||||||
model_worker_batch.spec_info.hidden_states = logits_output.hidden_states
|
# Forward with the draft model.
|
||||||
|
spec_info = EAGLEDraftInput()
|
||||||
|
spec_info.load_server_args(self.server_args)
|
||||||
|
spec_info.hidden_states = logits_output.hidden_states
|
||||||
|
spec_info.verified_id = next_token_ids
|
||||||
batch.spec_info = spec_info
|
batch.spec_info = spec_info
|
||||||
self.forward_draft_extend(batch)
|
self.forward_draft_extend(batch)
|
||||||
batch.spec_info = None
|
return logits_output, next_token_ids, model_worker_batch, 0
|
||||||
return logits_output, next_token_ids, model_worker_batch, spec_info
|
|
||||||
|
|
||||||
def verify(self, batch: ScheduleBatch):
|
def verify(self, batch: ScheduleBatch):
|
||||||
verify_input = batch.spec_info.prepare_for_verify(batch)
|
verify_input = batch.spec_info.prepare_for_verify(batch)
|
||||||
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
|
||||||
verify_input.prepare_for_verify(batch)
|
verify_input.prepare_for_verify(batch)
|
||||||
|
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
||||||
batch.spec_info = verify_input
|
batch.spec_info = verify_input
|
||||||
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
|
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
@@ -119,38 +128,41 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
batch.forward_mode = ForwardMode.DECODE
|
batch.forward_mode = ForwardMode.DECODE
|
||||||
return res + (model_worker_batch,)
|
return res + (model_worker_batch,)
|
||||||
|
|
||||||
def _swap_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner):
|
def _set_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner):
|
||||||
batch.token_to_kv_pool = runner.token_to_kv_pool
|
batch.token_to_kv_pool = runner.token_to_kv_pool
|
||||||
batch.req_to_token_pool = runner.req_to_token_pool
|
batch.req_to_token_pool = runner.req_to_token_pool
|
||||||
|
|
||||||
def forward_extend_after_decode(self, batch: ScheduleBatch):
|
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
||||||
self._swap_mem_pool(batch, self.model_runner)
|
self._set_mem_pool(batch, self.model_runner)
|
||||||
batch.forward_mode = ForwardMode.DRAFT_EXTEND
|
batch.forward_mode = ForwardMode.DRAFT_EXTEND
|
||||||
if batch.spec_info.has_finished:
|
if batch.spec_info.has_finished:
|
||||||
index = batch.spec_info.unfinished_index
|
index = batch.spec_info.unfinished_index
|
||||||
seq_lens = batch.seq_lens
|
seq_lens = batch.seq_lens
|
||||||
batch.seq_lens = batch.seq_lens[index]
|
batch.seq_lens = batch.seq_lens[index]
|
||||||
|
|
||||||
batch.spec_info.prepare_extend_after_decode(batch)
|
batch.spec_info.prepare_extend_after_decode(batch)
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||||
forward_batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||||
logits_output = self.model_runner.forward(forward_batch)
|
logits_output = self.model_runner.forward(forward_batch)
|
||||||
|
|
||||||
batch.spec_info.hidden_states = logits_output.hidden_states
|
batch.spec_info.hidden_states = logits_output.hidden_states
|
||||||
self.capture_for_decode(logits_output, forward_batch)
|
self.capture_for_decode(logits_output, forward_batch)
|
||||||
batch.forward_mode = ForwardMode.DECODE
|
batch.forward_mode = ForwardMode.DECODE
|
||||||
if batch.spec_info.has_finished:
|
if batch.spec_info.has_finished:
|
||||||
batch.seq_lens = seq_lens
|
batch.seq_lens = seq_lens
|
||||||
self._swap_mem_pool(batch, self.target_worker.model_runner)
|
self._set_mem_pool(batch, self.target_worker.model_runner)
|
||||||
|
|
||||||
def capture_for_decode(self, logits_output, forward_batch):
|
def capture_for_decode(
|
||||||
if isinstance(logits_output, LogitsProcessorOutput):
|
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
|
||||||
logits = logits_output.next_token_logits
|
):
|
||||||
sample_output = torch.softmax(
|
sample_output = torch.softmax(
|
||||||
logits, dim=-1
|
logits_output.next_token_logits, dim=-1
|
||||||
) # TODO: Support more sampling method @kavioyu
|
) # TODO(kavioyu): Support more sampling methods
|
||||||
forward_batch.spec_info.capture_for_decode(
|
spec_info = forward_batch.spec_info
|
||||||
sample_output, logits_output.hidden_states, forward_batch.forward_mode
|
spec_info.sample_output = sample_output
|
||||||
)
|
spec_info.hidden_states = logits_output.hidden_states
|
||||||
|
spec_info.prev_mode = forward_batch.forward_mode
|
||||||
|
|
||||||
# Don't support prefix share now.
|
# Don't support prefix share now.
|
||||||
def finish_request(self, reqs: Union[Req, List[Req]]):
|
def finish_request(self, reqs: Union[Req, List[Req]]):
|
||||||
|
|||||||
Reference in New Issue
Block a user