Remove sampling info events and overlap thread file (#11300)
This commit is contained in:
@@ -783,16 +783,6 @@ class SchedulerDisaggregationDecodeMixin:
|
|||||||
self.prepare_mlp_sync_batch(batch)
|
self.prepare_mlp_sync_batch(batch)
|
||||||
result = self.run_batch(batch)
|
result = self.run_batch(batch)
|
||||||
self.result_queue.append((batch.copy(), result))
|
self.result_queue.append((batch.copy(), result))
|
||||||
|
|
||||||
if (self.last_batch is None) or (not self.last_batch_in_queue):
|
|
||||||
# Create a dummy first batch to start the pipeline for overlap schedule.
|
|
||||||
# It is now used for triggering the sampling_info_done event.
|
|
||||||
tmp_batch = ScheduleBatch(
|
|
||||||
reqs=None,
|
|
||||||
forward_mode=ForwardMode.DUMMY_FIRST,
|
|
||||||
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
|
||||||
)
|
|
||||||
self.set_next_batch_sampling_info_done(tmp_batch)
|
|
||||||
last_batch_in_queue = True
|
last_batch_in_queue = True
|
||||||
|
|
||||||
elif prepare_mlp_sync_flag:
|
elif prepare_mlp_sync_flag:
|
||||||
@@ -806,9 +796,6 @@ class SchedulerDisaggregationDecodeMixin:
|
|||||||
# Process the results of the previous batch but skip if the last batch is extend
|
# Process the results of the previous batch but skip if the last batch is extend
|
||||||
if self.last_batch and self.last_batch_in_queue:
|
if self.last_batch and self.last_batch_in_queue:
|
||||||
tmp_batch, tmp_result = self.result_queue.popleft()
|
tmp_batch, tmp_result = self.result_queue.popleft()
|
||||||
tmp_batch.next_batch_sampling_info = (
|
|
||||||
self.tp_worker.cur_sampling_info if batch else None
|
|
||||||
)
|
|
||||||
self.process_batch_result(tmp_batch, tmp_result)
|
self.process_batch_result(tmp_batch, tmp_result)
|
||||||
|
|
||||||
queue_size = (
|
queue_size = (
|
||||||
|
|||||||
@@ -338,21 +338,8 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
result = self.run_batch(batch)
|
result = self.run_batch(batch)
|
||||||
self.result_queue.append((batch.copy(), result))
|
self.result_queue.append((batch.copy(), result))
|
||||||
|
|
||||||
if self.last_batch is None:
|
|
||||||
# Create a dummy first batch to start the pipeline for overlap schedule.
|
|
||||||
# It is now used for triggering the sampling_info_done event.
|
|
||||||
tmp_batch = ScheduleBatch(
|
|
||||||
reqs=None,
|
|
||||||
forward_mode=ForwardMode.DUMMY_FIRST,
|
|
||||||
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
|
||||||
)
|
|
||||||
self.set_next_batch_sampling_info_done(tmp_batch)
|
|
||||||
|
|
||||||
if self.last_batch:
|
if self.last_batch:
|
||||||
tmp_batch, tmp_result = self.result_queue.popleft()
|
tmp_batch, tmp_result = self.result_queue.popleft()
|
||||||
tmp_batch.next_batch_sampling_info = (
|
|
||||||
self.tp_worker.cur_sampling_info if batch else None
|
|
||||||
)
|
|
||||||
self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
|
self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
|
||||||
|
|
||||||
if len(self.disagg_prefill_inflight_queue) > 0:
|
if len(self.disagg_prefill_inflight_queue) > 0:
|
||||||
@@ -491,8 +478,6 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
if self.enable_overlap:
|
if self.enable_overlap:
|
||||||
self.send_kv_chunk(req, last_chunk=False, end_idx=req.tmp_end_idx)
|
self.send_kv_chunk(req, last_chunk=False, end_idx=req.tmp_end_idx)
|
||||||
|
|
||||||
# We need to remove the sync in the following function for overlap schedule.
|
|
||||||
self.set_next_batch_sampling_info_done(batch)
|
|
||||||
self.maybe_send_health_check_signal()
|
self.maybe_send_health_check_signal()
|
||||||
|
|
||||||
def process_disagg_prefill_inflight_queue(
|
def process_disagg_prefill_inflight_queue(
|
||||||
|
|||||||
@@ -891,7 +891,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
|
|
||||||
# Sampling info
|
# Sampling info
|
||||||
sampling_info: SamplingBatchInfo = None
|
sampling_info: SamplingBatchInfo = None
|
||||||
next_batch_sampling_info: SamplingBatchInfo = None
|
|
||||||
|
|
||||||
# Batched arguments to model runner
|
# Batched arguments to model runner
|
||||||
input_ids: torch.Tensor = None # shape: [b], int64
|
input_ids: torch.Tensor = None # shape: [b], int64
|
||||||
|
|||||||
@@ -1012,22 +1012,9 @@ class Scheduler(
|
|||||||
result = self.run_batch(batch)
|
result = self.run_batch(batch)
|
||||||
self.result_queue.append((batch.copy(), result))
|
self.result_queue.append((batch.copy(), result))
|
||||||
|
|
||||||
if self.last_batch is None:
|
|
||||||
# Create a dummy first batch to start the pipeline for overlap schedule.
|
|
||||||
# It is now used for triggering the sampling_info_done event.
|
|
||||||
tmp_batch = ScheduleBatch(
|
|
||||||
reqs=None,
|
|
||||||
forward_mode=ForwardMode.DUMMY_FIRST,
|
|
||||||
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
|
||||||
)
|
|
||||||
self.process_batch_result(tmp_batch, None)
|
|
||||||
|
|
||||||
if self.last_batch:
|
if self.last_batch:
|
||||||
# Process the results of the last batch
|
# Process the results of the last batch
|
||||||
tmp_batch, tmp_result = self.result_queue.popleft()
|
tmp_batch, tmp_result = self.result_queue.popleft()
|
||||||
tmp_batch.next_batch_sampling_info = (
|
|
||||||
self.tp_worker.cur_sampling_info if batch else None
|
|
||||||
)
|
|
||||||
self.process_batch_result(tmp_batch, tmp_result)
|
self.process_batch_result(tmp_batch, tmp_result)
|
||||||
elif batch is None:
|
elif batch is None:
|
||||||
# When the server is idle, do self-check and re-init some states
|
# When the server is idle, do self-check and re-init some states
|
||||||
@@ -2100,7 +2087,7 @@ class Scheduler(
|
|||||||
self.record_batch_in_overlap(model_worker_batch)
|
self.record_batch_in_overlap(model_worker_batch)
|
||||||
|
|
||||||
# Sampling info will be modified during forward
|
# Sampling info will be modified during forward
|
||||||
model_worker_batch.sampling_info = self.tp_worker.cur_sampling_info = (
|
model_worker_batch.sampling_info = (
|
||||||
model_worker_batch.sampling_info.copy_for_forward()
|
model_worker_batch.sampling_info.copy_for_forward()
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -2219,9 +2206,6 @@ class Scheduler(
|
|||||||
if self.enable_overlap:
|
if self.enable_overlap:
|
||||||
if result.copy_done is not None:
|
if result.copy_done is not None:
|
||||||
result.copy_done.synchronize()
|
result.copy_done.synchronize()
|
||||||
self.set_next_batch_sampling_info_done(batch)
|
|
||||||
elif batch.forward_mode.is_dummy_first():
|
|
||||||
self.set_next_batch_sampling_info_done(batch)
|
|
||||||
|
|
||||||
self.maybe_send_health_check_signal()
|
self.maybe_send_health_check_signal()
|
||||||
|
|
||||||
@@ -2431,13 +2415,6 @@ class Scheduler(
|
|||||||
self._add_request_to_queue(req)
|
self._add_request_to_queue(req)
|
||||||
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
||||||
|
|
||||||
def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
|
|
||||||
if batch.next_batch_sampling_info:
|
|
||||||
if batch.next_batch_sampling_info.grammars is not None:
|
|
||||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
|
||||||
self.default_stream.synchronize()
|
|
||||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
|
||||||
|
|
||||||
def watchdog_thread(self):
|
def watchdog_thread(self):
|
||||||
"""A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
|
"""A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
|
||||||
self.watchdog_last_forward_ct = 0
|
self.watchdog_last_forward_ct = 0
|
||||||
|
|||||||
@@ -173,8 +173,6 @@ class SchedulerOutputProcessorMixin:
|
|||||||
)
|
)
|
||||||
logprob_pt += num_input_logprobs
|
logprob_pt += num_input_logprobs
|
||||||
|
|
||||||
self.set_next_batch_sampling_info_done(batch)
|
|
||||||
|
|
||||||
else: # embedding or reward model
|
else: # embedding or reward model
|
||||||
embeddings = result.embeddings.tolist()
|
embeddings = result.embeddings.tolist()
|
||||||
|
|
||||||
@@ -295,7 +293,6 @@ class SchedulerOutputProcessorMixin:
|
|||||||
self.abort_request(AbortReq(rid=req.rid))
|
self.abort_request(AbortReq(rid=req.rid))
|
||||||
req.grammar.finished = req.finished()
|
req.grammar.finished = req.finished()
|
||||||
|
|
||||||
self.set_next_batch_sampling_info_done(batch)
|
|
||||||
self.stream_output(batch.reqs, batch.return_logprob)
|
self.stream_output(batch.reqs, batch.return_logprob)
|
||||||
self.token_to_kv_pool_allocator.free_group_end()
|
self.token_to_kv_pool_allocator.free_group_end()
|
||||||
|
|
||||||
|
|||||||
@@ -1,307 +0,0 @@
|
|||||||
# Copyright 2023-2024 SGLang Team
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""A tensor parallel worker."""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import dataclasses
|
|
||||||
import logging
|
|
||||||
import signal
|
|
||||||
import threading
|
|
||||||
from queue import Queue
|
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
|
||||||
|
|
||||||
import psutil
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from sglang.srt.managers.io_struct import (
|
|
||||||
DestroyWeightsUpdateGroupReqInput,
|
|
||||||
GetWeightsByNameReqInput,
|
|
||||||
InitWeightsSendGroupForRemoteInstanceReqInput,
|
|
||||||
InitWeightsUpdateGroupReqInput,
|
|
||||||
LoadLoRAAdapterReqInput,
|
|
||||||
SendWeightsToRemoteInstanceReqInput,
|
|
||||||
UnloadLoRAAdapterReqInput,
|
|
||||||
UpdateWeightFromDiskReqInput,
|
|
||||||
UpdateWeightsFromDistributedReqInput,
|
|
||||||
UpdateWeightsFromTensorReqInput,
|
|
||||||
)
|
|
||||||
from sglang.srt.managers.overlap_utils import FutureMap
|
|
||||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
|
||||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatchOutput
|
|
||||||
from sglang.srt.server_args import ServerArgs
|
|
||||||
from sglang.srt.utils import DynamicGradMode
|
|
||||||
from sglang.utils import get_exception_traceback
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from sglang.srt.managers.cache_controller import LayerDoneCounter
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class TpModelWorkerClient:
|
|
||||||
"""A tensor parallel model worker."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
server_args: ServerArgs,
|
|
||||||
gpu_id: int,
|
|
||||||
tp_rank: int,
|
|
||||||
moe_ep_rank: int,
|
|
||||||
pp_rank: int,
|
|
||||||
dp_rank: Optional[int],
|
|
||||||
nccl_port: int,
|
|
||||||
):
|
|
||||||
# Load the model
|
|
||||||
self.worker = TpModelWorker(
|
|
||||||
server_args, gpu_id, tp_rank, moe_ep_rank, pp_rank, dp_rank, nccl_port
|
|
||||||
)
|
|
||||||
self.max_running_requests = self.worker.max_running_requests
|
|
||||||
self.device = self.worker.device
|
|
||||||
self.gpu_id = gpu_id
|
|
||||||
|
|
||||||
# Init future mappings
|
|
||||||
self.future_map = FutureMap(self.max_running_requests, self.device)
|
|
||||||
|
|
||||||
# Launch threads
|
|
||||||
self.input_queue = Queue[Tuple[ModelWorkerBatch, int, torch.Event]]()
|
|
||||||
self.output_queue = Queue()
|
|
||||||
self.forward_stream = torch.get_device_module(self.device).Stream()
|
|
||||||
self.forward_thread = threading.Thread(
|
|
||||||
target=self.forward_thread_func,
|
|
||||||
)
|
|
||||||
self.forward_thread.start()
|
|
||||||
self.parent_process = psutil.Process().parent()
|
|
||||||
self.scheduler_stream = torch.get_device_module(self.device).current_stream()
|
|
||||||
if self.device == "cpu":
|
|
||||||
self.scheduler_stream.synchronize = lambda: None # No-op for CPU
|
|
||||||
|
|
||||||
self.hicache_layer_transfer_counter = None
|
|
||||||
|
|
||||||
def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
|
|
||||||
self.hicache_layer_transfer_counter = counter
|
|
||||||
|
|
||||||
def get_worker_info(self):
|
|
||||||
return self.worker.get_worker_info()
|
|
||||||
|
|
||||||
def get_tokens_per_layer_info(self):
|
|
||||||
return self.worker.get_tokens_per_layer_info()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def sliding_window_size(self) -> Optional[int]:
|
|
||||||
return self.worker.sliding_window_size
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_hybrid(self) -> bool:
|
|
||||||
return self.worker.is_hybrid
|
|
||||||
|
|
||||||
def get_pad_input_ids_func(self):
|
|
||||||
return self.worker.get_pad_input_ids_func()
|
|
||||||
|
|
||||||
def get_tp_group(self):
|
|
||||||
return self.worker.get_tp_group()
|
|
||||||
|
|
||||||
def get_attention_tp_group(self):
|
|
||||||
return self.worker.get_attention_tp_group()
|
|
||||||
|
|
||||||
def get_attention_tp_cpu_group(self):
|
|
||||||
return self.worker.get_attention_tp_cpu_group()
|
|
||||||
|
|
||||||
def get_memory_pool(self):
|
|
||||||
return (
|
|
||||||
self.worker.model_runner.req_to_token_pool,
|
|
||||||
self.worker.model_runner.token_to_kv_pool_allocator,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_kv_cache(self):
|
|
||||||
return self.worker.model_runner.token_to_kv_pool
|
|
||||||
|
|
||||||
def forward_thread_func(self):
|
|
||||||
try:
|
|
||||||
with torch.get_device_module(self.device).stream(self.forward_stream):
|
|
||||||
self.forward_thread_func_()
|
|
||||||
except Exception:
|
|
||||||
traceback = get_exception_traceback()
|
|
||||||
logger.error(f"TpModelWorkerClient hit an exception: {traceback}")
|
|
||||||
self.parent_process.send_signal(signal.SIGQUIT)
|
|
||||||
|
|
||||||
@DynamicGradMode()
|
|
||||||
def forward_thread_func_(self):
|
|
||||||
batch_pt = 0
|
|
||||||
batch_lists: List = [None] * 2
|
|
||||||
|
|
||||||
while True:
|
|
||||||
model_worker_batch, future_map_ct, sync_event = self.input_queue.get()
|
|
||||||
if not model_worker_batch:
|
|
||||||
break
|
|
||||||
|
|
||||||
sync_event.wait()
|
|
||||||
|
|
||||||
# Keep a reference of model_worker_batch by storing it into a list.
|
|
||||||
# Otherwise, the tensor members of model_worker_batch will be released
|
|
||||||
# by pytorch and cause CUDA illegal memory access errors.
|
|
||||||
batch_lists[batch_pt % 2] = model_worker_batch
|
|
||||||
batch_pt += 1
|
|
||||||
|
|
||||||
# Create event
|
|
||||||
copy_done = torch.get_device_module(self.device).Event()
|
|
||||||
|
|
||||||
# Resolve future tokens in the input
|
|
||||||
self.future_map.resolve_future(model_worker_batch)
|
|
||||||
|
|
||||||
# Run forward
|
|
||||||
forward_batch_output = self.worker.forward_batch_generation(
|
|
||||||
model_worker_batch,
|
|
||||||
model_worker_batch.launch_done,
|
|
||||||
)
|
|
||||||
|
|
||||||
logits_output, next_token_ids, can_run_cuda_graph = (
|
|
||||||
forward_batch_output.logits_output,
|
|
||||||
forward_batch_output.next_token_ids,
|
|
||||||
forward_batch_output.can_run_cuda_graph,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update the future token ids map
|
|
||||||
bs = len(model_worker_batch.seq_lens)
|
|
||||||
if model_worker_batch.is_prefill_only:
|
|
||||||
# For prefill-only requests, create dummy token IDs on CPU
|
|
||||||
next_token_ids = torch.zeros(bs, dtype=torch.long)
|
|
||||||
|
|
||||||
# store the future indices into future map
|
|
||||||
self.future_map.store_to_map(future_map_ct, bs, next_token_ids)
|
|
||||||
|
|
||||||
# Copy results to the CPU
|
|
||||||
if model_worker_batch.return_logprob:
|
|
||||||
if logits_output.next_token_logprobs is not None:
|
|
||||||
logits_output.next_token_logprobs = (
|
|
||||||
logits_output.next_token_logprobs.to("cpu", non_blocking=True)
|
|
||||||
)
|
|
||||||
if logits_output.input_token_logprobs is not None:
|
|
||||||
logits_output.input_token_logprobs = (
|
|
||||||
logits_output.input_token_logprobs.to("cpu", non_blocking=True)
|
|
||||||
)
|
|
||||||
if logits_output.hidden_states is not None:
|
|
||||||
logits_output.hidden_states = logits_output.hidden_states.to(
|
|
||||||
"cpu", non_blocking=True
|
|
||||||
)
|
|
||||||
# Only copy to CPU if not already on CPU
|
|
||||||
if next_token_ids.device.type != "cpu":
|
|
||||||
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
|
|
||||||
copy_done.record()
|
|
||||||
|
|
||||||
self.output_queue.put(
|
|
||||||
(copy_done, logits_output, next_token_ids, can_run_cuda_graph)
|
|
||||||
)
|
|
||||||
|
|
||||||
def resolve_last_batch_result(self, launch_done: Optional[threading.Event] = None):
|
|
||||||
"""
|
|
||||||
This function is called to resolve the last batch result and
|
|
||||||
wait for the current batch to be launched. Used in overlap mode.
|
|
||||||
"""
|
|
||||||
copy_done, logits_output, next_token_ids, can_run_cuda_graph = (
|
|
||||||
self.output_queue.get()
|
|
||||||
)
|
|
||||||
|
|
||||||
if launch_done is not None:
|
|
||||||
launch_done.wait()
|
|
||||||
copy_done.synchronize()
|
|
||||||
|
|
||||||
if logits_output.next_token_logprobs is not None:
|
|
||||||
logits_output.next_token_logprobs = (
|
|
||||||
logits_output.next_token_logprobs.tolist()
|
|
||||||
)
|
|
||||||
if logits_output.input_token_logprobs is not None:
|
|
||||||
logits_output.input_token_logprobs = tuple(
|
|
||||||
logits_output.input_token_logprobs.tolist()
|
|
||||||
)
|
|
||||||
next_token_ids = next_token_ids.tolist()
|
|
||||||
return logits_output, next_token_ids, can_run_cuda_graph
|
|
||||||
|
|
||||||
def forward_batch_generation(
|
|
||||||
self, model_worker_batch: ModelWorkerBatch
|
|
||||||
) -> ForwardBatchOutput:
|
|
||||||
# Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
|
|
||||||
model_worker_batch.sampling_info = self.cur_sampling_info = (
|
|
||||||
model_worker_batch.sampling_info.copy_for_forward()
|
|
||||||
)
|
|
||||||
|
|
||||||
# A cuda stream sync here to avoid the cuda illegal memory access error.
|
|
||||||
sync_event = torch.get_device_module(self.device).Event()
|
|
||||||
sync_event.record(self.scheduler_stream)
|
|
||||||
|
|
||||||
# Push a new batch to the queue
|
|
||||||
bs = len(model_worker_batch.seq_lens)
|
|
||||||
cur_future_map_ct = self.future_map.update_ct(bs)
|
|
||||||
self.input_queue.put((model_worker_batch, cur_future_map_ct, sync_event))
|
|
||||||
|
|
||||||
# get this forward batch's future token ids
|
|
||||||
future_next_token_ids = self.future_map.update_next_future(
|
|
||||||
cur_future_map_ct, bs
|
|
||||||
)
|
|
||||||
return ForwardBatchOutput(
|
|
||||||
next_token_ids=future_next_token_ids,
|
|
||||||
can_run_cuda_graph=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
|
||||||
success, message = self.worker.update_weights_from_disk(recv_req)
|
|
||||||
return success, message
|
|
||||||
|
|
||||||
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
|
||||||
success, message = self.worker.init_weights_update_group(recv_req)
|
|
||||||
return success, message
|
|
||||||
|
|
||||||
def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
|
|
||||||
success, message = self.worker.destroy_weights_update_group(recv_req)
|
|
||||||
return success, message
|
|
||||||
|
|
||||||
def init_weights_send_group_for_remote_instance(
|
|
||||||
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
|
|
||||||
):
|
|
||||||
success, message = self.worker.init_weights_send_group_for_remote_instance(
|
|
||||||
recv_req
|
|
||||||
)
|
|
||||||
return success, message
|
|
||||||
|
|
||||||
def send_weights_to_remote_instance(
|
|
||||||
self, recv_req: SendWeightsToRemoteInstanceReqInput
|
|
||||||
):
|
|
||||||
success, message = self.worker.send_weights_to_remote_instance(recv_req)
|
|
||||||
return success, message
|
|
||||||
|
|
||||||
def update_weights_from_distributed(
|
|
||||||
self, recv_req: UpdateWeightsFromDistributedReqInput
|
|
||||||
):
|
|
||||||
success, message = self.worker.update_weights_from_distributed(recv_req)
|
|
||||||
return success, message
|
|
||||||
|
|
||||||
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
|
||||||
success, message = self.worker.update_weights_from_tensor(recv_req)
|
|
||||||
return success, message
|
|
||||||
|
|
||||||
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
|
||||||
return self.worker.get_weights_by_name(recv_req)
|
|
||||||
|
|
||||||
def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
|
|
||||||
return self.worker.load_lora_adapter(recv_req)
|
|
||||||
|
|
||||||
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
|
|
||||||
return self.worker.unload_lora_adapter(recv_req)
|
|
||||||
|
|
||||||
def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
|
|
||||||
return self.worker.can_run_lora_batch(lora_ids)
|
|
||||||
|
|
||||||
def __delete__(self):
|
|
||||||
self.input_queue.put((None, None))
|
|
||||||
self.copy_queue.put((None, None, None))
|
|
||||||
@@ -75,10 +75,6 @@ class ForwardMode(IntEnum):
|
|||||||
# Used in speculative decoding: extend a batch in the draft model.
|
# Used in speculative decoding: extend a batch in the draft model.
|
||||||
DRAFT_EXTEND = auto()
|
DRAFT_EXTEND = auto()
|
||||||
|
|
||||||
# A dummy first batch to start the pipeline for overlap scheduler.
|
|
||||||
# It is now used for triggering the sampling_info_done event for the first prefill batch.
|
|
||||||
DUMMY_FIRST = auto()
|
|
||||||
|
|
||||||
# Split Prefill for PD multiplexing
|
# Split Prefill for PD multiplexing
|
||||||
SPLIT_PREFILL = auto()
|
SPLIT_PREFILL = auto()
|
||||||
|
|
||||||
@@ -128,9 +124,6 @@ class ForwardMode(IntEnum):
|
|||||||
def is_cpu_graph(self):
|
def is_cpu_graph(self):
|
||||||
return self == ForwardMode.DECODE
|
return self == ForwardMode.DECODE
|
||||||
|
|
||||||
def is_dummy_first(self):
|
|
||||||
return self == ForwardMode.DUMMY_FIRST
|
|
||||||
|
|
||||||
def is_split_prefill(self):
|
def is_split_prefill(self):
|
||||||
return self == ForwardMode.SPLIT_PREFILL
|
return self == ForwardMode.SPLIT_PREFILL
|
||||||
|
|
||||||
|
|||||||
@@ -2057,15 +2057,11 @@ class ModelRunner:
|
|||||||
def _preprocess_logits(
|
def _preprocess_logits(
|
||||||
self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
|
self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
|
||||||
):
|
):
|
||||||
# Apply logit bias
|
# NOTE: In overlap mode, the function update_regex_vocab_mask (in sample)
|
||||||
if sampling_info.sampling_info_done:
|
# was executed after we processed last batch's results.
|
||||||
# Overlap mode: the function update_regex_vocab_mask was executed
|
|
||||||
# in process_batch_result of the last batch.
|
# Calculate logits bias and apply it to next_token_logits.
|
||||||
if sampling_info.grammars:
|
sampling_info.update_regex_vocab_mask()
|
||||||
sampling_info.sampling_info_done.wait()
|
|
||||||
else:
|
|
||||||
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
|
||||||
sampling_info.update_regex_vocab_mask()
|
|
||||||
sampling_info.apply_logits_bias(logits_output.next_token_logits)
|
sampling_info.apply_logits_bias(logits_output.next_token_logits)
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
|
|||||||
@@ -44,12 +44,9 @@ class SamplingBatchInfo:
|
|||||||
vocab_mask: Optional[torch.Tensor] = None
|
vocab_mask: Optional[torch.Tensor] = None
|
||||||
apply_mask_func: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
|
apply_mask_func: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
|
||||||
|
|
||||||
# An event used for overlap schedule
|
|
||||||
sampling_info_done: Optional[threading.Event] = None
|
|
||||||
|
|
||||||
# Penalizer
|
# Penalizer
|
||||||
penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
|
penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
|
||||||
linear_penalty: torch.Tensor = None
|
acc_linear_penalties: torch.Tensor = None # Used in the overlap mode
|
||||||
|
|
||||||
# Whether any request has custom logit processor
|
# Whether any request has custom logit processor
|
||||||
has_custom_logit_processor: bool = False
|
has_custom_logit_processor: bool = False
|
||||||
@@ -217,19 +214,19 @@ class SamplingBatchInfo:
|
|||||||
|
|
||||||
def update_penalties(self):
|
def update_penalties(self):
|
||||||
if self.penalizer_orchestrator.is_required:
|
if self.penalizer_orchestrator.is_required:
|
||||||
self.linear_penalty = torch.zeros(
|
self.acc_linear_penalties = torch.zeros(
|
||||||
(len(self.temperatures), self.vocab_size),
|
(len(self.temperatures), self.vocab_size),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=self.temperatures.device,
|
device=self.temperatures.device,
|
||||||
)
|
)
|
||||||
self.penalizer_orchestrator.apply(self.linear_penalty)
|
self.penalizer_orchestrator.apply(self.acc_linear_penalties)
|
||||||
else:
|
else:
|
||||||
self.linear_penalty = None
|
self.acc_linear_penalties = None
|
||||||
|
|
||||||
def apply_logits_bias(self, logits: torch.Tensor):
|
def apply_logits_bias(self, logits: torch.Tensor):
|
||||||
if self.linear_penalty is not None:
|
if self.acc_linear_penalties is not None:
|
||||||
# Used in the overlap mode
|
# Used in the overlap mode
|
||||||
logits.add_(self.linear_penalty)
|
logits.add_(self.acc_linear_penalties)
|
||||||
|
|
||||||
if self.penalizer_orchestrator and self.penalizer_orchestrator.is_required:
|
if self.penalizer_orchestrator and self.penalizer_orchestrator.is_required:
|
||||||
# Used in the non-overlap mode
|
# Used in the non-overlap mode
|
||||||
@@ -373,11 +370,7 @@ class SamplingBatchInfo:
|
|||||||
def copy_for_forward(self):
|
def copy_for_forward(self):
|
||||||
# Accumulate the penalty into a pre-allocated buffer to get rid of the dependency of `penalizer_orchestrator` later
|
# Accumulate the penalty into a pre-allocated buffer to get rid of the dependency of `penalizer_orchestrator` later
|
||||||
self.update_penalties()
|
self.update_penalties()
|
||||||
return dataclasses.replace(
|
return dataclasses.replace(self, penalizer_orchestrator=None)
|
||||||
self,
|
|
||||||
sampling_info_done=threading.Event(),
|
|
||||||
penalizer_orchestrator=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def merge_bias_tensor(
|
def merge_bias_tensor(
|
||||||
|
|||||||
Reference in New Issue
Block a user