This commit is contained in:
Ying Sheng
2024-07-05 10:06:17 -07:00
committed by GitHub
parent 5a57b8addd
commit dc1b8bcfaa
21 changed files with 487 additions and 354 deletions

View File

@@ -34,11 +34,11 @@ from sglang.srt.managers.io_struct import (
from sglang.srt.model_config import ModelConfig
from sglang.srt.server_args import ModelPortArgs, ServerArgs
from sglang.srt.utils import (
connect_rpyc_service,
get_int_token_logit_bias,
is_multimodal_model,
set_random_seed,
start_rpyc_service_process,
connect_rpyc_service,
suppress_other_loggers,
)
from sglang.utils import get_exception_traceback
@@ -368,9 +368,11 @@ class ModelTpServer:
if (
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
< available_size
and (req.extend_input_len + new_batch_input_tokens
<= self.max_prefill_tokens
or len(can_run_list) == 0)
and (
req.extend_input_len + new_batch_input_tokens
<= self.max_prefill_tokens
or len(can_run_list) == 0
)
):
delta = self.tree_cache.inc_lock_ref(req.last_node)
available_size += delta
@@ -452,7 +454,9 @@ class ModelTpServer:
next_token_ids,
].tolist()
output.prefill_token_logprobs = output.prefill_token_logprobs.tolist()
output.normalized_prompt_logprobs = output.normalized_prompt_logprobs.tolist()
output.normalized_prompt_logprobs = (
output.normalized_prompt_logprobs.tolist()
)
next_token_ids = next_token_ids.tolist()
else:
@@ -582,7 +586,9 @@ class ModelTpServer:
req.check_finished()
if req.return_logprob:
req.decode_token_logprobs.append((next_token_logprobs[i], next_token_id))
req.decode_token_logprobs.append(
(next_token_logprobs[i], next_token_id)
)
if req.top_logprobs_num > 0:
req.decode_top_logprobs.append(output.decode_top_logprobs[i])
@@ -759,16 +765,27 @@ class ModelTpClient:
with ThreadPoolExecutor(self.tp_size) as executor:
# Launch model processes
if server_args.nnodes == 1:
self.procs = list(executor.map(
lambda args: start_rpyc_service_process(*args),
[(ModelTpService, p) for p in model_port_args.model_tp_ports],
))
self.procs = list(
executor.map(
lambda args: start_rpyc_service_process(*args),
[
(ModelTpService, p)
for p in model_port_args.model_tp_ports
],
)
)
addrs = [("localhost", p) for p in model_port_args.model_tp_ports]
else:
addrs = [(ip, port) for ip, port in zip(model_port_args.model_tp_ips, model_port_args.model_tp_ports)]
addrs = [
(ip, port)
for ip, port in zip(
model_port_args.model_tp_ips, model_port_args.model_tp_ports
)
]
self.model_services = list(executor.map(
lambda args: connect_rpyc_service(*args), addrs))
self.model_services = list(
executor.map(lambda args: connect_rpyc_service(*args), addrs)
)
# Init model
def init_model(i):