Abstraction for spec worker and code cleanup (#11643)
This commit is contained in:
@@ -1061,38 +1061,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
)
|
||||
return req_pool_indices
|
||||
|
||||
def allocate_for_eagle_v2(self):
|
||||
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
||||
from sglang.srt.speculative.spec_utils import assign_req_to_token_pool
|
||||
|
||||
bs = self.batch_size()
|
||||
|
||||
assert self.spec_info.is_draft_input()
|
||||
draft_input: EagleDraftInput = self.spec_info
|
||||
|
||||
# FIXME(lsyin): now implementation does not enable over-allocation
|
||||
# Now seq_lens and allocate_lens are correct
|
||||
self.maybe_wait_verify_done()
|
||||
|
||||
new_allocate_lens = self.seq_lens + EagleDraftInput.ALLOC_LEN_PER_DECODE
|
||||
num_needed_tokens = (new_allocate_lens - draft_input.allocate_lens).sum().item()
|
||||
out_cache_loc = alloc_token_slots(self.tree_cache, num_needed_tokens)
|
||||
|
||||
assign_req_to_token_pool[(bs,)](
|
||||
self.req_pool_indices,
|
||||
self.req_to_token_pool.req_to_token,
|
||||
draft_input.allocate_lens,
|
||||
new_allocate_lens,
|
||||
out_cache_loc,
|
||||
self.req_to_token_pool.req_to_token.shape[1],
|
||||
next_power_of_2(bs),
|
||||
)
|
||||
draft_input.allocate_lens = new_allocate_lens
|
||||
|
||||
# FIXME(lsyin): remove seq_lens_sum calculation
|
||||
self.seq_lens_cpu = self.seq_lens.cpu()
|
||||
self.seq_lens_sum = self.seq_lens_cpu.sum().item()
|
||||
|
||||
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
|
||||
self.encoder_lens_cpu = []
|
||||
self.encoder_cached = []
|
||||
@@ -1522,8 +1490,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
bs = len(self.reqs)
|
||||
|
||||
if self.is_v2_eagle:
|
||||
# FIXME(lsyin): make this sync optional
|
||||
self.allocate_for_eagle_v2()
|
||||
# TODO(spec-v2): all v2 spec should go through this path
|
||||
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
||||
|
||||
draft_input: EagleDraftInput = self.spec_info
|
||||
draft_input.prepare_for_decode(self)
|
||||
|
||||
if not self.spec_algorithm.is_none():
|
||||
# if spec decoding is used, the decode batch is prepared inside
|
||||
|
||||
@@ -215,10 +215,10 @@ class GenerationBatchResult:
|
||||
delay_sample_func: Optional[callable] = None
|
||||
future_indices: Optional[FutureIndices] = None
|
||||
|
||||
# FIXME(lsyin): maybe move to <BetterPlace> ?
|
||||
# FIXME(lsyin): maybe move to a better place?
|
||||
# sync path: forward stream -> output processor
|
||||
accept_lens: Optional[torch.Tensor] = None
|
||||
last_batch_allocate_lens: Optional[torch.Tensor] = None
|
||||
allocate_lens: Optional[torch.Tensor] = None
|
||||
|
||||
# relay path: forward stream -> next step forward
|
||||
next_draft_input: Optional[EagleDraftInput] = None
|
||||
@@ -246,10 +246,8 @@ class GenerationBatchResult:
|
||||
if self.accept_lens is not None:
|
||||
self.accept_lens = self.accept_lens.to("cpu", non_blocking=True)
|
||||
|
||||
if self.last_batch_allocate_lens is not None:
|
||||
self.last_batch_allocate_lens = self.last_batch_allocate_lens.to(
|
||||
"cpu", non_blocking=True
|
||||
)
|
||||
if self.allocate_lens is not None:
|
||||
self.allocate_lens = self.allocate_lens.to("cpu", non_blocking=True)
|
||||
|
||||
self.copy_done.record()
|
||||
|
||||
|
||||
@@ -42,23 +42,21 @@ class SchedulerOutputProcessorMixin:
|
||||
skip_stream_req = None
|
||||
|
||||
if self.is_generation:
|
||||
if result.copy_done is not None:
|
||||
result.copy_done.synchronize()
|
||||
|
||||
(
|
||||
logits_output,
|
||||
next_token_ids,
|
||||
extend_input_len_per_req,
|
||||
extend_logprob_start_len_per_req,
|
||||
copy_done,
|
||||
) = (
|
||||
result.logits_output,
|
||||
result.next_token_ids,
|
||||
result.extend_input_len_per_req,
|
||||
result.extend_logprob_start_len_per_req,
|
||||
result.copy_done,
|
||||
)
|
||||
|
||||
if copy_done is not None:
|
||||
copy_done.synchronize()
|
||||
|
||||
# Move next_token_ids and logprobs to cpu
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
if batch.return_logprob:
|
||||
@@ -199,57 +197,52 @@ class SchedulerOutputProcessorMixin:
|
||||
|
||||
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
|
||||
|
||||
def hacky_process_eagle_overlap_result(
|
||||
def _resolve_spec_overlap_token_ids(
|
||||
self: Scheduler, result: GenerationBatchResult, batch: ScheduleBatch
|
||||
):
|
||||
# TODO(lsyin): try use a copy stream to share SMs with forward
|
||||
# FIXME(lsyin): better organize this token free logic in eagle-overlap
|
||||
last_batch_allocate_lens_cpu = result.last_batch_allocate_lens.tolist()
|
||||
accept_lens_cpu = result.accept_lens.tolist()
|
||||
) -> List[List[int]]:
|
||||
"""Resolve the padding next token ids for speculative decoding with overlap."""
|
||||
assert result.next_token_ids.is_cpu
|
||||
assert result.accept_lens.is_cpu
|
||||
assert result.allocate_lens.is_cpu
|
||||
|
||||
next_token_ids = result.next_token_ids.tolist()
|
||||
accept_lens = result.accept_lens.tolist()
|
||||
result.num_accepted_tokens = sum(accept_lens)
|
||||
|
||||
predict_tokens = []
|
||||
num_draft_tokens = self.draft_worker.speculative_num_draft_tokens
|
||||
stride = self.draft_worker.speculative_num_draft_tokens
|
||||
for i, req in enumerate(batch.reqs):
|
||||
predict_tokens.append(
|
||||
next_token_ids[
|
||||
i * num_draft_tokens : i * num_draft_tokens + accept_lens_cpu[i]
|
||||
]
|
||||
next_token_ids[i * stride : i * stride + accept_lens[i]]
|
||||
)
|
||||
# FIXME(lsyin): move this update elsewhere
|
||||
req.spec_verify_ct += 1
|
||||
|
||||
return last_batch_allocate_lens_cpu, accept_lens_cpu, predict_tokens
|
||||
return predict_tokens
|
||||
|
||||
def process_batch_result_decode(
|
||||
self: Scheduler,
|
||||
batch: ScheduleBatch,
|
||||
result: GenerationBatchResult,
|
||||
):
|
||||
logits_output, next_token_ids, can_run_cuda_graph, copy_done = (
|
||||
if result.copy_done is not None:
|
||||
result.copy_done.synchronize()
|
||||
|
||||
logits_output, next_token_ids, can_run_cuda_graph = (
|
||||
result.logits_output,
|
||||
result.next_token_ids,
|
||||
result.can_run_cuda_graph,
|
||||
result.copy_done,
|
||||
)
|
||||
self.num_generated_tokens += len(batch.reqs)
|
||||
|
||||
if copy_done is not None:
|
||||
copy_done.synchronize()
|
||||
|
||||
if batch.spec_algorithm.is_none():
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
if batch.return_logprob:
|
||||
next_token_logprobs = logits_output.next_token_logprobs.tolist()
|
||||
elif batch.is_v2_eagle:
|
||||
(
|
||||
last_batch_allocate_lens_cpu,
|
||||
accept_lens_cpu,
|
||||
next_token_ids,
|
||||
) = self.hacky_process_eagle_overlap_result(result, batch)
|
||||
result.num_accepted_tokens = sum(accept_lens_cpu)
|
||||
next_token_ids = self._resolve_spec_overlap_token_ids(result, batch)
|
||||
allocate_lens_list = result.allocate_lens.tolist()
|
||||
accept_lens_list = result.accept_lens.tolist()
|
||||
|
||||
# FIXME(lsyin): we suppose we have already got the num_accepted_tokens in result
|
||||
self.num_generated_tokens += len(batch.reqs)
|
||||
if not self.spec_algorithm.is_none():
|
||||
self.update_spec_metrics(batch.batch_size(), result.num_accepted_tokens)
|
||||
|
||||
@@ -264,43 +257,38 @@ class SchedulerOutputProcessorMixin:
|
||||
continue
|
||||
|
||||
if self.enable_overlap and req.finished():
|
||||
indices_to_free = None
|
||||
if self.page_size == 1:
|
||||
if batch.spec_algorithm.is_eagle():
|
||||
from sglang.srt.speculative.eagle_worker_v2 import (
|
||||
free_spec_dec_tokens_page_size_1,
|
||||
)
|
||||
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
||||
|
||||
free_spec_dec_tokens_page_size_1(
|
||||
self.req_to_token_pool,
|
||||
self.token_to_kv_pool_allocator,
|
||||
req,
|
||||
last_batch_allocate_lens_cpu[i],
|
||||
None,
|
||||
)
|
||||
end_p = allocate_lens_list[i]
|
||||
start_p = end_p - EagleDraftInput.ALLOC_LEN_PER_DECODE
|
||||
indices_to_free = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx
|
||||
][start_p:end_p]
|
||||
else:
|
||||
# Free the one extra delayed token
|
||||
self.token_to_kv_pool_allocator.free(
|
||||
batch.out_cache_loc[i : i + 1]
|
||||
)
|
||||
indices_to_free = batch.out_cache_loc[i : i + 1]
|
||||
else:
|
||||
if batch.spec_algorithm.is_eagle():
|
||||
# TODO(lsyin): support eagle with page_size > 1
|
||||
# TODO(spec-v2): support eagle with page_size > 1
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
if (
|
||||
len(req.origin_input_ids) + len(req.output_ids) - 1
|
||||
) % self.page_size == 0:
|
||||
# Only free when the extra token is in a new page
|
||||
self.token_to_kv_pool_allocator.free(
|
||||
batch.out_cache_loc[i : i + 1]
|
||||
)
|
||||
indices_to_free = batch.out_cache_loc[i : i + 1]
|
||||
|
||||
if indices_to_free is not None:
|
||||
self.token_to_kv_pool_allocator.free(indices_to_free)
|
||||
continue
|
||||
|
||||
if batch.spec_algorithm.is_none():
|
||||
req.output_ids.append(next_token_id)
|
||||
elif batch.is_v2_eagle:
|
||||
# FIXME(lsyin): non-overlap spec worker will solve the output_ids in speculative decoding
|
||||
# !!!unify the logic here!!!
|
||||
# Only v2 eagle's output_ids are updated here.
|
||||
req.output_ids.extend(next_token_id)
|
||||
|
||||
req.check_finished()
|
||||
@@ -308,24 +296,13 @@ class SchedulerOutputProcessorMixin:
|
||||
if batch.is_v2_eagle and self.cur_batch.forward_mode.is_extend():
|
||||
# FIXME(lsyin): fix the messy logic here
|
||||
# 1) when not overlap (v2 impl), we free the extra tokens in the req
|
||||
# 2) when overlap and current batch is extend, we free the extra tokens in the req of the previous batch
|
||||
from sglang.srt.speculative.eagle_worker_v2 import (
|
||||
free_spec_dec_tokens_page_size_1,
|
||||
)
|
||||
|
||||
new_seq_len = len(req.origin_input_ids) + len(req.output_ids) - 1
|
||||
# FIXME(lsyin): remove this assert
|
||||
assert new_seq_len == int(
|
||||
batch.seq_lens_cpu[i] + accept_lens_cpu[i]
|
||||
), f"{new_seq_len=} vs {batch.seq_lens_cpu[i] + accept_lens_cpu[i]=}"
|
||||
|
||||
free_spec_dec_tokens_page_size_1(
|
||||
self.req_to_token_pool,
|
||||
self.token_to_kv_pool_allocator,
|
||||
req,
|
||||
last_batch_allocate_lens_cpu[i],
|
||||
new_seq_len,
|
||||
)
|
||||
# 2) overlap eagle and the current batch is prefill. This seq will not run extra iteration.
|
||||
start_p = batch.seq_lens_cpu[i] + accept_lens_list[i]
|
||||
end_p = allocate_lens_list[i]
|
||||
indices_to_free = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx
|
||||
][start_p:end_p]
|
||||
self.token_to_kv_pool_allocator.free(indices_to_free)
|
||||
|
||||
if self.server_args.disaggregation_decode_enable_offload_kvcache:
|
||||
# Asynchronously offload KV cache; cache_finished_req will be called after Device->Host transfer completes
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
@@ -54,7 +55,140 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TpModelWorker:
|
||||
class BaseTpWorker(ABC):
|
||||
@abstractmethod
|
||||
def forward_batch_generation(self, forward_batch: ForwardBatch):
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def model_runner(self) -> ModelRunner:
|
||||
pass
|
||||
|
||||
@property
|
||||
def sliding_window_size(self) -> Optional[int]:
|
||||
return self.model_runner.sliding_window_size
|
||||
|
||||
@property
|
||||
def is_hybrid(self) -> bool:
|
||||
return self.model_runner.is_hybrid is not None
|
||||
|
||||
def get_tokens_per_layer_info(self):
|
||||
return (
|
||||
self.model_runner.full_max_total_num_tokens,
|
||||
self.model_runner.swa_max_total_num_tokens,
|
||||
)
|
||||
|
||||
def get_pad_input_ids_func(self):
|
||||
return getattr(self.model_runner.model, "pad_input_ids", None)
|
||||
|
||||
def get_tp_group(self):
|
||||
return self.model_runner.tp_group
|
||||
|
||||
def get_attention_tp_group(self):
|
||||
return self.model_runner.attention_tp_group
|
||||
|
||||
def get_attention_tp_cpu_group(self):
|
||||
return getattr(self.model_runner.attention_tp_group, "cpu_group", None)
|
||||
|
||||
def get_memory_pool(self):
|
||||
return (
|
||||
self.model_runner.req_to_token_pool,
|
||||
self.model_runner.token_to_kv_pool_allocator,
|
||||
)
|
||||
|
||||
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
||||
success, message = self.model_runner.update_weights_from_disk(
|
||||
recv_req.model_path, recv_req.load_format
|
||||
)
|
||||
return success, message
|
||||
|
||||
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
||||
success, message = self.model_runner.init_weights_update_group(
|
||||
recv_req.master_address,
|
||||
recv_req.master_port,
|
||||
recv_req.rank_offset,
|
||||
recv_req.world_size,
|
||||
recv_req.group_name,
|
||||
recv_req.backend,
|
||||
)
|
||||
return success, message
|
||||
|
||||
def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
|
||||
success, message = self.model_runner.destroy_weights_update_group(
|
||||
recv_req.group_name,
|
||||
)
|
||||
return success, message
|
||||
|
||||
def init_weights_send_group_for_remote_instance(
|
||||
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
|
||||
):
|
||||
success, message = (
|
||||
self.model_runner.init_weights_send_group_for_remote_instance(
|
||||
recv_req.master_address,
|
||||
recv_req.ports,
|
||||
recv_req.group_rank,
|
||||
recv_req.world_size,
|
||||
recv_req.group_name,
|
||||
recv_req.backend,
|
||||
)
|
||||
)
|
||||
return success, message
|
||||
|
||||
def send_weights_to_remote_instance(
|
||||
self, recv_req: SendWeightsToRemoteInstanceReqInput
|
||||
):
|
||||
success, message = self.model_runner.send_weights_to_remote_instance(
|
||||
recv_req.master_address,
|
||||
recv_req.ports,
|
||||
recv_req.group_name,
|
||||
)
|
||||
return success, message
|
||||
|
||||
def update_weights_from_distributed(
|
||||
self, recv_req: UpdateWeightsFromDistributedReqInput
|
||||
):
|
||||
success, message = self.model_runner.update_weights_from_distributed(
|
||||
recv_req.names, recv_req.dtypes, recv_req.shapes, recv_req.group_name
|
||||
)
|
||||
return success, message
|
||||
|
||||
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
||||
|
||||
monkey_patch_torch_reductions()
|
||||
success, message = self.model_runner.update_weights_from_tensor(
|
||||
named_tensors=MultiprocessingSerializer.deserialize(
|
||||
recv_req.serialized_named_tensors[self.tp_rank]
|
||||
),
|
||||
load_format=recv_req.load_format,
|
||||
)
|
||||
return success, message
|
||||
|
||||
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
||||
parameter = self.model_runner.get_weights_by_name(
|
||||
recv_req.name, recv_req.truncate_size
|
||||
)
|
||||
return parameter
|
||||
|
||||
def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
|
||||
result = self.model_runner.load_lora_adapter(recv_req.to_ref())
|
||||
return result
|
||||
|
||||
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
|
||||
result = self.model_runner.unload_lora_adapter(recv_req.to_ref())
|
||||
return result
|
||||
|
||||
def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
|
||||
return self.model_runner.lora_manager.validate_lora_batch(lora_ids)
|
||||
|
||||
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||
logits_output, _ = self.model_runner.forward(forward_batch)
|
||||
embeddings = logits_output.embeddings
|
||||
return embeddings
|
||||
|
||||
|
||||
class TpModelWorker(BaseTpWorker):
|
||||
"""A tensor parallel model worker."""
|
||||
|
||||
def __init__(
|
||||
@@ -92,7 +226,7 @@ class TpModelWorker:
|
||||
is_draft_model=is_draft_worker,
|
||||
)
|
||||
|
||||
self.model_runner = ModelRunner(
|
||||
self._model_runner = ModelRunner(
|
||||
model_config=self.model_config,
|
||||
mem_fraction_static=server_args.mem_fraction_static,
|
||||
gpu_id=gpu_id,
|
||||
@@ -171,6 +305,10 @@ class TpModelWorker:
|
||||
self.enable_overlap = not server_args.disable_overlap_schedule
|
||||
self.hicache_layer_transfer_counter = None
|
||||
|
||||
@property
|
||||
def model_runner(self) -> ModelRunner:
|
||||
return self._model_runner
|
||||
|
||||
def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
|
||||
self.hicache_layer_transfer_counter = counter
|
||||
|
||||
@@ -193,38 +331,6 @@ class TpModelWorker:
|
||||
self.model_runner.token_to_kv_pool.size,
|
||||
)
|
||||
|
||||
@property
|
||||
def sliding_window_size(self) -> Optional[int]:
|
||||
return self.model_runner.sliding_window_size
|
||||
|
||||
@property
|
||||
def is_hybrid(self) -> bool:
|
||||
return self.model_runner.is_hybrid is not None
|
||||
|
||||
def get_tokens_per_layer_info(self):
|
||||
return (
|
||||
self.model_runner.full_max_total_num_tokens,
|
||||
self.model_runner.swa_max_total_num_tokens,
|
||||
)
|
||||
|
||||
def get_pad_input_ids_func(self):
|
||||
return getattr(self.model_runner.model, "pad_input_ids", None)
|
||||
|
||||
def get_tp_group(self):
|
||||
return self.model_runner.tp_group
|
||||
|
||||
def get_attention_tp_group(self):
|
||||
return self.model_runner.attention_tp_group
|
||||
|
||||
def get_attention_tp_cpu_group(self):
|
||||
return getattr(self.model_runner.attention_tp_group, "cpu_group", None)
|
||||
|
||||
def get_memory_pool(self):
|
||||
return (
|
||||
self.model_runner.req_to_token_pool,
|
||||
self.model_runner.token_to_kv_pool_allocator,
|
||||
)
|
||||
|
||||
def forward_batch_generation(
|
||||
self,
|
||||
model_worker_batch: ModelWorkerBatch,
|
||||
@@ -313,93 +419,3 @@ class TpModelWorker:
|
||||
pp_hidden_states_proxy_tensors=pp_proxy_tensors,
|
||||
can_run_cuda_graph=can_run_cuda_graph,
|
||||
)
|
||||
|
||||
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||
logits_output, _ = self.model_runner.forward(forward_batch)
|
||||
embeddings = logits_output.embeddings
|
||||
return embeddings
|
||||
|
||||
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
||||
success, message = self.model_runner.update_weights_from_disk(
|
||||
recv_req.model_path, recv_req.load_format
|
||||
)
|
||||
return success, message
|
||||
|
||||
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
||||
success, message = self.model_runner.init_weights_update_group(
|
||||
recv_req.master_address,
|
||||
recv_req.master_port,
|
||||
recv_req.rank_offset,
|
||||
recv_req.world_size,
|
||||
recv_req.group_name,
|
||||
recv_req.backend,
|
||||
)
|
||||
return success, message
|
||||
|
||||
def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
|
||||
success, message = self.model_runner.destroy_weights_update_group(
|
||||
recv_req.group_name,
|
||||
)
|
||||
return success, message
|
||||
|
||||
def init_weights_send_group_for_remote_instance(
|
||||
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
|
||||
):
|
||||
success, message = (
|
||||
self.model_runner.init_weights_send_group_for_remote_instance(
|
||||
recv_req.master_address,
|
||||
recv_req.ports,
|
||||
recv_req.group_rank,
|
||||
recv_req.world_size,
|
||||
recv_req.group_name,
|
||||
recv_req.backend,
|
||||
)
|
||||
)
|
||||
return success, message
|
||||
|
||||
def send_weights_to_remote_instance(
|
||||
self, recv_req: SendWeightsToRemoteInstanceReqInput
|
||||
):
|
||||
success, message = self.model_runner.send_weights_to_remote_instance(
|
||||
recv_req.master_address,
|
||||
recv_req.ports,
|
||||
recv_req.group_name,
|
||||
)
|
||||
return success, message
|
||||
|
||||
def update_weights_from_distributed(
|
||||
self, recv_req: UpdateWeightsFromDistributedReqInput
|
||||
):
|
||||
success, message = self.model_runner.update_weights_from_distributed(
|
||||
recv_req.names, recv_req.dtypes, recv_req.shapes, recv_req.group_name
|
||||
)
|
||||
return success, message
|
||||
|
||||
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
||||
|
||||
monkey_patch_torch_reductions()
|
||||
success, message = self.model_runner.update_weights_from_tensor(
|
||||
named_tensors=MultiprocessingSerializer.deserialize(
|
||||
recv_req.serialized_named_tensors[self.tp_rank]
|
||||
),
|
||||
load_format=recv_req.load_format,
|
||||
)
|
||||
return success, message
|
||||
|
||||
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
||||
parameter = self.model_runner.get_weights_by_name(
|
||||
recv_req.name, recv_req.truncate_size
|
||||
)
|
||||
return parameter
|
||||
|
||||
def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
|
||||
result = self.model_runner.load_lora_adapter(recv_req.to_ref())
|
||||
return result
|
||||
|
||||
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
|
||||
result = self.model_runner.unload_lora_adapter(recv_req.to_ref())
|
||||
return result
|
||||
|
||||
def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
|
||||
return self.model_runner.lora_manager.validate_lora_batch(lora_ids)
|
||||
|
||||
@@ -53,7 +53,6 @@ from sglang.srt.utils import (
|
||||
empty_context,
|
||||
get_available_gpu_memory,
|
||||
get_bool_env_var,
|
||||
get_device_memory_capacity,
|
||||
is_hip,
|
||||
log_info_on_rank0,
|
||||
require_attn_tp_gather,
|
||||
@@ -274,7 +273,6 @@ class CudaGraphRunner:
|
||||
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
||||
)
|
||||
|
||||
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
|
||||
self.encoder_len_fill_value = 0
|
||||
self.seq_lens_cpu = torch.full(
|
||||
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
||||
|
||||
29
python/sglang/srt/speculative/base_spec_worker.py
Normal file
29
python/sglang/srt/speculative/base_spec_worker.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
|
||||
|
||||
class BaseDraftWorker(ABC):
|
||||
@abstractmethod
|
||||
def draft():
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def draft_extend():
|
||||
pass
|
||||
|
||||
|
||||
class BaseSpecWorker(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def target_worker(self) -> TpModelWorker:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def draft_worker(self) -> BaseDraftWorker:
|
||||
pass
|
||||
@@ -40,7 +40,11 @@ class EAGLEDraftCudaGraphRunner:
|
||||
def __init__(self, eagle_worker: EAGLEWorker):
|
||||
# Parse args
|
||||
self.eagle_worker = eagle_worker
|
||||
self.model_runner = model_runner = eagle_worker.model_runner
|
||||
if not hasattr(eagle_worker, "model_runner"):
|
||||
# V2: EagleDraftWorker
|
||||
self.model_runner = model_runner = eagle_worker.draft_runner
|
||||
else:
|
||||
self.model_runner = model_runner = eagle_worker.model_runner
|
||||
self.graphs = {}
|
||||
self.output_buffers = {}
|
||||
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
||||
|
||||
@@ -38,7 +38,12 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
def __init__(self, eagle_worker: EAGLEWorker):
|
||||
# Parse args
|
||||
self.eagle_worker = eagle_worker
|
||||
self.model_runner = model_runner = eagle_worker.model_runner
|
||||
if not hasattr(eagle_worker, "model_runner"):
|
||||
# V2: EagleDraftWorker
|
||||
self.model_runner = model_runner = eagle_worker.draft_runner
|
||||
else:
|
||||
self.model_runner = model_runner = eagle_worker.model_runner
|
||||
|
||||
self.graphs = {}
|
||||
self.output_buffers = {}
|
||||
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
||||
@@ -285,7 +290,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
output_cache_loc_backup = forward_batch.out_cache_loc
|
||||
hidden_states_backup = forward_batch.spec_info.hidden_states
|
||||
|
||||
ret = self.eagle_worker.draft_model_runner.model.forward(
|
||||
ret = self.model_runner.model.forward(
|
||||
forward_batch.input_ids,
|
||||
forward_batch.positions,
|
||||
forward_batch,
|
||||
|
||||
@@ -574,6 +574,9 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
|
||||
|
||||
@dataclass
|
||||
class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
|
||||
# Constant: alloc length per decode step
|
||||
ALLOC_LEN_PER_DECODE: ClassVar[int] = None
|
||||
|
||||
# The inputs for decode
|
||||
# shape: (b, topk)
|
||||
topk_p: torch.Tensor = None
|
||||
@@ -609,9 +612,6 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
|
||||
new_seq_lens: Optional[torch.Tensor] = None
|
||||
verify_done: Optional[torch.cuda.Event] = None
|
||||
|
||||
# FIXME(lsyin): remove this hack
|
||||
ALLOC_LEN_PER_DECODE: ClassVar[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__(SpecInputType.EAGLE_DRAFT)
|
||||
|
||||
|
||||
@@ -9,7 +9,8 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, ScheduleBatch
|
||||
from sglang.srt.mem_cache.common import alloc_token_slots
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
CaptureHiddenMode,
|
||||
@@ -72,6 +73,34 @@ def assign_draft_cache_locs_page_size_1(
|
||||
|
||||
@dataclass
|
||||
class EagleDraftInputV2Mixin:
|
||||
def prepare_for_decode(self: EagleDraftInput, batch: ScheduleBatch):
|
||||
from sglang.srt.speculative.spec_utils import assign_req_to_token_pool
|
||||
|
||||
bs = batch.batch_size()
|
||||
|
||||
# TODO(lsyin): implement over-allocation
|
||||
# Now seq_lens and allocate_lens are correct
|
||||
batch.maybe_wait_verify_done()
|
||||
|
||||
new_allocate_lens = batch.seq_lens + self.ALLOC_LEN_PER_DECODE
|
||||
num_needed_tokens = (new_allocate_lens - self.allocate_lens).sum().item()
|
||||
out_cache_loc = alloc_token_slots(batch.tree_cache, num_needed_tokens)
|
||||
|
||||
assign_req_to_token_pool[(bs,)](
|
||||
batch.req_pool_indices,
|
||||
batch.req_to_token_pool.req_to_token,
|
||||
self.allocate_lens,
|
||||
new_allocate_lens,
|
||||
out_cache_loc,
|
||||
batch.req_to_token_pool.req_to_token.shape[1],
|
||||
next_power_of_2(bs),
|
||||
)
|
||||
self.allocate_lens = new_allocate_lens
|
||||
|
||||
# FIXME(lsyin): make this sync optional
|
||||
batch.seq_lens_cpu = batch.seq_lens.cpu()
|
||||
batch.seq_lens_sum = batch.seq_lens_cpu.sum().item()
|
||||
|
||||
def prepare_for_v2_draft(
|
||||
self: EagleDraftInput,
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
|
||||
@@ -1,17 +1,10 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
GroupCoordinator,
|
||||
get_tp_group,
|
||||
patch_tensor_parallel_group,
|
||||
)
|
||||
from sglang.srt.distributed import get_tp_group
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
@@ -47,15 +40,17 @@ from sglang.srt.speculative.eagle_utils import (
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.speculative.spec_utils import (
|
||||
assign_draft_cache_locs,
|
||||
detect_nan,
|
||||
draft_tp_context,
|
||||
fast_topk,
|
||||
generate_token_bitmask,
|
||||
load_token_map,
|
||||
select_top_k_tokens,
|
||||
)
|
||||
from sglang.srt.utils import (
|
||||
empty_context,
|
||||
get_available_gpu_memory,
|
||||
get_bool_env_var,
|
||||
is_blackwell,
|
||||
is_cuda,
|
||||
next_power_of_2,
|
||||
)
|
||||
@@ -67,14 +62,6 @@ logger = logging.getLogger(__name__)
|
||||
SGLANG_RETURN_ORIGINAL_LOGPROB = get_bool_env_var("SGLANG_RETURN_ORIGINAL_LOGPROB")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def draft_tp_context(tp_group: GroupCoordinator):
|
||||
# Draft model doesn't use dp and has its own tp group.
|
||||
# We disable mscclpp now because it doesn't support 2 comm groups.
|
||||
with patch_tensor_parallel_group(tp_group):
|
||||
yield
|
||||
|
||||
|
||||
class EAGLEWorker(TpModelWorker):
|
||||
|
||||
def __init__(
|
||||
@@ -100,7 +87,6 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
|
||||
server_args.speculative_algorithm
|
||||
)
|
||||
self.padded_static_len = -1
|
||||
|
||||
# Override the context length of the draft model to be the same as the target model.
|
||||
server_args.context_length = target_worker.model_runner.model_config.context_len
|
||||
@@ -612,7 +598,8 @@ class EAGLEWorker(TpModelWorker):
|
||||
logits_output, _ = self.draft_model_runner.forward(
|
||||
forward_batch, skip_attn_backend_init=True
|
||||
)
|
||||
self._detect_nan_if_needed(logits_output)
|
||||
if self.server_args.enable_nan_detection:
|
||||
detect_nan(logits_output)
|
||||
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
|
||||
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
|
||||
if self.hot_token_id is not None:
|
||||
@@ -680,7 +667,9 @@ class EAGLEWorker(TpModelWorker):
|
||||
# and will be applied to produce wrong results
|
||||
batch.sampling_info.vocab_mask = None
|
||||
|
||||
self._detect_nan_if_needed(logits_output)
|
||||
if self.enable_nan_detection:
|
||||
detect_nan(logits_output)
|
||||
|
||||
spec_info.hidden_states = logits_output.hidden_states
|
||||
res: EagleVerifyOutput = spec_info.verify(
|
||||
batch,
|
||||
@@ -833,7 +822,8 @@ class EAGLEWorker(TpModelWorker):
|
||||
)
|
||||
forward_batch.return_logprob = False
|
||||
logits_output, _ = self.draft_model_runner.forward(forward_batch)
|
||||
self._detect_nan_if_needed(logits_output)
|
||||
if self.enable_nan_detection:
|
||||
detect_nan(logits_output)
|
||||
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
||||
assert forward_batch.spec_info is batch.spec_info
|
||||
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
||||
@@ -928,7 +918,8 @@ class EAGLEWorker(TpModelWorker):
|
||||
)
|
||||
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
||||
|
||||
self._detect_nan_if_needed(logits_output)
|
||||
if self.enable_nan_detection:
|
||||
detect_nan(logits_output)
|
||||
|
||||
# Restore backup.
|
||||
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
|
||||
@@ -948,24 +939,6 @@ class EAGLEWorker(TpModelWorker):
|
||||
draft_input.topk_p, draft_input.topk_index = fast_topk(probs, self.topk, dim=-1)
|
||||
draft_input.hidden_states = logits_output.hidden_states
|
||||
|
||||
def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput):
|
||||
if self.enable_nan_detection:
|
||||
logits = logits_output.next_token_logits
|
||||
if torch.any(torch.isnan(logits)):
|
||||
logger.error("Detected errors during sampling! NaN in the logits.")
|
||||
raise ValueError("Detected errors during sampling! NaN in the logits.")
|
||||
|
||||
|
||||
def load_token_map(token_map_path: str) -> List[int]:
|
||||
if not os.path.exists(token_map_path):
|
||||
cache_dir = snapshot_download(
|
||||
os.path.dirname(token_map_path),
|
||||
ignore_patterns=["*.bin", "*.safetensors"],
|
||||
)
|
||||
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
|
||||
hot_token_id = torch.load(token_map_path, weights_only=True)
|
||||
return torch.tensor(hot_token_id, dtype=torch.int64)
|
||||
|
||||
|
||||
@torch.compile(dynamic=True)
|
||||
def get_last_loc_large_page_size_top_k_1(
|
||||
|
||||
@@ -1,19 +1,25 @@
|
||||
import contextlib
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
import time
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch.cuda import Stream as CudaStream
|
||||
|
||||
from sglang.srt.environ import envs
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, Req
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
||||
from sglang.srt.managers.scheduler import GenerationBatchResult
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardBatch
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.speculative.base_spec_worker import BaseDraftWorker, BaseSpecWorker
|
||||
from sglang.srt.speculative.draft_utils import DraftBackendFactory
|
||||
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
|
||||
EAGLEDraftCudaGraphRunner,
|
||||
)
|
||||
from sglang.srt.speculative.eagle_draft_extend_cuda_graph_runner import (
|
||||
EAGLEDraftExtendCudaGraphRunner,
|
||||
)
|
||||
from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.speculative.eagle_info_v2 import (
|
||||
assign_extend_cache_locs,
|
||||
@@ -22,69 +28,214 @@ from sglang.srt.speculative.eagle_info_v2 import (
|
||||
select_top_k_tokens_tmp,
|
||||
)
|
||||
from sglang.srt.speculative.eagle_utils import TreeMaskMode, build_tree_kernel_efficient
|
||||
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
||||
from sglang.srt.utils.common import fast_topk, next_power_of_2
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.speculative.spec_utils import (
|
||||
detect_nan,
|
||||
draft_tp_context,
|
||||
load_token_map,
|
||||
)
|
||||
from sglang.srt.utils.common import (
|
||||
empty_context,
|
||||
fast_topk,
|
||||
get_available_gpu_memory,
|
||||
next_power_of_2,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EAGLEWorkerV2(EAGLEWorker):
|
||||
def _get_plan_stream(
|
||||
device: str,
|
||||
) -> Tuple[Optional[CudaStream], contextlib.AbstractContextManager]:
|
||||
if envs.SGLANG_ENABLE_OVERLAP_PLAN_STREAM.get():
|
||||
plan_stream: CudaStream = torch.get_device_module(device).Stream()
|
||||
plan_stream_ctx = torch.cuda.stream(plan_stream)
|
||||
return plan_stream, plan_stream_ctx
|
||||
else:
|
||||
return None, contextlib.nullcontext()
|
||||
|
||||
|
||||
class EagleDraftWorker(BaseDraftWorker):
|
||||
def __init__(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
gpu_id: int,
|
||||
tp_rank: int,
|
||||
dp_rank: Optional[int],
|
||||
dp_rank: int,
|
||||
moe_ep_rank: int,
|
||||
nccl_port: int,
|
||||
target_worker: TpModelWorker,
|
||||
):
|
||||
super().__init__(
|
||||
server_args,
|
||||
gpu_id,
|
||||
tp_rank,
|
||||
dp_rank,
|
||||
moe_ep_rank,
|
||||
nccl_port,
|
||||
target_worker,
|
||||
# copy args
|
||||
self.server_args = server_args
|
||||
self.gpu_id = gpu_id
|
||||
self.tp_rank = tp_rank
|
||||
self.dp_rank = dp_rank
|
||||
self.moe_ep_rank = moe_ep_rank
|
||||
self.nccl_port = nccl_port
|
||||
self.target_worker = target_worker
|
||||
|
||||
# Args for easy access
|
||||
self.device = server_args.device
|
||||
self.topk = server_args.speculative_eagle_topk
|
||||
self.speculative_num_steps = server_args.speculative_num_steps
|
||||
self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens
|
||||
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
|
||||
server_args.speculative_algorithm
|
||||
)
|
||||
|
||||
# Set constant
|
||||
EagleDraftInput.ALLOC_LEN_PER_DECODE = max(
|
||||
self.speculative_num_steps * self.topk, self.speculative_num_draft_tokens
|
||||
)
|
||||
|
||||
# Do not capture cuda graph in `TpModelWorker` init,
|
||||
# will capture later with init_cuda_graphs()
|
||||
backup_disable_cuda_graph = server_args.disable_cuda_graph
|
||||
server_args.disable_cuda_graph = True
|
||||
|
||||
# Share the allocator with a target worker.
|
||||
# Draft and target worker own their own KV cache pools.
|
||||
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
|
||||
target_worker.get_memory_pool()
|
||||
)
|
||||
with empty_context():
|
||||
# Init draft worker
|
||||
self.draft_worker = TpModelWorker(
|
||||
server_args=server_args,
|
||||
gpu_id=gpu_id,
|
||||
tp_rank=tp_rank,
|
||||
pp_rank=0, # FIXME
|
||||
dp_rank=dp_rank,
|
||||
moe_ep_rank=moe_ep_rank,
|
||||
nccl_port=nccl_port,
|
||||
is_draft_worker=True,
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
)
|
||||
|
||||
# Alias for better readability
|
||||
self.draft_runner = self.draft_worker.model_runner
|
||||
|
||||
self.init_token_map()
|
||||
self.init_lm_head()
|
||||
|
||||
# Init attention backend and cuda graphs
|
||||
self.draft_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
|
||||
self.draft_tp_context = (
|
||||
draft_tp_context if server_args.enable_dp_attention else empty_context
|
||||
)
|
||||
with self.draft_tp_context(self.draft_runner.tp_group):
|
||||
self.init_attention_backend()
|
||||
self.init_cuda_graphs()
|
||||
|
||||
self.tree_mask_mode = TreeMaskMode.FULL_MASK
|
||||
|
||||
if envs.SGLANG_ENABLE_OVERLAP_PLAN_STREAM.get():
|
||||
self.plan_stream: CudaStream = torch.get_device_module(self.device).Stream()
|
||||
self.plan_stream_ctx = torch.cuda.stream(self.plan_stream)
|
||||
else:
|
||||
self.plan_stream = None
|
||||
self.plan_stream_ctx = contextlib.nullcontext()
|
||||
self.plan_stream, self.plan_stream_ctx = _get_plan_stream(self.device)
|
||||
|
||||
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
|
||||
if model_worker_batch.forward_mode.is_decode():
|
||||
# FIXME(lsyin): why shall we use spec_info for both draft and verify?
|
||||
draft_input: EagleDraftInput = model_worker_batch.spec_info
|
||||
assert draft_input.is_draft_input()
|
||||
verify_input: EagleVerifyInput = self.draft(model_worker_batch)
|
||||
assert verify_input.is_verify_input()
|
||||
model_worker_batch.spec_info = verify_input
|
||||
batch_output = self.verify(model_worker_batch, draft_input.allocate_lens)
|
||||
return batch_output
|
||||
def init_token_map(self):
|
||||
# Load hot token ids
|
||||
if self.speculative_algorithm.is_eagle3():
|
||||
if self.server_args.speculative_token_map is not None:
|
||||
logger.warning(
|
||||
"Speculative token map specified, but EAGLE3 models already have this. Ignoring the specified token map."
|
||||
)
|
||||
self.hot_token_id = None
|
||||
elif self.server_args.speculative_token_map is not None:
|
||||
self.hot_token_id = load_token_map(self.server_args.speculative_token_map)
|
||||
self.server_args.json_model_override_args = (
|
||||
f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
|
||||
)
|
||||
else:
|
||||
# Target prefill
|
||||
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||
batch_output = self.target_worker.forward_batch_generation(
|
||||
model_worker_batch
|
||||
self.hot_token_id = None
|
||||
|
||||
def init_lm_head(self):
|
||||
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
||||
if self.speculative_algorithm.is_eagle3():
|
||||
# most cases EAGLE3 models don't share lm_head
|
||||
# but some models (e.g. nvidia/gpt-oss-120b-Eagle3) shares
|
||||
if (
|
||||
hasattr(self.draft_runner.model, "load_lm_head_from_target")
|
||||
and self.draft_runner.model.load_lm_head_from_target
|
||||
):
|
||||
self.draft_runner.model.set_embed_and_head(embed, head)
|
||||
else:
|
||||
self.draft_runner.model.set_embed(embed)
|
||||
|
||||
# grab hot token ids
|
||||
if self.draft_runner.model.hot_token_id is not None:
|
||||
self.hot_token_id = self.draft_runner.model.hot_token_id.to(
|
||||
embed.device
|
||||
)
|
||||
|
||||
else:
|
||||
if self.hot_token_id is not None:
|
||||
head = head.clone()
|
||||
self.hot_token_id = self.hot_token_id.to(head.device)
|
||||
head.data = head.data[self.hot_token_id]
|
||||
|
||||
# Share the embedding and lm_head
|
||||
self.draft_runner.model.set_embed_and_head(embed, head)
|
||||
|
||||
def init_attention_backend(self):
|
||||
# Create multi-step attn backends and cuda graph runners
|
||||
|
||||
self.has_prefill_wrapper_verify = False
|
||||
self.draft_extend_attn_backend = None
|
||||
|
||||
draft_backend_factory = DraftBackendFactory(
|
||||
self.server_args,
|
||||
self.draft_runner,
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
|
||||
# Initialize decode attention backend
|
||||
self.draft_attn_backend = draft_backend_factory.create_decode_backend()
|
||||
|
||||
# Initialize draft extend attention backend (respects speculative_attention_mode setting)
|
||||
self.draft_extend_attn_backend = (
|
||||
draft_backend_factory.create_draft_extend_backend()
|
||||
)
|
||||
|
||||
self.draft_runner.draft_attn_backend = self.draft_attn_backend
|
||||
self.tree_mask_mode = TreeMaskMode.FULL_MASK
|
||||
|
||||
def init_cuda_graphs(self):
|
||||
"""Capture cuda graphs."""
|
||||
self.cuda_graph_runner = None
|
||||
self.cuda_graph_runner_for_draft_extend = None
|
||||
|
||||
if self.server_args.disable_cuda_graph:
|
||||
return
|
||||
|
||||
# Capture draft
|
||||
if self.speculative_num_steps > 1:
|
||||
tic = time.perf_counter()
|
||||
before_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
||||
logger.info(
|
||||
f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
|
||||
)
|
||||
self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
|
||||
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
||||
logger.info(
|
||||
f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
|
||||
)
|
||||
|
||||
# Draft prefill
|
||||
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
batch_output.next_draft_input = self.forward_draft_extend(
|
||||
model_worker_batch,
|
||||
batch_output.logits_output.hidden_states,
|
||||
batch_output.next_token_ids,
|
||||
# Capture extend
|
||||
if self.draft_extend_attn_backend:
|
||||
tic = time.perf_counter()
|
||||
before_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
||||
logger.info(
|
||||
f"Capture draft extend cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
|
||||
)
|
||||
self.cuda_graph_runner_for_draft_extend = EAGLEDraftExtendCudaGraphRunner(
|
||||
self
|
||||
)
|
||||
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
||||
logger.info(
|
||||
f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
|
||||
)
|
||||
return batch_output
|
||||
|
||||
def draft(self, model_worker_batch: ModelWorkerBatch):
|
||||
draft_input: EagleDraftInput = model_worker_batch.spec_info
|
||||
@@ -92,7 +243,7 @@ class EAGLEWorkerV2(EAGLEWorker):
|
||||
self.req_to_token_pool,
|
||||
model_worker_batch,
|
||||
self.cuda_graph_runner,
|
||||
self.draft_model_runner,
|
||||
self.draft_runner,
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
@@ -201,10 +352,11 @@ class EAGLEWorkerV2(EAGLEWorker):
|
||||
spec_info.hidden_states = hidden_states
|
||||
|
||||
# Run forward
|
||||
logits_output = self.draft_model_runner.model.forward(
|
||||
logits_output = self.draft_runner.model.forward(
|
||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||
)
|
||||
self._detect_nan_if_needed(logits_output)
|
||||
if self.server_args.enable_nan_detection:
|
||||
detect_nan(logits_output)
|
||||
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
|
||||
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
|
||||
if self.hot_token_id is not None:
|
||||
@@ -233,10 +385,190 @@ class EAGLEWorkerV2(EAGLEWorker):
|
||||
|
||||
return parent_list, top_scores_index, draft_tokens
|
||||
|
||||
def draft_extend(self):
|
||||
pass
|
||||
|
||||
def _draft_extend_for_prefill(
|
||||
self,
|
||||
batch: ModelWorkerBatch,
|
||||
target_hidden_states: torch.Tensor,
|
||||
next_token_ids: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Run draft model extend to correctly fill the KV cache.
|
||||
|
||||
Args:
|
||||
batch: The batch to run.
|
||||
target_hidden_states: Hidden states from the target model forward
|
||||
next_token_ids: Next token ids generated from the target forward.
|
||||
"""
|
||||
# Construct input_ids
|
||||
pt = 0
|
||||
for i, extend_len in enumerate(batch.extend_seq_lens):
|
||||
input_ids = batch.input_ids[pt : pt + extend_len]
|
||||
batch.input_ids[pt : pt + extend_len] = torch.cat(
|
||||
(input_ids[1:], next_token_ids[i].reshape(1))
|
||||
)
|
||||
pt += extend_len
|
||||
|
||||
# Construct spec_info
|
||||
next_draft_input = EagleDraftInput(
|
||||
hidden_states=target_hidden_states,
|
||||
verified_id=next_token_ids,
|
||||
new_seq_lens=batch.seq_lens,
|
||||
allocate_lens=batch.seq_lens,
|
||||
)
|
||||
batch.spec_info = next_draft_input
|
||||
|
||||
# Run forward
|
||||
forward_batch = ForwardBatch.init_new(batch, self.draft_runner)
|
||||
logits_output, _ = self.draft_runner.forward(forward_batch)
|
||||
|
||||
# Update spec_info for the next draft step
|
||||
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
|
||||
next_draft_input.topk_p, next_draft_input.topk_index = fast_topk(
|
||||
probs, self.topk, dim=-1
|
||||
)
|
||||
next_draft_input.hidden_states = logits_output.hidden_states
|
||||
return next_draft_input
|
||||
|
||||
def _draft_extend_for_decode(
|
||||
self, batch: ModelWorkerBatch, batch_result: GenerationBatchResult
|
||||
):
|
||||
# Batch 2: Draft extend
|
||||
draft_input = EagleDraftInput(
|
||||
hidden_states=batch_result.logits_output.hidden_states,
|
||||
)
|
||||
select_index = (
|
||||
torch.arange(len(batch.seq_lens), device=self.device)
|
||||
* self.speculative_num_draft_tokens
|
||||
+ batch_result.accept_lens
|
||||
- 1
|
||||
)
|
||||
|
||||
# Prepare for draft extend in a separate stream
|
||||
with self.plan_stream_ctx:
|
||||
forward_batch = draft_input.prepare_for_extend_to_fill_draft_kvcache(
|
||||
batch,
|
||||
batch_result.next_token_ids,
|
||||
self.speculative_num_draft_tokens,
|
||||
self.draft_runner,
|
||||
)
|
||||
|
||||
if self.plan_stream:
|
||||
torch.cuda.current_stream().wait_stream(self.plan_stream)
|
||||
|
||||
# Run draft extend batch in the main compute stream
|
||||
draft_logits_output = self.draft_runner.model.forward(
|
||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||
)
|
||||
|
||||
# Reorganize the spec info for the next batch
|
||||
draft_logits_output.next_token_logits = draft_logits_output.next_token_logits[
|
||||
select_index
|
||||
]
|
||||
draft_logits_output.hidden_states = draft_logits_output.hidden_states[
|
||||
select_index
|
||||
]
|
||||
probs = torch.softmax(draft_logits_output.next_token_logits, dim=-1)
|
||||
ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1)
|
||||
ret_hidden_states = draft_logits_output.hidden_states
|
||||
|
||||
# Construct the return values
|
||||
next_draft_input = batch_result.next_draft_input
|
||||
(
|
||||
next_draft_input.topk_p,
|
||||
next_draft_input.topk_index,
|
||||
next_draft_input.hidden_states,
|
||||
) = (
|
||||
ret_topk_p,
|
||||
ret_topk_index,
|
||||
ret_hidden_states,
|
||||
)
|
||||
|
||||
|
||||
class EAGLEWorkerV2(BaseSpecWorker):
|
||||
def __init__(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
gpu_id: int,
|
||||
tp_rank: int,
|
||||
dp_rank: Optional[int],
|
||||
moe_ep_rank: int,
|
||||
nccl_port: int,
|
||||
target_worker: TpModelWorker,
|
||||
):
|
||||
# Parse arguments
|
||||
self.server_args = server_args
|
||||
self.topk = server_args.speculative_eagle_topk
|
||||
self.speculative_num_steps = server_args.speculative_num_steps
|
||||
self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens
|
||||
self.enable_nan_detection = server_args.enable_nan_detection
|
||||
self.gpu_id = gpu_id
|
||||
self.device = server_args.device
|
||||
self._target_worker = target_worker
|
||||
self.page_size = server_args.page_size
|
||||
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
|
||||
server_args.speculative_algorithm
|
||||
)
|
||||
|
||||
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
|
||||
target_worker.get_memory_pool()
|
||||
)
|
||||
|
||||
# Override the context length of the draft model to be the same as the target model.
|
||||
server_args.context_length = target_worker.model_runner.model_config.context_len
|
||||
|
||||
self._draft_worker = EagleDraftWorker(
|
||||
server_args, gpu_id, tp_rank, dp_rank, moe_ep_rank, nccl_port, target_worker
|
||||
)
|
||||
|
||||
# Some dummy tensors
|
||||
self.num_new_pages_per_topk = torch.empty(
|
||||
(), dtype=torch.int64, device=self.device
|
||||
)
|
||||
self.extend_lens = torch.empty((), dtype=torch.int64, device=self.device)
|
||||
|
||||
self.plan_stream, self.plan_stream_ctx = _get_plan_stream(self.device)
|
||||
|
||||
@property
|
||||
def target_worker(self):
|
||||
return self._target_worker
|
||||
|
||||
@property
|
||||
def draft_worker(self):
|
||||
return self._draft_worker
|
||||
|
||||
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
|
||||
if model_worker_batch.forward_mode.is_decode():
|
||||
draft_input: EagleDraftInput = model_worker_batch.spec_info
|
||||
assert draft_input.is_draft_input()
|
||||
verify_input: EagleVerifyInput = self.draft_worker.draft(model_worker_batch)
|
||||
assert verify_input.is_verify_input()
|
||||
model_worker_batch.spec_info = verify_input
|
||||
batch_output = self.verify(model_worker_batch, draft_input.allocate_lens)
|
||||
self.draft_worker._draft_extend_for_decode(model_worker_batch, batch_output)
|
||||
return batch_output
|
||||
else:
|
||||
# Target prefill
|
||||
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||
batch_output = self.target_worker.forward_batch_generation(
|
||||
model_worker_batch
|
||||
)
|
||||
|
||||
# Draft prefill
|
||||
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
batch_output.next_draft_input = self.draft_worker._draft_extend_for_prefill(
|
||||
model_worker_batch,
|
||||
batch_output.logits_output.hidden_states,
|
||||
batch_output.next_token_ids,
|
||||
)
|
||||
return batch_output
|
||||
|
||||
def verify(
|
||||
self,
|
||||
batch: ModelWorkerBatch,
|
||||
pre_draft_allocate_lens: torch.Tensor,
|
||||
cur_allocate_lens: torch.Tensor,
|
||||
):
|
||||
# Since batch.seq_lens is allocated in another stream, we need
|
||||
# record_stream() to prevent pytorch gc and reuse the gpu memory
|
||||
@@ -284,7 +616,8 @@ class EAGLEWorkerV2(EAGLEWorker):
|
||||
logits_output = forward_batch_output.logits_output
|
||||
|
||||
# Sample
|
||||
self._detect_nan_if_needed(logits_output)
|
||||
if self.enable_nan_detection:
|
||||
detect_nan(logits_output)
|
||||
(
|
||||
predict,
|
||||
accept_length,
|
||||
@@ -303,53 +636,11 @@ class EAGLEWorkerV2(EAGLEWorker):
|
||||
self.speculative_num_draft_tokens,
|
||||
)
|
||||
|
||||
# Batch 2: Draft extend
|
||||
draft_input = EagleDraftInput(
|
||||
hidden_states=logits_output.hidden_states,
|
||||
)
|
||||
select_index = (
|
||||
torch.arange(len(batch.seq_lens), device=self.device)
|
||||
* self.speculative_num_draft_tokens
|
||||
+ accept_length
|
||||
- 1
|
||||
)
|
||||
|
||||
# Prepare for draft extend in a separate stream
|
||||
with self.plan_stream_ctx:
|
||||
forward_batch = draft_input.prepare_for_extend_to_fill_draft_kvcache(
|
||||
batch,
|
||||
predict,
|
||||
self.speculative_num_draft_tokens,
|
||||
self.draft_model_runner,
|
||||
)
|
||||
|
||||
if self.plan_stream:
|
||||
torch.cuda.current_stream().wait_stream(self.plan_stream)
|
||||
|
||||
# Run draft extend batch in the main compute stream
|
||||
draft_logits_output = self.draft_model_runner.model.forward(
|
||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||
)
|
||||
|
||||
# Reorganize the spec info for the next batch
|
||||
draft_logits_output.next_token_logits = draft_logits_output.next_token_logits[
|
||||
select_index
|
||||
]
|
||||
draft_logits_output.hidden_states = draft_logits_output.hidden_states[
|
||||
select_index
|
||||
]
|
||||
probs = torch.softmax(draft_logits_output.next_token_logits, dim=-1)
|
||||
ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1)
|
||||
ret_hidden_states = draft_logits_output.hidden_states
|
||||
|
||||
# Construct the return values
|
||||
# Construct the next draft input
|
||||
next_draft_input = EagleDraftInput(
|
||||
topk_p=ret_topk_p,
|
||||
topk_index=ret_topk_index,
|
||||
hidden_states=ret_hidden_states,
|
||||
verified_id=verified_id,
|
||||
new_seq_lens=new_seq_lens,
|
||||
allocate_lens=pre_draft_allocate_lens,
|
||||
allocate_lens=cur_allocate_lens,
|
||||
verify_done=verify_done,
|
||||
)
|
||||
|
||||
@@ -359,53 +650,9 @@ class EAGLEWorkerV2(EAGLEWorker):
|
||||
can_run_cuda_graph=can_run_cuda_graph,
|
||||
next_draft_input=next_draft_input,
|
||||
accept_lens=accept_length,
|
||||
last_batch_allocate_lens=pre_draft_allocate_lens,
|
||||
allocate_lens=cur_allocate_lens,
|
||||
)
|
||||
|
||||
def forward_draft_extend(
|
||||
self,
|
||||
batch: ModelWorkerBatch,
|
||||
target_hidden_states: torch.Tensor,
|
||||
next_token_ids: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Run draft model extend to correctly fill the KV cache.
|
||||
|
||||
Args:
|
||||
batch: The batch to run.
|
||||
target_hidden_states: Hidden states from the target model forward
|
||||
next_token_ids: Next token ids generated from the target forward.
|
||||
"""
|
||||
# Construct input_ids
|
||||
pt = 0
|
||||
for i, extend_len in enumerate(batch.extend_seq_lens):
|
||||
input_ids = batch.input_ids[pt : pt + extend_len]
|
||||
batch.input_ids[pt : pt + extend_len] = torch.cat(
|
||||
(input_ids[1:], next_token_ids[i].reshape(1))
|
||||
)
|
||||
pt += extend_len
|
||||
|
||||
# Construct spec_info
|
||||
next_draft_input = EagleDraftInput(
|
||||
hidden_states=target_hidden_states,
|
||||
verified_id=next_token_ids,
|
||||
new_seq_lens=batch.seq_lens,
|
||||
allocate_lens=batch.seq_lens,
|
||||
)
|
||||
batch.spec_info = next_draft_input
|
||||
|
||||
# Run forward
|
||||
forward_batch = ForwardBatch.init_new(batch, self.draft_model_runner)
|
||||
logits_output, _ = self.draft_model_runner.forward(forward_batch)
|
||||
|
||||
# Update spec_info for the next draft step
|
||||
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
|
||||
next_draft_input.topk_p, next_draft_input.topk_index = fast_topk(
|
||||
probs, self.topk, dim=-1
|
||||
)
|
||||
next_draft_input.hidden_states = logits_output.hidden_states
|
||||
return next_draft_input
|
||||
|
||||
def move_accepted_tokens_to_target_kvcache(
|
||||
self,
|
||||
batch: ModelWorkerBatch,
|
||||
@@ -449,32 +696,3 @@ class EAGLEWorkerV2(EAGLEWorker):
|
||||
self.token_to_kv_pool_allocator.get_kvcache().move_kv_cache(
|
||||
tgt_cache_loc, accepted_out_cache_loc
|
||||
)
|
||||
|
||||
def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput):
|
||||
if self.enable_nan_detection:
|
||||
logits = logits_output.next_token_logits
|
||||
if torch.any(torch.isnan(logits)):
|
||||
logger.error("Detected errors during sampling! NaN in the logits.")
|
||||
raise ValueError("Detected errors during sampling! NaN in the logits.")
|
||||
|
||||
|
||||
def free_spec_dec_tokens_page_size_1(
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
req: Req,
|
||||
allocate_len: int,
|
||||
new_seq_len: int,
|
||||
):
|
||||
# FIXME(lsyin): move this function elsewhere
|
||||
|
||||
# free extra allocated tokens
|
||||
if new_seq_len is None:
|
||||
# True only for overlap eagle and the current batch is decode. This seq will be part of the decode, so the final iteration's allocation is not used (i.e. this case).
|
||||
start_len = allocate_len - EagleDraftInput.ALLOC_LEN_PER_DECODE
|
||||
else:
|
||||
# True for 1) non-overlap; 2) overlap eagle and the current batch is prefill. This seq will not run extra iteration, so start_lens is passed in.
|
||||
start_len = new_seq_len
|
||||
indices_to_free = req_to_token_pool.req_to_token[req.req_pool_idx][
|
||||
start_len:allocate_len
|
||||
]
|
||||
token_to_kv_pool_allocator.free(indices_to_free)
|
||||
|
||||
@@ -1,25 +1,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
||||
from sglang.srt.distributed.parallel_state import (
|
||||
GroupCoordinator,
|
||||
patch_tensor_parallel_group,
|
||||
)
|
||||
from sglang.srt.environ import envs
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
from sglang.srt.utils import is_cuda, is_hip
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||
from sglang.srt.speculative.eagle_info import EagleVerifyInput
|
||||
|
||||
|
||||
if is_cuda():
|
||||
from sgl_kernel import fast_topk
|
||||
elif is_hip():
|
||||
from sgl_kernel import fast_topk
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.speculative.eagle_info import EagleVerifyInput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -603,3 +615,29 @@ def generate_token_bitmask(
|
||||
|
||||
verify_input.grammar = grammar
|
||||
return allocate_token_bitmask
|
||||
|
||||
|
||||
def load_token_map(token_map_path: str) -> List[int]:
|
||||
if not os.path.exists(token_map_path):
|
||||
cache_dir = snapshot_download(
|
||||
os.path.dirname(token_map_path),
|
||||
ignore_patterns=["*.bin", "*.safetensors"],
|
||||
)
|
||||
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
|
||||
hot_token_id = torch.load(token_map_path, weights_only=True)
|
||||
return torch.tensor(hot_token_id, dtype=torch.int64)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def draft_tp_context(tp_group: GroupCoordinator):
|
||||
# Draft model doesn't use dp and has its own tp group.
|
||||
# We disable mscclpp now because it doesn't support 2 comm groups.
|
||||
with patch_tensor_parallel_group(tp_group):
|
||||
yield
|
||||
|
||||
|
||||
def detect_nan(logits_output: LogitsProcessorOutput):
|
||||
logits = logits_output.next_token_logits
|
||||
if torch.any(torch.isnan(logits)):
|
||||
logger.error("Detected errors during sampling! NaN in the logits.")
|
||||
raise ValueError("Detected errors during sampling! NaN in the logits.")
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.speculative.eagle_worker import EAGLEWorker, load_token_map
|
||||
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.speculative.spec_utils import draft_tp_context, load_token_map
|
||||
from sglang.srt.utils import empty_context, get_bool_env_var, is_cuda
|
||||
|
||||
if is_cuda():
|
||||
@@ -18,14 +17,6 @@ logger = logging.getLogger(__name__)
|
||||
SGLANG_RETURN_ORIGINAL_LOGPROB = get_bool_env_var("SGLANG_RETURN_ORIGINAL_LOGPROB")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def draft_tp_context(tp_group: GroupCoordinator):
|
||||
# Draft model doesn't use dp and has its own tp group.
|
||||
# We disable mscclpp now because it doesn't support 2 comm groups.
|
||||
with patch_tensor_parallel_group(tp_group):
|
||||
yield
|
||||
|
||||
|
||||
class StandaloneWorker(EAGLEWorker):
|
||||
|
||||
def __init__(
|
||||
@@ -51,7 +42,6 @@ class StandaloneWorker(EAGLEWorker):
|
||||
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
|
||||
server_args.speculative_algorithm
|
||||
)
|
||||
self.padded_static_len = -1
|
||||
|
||||
# Override the context length of the draft model to be the same as the target model.
|
||||
server_args.context_length = target_worker.model_runner.model_config.context_len
|
||||
|
||||
Reference in New Issue
Block a user