Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)
Co-authored-by: SangBin Cho <rkooo567@gmail.com> Co-authored-by: dhou-xai <dhou@x.ai> Co-authored-by: Hanming Lu <hanming_lu@berkeley.edu>
This commit is contained in:
@@ -28,6 +28,7 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--url", type=str, default="http://localhost:30000")
|
||||
parser.add_argument("--log-requests", action="store_true")
|
||||
parser.add_argument("--log-requests-level", type=int, default=2)
|
||||
parser.add_argument(
|
||||
"--dump-requests-folder", type=str, default="/tmp/sglang_request_dump"
|
||||
)
|
||||
@@ -38,7 +39,7 @@ if __name__ == "__main__":
|
||||
args.url + "/configure_logging",
|
||||
json={
|
||||
"log_requests": args.log_requests,
|
||||
"log_requests_level": 1, # Log full requests
|
||||
"log_requests_level": args.log_requests_level, # Log full requests
|
||||
"dump_requests_folder": args.dump_requests_folder,
|
||||
"dump_requests_threshold": args.dump_requests_threshold,
|
||||
},
|
||||
|
||||
@@ -198,6 +198,8 @@ class DataParallelController:
|
||||
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
|
||||
self.max_req_input_len = scheduler_info[0]["max_req_input_len"]
|
||||
|
||||
print(f"{scheduler_info=}")
|
||||
|
||||
def round_robin_scheduler(self, req):
|
||||
self.workers[self.round_robin_counter].send_pyobj(req)
|
||||
self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)
|
||||
@@ -220,6 +222,7 @@ class DataParallelController:
|
||||
TokenizedEmbeddingReqInput,
|
||||
),
|
||||
):
|
||||
logger.info("dispatching")
|
||||
self.dispatching(recv_req)
|
||||
else:
|
||||
# Send other control messages to first worker of tp group
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
"""DetokenizerManager is a process that detokenizes the token ids."""
|
||||
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
@@ -27,11 +28,16 @@ import zmq
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.managers.io_struct import (
|
||||
BatchEmbeddingOut,
|
||||
BatchMultimodalDecodeReq,
|
||||
BatchStrOut,
|
||||
BatchTokenIDOut,
|
||||
)
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import configure_logger, get_zmq_socket
|
||||
from sglang.srt.utils import (
|
||||
configure_logger,
|
||||
get_zmq_socket,
|
||||
kill_itself_when_parent_died,
|
||||
)
|
||||
from sglang.utils import (
|
||||
TypeBasedDispatcher,
|
||||
find_printable_text,
|
||||
@@ -86,14 +92,23 @@ class DetokenizerManager:
|
||||
)
|
||||
|
||||
self.decode_status = LimitedCapacityDict(capacity=DETOKENIZER_MAX_STATES)
|
||||
self.is_dummy = server_args.load_format == "dummy"
|
||||
|
||||
self._request_dispatcher = TypeBasedDispatcher(
|
||||
[
|
||||
(BatchEmbeddingOut, self.handle_batch_embedding_out),
|
||||
(BatchTokenIDOut, self.handle_batch_token_id_out),
|
||||
(BatchMultimodalDecodeReq, self.handle_multimodal_decode_req),
|
||||
]
|
||||
)
|
||||
|
||||
def event_loop(self):
|
||||
"""The event loop that handles requests"""
|
||||
while True:
|
||||
recv_obj = self.recv_from_scheduler.recv_pyobj()
|
||||
output = self._request_dispatcher(recv_obj)
|
||||
self.send_to_tokenizer.send_pyobj(output)
|
||||
|
||||
def trim_matched_stop(
|
||||
self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
|
||||
):
|
||||
@@ -117,14 +132,6 @@ class DetokenizerManager:
|
||||
return output[:-1]
|
||||
return output
|
||||
|
||||
def event_loop(self):
|
||||
"""The event loop that handles requests"""
|
||||
|
||||
while True:
|
||||
recv_obj = self.recv_from_scheduler.recv_pyobj()
|
||||
output = self._request_dispatcher(recv_obj)
|
||||
self.send_to_tokenizer.send_pyobj(output)
|
||||
|
||||
def handle_batch_embedding_out(self, recv_obj: BatchEmbeddingOut):
|
||||
# If it is embedding model, no detokenization is needed.
|
||||
return recv_obj
|
||||
@@ -173,7 +180,6 @@ class DetokenizerManager:
|
||||
|
||||
# Incremental decoding
|
||||
output_strs = []
|
||||
finished_reqs = []
|
||||
for i in range(bs):
|
||||
try:
|
||||
s = self.decode_status[recv_obj.rids[i]]
|
||||
@@ -196,8 +202,6 @@ class DetokenizerManager:
|
||||
new_text = ""
|
||||
else:
|
||||
new_text = find_printable_text(new_text)
|
||||
else:
|
||||
finished_reqs.append(recv_obj.rids[i])
|
||||
|
||||
output_strs.append(
|
||||
self.trim_matched_stop(
|
||||
@@ -207,7 +211,7 @@ class DetokenizerManager:
|
||||
)
|
||||
)
|
||||
|
||||
out = BatchStrOut(
|
||||
return BatchStrOut(
|
||||
rids=recv_obj.rids,
|
||||
finished_reasons=recv_obj.finished_reasons,
|
||||
output_strs=output_strs,
|
||||
@@ -223,14 +227,15 @@ class DetokenizerManager:
|
||||
input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
|
||||
output_top_logprobs_val=recv_obj.output_top_logprobs_val,
|
||||
output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
|
||||
input_token_ids_logprobs_val=recv_obj.input_token_ids_logprobs_val,
|
||||
input_token_ids_logprobs_idx=recv_obj.input_token_ids_logprobs_idx,
|
||||
output_token_ids_logprobs_val=recv_obj.output_token_ids_logprobs_val,
|
||||
output_token_ids_logprobs_idx=recv_obj.output_token_ids_logprobs_idx,
|
||||
output_hidden_states=recv_obj.output_hidden_states,
|
||||
)
|
||||
|
||||
# remove decodestatus for completed requests
|
||||
for rid in finished_reqs:
|
||||
self.decode_status.pop(rid)
|
||||
|
||||
return out
|
||||
def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class LimitedCapacityDict(OrderedDict):
|
||||
@@ -250,6 +255,7 @@ def run_detokenizer_process(
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
):
|
||||
kill_itself_when_parent_died()
|
||||
setproctitle.setproctitle("sglang::detokenizer")
|
||||
configure_logger(server_args)
|
||||
parent_process = psutil.Process().parent()
|
||||
|
||||
@@ -16,10 +16,11 @@ The definition of objects transfered between different
|
||||
processes (TokenizerManager, DetokenizerManager, Controller).
|
||||
"""
|
||||
|
||||
import copy
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
@@ -55,6 +56,8 @@ class GenerateReqInput:
|
||||
logprob_start_len: Optional[Union[List[int], int]] = None
|
||||
# If return logprobs, the number of top logprobs to return at each position.
|
||||
top_logprobs_num: Optional[Union[List[int], int]] = None
|
||||
# If return logprobs, the token ids to return logprob for.
|
||||
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None
|
||||
# Whether to detokenize tokens in text in the returned logprobs.
|
||||
return_text_in_logprobs: bool = False
|
||||
# Whether to stream output.
|
||||
@@ -146,6 +149,8 @@ class GenerateReqInput:
|
||||
self.logprob_start_len = -1
|
||||
if self.top_logprobs_num is None:
|
||||
self.top_logprobs_num = 0
|
||||
if not self.token_ids_logprob: # covers both None and []
|
||||
self.token_ids_logprob = None
|
||||
else:
|
||||
if self.parallel_sample_num == 1:
|
||||
num = self.batch_size
|
||||
@@ -191,6 +196,17 @@ class GenerateReqInput:
|
||||
else:
|
||||
assert self.parallel_sample_num == 1
|
||||
|
||||
if not self.token_ids_logprob: # covers both None and []
|
||||
self.token_ids_logprob = [None] * num
|
||||
elif not isinstance(self.token_ids_logprob, list):
|
||||
self.token_ids_logprob = [[self.token_ids_logprob] for _ in range(num)]
|
||||
elif not isinstance(self.token_ids_logprob[0], list):
|
||||
self.token_ids_logprob = [
|
||||
copy.deepcopy(self.token_ids_logprob) for _ in range(num)
|
||||
]
|
||||
else:
|
||||
assert self.parallel_sample_num == 1
|
||||
|
||||
if self.custom_logit_processor is None:
|
||||
self.custom_logit_processor = [None] * num
|
||||
elif not isinstance(self.custom_logit_processor, list):
|
||||
@@ -198,6 +214,12 @@ class GenerateReqInput:
|
||||
else:
|
||||
assert self.parallel_sample_num == 1
|
||||
|
||||
# Other checks
|
||||
if self.session_params is not None:
|
||||
assert isinstance(self.session_params, dict) or isinstance(
|
||||
self.session_params[0], dict
|
||||
)
|
||||
|
||||
def regenerate_rid(self):
|
||||
self.rid = uuid.uuid4().hex
|
||||
return self.rid
|
||||
@@ -212,6 +234,7 @@ class GenerateReqInput:
|
||||
return_logprob=self.return_logprob[i],
|
||||
logprob_start_len=self.logprob_start_len[i],
|
||||
top_logprobs_num=self.top_logprobs_num[i],
|
||||
token_ids_logprob=self.token_ids_logprob[i],
|
||||
return_text_in_logprobs=self.return_text_in_logprobs,
|
||||
stream=self.stream,
|
||||
log_metrics=self.log_metrics,
|
||||
@@ -244,6 +267,8 @@ class TokenizedGenerateReqInput:
|
||||
logprob_start_len: int
|
||||
# If return logprobs, the number of top logprobs to return at each position.
|
||||
top_logprobs_num: int
|
||||
# If return logprobs, the token id to return logprob for
|
||||
token_ids_logprob: List[int]
|
||||
# Whether to stream output
|
||||
stream: bool
|
||||
|
||||
@@ -378,10 +403,21 @@ class BatchTokenIDOut:
|
||||
input_top_logprobs_idx: List[List]
|
||||
output_top_logprobs_val: List[List]
|
||||
output_top_logprobs_idx: List[List]
|
||||
input_token_ids_logprobs_val: List[List]
|
||||
input_token_ids_logprobs_idx: List[List]
|
||||
output_token_ids_logprobs_val: List[List]
|
||||
output_token_ids_logprobs_idx: List[List]
|
||||
|
||||
# Hidden states
|
||||
output_hidden_states: List[List[float]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchMultimodalDecodeReq:
|
||||
# The request id
|
||||
rids: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchStrOut:
|
||||
# The request id
|
||||
@@ -406,10 +442,21 @@ class BatchStrOut:
|
||||
input_top_logprobs_idx: List[List]
|
||||
output_top_logprobs_val: List[List]
|
||||
output_top_logprobs_idx: List[List]
|
||||
input_token_ids_logprobs_val: List[List]
|
||||
input_token_ids_logprobs_idx: List[List]
|
||||
output_token_ids_logprobs_val: List[List]
|
||||
output_token_ids_logprobs_idx: List[List]
|
||||
|
||||
# Hidden states
|
||||
output_hidden_states: List[List[float]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchMultimodalOut:
|
||||
# The request id
|
||||
rids: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchEmbeddingOut:
|
||||
# The request id
|
||||
@@ -439,6 +486,8 @@ class UpdateWeightFromDiskReqInput:
|
||||
class UpdateWeightFromDiskReqOutput:
|
||||
success: bool
|
||||
message: str
|
||||
# Number of paused requests during weight sync.
|
||||
num_paused_requests: Optional[int] = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -526,11 +575,57 @@ class AbortReq:
|
||||
rid: str
|
||||
|
||||
|
||||
class ProfileReq(Enum):
|
||||
@dataclass
|
||||
class GetInternalStateReq:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetInternalStateReqOutput:
|
||||
internal_state: Dict[Any, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SetInternalStateReq:
|
||||
server_args: Dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SetInternalStateReqOutput:
|
||||
updated: bool
|
||||
server_args: Dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProfileReqInput:
|
||||
# The output directory
|
||||
output_dir: Optional[str] = None
|
||||
# If set, it profile as many as this number of steps.
|
||||
# If it is set, profiling is automatically stopped after this step, and
|
||||
# the caller doesn't need to run stop_profile.
|
||||
num_steps: Optional[int] = None
|
||||
activities: Optional[List[str]] = None
|
||||
|
||||
|
||||
class ProfileReqType(Enum):
|
||||
START_PROFILE = 1
|
||||
STOP_PROFILE = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProfileReq:
|
||||
type: ProfileReqType
|
||||
output_dir: Optional[str] = None
|
||||
num_steps: Optional[int] = None
|
||||
activities: Optional[List[str]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProfileReqOutput:
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfigureLoggingReq:
|
||||
log_requests: Optional[bool] = None
|
||||
@@ -556,6 +651,11 @@ class OpenSessionReqOutput:
|
||||
success: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class HealthCheckOutput:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Function:
|
||||
description: Optional[str] = None
|
||||
|
||||
@@ -29,6 +29,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
||||
It contains low-level tensor data. Most of the data consists of GPU tensors.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
|
||||
@@ -50,7 +51,10 @@ from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
|
||||
|
||||
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||
|
||||
@@ -65,6 +69,8 @@ global_server_args_dict = {
|
||||
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
||||
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
||||
"device": ServerArgs.device,
|
||||
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
|
||||
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
|
||||
"enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
|
||||
"disable_radix_cache": ServerArgs.disable_radix_cache,
|
||||
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
|
||||
@@ -230,6 +236,7 @@ class Req:
|
||||
sampling_params: SamplingParams,
|
||||
return_logprob: bool = False,
|
||||
top_logprobs_num: int = 0,
|
||||
token_ids_logprob: List[int] = None,
|
||||
stream: bool = False,
|
||||
origin_input_ids_unpadded: Optional[Tuple[int]] = None,
|
||||
lora_path: Optional[str] = None,
|
||||
@@ -256,17 +263,24 @@ class Req:
|
||||
self.input_embeds = input_embeds
|
||||
|
||||
# Sampling info
|
||||
if isinstance(sampling_params.custom_params, dict):
|
||||
sampling_params = copy.copy(sampling_params)
|
||||
sampling_params.custom_params = sampling_params.custom_params | {
|
||||
"__req__": self
|
||||
}
|
||||
self.sampling_params = sampling_params
|
||||
|
||||
self.custom_logit_processor = custom_logit_processor
|
||||
self.return_hidden_states = return_hidden_states
|
||||
|
||||
# Memory pool info
|
||||
self.req_pool_idx = None
|
||||
self.req_pool_idx: Optional[int] = None
|
||||
|
||||
# Check finish
|
||||
self.tokenizer = None
|
||||
self.finished_reason = None
|
||||
# If we want to abort the request in the middle of the event loop, set this to true
|
||||
# Note: We should never set finished_reason in the middle, the req will get filtered and never respond
|
||||
self.to_abort = False
|
||||
self.stream = stream
|
||||
self.eos_token_ids = eos_token_ids
|
||||
@@ -289,38 +303,56 @@ class Req:
|
||||
self.image_inputs: Optional[ImageInputs] = None
|
||||
|
||||
# Prefix info
|
||||
# The indices to kv cache for the shared prefix.
|
||||
self.prefix_indices = []
|
||||
# Tokens to run prefill. input_tokens - shared_prefix_tokens.
|
||||
# Updated if chunked.
|
||||
# Number of tokens to run prefill.
|
||||
self.extend_input_len = 0
|
||||
# The relative logprob_start_len in an extend batch
|
||||
self.extend_logprob_start_len = 0
|
||||
self.last_node = None
|
||||
|
||||
# Chunked prefill
|
||||
self.is_being_chunked = 0
|
||||
# Whether or not if it is chunked. It increments whenever
|
||||
# it is chunked, and decrement whenever chunked request is
|
||||
# processed.
|
||||
self.is_chunked = 0
|
||||
|
||||
# For retraction
|
||||
self.is_retracted = False
|
||||
|
||||
# Logprobs (arguments)
|
||||
self.return_logprob = return_logprob
|
||||
# Start index to compute logprob from.
|
||||
self.logprob_start_len = 0
|
||||
self.top_logprobs_num = top_logprobs_num
|
||||
self.token_ids_logprob = token_ids_logprob
|
||||
|
||||
# Logprobs (return values)
|
||||
self.input_token_logprobs_val: Optional[List[float]] = None
|
||||
self.input_token_logprobs_idx: Optional[List[int]] = None
|
||||
self.input_top_logprobs_val: Optional[List[float]] = None
|
||||
self.input_top_logprobs_idx: Optional[List[int]] = None
|
||||
self.input_token_ids_logprobs_val: Optional[List[float]] = None
|
||||
self.input_token_ids_logprobs_idx: Optional[List[int]] = None
|
||||
# Temporary holder to store input_token_logprobs.
|
||||
self.input_token_logprobs: Optional[List[Tuple[int]]] = None
|
||||
self.temp_input_top_logprobs_val: Optional[List[torch.Tensor]] = None
|
||||
self.temp_input_top_logprobs_idx: Optional[List[int]] = None
|
||||
self.temp_input_token_ids_logprobs_val: Optional[List[float]] = None
|
||||
self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
|
||||
|
||||
if return_logprob:
|
||||
self.output_token_logprobs_val = []
|
||||
self.output_token_logprobs_idx = []
|
||||
self.output_top_logprobs_val = []
|
||||
self.output_top_logprobs_idx = []
|
||||
self.output_token_ids_logprobs_val = []
|
||||
self.output_token_ids_logprobs_idx = []
|
||||
else:
|
||||
self.output_token_logprobs_val = self.output_token_logprobs_idx = (
|
||||
self.output_top_logprobs_val
|
||||
) = self.output_top_logprobs_idx = None
|
||||
) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = (
|
||||
self.output_token_ids_logprobs_idx
|
||||
) = None
|
||||
self.hidden_states = []
|
||||
|
||||
# Logprobs (internal values)
|
||||
@@ -345,6 +377,13 @@ class Req:
|
||||
self.spec_verify_ct = 0
|
||||
self.lora_path = lora_path
|
||||
|
||||
# This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
|
||||
self.to_abort_message: str = "Unknown error"
|
||||
|
||||
@property
|
||||
def seqlen(self):
|
||||
return len(self.origin_input_ids) + len(self.output_ids)
|
||||
|
||||
def extend_image_inputs(self, image_inputs):
|
||||
if self.image_inputs is None:
|
||||
self.image_inputs = image_inputs
|
||||
@@ -422,7 +461,9 @@ class Req:
|
||||
return
|
||||
|
||||
if self.to_abort:
|
||||
self.finished_reason = FINISH_ABORT()
|
||||
self.finished_reason = FINISH_ABORT(
|
||||
message=self.to_abort_message,
|
||||
)
|
||||
return
|
||||
|
||||
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
|
||||
@@ -517,6 +558,8 @@ class Req:
|
||||
self.output_token_logprobs_idx = self.output_token_logprobs_idx[:k]
|
||||
self.output_top_logprobs_val = self.output_top_logprobs_val[:k]
|
||||
self.output_top_logprobs_idx = self.output_top_logprobs_idx[:k]
|
||||
self.output_token_ids_logprobs_val = self.output_token_ids_logprobs_val[:k]
|
||||
self.output_token_ids_logprobs_idx = self.output_token_ids_logprobs_idx[:k]
|
||||
self.logprob_start_len = prompt_tokens + k
|
||||
self.last_update_decode_tokens = len(self.output_ids) - k
|
||||
|
||||
@@ -527,16 +570,19 @@ class Req:
|
||||
self.last_node = None
|
||||
self.extend_input_len = 0
|
||||
self.is_retracted = True
|
||||
self.input_token_logprobs = None
|
||||
self.temp_input_top_logprobs_val = None
|
||||
self.temp_input_top_logprobs_idx = None
|
||||
self.extend_logprob_start_len = 0
|
||||
self.is_chunked = 0
|
||||
self.req_pool_idx = None
|
||||
|
||||
# For incremental logprobs
|
||||
# TODO: Fix the `logprob_start_len`
|
||||
self.last_update_decode_tokens = 0
|
||||
self.logprob_start_len = 10**9
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"rid(n={self.rid}, "
|
||||
f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}"
|
||||
f"Req(rid={self.rid}, "
|
||||
f"input_ids={self.origin_input_ids}, output_ids={self.output_ids})"
|
||||
)
|
||||
|
||||
|
||||
@@ -576,11 +622,13 @@ class ScheduleBatch:
|
||||
|
||||
# For DP attention
|
||||
global_num_tokens: Optional[List[int]] = None
|
||||
global_num_tokens_for_logprob: Optional[List[int]] = None
|
||||
can_run_dp_cuda_graph: bool = False
|
||||
|
||||
# For processing logprobs
|
||||
return_logprob: bool = False
|
||||
top_logprobs_nums: Optional[List[int]] = None
|
||||
token_ids_logprobs: Optional[List[List[int]]] = None
|
||||
|
||||
# For extend and mixed chunekd prefill
|
||||
prefix_lens: List[int] = None
|
||||
@@ -588,6 +636,8 @@ class ScheduleBatch:
|
||||
extend_num_tokens: int = None
|
||||
decoding_reqs: List[Req] = None
|
||||
extend_logprob_start_lens: List[int] = None
|
||||
# It comes empty list if logprob is not required.
|
||||
extend_input_logprob_token_ids: Optional[torch.Tensor] = None
|
||||
|
||||
# For encoder-decoder
|
||||
encoder_cached: Optional[List[bool]] = None
|
||||
@@ -606,7 +656,7 @@ class ScheduleBatch:
|
||||
|
||||
# Speculative decoding
|
||||
spec_algorithm: SpeculativeAlgorithm = None
|
||||
spec_info: Optional[SpecInfo] = None
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
|
||||
|
||||
# Enable custom logit processor
|
||||
enable_custom_logit_processor: bool = False
|
||||
@@ -653,8 +703,10 @@ class ScheduleBatch:
|
||||
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
|
||||
if req_pool_indices is None:
|
||||
raise RuntimeError(
|
||||
"Out of memory. "
|
||||
"Please set a smaller number for `--max-running-requests`."
|
||||
"alloc_req_slots runs out of memory. "
|
||||
"Please set a smaller number for `--max-running-requests`. "
|
||||
f"{self.req_to_token_pool.available_size()=}, "
|
||||
f"{num_reqs=}, "
|
||||
)
|
||||
return req_pool_indices
|
||||
|
||||
@@ -765,6 +817,7 @@ class ScheduleBatch:
|
||||
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
||||
|
||||
input_embeds = []
|
||||
extend_input_logprob_token_ids = []
|
||||
|
||||
pt = 0
|
||||
for i, req in enumerate(reqs):
|
||||
@@ -783,22 +836,64 @@ class ScheduleBatch:
|
||||
# If req.input_embeds is already a list, append its content directly
|
||||
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
|
||||
|
||||
if req.return_logprob:
|
||||
# Compute the relative logprob_start_len in an extend batch
|
||||
if req.logprob_start_len >= pre_len:
|
||||
extend_logprob_start_len = min(
|
||||
req.logprob_start_len - pre_len, req.extend_input_len - 1
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"This should never happen. {req.logprob_start_len=}, {pre_len=}"
|
||||
)
|
||||
req.extend_logprob_start_len = extend_logprob_start_len
|
||||
|
||||
req.cached_tokens += pre_len - req.already_computed
|
||||
req.already_computed = seq_len
|
||||
req.is_retracted = False
|
||||
pre_lens.append(pre_len)
|
||||
# Compute the relative logprob_start_len in an extend batch
|
||||
if req.logprob_start_len >= pre_len:
|
||||
req.extend_logprob_start_len = min(
|
||||
req.logprob_start_len - pre_len,
|
||||
req.extend_input_len,
|
||||
req.seqlen - 1,
|
||||
)
|
||||
else:
|
||||
req.extend_logprob_start_len = 0
|
||||
|
||||
if self.return_logprob:
|
||||
# Find input logprob token ids.
|
||||
# First, find a global index within origin_input_ids and slide it by 1
|
||||
# to compute input logprobs. It is because you need the next token
|
||||
# to compute input logprobs. E.g., (chunk size 2)
|
||||
#
|
||||
# input_logprobs = [1, 2, 3, 4]
|
||||
# fill_ids = [1, 2]
|
||||
# extend_input_logprob_token_id = [2, 3]
|
||||
#
|
||||
# Note that it can also overflow. In this case, we pad it with 0.
|
||||
# input_logprobs = [1, 2, 3, 4]
|
||||
# fill_ids = [3, 4]
|
||||
# extend_input_logprob_token_id = [4, 0]
|
||||
global_start_idx, global_end_idx = (
|
||||
len(req.prefix_indices),
|
||||
len(req.fill_ids),
|
||||
)
|
||||
# Apply logprob_start_len
|
||||
if global_start_idx < req.logprob_start_len:
|
||||
global_start_idx = req.logprob_start_len
|
||||
|
||||
logprob_token_ids = req.origin_input_ids[
|
||||
global_start_idx + 1 : global_end_idx + 1
|
||||
]
|
||||
extend_input_logprob_token_ids.extend(logprob_token_ids)
|
||||
|
||||
# We will need req.extend_input_len - req.extend_logprob_start_len number of
|
||||
# tokens, and logprob_token_ids is for input logprob, so pad the rest of them by 0.
|
||||
extend_input_logprob_token_ids.extend(
|
||||
[0]
|
||||
* (
|
||||
req.extend_input_len
|
||||
- req.extend_logprob_start_len
|
||||
- len(logprob_token_ids)
|
||||
)
|
||||
)
|
||||
|
||||
if self.return_logprob:
|
||||
extend_input_logprob_token_ids = torch.tensor(
|
||||
extend_input_logprob_token_ids
|
||||
)
|
||||
else:
|
||||
extend_input_logprob_token_ids = None
|
||||
|
||||
# Set fields
|
||||
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
|
||||
@@ -821,10 +916,12 @@ class ScheduleBatch:
|
||||
self.seq_lens_sum = sum(seq_lens)
|
||||
if self.return_logprob:
|
||||
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
||||
self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
|
||||
self.extend_num_tokens = extend_num_tokens
|
||||
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
|
||||
self.extend_lens = [r.extend_input_len for r in reqs]
|
||||
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
||||
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
|
||||
|
||||
# Write to req_to_token_pool
|
||||
pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to(
|
||||
@@ -860,7 +957,6 @@ class ScheduleBatch:
|
||||
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
||||
self,
|
||||
self.model_config.vocab_size,
|
||||
enable_overlap_schedule=self.enable_overlap,
|
||||
)
|
||||
|
||||
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
||||
@@ -905,25 +1001,43 @@ class ScheduleBatch:
|
||||
|
||||
return False
|
||||
|
||||
def retract_decode(self):
|
||||
def retract_decode(self, server_args: ServerArgs):
|
||||
"""Retract the decoding requests when there is not enough memory."""
|
||||
sorted_indices = [i for i in range(len(self.reqs))]
|
||||
|
||||
# TODO(lsyin): improve retraction policy for radix cache
|
||||
sorted_indices.sort(
|
||||
key=lambda i: (
|
||||
len(self.reqs[i].output_ids),
|
||||
-len(self.reqs[i].origin_input_ids),
|
||||
),
|
||||
reverse=True,
|
||||
)
|
||||
# For spec decoding, filter_batch API can only filter
|
||||
# requests from the back, so we can only retract from the back.
|
||||
# TODO(sang): Clean up finish path and support better retract
|
||||
# policy.
|
||||
if not server_args.speculative_algorithm:
|
||||
sorted_indices.sort(
|
||||
key=lambda i: (
|
||||
len(self.reqs[i].output_ids),
|
||||
-len(self.reqs[i].origin_input_ids),
|
||||
),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
def get_required_tokens(num_reqs: int):
|
||||
headroom_for_spec_decode = 0
|
||||
if server_args.speculative_algorithm:
|
||||
headroom_for_spec_decode += (
|
||||
num_reqs
|
||||
* server_args.speculative_eagle_topk
|
||||
* server_args.speculative_num_steps
|
||||
+ num_reqs * server_args.speculative_num_draft_tokens
|
||||
)
|
||||
return (
|
||||
num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
|
||||
)
|
||||
|
||||
retracted_reqs = []
|
||||
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
||||
first_iter = True
|
||||
while (
|
||||
self.token_to_kv_pool.available_size()
|
||||
< len(sorted_indices) * global_config.retract_decode_steps
|
||||
< get_required_tokens(len(sorted_indices))
|
||||
or first_iter
|
||||
):
|
||||
if len(sorted_indices) == 1:
|
||||
@@ -1048,17 +1162,40 @@ class ScheduleBatch:
|
||||
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
||||
self,
|
||||
self.model_config.vocab_size,
|
||||
enable_overlap_schedule=self.enable_overlap,
|
||||
)
|
||||
|
||||
def prepare_for_decode(self):
|
||||
self.forward_mode = ForwardMode.DECODE
|
||||
if self.spec_algorithm.is_eagle():
|
||||
# if spec decoding is used, the decode batch is prepared inside
|
||||
# `forward_batch_speculative_generation` after running draft models.
|
||||
return
|
||||
|
||||
if self.sampling_info.penalizer_orchestrator.is_required:
|
||||
if self.enable_overlap:
|
||||
# TODO: this can be slow, optimize this.
|
||||
delayed_output_ids = torch.tensor(
|
||||
[
|
||||
(
|
||||
req.output_ids[-1]
|
||||
if len(req.output_ids)
|
||||
else req.origin_input_ids[-1]
|
||||
)
|
||||
for req in self.reqs
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device=self.device,
|
||||
)
|
||||
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||
delayed_output_ids
|
||||
)
|
||||
else:
|
||||
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||
self.output_ids.to(torch.int64)
|
||||
)
|
||||
|
||||
self.input_ids = self.output_ids
|
||||
self.output_ids = None
|
||||
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(self.input_ids)
|
||||
|
||||
# Alloc mem
|
||||
bs = len(self.reqs)
|
||||
@@ -1086,14 +1223,15 @@ class ScheduleBatch:
|
||||
|
||||
def filter_batch(
|
||||
self,
|
||||
being_chunked_req: Optional[Req] = None,
|
||||
chunked_req_to_exclude: Optional[Req] = None,
|
||||
keep_indices: Optional[List[int]] = None,
|
||||
):
|
||||
if keep_indices is None:
|
||||
keep_indices = [
|
||||
i
|
||||
for i in range(len(self.reqs))
|
||||
if not self.reqs[i].finished() and self.reqs[i] is not being_chunked_req
|
||||
if not self.reqs[i].finished()
|
||||
and self.reqs[i] is not chunked_req_to_exclude
|
||||
]
|
||||
|
||||
if keep_indices is None or len(keep_indices) == 0:
|
||||
@@ -1105,31 +1243,34 @@ class ScheduleBatch:
|
||||
# No need to filter
|
||||
return
|
||||
|
||||
keep_indices_device = torch.tensor(keep_indices, dtype=torch.int64).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
|
||||
if self.model_config.is_encoder_decoder:
|
||||
self.encoder_lens = self.encoder_lens[keep_indices]
|
||||
self.encoder_lens = self.encoder_lens[keep_indices_device]
|
||||
self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
|
||||
|
||||
self.reqs = [self.reqs[i] for i in keep_indices]
|
||||
new_indices = torch.tensor(keep_indices, dtype=torch.int64).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.req_pool_indices = self.req_pool_indices[new_indices]
|
||||
self.seq_lens = self.seq_lens[new_indices]
|
||||
self.req_pool_indices = self.req_pool_indices[keep_indices_device]
|
||||
self.seq_lens = self.seq_lens[keep_indices_device]
|
||||
self.out_cache_loc = None
|
||||
self.seq_lens_sum = self.seq_lens.sum().item()
|
||||
self.output_ids = self.output_ids[new_indices]
|
||||
self.output_ids = self.output_ids[keep_indices_device]
|
||||
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||
if self.return_logprob:
|
||||
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
|
||||
self.token_ids_logprobs = [self.token_ids_logprobs[i] for i in keep_indices]
|
||||
else:
|
||||
self.top_logprobs_nums = None
|
||||
self.token_ids_logprobs = None
|
||||
|
||||
self.has_stream = any(req.stream for req in self.reqs)
|
||||
self.has_grammar = any(req.grammar for req in self.reqs)
|
||||
|
||||
self.sampling_info.filter_batch(keep_indices, new_indices)
|
||||
self.sampling_info.filter_batch(keep_indices, keep_indices_device)
|
||||
if self.spec_info:
|
||||
self.spec_info.filter_batch(new_indices)
|
||||
self.spec_info.filter_batch(keep_indices_device)
|
||||
|
||||
def merge_batch(self, other: "ScheduleBatch"):
|
||||
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
||||
@@ -1152,10 +1293,13 @@ class ScheduleBatch:
|
||||
self.output_ids = torch.concat([self.output_ids, other.output_ids])
|
||||
if self.return_logprob and other.return_logprob:
|
||||
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
||||
self.token_ids_logprobs.extend(other.token_ids_logprobs)
|
||||
elif self.return_logprob:
|
||||
self.top_logprobs_nums.extend([0] * len(other.reqs))
|
||||
self.token_ids_logprobs.extend([None] * len(other.reqs))
|
||||
elif other.return_logprob:
|
||||
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
|
||||
self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
|
||||
self.reqs.extend(other.reqs)
|
||||
|
||||
self.return_logprob |= other.return_logprob
|
||||
@@ -1192,7 +1336,9 @@ class ScheduleBatch:
|
||||
seq_lens_sum=self.seq_lens_sum,
|
||||
return_logprob=self.return_logprob,
|
||||
top_logprobs_nums=self.top_logprobs_nums,
|
||||
token_ids_logprobs=self.token_ids_logprobs,
|
||||
global_num_tokens=self.global_num_tokens,
|
||||
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
|
||||
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
||||
extend_num_tokens=self.extend_num_tokens,
|
||||
extend_seq_lens=extend_seq_lens,
|
||||
@@ -1219,6 +1365,7 @@ class ScheduleBatch:
|
||||
else CaptureHiddenMode.NULL
|
||||
)
|
||||
),
|
||||
extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
|
||||
)
|
||||
|
||||
def copy(self):
|
||||
@@ -1262,9 +1409,11 @@ class ModelWorkerBatch:
|
||||
# For logprob
|
||||
return_logprob: bool
|
||||
top_logprobs_nums: Optional[List[int]]
|
||||
token_ids_logprobs: Optional[List[List[int]]]
|
||||
|
||||
# For DP attention
|
||||
global_num_tokens: Optional[List[int]]
|
||||
global_num_tokens_for_logprob: Optional[List[int]]
|
||||
can_run_dp_cuda_graph: bool
|
||||
|
||||
# For extend
|
||||
@@ -1272,6 +1421,7 @@ class ModelWorkerBatch:
|
||||
extend_seq_lens: Optional[List[int]]
|
||||
extend_prefix_lens: Optional[List[int]]
|
||||
extend_logprob_start_lens: Optional[List[int]]
|
||||
extend_input_logprob_token_ids: Optional[torch.Tensor]
|
||||
|
||||
# For multimodal
|
||||
image_inputs: Optional[List[ImageInputs]]
|
||||
@@ -1293,7 +1443,8 @@ class ModelWorkerBatch:
|
||||
|
||||
# Speculative decoding
|
||||
spec_algorithm: SpeculativeAlgorithm = None
|
||||
spec_info: Optional[SpecInfo] = None
|
||||
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
|
||||
# If set, the output of the batch contains the hidden states of the run.
|
||||
capture_hidden_mode: CaptureHiddenMode = None
|
||||
|
||||
|
||||
|
||||
@@ -272,7 +272,7 @@ class PrefillAdder:
|
||||
|
||||
self.req_states = None
|
||||
self.can_run_list = []
|
||||
self.new_being_chunked_req = None
|
||||
self.new_chunked_req = None
|
||||
self.log_hit_tokens = 0
|
||||
self.log_input_tokens = 0
|
||||
|
||||
@@ -327,7 +327,7 @@ class PrefillAdder:
|
||||
self.log_hit_tokens += prefix_len
|
||||
self.log_input_tokens += extend_input_len
|
||||
|
||||
def add_being_chunked_req(self, req: Req):
|
||||
def add_chunked_req(self, req: Req):
|
||||
truncated = req.extend_input_len > self.rem_chunk_tokens
|
||||
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
|
||||
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
|
||||
@@ -354,7 +354,7 @@ class PrefillAdder:
|
||||
finally:
|
||||
self.tree_cache.dec_lock_ref(last_node)
|
||||
|
||||
def add_one_req_ignore_eos(self, req: Req):
|
||||
def add_one_req_ignore_eos(self, req: Req, has_chunked_req: bool):
|
||||
def add_req_state(r, insert_sort=False):
|
||||
new_token_ratio = (
|
||||
1.0 if r.sampling_params.ignore_eos else self.new_token_ratio
|
||||
@@ -403,6 +403,7 @@ class PrefillAdder:
|
||||
self.rem_chunk_tokens is None
|
||||
or req.extend_input_len <= self.rem_chunk_tokens
|
||||
):
|
||||
# Non-chunked prefill
|
||||
self.can_run_list.append(req)
|
||||
self._prefill_one_req(
|
||||
0,
|
||||
@@ -418,14 +419,14 @@ class PrefillAdder:
|
||||
req.extend_input_len = trunc_len
|
||||
req.fill_ids = req.fill_ids[:trunc_len]
|
||||
self.can_run_list.append(req)
|
||||
self.new_being_chunked_req = req
|
||||
self.new_chunked_req = req
|
||||
self._prefill_one_req(0, trunc_len, 0)
|
||||
|
||||
return self.budget_state()
|
||||
|
||||
def add_one_req(self, req: Req):
|
||||
def add_one_req(self, req: Req, has_chunked_req: bool):
|
||||
if req.sampling_params.ignore_eos and self.tree_cache.disable:
|
||||
return self.add_one_req_ignore_eos(req)
|
||||
return self.add_one_req_ignore_eos(req, has_chunked_req)
|
||||
|
||||
total_tokens = req.extend_input_len + min(
|
||||
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
|
||||
@@ -443,14 +444,7 @@ class PrefillAdder:
|
||||
if total_tokens > self.rem_total_tokens:
|
||||
return AddReqResult.NO_TOKEN
|
||||
|
||||
if (
|
||||
self.rem_chunk_tokens is None
|
||||
or input_tokens <= self.rem_chunk_tokens
|
||||
or (
|
||||
req.return_logprob
|
||||
and req.logprob_start_len != len(req.origin_input_ids) - 1
|
||||
)
|
||||
):
|
||||
if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
|
||||
# Non-chunked prefill
|
||||
self.can_run_list.append(req)
|
||||
self.tree_cache.inc_lock_ref(req.last_node)
|
||||
@@ -470,8 +464,9 @@ class PrefillAdder:
|
||||
|
||||
req.extend_input_len = trunc_len
|
||||
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
|
||||
|
||||
self.can_run_list.append(req)
|
||||
self.new_being_chunked_req = req
|
||||
self.new_chunked_req = req
|
||||
self.tree_cache.inc_lock_ref(req.last_node)
|
||||
self._prefill_one_req(prefix_len, trunc_len, 0)
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -35,12 +35,12 @@ class SessionReqNode:
|
||||
for req_node in self.childs:
|
||||
req_node.clear(req_dict)
|
||||
|
||||
if self.req.finished_reason == None:
|
||||
if self.req.finished_reason is None:
|
||||
self.req.to_abort = True
|
||||
del req_dict[self.req.rid]
|
||||
|
||||
def abort(self):
|
||||
if self.req.finished_reason == None:
|
||||
if self.req.finished_reason is None:
|
||||
self.req.to_abort = True
|
||||
|
||||
def __str__(self):
|
||||
@@ -132,6 +132,10 @@ class Session:
|
||||
lora_path=req.lora_path,
|
||||
session_id=self.session_id,
|
||||
custom_logit_processor=req.custom_logit_processor,
|
||||
stream=req.stream,
|
||||
return_logprob=req.return_logprob,
|
||||
top_logprobs_num=req.top_logprobs_num,
|
||||
token_ids_logprob=req.token_ids_logprob,
|
||||
)
|
||||
if last_req is not None:
|
||||
new_req.image_inputs = last_req.image_inputs
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
@@ -24,9 +25,21 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Deque,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
import fastapi
|
||||
import uvloop
|
||||
@@ -44,6 +57,7 @@ from sglang.srt.managers.image_processor import (
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
BatchEmbeddingOut,
|
||||
BatchMultimodalOut,
|
||||
BatchStrOut,
|
||||
BatchTokenIDOut,
|
||||
CloseSessionReqInput,
|
||||
@@ -51,18 +65,25 @@ from sglang.srt.managers.io_struct import (
|
||||
EmbeddingReqInput,
|
||||
FlushCacheReq,
|
||||
GenerateReqInput,
|
||||
GetInternalStateReq,
|
||||
GetInternalStateReqOutput,
|
||||
GetWeightsByNameReqInput,
|
||||
GetWeightsByNameReqOutput,
|
||||
HealthCheckOutput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
InitWeightsUpdateGroupReqOutput,
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
ProfileReqOutput,
|
||||
ProfileReqType,
|
||||
ReleaseMemoryOccupationReqInput,
|
||||
ReleaseMemoryOccupationReqOutput,
|
||||
ResumeMemoryOccupationReqInput,
|
||||
ResumeMemoryOccupationReqOutput,
|
||||
SessionParams,
|
||||
SetInternalStateReq,
|
||||
SetInternalStateReqOutput,
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
@@ -98,7 +119,10 @@ class ReqState:
|
||||
|
||||
# For metrics
|
||||
created_time: float
|
||||
first_token_time: Optional[float] = None
|
||||
finished_time: float = 0.0
|
||||
first_token_time: float = 0.0
|
||||
last_time: float = 0.0
|
||||
last_completion_tokens: int = 1
|
||||
|
||||
# For streaming output
|
||||
last_output_offset: int = 0
|
||||
@@ -113,11 +137,10 @@ class TokenizerManager:
|
||||
port_args: PortArgs,
|
||||
):
|
||||
# Parse args
|
||||
|
||||
self.server_args = server_args
|
||||
self.enable_metrics = server_args.enable_metrics
|
||||
self.log_requests = server_args.log_requests
|
||||
self.log_requests_level = 0
|
||||
self.log_requests_level = server_args.log_requests_level
|
||||
|
||||
# Init inter-process communication
|
||||
context = zmq.asyncio.Context(2)
|
||||
@@ -143,6 +166,7 @@ class TokenizerManager:
|
||||
)
|
||||
|
||||
self.is_generation = self.model_config.is_generation
|
||||
self.is_image_gen = self.model_config.is_image_gen
|
||||
self.context_len = self.model_config.context_len
|
||||
self.image_token_id = self.model_config.image_token_id
|
||||
|
||||
@@ -178,9 +202,12 @@ class TokenizerManager:
|
||||
# Store states
|
||||
self.no_create_loop = False
|
||||
self.rid_to_state: Dict[str, ReqState] = {}
|
||||
self.gracefully_exit = False
|
||||
self.last_receive_tstamp = 0
|
||||
self.dump_requests_folder = "" # By default do not dump
|
||||
self.dump_requests_threshold = 1000
|
||||
self.dump_request_list: List[Tuple] = []
|
||||
self.log_request_metadata = self.get_log_request_metadata()
|
||||
|
||||
# The event to notify the weight sync is finished.
|
||||
self.model_update_lock = RWLock()
|
||||
@@ -192,8 +219,19 @@ class TokenizerManager:
|
||||
# For session info
|
||||
self.session_futures = {} # session_id -> asyncio event
|
||||
|
||||
# Others
|
||||
self.gracefully_exit = False
|
||||
# Set after scheduler is initialized
|
||||
self.max_req_input_len = None
|
||||
|
||||
# Metrics
|
||||
if self.enable_metrics:
|
||||
self.metrics_collector = TokenizerMetricsCollector(
|
||||
labels={
|
||||
"model_name": self.server_args.served_model_name,
|
||||
# TODO: Add lora name/path in the future,
|
||||
},
|
||||
)
|
||||
|
||||
# Communicators
|
||||
self.init_weights_update_group_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
@@ -212,22 +250,26 @@ class TokenizerManager:
|
||||
self.resume_memory_occupation_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
# Set after scheduler is initialized
|
||||
self.max_req_input_len = None
|
||||
|
||||
# Metrics
|
||||
if self.enable_metrics:
|
||||
self.metrics_collector = TokenizerMetricsCollector(
|
||||
labels={
|
||||
"model_name": self.server_args.served_model_name,
|
||||
# TODO: Add lora name/path in the future,
|
||||
},
|
||||
)
|
||||
self.start_profile_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.health_check_communitcator = _Communicator(self.send_to_scheduler, 1)
|
||||
self.get_internal_state_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.set_internal_state_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
|
||||
self._result_dispatcher = TypeBasedDispatcher(
|
||||
[
|
||||
(
|
||||
(BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut),
|
||||
(
|
||||
BatchStrOut,
|
||||
BatchEmbeddingOut,
|
||||
BatchTokenIDOut,
|
||||
BatchMultimodalOut,
|
||||
),
|
||||
self._handle_batch_output,
|
||||
),
|
||||
(OpenSessionReqOutput, self._handle_open_session_req_output),
|
||||
@@ -259,6 +301,19 @@ class TokenizerManager:
|
||||
ResumeMemoryOccupationReqOutput,
|
||||
self.resume_memory_occupation_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
ProfileReqOutput,
|
||||
self.start_profile_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
GetInternalStateReqOutput,
|
||||
self.get_internal_state_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
SetInternalStateReqOutput,
|
||||
self.set_internal_state_communicator.handle_recv,
|
||||
),
|
||||
(HealthCheckOutput, lambda x: None),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -280,9 +335,9 @@ class TokenizerManager:
|
||||
obj.normalize_batch_and_arguments()
|
||||
|
||||
if self.log_requests:
|
||||
max_length = 2048 if self.log_requests_level == 0 else 1 << 30
|
||||
max_length, skip_names, _ = self.log_request_metadata
|
||||
logger.info(
|
||||
f"Receive: obj={dataclass_to_string_truncated(obj, max_length)}"
|
||||
f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
|
||||
)
|
||||
|
||||
async with self.model_update_lock.reader_lock:
|
||||
@@ -336,6 +391,7 @@ class TokenizerManager:
|
||||
return_logprob = obj.return_logprob
|
||||
logprob_start_len = obj.logprob_start_len
|
||||
top_logprobs_num = obj.top_logprobs_num
|
||||
token_ids_logprob = obj.token_ids_logprob
|
||||
session_params = (
|
||||
SessionParams(**obj.session_params) if obj.session_params else None
|
||||
)
|
||||
@@ -378,6 +434,7 @@ class TokenizerManager:
|
||||
return_logprob,
|
||||
logprob_start_len,
|
||||
top_logprobs_num,
|
||||
token_ids_logprob,
|
||||
obj.stream,
|
||||
lora_path=obj.lora_path,
|
||||
input_embeds=input_embeds,
|
||||
@@ -401,8 +458,7 @@ class TokenizerManager:
|
||||
tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
|
||||
created_time: Optional[float] = None,
|
||||
):
|
||||
event = asyncio.Event()
|
||||
state = ReqState([], False, event, obj, created_time=created_time)
|
||||
state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
|
||||
self.rid_to_state[obj.rid] = state
|
||||
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
||||
|
||||
@@ -420,7 +476,10 @@ class TokenizerManager:
|
||||
except asyncio.TimeoutError:
|
||||
if request is not None and await request.is_disconnected():
|
||||
self.abort_request(obj.rid)
|
||||
raise ValueError(f"Abort request {obj.rid}")
|
||||
raise ValueError(
|
||||
"Request is disconnected from the client side. "
|
||||
f"Abort request {obj.rid}"
|
||||
)
|
||||
continue
|
||||
|
||||
out = state.out_list[-1]
|
||||
@@ -428,8 +487,11 @@ class TokenizerManager:
|
||||
state.out_list = []
|
||||
if state.finished:
|
||||
if self.log_requests:
|
||||
max_length = 2048 if self.log_requests_level == 0 else 1 << 30
|
||||
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length)}, out={dataclass_to_string_truncated(out, max_length)}"
|
||||
max_length, skip_names, out_skip_names = self.log_request_metadata
|
||||
if self.model_config.is_multimodal_gen:
|
||||
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
|
||||
else:
|
||||
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
|
||||
logger.info(msg)
|
||||
del self.rid_to_state[obj.rid]
|
||||
|
||||
@@ -452,7 +514,10 @@ class TokenizerManager:
|
||||
else:
|
||||
if request is not None and await request.is_disconnected():
|
||||
self.abort_request(obj.rid)
|
||||
raise ValueError(f"Abort request {obj.rid}")
|
||||
raise ValueError(
|
||||
"Request is disconnected from the client side. "
|
||||
f"Abort request {obj.rid}"
|
||||
)
|
||||
|
||||
async def _handle_batch_request(
|
||||
self,
|
||||
@@ -543,12 +608,25 @@ class TokenizerManager:
|
||||
req = AbortReq(rid)
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
|
||||
def start_profile(self):
|
||||
req = ProfileReq.START_PROFILE
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
async def start_profile(
|
||||
self,
|
||||
output_dir: Optional[str] = None,
|
||||
num_steps: Optional[int] = None,
|
||||
activities: Optional[List[str]] = None,
|
||||
):
|
||||
req = ProfileReq(
|
||||
type=ProfileReqType.START_PROFILE,
|
||||
output_dir=output_dir,
|
||||
num_steps=num_steps,
|
||||
activities=activities,
|
||||
)
|
||||
result = (await self.start_profile_communicator(req))[0]
|
||||
if not result.success:
|
||||
raise RuntimeError(result.message)
|
||||
return result
|
||||
|
||||
def stop_profile(self):
|
||||
req = ProfileReq.STOP_PROFILE
|
||||
req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
|
||||
async def update_weights_from_disk(
|
||||
@@ -581,7 +659,7 @@ class TokenizerManager:
|
||||
self.server_args.model_path = obj.model_path
|
||||
self.server_args.load_format = obj.load_format
|
||||
self.model_path = obj.model_path
|
||||
return result.success, result.message
|
||||
return result.success, result.message, result.num_paused_requests
|
||||
else: # self.server_args.dp_size > 1
|
||||
self.model_update_tmp = []
|
||||
result = await self.model_update_result
|
||||
@@ -593,7 +671,8 @@ class TokenizerManager:
|
||||
self.model_path = obj.model_path
|
||||
all_message = [r.message for r in result]
|
||||
all_message = " | ".join(all_message)
|
||||
return all_success, all_message
|
||||
all_paused_requests = [r.num_paused_requests for r in result]
|
||||
return all_success, all_message, all_paused_requests
|
||||
|
||||
async def init_weights_update_group(
|
||||
self,
|
||||
@@ -688,6 +767,54 @@ class TokenizerManager:
|
||||
):
|
||||
await self.send_to_scheduler.send_pyobj(obj)
|
||||
|
||||
async def get_internal_state(self) -> Dict[Any, Any]:
|
||||
req = GetInternalStateReq()
|
||||
res: List[GetInternalStateReqOutput] = (
|
||||
await self.get_internal_state_communicator(req)
|
||||
)
|
||||
return res[0].internal_state
|
||||
|
||||
async def set_internal_state(
|
||||
self, obj: SetInternalStateReq
|
||||
) -> SetInternalStateReqOutput:
|
||||
res: List[SetInternalStateReqOutput] = (
|
||||
await self.set_internal_state_communicator(obj)
|
||||
)
|
||||
return res[0]
|
||||
|
||||
def get_log_request_metadata(self):
|
||||
max_length = None
|
||||
skip_names = None
|
||||
out_skip_names = None
|
||||
if self.log_requests:
|
||||
if self.log_requests_level == 0:
|
||||
max_length = 1 << 30
|
||||
skip_names = set(
|
||||
[
|
||||
"text",
|
||||
"input_ids",
|
||||
"input_embeds",
|
||||
"image_data",
|
||||
"audio_data",
|
||||
"lora_path",
|
||||
]
|
||||
)
|
||||
out_skip_names = set(
|
||||
[
|
||||
"text",
|
||||
"output_ids",
|
||||
]
|
||||
)
|
||||
elif self.log_requests_level == 1:
|
||||
max_length = 2048
|
||||
elif self.log_requests_level == 2:
|
||||
max_length = 1 << 30
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid --log-requests-level: {self.log_requests_level=}"
|
||||
)
|
||||
return max_length, skip_names, out_skip_names
|
||||
|
||||
def configure_logging(self, obj: ConfigureLoggingReq):
|
||||
if obj.log_requests is not None:
|
||||
self.log_requests = obj.log_requests
|
||||
@@ -698,6 +825,7 @@ class TokenizerManager:
|
||||
if obj.dump_requests_threshold is not None:
|
||||
self.dump_requests_threshold = obj.dump_requests_threshold
|
||||
logging.info(f"Config logging: {obj=}")
|
||||
self.log_request_metadata = self.get_log_request_metadata()
|
||||
|
||||
def create_abort_task(self, obj: GenerateReqInput):
|
||||
# Abort the request if the client is disconnected.
|
||||
@@ -762,15 +890,20 @@ class TokenizerManager:
|
||||
while True:
|
||||
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
||||
self._result_dispatcher(recv_obj)
|
||||
self.last_receive_tstamp = time.time()
|
||||
|
||||
def _handle_batch_output(
|
||||
self, recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut]
|
||||
self,
|
||||
recv_obj: Union[
|
||||
BatchStrOut, BatchEmbeddingOut, BatchMultimodalOut, BatchTokenIDOut
|
||||
],
|
||||
):
|
||||
for i, rid in enumerate(recv_obj.rids):
|
||||
state = self.rid_to_state.get(rid, None)
|
||||
if state is None:
|
||||
continue
|
||||
|
||||
# Build meta_info and return value
|
||||
meta_info = {
|
||||
"id": rid,
|
||||
"finish_reason": recv_obj.finished_reasons[i],
|
||||
@@ -781,14 +914,12 @@ class TokenizerManager:
|
||||
self.convert_logprob_style(
|
||||
meta_info,
|
||||
state.obj.top_logprobs_num,
|
||||
state.obj.token_ids_logprob,
|
||||
state.obj.return_text_in_logprobs,
|
||||
recv_obj,
|
||||
i,
|
||||
)
|
||||
|
||||
if self.server_args.speculative_algorithm:
|
||||
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
|
||||
|
||||
if not isinstance(recv_obj, BatchEmbeddingOut):
|
||||
meta_info.update(
|
||||
{
|
||||
@@ -806,10 +937,20 @@ class TokenizerManager:
|
||||
"meta_info": meta_info,
|
||||
}
|
||||
elif isinstance(recv_obj, BatchTokenIDOut):
|
||||
if self.server_args.stream_output and state.obj.stream:
|
||||
output_token_ids = recv_obj.output_ids[i][
|
||||
state.last_output_offset :
|
||||
]
|
||||
state.last_output_offset = len(recv_obj.output_ids[i])
|
||||
else:
|
||||
output_token_ids = recv_obj.output_ids[i]
|
||||
|
||||
out_dict = {
|
||||
"token_ids": recv_obj.output_ids[i],
|
||||
"output_ids": output_token_ids,
|
||||
"meta_info": meta_info,
|
||||
}
|
||||
elif isinstance(recv_obj, BatchMultimodalOut):
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
assert isinstance(recv_obj, BatchEmbeddingOut)
|
||||
out_dict = {
|
||||
@@ -817,10 +958,17 @@ class TokenizerManager:
|
||||
"meta_info": meta_info,
|
||||
}
|
||||
|
||||
state.out_list.append(out_dict)
|
||||
state.finished = recv_obj.finished_reasons[i] is not None
|
||||
if state.finished:
|
||||
if self.server_args.speculative_algorithm:
|
||||
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
|
||||
state.finished_time = time.time()
|
||||
meta_info["e2e_latency"] = state.finished_time - state.created_time
|
||||
|
||||
state.out_list.append(out_dict)
|
||||
state.event.set()
|
||||
|
||||
# Log metrics and dump
|
||||
if self.enable_metrics and state.obj.log_metrics:
|
||||
self.collect_metrics(state, recv_obj, i)
|
||||
if self.dump_requests_folder and state.finished and state.obj.log_metrics:
|
||||
@@ -830,6 +978,7 @@ class TokenizerManager:
|
||||
self,
|
||||
meta_info: dict,
|
||||
top_logprobs_num: int,
|
||||
token_ids_logprob: List[int],
|
||||
return_text_in_logprobs: bool,
|
||||
recv_obj: BatchStrOut,
|
||||
recv_obj_index: int,
|
||||
@@ -857,6 +1006,20 @@ class TokenizerManager:
|
||||
return_text_in_logprobs,
|
||||
)
|
||||
|
||||
if token_ids_logprob is not None:
|
||||
meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens(
|
||||
recv_obj.input_token_ids_logprobs_val[recv_obj_index],
|
||||
recv_obj.input_token_ids_logprobs_idx[recv_obj_index],
|
||||
return_text_in_logprobs,
|
||||
)
|
||||
meta_info["output_token_ids_logprobs"] = (
|
||||
self.detokenize_top_logprobs_tokens(
|
||||
recv_obj.output_token_ids_logprobs_val[recv_obj_index],
|
||||
recv_obj.output_token_ids_logprobs_idx[recv_obj_index],
|
||||
return_text_in_logprobs,
|
||||
)
|
||||
)
|
||||
|
||||
def detokenize_logprob_tokens(
|
||||
self,
|
||||
token_logprobs_val: List[float],
|
||||
@@ -900,34 +1063,30 @@ class TokenizerManager:
|
||||
else 0
|
||||
)
|
||||
|
||||
if state.first_token_time is None:
|
||||
state.first_token_time = time.time()
|
||||
if state.first_token_time == 0.0:
|
||||
state.first_token_time = state.last_time = time.time()
|
||||
state.last_completion_tokens = completion_tokens
|
||||
self.metrics_collector.observe_time_to_first_token(
|
||||
state.first_token_time - state.created_time
|
||||
)
|
||||
else:
|
||||
if completion_tokens >= 2:
|
||||
# Compute time_per_output_token for the streaming case
|
||||
self.metrics_collector.observe_time_per_output_token(
|
||||
(time.time() - state.first_token_time) / (completion_tokens - 1)
|
||||
num_new_tokens = completion_tokens - state.last_completion_tokens
|
||||
if num_new_tokens:
|
||||
new_time = time.time()
|
||||
interval = new_time - state.last_time
|
||||
self.metrics_collector.observe_inter_token_latency(
|
||||
interval,
|
||||
num_new_tokens,
|
||||
)
|
||||
state.last_time = new_time
|
||||
state.last_completion_tokens = completion_tokens
|
||||
|
||||
if state.finished:
|
||||
self.metrics_collector.observe_one_finished_request(
|
||||
recv_obj.prompt_tokens[i], completion_tokens
|
||||
recv_obj.prompt_tokens[i],
|
||||
completion_tokens,
|
||||
state.finished_time - state.created_time,
|
||||
)
|
||||
self.metrics_collector.observe_e2e_request_latency(
|
||||
time.time() - state.created_time
|
||||
)
|
||||
# Compute time_per_output_token for the non-streaming case
|
||||
if (
|
||||
hasattr(state.obj, "stream")
|
||||
and not state.obj.stream
|
||||
and completion_tokens >= 1
|
||||
):
|
||||
self.metrics_collector.observe_time_per_output_token(
|
||||
(time.time() - state.created_time) / completion_tokens
|
||||
)
|
||||
|
||||
def dump_requests(self, state: ReqState, out_dict: dict):
|
||||
self.dump_request_list.append(
|
||||
@@ -996,22 +1155,38 @@ T = TypeVar("T")
|
||||
|
||||
|
||||
class _Communicator(Generic[T]):
|
||||
"""Note: The communicator now only run up to 1 in-flight request at any time."""
|
||||
|
||||
def __init__(self, sender, fan_out: int):
|
||||
self._sender = sender
|
||||
self._fan_out = fan_out
|
||||
self._result_future: Optional[asyncio.Future] = None
|
||||
self._result_event: Optional[asyncio.Event] = None
|
||||
self._result_values: Optional[List[T]] = None
|
||||
self._ready_queue: Deque[asyncio.Future] = deque()
|
||||
|
||||
async def __call__(self, obj):
|
||||
self._sender.send_pyobj(obj)
|
||||
self._result_future = asyncio.Future()
|
||||
ready_event = asyncio.Event()
|
||||
if self._result_event is not None or len(self._ready_queue) > 0:
|
||||
self._ready_queue.append(ready_event)
|
||||
await ready_event.wait()
|
||||
assert self._result_event is None
|
||||
assert self._result_values is None
|
||||
|
||||
if obj:
|
||||
self._sender.send_pyobj(obj)
|
||||
|
||||
self._result_event = asyncio.Event()
|
||||
self._result_values = []
|
||||
await self._result_future
|
||||
await self._result_event.wait()
|
||||
result_values = self._result_values
|
||||
self._result_future = self._result_values = None
|
||||
self._result_event = self._result_values = None
|
||||
|
||||
if len(self._ready_queue) > 0:
|
||||
self._ready_queue.popleft().set()
|
||||
|
||||
return result_values
|
||||
|
||||
def handle_recv(self, recv_obj: T):
|
||||
self._result_values.append(recv_obj)
|
||||
if len(self._result_values) == self._fan_out:
|
||||
self._result_future.set_result(None)
|
||||
self._result_event.set()
|
||||
|
||||
@@ -15,10 +15,13 @@
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.io_struct import (
|
||||
GetWeightsByNameReqInput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
@@ -159,7 +162,7 @@ class TpModelWorker:
|
||||
model_worker_batch: ModelWorkerBatch,
|
||||
launch_done: Optional[threading.Event] = None,
|
||||
skip_sample: bool = False,
|
||||
):
|
||||
) -> Tuple[LogitsProcessorOutput, Optional[torch.Tensor]]:
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||
logits_output = self.model_runner.forward(forward_batch)
|
||||
if launch_done:
|
||||
|
||||
@@ -175,7 +175,7 @@ class TpModelWorkerClient:
|
||||
logits_output.next_token_logprobs.tolist()
|
||||
)
|
||||
if logits_output.input_token_logprobs is not None:
|
||||
logits_output.input_token_logprobs = (
|
||||
logits_output.input_token_logprobs = tuple(
|
||||
logits_output.input_token_logprobs.tolist()
|
||||
)
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
@@ -188,8 +188,7 @@ class TpModelWorkerClient:
|
||||
model_worker_batch.sampling_info = self.cur_sampling_info = dataclasses.replace(
|
||||
sampling_info,
|
||||
sampling_info_done=threading.Event(),
|
||||
scaling_penalties=sampling_info.scaling_penalties,
|
||||
linear_penalties=sampling_info.linear_penalties,
|
||||
penalizer_orchestrator=None,
|
||||
)
|
||||
|
||||
# A cuda stream sync here to avoid the cuda illegal memory access error.
|
||||
|
||||
Reference in New Issue
Block a user