Abstraction for spec worker and code cleanup (#11643)
This commit is contained in:
@@ -1061,38 +1061,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
)
|
)
|
||||||
return req_pool_indices
|
return req_pool_indices
|
||||||
|
|
||||||
def allocate_for_eagle_v2(self):
|
|
||||||
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
|
||||||
from sglang.srt.speculative.spec_utils import assign_req_to_token_pool
|
|
||||||
|
|
||||||
bs = self.batch_size()
|
|
||||||
|
|
||||||
assert self.spec_info.is_draft_input()
|
|
||||||
draft_input: EagleDraftInput = self.spec_info
|
|
||||||
|
|
||||||
# FIXME(lsyin): now implementation does not enable over-allocation
|
|
||||||
# Now seq_lens and allocate_lens are correct
|
|
||||||
self.maybe_wait_verify_done()
|
|
||||||
|
|
||||||
new_allocate_lens = self.seq_lens + EagleDraftInput.ALLOC_LEN_PER_DECODE
|
|
||||||
num_needed_tokens = (new_allocate_lens - draft_input.allocate_lens).sum().item()
|
|
||||||
out_cache_loc = alloc_token_slots(self.tree_cache, num_needed_tokens)
|
|
||||||
|
|
||||||
assign_req_to_token_pool[(bs,)](
|
|
||||||
self.req_pool_indices,
|
|
||||||
self.req_to_token_pool.req_to_token,
|
|
||||||
draft_input.allocate_lens,
|
|
||||||
new_allocate_lens,
|
|
||||||
out_cache_loc,
|
|
||||||
self.req_to_token_pool.req_to_token.shape[1],
|
|
||||||
next_power_of_2(bs),
|
|
||||||
)
|
|
||||||
draft_input.allocate_lens = new_allocate_lens
|
|
||||||
|
|
||||||
# FIXME(lsyin): remove seq_lens_sum calculation
|
|
||||||
self.seq_lens_cpu = self.seq_lens.cpu()
|
|
||||||
self.seq_lens_sum = self.seq_lens_cpu.sum().item()
|
|
||||||
|
|
||||||
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
|
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
|
||||||
self.encoder_lens_cpu = []
|
self.encoder_lens_cpu = []
|
||||||
self.encoder_cached = []
|
self.encoder_cached = []
|
||||||
@@ -1522,8 +1490,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
bs = len(self.reqs)
|
bs = len(self.reqs)
|
||||||
|
|
||||||
if self.is_v2_eagle:
|
if self.is_v2_eagle:
|
||||||
# FIXME(lsyin): make this sync optional
|
# TODO(spec-v2): all v2 spec should go through this path
|
||||||
self.allocate_for_eagle_v2()
|
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
||||||
|
|
||||||
|
draft_input: EagleDraftInput = self.spec_info
|
||||||
|
draft_input.prepare_for_decode(self)
|
||||||
|
|
||||||
if not self.spec_algorithm.is_none():
|
if not self.spec_algorithm.is_none():
|
||||||
# if spec decoding is used, the decode batch is prepared inside
|
# if spec decoding is used, the decode batch is prepared inside
|
||||||
|
|||||||
@@ -215,10 +215,10 @@ class GenerationBatchResult:
|
|||||||
delay_sample_func: Optional[callable] = None
|
delay_sample_func: Optional[callable] = None
|
||||||
future_indices: Optional[FutureIndices] = None
|
future_indices: Optional[FutureIndices] = None
|
||||||
|
|
||||||
# FIXME(lsyin): maybe move to <BetterPlace> ?
|
# FIXME(lsyin): maybe move to a better place?
|
||||||
# sync path: forward stream -> output processor
|
# sync path: forward stream -> output processor
|
||||||
accept_lens: Optional[torch.Tensor] = None
|
accept_lens: Optional[torch.Tensor] = None
|
||||||
last_batch_allocate_lens: Optional[torch.Tensor] = None
|
allocate_lens: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# relay path: forward stream -> next step forward
|
# relay path: forward stream -> next step forward
|
||||||
next_draft_input: Optional[EagleDraftInput] = None
|
next_draft_input: Optional[EagleDraftInput] = None
|
||||||
@@ -246,10 +246,8 @@ class GenerationBatchResult:
|
|||||||
if self.accept_lens is not None:
|
if self.accept_lens is not None:
|
||||||
self.accept_lens = self.accept_lens.to("cpu", non_blocking=True)
|
self.accept_lens = self.accept_lens.to("cpu", non_blocking=True)
|
||||||
|
|
||||||
if self.last_batch_allocate_lens is not None:
|
if self.allocate_lens is not None:
|
||||||
self.last_batch_allocate_lens = self.last_batch_allocate_lens.to(
|
self.allocate_lens = self.allocate_lens.to("cpu", non_blocking=True)
|
||||||
"cpu", non_blocking=True
|
|
||||||
)
|
|
||||||
|
|
||||||
self.copy_done.record()
|
self.copy_done.record()
|
||||||
|
|
||||||
|
|||||||
@@ -42,23 +42,21 @@ class SchedulerOutputProcessorMixin:
|
|||||||
skip_stream_req = None
|
skip_stream_req = None
|
||||||
|
|
||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
|
if result.copy_done is not None:
|
||||||
|
result.copy_done.synchronize()
|
||||||
|
|
||||||
(
|
(
|
||||||
logits_output,
|
logits_output,
|
||||||
next_token_ids,
|
next_token_ids,
|
||||||
extend_input_len_per_req,
|
extend_input_len_per_req,
|
||||||
extend_logprob_start_len_per_req,
|
extend_logprob_start_len_per_req,
|
||||||
copy_done,
|
|
||||||
) = (
|
) = (
|
||||||
result.logits_output,
|
result.logits_output,
|
||||||
result.next_token_ids,
|
result.next_token_ids,
|
||||||
result.extend_input_len_per_req,
|
result.extend_input_len_per_req,
|
||||||
result.extend_logprob_start_len_per_req,
|
result.extend_logprob_start_len_per_req,
|
||||||
result.copy_done,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if copy_done is not None:
|
|
||||||
copy_done.synchronize()
|
|
||||||
|
|
||||||
# Move next_token_ids and logprobs to cpu
|
# Move next_token_ids and logprobs to cpu
|
||||||
next_token_ids = next_token_ids.tolist()
|
next_token_ids = next_token_ids.tolist()
|
||||||
if batch.return_logprob:
|
if batch.return_logprob:
|
||||||
@@ -199,57 +197,52 @@ class SchedulerOutputProcessorMixin:
|
|||||||
|
|
||||||
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
|
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
|
||||||
|
|
||||||
def hacky_process_eagle_overlap_result(
|
def _resolve_spec_overlap_token_ids(
|
||||||
self: Scheduler, result: GenerationBatchResult, batch: ScheduleBatch
|
self: Scheduler, result: GenerationBatchResult, batch: ScheduleBatch
|
||||||
):
|
) -> List[List[int]]:
|
||||||
# TODO(lsyin): try use a copy stream to share SMs with forward
|
"""Resolve the padding next token ids for speculative decoding with overlap."""
|
||||||
# FIXME(lsyin): better organize this token free logic in eagle-overlap
|
assert result.next_token_ids.is_cpu
|
||||||
last_batch_allocate_lens_cpu = result.last_batch_allocate_lens.tolist()
|
assert result.accept_lens.is_cpu
|
||||||
accept_lens_cpu = result.accept_lens.tolist()
|
assert result.allocate_lens.is_cpu
|
||||||
|
|
||||||
next_token_ids = result.next_token_ids.tolist()
|
next_token_ids = result.next_token_ids.tolist()
|
||||||
|
accept_lens = result.accept_lens.tolist()
|
||||||
|
result.num_accepted_tokens = sum(accept_lens)
|
||||||
|
|
||||||
predict_tokens = []
|
predict_tokens = []
|
||||||
num_draft_tokens = self.draft_worker.speculative_num_draft_tokens
|
stride = self.draft_worker.speculative_num_draft_tokens
|
||||||
for i, req in enumerate(batch.reqs):
|
for i, req in enumerate(batch.reqs):
|
||||||
predict_tokens.append(
|
predict_tokens.append(
|
||||||
next_token_ids[
|
next_token_ids[i * stride : i * stride + accept_lens[i]]
|
||||||
i * num_draft_tokens : i * num_draft_tokens + accept_lens_cpu[i]
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
# FIXME(lsyin): move this update elsewhere
|
|
||||||
req.spec_verify_ct += 1
|
req.spec_verify_ct += 1
|
||||||
|
|
||||||
return last_batch_allocate_lens_cpu, accept_lens_cpu, predict_tokens
|
return predict_tokens
|
||||||
|
|
||||||
def process_batch_result_decode(
|
def process_batch_result_decode(
|
||||||
self: Scheduler,
|
self: Scheduler,
|
||||||
batch: ScheduleBatch,
|
batch: ScheduleBatch,
|
||||||
result: GenerationBatchResult,
|
result: GenerationBatchResult,
|
||||||
):
|
):
|
||||||
logits_output, next_token_ids, can_run_cuda_graph, copy_done = (
|
if result.copy_done is not None:
|
||||||
|
result.copy_done.synchronize()
|
||||||
|
|
||||||
|
logits_output, next_token_ids, can_run_cuda_graph = (
|
||||||
result.logits_output,
|
result.logits_output,
|
||||||
result.next_token_ids,
|
result.next_token_ids,
|
||||||
result.can_run_cuda_graph,
|
result.can_run_cuda_graph,
|
||||||
result.copy_done,
|
|
||||||
)
|
)
|
||||||
self.num_generated_tokens += len(batch.reqs)
|
|
||||||
|
|
||||||
if copy_done is not None:
|
|
||||||
copy_done.synchronize()
|
|
||||||
|
|
||||||
if batch.spec_algorithm.is_none():
|
if batch.spec_algorithm.is_none():
|
||||||
next_token_ids = next_token_ids.tolist()
|
next_token_ids = next_token_ids.tolist()
|
||||||
if batch.return_logprob:
|
if batch.return_logprob:
|
||||||
next_token_logprobs = logits_output.next_token_logprobs.tolist()
|
next_token_logprobs = logits_output.next_token_logprobs.tolist()
|
||||||
elif batch.is_v2_eagle:
|
elif batch.is_v2_eagle:
|
||||||
(
|
next_token_ids = self._resolve_spec_overlap_token_ids(result, batch)
|
||||||
last_batch_allocate_lens_cpu,
|
allocate_lens_list = result.allocate_lens.tolist()
|
||||||
accept_lens_cpu,
|
accept_lens_list = result.accept_lens.tolist()
|
||||||
next_token_ids,
|
|
||||||
) = self.hacky_process_eagle_overlap_result(result, batch)
|
|
||||||
result.num_accepted_tokens = sum(accept_lens_cpu)
|
|
||||||
|
|
||||||
# FIXME(lsyin): we suppose we have already got the num_accepted_tokens in result
|
self.num_generated_tokens += len(batch.reqs)
|
||||||
if not self.spec_algorithm.is_none():
|
if not self.spec_algorithm.is_none():
|
||||||
self.update_spec_metrics(batch.batch_size(), result.num_accepted_tokens)
|
self.update_spec_metrics(batch.batch_size(), result.num_accepted_tokens)
|
||||||
|
|
||||||
@@ -264,43 +257,38 @@ class SchedulerOutputProcessorMixin:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if self.enable_overlap and req.finished():
|
if self.enable_overlap and req.finished():
|
||||||
|
indices_to_free = None
|
||||||
if self.page_size == 1:
|
if self.page_size == 1:
|
||||||
if batch.spec_algorithm.is_eagle():
|
if batch.spec_algorithm.is_eagle():
|
||||||
from sglang.srt.speculative.eagle_worker_v2 import (
|
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
||||||
free_spec_dec_tokens_page_size_1,
|
|
||||||
)
|
|
||||||
|
|
||||||
free_spec_dec_tokens_page_size_1(
|
end_p = allocate_lens_list[i]
|
||||||
self.req_to_token_pool,
|
start_p = end_p - EagleDraftInput.ALLOC_LEN_PER_DECODE
|
||||||
self.token_to_kv_pool_allocator,
|
indices_to_free = self.req_to_token_pool.req_to_token[
|
||||||
req,
|
req.req_pool_idx
|
||||||
last_batch_allocate_lens_cpu[i],
|
][start_p:end_p]
|
||||||
None,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# Free the one extra delayed token
|
# Free the one extra delayed token
|
||||||
self.token_to_kv_pool_allocator.free(
|
indices_to_free = batch.out_cache_loc[i : i + 1]
|
||||||
batch.out_cache_loc[i : i + 1]
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
if batch.spec_algorithm.is_eagle():
|
if batch.spec_algorithm.is_eagle():
|
||||||
# TODO(lsyin): support eagle with page_size > 1
|
# TODO(spec-v2): support eagle with page_size > 1
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
else:
|
else:
|
||||||
if (
|
if (
|
||||||
len(req.origin_input_ids) + len(req.output_ids) - 1
|
len(req.origin_input_ids) + len(req.output_ids) - 1
|
||||||
) % self.page_size == 0:
|
) % self.page_size == 0:
|
||||||
# Only free when the extra token is in a new page
|
# Only free when the extra token is in a new page
|
||||||
self.token_to_kv_pool_allocator.free(
|
indices_to_free = batch.out_cache_loc[i : i + 1]
|
||||||
batch.out_cache_loc[i : i + 1]
|
|
||||||
)
|
if indices_to_free is not None:
|
||||||
|
self.token_to_kv_pool_allocator.free(indices_to_free)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if batch.spec_algorithm.is_none():
|
if batch.spec_algorithm.is_none():
|
||||||
req.output_ids.append(next_token_id)
|
req.output_ids.append(next_token_id)
|
||||||
elif batch.is_v2_eagle:
|
elif batch.is_v2_eagle:
|
||||||
# FIXME(lsyin): non-overlap spec worker will solve the output_ids in speculative decoding
|
# Only v2 eagle's output_ids are updated here.
|
||||||
# !!!unify the logic here!!!
|
|
||||||
req.output_ids.extend(next_token_id)
|
req.output_ids.extend(next_token_id)
|
||||||
|
|
||||||
req.check_finished()
|
req.check_finished()
|
||||||
@@ -308,24 +296,13 @@ class SchedulerOutputProcessorMixin:
|
|||||||
if batch.is_v2_eagle and self.cur_batch.forward_mode.is_extend():
|
if batch.is_v2_eagle and self.cur_batch.forward_mode.is_extend():
|
||||||
# FIXME(lsyin): fix the messy logic here
|
# FIXME(lsyin): fix the messy logic here
|
||||||
# 1) when not overlap (v2 impl), we free the extra tokens in the req
|
# 1) when not overlap (v2 impl), we free the extra tokens in the req
|
||||||
# 2) when overlap and current batch is extend, we free the extra tokens in the req of the previous batch
|
# 2) overlap eagle and the current batch is prefill. This seq will not run extra iteration.
|
||||||
from sglang.srt.speculative.eagle_worker_v2 import (
|
start_p = batch.seq_lens_cpu[i] + accept_lens_list[i]
|
||||||
free_spec_dec_tokens_page_size_1,
|
end_p = allocate_lens_list[i]
|
||||||
)
|
indices_to_free = self.req_to_token_pool.req_to_token[
|
||||||
|
req.req_pool_idx
|
||||||
new_seq_len = len(req.origin_input_ids) + len(req.output_ids) - 1
|
][start_p:end_p]
|
||||||
# FIXME(lsyin): remove this assert
|
self.token_to_kv_pool_allocator.free(indices_to_free)
|
||||||
assert new_seq_len == int(
|
|
||||||
batch.seq_lens_cpu[i] + accept_lens_cpu[i]
|
|
||||||
), f"{new_seq_len=} vs {batch.seq_lens_cpu[i] + accept_lens_cpu[i]=}"
|
|
||||||
|
|
||||||
free_spec_dec_tokens_page_size_1(
|
|
||||||
self.req_to_token_pool,
|
|
||||||
self.token_to_kv_pool_allocator,
|
|
||||||
req,
|
|
||||||
last_batch_allocate_lens_cpu[i],
|
|
||||||
new_seq_len,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.server_args.disaggregation_decode_enable_offload_kvcache:
|
if self.server_args.disaggregation_decode_enable_offload_kvcache:
|
||||||
# Asynchronously offload KV cache; cache_finished_req will be called after Device->Host transfer completes
|
# Asynchronously offload KV cache; cache_finished_req will be called after Device->Host transfer completes
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -54,7 +55,140 @@ if TYPE_CHECKING:
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TpModelWorker:
|
class BaseTpWorker(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def forward_batch_generation(self, forward_batch: ForwardBatch):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def model_runner(self) -> ModelRunner:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sliding_window_size(self) -> Optional[int]:
|
||||||
|
return self.model_runner.sliding_window_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_hybrid(self) -> bool:
|
||||||
|
return self.model_runner.is_hybrid is not None
|
||||||
|
|
||||||
|
def get_tokens_per_layer_info(self):
|
||||||
|
return (
|
||||||
|
self.model_runner.full_max_total_num_tokens,
|
||||||
|
self.model_runner.swa_max_total_num_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_pad_input_ids_func(self):
|
||||||
|
return getattr(self.model_runner.model, "pad_input_ids", None)
|
||||||
|
|
||||||
|
def get_tp_group(self):
|
||||||
|
return self.model_runner.tp_group
|
||||||
|
|
||||||
|
def get_attention_tp_group(self):
|
||||||
|
return self.model_runner.attention_tp_group
|
||||||
|
|
||||||
|
def get_attention_tp_cpu_group(self):
|
||||||
|
return getattr(self.model_runner.attention_tp_group, "cpu_group", None)
|
||||||
|
|
||||||
|
def get_memory_pool(self):
|
||||||
|
return (
|
||||||
|
self.model_runner.req_to_token_pool,
|
||||||
|
self.model_runner.token_to_kv_pool_allocator,
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
||||||
|
success, message = self.model_runner.update_weights_from_disk(
|
||||||
|
recv_req.model_path, recv_req.load_format
|
||||||
|
)
|
||||||
|
return success, message
|
||||||
|
|
||||||
|
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
||||||
|
success, message = self.model_runner.init_weights_update_group(
|
||||||
|
recv_req.master_address,
|
||||||
|
recv_req.master_port,
|
||||||
|
recv_req.rank_offset,
|
||||||
|
recv_req.world_size,
|
||||||
|
recv_req.group_name,
|
||||||
|
recv_req.backend,
|
||||||
|
)
|
||||||
|
return success, message
|
||||||
|
|
||||||
|
def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
|
||||||
|
success, message = self.model_runner.destroy_weights_update_group(
|
||||||
|
recv_req.group_name,
|
||||||
|
)
|
||||||
|
return success, message
|
||||||
|
|
||||||
|
def init_weights_send_group_for_remote_instance(
|
||||||
|
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
|
||||||
|
):
|
||||||
|
success, message = (
|
||||||
|
self.model_runner.init_weights_send_group_for_remote_instance(
|
||||||
|
recv_req.master_address,
|
||||||
|
recv_req.ports,
|
||||||
|
recv_req.group_rank,
|
||||||
|
recv_req.world_size,
|
||||||
|
recv_req.group_name,
|
||||||
|
recv_req.backend,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return success, message
|
||||||
|
|
||||||
|
def send_weights_to_remote_instance(
|
||||||
|
self, recv_req: SendWeightsToRemoteInstanceReqInput
|
||||||
|
):
|
||||||
|
success, message = self.model_runner.send_weights_to_remote_instance(
|
||||||
|
recv_req.master_address,
|
||||||
|
recv_req.ports,
|
||||||
|
recv_req.group_name,
|
||||||
|
)
|
||||||
|
return success, message
|
||||||
|
|
||||||
|
def update_weights_from_distributed(
|
||||||
|
self, recv_req: UpdateWeightsFromDistributedReqInput
|
||||||
|
):
|
||||||
|
success, message = self.model_runner.update_weights_from_distributed(
|
||||||
|
recv_req.names, recv_req.dtypes, recv_req.shapes, recv_req.group_name
|
||||||
|
)
|
||||||
|
return success, message
|
||||||
|
|
||||||
|
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
||||||
|
|
||||||
|
monkey_patch_torch_reductions()
|
||||||
|
success, message = self.model_runner.update_weights_from_tensor(
|
||||||
|
named_tensors=MultiprocessingSerializer.deserialize(
|
||||||
|
recv_req.serialized_named_tensors[self.tp_rank]
|
||||||
|
),
|
||||||
|
load_format=recv_req.load_format,
|
||||||
|
)
|
||||||
|
return success, message
|
||||||
|
|
||||||
|
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
||||||
|
parameter = self.model_runner.get_weights_by_name(
|
||||||
|
recv_req.name, recv_req.truncate_size
|
||||||
|
)
|
||||||
|
return parameter
|
||||||
|
|
||||||
|
def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
|
||||||
|
result = self.model_runner.load_lora_adapter(recv_req.to_ref())
|
||||||
|
return result
|
||||||
|
|
||||||
|
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
|
||||||
|
result = self.model_runner.unload_lora_adapter(recv_req.to_ref())
|
||||||
|
return result
|
||||||
|
|
||||||
|
def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
|
||||||
|
return self.model_runner.lora_manager.validate_lora_batch(lora_ids)
|
||||||
|
|
||||||
|
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
||||||
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||||
|
logits_output, _ = self.model_runner.forward(forward_batch)
|
||||||
|
embeddings = logits_output.embeddings
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class TpModelWorker(BaseTpWorker):
|
||||||
"""A tensor parallel model worker."""
|
"""A tensor parallel model worker."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -92,7 +226,7 @@ class TpModelWorker:
|
|||||||
is_draft_model=is_draft_worker,
|
is_draft_model=is_draft_worker,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model_runner = ModelRunner(
|
self._model_runner = ModelRunner(
|
||||||
model_config=self.model_config,
|
model_config=self.model_config,
|
||||||
mem_fraction_static=server_args.mem_fraction_static,
|
mem_fraction_static=server_args.mem_fraction_static,
|
||||||
gpu_id=gpu_id,
|
gpu_id=gpu_id,
|
||||||
@@ -171,6 +305,10 @@ class TpModelWorker:
|
|||||||
self.enable_overlap = not server_args.disable_overlap_schedule
|
self.enable_overlap = not server_args.disable_overlap_schedule
|
||||||
self.hicache_layer_transfer_counter = None
|
self.hicache_layer_transfer_counter = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_runner(self) -> ModelRunner:
|
||||||
|
return self._model_runner
|
||||||
|
|
||||||
def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
|
def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
|
||||||
self.hicache_layer_transfer_counter = counter
|
self.hicache_layer_transfer_counter = counter
|
||||||
|
|
||||||
@@ -193,38 +331,6 @@ class TpModelWorker:
|
|||||||
self.model_runner.token_to_kv_pool.size,
|
self.model_runner.token_to_kv_pool.size,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
|
||||||
def sliding_window_size(self) -> Optional[int]:
|
|
||||||
return self.model_runner.sliding_window_size
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_hybrid(self) -> bool:
|
|
||||||
return self.model_runner.is_hybrid is not None
|
|
||||||
|
|
||||||
def get_tokens_per_layer_info(self):
|
|
||||||
return (
|
|
||||||
self.model_runner.full_max_total_num_tokens,
|
|
||||||
self.model_runner.swa_max_total_num_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_pad_input_ids_func(self):
|
|
||||||
return getattr(self.model_runner.model, "pad_input_ids", None)
|
|
||||||
|
|
||||||
def get_tp_group(self):
|
|
||||||
return self.model_runner.tp_group
|
|
||||||
|
|
||||||
def get_attention_tp_group(self):
|
|
||||||
return self.model_runner.attention_tp_group
|
|
||||||
|
|
||||||
def get_attention_tp_cpu_group(self):
|
|
||||||
return getattr(self.model_runner.attention_tp_group, "cpu_group", None)
|
|
||||||
|
|
||||||
def get_memory_pool(self):
|
|
||||||
return (
|
|
||||||
self.model_runner.req_to_token_pool,
|
|
||||||
self.model_runner.token_to_kv_pool_allocator,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward_batch_generation(
|
def forward_batch_generation(
|
||||||
self,
|
self,
|
||||||
model_worker_batch: ModelWorkerBatch,
|
model_worker_batch: ModelWorkerBatch,
|
||||||
@@ -313,93 +419,3 @@ class TpModelWorker:
|
|||||||
pp_hidden_states_proxy_tensors=pp_proxy_tensors,
|
pp_hidden_states_proxy_tensors=pp_proxy_tensors,
|
||||||
can_run_cuda_graph=can_run_cuda_graph,
|
can_run_cuda_graph=can_run_cuda_graph,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
|
||||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
|
||||||
logits_output, _ = self.model_runner.forward(forward_batch)
|
|
||||||
embeddings = logits_output.embeddings
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
|
||||||
success, message = self.model_runner.update_weights_from_disk(
|
|
||||||
recv_req.model_path, recv_req.load_format
|
|
||||||
)
|
|
||||||
return success, message
|
|
||||||
|
|
||||||
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
|
||||||
success, message = self.model_runner.init_weights_update_group(
|
|
||||||
recv_req.master_address,
|
|
||||||
recv_req.master_port,
|
|
||||||
recv_req.rank_offset,
|
|
||||||
recv_req.world_size,
|
|
||||||
recv_req.group_name,
|
|
||||||
recv_req.backend,
|
|
||||||
)
|
|
||||||
return success, message
|
|
||||||
|
|
||||||
def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
|
|
||||||
success, message = self.model_runner.destroy_weights_update_group(
|
|
||||||
recv_req.group_name,
|
|
||||||
)
|
|
||||||
return success, message
|
|
||||||
|
|
||||||
def init_weights_send_group_for_remote_instance(
|
|
||||||
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
|
|
||||||
):
|
|
||||||
success, message = (
|
|
||||||
self.model_runner.init_weights_send_group_for_remote_instance(
|
|
||||||
recv_req.master_address,
|
|
||||||
recv_req.ports,
|
|
||||||
recv_req.group_rank,
|
|
||||||
recv_req.world_size,
|
|
||||||
recv_req.group_name,
|
|
||||||
recv_req.backend,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return success, message
|
|
||||||
|
|
||||||
def send_weights_to_remote_instance(
|
|
||||||
self, recv_req: SendWeightsToRemoteInstanceReqInput
|
|
||||||
):
|
|
||||||
success, message = self.model_runner.send_weights_to_remote_instance(
|
|
||||||
recv_req.master_address,
|
|
||||||
recv_req.ports,
|
|
||||||
recv_req.group_name,
|
|
||||||
)
|
|
||||||
return success, message
|
|
||||||
|
|
||||||
def update_weights_from_distributed(
|
|
||||||
self, recv_req: UpdateWeightsFromDistributedReqInput
|
|
||||||
):
|
|
||||||
success, message = self.model_runner.update_weights_from_distributed(
|
|
||||||
recv_req.names, recv_req.dtypes, recv_req.shapes, recv_req.group_name
|
|
||||||
)
|
|
||||||
return success, message
|
|
||||||
|
|
||||||
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
|
||||||
|
|
||||||
monkey_patch_torch_reductions()
|
|
||||||
success, message = self.model_runner.update_weights_from_tensor(
|
|
||||||
named_tensors=MultiprocessingSerializer.deserialize(
|
|
||||||
recv_req.serialized_named_tensors[self.tp_rank]
|
|
||||||
),
|
|
||||||
load_format=recv_req.load_format,
|
|
||||||
)
|
|
||||||
return success, message
|
|
||||||
|
|
||||||
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
|
||||||
parameter = self.model_runner.get_weights_by_name(
|
|
||||||
recv_req.name, recv_req.truncate_size
|
|
||||||
)
|
|
||||||
return parameter
|
|
||||||
|
|
||||||
def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
|
|
||||||
result = self.model_runner.load_lora_adapter(recv_req.to_ref())
|
|
||||||
return result
|
|
||||||
|
|
||||||
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
|
|
||||||
result = self.model_runner.unload_lora_adapter(recv_req.to_ref())
|
|
||||||
return result
|
|
||||||
|
|
||||||
def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
|
|
||||||
return self.model_runner.lora_manager.validate_lora_batch(lora_ids)
|
|
||||||
|
|||||||
@@ -53,7 +53,6 @@ from sglang.srt.utils import (
|
|||||||
empty_context,
|
empty_context,
|
||||||
get_available_gpu_memory,
|
get_available_gpu_memory,
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
get_device_memory_capacity,
|
|
||||||
is_hip,
|
is_hip,
|
||||||
log_info_on_rank0,
|
log_info_on_rank0,
|
||||||
require_attn_tp_gather,
|
require_attn_tp_gather,
|
||||||
@@ -274,7 +273,6 @@ class CudaGraphRunner:
|
|||||||
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
||||||
)
|
)
|
||||||
|
|
||||||
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
|
|
||||||
self.encoder_len_fill_value = 0
|
self.encoder_len_fill_value = 0
|
||||||
self.seq_lens_cpu = torch.full(
|
self.seq_lens_cpu = torch.full(
|
||||||
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
||||||
|
|||||||
29
python/sglang/srt/speculative/base_spec_worker.py
Normal file
29
python/sglang/srt/speculative/base_spec_worker.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||||
|
|
||||||
|
|
||||||
|
class BaseDraftWorker(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def draft():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def draft_extend():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BaseSpecWorker(ABC):
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def target_worker(self) -> TpModelWorker:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def draft_worker(self) -> BaseDraftWorker:
|
||||||
|
pass
|
||||||
@@ -40,7 +40,11 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
def __init__(self, eagle_worker: EAGLEWorker):
|
def __init__(self, eagle_worker: EAGLEWorker):
|
||||||
# Parse args
|
# Parse args
|
||||||
self.eagle_worker = eagle_worker
|
self.eagle_worker = eagle_worker
|
||||||
self.model_runner = model_runner = eagle_worker.model_runner
|
if not hasattr(eagle_worker, "model_runner"):
|
||||||
|
# V2: EagleDraftWorker
|
||||||
|
self.model_runner = model_runner = eagle_worker.draft_runner
|
||||||
|
else:
|
||||||
|
self.model_runner = model_runner = eagle_worker.model_runner
|
||||||
self.graphs = {}
|
self.graphs = {}
|
||||||
self.output_buffers = {}
|
self.output_buffers = {}
|
||||||
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
||||||
|
|||||||
@@ -38,7 +38,12 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
def __init__(self, eagle_worker: EAGLEWorker):
|
def __init__(self, eagle_worker: EAGLEWorker):
|
||||||
# Parse args
|
# Parse args
|
||||||
self.eagle_worker = eagle_worker
|
self.eagle_worker = eagle_worker
|
||||||
self.model_runner = model_runner = eagle_worker.model_runner
|
if not hasattr(eagle_worker, "model_runner"):
|
||||||
|
# V2: EagleDraftWorker
|
||||||
|
self.model_runner = model_runner = eagle_worker.draft_runner
|
||||||
|
else:
|
||||||
|
self.model_runner = model_runner = eagle_worker.model_runner
|
||||||
|
|
||||||
self.graphs = {}
|
self.graphs = {}
|
||||||
self.output_buffers = {}
|
self.output_buffers = {}
|
||||||
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
||||||
@@ -285,7 +290,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
output_cache_loc_backup = forward_batch.out_cache_loc
|
output_cache_loc_backup = forward_batch.out_cache_loc
|
||||||
hidden_states_backup = forward_batch.spec_info.hidden_states
|
hidden_states_backup = forward_batch.spec_info.hidden_states
|
||||||
|
|
||||||
ret = self.eagle_worker.draft_model_runner.model.forward(
|
ret = self.model_runner.model.forward(
|
||||||
forward_batch.input_ids,
|
forward_batch.input_ids,
|
||||||
forward_batch.positions,
|
forward_batch.positions,
|
||||||
forward_batch,
|
forward_batch,
|
||||||
|
|||||||
@@ -574,6 +574,9 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
|
class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
|
||||||
|
# Constant: alloc length per decode step
|
||||||
|
ALLOC_LEN_PER_DECODE: ClassVar[int] = None
|
||||||
|
|
||||||
# The inputs for decode
|
# The inputs for decode
|
||||||
# shape: (b, topk)
|
# shape: (b, topk)
|
||||||
topk_p: torch.Tensor = None
|
topk_p: torch.Tensor = None
|
||||||
@@ -609,9 +612,6 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
|
|||||||
new_seq_lens: Optional[torch.Tensor] = None
|
new_seq_lens: Optional[torch.Tensor] = None
|
||||||
verify_done: Optional[torch.cuda.Event] = None
|
verify_done: Optional[torch.cuda.Event] = None
|
||||||
|
|
||||||
# FIXME(lsyin): remove this hack
|
|
||||||
ALLOC_LEN_PER_DECODE: ClassVar[int] = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__(SpecInputType.EAGLE_DRAFT)
|
super().__init__(SpecInputType.EAGLE_DRAFT)
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,8 @@ import triton
|
|||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, ScheduleBatch
|
||||||
|
from sglang.srt.mem_cache.common import alloc_token_slots
|
||||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||||
from sglang.srt.model_executor.forward_batch_info import (
|
from sglang.srt.model_executor.forward_batch_info import (
|
||||||
CaptureHiddenMode,
|
CaptureHiddenMode,
|
||||||
@@ -72,6 +73,34 @@ def assign_draft_cache_locs_page_size_1(
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EagleDraftInputV2Mixin:
|
class EagleDraftInputV2Mixin:
|
||||||
|
def prepare_for_decode(self: EagleDraftInput, batch: ScheduleBatch):
|
||||||
|
from sglang.srt.speculative.spec_utils import assign_req_to_token_pool
|
||||||
|
|
||||||
|
bs = batch.batch_size()
|
||||||
|
|
||||||
|
# TODO(lsyin): implement over-allocation
|
||||||
|
# Now seq_lens and allocate_lens are correct
|
||||||
|
batch.maybe_wait_verify_done()
|
||||||
|
|
||||||
|
new_allocate_lens = batch.seq_lens + self.ALLOC_LEN_PER_DECODE
|
||||||
|
num_needed_tokens = (new_allocate_lens - self.allocate_lens).sum().item()
|
||||||
|
out_cache_loc = alloc_token_slots(batch.tree_cache, num_needed_tokens)
|
||||||
|
|
||||||
|
assign_req_to_token_pool[(bs,)](
|
||||||
|
batch.req_pool_indices,
|
||||||
|
batch.req_to_token_pool.req_to_token,
|
||||||
|
self.allocate_lens,
|
||||||
|
new_allocate_lens,
|
||||||
|
out_cache_loc,
|
||||||
|
batch.req_to_token_pool.req_to_token.shape[1],
|
||||||
|
next_power_of_2(bs),
|
||||||
|
)
|
||||||
|
self.allocate_lens = new_allocate_lens
|
||||||
|
|
||||||
|
# FIXME(lsyin): make this sync optional
|
||||||
|
batch.seq_lens_cpu = batch.seq_lens.cpu()
|
||||||
|
batch.seq_lens_sum = batch.seq_lens_cpu.sum().item()
|
||||||
|
|
||||||
def prepare_for_v2_draft(
|
def prepare_for_v2_draft(
|
||||||
self: EagleDraftInput,
|
self: EagleDraftInput,
|
||||||
req_to_token_pool: ReqToTokenPool,
|
req_to_token_pool: ReqToTokenPool,
|
||||||
|
|||||||
@@ -1,17 +1,10 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import time
|
import time
|
||||||
from contextlib import contextmanager
|
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
|
|
||||||
from sglang.srt.distributed import (
|
from sglang.srt.distributed import get_tp_group
|
||||||
GroupCoordinator,
|
|
||||||
get_tp_group,
|
|
||||||
patch_tensor_parallel_group,
|
|
||||||
)
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
|
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
|
||||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||||
@@ -47,15 +40,17 @@ from sglang.srt.speculative.eagle_utils import (
|
|||||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||||
from sglang.srt.speculative.spec_utils import (
|
from sglang.srt.speculative.spec_utils import (
|
||||||
assign_draft_cache_locs,
|
assign_draft_cache_locs,
|
||||||
|
detect_nan,
|
||||||
|
draft_tp_context,
|
||||||
fast_topk,
|
fast_topk,
|
||||||
generate_token_bitmask,
|
generate_token_bitmask,
|
||||||
|
load_token_map,
|
||||||
select_top_k_tokens,
|
select_top_k_tokens,
|
||||||
)
|
)
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
empty_context,
|
empty_context,
|
||||||
get_available_gpu_memory,
|
get_available_gpu_memory,
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
is_blackwell,
|
|
||||||
is_cuda,
|
is_cuda,
|
||||||
next_power_of_2,
|
next_power_of_2,
|
||||||
)
|
)
|
||||||
@@ -67,14 +62,6 @@ logger = logging.getLogger(__name__)
|
|||||||
SGLANG_RETURN_ORIGINAL_LOGPROB = get_bool_env_var("SGLANG_RETURN_ORIGINAL_LOGPROB")
|
SGLANG_RETURN_ORIGINAL_LOGPROB = get_bool_env_var("SGLANG_RETURN_ORIGINAL_LOGPROB")
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def draft_tp_context(tp_group: GroupCoordinator):
|
|
||||||
# Draft model doesn't use dp and has its own tp group.
|
|
||||||
# We disable mscclpp now because it doesn't support 2 comm groups.
|
|
||||||
with patch_tensor_parallel_group(tp_group):
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
class EAGLEWorker(TpModelWorker):
|
class EAGLEWorker(TpModelWorker):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -100,7 +87,6 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
|
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
|
||||||
server_args.speculative_algorithm
|
server_args.speculative_algorithm
|
||||||
)
|
)
|
||||||
self.padded_static_len = -1
|
|
||||||
|
|
||||||
# Override the context length of the draft model to be the same as the target model.
|
# Override the context length of the draft model to be the same as the target model.
|
||||||
server_args.context_length = target_worker.model_runner.model_config.context_len
|
server_args.context_length = target_worker.model_runner.model_config.context_len
|
||||||
@@ -612,7 +598,8 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
logits_output, _ = self.draft_model_runner.forward(
|
logits_output, _ = self.draft_model_runner.forward(
|
||||||
forward_batch, skip_attn_backend_init=True
|
forward_batch, skip_attn_backend_init=True
|
||||||
)
|
)
|
||||||
self._detect_nan_if_needed(logits_output)
|
if self.server_args.enable_nan_detection:
|
||||||
|
detect_nan(logits_output)
|
||||||
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
|
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
|
||||||
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
|
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
|
||||||
if self.hot_token_id is not None:
|
if self.hot_token_id is not None:
|
||||||
@@ -680,7 +667,9 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
# and will be applied to produce wrong results
|
# and will be applied to produce wrong results
|
||||||
batch.sampling_info.vocab_mask = None
|
batch.sampling_info.vocab_mask = None
|
||||||
|
|
||||||
self._detect_nan_if_needed(logits_output)
|
if self.enable_nan_detection:
|
||||||
|
detect_nan(logits_output)
|
||||||
|
|
||||||
spec_info.hidden_states = logits_output.hidden_states
|
spec_info.hidden_states = logits_output.hidden_states
|
||||||
res: EagleVerifyOutput = spec_info.verify(
|
res: EagleVerifyOutput = spec_info.verify(
|
||||||
batch,
|
batch,
|
||||||
@@ -833,7 +822,8 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
)
|
)
|
||||||
forward_batch.return_logprob = False
|
forward_batch.return_logprob = False
|
||||||
logits_output, _ = self.draft_model_runner.forward(forward_batch)
|
logits_output, _ = self.draft_model_runner.forward(forward_batch)
|
||||||
self._detect_nan_if_needed(logits_output)
|
if self.enable_nan_detection:
|
||||||
|
detect_nan(logits_output)
|
||||||
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
||||||
assert forward_batch.spec_info is batch.spec_info
|
assert forward_batch.spec_info is batch.spec_info
|
||||||
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
||||||
@@ -928,7 +918,8 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
)
|
)
|
||||||
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
||||||
|
|
||||||
self._detect_nan_if_needed(logits_output)
|
if self.enable_nan_detection:
|
||||||
|
detect_nan(logits_output)
|
||||||
|
|
||||||
# Restore backup.
|
# Restore backup.
|
||||||
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
|
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
|
||||||
@@ -948,24 +939,6 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
draft_input.topk_p, draft_input.topk_index = fast_topk(probs, self.topk, dim=-1)
|
draft_input.topk_p, draft_input.topk_index = fast_topk(probs, self.topk, dim=-1)
|
||||||
draft_input.hidden_states = logits_output.hidden_states
|
draft_input.hidden_states = logits_output.hidden_states
|
||||||
|
|
||||||
def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput):
|
|
||||||
if self.enable_nan_detection:
|
|
||||||
logits = logits_output.next_token_logits
|
|
||||||
if torch.any(torch.isnan(logits)):
|
|
||||||
logger.error("Detected errors during sampling! NaN in the logits.")
|
|
||||||
raise ValueError("Detected errors during sampling! NaN in the logits.")
|
|
||||||
|
|
||||||
|
|
||||||
def load_token_map(token_map_path: str) -> List[int]:
|
|
||||||
if not os.path.exists(token_map_path):
|
|
||||||
cache_dir = snapshot_download(
|
|
||||||
os.path.dirname(token_map_path),
|
|
||||||
ignore_patterns=["*.bin", "*.safetensors"],
|
|
||||||
)
|
|
||||||
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
|
|
||||||
hot_token_id = torch.load(token_map_path, weights_only=True)
|
|
||||||
return torch.tensor(hot_token_id, dtype=torch.int64)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.compile(dynamic=True)
|
@torch.compile(dynamic=True)
|
||||||
def get_last_loc_large_page_size_top_k_1(
|
def get_last_loc_large_page_size_top_k_1(
|
||||||
|
|||||||
@@ -1,19 +1,25 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Optional
|
import time
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.cuda import Stream as CudaStream
|
from torch.cuda import Stream as CudaStream
|
||||||
|
|
||||||
from sglang.srt.environ import envs
|
from sglang.srt.environ import envs
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
||||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, Req
|
|
||||||
from sglang.srt.managers.scheduler import GenerationBatchResult
|
from sglang.srt.managers.scheduler import GenerationBatchResult
|
||||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||||
from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator
|
|
||||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardBatch
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
from sglang.srt.speculative.base_spec_worker import BaseDraftWorker, BaseSpecWorker
|
||||||
|
from sglang.srt.speculative.draft_utils import DraftBackendFactory
|
||||||
|
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
|
||||||
|
EAGLEDraftCudaGraphRunner,
|
||||||
|
)
|
||||||
|
from sglang.srt.speculative.eagle_draft_extend_cuda_graph_runner import (
|
||||||
|
EAGLEDraftExtendCudaGraphRunner,
|
||||||
|
)
|
||||||
from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
|
from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
|
||||||
from sglang.srt.speculative.eagle_info_v2 import (
|
from sglang.srt.speculative.eagle_info_v2 import (
|
||||||
assign_extend_cache_locs,
|
assign_extend_cache_locs,
|
||||||
@@ -22,69 +28,214 @@ from sglang.srt.speculative.eagle_info_v2 import (
|
|||||||
select_top_k_tokens_tmp,
|
select_top_k_tokens_tmp,
|
||||||
)
|
)
|
||||||
from sglang.srt.speculative.eagle_utils import TreeMaskMode, build_tree_kernel_efficient
|
from sglang.srt.speculative.eagle_utils import TreeMaskMode, build_tree_kernel_efficient
|
||||||
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||||
from sglang.srt.utils.common import fast_topk, next_power_of_2
|
from sglang.srt.speculative.spec_utils import (
|
||||||
|
detect_nan,
|
||||||
|
draft_tp_context,
|
||||||
|
load_token_map,
|
||||||
|
)
|
||||||
|
from sglang.srt.utils.common import (
|
||||||
|
empty_context,
|
||||||
|
fast_topk,
|
||||||
|
get_available_gpu_memory,
|
||||||
|
next_power_of_2,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class EAGLEWorkerV2(EAGLEWorker):
|
def _get_plan_stream(
|
||||||
|
device: str,
|
||||||
|
) -> Tuple[Optional[CudaStream], contextlib.AbstractContextManager]:
|
||||||
|
if envs.SGLANG_ENABLE_OVERLAP_PLAN_STREAM.get():
|
||||||
|
plan_stream: CudaStream = torch.get_device_module(device).Stream()
|
||||||
|
plan_stream_ctx = torch.cuda.stream(plan_stream)
|
||||||
|
return plan_stream, plan_stream_ctx
|
||||||
|
else:
|
||||||
|
return None, contextlib.nullcontext()
|
||||||
|
|
||||||
|
|
||||||
|
class EagleDraftWorker(BaseDraftWorker):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
gpu_id: int,
|
gpu_id: int,
|
||||||
tp_rank: int,
|
tp_rank: int,
|
||||||
dp_rank: Optional[int],
|
dp_rank: int,
|
||||||
moe_ep_rank: int,
|
moe_ep_rank: int,
|
||||||
nccl_port: int,
|
nccl_port: int,
|
||||||
target_worker: TpModelWorker,
|
target_worker: TpModelWorker,
|
||||||
):
|
):
|
||||||
super().__init__(
|
# copy args
|
||||||
server_args,
|
self.server_args = server_args
|
||||||
gpu_id,
|
self.gpu_id = gpu_id
|
||||||
tp_rank,
|
self.tp_rank = tp_rank
|
||||||
dp_rank,
|
self.dp_rank = dp_rank
|
||||||
moe_ep_rank,
|
self.moe_ep_rank = moe_ep_rank
|
||||||
nccl_port,
|
self.nccl_port = nccl_port
|
||||||
target_worker,
|
self.target_worker = target_worker
|
||||||
|
|
||||||
|
# Args for easy access
|
||||||
|
self.device = server_args.device
|
||||||
|
self.topk = server_args.speculative_eagle_topk
|
||||||
|
self.speculative_num_steps = server_args.speculative_num_steps
|
||||||
|
self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens
|
||||||
|
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
|
||||||
|
server_args.speculative_algorithm
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Set constant
|
||||||
EagleDraftInput.ALLOC_LEN_PER_DECODE = max(
|
EagleDraftInput.ALLOC_LEN_PER_DECODE = max(
|
||||||
self.speculative_num_steps * self.topk, self.speculative_num_draft_tokens
|
self.speculative_num_steps * self.topk, self.speculative_num_draft_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Do not capture cuda graph in `TpModelWorker` init,
|
||||||
|
# will capture later with init_cuda_graphs()
|
||||||
|
backup_disable_cuda_graph = server_args.disable_cuda_graph
|
||||||
|
server_args.disable_cuda_graph = True
|
||||||
|
|
||||||
|
# Share the allocator with a target worker.
|
||||||
|
# Draft and target worker own their own KV cache pools.
|
||||||
|
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
|
||||||
|
target_worker.get_memory_pool()
|
||||||
|
)
|
||||||
|
with empty_context():
|
||||||
|
# Init draft worker
|
||||||
|
self.draft_worker = TpModelWorker(
|
||||||
|
server_args=server_args,
|
||||||
|
gpu_id=gpu_id,
|
||||||
|
tp_rank=tp_rank,
|
||||||
|
pp_rank=0, # FIXME
|
||||||
|
dp_rank=dp_rank,
|
||||||
|
moe_ep_rank=moe_ep_rank,
|
||||||
|
nccl_port=nccl_port,
|
||||||
|
is_draft_worker=True,
|
||||||
|
req_to_token_pool=self.req_to_token_pool,
|
||||||
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Alias for better readability
|
||||||
|
self.draft_runner = self.draft_worker.model_runner
|
||||||
|
|
||||||
|
self.init_token_map()
|
||||||
|
self.init_lm_head()
|
||||||
|
|
||||||
|
# Init attention backend and cuda graphs
|
||||||
|
self.draft_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
|
||||||
|
self.draft_tp_context = (
|
||||||
|
draft_tp_context if server_args.enable_dp_attention else empty_context
|
||||||
|
)
|
||||||
|
with self.draft_tp_context(self.draft_runner.tp_group):
|
||||||
|
self.init_attention_backend()
|
||||||
|
self.init_cuda_graphs()
|
||||||
|
|
||||||
self.tree_mask_mode = TreeMaskMode.FULL_MASK
|
self.tree_mask_mode = TreeMaskMode.FULL_MASK
|
||||||
|
|
||||||
if envs.SGLANG_ENABLE_OVERLAP_PLAN_STREAM.get():
|
self.plan_stream, self.plan_stream_ctx = _get_plan_stream(self.device)
|
||||||
self.plan_stream: CudaStream = torch.get_device_module(self.device).Stream()
|
|
||||||
self.plan_stream_ctx = torch.cuda.stream(self.plan_stream)
|
|
||||||
else:
|
|
||||||
self.plan_stream = None
|
|
||||||
self.plan_stream_ctx = contextlib.nullcontext()
|
|
||||||
|
|
||||||
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
|
def init_token_map(self):
|
||||||
if model_worker_batch.forward_mode.is_decode():
|
# Load hot token ids
|
||||||
# FIXME(lsyin): why shall we use spec_info for both draft and verify?
|
if self.speculative_algorithm.is_eagle3():
|
||||||
draft_input: EagleDraftInput = model_worker_batch.spec_info
|
if self.server_args.speculative_token_map is not None:
|
||||||
assert draft_input.is_draft_input()
|
logger.warning(
|
||||||
verify_input: EagleVerifyInput = self.draft(model_worker_batch)
|
"Speculative token map specified, but EAGLE3 models already have this. Ignoring the specified token map."
|
||||||
assert verify_input.is_verify_input()
|
)
|
||||||
model_worker_batch.spec_info = verify_input
|
self.hot_token_id = None
|
||||||
batch_output = self.verify(model_worker_batch, draft_input.allocate_lens)
|
elif self.server_args.speculative_token_map is not None:
|
||||||
return batch_output
|
self.hot_token_id = load_token_map(self.server_args.speculative_token_map)
|
||||||
|
self.server_args.json_model_override_args = (
|
||||||
|
f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Target prefill
|
self.hot_token_id = None
|
||||||
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
|
||||||
batch_output = self.target_worker.forward_batch_generation(
|
def init_lm_head(self):
|
||||||
model_worker_batch
|
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
||||||
|
if self.speculative_algorithm.is_eagle3():
|
||||||
|
# most cases EAGLE3 models don't share lm_head
|
||||||
|
# but some models (e.g. nvidia/gpt-oss-120b-Eagle3) shares
|
||||||
|
if (
|
||||||
|
hasattr(self.draft_runner.model, "load_lm_head_from_target")
|
||||||
|
and self.draft_runner.model.load_lm_head_from_target
|
||||||
|
):
|
||||||
|
self.draft_runner.model.set_embed_and_head(embed, head)
|
||||||
|
else:
|
||||||
|
self.draft_runner.model.set_embed(embed)
|
||||||
|
|
||||||
|
# grab hot token ids
|
||||||
|
if self.draft_runner.model.hot_token_id is not None:
|
||||||
|
self.hot_token_id = self.draft_runner.model.hot_token_id.to(
|
||||||
|
embed.device
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if self.hot_token_id is not None:
|
||||||
|
head = head.clone()
|
||||||
|
self.hot_token_id = self.hot_token_id.to(head.device)
|
||||||
|
head.data = head.data[self.hot_token_id]
|
||||||
|
|
||||||
|
# Share the embedding and lm_head
|
||||||
|
self.draft_runner.model.set_embed_and_head(embed, head)
|
||||||
|
|
||||||
|
def init_attention_backend(self):
|
||||||
|
# Create multi-step attn backends and cuda graph runners
|
||||||
|
|
||||||
|
self.has_prefill_wrapper_verify = False
|
||||||
|
self.draft_extend_attn_backend = None
|
||||||
|
|
||||||
|
draft_backend_factory = DraftBackendFactory(
|
||||||
|
self.server_args,
|
||||||
|
self.draft_runner,
|
||||||
|
self.topk,
|
||||||
|
self.speculative_num_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize decode attention backend
|
||||||
|
self.draft_attn_backend = draft_backend_factory.create_decode_backend()
|
||||||
|
|
||||||
|
# Initialize draft extend attention backend (respects speculative_attention_mode setting)
|
||||||
|
self.draft_extend_attn_backend = (
|
||||||
|
draft_backend_factory.create_draft_extend_backend()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.draft_runner.draft_attn_backend = self.draft_attn_backend
|
||||||
|
self.tree_mask_mode = TreeMaskMode.FULL_MASK
|
||||||
|
|
||||||
|
def init_cuda_graphs(self):
|
||||||
|
"""Capture cuda graphs."""
|
||||||
|
self.cuda_graph_runner = None
|
||||||
|
self.cuda_graph_runner_for_draft_extend = None
|
||||||
|
|
||||||
|
if self.server_args.disable_cuda_graph:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Capture draft
|
||||||
|
if self.speculative_num_steps > 1:
|
||||||
|
tic = time.perf_counter()
|
||||||
|
before_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
||||||
|
logger.info(
|
||||||
|
f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
|
||||||
|
)
|
||||||
|
self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
|
||||||
|
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
||||||
|
logger.info(
|
||||||
|
f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Draft prefill
|
# Capture extend
|
||||||
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.LAST
|
if self.draft_extend_attn_backend:
|
||||||
batch_output.next_draft_input = self.forward_draft_extend(
|
tic = time.perf_counter()
|
||||||
model_worker_batch,
|
before_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
||||||
batch_output.logits_output.hidden_states,
|
logger.info(
|
||||||
batch_output.next_token_ids,
|
f"Capture draft extend cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
|
||||||
|
)
|
||||||
|
self.cuda_graph_runner_for_draft_extend = EAGLEDraftExtendCudaGraphRunner(
|
||||||
|
self
|
||||||
|
)
|
||||||
|
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
||||||
|
logger.info(
|
||||||
|
f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
|
||||||
)
|
)
|
||||||
return batch_output
|
|
||||||
|
|
||||||
def draft(self, model_worker_batch: ModelWorkerBatch):
|
def draft(self, model_worker_batch: ModelWorkerBatch):
|
||||||
draft_input: EagleDraftInput = model_worker_batch.spec_info
|
draft_input: EagleDraftInput = model_worker_batch.spec_info
|
||||||
@@ -92,7 +243,7 @@ class EAGLEWorkerV2(EAGLEWorker):
|
|||||||
self.req_to_token_pool,
|
self.req_to_token_pool,
|
||||||
model_worker_batch,
|
model_worker_batch,
|
||||||
self.cuda_graph_runner,
|
self.cuda_graph_runner,
|
||||||
self.draft_model_runner,
|
self.draft_runner,
|
||||||
self.topk,
|
self.topk,
|
||||||
self.speculative_num_steps,
|
self.speculative_num_steps,
|
||||||
)
|
)
|
||||||
@@ -201,10 +352,11 @@ class EAGLEWorkerV2(EAGLEWorker):
|
|||||||
spec_info.hidden_states = hidden_states
|
spec_info.hidden_states = hidden_states
|
||||||
|
|
||||||
# Run forward
|
# Run forward
|
||||||
logits_output = self.draft_model_runner.model.forward(
|
logits_output = self.draft_runner.model.forward(
|
||||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||||
)
|
)
|
||||||
self._detect_nan_if_needed(logits_output)
|
if self.server_args.enable_nan_detection:
|
||||||
|
detect_nan(logits_output)
|
||||||
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
|
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
|
||||||
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
|
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
|
||||||
if self.hot_token_id is not None:
|
if self.hot_token_id is not None:
|
||||||
@@ -233,10 +385,190 @@ class EAGLEWorkerV2(EAGLEWorker):
|
|||||||
|
|
||||||
return parent_list, top_scores_index, draft_tokens
|
return parent_list, top_scores_index, draft_tokens
|
||||||
|
|
||||||
|
def draft_extend(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _draft_extend_for_prefill(
|
||||||
|
self,
|
||||||
|
batch: ModelWorkerBatch,
|
||||||
|
target_hidden_states: torch.Tensor,
|
||||||
|
next_token_ids: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Run draft model extend to correctly fill the KV cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: The batch to run.
|
||||||
|
target_hidden_states: Hidden states from the target model forward
|
||||||
|
next_token_ids: Next token ids generated from the target forward.
|
||||||
|
"""
|
||||||
|
# Construct input_ids
|
||||||
|
pt = 0
|
||||||
|
for i, extend_len in enumerate(batch.extend_seq_lens):
|
||||||
|
input_ids = batch.input_ids[pt : pt + extend_len]
|
||||||
|
batch.input_ids[pt : pt + extend_len] = torch.cat(
|
||||||
|
(input_ids[1:], next_token_ids[i].reshape(1))
|
||||||
|
)
|
||||||
|
pt += extend_len
|
||||||
|
|
||||||
|
# Construct spec_info
|
||||||
|
next_draft_input = EagleDraftInput(
|
||||||
|
hidden_states=target_hidden_states,
|
||||||
|
verified_id=next_token_ids,
|
||||||
|
new_seq_lens=batch.seq_lens,
|
||||||
|
allocate_lens=batch.seq_lens,
|
||||||
|
)
|
||||||
|
batch.spec_info = next_draft_input
|
||||||
|
|
||||||
|
# Run forward
|
||||||
|
forward_batch = ForwardBatch.init_new(batch, self.draft_runner)
|
||||||
|
logits_output, _ = self.draft_runner.forward(forward_batch)
|
||||||
|
|
||||||
|
# Update spec_info for the next draft step
|
||||||
|
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
|
||||||
|
next_draft_input.topk_p, next_draft_input.topk_index = fast_topk(
|
||||||
|
probs, self.topk, dim=-1
|
||||||
|
)
|
||||||
|
next_draft_input.hidden_states = logits_output.hidden_states
|
||||||
|
return next_draft_input
|
||||||
|
|
||||||
|
def _draft_extend_for_decode(
|
||||||
|
self, batch: ModelWorkerBatch, batch_result: GenerationBatchResult
|
||||||
|
):
|
||||||
|
# Batch 2: Draft extend
|
||||||
|
draft_input = EagleDraftInput(
|
||||||
|
hidden_states=batch_result.logits_output.hidden_states,
|
||||||
|
)
|
||||||
|
select_index = (
|
||||||
|
torch.arange(len(batch.seq_lens), device=self.device)
|
||||||
|
* self.speculative_num_draft_tokens
|
||||||
|
+ batch_result.accept_lens
|
||||||
|
- 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare for draft extend in a separate stream
|
||||||
|
with self.plan_stream_ctx:
|
||||||
|
forward_batch = draft_input.prepare_for_extend_to_fill_draft_kvcache(
|
||||||
|
batch,
|
||||||
|
batch_result.next_token_ids,
|
||||||
|
self.speculative_num_draft_tokens,
|
||||||
|
self.draft_runner,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.plan_stream:
|
||||||
|
torch.cuda.current_stream().wait_stream(self.plan_stream)
|
||||||
|
|
||||||
|
# Run draft extend batch in the main compute stream
|
||||||
|
draft_logits_output = self.draft_runner.model.forward(
|
||||||
|
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reorganize the spec info for the next batch
|
||||||
|
draft_logits_output.next_token_logits = draft_logits_output.next_token_logits[
|
||||||
|
select_index
|
||||||
|
]
|
||||||
|
draft_logits_output.hidden_states = draft_logits_output.hidden_states[
|
||||||
|
select_index
|
||||||
|
]
|
||||||
|
probs = torch.softmax(draft_logits_output.next_token_logits, dim=-1)
|
||||||
|
ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1)
|
||||||
|
ret_hidden_states = draft_logits_output.hidden_states
|
||||||
|
|
||||||
|
# Construct the return values
|
||||||
|
next_draft_input = batch_result.next_draft_input
|
||||||
|
(
|
||||||
|
next_draft_input.topk_p,
|
||||||
|
next_draft_input.topk_index,
|
||||||
|
next_draft_input.hidden_states,
|
||||||
|
) = (
|
||||||
|
ret_topk_p,
|
||||||
|
ret_topk_index,
|
||||||
|
ret_hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EAGLEWorkerV2(BaseSpecWorker):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
server_args: ServerArgs,
|
||||||
|
gpu_id: int,
|
||||||
|
tp_rank: int,
|
||||||
|
dp_rank: Optional[int],
|
||||||
|
moe_ep_rank: int,
|
||||||
|
nccl_port: int,
|
||||||
|
target_worker: TpModelWorker,
|
||||||
|
):
|
||||||
|
# Parse arguments
|
||||||
|
self.server_args = server_args
|
||||||
|
self.topk = server_args.speculative_eagle_topk
|
||||||
|
self.speculative_num_steps = server_args.speculative_num_steps
|
||||||
|
self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens
|
||||||
|
self.enable_nan_detection = server_args.enable_nan_detection
|
||||||
|
self.gpu_id = gpu_id
|
||||||
|
self.device = server_args.device
|
||||||
|
self._target_worker = target_worker
|
||||||
|
self.page_size = server_args.page_size
|
||||||
|
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
|
||||||
|
server_args.speculative_algorithm
|
||||||
|
)
|
||||||
|
|
||||||
|
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
|
||||||
|
target_worker.get_memory_pool()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Override the context length of the draft model to be the same as the target model.
|
||||||
|
server_args.context_length = target_worker.model_runner.model_config.context_len
|
||||||
|
|
||||||
|
self._draft_worker = EagleDraftWorker(
|
||||||
|
server_args, gpu_id, tp_rank, dp_rank, moe_ep_rank, nccl_port, target_worker
|
||||||
|
)
|
||||||
|
|
||||||
|
# Some dummy tensors
|
||||||
|
self.num_new_pages_per_topk = torch.empty(
|
||||||
|
(), dtype=torch.int64, device=self.device
|
||||||
|
)
|
||||||
|
self.extend_lens = torch.empty((), dtype=torch.int64, device=self.device)
|
||||||
|
|
||||||
|
self.plan_stream, self.plan_stream_ctx = _get_plan_stream(self.device)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def target_worker(self):
|
||||||
|
return self._target_worker
|
||||||
|
|
||||||
|
@property
|
||||||
|
def draft_worker(self):
|
||||||
|
return self._draft_worker
|
||||||
|
|
||||||
|
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
|
||||||
|
if model_worker_batch.forward_mode.is_decode():
|
||||||
|
draft_input: EagleDraftInput = model_worker_batch.spec_info
|
||||||
|
assert draft_input.is_draft_input()
|
||||||
|
verify_input: EagleVerifyInput = self.draft_worker.draft(model_worker_batch)
|
||||||
|
assert verify_input.is_verify_input()
|
||||||
|
model_worker_batch.spec_info = verify_input
|
||||||
|
batch_output = self.verify(model_worker_batch, draft_input.allocate_lens)
|
||||||
|
self.draft_worker._draft_extend_for_decode(model_worker_batch, batch_output)
|
||||||
|
return batch_output
|
||||||
|
else:
|
||||||
|
# Target prefill
|
||||||
|
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||||
|
batch_output = self.target_worker.forward_batch_generation(
|
||||||
|
model_worker_batch
|
||||||
|
)
|
||||||
|
|
||||||
|
# Draft prefill
|
||||||
|
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||||
|
batch_output.next_draft_input = self.draft_worker._draft_extend_for_prefill(
|
||||||
|
model_worker_batch,
|
||||||
|
batch_output.logits_output.hidden_states,
|
||||||
|
batch_output.next_token_ids,
|
||||||
|
)
|
||||||
|
return batch_output
|
||||||
|
|
||||||
def verify(
|
def verify(
|
||||||
self,
|
self,
|
||||||
batch: ModelWorkerBatch,
|
batch: ModelWorkerBatch,
|
||||||
pre_draft_allocate_lens: torch.Tensor,
|
cur_allocate_lens: torch.Tensor,
|
||||||
):
|
):
|
||||||
# Since batch.seq_lens is allocated in another stream, we need
|
# Since batch.seq_lens is allocated in another stream, we need
|
||||||
# record_stream() to prevent pytorch gc and reuse the gpu memory
|
# record_stream() to prevent pytorch gc and reuse the gpu memory
|
||||||
@@ -284,7 +616,8 @@ class EAGLEWorkerV2(EAGLEWorker):
|
|||||||
logits_output = forward_batch_output.logits_output
|
logits_output = forward_batch_output.logits_output
|
||||||
|
|
||||||
# Sample
|
# Sample
|
||||||
self._detect_nan_if_needed(logits_output)
|
if self.enable_nan_detection:
|
||||||
|
detect_nan(logits_output)
|
||||||
(
|
(
|
||||||
predict,
|
predict,
|
||||||
accept_length,
|
accept_length,
|
||||||
@@ -303,53 +636,11 @@ class EAGLEWorkerV2(EAGLEWorker):
|
|||||||
self.speculative_num_draft_tokens,
|
self.speculative_num_draft_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Batch 2: Draft extend
|
# Construct the next draft input
|
||||||
draft_input = EagleDraftInput(
|
|
||||||
hidden_states=logits_output.hidden_states,
|
|
||||||
)
|
|
||||||
select_index = (
|
|
||||||
torch.arange(len(batch.seq_lens), device=self.device)
|
|
||||||
* self.speculative_num_draft_tokens
|
|
||||||
+ accept_length
|
|
||||||
- 1
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prepare for draft extend in a separate stream
|
|
||||||
with self.plan_stream_ctx:
|
|
||||||
forward_batch = draft_input.prepare_for_extend_to_fill_draft_kvcache(
|
|
||||||
batch,
|
|
||||||
predict,
|
|
||||||
self.speculative_num_draft_tokens,
|
|
||||||
self.draft_model_runner,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.plan_stream:
|
|
||||||
torch.cuda.current_stream().wait_stream(self.plan_stream)
|
|
||||||
|
|
||||||
# Run draft extend batch in the main compute stream
|
|
||||||
draft_logits_output = self.draft_model_runner.model.forward(
|
|
||||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
|
||||||
)
|
|
||||||
|
|
||||||
# Reorganize the spec info for the next batch
|
|
||||||
draft_logits_output.next_token_logits = draft_logits_output.next_token_logits[
|
|
||||||
select_index
|
|
||||||
]
|
|
||||||
draft_logits_output.hidden_states = draft_logits_output.hidden_states[
|
|
||||||
select_index
|
|
||||||
]
|
|
||||||
probs = torch.softmax(draft_logits_output.next_token_logits, dim=-1)
|
|
||||||
ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1)
|
|
||||||
ret_hidden_states = draft_logits_output.hidden_states
|
|
||||||
|
|
||||||
# Construct the return values
|
|
||||||
next_draft_input = EagleDraftInput(
|
next_draft_input = EagleDraftInput(
|
||||||
topk_p=ret_topk_p,
|
|
||||||
topk_index=ret_topk_index,
|
|
||||||
hidden_states=ret_hidden_states,
|
|
||||||
verified_id=verified_id,
|
verified_id=verified_id,
|
||||||
new_seq_lens=new_seq_lens,
|
new_seq_lens=new_seq_lens,
|
||||||
allocate_lens=pre_draft_allocate_lens,
|
allocate_lens=cur_allocate_lens,
|
||||||
verify_done=verify_done,
|
verify_done=verify_done,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -359,53 +650,9 @@ class EAGLEWorkerV2(EAGLEWorker):
|
|||||||
can_run_cuda_graph=can_run_cuda_graph,
|
can_run_cuda_graph=can_run_cuda_graph,
|
||||||
next_draft_input=next_draft_input,
|
next_draft_input=next_draft_input,
|
||||||
accept_lens=accept_length,
|
accept_lens=accept_length,
|
||||||
last_batch_allocate_lens=pre_draft_allocate_lens,
|
allocate_lens=cur_allocate_lens,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_draft_extend(
|
|
||||||
self,
|
|
||||||
batch: ModelWorkerBatch,
|
|
||||||
target_hidden_states: torch.Tensor,
|
|
||||||
next_token_ids: torch.Tensor,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Run draft model extend to correctly fill the KV cache.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
batch: The batch to run.
|
|
||||||
target_hidden_states: Hidden states from the target model forward
|
|
||||||
next_token_ids: Next token ids generated from the target forward.
|
|
||||||
"""
|
|
||||||
# Construct input_ids
|
|
||||||
pt = 0
|
|
||||||
for i, extend_len in enumerate(batch.extend_seq_lens):
|
|
||||||
input_ids = batch.input_ids[pt : pt + extend_len]
|
|
||||||
batch.input_ids[pt : pt + extend_len] = torch.cat(
|
|
||||||
(input_ids[1:], next_token_ids[i].reshape(1))
|
|
||||||
)
|
|
||||||
pt += extend_len
|
|
||||||
|
|
||||||
# Construct spec_info
|
|
||||||
next_draft_input = EagleDraftInput(
|
|
||||||
hidden_states=target_hidden_states,
|
|
||||||
verified_id=next_token_ids,
|
|
||||||
new_seq_lens=batch.seq_lens,
|
|
||||||
allocate_lens=batch.seq_lens,
|
|
||||||
)
|
|
||||||
batch.spec_info = next_draft_input
|
|
||||||
|
|
||||||
# Run forward
|
|
||||||
forward_batch = ForwardBatch.init_new(batch, self.draft_model_runner)
|
|
||||||
logits_output, _ = self.draft_model_runner.forward(forward_batch)
|
|
||||||
|
|
||||||
# Update spec_info for the next draft step
|
|
||||||
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
|
|
||||||
next_draft_input.topk_p, next_draft_input.topk_index = fast_topk(
|
|
||||||
probs, self.topk, dim=-1
|
|
||||||
)
|
|
||||||
next_draft_input.hidden_states = logits_output.hidden_states
|
|
||||||
return next_draft_input
|
|
||||||
|
|
||||||
def move_accepted_tokens_to_target_kvcache(
|
def move_accepted_tokens_to_target_kvcache(
|
||||||
self,
|
self,
|
||||||
batch: ModelWorkerBatch,
|
batch: ModelWorkerBatch,
|
||||||
@@ -449,32 +696,3 @@ class EAGLEWorkerV2(EAGLEWorker):
|
|||||||
self.token_to_kv_pool_allocator.get_kvcache().move_kv_cache(
|
self.token_to_kv_pool_allocator.get_kvcache().move_kv_cache(
|
||||||
tgt_cache_loc, accepted_out_cache_loc
|
tgt_cache_loc, accepted_out_cache_loc
|
||||||
)
|
)
|
||||||
|
|
||||||
def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput):
|
|
||||||
if self.enable_nan_detection:
|
|
||||||
logits = logits_output.next_token_logits
|
|
||||||
if torch.any(torch.isnan(logits)):
|
|
||||||
logger.error("Detected errors during sampling! NaN in the logits.")
|
|
||||||
raise ValueError("Detected errors during sampling! NaN in the logits.")
|
|
||||||
|
|
||||||
|
|
||||||
def free_spec_dec_tokens_page_size_1(
|
|
||||||
req_to_token_pool: ReqToTokenPool,
|
|
||||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
|
||||||
req: Req,
|
|
||||||
allocate_len: int,
|
|
||||||
new_seq_len: int,
|
|
||||||
):
|
|
||||||
# FIXME(lsyin): move this function elsewhere
|
|
||||||
|
|
||||||
# free extra allocated tokens
|
|
||||||
if new_seq_len is None:
|
|
||||||
# True only for overlap eagle and the current batch is decode. This seq will be part of the decode, so the final iteration's allocation is not used (i.e. this case).
|
|
||||||
start_len = allocate_len - EagleDraftInput.ALLOC_LEN_PER_DECODE
|
|
||||||
else:
|
|
||||||
# True for 1) non-overlap; 2) overlap eagle and the current batch is prefill. This seq will not run extra iteration, so start_lens is passed in.
|
|
||||||
start_len = new_seq_len
|
|
||||||
indices_to_free = req_to_token_pool.req_to_token[req.req_pool_idx][
|
|
||||||
start_len:allocate_len
|
|
||||||
]
|
|
||||||
token_to_kv_pool_allocator.free(indices_to_free)
|
|
||||||
|
|||||||
@@ -1,25 +1,37 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
|
from contextlib import contextmanager
|
||||||
from typing import TYPE_CHECKING, List
|
from typing import TYPE_CHECKING, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
||||||
|
from sglang.srt.distributed.parallel_state import (
|
||||||
|
GroupCoordinator,
|
||||||
|
patch_tensor_parallel_group,
|
||||||
|
)
|
||||||
from sglang.srt.environ import envs
|
from sglang.srt.environ import envs
|
||||||
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.managers.schedule_batch import Req
|
from sglang.srt.managers.schedule_batch import Req
|
||||||
from sglang.srt.utils import is_cuda, is_hip
|
from sglang.srt.utils import is_cuda, is_hip
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator
|
||||||
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||||
|
from sglang.srt.speculative.eagle_info import EagleVerifyInput
|
||||||
|
|
||||||
|
|
||||||
if is_cuda():
|
if is_cuda():
|
||||||
from sgl_kernel import fast_topk
|
from sgl_kernel import fast_topk
|
||||||
elif is_hip():
|
elif is_hip():
|
||||||
from sgl_kernel import fast_topk
|
from sgl_kernel import fast_topk
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from sglang.srt.speculative.eagle_info import EagleVerifyInput
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -603,3 +615,29 @@ def generate_token_bitmask(
|
|||||||
|
|
||||||
verify_input.grammar = grammar
|
verify_input.grammar = grammar
|
||||||
return allocate_token_bitmask
|
return allocate_token_bitmask
|
||||||
|
|
||||||
|
|
||||||
|
def load_token_map(token_map_path: str) -> List[int]:
|
||||||
|
if not os.path.exists(token_map_path):
|
||||||
|
cache_dir = snapshot_download(
|
||||||
|
os.path.dirname(token_map_path),
|
||||||
|
ignore_patterns=["*.bin", "*.safetensors"],
|
||||||
|
)
|
||||||
|
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
|
||||||
|
hot_token_id = torch.load(token_map_path, weights_only=True)
|
||||||
|
return torch.tensor(hot_token_id, dtype=torch.int64)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def draft_tp_context(tp_group: GroupCoordinator):
|
||||||
|
# Draft model doesn't use dp and has its own tp group.
|
||||||
|
# We disable mscclpp now because it doesn't support 2 comm groups.
|
||||||
|
with patch_tensor_parallel_group(tp_group):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
def detect_nan(logits_output: LogitsProcessorOutput):
|
||||||
|
logits = logits_output.next_token_logits
|
||||||
|
if torch.any(torch.isnan(logits)):
|
||||||
|
logger.error("Detected errors during sampling! NaN in the logits.")
|
||||||
|
raise ValueError("Detected errors during sampling! NaN in the logits.")
|
||||||
|
|||||||
@@ -1,14 +1,13 @@
|
|||||||
import logging
|
import logging
|
||||||
from contextlib import contextmanager
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
|
|
||||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.speculative.eagle_worker import EAGLEWorker, load_token_map
|
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
||||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||||
|
from sglang.srt.speculative.spec_utils import draft_tp_context, load_token_map
|
||||||
from sglang.srt.utils import empty_context, get_bool_env_var, is_cuda
|
from sglang.srt.utils import empty_context, get_bool_env_var, is_cuda
|
||||||
|
|
||||||
if is_cuda():
|
if is_cuda():
|
||||||
@@ -18,14 +17,6 @@ logger = logging.getLogger(__name__)
|
|||||||
SGLANG_RETURN_ORIGINAL_LOGPROB = get_bool_env_var("SGLANG_RETURN_ORIGINAL_LOGPROB")
|
SGLANG_RETURN_ORIGINAL_LOGPROB = get_bool_env_var("SGLANG_RETURN_ORIGINAL_LOGPROB")
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def draft_tp_context(tp_group: GroupCoordinator):
|
|
||||||
# Draft model doesn't use dp and has its own tp group.
|
|
||||||
# We disable mscclpp now because it doesn't support 2 comm groups.
|
|
||||||
with patch_tensor_parallel_group(tp_group):
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
class StandaloneWorker(EAGLEWorker):
|
class StandaloneWorker(EAGLEWorker):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -51,7 +42,6 @@ class StandaloneWorker(EAGLEWorker):
|
|||||||
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
|
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
|
||||||
server_args.speculative_algorithm
|
server_args.speculative_algorithm
|
||||||
)
|
)
|
||||||
self.padded_static_len = -1
|
|
||||||
|
|
||||||
# Override the context length of the draft model to be the same as the target model.
|
# Override the context length of the draft model to be the same as the target model.
|
||||||
server_args.context_length = target_worker.model_runner.model_config.context_len
|
server_args.context_length = target_worker.model_runner.model_config.context_len
|
||||||
|
|||||||
Reference in New Issue
Block a user