Remove sampling info events and overlap thread file (#11300)
This commit is contained in:
@@ -783,16 +783,6 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
self.prepare_mlp_sync_batch(batch)
|
||||
result = self.run_batch(batch)
|
||||
self.result_queue.append((batch.copy(), result))
|
||||
|
||||
if (self.last_batch is None) or (not self.last_batch_in_queue):
|
||||
# Create a dummy first batch to start the pipeline for overlap schedule.
|
||||
# It is now used for triggering the sampling_info_done event.
|
||||
tmp_batch = ScheduleBatch(
|
||||
reqs=None,
|
||||
forward_mode=ForwardMode.DUMMY_FIRST,
|
||||
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
||||
)
|
||||
self.set_next_batch_sampling_info_done(tmp_batch)
|
||||
last_batch_in_queue = True
|
||||
|
||||
elif prepare_mlp_sync_flag:
|
||||
@@ -806,9 +796,6 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
# Process the results of the previous batch but skip if the last batch is extend
|
||||
if self.last_batch and self.last_batch_in_queue:
|
||||
tmp_batch, tmp_result = self.result_queue.popleft()
|
||||
tmp_batch.next_batch_sampling_info = (
|
||||
self.tp_worker.cur_sampling_info if batch else None
|
||||
)
|
||||
self.process_batch_result(tmp_batch, tmp_result)
|
||||
|
||||
queue_size = (
|
||||
|
||||
@@ -338,21 +338,8 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
result = self.run_batch(batch)
|
||||
self.result_queue.append((batch.copy(), result))
|
||||
|
||||
if self.last_batch is None:
|
||||
# Create a dummy first batch to start the pipeline for overlap schedule.
|
||||
# It is now used for triggering the sampling_info_done event.
|
||||
tmp_batch = ScheduleBatch(
|
||||
reqs=None,
|
||||
forward_mode=ForwardMode.DUMMY_FIRST,
|
||||
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
||||
)
|
||||
self.set_next_batch_sampling_info_done(tmp_batch)
|
||||
|
||||
if self.last_batch:
|
||||
tmp_batch, tmp_result = self.result_queue.popleft()
|
||||
tmp_batch.next_batch_sampling_info = (
|
||||
self.tp_worker.cur_sampling_info if batch else None
|
||||
)
|
||||
self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
|
||||
|
||||
if len(self.disagg_prefill_inflight_queue) > 0:
|
||||
@@ -491,8 +478,6 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
if self.enable_overlap:
|
||||
self.send_kv_chunk(req, last_chunk=False, end_idx=req.tmp_end_idx)
|
||||
|
||||
# We need to remove the sync in the following function for overlap schedule.
|
||||
self.set_next_batch_sampling_info_done(batch)
|
||||
self.maybe_send_health_check_signal()
|
||||
|
||||
def process_disagg_prefill_inflight_queue(
|
||||
|
||||
@@ -891,7 +891,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
|
||||
# Sampling info
|
||||
sampling_info: SamplingBatchInfo = None
|
||||
next_batch_sampling_info: SamplingBatchInfo = None
|
||||
|
||||
# Batched arguments to model runner
|
||||
input_ids: torch.Tensor = None # shape: [b], int64
|
||||
|
||||
@@ -1012,22 +1012,9 @@ class Scheduler(
|
||||
result = self.run_batch(batch)
|
||||
self.result_queue.append((batch.copy(), result))
|
||||
|
||||
if self.last_batch is None:
|
||||
# Create a dummy first batch to start the pipeline for overlap schedule.
|
||||
# It is now used for triggering the sampling_info_done event.
|
||||
tmp_batch = ScheduleBatch(
|
||||
reqs=None,
|
||||
forward_mode=ForwardMode.DUMMY_FIRST,
|
||||
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
||||
)
|
||||
self.process_batch_result(tmp_batch, None)
|
||||
|
||||
if self.last_batch:
|
||||
# Process the results of the last batch
|
||||
tmp_batch, tmp_result = self.result_queue.popleft()
|
||||
tmp_batch.next_batch_sampling_info = (
|
||||
self.tp_worker.cur_sampling_info if batch else None
|
||||
)
|
||||
self.process_batch_result(tmp_batch, tmp_result)
|
||||
elif batch is None:
|
||||
# When the server is idle, do self-check and re-init some states
|
||||
@@ -2100,7 +2087,7 @@ class Scheduler(
|
||||
self.record_batch_in_overlap(model_worker_batch)
|
||||
|
||||
# Sampling info will be modified during forward
|
||||
model_worker_batch.sampling_info = self.tp_worker.cur_sampling_info = (
|
||||
model_worker_batch.sampling_info = (
|
||||
model_worker_batch.sampling_info.copy_for_forward()
|
||||
)
|
||||
|
||||
@@ -2219,9 +2206,6 @@ class Scheduler(
|
||||
if self.enable_overlap:
|
||||
if result.copy_done is not None:
|
||||
result.copy_done.synchronize()
|
||||
self.set_next_batch_sampling_info_done(batch)
|
||||
elif batch.forward_mode.is_dummy_first():
|
||||
self.set_next_batch_sampling_info_done(batch)
|
||||
|
||||
self.maybe_send_health_check_signal()
|
||||
|
||||
@@ -2431,13 +2415,6 @@ class Scheduler(
|
||||
self._add_request_to_queue(req)
|
||||
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
||||
|
||||
def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
|
||||
if batch.next_batch_sampling_info:
|
||||
if batch.next_batch_sampling_info.grammars is not None:
|
||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||
self.default_stream.synchronize()
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
|
||||
def watchdog_thread(self):
|
||||
"""A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
|
||||
self.watchdog_last_forward_ct = 0
|
||||
|
||||
@@ -173,8 +173,6 @@ class SchedulerOutputProcessorMixin:
|
||||
)
|
||||
logprob_pt += num_input_logprobs
|
||||
|
||||
self.set_next_batch_sampling_info_done(batch)
|
||||
|
||||
else: # embedding or reward model
|
||||
embeddings = result.embeddings.tolist()
|
||||
|
||||
@@ -295,7 +293,6 @@ class SchedulerOutputProcessorMixin:
|
||||
self.abort_request(AbortReq(rid=req.rid))
|
||||
req.grammar.finished = req.finished()
|
||||
|
||||
self.set_next_batch_sampling_info_done(batch)
|
||||
self.stream_output(batch.reqs, batch.return_logprob)
|
||||
self.token_to_kv_pool_allocator.free_group_end()
|
||||
|
||||
|
||||
@@ -1,307 +0,0 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""A tensor parallel worker."""
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
import signal
|
||||
import threading
|
||||
from queue import Queue
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.io_struct import (
|
||||
DestroyWeightsUpdateGroupReqInput,
|
||||
GetWeightsByNameReqInput,
|
||||
InitWeightsSendGroupForRemoteInstanceReqInput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
LoadLoRAAdapterReqInput,
|
||||
SendWeightsToRemoteInstanceReqInput,
|
||||
UnloadLoRAAdapterReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
)
|
||||
from sglang.srt.managers.overlap_utils import FutureMap
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatchOutput
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import DynamicGradMode
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.cache_controller import LayerDoneCounter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TpModelWorkerClient:
|
||||
"""A tensor parallel model worker."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
gpu_id: int,
|
||||
tp_rank: int,
|
||||
moe_ep_rank: int,
|
||||
pp_rank: int,
|
||||
dp_rank: Optional[int],
|
||||
nccl_port: int,
|
||||
):
|
||||
# Load the model
|
||||
self.worker = TpModelWorker(
|
||||
server_args, gpu_id, tp_rank, moe_ep_rank, pp_rank, dp_rank, nccl_port
|
||||
)
|
||||
self.max_running_requests = self.worker.max_running_requests
|
||||
self.device = self.worker.device
|
||||
self.gpu_id = gpu_id
|
||||
|
||||
# Init future mappings
|
||||
self.future_map = FutureMap(self.max_running_requests, self.device)
|
||||
|
||||
# Launch threads
|
||||
self.input_queue = Queue[Tuple[ModelWorkerBatch, int, torch.Event]]()
|
||||
self.output_queue = Queue()
|
||||
self.forward_stream = torch.get_device_module(self.device).Stream()
|
||||
self.forward_thread = threading.Thread(
|
||||
target=self.forward_thread_func,
|
||||
)
|
||||
self.forward_thread.start()
|
||||
self.parent_process = psutil.Process().parent()
|
||||
self.scheduler_stream = torch.get_device_module(self.device).current_stream()
|
||||
if self.device == "cpu":
|
||||
self.scheduler_stream.synchronize = lambda: None # No-op for CPU
|
||||
|
||||
self.hicache_layer_transfer_counter = None
|
||||
|
||||
def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
|
||||
self.hicache_layer_transfer_counter = counter
|
||||
|
||||
def get_worker_info(self):
|
||||
return self.worker.get_worker_info()
|
||||
|
||||
def get_tokens_per_layer_info(self):
|
||||
return self.worker.get_tokens_per_layer_info()
|
||||
|
||||
@property
|
||||
def sliding_window_size(self) -> Optional[int]:
|
||||
return self.worker.sliding_window_size
|
||||
|
||||
@property
|
||||
def is_hybrid(self) -> bool:
|
||||
return self.worker.is_hybrid
|
||||
|
||||
def get_pad_input_ids_func(self):
|
||||
return self.worker.get_pad_input_ids_func()
|
||||
|
||||
def get_tp_group(self):
|
||||
return self.worker.get_tp_group()
|
||||
|
||||
def get_attention_tp_group(self):
|
||||
return self.worker.get_attention_tp_group()
|
||||
|
||||
def get_attention_tp_cpu_group(self):
|
||||
return self.worker.get_attention_tp_cpu_group()
|
||||
|
||||
def get_memory_pool(self):
|
||||
return (
|
||||
self.worker.model_runner.req_to_token_pool,
|
||||
self.worker.model_runner.token_to_kv_pool_allocator,
|
||||
)
|
||||
|
||||
def get_kv_cache(self):
|
||||
return self.worker.model_runner.token_to_kv_pool
|
||||
|
||||
def forward_thread_func(self):
|
||||
try:
|
||||
with torch.get_device_module(self.device).stream(self.forward_stream):
|
||||
self.forward_thread_func_()
|
||||
except Exception:
|
||||
traceback = get_exception_traceback()
|
||||
logger.error(f"TpModelWorkerClient hit an exception: {traceback}")
|
||||
self.parent_process.send_signal(signal.SIGQUIT)
|
||||
|
||||
@DynamicGradMode()
|
||||
def forward_thread_func_(self):
|
||||
batch_pt = 0
|
||||
batch_lists: List = [None] * 2
|
||||
|
||||
while True:
|
||||
model_worker_batch, future_map_ct, sync_event = self.input_queue.get()
|
||||
if not model_worker_batch:
|
||||
break
|
||||
|
||||
sync_event.wait()
|
||||
|
||||
# Keep a reference of model_worker_batch by storing it into a list.
|
||||
# Otherwise, the tensor members of model_worker_batch will be released
|
||||
# by pytorch and cause CUDA illegal memory access errors.
|
||||
batch_lists[batch_pt % 2] = model_worker_batch
|
||||
batch_pt += 1
|
||||
|
||||
# Create event
|
||||
copy_done = torch.get_device_module(self.device).Event()
|
||||
|
||||
# Resolve future tokens in the input
|
||||
self.future_map.resolve_future(model_worker_batch)
|
||||
|
||||
# Run forward
|
||||
forward_batch_output = self.worker.forward_batch_generation(
|
||||
model_worker_batch,
|
||||
model_worker_batch.launch_done,
|
||||
)
|
||||
|
||||
logits_output, next_token_ids, can_run_cuda_graph = (
|
||||
forward_batch_output.logits_output,
|
||||
forward_batch_output.next_token_ids,
|
||||
forward_batch_output.can_run_cuda_graph,
|
||||
)
|
||||
|
||||
# Update the future token ids map
|
||||
bs = len(model_worker_batch.seq_lens)
|
||||
if model_worker_batch.is_prefill_only:
|
||||
# For prefill-only requests, create dummy token IDs on CPU
|
||||
next_token_ids = torch.zeros(bs, dtype=torch.long)
|
||||
|
||||
# store the future indices into future map
|
||||
self.future_map.store_to_map(future_map_ct, bs, next_token_ids)
|
||||
|
||||
# Copy results to the CPU
|
||||
if model_worker_batch.return_logprob:
|
||||
if logits_output.next_token_logprobs is not None:
|
||||
logits_output.next_token_logprobs = (
|
||||
logits_output.next_token_logprobs.to("cpu", non_blocking=True)
|
||||
)
|
||||
if logits_output.input_token_logprobs is not None:
|
||||
logits_output.input_token_logprobs = (
|
||||
logits_output.input_token_logprobs.to("cpu", non_blocking=True)
|
||||
)
|
||||
if logits_output.hidden_states is not None:
|
||||
logits_output.hidden_states = logits_output.hidden_states.to(
|
||||
"cpu", non_blocking=True
|
||||
)
|
||||
# Only copy to CPU if not already on CPU
|
||||
if next_token_ids.device.type != "cpu":
|
||||
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
|
||||
copy_done.record()
|
||||
|
||||
self.output_queue.put(
|
||||
(copy_done, logits_output, next_token_ids, can_run_cuda_graph)
|
||||
)
|
||||
|
||||
def resolve_last_batch_result(self, launch_done: Optional[threading.Event] = None):
|
||||
"""
|
||||
This function is called to resolve the last batch result and
|
||||
wait for the current batch to be launched. Used in overlap mode.
|
||||
"""
|
||||
copy_done, logits_output, next_token_ids, can_run_cuda_graph = (
|
||||
self.output_queue.get()
|
||||
)
|
||||
|
||||
if launch_done is not None:
|
||||
launch_done.wait()
|
||||
copy_done.synchronize()
|
||||
|
||||
if logits_output.next_token_logprobs is not None:
|
||||
logits_output.next_token_logprobs = (
|
||||
logits_output.next_token_logprobs.tolist()
|
||||
)
|
||||
if logits_output.input_token_logprobs is not None:
|
||||
logits_output.input_token_logprobs = tuple(
|
||||
logits_output.input_token_logprobs.tolist()
|
||||
)
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
return logits_output, next_token_ids, can_run_cuda_graph
|
||||
|
||||
def forward_batch_generation(
|
||||
self, model_worker_batch: ModelWorkerBatch
|
||||
) -> ForwardBatchOutput:
|
||||
# Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
|
||||
model_worker_batch.sampling_info = self.cur_sampling_info = (
|
||||
model_worker_batch.sampling_info.copy_for_forward()
|
||||
)
|
||||
|
||||
# A cuda stream sync here to avoid the cuda illegal memory access error.
|
||||
sync_event = torch.get_device_module(self.device).Event()
|
||||
sync_event.record(self.scheduler_stream)
|
||||
|
||||
# Push a new batch to the queue
|
||||
bs = len(model_worker_batch.seq_lens)
|
||||
cur_future_map_ct = self.future_map.update_ct(bs)
|
||||
self.input_queue.put((model_worker_batch, cur_future_map_ct, sync_event))
|
||||
|
||||
# get this forward batch's future token ids
|
||||
future_next_token_ids = self.future_map.update_next_future(
|
||||
cur_future_map_ct, bs
|
||||
)
|
||||
return ForwardBatchOutput(
|
||||
next_token_ids=future_next_token_ids,
|
||||
can_run_cuda_graph=False,
|
||||
)
|
||||
|
||||
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
||||
success, message = self.worker.update_weights_from_disk(recv_req)
|
||||
return success, message
|
||||
|
||||
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
||||
success, message = self.worker.init_weights_update_group(recv_req)
|
||||
return success, message
|
||||
|
||||
def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
|
||||
success, message = self.worker.destroy_weights_update_group(recv_req)
|
||||
return success, message
|
||||
|
||||
def init_weights_send_group_for_remote_instance(
|
||||
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
|
||||
):
|
||||
success, message = self.worker.init_weights_send_group_for_remote_instance(
|
||||
recv_req
|
||||
)
|
||||
return success, message
|
||||
|
||||
def send_weights_to_remote_instance(
|
||||
self, recv_req: SendWeightsToRemoteInstanceReqInput
|
||||
):
|
||||
success, message = self.worker.send_weights_to_remote_instance(recv_req)
|
||||
return success, message
|
||||
|
||||
def update_weights_from_distributed(
|
||||
self, recv_req: UpdateWeightsFromDistributedReqInput
|
||||
):
|
||||
success, message = self.worker.update_weights_from_distributed(recv_req)
|
||||
return success, message
|
||||
|
||||
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
||||
success, message = self.worker.update_weights_from_tensor(recv_req)
|
||||
return success, message
|
||||
|
||||
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
||||
return self.worker.get_weights_by_name(recv_req)
|
||||
|
||||
def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
|
||||
return self.worker.load_lora_adapter(recv_req)
|
||||
|
||||
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
|
||||
return self.worker.unload_lora_adapter(recv_req)
|
||||
|
||||
def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
|
||||
return self.worker.can_run_lora_batch(lora_ids)
|
||||
|
||||
def __delete__(self):
|
||||
self.input_queue.put((None, None))
|
||||
self.copy_queue.put((None, None, None))
|
||||
@@ -75,10 +75,6 @@ class ForwardMode(IntEnum):
|
||||
# Used in speculative decoding: extend a batch in the draft model.
|
||||
DRAFT_EXTEND = auto()
|
||||
|
||||
# A dummy first batch to start the pipeline for overlap scheduler.
|
||||
# It is now used for triggering the sampling_info_done event for the first prefill batch.
|
||||
DUMMY_FIRST = auto()
|
||||
|
||||
# Split Prefill for PD multiplexing
|
||||
SPLIT_PREFILL = auto()
|
||||
|
||||
@@ -128,9 +124,6 @@ class ForwardMode(IntEnum):
|
||||
def is_cpu_graph(self):
|
||||
return self == ForwardMode.DECODE
|
||||
|
||||
def is_dummy_first(self):
|
||||
return self == ForwardMode.DUMMY_FIRST
|
||||
|
||||
def is_split_prefill(self):
|
||||
return self == ForwardMode.SPLIT_PREFILL
|
||||
|
||||
|
||||
@@ -2057,15 +2057,11 @@ class ModelRunner:
|
||||
def _preprocess_logits(
|
||||
self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
|
||||
):
|
||||
# Apply logit bias
|
||||
if sampling_info.sampling_info_done:
|
||||
# Overlap mode: the function update_regex_vocab_mask was executed
|
||||
# in process_batch_result of the last batch.
|
||||
if sampling_info.grammars:
|
||||
sampling_info.sampling_info_done.wait()
|
||||
else:
|
||||
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
||||
sampling_info.update_regex_vocab_mask()
|
||||
# NOTE: In overlap mode, the function update_regex_vocab_mask (in sample)
|
||||
# was executed after we processed last batch's results.
|
||||
|
||||
# Calculate logits bias and apply it to next_token_logits.
|
||||
sampling_info.update_regex_vocab_mask()
|
||||
sampling_info.apply_logits_bias(logits_output.next_token_logits)
|
||||
|
||||
def sample(
|
||||
|
||||
@@ -44,12 +44,9 @@ class SamplingBatchInfo:
|
||||
vocab_mask: Optional[torch.Tensor] = None
|
||||
apply_mask_func: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
|
||||
|
||||
# An event used for overlap schedule
|
||||
sampling_info_done: Optional[threading.Event] = None
|
||||
|
||||
# Penalizer
|
||||
penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
|
||||
linear_penalty: torch.Tensor = None
|
||||
acc_linear_penalties: torch.Tensor = None # Used in the overlap mode
|
||||
|
||||
# Whether any request has custom logit processor
|
||||
has_custom_logit_processor: bool = False
|
||||
@@ -217,19 +214,19 @@ class SamplingBatchInfo:
|
||||
|
||||
def update_penalties(self):
|
||||
if self.penalizer_orchestrator.is_required:
|
||||
self.linear_penalty = torch.zeros(
|
||||
self.acc_linear_penalties = torch.zeros(
|
||||
(len(self.temperatures), self.vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=self.temperatures.device,
|
||||
)
|
||||
self.penalizer_orchestrator.apply(self.linear_penalty)
|
||||
self.penalizer_orchestrator.apply(self.acc_linear_penalties)
|
||||
else:
|
||||
self.linear_penalty = None
|
||||
self.acc_linear_penalties = None
|
||||
|
||||
def apply_logits_bias(self, logits: torch.Tensor):
|
||||
if self.linear_penalty is not None:
|
||||
if self.acc_linear_penalties is not None:
|
||||
# Used in the overlap mode
|
||||
logits.add_(self.linear_penalty)
|
||||
logits.add_(self.acc_linear_penalties)
|
||||
|
||||
if self.penalizer_orchestrator and self.penalizer_orchestrator.is_required:
|
||||
# Used in the non-overlap mode
|
||||
@@ -373,11 +370,7 @@ class SamplingBatchInfo:
|
||||
def copy_for_forward(self):
|
||||
# Accumulate the penalty into a pre-allocated buffer to get rid of the dependency of `penalizer_orchestrator` later
|
||||
self.update_penalties()
|
||||
return dataclasses.replace(
|
||||
self,
|
||||
sampling_info_done=threading.Event(),
|
||||
penalizer_orchestrator=None,
|
||||
)
|
||||
return dataclasses.replace(self, penalizer_orchestrator=None)
|
||||
|
||||
|
||||
def merge_bias_tensor(
|
||||
|
||||
Reference in New Issue
Block a user