[Minor] Fix code style (#2311)
This commit is contained in:
@@ -25,7 +25,6 @@ import uuid
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import fastapi
|
||||
import torch
|
||||
import uvloop
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
@@ -337,6 +336,12 @@ class TokenizerManager:
|
||||
rids.append(tmp_obj.rid)
|
||||
else:
|
||||
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
|
||||
if batch_size > 128:
|
||||
logger.warning(
|
||||
"Sending a single large batch with parallel sampling (n > 1) has not been well optimized. "
|
||||
"The performance might be better if you just duplicate the requests n times or use "
|
||||
"many threads to send them one by one with parallel sampling (n > 1)."
|
||||
)
|
||||
|
||||
# Tokenize all requests
|
||||
objs = [obj[i] for i in range(batch_size)]
|
||||
@@ -494,9 +499,7 @@ class TokenizerManager:
|
||||
result = await self.parameter_update_result
|
||||
return result.success, result.message
|
||||
else:
|
||||
logger.error(
|
||||
f"Another parameter update is in progress in tokenizer manager"
|
||||
)
|
||||
logger.error("Another parameter update is in progress in tokenizer manager")
|
||||
return (
|
||||
False,
|
||||
"Another parameter update is in progress. Please try again later.",
|
||||
@@ -597,7 +600,68 @@ class TokenizerManager:
|
||||
InitWeightsUpdateGroupReqOutput,
|
||||
] = await self.recv_from_detokenizer.recv_pyobj()
|
||||
|
||||
if isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
|
||||
if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)):
|
||||
for i, rid in enumerate(recv_obj.rids):
|
||||
state = self.rid_to_state.get(rid, None)
|
||||
if state is None:
|
||||
continue
|
||||
|
||||
recv_obj.meta_info[i]["id"] = rid
|
||||
if isinstance(recv_obj, BatchStrOut):
|
||||
out_dict = {
|
||||
"text": recv_obj.output_strs[i],
|
||||
"meta_info": recv_obj.meta_info[i],
|
||||
}
|
||||
elif isinstance(recv_obj, BatchTokenIDOut):
|
||||
out_dict = {
|
||||
"token_ids": recv_obj.output_ids[i],
|
||||
"meta_info": recv_obj.meta_info[i],
|
||||
}
|
||||
else:
|
||||
assert isinstance(recv_obj, BatchEmbeddingOut)
|
||||
out_dict = {
|
||||
"embedding": recv_obj.embeddings[i],
|
||||
"meta_info": recv_obj.meta_info[i],
|
||||
}
|
||||
state.out_list.append(out_dict)
|
||||
state.finished = recv_obj.finished_reason[i] is not None
|
||||
state.event.set()
|
||||
|
||||
if self.enable_metrics:
|
||||
completion_tokens = recv_obj.meta_info[i]["completion_tokens"]
|
||||
|
||||
if state.first_token_time is None:
|
||||
state.first_token_time = time.time()
|
||||
self.metrics_collector.observe_time_to_first_token(
|
||||
state.first_token_time - state.created_time
|
||||
)
|
||||
else:
|
||||
if completion_tokens >= 2:
|
||||
self.metrics_collector.observe_time_per_output_token(
|
||||
(time.time() - state.first_token_time)
|
||||
/ (completion_tokens - 1)
|
||||
)
|
||||
|
||||
if state.finished:
|
||||
self.metrics_collector.inc_prompt_tokens(
|
||||
recv_obj.meta_info[i]["prompt_tokens"]
|
||||
)
|
||||
self.metrics_collector.inc_generation_tokens(
|
||||
completion_tokens
|
||||
)
|
||||
self.metrics_collector.observe_e2e_request_latency(
|
||||
time.time() - state.created_time
|
||||
)
|
||||
if completion_tokens >= 1:
|
||||
self.metrics_collector.observe_time_per_output_token(
|
||||
(time.time() - state.created_time)
|
||||
/ completion_tokens
|
||||
)
|
||||
elif isinstance(recv_obj, OpenSessionReqOutput):
|
||||
self.session_futures[recv_obj.session_id].set_result(
|
||||
recv_obj.session_id
|
||||
)
|
||||
elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
|
||||
if self.server_args.dp_size == 1:
|
||||
self.model_update_result.set_result(recv_obj)
|
||||
else: # self.server_args.dp_size > 1
|
||||
@@ -605,13 +669,16 @@ class TokenizerManager:
|
||||
# set future if the all results are recevied
|
||||
if len(self.model_update_tmp) == self.server_args.dp_size:
|
||||
self.model_update_result.set_result(self.model_update_tmp)
|
||||
continue
|
||||
elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput):
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be 1 for init parameter update group"
|
||||
self.init_weights_update_group_result.set_result(recv_obj)
|
||||
elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be 1 for update weights from distributed"
|
||||
self.parameter_update_result.set_result(recv_obj)
|
||||
continue
|
||||
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
|
||||
if self.server_args.dp_size == 1:
|
||||
self.get_weights_by_name_result.set_result(recv_obj)
|
||||
@@ -621,76 +688,8 @@ class TokenizerManager:
|
||||
self.get_weights_by_name_result.set_result(
|
||||
self.get_weights_by_name_tmp
|
||||
)
|
||||
continue
|
||||
elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput):
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be 1 for init parameter update group"
|
||||
self.init_weights_update_group_result.set_result(recv_obj)
|
||||
continue
|
||||
elif isinstance(recv_obj, OpenSessionReqOutput):
|
||||
self.session_futures[recv_obj.session_id].set_result(
|
||||
recv_obj.session_id
|
||||
)
|
||||
continue
|
||||
|
||||
assert isinstance(
|
||||
recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
|
||||
), f"Unexpected obj received: {type(recv_obj)}"
|
||||
|
||||
for i, rid in enumerate(recv_obj.rids):
|
||||
state = self.rid_to_state.get(rid, None)
|
||||
if state is None:
|
||||
continue
|
||||
|
||||
recv_obj.meta_info[i]["id"] = rid
|
||||
if isinstance(recv_obj, BatchStrOut):
|
||||
out_dict = {
|
||||
"text": recv_obj.output_strs[i],
|
||||
"meta_info": recv_obj.meta_info[i],
|
||||
}
|
||||
elif isinstance(recv_obj, BatchTokenIDOut):
|
||||
out_dict = {
|
||||
"token_ids": recv_obj.output_ids[i],
|
||||
"meta_info": recv_obj.meta_info[i],
|
||||
}
|
||||
else:
|
||||
assert isinstance(recv_obj, BatchEmbeddingOut)
|
||||
out_dict = {
|
||||
"embedding": recv_obj.embeddings[i],
|
||||
"meta_info": recv_obj.meta_info[i],
|
||||
}
|
||||
state.out_list.append(out_dict)
|
||||
state.finished = recv_obj.finished_reason[i] is not None
|
||||
state.event.set()
|
||||
|
||||
if self.enable_metrics:
|
||||
completion_tokens = recv_obj.meta_info[i]["completion_tokens"]
|
||||
|
||||
if state.first_token_time is None:
|
||||
state.first_token_time = time.time()
|
||||
self.metrics_collector.observe_time_to_first_token(
|
||||
state.first_token_time - state.created_time
|
||||
)
|
||||
else:
|
||||
if completion_tokens >= 2:
|
||||
self.metrics_collector.observe_time_per_output_token(
|
||||
(time.time() - state.first_token_time)
|
||||
/ (completion_tokens - 1)
|
||||
)
|
||||
|
||||
if state.finished:
|
||||
self.metrics_collector.inc_prompt_tokens(
|
||||
recv_obj.meta_info[i]["prompt_tokens"]
|
||||
)
|
||||
self.metrics_collector.inc_generation_tokens(completion_tokens)
|
||||
self.metrics_collector.observe_e2e_request_latency(
|
||||
time.time() - state.created_time
|
||||
)
|
||||
if completion_tokens >= 1:
|
||||
self.metrics_collector.observe_time_per_output_token(
|
||||
(time.time() - state.created_time) / completion_tokens
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid object: {recv_obj=}")
|
||||
|
||||
def convert_logprob_style(
|
||||
self,
|
||||
|
||||
@@ -218,16 +218,6 @@ class ModelRunner:
|
||||
)
|
||||
self.tp_group = get_tp_group()
|
||||
|
||||
# Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
|
||||
# so we disable padding in cuda graph.
|
||||
if self.device == "cuda" and not all(
|
||||
in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)
|
||||
):
|
||||
self.server_args.disable_cuda_graph_padding = True
|
||||
logger.info(
|
||||
"Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
|
||||
)
|
||||
|
||||
# Check memory for tensor parallelism
|
||||
if self.tp_size > 1:
|
||||
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
||||
|
||||
@@ -82,7 +82,6 @@ from sglang.srt.utils import (
|
||||
assert_pkg_version,
|
||||
configure_logger,
|
||||
delete_directory,
|
||||
init_custom_process_group,
|
||||
is_port_available,
|
||||
kill_process_tree,
|
||||
maybe_set_triton_cache_manager,
|
||||
@@ -154,13 +153,11 @@ async def get_model_info():
|
||||
|
||||
@app.get("/get_server_info")
|
||||
async def get_server_info():
|
||||
try:
|
||||
return await _get_server_info()
|
||||
|
||||
except Exception as e:
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
return {
|
||||
**dataclasses.asdict(tokenizer_manager.server_args), # server args
|
||||
**scheduler_info,
|
||||
"version": __version__,
|
||||
}
|
||||
|
||||
|
||||
@app.post("/flush_cache")
|
||||
@@ -567,14 +564,6 @@ def launch_server(
|
||||
t.join()
|
||||
|
||||
|
||||
async def _get_server_info():
|
||||
return {
|
||||
**dataclasses.asdict(tokenizer_manager.server_args), # server args
|
||||
**scheduler_info,
|
||||
"version": __version__,
|
||||
}
|
||||
|
||||
|
||||
def _set_envs_and_config(server_args: ServerArgs):
|
||||
# Set global environments
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
@@ -687,11 +676,218 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
||||
delete_directory(server_args.model_path)
|
||||
|
||||
|
||||
STREAM_END_SYMBOL = b"data: [DONE]"
|
||||
STREAM_CHUNK_START_SYMBOL = b"data:"
|
||||
|
||||
|
||||
class Engine:
|
||||
"""
|
||||
SRT Engine without an HTTP server layer.
|
||||
|
||||
This class provides a direct inference engine without the need for an HTTP server. It is designed for use cases where
|
||||
launching the HTTP server adds unnecessary complexity or overhead,
|
||||
"""
|
||||
|
||||
def __init__(self, log_level: str = "error", *args, **kwargs):
|
||||
"""See the arguments in server_args.py::ServerArgs"""
|
||||
|
||||
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
|
||||
atexit.register(self.shutdown)
|
||||
|
||||
server_args = ServerArgs(*args, log_level=log_level, **kwargs)
|
||||
launch_engine(server_args=server_args)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||
prompt: Optional[Union[List[str], str]] = None,
|
||||
sampling_params: Optional[Union[List[Dict], Dict]] = None,
|
||||
# The token ids for text; one can either specify text or input_ids.
|
||||
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||
lora_path: Optional[List[Optional[str]]] = None,
|
||||
stream: bool = False,
|
||||
):
|
||||
obj = GenerateReqInput(
|
||||
text=prompt,
|
||||
input_ids=input_ids,
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=return_logprob,
|
||||
logprob_start_len=logprob_start_len,
|
||||
top_logprobs_num=top_logprobs_num,
|
||||
lora_path=lora_path,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
# get the current event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
ret = loop.run_until_complete(generate_request(obj, None))
|
||||
|
||||
if stream is True:
|
||||
|
||||
def generator_wrapper():
|
||||
offset = 0
|
||||
loop = asyncio.get_event_loop()
|
||||
generator = ret.body_iterator
|
||||
while True:
|
||||
chunk = loop.run_until_complete(generator.__anext__())
|
||||
|
||||
if chunk.startswith(STREAM_END_SYMBOL):
|
||||
break
|
||||
else:
|
||||
data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
|
||||
data["text"] = data["text"][offset:]
|
||||
offset += len(data["text"])
|
||||
yield data
|
||||
|
||||
# we cannot yield in the scope of generate() because python does not allow yield + return in the same function
|
||||
# however, it allows to wrap the generator as a subfunction and return
|
||||
return generator_wrapper()
|
||||
else:
|
||||
return ret
|
||||
|
||||
async def async_generate(
|
||||
self,
|
||||
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||
prompt: Optional[Union[List[str], str]] = None,
|
||||
sampling_params: Optional[Dict] = None,
|
||||
# The token ids for text; one can either specify text or input_ids.
|
||||
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||
lora_path: Optional[List[Optional[str]]] = None,
|
||||
stream: bool = False,
|
||||
):
|
||||
obj = GenerateReqInput(
|
||||
text=prompt,
|
||||
input_ids=input_ids,
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=return_logprob,
|
||||
logprob_start_len=logprob_start_len,
|
||||
top_logprobs_num=top_logprobs_num,
|
||||
lora_path=lora_path,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
ret = await generate_request(obj, None)
|
||||
|
||||
if stream is True:
|
||||
generator = ret.body_iterator
|
||||
|
||||
async def generator_wrapper():
|
||||
|
||||
offset = 0
|
||||
|
||||
while True:
|
||||
chunk = await generator.__anext__()
|
||||
|
||||
if chunk.startswith(STREAM_END_SYMBOL):
|
||||
break
|
||||
else:
|
||||
data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
|
||||
data["text"] = data["text"][offset:]
|
||||
offset += len(data["text"])
|
||||
yield data
|
||||
|
||||
return generator_wrapper()
|
||||
else:
|
||||
return ret
|
||||
|
||||
def shutdown(self):
|
||||
kill_process_tree(os.getpid(), include_parent=False)
|
||||
|
||||
def get_tokenizer(self):
|
||||
global tokenizer_manager
|
||||
|
||||
if tokenizer_manager is None:
|
||||
raise ReferenceError("Tokenizer Manager is not initialized.")
|
||||
else:
|
||||
return tokenizer_manager.tokenizer
|
||||
|
||||
def encode(
|
||||
self,
|
||||
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
||||
):
|
||||
obj = EmbeddingReqInput(text=prompt)
|
||||
|
||||
# get the current event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(encode_request(obj, None))
|
||||
|
||||
def start_profile(self):
|
||||
tokenizer_manager.start_profile()
|
||||
|
||||
def stop_profile(self):
|
||||
tokenizer_manager.stop_profile()
|
||||
|
||||
def get_server_info(self):
|
||||
return {
|
||||
**dataclasses.asdict(tokenizer_manager.server_args), # server args
|
||||
**scheduler_info,
|
||||
"version": __version__,
|
||||
}
|
||||
|
||||
def init_weights_update_group(
|
||||
self,
|
||||
master_address: str,
|
||||
master_port: int,
|
||||
rank_offset: int,
|
||||
world_size: int,
|
||||
group_name: str,
|
||||
backend: str = "nccl",
|
||||
):
|
||||
"""Initialize parameter update group."""
|
||||
obj = InitWeightsUpdateGroupReqInput(
|
||||
master_address=master_address,
|
||||
master_port=master_port,
|
||||
rank_offset=rank_offset,
|
||||
world_size=world_size,
|
||||
group_name=group_name,
|
||||
backend=backend,
|
||||
)
|
||||
|
||||
async def _init_group():
|
||||
return await tokenizer_manager.init_weights_update_group(obj, None)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(_init_group())
|
||||
|
||||
def update_weights_from_distributed(self, name, dtype, shape):
|
||||
"""Update weights from distributed source."""
|
||||
obj = UpdateWeightsFromDistributedReqInput(
|
||||
name=name,
|
||||
dtype=dtype,
|
||||
shape=shape,
|
||||
)
|
||||
|
||||
async def _update_weights():
|
||||
return await tokenizer_manager.update_weights_from_distributed(obj, None)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(_update_weights())
|
||||
|
||||
def get_weights_by_name(self, name, truncate_size=100):
|
||||
"""Get weights by parameter name."""
|
||||
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
|
||||
|
||||
async def _get_weights():
|
||||
return await tokenizer_manager.get_weights_by_name(obj, None)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(_get_weights())
|
||||
|
||||
|
||||
class Runtime:
|
||||
"""
|
||||
A wrapper for the server.
|
||||
A wrapper for the HTTP server.
|
||||
This is used for launching the server in a python program without
|
||||
using the commond line interface.
|
||||
|
||||
It is mainly used for the frontend language.
|
||||
You should use the Engine class if you want to do normal offline processing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -839,201 +1035,3 @@ class Runtime:
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
|
||||
STREAM_END_SYMBOL = b"data: [DONE]"
|
||||
STREAM_CHUNK_START_SYMBOL = b"data:"
|
||||
|
||||
|
||||
class Engine:
|
||||
"""
|
||||
SRT Engine without an HTTP server layer.
|
||||
|
||||
This class provides a direct inference engine without the need for an HTTP server. It is designed for use cases where
|
||||
launching the HTTP server adds unnecessary complexity or overhead,
|
||||
"""
|
||||
|
||||
def __init__(self, log_level: str = "error", *args, **kwargs):
|
||||
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
|
||||
atexit.register(self.shutdown)
|
||||
|
||||
server_args = ServerArgs(*args, log_level=log_level, **kwargs)
|
||||
launch_engine(server_args=server_args)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||
prompt: Optional[Union[List[str], str]] = None,
|
||||
sampling_params: Optional[Union[List[Dict], Dict]] = None,
|
||||
# The token ids for text; one can either specify text or input_ids.
|
||||
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||
lora_path: Optional[List[Optional[str]]] = None,
|
||||
stream: bool = False,
|
||||
):
|
||||
obj = GenerateReqInput(
|
||||
text=prompt,
|
||||
input_ids=input_ids,
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=return_logprob,
|
||||
logprob_start_len=logprob_start_len,
|
||||
top_logprobs_num=top_logprobs_num,
|
||||
lora_path=lora_path,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
# get the current event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
ret = loop.run_until_complete(generate_request(obj, None))
|
||||
|
||||
if stream is True:
|
||||
|
||||
def generator_wrapper():
|
||||
offset = 0
|
||||
loop = asyncio.get_event_loop()
|
||||
generator = ret.body_iterator
|
||||
while True:
|
||||
chunk = loop.run_until_complete(generator.__anext__())
|
||||
|
||||
if chunk.startswith(STREAM_END_SYMBOL):
|
||||
break
|
||||
else:
|
||||
data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
|
||||
data["text"] = data["text"][offset:]
|
||||
offset += len(data["text"])
|
||||
yield data
|
||||
|
||||
# we cannot yield in the scope of generate() because python does not allow yield + return in the same function
|
||||
# however, it allows to wrap the generator as a subfunction and return
|
||||
return generator_wrapper()
|
||||
else:
|
||||
return ret
|
||||
|
||||
async def async_generate(
|
||||
self,
|
||||
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||
prompt: Optional[Union[List[str], str]] = None,
|
||||
sampling_params: Optional[Dict] = None,
|
||||
# The token ids for text; one can either specify text or input_ids.
|
||||
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||
lora_path: Optional[List[Optional[str]]] = None,
|
||||
stream: bool = False,
|
||||
):
|
||||
obj = GenerateReqInput(
|
||||
text=prompt,
|
||||
input_ids=input_ids,
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=return_logprob,
|
||||
logprob_start_len=logprob_start_len,
|
||||
top_logprobs_num=top_logprobs_num,
|
||||
lora_path=lora_path,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
ret = await generate_request(obj, None)
|
||||
|
||||
if stream is True:
|
||||
generator = ret.body_iterator
|
||||
|
||||
async def generator_wrapper():
|
||||
|
||||
offset = 0
|
||||
|
||||
while True:
|
||||
chunk = await generator.__anext__()
|
||||
|
||||
if chunk.startswith(STREAM_END_SYMBOL):
|
||||
break
|
||||
else:
|
||||
data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
|
||||
data["text"] = data["text"][offset:]
|
||||
offset += len(data["text"])
|
||||
yield data
|
||||
|
||||
return generator_wrapper()
|
||||
else:
|
||||
return ret
|
||||
|
||||
def shutdown(self):
|
||||
kill_process_tree(os.getpid(), include_parent=False)
|
||||
|
||||
def get_tokenizer(self):
|
||||
global tokenizer_manager
|
||||
|
||||
if tokenizer_manager is None:
|
||||
raise ReferenceError("Tokenizer Manager is not initialized.")
|
||||
else:
|
||||
return tokenizer_manager.tokenizer
|
||||
|
||||
def encode(
|
||||
self,
|
||||
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
||||
):
|
||||
obj = EmbeddingReqInput(text=prompt)
|
||||
|
||||
# get the current event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(encode_request(obj, None))
|
||||
|
||||
def start_profile(self):
|
||||
tokenizer_manager.start_profile()
|
||||
|
||||
def stop_profile(self):
|
||||
tokenizer_manager.stop_profile()
|
||||
|
||||
async def get_server_info(self):
|
||||
return await _get_server_info()
|
||||
|
||||
def init_weights_update_group(
|
||||
self,
|
||||
master_address: str,
|
||||
master_port: int,
|
||||
rank_offset: int,
|
||||
world_size: int,
|
||||
group_name: str,
|
||||
backend: str = "nccl",
|
||||
):
|
||||
"""Initialize parameter update group."""
|
||||
obj = InitWeightsUpdateGroupReqInput(
|
||||
master_address=master_address,
|
||||
master_port=master_port,
|
||||
rank_offset=rank_offset,
|
||||
world_size=world_size,
|
||||
group_name=group_name,
|
||||
backend=backend,
|
||||
)
|
||||
|
||||
async def _init_group():
|
||||
return await tokenizer_manager.init_weights_update_group(obj, None)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(_init_group())
|
||||
|
||||
def update_weights_from_distributed(self, name, dtype, shape):
|
||||
"""Update weights from distributed source."""
|
||||
obj = UpdateWeightsFromDistributedReqInput(
|
||||
name=name,
|
||||
dtype=dtype,
|
||||
shape=shape,
|
||||
)
|
||||
|
||||
async def _update_weights():
|
||||
return await tokenizer_manager.update_weights_from_distributed(obj, None)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(_update_weights())
|
||||
|
||||
def get_weights_by_name(self, name, truncate_size=100):
|
||||
"""Get weights by parameter name."""
|
||||
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
|
||||
|
||||
async def _get_weights():
|
||||
return await tokenizer_manager.get_weights_by_name(obj, None)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(_get_weights())
|
||||
|
||||
Reference in New Issue
Block a user