Simplify prepare_extend_after_decode (#6987)
This commit is contained in:
@@ -1636,7 +1636,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
if self.spec_info:
|
||||
self.spec_info.merge_batch(other.spec_info)
|
||||
|
||||
def get_model_worker_batch(self) -> ModelWorkerBatch:
|
||||
def get_model_worker_batch(
|
||||
self, seq_lens_cpu_cache: Optional[torch.Tensor] = None
|
||||
) -> ModelWorkerBatch:
|
||||
if self.forward_mode.is_decode_or_idle():
|
||||
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
|
||||
else:
|
||||
@@ -1646,16 +1648,20 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
|
||||
# Create seq_lens_cpu when needed
|
||||
if (
|
||||
(
|
||||
global_server_args_dict["attention_backend"] == "fa3"
|
||||
or (
|
||||
global_server_args_dict["use_mla_backend"]
|
||||
and global_server_args_dict["attention_backend"] == "flashinfer"
|
||||
)
|
||||
or global_server_args_dict["attention_backend"] == "flashmla"
|
||||
or global_server_args_dict["attention_backend"] == "fa3"
|
||||
or global_server_args_dict["attention_backend"] == "cutlass_mla"
|
||||
or global_server_args_dict["enable_two_batch_overlap"]
|
||||
):
|
||||
seq_lens_cpu = self.seq_lens.cpu()
|
||||
seq_lens_cpu = (
|
||||
seq_lens_cpu_cache
|
||||
if seq_lens_cpu_cache is not None
|
||||
else self.seq_lens.cpu()
|
||||
)
|
||||
else:
|
||||
seq_lens_cpu = None
|
||||
|
||||
|
||||
@@ -1575,10 +1575,9 @@ class Scheduler(
|
||||
num_accepted_tokens,
|
||||
can_run_cuda_graph,
|
||||
) = self.draft_worker.forward_batch_speculative_generation(batch)
|
||||
self.spec_num_total_accepted_tokens += (
|
||||
num_accepted_tokens + batch.batch_size()
|
||||
)
|
||||
self.spec_num_total_forward_ct += batch.batch_size()
|
||||
bs = batch.batch_size()
|
||||
self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
|
||||
self.spec_num_total_forward_ct += bs
|
||||
self.num_generated_tokens += num_accepted_tokens
|
||||
|
||||
if self.pp_group.is_last_rank:
|
||||
|
||||
@@ -56,6 +56,16 @@ def get_is_capture_mode():
|
||||
return is_capture_mode
|
||||
|
||||
|
||||
@contextmanager
|
||||
def model_capture_mode():
|
||||
global is_capture_mode
|
||||
is_capture_mode = True
|
||||
|
||||
yield
|
||||
|
||||
is_capture_mode = False
|
||||
|
||||
|
||||
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
|
||||
for sub in model._modules.values():
|
||||
if isinstance(sub, CustomOp):
|
||||
@@ -291,22 +301,13 @@ class CudaGraphRunner:
|
||||
|
||||
# Capture
|
||||
try:
|
||||
with self.model_capture_mode():
|
||||
with model_capture_mode():
|
||||
self.capture()
|
||||
except RuntimeError as e:
|
||||
raise Exception(
|
||||
f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def model_capture_mode(self):
|
||||
global is_capture_mode
|
||||
is_capture_mode = True
|
||||
|
||||
yield
|
||||
|
||||
is_capture_mode = False
|
||||
|
||||
def can_run(self, forward_batch: ForwardBatch):
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
total_global_tokens = sum(forward_batch.global_num_tokens_cpu)
|
||||
@@ -650,6 +651,8 @@ class CudaGraphRunner:
|
||||
topk=self.model_runner.server_args.speculative_eagle_topk,
|
||||
draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
|
||||
capture_hidden_mode=CaptureHiddenMode.FULL,
|
||||
seq_lens_sum=None,
|
||||
seq_lens_cpu=None,
|
||||
)
|
||||
|
||||
return spec_info
|
||||
|
||||
@@ -1013,13 +1013,13 @@ class ServerArgs:
|
||||
type=str,
|
||||
choices=[
|
||||
"aiter",
|
||||
"flashinfer",
|
||||
"triton",
|
||||
"torch_native",
|
||||
"fa3",
|
||||
"flashmla",
|
||||
"cutlass_mla",
|
||||
"fa3",
|
||||
"flashinfer",
|
||||
"flashmla",
|
||||
"intel_amx",
|
||||
"torch_native",
|
||||
"triton",
|
||||
],
|
||||
default=ServerArgs.attention_backend,
|
||||
help="Choose the kernels for attention layers.",
|
||||
|
||||
@@ -10,6 +10,7 @@ from sglang.srt.model_executor.cuda_graph_runner import (
|
||||
CudaGraphRunner,
|
||||
get_batch_sizes_to_capture,
|
||||
get_global_graph_memory_pool,
|
||||
model_capture_mode,
|
||||
set_global_graph_memory_pool,
|
||||
set_torch_compile_config,
|
||||
)
|
||||
@@ -80,7 +81,8 @@ class EAGLEDraftCudaGraphRunner:
|
||||
|
||||
# Capture
|
||||
try:
|
||||
self.capture()
|
||||
with model_capture_mode():
|
||||
self.capture()
|
||||
except RuntimeError as e:
|
||||
raise Exception(
|
||||
f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
|
||||
|
||||
@@ -11,6 +11,7 @@ from sglang.srt.model_executor.cuda_graph_runner import (
|
||||
LogitsProcessorOutput,
|
||||
get_batch_sizes_to_capture,
|
||||
get_global_graph_memory_pool,
|
||||
model_capture_mode,
|
||||
set_global_graph_memory_pool,
|
||||
set_torch_compile_config,
|
||||
)
|
||||
@@ -19,7 +20,7 @@ from sglang.srt.model_executor.forward_batch_info import (
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
)
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
||||
@@ -37,6 +38,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
self.tp_size = self.model_runner.tp_size
|
||||
self.dp_size = model_runner.server_args.dp_size
|
||||
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
|
||||
self.topk = model_runner.server_args.speculative_eagle_topk
|
||||
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
||||
self.padded_static_len = -1
|
||||
|
||||
@@ -87,7 +89,8 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
|
||||
# Capture
|
||||
try:
|
||||
self.capture()
|
||||
with model_capture_mode():
|
||||
self.capture()
|
||||
except RuntimeError as e:
|
||||
raise Exception(
|
||||
f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
|
||||
@@ -170,6 +173,8 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
forward_batch.positions,
|
||||
forward_batch,
|
||||
)
|
||||
probs = torch.softmax(ret.next_token_logits, dim=-1)
|
||||
ret.topk_p, ret.topk_index = fast_topk(probs, self.topk, dim=-1)
|
||||
|
||||
forward_batch.out_cache_loc = output_cache_loc_backup
|
||||
forward_batch.spec_info.hidden_states = hidden_states_backup
|
||||
@@ -198,7 +203,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
|
||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||
bs = self.capture_bs[index]
|
||||
if bs != raw_bs:
|
||||
if bs * self.num_tokens_per_bs != num_tokens:
|
||||
self.seq_lens.fill_(1)
|
||||
self.accept_length.fill_(1)
|
||||
self.out_cache_loc.zero_()
|
||||
@@ -238,8 +243,11 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
out = self.output_buffers[bs]
|
||||
if bs != raw_bs:
|
||||
forward_batch.spec_info.accept_length = self.accept_length[:raw_bs]
|
||||
out_copy = out
|
||||
out = LogitsProcessorOutput(
|
||||
next_token_logits=out.next_token_logits[:raw_bs],
|
||||
hidden_states=out.hidden_states[:raw_bs],
|
||||
)
|
||||
out.topk_p = out_copy.topk_p[:raw_bs]
|
||||
out.topk_index = out_copy.topk_index[:raw_bs]
|
||||
return out
|
||||
|
||||
@@ -22,8 +22,7 @@ from sglang.srt.managers.schedule_batch import (
|
||||
global_server_args_dict,
|
||||
)
|
||||
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
||||
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
||||
from sglang.srt.utils import fast_topk, is_cuda, is_hip, next_power_of_2
|
||||
|
||||
if is_cuda():
|
||||
@@ -86,78 +85,29 @@ class EagleDraftInput:
|
||||
self,
|
||||
batch: ScheduleBatch,
|
||||
speculative_num_steps: int,
|
||||
context_length: int,
|
||||
pad_input: bool = False,
|
||||
):
|
||||
accept_length_cpu = batch.spec_info.accept_length_cpu
|
||||
batch.extend_lens = [x + 1 for x in accept_length_cpu]
|
||||
batch.forward_mode = ForwardMode.DRAFT_EXTEND
|
||||
batch.input_ids = self.verified_id
|
||||
batch.extend_lens = [x + 1 for x in batch.spec_info.accept_length_cpu]
|
||||
batch.extend_num_tokens = sum(batch.extend_lens)
|
||||
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
|
||||
batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
|
||||
seq_lens_cpu = batch.seq_lens.tolist()
|
||||
batch.return_logprob = False
|
||||
|
||||
self.positions = torch.empty_like(self.verified_id, dtype=torch.long)
|
||||
new_verified_id = torch.empty_like(self.accept_length, dtype=torch.int32)
|
||||
self.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
self.accept_length.add_(1)
|
||||
self.positions = torch.empty_like(batch.input_ids, dtype=torch.long)
|
||||
self.verified_id = torch.empty_like(self.accept_length, dtype=torch.int32)
|
||||
|
||||
create_extend_spec_info[(self.accept_length.numel(),)](
|
||||
self.verified_id,
|
||||
create_extend_after_decode_spec_info[(len(batch.seq_lens),)](
|
||||
batch.input_ids,
|
||||
batch.seq_lens,
|
||||
self.accept_length,
|
||||
torch.cumsum(self.accept_length, axis=0, dtype=torch.int),
|
||||
self.positions,
|
||||
new_verified_id,
|
||||
next_power_of_2(speculative_num_steps + 1),
|
||||
self.verified_id,
|
||||
next_power_of_2(max(speculative_num_steps + 1, len(batch.seq_lens))),
|
||||
)
|
||||
|
||||
batch.seq_lens_sum = sum(seq_lens_cpu)
|
||||
batch.input_ids = self.verified_id
|
||||
self.verified_id = new_verified_id
|
||||
|
||||
if not pad_input:
|
||||
return
|
||||
|
||||
batch_size = sum(not req.finished() for req in batch.reqs)
|
||||
# Total constant input length after padding
|
||||
static_len = speculative_num_steps + 1
|
||||
# Total size after padding
|
||||
padded_input_size = batch_size * static_len
|
||||
|
||||
padded_len = padded_input_size - batch.input_ids.shape[0]
|
||||
if padded_len > 0:
|
||||
new_input_ids = torch.nn.functional.pad(
|
||||
batch.input_ids, (0, padded_len), value=0
|
||||
)
|
||||
position_padding = torch.arange(padded_len, device=self.positions.device)
|
||||
new_positions = torch.cat([self.positions, position_padding])
|
||||
|
||||
# need dummy hidden states for the padded positions
|
||||
hidden_states_dim = self.hidden_states.shape[-1]
|
||||
new_hidden_states = torch.cat(
|
||||
[
|
||||
self.hidden_states,
|
||||
torch.zeros(
|
||||
(padded_len, hidden_states_dim),
|
||||
dtype=self.hidden_states.dtype,
|
||||
device=self.hidden_states.device,
|
||||
),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# allocate KV cache location for the padded tokens
|
||||
padded_cache_loc = torch.zeros(
|
||||
padded_len,
|
||||
dtype=batch.out_cache_loc.dtype,
|
||||
device=batch.out_cache_loc.device,
|
||||
)
|
||||
new_out_cache_loc = torch.cat([batch.out_cache_loc, padded_cache_loc])
|
||||
|
||||
batch.input_ids = new_input_ids
|
||||
self.hidden_states = new_hidden_states
|
||||
self.positions = new_positions
|
||||
batch.out_cache_loc = new_out_cache_loc
|
||||
|
||||
def generate_attn_arg_prefill(
|
||||
self,
|
||||
req_pool_indices: torch.Tensor,
|
||||
@@ -173,8 +123,9 @@ class EagleDraftInput:
|
||||
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
||||
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
|
||||
# TODO: replace cum_kv_seq_len[-1] with paged_kernel_lens_sum to avoid the device sync.
|
||||
kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda")
|
||||
kv_indices = torch.empty(
|
||||
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
req_to_token,
|
||||
@@ -238,54 +189,10 @@ class EagleVerifyInput:
|
||||
topk: int
|
||||
draft_token_num: int
|
||||
capture_hidden_mode: CaptureHiddenMode
|
||||
seq_lens_sum: int
|
||||
seq_lens_cpu: torch.Tensor
|
||||
grammar: BaseGrammarObject = None
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
verified_id: torch.Tensor,
|
||||
score_list: List[torch.Tensor],
|
||||
token_list: List[torch.Tensor],
|
||||
parents_list: List[torch.Tensor],
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
topk: int,
|
||||
spec_steps: int,
|
||||
num_verify_tokens: int,
|
||||
):
|
||||
(
|
||||
tree_mask,
|
||||
position,
|
||||
retrive_index,
|
||||
retrive_next_token,
|
||||
retrive_next_sibling,
|
||||
draft_tokens,
|
||||
) = build_tree_kernel_efficient(
|
||||
verified_id,
|
||||
score_list,
|
||||
token_list,
|
||||
parents_list,
|
||||
seq_lens,
|
||||
seq_lens_sum,
|
||||
topk,
|
||||
spec_steps,
|
||||
num_verify_tokens,
|
||||
)
|
||||
|
||||
return cls(
|
||||
draft_token=draft_tokens,
|
||||
custom_mask=tree_mask,
|
||||
positions=position,
|
||||
retrive_index=retrive_index,
|
||||
retrive_next_token=retrive_next_token,
|
||||
retrive_next_sibling=retrive_next_sibling,
|
||||
retrive_cum_len=None,
|
||||
spec_steps=spec_steps,
|
||||
topk=topk,
|
||||
draft_token_num=num_verify_tokens,
|
||||
capture_hidden_mode=CaptureHiddenMode.FULL,
|
||||
)
|
||||
|
||||
def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
|
||||
batch.input_ids = self.draft_token
|
||||
|
||||
@@ -614,26 +521,28 @@ class EagleVerifyInput:
|
||||
|
||||
|
||||
@triton.jit
|
||||
def create_extend_spec_info(
|
||||
def create_extend_after_decode_spec_info(
|
||||
verified_id,
|
||||
seq_len,
|
||||
accept_len,
|
||||
accept_len_cum,
|
||||
seq_lens,
|
||||
accept_lens,
|
||||
positions,
|
||||
new_verified_id,
|
||||
accept_len_upper: tl.constexpr,
|
||||
bs_upper: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
offset = 0 if pid == 0 else tl.load(accept_len_cum + pid - 1)
|
||||
seq_length = tl.load(seq_len + pid)
|
||||
accept_length = tl.load(accept_len + pid)
|
||||
positions_ptr = positions + offset
|
||||
data = tl.arange(0, accept_len_upper)
|
||||
mask = data < accept_length
|
||||
tl.store(positions_ptr + data, seq_length - accept_length + data, mask)
|
||||
offsets = tl.arange(0, bs_upper)
|
||||
seq_length = tl.load(seq_lens + pid)
|
||||
accept_length = tl.load(accept_lens + pid)
|
||||
|
||||
offset = tl.load(accept_len_cum + pid) - 1
|
||||
verified_id_data = tl.load(verified_id + offset)
|
||||
accept_len_cumsum = tl.sum(
|
||||
tl.load(accept_lens + offsets, mask=offsets < pid, other=0)
|
||||
)
|
||||
positions_ptr = positions + accept_len_cumsum
|
||||
mask = offsets < accept_length
|
||||
tl.store(positions_ptr + offsets, seq_length - accept_length + offsets, mask)
|
||||
|
||||
accept_len_cumsum += accept_length - 1
|
||||
verified_id_data = tl.load(verified_id + accept_len_cumsum)
|
||||
tl.store(new_verified_id + pid, verified_id_data)
|
||||
|
||||
|
||||
@@ -654,8 +563,8 @@ def assign_req_to_token_pool(
|
||||
token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
|
||||
|
||||
length_offset = tl.arange(0, bs_upper)
|
||||
start = tl.load(start_offset + length_offset, mask=length_offset < pid)
|
||||
end = tl.load(end_offset + length_offset, mask=length_offset < pid)
|
||||
start = tl.load(start_offset + length_offset, mask=length_offset < pid, other=0)
|
||||
end = tl.load(end_offset + length_offset, mask=length_offset < pid, other=0)
|
||||
out_offset = tl.sum(end - start, axis=0)
|
||||
|
||||
out_cache_ptr = out_cache_loc + out_offset
|
||||
@@ -736,7 +645,7 @@ def generate_draft_decode_kv_indices(
|
||||
iters += 1
|
||||
|
||||
load_offset = tl.arange(0, bs_upper)
|
||||
seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid)
|
||||
seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid, other=0)
|
||||
seq_len = tl.load(paged_kernel_lens + bid)
|
||||
cum_seq_len = tl.sum(seq_lens)
|
||||
|
||||
@@ -765,7 +674,7 @@ def generate_draft_decode_kv_indices(
|
||||
zid = bid * topk + topk_id
|
||||
if zid == 0:
|
||||
zid = num_seqs * topk
|
||||
positions = tl.load(positions + bs_offset, mask=bs_offset < zid)
|
||||
positions = tl.load(positions + bs_offset, mask=bs_offset < zid, other=0)
|
||||
base = tl.sum(positions)
|
||||
tl.store(kv_indptr + zid, base + zid * iters)
|
||||
|
||||
@@ -783,7 +692,9 @@ def align_evict_mask_to_page_size(
|
||||
bid = tl.program_id(axis=0)
|
||||
seq_len = tl.load(seq_lens + bid)
|
||||
io_mask = t_range < num_draft_tokens
|
||||
mask_row = tl.load(evict_mask + bid * num_draft_tokens + t_range, mask=io_mask)
|
||||
mask_row = tl.load(
|
||||
evict_mask + bid * num_draft_tokens + t_range, mask=io_mask, other=0
|
||||
)
|
||||
|
||||
num_trues = tl.sum(mask_row)
|
||||
num_false = num_draft_tokens - num_trues
|
||||
|
||||
@@ -23,6 +23,7 @@ from sglang.srt.model_executor.forward_batch_info import (
|
||||
ForwardMode,
|
||||
)
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
|
||||
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
|
||||
EAGLEDraftCudaGraphRunner,
|
||||
)
|
||||
@@ -69,7 +70,6 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.server_args = server_args
|
||||
self.topk = server_args.speculative_eagle_topk
|
||||
self.speculative_num_steps = server_args.speculative_num_steps
|
||||
self.padded_static_len = self.speculative_num_steps + 1
|
||||
self.enable_nan_detection = server_args.enable_nan_detection
|
||||
self.gpu_id = gpu_id
|
||||
self.device = server_args.device
|
||||
@@ -78,6 +78,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
|
||||
server_args.speculative_algorithm
|
||||
)
|
||||
self.padded_static_len = -1
|
||||
|
||||
# Override context length with target model's context length
|
||||
server_args.context_length = target_worker.model_runner.model_config.context_len
|
||||
@@ -184,7 +185,6 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.draft_model_runner,
|
||||
skip_prefill=False,
|
||||
)
|
||||
self.padded_static_len = self.speculative_num_steps + 1
|
||||
self.has_prefill_wrapper_verify = True
|
||||
elif self.server_args.attention_backend == "triton":
|
||||
from sglang.srt.layers.attention.triton_backend import (
|
||||
@@ -201,7 +201,6 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.draft_model_runner,
|
||||
skip_prefill=False,
|
||||
)
|
||||
self.padded_static_len = self.speculative_num_steps + 1
|
||||
self.has_prefill_wrapper_verify = False
|
||||
elif self.server_args.attention_backend == "fa3":
|
||||
from sglang.srt.layers.attention.flashattention_backend import (
|
||||
@@ -218,7 +217,6 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.draft_model_runner,
|
||||
skip_prefill=False,
|
||||
)
|
||||
self.padded_static_len = self.speculative_num_steps + 1
|
||||
self.has_prefill_wrapper_verify = False
|
||||
elif self.server_args.attention_backend == "flashmla":
|
||||
from sglang.srt.layers.attention.flashmla_backend import (
|
||||
@@ -231,7 +229,6 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
self.draft_extend_attn_backend = None
|
||||
self.padded_static_len = self.speculative_num_steps + 1
|
||||
self.has_prefill_wrapper_verify = False
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -319,10 +316,12 @@ class EAGLEWorker(TpModelWorker):
|
||||
|
||||
return logits_output, next_token_ids, model_worker_batch.bid, 0, False
|
||||
else:
|
||||
logits_output, next_token_ids, bid = self.forward_target_extend(batch)
|
||||
logits_output, next_token_ids, bid, seq_lens_cpu = (
|
||||
self.forward_target_extend(batch)
|
||||
)
|
||||
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||
self.forward_draft_extend(
|
||||
batch, logits_output.hidden_states, next_token_ids
|
||||
batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
|
||||
)
|
||||
return logits_output, next_token_ids, bid, 0, False
|
||||
|
||||
@@ -346,7 +345,12 @@ class EAGLEWorker(TpModelWorker):
|
||||
logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
|
||||
model_worker_batch
|
||||
)
|
||||
return logits_output, next_token_ids, model_worker_batch.bid
|
||||
return (
|
||||
logits_output,
|
||||
next_token_ids,
|
||||
model_worker_batch.bid,
|
||||
model_worker_batch.seq_lens_cpu,
|
||||
)
|
||||
|
||||
def draft(self, batch: ScheduleBatch):
|
||||
# Parse args
|
||||
@@ -452,7 +456,14 @@ class EAGLEWorker(TpModelWorker):
|
||||
|
||||
self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)
|
||||
|
||||
ret = EagleVerifyInput.create(
|
||||
(
|
||||
tree_mask,
|
||||
position,
|
||||
retrive_index,
|
||||
retrive_next_token,
|
||||
retrive_next_sibling,
|
||||
draft_tokens,
|
||||
) = build_tree_kernel_efficient(
|
||||
spec_info.verified_id,
|
||||
score_list,
|
||||
token_list,
|
||||
@@ -463,7 +474,22 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.speculative_num_steps,
|
||||
self.server_args.speculative_num_draft_tokens,
|
||||
)
|
||||
return ret
|
||||
|
||||
return EagleVerifyInput(
|
||||
draft_token=draft_tokens,
|
||||
custom_mask=tree_mask,
|
||||
positions=position,
|
||||
retrive_index=retrive_index,
|
||||
retrive_next_token=retrive_next_token,
|
||||
retrive_next_sibling=retrive_next_sibling,
|
||||
retrive_cum_len=None,
|
||||
spec_steps=self.speculative_num_steps,
|
||||
topk=self.topk,
|
||||
draft_token_num=self.server_args.speculative_num_draft_tokens,
|
||||
capture_hidden_mode=CaptureHiddenMode.FULL,
|
||||
seq_lens_sum=forward_batch.seq_lens_sum,
|
||||
seq_lens_cpu=forward_batch.seq_lens_cpu,
|
||||
)
|
||||
|
||||
def draft_forward(self, forward_batch: ForwardBatch):
|
||||
# Parse args
|
||||
@@ -523,7 +549,9 @@ class EAGLEWorker(TpModelWorker):
|
||||
spec_info.prepare_for_verify(batch, self.page_size)
|
||||
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
||||
batch.spec_info = spec_info
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
model_worker_batch = batch.get_model_worker_batch(
|
||||
seq_lens_cpu_cache=spec_info.seq_lens_cpu
|
||||
)
|
||||
|
||||
if batch.has_grammar:
|
||||
retrieve_next_token_cpu = spec_info.retrive_next_token.cpu()
|
||||
@@ -650,6 +678,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
batch: ScheduleBatch,
|
||||
hidden_states: torch.Tensor,
|
||||
next_token_ids: List[int],
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
):
|
||||
"""Run draft model extend. This API modifies the states of the batch.
|
||||
|
||||
@@ -664,7 +693,9 @@ class EAGLEWorker(TpModelWorker):
|
||||
)
|
||||
batch.spec_info.prepare_for_extend(batch)
|
||||
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
model_worker_batch = batch.get_model_worker_batch(
|
||||
seq_lens_cpu_cache=seq_lens_cpu
|
||||
)
|
||||
forward_batch = ForwardBatch.init_new(
|
||||
model_worker_batch, self.draft_model_runner
|
||||
)
|
||||
@@ -683,19 +714,18 @@ class EAGLEWorker(TpModelWorker):
|
||||
return_logprob_backup = batch.return_logprob
|
||||
|
||||
# Prepare metadata
|
||||
batch.forward_mode = ForwardMode.DRAFT_EXTEND
|
||||
batch.spec_info.prepare_extend_after_decode(
|
||||
batch,
|
||||
self.speculative_num_steps,
|
||||
self.server_args.context_length,
|
||||
pad_input=self.cuda_graph_runner_for_draft_extend is not None,
|
||||
)
|
||||
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
batch.return_logprob = False
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(
|
||||
model_worker_batch, self.draft_model_runner
|
||||
)
|
||||
if forward_batch.seq_lens_cpu is not None:
|
||||
forward_batch.seq_lens_sum = forward_batch.seq_lens_cpu.sum().item()
|
||||
else:
|
||||
forward_batch.seq_lens_sum = batch.seq_lens.sum().item()
|
||||
|
||||
# Run
|
||||
can_cuda_graph = (
|
||||
@@ -706,14 +736,19 @@ class EAGLEWorker(TpModelWorker):
|
||||
logits_output = self.cuda_graph_runner_for_draft_extend.replay(
|
||||
forward_batch
|
||||
)
|
||||
forward_batch.spec_info.topk_p, forward_batch.spec_info.topk_index = (
|
||||
logits_output.topk_p,
|
||||
logits_output.topk_index,
|
||||
)
|
||||
forward_batch.spec_info.hidden_states = logits_output.hidden_states
|
||||
else:
|
||||
self.draft_model_runner.attn_backend.init_forward_metadata(forward_batch)
|
||||
logits_output = self.draft_model_runner.model.forward(
|
||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||
)
|
||||
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
||||
|
||||
self._detect_nan_if_needed(logits_output)
|
||||
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
||||
|
||||
# Restore backup.
|
||||
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
|
||||
|
||||
Reference in New Issue
Block a user