[Minor] Fix code style (#2311)
This commit is contained in:
@@ -25,7 +25,6 @@ import uuid
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import fastapi
|
||||
import torch
|
||||
import uvloop
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
@@ -337,6 +336,12 @@ class TokenizerManager:
|
||||
rids.append(tmp_obj.rid)
|
||||
else:
|
||||
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
|
||||
if batch_size > 128:
|
||||
logger.warning(
|
||||
"Sending a single large batch with parallel sampling (n > 1) has not been well optimized. "
|
||||
"The performance might be better if you just duplicate the requests n times or use "
|
||||
"many threads to send them one by one with parallel sampling (n > 1)."
|
||||
)
|
||||
|
||||
# Tokenize all requests
|
||||
objs = [obj[i] for i in range(batch_size)]
|
||||
@@ -494,9 +499,7 @@ class TokenizerManager:
|
||||
result = await self.parameter_update_result
|
||||
return result.success, result.message
|
||||
else:
|
||||
logger.error(
|
||||
f"Another parameter update is in progress in tokenizer manager"
|
||||
)
|
||||
logger.error("Another parameter update is in progress in tokenizer manager")
|
||||
return (
|
||||
False,
|
||||
"Another parameter update is in progress. Please try again later.",
|
||||
@@ -597,7 +600,68 @@ class TokenizerManager:
|
||||
InitWeightsUpdateGroupReqOutput,
|
||||
] = await self.recv_from_detokenizer.recv_pyobj()
|
||||
|
||||
if isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
|
||||
if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)):
|
||||
for i, rid in enumerate(recv_obj.rids):
|
||||
state = self.rid_to_state.get(rid, None)
|
||||
if state is None:
|
||||
continue
|
||||
|
||||
recv_obj.meta_info[i]["id"] = rid
|
||||
if isinstance(recv_obj, BatchStrOut):
|
||||
out_dict = {
|
||||
"text": recv_obj.output_strs[i],
|
||||
"meta_info": recv_obj.meta_info[i],
|
||||
}
|
||||
elif isinstance(recv_obj, BatchTokenIDOut):
|
||||
out_dict = {
|
||||
"token_ids": recv_obj.output_ids[i],
|
||||
"meta_info": recv_obj.meta_info[i],
|
||||
}
|
||||
else:
|
||||
assert isinstance(recv_obj, BatchEmbeddingOut)
|
||||
out_dict = {
|
||||
"embedding": recv_obj.embeddings[i],
|
||||
"meta_info": recv_obj.meta_info[i],
|
||||
}
|
||||
state.out_list.append(out_dict)
|
||||
state.finished = recv_obj.finished_reason[i] is not None
|
||||
state.event.set()
|
||||
|
||||
if self.enable_metrics:
|
||||
completion_tokens = recv_obj.meta_info[i]["completion_tokens"]
|
||||
|
||||
if state.first_token_time is None:
|
||||
state.first_token_time = time.time()
|
||||
self.metrics_collector.observe_time_to_first_token(
|
||||
state.first_token_time - state.created_time
|
||||
)
|
||||
else:
|
||||
if completion_tokens >= 2:
|
||||
self.metrics_collector.observe_time_per_output_token(
|
||||
(time.time() - state.first_token_time)
|
||||
/ (completion_tokens - 1)
|
||||
)
|
||||
|
||||
if state.finished:
|
||||
self.metrics_collector.inc_prompt_tokens(
|
||||
recv_obj.meta_info[i]["prompt_tokens"]
|
||||
)
|
||||
self.metrics_collector.inc_generation_tokens(
|
||||
completion_tokens
|
||||
)
|
||||
self.metrics_collector.observe_e2e_request_latency(
|
||||
time.time() - state.created_time
|
||||
)
|
||||
if completion_tokens >= 1:
|
||||
self.metrics_collector.observe_time_per_output_token(
|
||||
(time.time() - state.created_time)
|
||||
/ completion_tokens
|
||||
)
|
||||
elif isinstance(recv_obj, OpenSessionReqOutput):
|
||||
self.session_futures[recv_obj.session_id].set_result(
|
||||
recv_obj.session_id
|
||||
)
|
||||
elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
|
||||
if self.server_args.dp_size == 1:
|
||||
self.model_update_result.set_result(recv_obj)
|
||||
else: # self.server_args.dp_size > 1
|
||||
@@ -605,13 +669,16 @@ class TokenizerManager:
|
||||
# set future if the all results are recevied
|
||||
if len(self.model_update_tmp) == self.server_args.dp_size:
|
||||
self.model_update_result.set_result(self.model_update_tmp)
|
||||
continue
|
||||
elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput):
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be 1 for init parameter update group"
|
||||
self.init_weights_update_group_result.set_result(recv_obj)
|
||||
elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be 1 for update weights from distributed"
|
||||
self.parameter_update_result.set_result(recv_obj)
|
||||
continue
|
||||
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
|
||||
if self.server_args.dp_size == 1:
|
||||
self.get_weights_by_name_result.set_result(recv_obj)
|
||||
@@ -621,76 +688,8 @@ class TokenizerManager:
|
||||
self.get_weights_by_name_result.set_result(
|
||||
self.get_weights_by_name_tmp
|
||||
)
|
||||
continue
|
||||
elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput):
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be 1 for init parameter update group"
|
||||
self.init_weights_update_group_result.set_result(recv_obj)
|
||||
continue
|
||||
elif isinstance(recv_obj, OpenSessionReqOutput):
|
||||
self.session_futures[recv_obj.session_id].set_result(
|
||||
recv_obj.session_id
|
||||
)
|
||||
continue
|
||||
|
||||
assert isinstance(
|
||||
recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
|
||||
), f"Unexpected obj received: {type(recv_obj)}"
|
||||
|
||||
for i, rid in enumerate(recv_obj.rids):
|
||||
state = self.rid_to_state.get(rid, None)
|
||||
if state is None:
|
||||
continue
|
||||
|
||||
recv_obj.meta_info[i]["id"] = rid
|
||||
if isinstance(recv_obj, BatchStrOut):
|
||||
out_dict = {
|
||||
"text": recv_obj.output_strs[i],
|
||||
"meta_info": recv_obj.meta_info[i],
|
||||
}
|
||||
elif isinstance(recv_obj, BatchTokenIDOut):
|
||||
out_dict = {
|
||||
"token_ids": recv_obj.output_ids[i],
|
||||
"meta_info": recv_obj.meta_info[i],
|
||||
}
|
||||
else:
|
||||
assert isinstance(recv_obj, BatchEmbeddingOut)
|
||||
out_dict = {
|
||||
"embedding": recv_obj.embeddings[i],
|
||||
"meta_info": recv_obj.meta_info[i],
|
||||
}
|
||||
state.out_list.append(out_dict)
|
||||
state.finished = recv_obj.finished_reason[i] is not None
|
||||
state.event.set()
|
||||
|
||||
if self.enable_metrics:
|
||||
completion_tokens = recv_obj.meta_info[i]["completion_tokens"]
|
||||
|
||||
if state.first_token_time is None:
|
||||
state.first_token_time = time.time()
|
||||
self.metrics_collector.observe_time_to_first_token(
|
||||
state.first_token_time - state.created_time
|
||||
)
|
||||
else:
|
||||
if completion_tokens >= 2:
|
||||
self.metrics_collector.observe_time_per_output_token(
|
||||
(time.time() - state.first_token_time)
|
||||
/ (completion_tokens - 1)
|
||||
)
|
||||
|
||||
if state.finished:
|
||||
self.metrics_collector.inc_prompt_tokens(
|
||||
recv_obj.meta_info[i]["prompt_tokens"]
|
||||
)
|
||||
self.metrics_collector.inc_generation_tokens(completion_tokens)
|
||||
self.metrics_collector.observe_e2e_request_latency(
|
||||
time.time() - state.created_time
|
||||
)
|
||||
if completion_tokens >= 1:
|
||||
self.metrics_collector.observe_time_per_output_token(
|
||||
(time.time() - state.created_time) / completion_tokens
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid object: {recv_obj=}")
|
||||
|
||||
def convert_logprob_style(
|
||||
self,
|
||||
|
||||
@@ -218,16 +218,6 @@ class ModelRunner:
|
||||
)
|
||||
self.tp_group = get_tp_group()
|
||||
|
||||
# Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
|
||||
# so we disable padding in cuda graph.
|
||||
if self.device == "cuda" and not all(
|
||||
in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)
|
||||
):
|
||||
self.server_args.disable_cuda_graph_padding = True
|
||||
logger.info(
|
||||
"Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
|
||||
)
|
||||
|
||||
# Check memory for tensor parallelism
|
||||
if self.tp_size > 1:
|
||||
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
||||
|
||||
@@ -82,7 +82,6 @@ from sglang.srt.utils import (
|
||||
assert_pkg_version,
|
||||
configure_logger,
|
||||
delete_directory,
|
||||
init_custom_process_group,
|
||||
is_port_available,
|
||||
kill_process_tree,
|
||||
maybe_set_triton_cache_manager,
|
||||
@@ -154,13 +153,11 @@ async def get_model_info():
|
||||
|
||||
@app.get("/get_server_info")
|
||||
async def get_server_info():
|
||||
try:
|
||||
return await _get_server_info()
|
||||
|
||||
except Exception as e:
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
return {
|
||||
**dataclasses.asdict(tokenizer_manager.server_args), # server args
|
||||
**scheduler_info,
|
||||
"version": __version__,
|
||||
}
|
||||
|
||||
|
||||
@app.post("/flush_cache")
|
||||
@@ -567,14 +564,6 @@ def launch_server(
|
||||
t.join()
|
||||
|
||||
|
||||
async def _get_server_info():
|
||||
return {
|
||||
**dataclasses.asdict(tokenizer_manager.server_args), # server args
|
||||
**scheduler_info,
|
||||
"version": __version__,
|
||||
}
|
||||
|
||||
|
||||
def _set_envs_and_config(server_args: ServerArgs):
|
||||
# Set global environments
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
@@ -687,11 +676,218 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
||||
delete_directory(server_args.model_path)
|
||||
|
||||
|
||||
STREAM_END_SYMBOL = b"data: [DONE]"
|
||||
STREAM_CHUNK_START_SYMBOL = b"data:"
|
||||
|
||||
|
||||
class Engine:
|
||||
"""
|
||||
SRT Engine without an HTTP server layer.
|
||||
|
||||
This class provides a direct inference engine without the need for an HTTP server. It is designed for use cases where
|
||||
launching the HTTP server adds unnecessary complexity or overhead,
|
||||
"""
|
||||
|
||||
def __init__(self, log_level: str = "error", *args, **kwargs):
|
||||
"""See the arguments in server_args.py::ServerArgs"""
|
||||
|
||||
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
|
||||
atexit.register(self.shutdown)
|
||||
|
||||
server_args = ServerArgs(*args, log_level=log_level, **kwargs)
|
||||
launch_engine(server_args=server_args)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||
prompt: Optional[Union[List[str], str]] = None,
|
||||
sampling_params: Optional[Union[List[Dict], Dict]] = None,
|
||||
# The token ids for text; one can either specify text or input_ids.
|
||||
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||
lora_path: Optional[List[Optional[str]]] = None,
|
||||
stream: bool = False,
|
||||
):
|
||||
obj = GenerateReqInput(
|
||||
text=prompt,
|
||||
input_ids=input_ids,
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=return_logprob,
|
||||
logprob_start_len=logprob_start_len,
|
||||
top_logprobs_num=top_logprobs_num,
|
||||
lora_path=lora_path,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
# get the current event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
ret = loop.run_until_complete(generate_request(obj, None))
|
||||
|
||||
if stream is True:
|
||||
|
||||
def generator_wrapper():
|
||||
offset = 0
|
||||
loop = asyncio.get_event_loop()
|
||||
generator = ret.body_iterator
|
||||
while True:
|
||||
chunk = loop.run_until_complete(generator.__anext__())
|
||||
|
||||
if chunk.startswith(STREAM_END_SYMBOL):
|
||||
break
|
||||
else:
|
||||
data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
|
||||
data["text"] = data["text"][offset:]
|
||||
offset += len(data["text"])
|
||||
yield data
|
||||
|
||||
# we cannot yield in the scope of generate() because python does not allow yield + return in the same function
|
||||
# however, it allows to wrap the generator as a subfunction and return
|
||||
return generator_wrapper()
|
||||
else:
|
||||
return ret
|
||||
|
||||
async def async_generate(
|
||||
self,
|
||||
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||
prompt: Optional[Union[List[str], str]] = None,
|
||||
sampling_params: Optional[Dict] = None,
|
||||
# The token ids for text; one can either specify text or input_ids.
|
||||
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||
lora_path: Optional[List[Optional[str]]] = None,
|
||||
stream: bool = False,
|
||||
):
|
||||
obj = GenerateReqInput(
|
||||
text=prompt,
|
||||
input_ids=input_ids,
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=return_logprob,
|
||||
logprob_start_len=logprob_start_len,
|
||||
top_logprobs_num=top_logprobs_num,
|
||||
lora_path=lora_path,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
ret = await generate_request(obj, None)
|
||||
|
||||
if stream is True:
|
||||
generator = ret.body_iterator
|
||||
|
||||
async def generator_wrapper():
|
||||
|
||||
offset = 0
|
||||
|
||||
while True:
|
||||
chunk = await generator.__anext__()
|
||||
|
||||
if chunk.startswith(STREAM_END_SYMBOL):
|
||||
break
|
||||
else:
|
||||
data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
|
||||
data["text"] = data["text"][offset:]
|
||||
offset += len(data["text"])
|
||||
yield data
|
||||
|
||||
return generator_wrapper()
|
||||
else:
|
||||
return ret
|
||||
|
||||
def shutdown(self):
|
||||
kill_process_tree(os.getpid(), include_parent=False)
|
||||
|
||||
def get_tokenizer(self):
|
||||
global tokenizer_manager
|
||||
|
||||
if tokenizer_manager is None:
|
||||
raise ReferenceError("Tokenizer Manager is not initialized.")
|
||||
else:
|
||||
return tokenizer_manager.tokenizer
|
||||
|
||||
def encode(
|
||||
self,
|
||||
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
||||
):
|
||||
obj = EmbeddingReqInput(text=prompt)
|
||||
|
||||
# get the current event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(encode_request(obj, None))
|
||||
|
||||
def start_profile(self):
|
||||
tokenizer_manager.start_profile()
|
||||
|
||||
def stop_profile(self):
|
||||
tokenizer_manager.stop_profile()
|
||||
|
||||
def get_server_info(self):
|
||||
return {
|
||||
**dataclasses.asdict(tokenizer_manager.server_args), # server args
|
||||
**scheduler_info,
|
||||
"version": __version__,
|
||||
}
|
||||
|
||||
def init_weights_update_group(
|
||||
self,
|
||||
master_address: str,
|
||||
master_port: int,
|
||||
rank_offset: int,
|
||||
world_size: int,
|
||||
group_name: str,
|
||||
backend: str = "nccl",
|
||||
):
|
||||
"""Initialize parameter update group."""
|
||||
obj = InitWeightsUpdateGroupReqInput(
|
||||
master_address=master_address,
|
||||
master_port=master_port,
|
||||
rank_offset=rank_offset,
|
||||
world_size=world_size,
|
||||
group_name=group_name,
|
||||
backend=backend,
|
||||
)
|
||||
|
||||
async def _init_group():
|
||||
return await tokenizer_manager.init_weights_update_group(obj, None)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(_init_group())
|
||||
|
||||
def update_weights_from_distributed(self, name, dtype, shape):
|
||||
"""Update weights from distributed source."""
|
||||
obj = UpdateWeightsFromDistributedReqInput(
|
||||
name=name,
|
||||
dtype=dtype,
|
||||
shape=shape,
|
||||
)
|
||||
|
||||
async def _update_weights():
|
||||
return await tokenizer_manager.update_weights_from_distributed(obj, None)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(_update_weights())
|
||||
|
||||
def get_weights_by_name(self, name, truncate_size=100):
|
||||
"""Get weights by parameter name."""
|
||||
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
|
||||
|
||||
async def _get_weights():
|
||||
return await tokenizer_manager.get_weights_by_name(obj, None)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(_get_weights())
|
||||
|
||||
|
||||
class Runtime:
|
||||
"""
|
||||
A wrapper for the server.
|
||||
A wrapper for the HTTP server.
|
||||
This is used for launching the server in a python program without
|
||||
using the commond line interface.
|
||||
|
||||
It is mainly used for the frontend language.
|
||||
You should use the Engine class if you want to do normal offline processing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -839,201 +1035,3 @@ class Runtime:
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
|
||||
STREAM_END_SYMBOL = b"data: [DONE]"
|
||||
STREAM_CHUNK_START_SYMBOL = b"data:"
|
||||
|
||||
|
||||
class Engine:
|
||||
"""
|
||||
SRT Engine without an HTTP server layer.
|
||||
|
||||
This class provides a direct inference engine without the need for an HTTP server. It is designed for use cases where
|
||||
launching the HTTP server adds unnecessary complexity or overhead,
|
||||
"""
|
||||
|
||||
def __init__(self, log_level: str = "error", *args, **kwargs):
|
||||
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
|
||||
atexit.register(self.shutdown)
|
||||
|
||||
server_args = ServerArgs(*args, log_level=log_level, **kwargs)
|
||||
launch_engine(server_args=server_args)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||
prompt: Optional[Union[List[str], str]] = None,
|
||||
sampling_params: Optional[Union[List[Dict], Dict]] = None,
|
||||
# The token ids for text; one can either specify text or input_ids.
|
||||
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||
lora_path: Optional[List[Optional[str]]] = None,
|
||||
stream: bool = False,
|
||||
):
|
||||
obj = GenerateReqInput(
|
||||
text=prompt,
|
||||
input_ids=input_ids,
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=return_logprob,
|
||||
logprob_start_len=logprob_start_len,
|
||||
top_logprobs_num=top_logprobs_num,
|
||||
lora_path=lora_path,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
# get the current event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
ret = loop.run_until_complete(generate_request(obj, None))
|
||||
|
||||
if stream is True:
|
||||
|
||||
def generator_wrapper():
|
||||
offset = 0
|
||||
loop = asyncio.get_event_loop()
|
||||
generator = ret.body_iterator
|
||||
while True:
|
||||
chunk = loop.run_until_complete(generator.__anext__())
|
||||
|
||||
if chunk.startswith(STREAM_END_SYMBOL):
|
||||
break
|
||||
else:
|
||||
data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
|
||||
data["text"] = data["text"][offset:]
|
||||
offset += len(data["text"])
|
||||
yield data
|
||||
|
||||
# we cannot yield in the scope of generate() because python does not allow yield + return in the same function
|
||||
# however, it allows to wrap the generator as a subfunction and return
|
||||
return generator_wrapper()
|
||||
else:
|
||||
return ret
|
||||
|
||||
async def async_generate(
|
||||
self,
|
||||
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||
prompt: Optional[Union[List[str], str]] = None,
|
||||
sampling_params: Optional[Dict] = None,
|
||||
# The token ids for text; one can either specify text or input_ids.
|
||||
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||
lora_path: Optional[List[Optional[str]]] = None,
|
||||
stream: bool = False,
|
||||
):
|
||||
obj = GenerateReqInput(
|
||||
text=prompt,
|
||||
input_ids=input_ids,
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=return_logprob,
|
||||
logprob_start_len=logprob_start_len,
|
||||
top_logprobs_num=top_logprobs_num,
|
||||
lora_path=lora_path,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
ret = await generate_request(obj, None)
|
||||
|
||||
if stream is True:
|
||||
generator = ret.body_iterator
|
||||
|
||||
async def generator_wrapper():
|
||||
|
||||
offset = 0
|
||||
|
||||
while True:
|
||||
chunk = await generator.__anext__()
|
||||
|
||||
if chunk.startswith(STREAM_END_SYMBOL):
|
||||
break
|
||||
else:
|
||||
data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
|
||||
data["text"] = data["text"][offset:]
|
||||
offset += len(data["text"])
|
||||
yield data
|
||||
|
||||
return generator_wrapper()
|
||||
else:
|
||||
return ret
|
||||
|
||||
def shutdown(self):
|
||||
kill_process_tree(os.getpid(), include_parent=False)
|
||||
|
||||
def get_tokenizer(self):
|
||||
global tokenizer_manager
|
||||
|
||||
if tokenizer_manager is None:
|
||||
raise ReferenceError("Tokenizer Manager is not initialized.")
|
||||
else:
|
||||
return tokenizer_manager.tokenizer
|
||||
|
||||
def encode(
|
||||
self,
|
||||
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
||||
):
|
||||
obj = EmbeddingReqInput(text=prompt)
|
||||
|
||||
# get the current event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(encode_request(obj, None))
|
||||
|
||||
def start_profile(self):
|
||||
tokenizer_manager.start_profile()
|
||||
|
||||
def stop_profile(self):
|
||||
tokenizer_manager.stop_profile()
|
||||
|
||||
async def get_server_info(self):
|
||||
return await _get_server_info()
|
||||
|
||||
def init_weights_update_group(
|
||||
self,
|
||||
master_address: str,
|
||||
master_port: int,
|
||||
rank_offset: int,
|
||||
world_size: int,
|
||||
group_name: str,
|
||||
backend: str = "nccl",
|
||||
):
|
||||
"""Initialize parameter update group."""
|
||||
obj = InitWeightsUpdateGroupReqInput(
|
||||
master_address=master_address,
|
||||
master_port=master_port,
|
||||
rank_offset=rank_offset,
|
||||
world_size=world_size,
|
||||
group_name=group_name,
|
||||
backend=backend,
|
||||
)
|
||||
|
||||
async def _init_group():
|
||||
return await tokenizer_manager.init_weights_update_group(obj, None)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(_init_group())
|
||||
|
||||
def update_weights_from_distributed(self, name, dtype, shape):
|
||||
"""Update weights from distributed source."""
|
||||
obj = UpdateWeightsFromDistributedReqInput(
|
||||
name=name,
|
||||
dtype=dtype,
|
||||
shape=shape,
|
||||
)
|
||||
|
||||
async def _update_weights():
|
||||
return await tokenizer_manager.update_weights_from_distributed(obj, None)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(_update_weights())
|
||||
|
||||
def get_weights_by_name(self, name, truncate_size=100):
|
||||
"""Get weights by parameter name."""
|
||||
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
|
||||
|
||||
async def _get_weights():
|
||||
return await tokenizer_manager.get_weights_by_name(obj, None)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(_get_weights())
|
||||
|
||||
@@ -67,7 +67,7 @@ class TestGetWeightsByName(unittest.TestCase):
|
||||
terminate_process(self.process)
|
||||
|
||||
def assert_tie_word_embeddings(self, truncate_size):
|
||||
print(f"assert_tie_word_embeddings")
|
||||
print("assert_tie_word_embeddings")
|
||||
if self.backend == "Engine":
|
||||
backend_ret = _process_return(
|
||||
self.engine.get_weights_by_name("lm_head.weight", truncate_size)
|
||||
@@ -79,7 +79,7 @@ class TestGetWeightsByName(unittest.TestCase):
|
||||
json={"name": "lm_head.weight", "truncate_size": truncate_size},
|
||||
).json()
|
||||
)
|
||||
print(f"assert_tie_word_embeddings of hf and backend")
|
||||
print("assert_tie_word_embeddings of hf and backend")
|
||||
assert np.allclose(
|
||||
self.hf_model.get_parameter("model.embed_tokens.weight")
|
||||
.cpu()
|
||||
|
||||
@@ -127,7 +127,7 @@ def init_process_hf(
|
||||
hf_instruct_params = []
|
||||
hf_base_params = []
|
||||
|
||||
print(f"get parameter in hf instruct model and base model")
|
||||
print("get parameter in hf instruct model and base model")
|
||||
for parameter_name in checking_parameters:
|
||||
hf_instruct_params.append(
|
||||
hf_instruct_model.get_parameter(parameter_name)[:truncate_size]
|
||||
@@ -186,7 +186,6 @@ def init_process_hf(
|
||||
param_queue.put(("broadcast_time", broadcast_time))
|
||||
|
||||
# Delete the huggingface models to free up memory.
|
||||
|
||||
del hf_instruct_model
|
||||
del hf_base_model
|
||||
gc.collect()
|
||||
@@ -238,7 +237,6 @@ def init_process_sgl(
|
||||
print(f"rank {rank} init server on url: {url}")
|
||||
|
||||
# Get weights of instruct model, i.e. pre-training weights.
|
||||
|
||||
instruct_params = []
|
||||
for parameter_name in checking_parameters:
|
||||
instruct_params.append(
|
||||
@@ -253,7 +251,6 @@ def init_process_sgl(
|
||||
param_queue.put((f"sgl_dp_{rank}_instruct_params", instruct_params))
|
||||
|
||||
# Init weight update group with the training engine.
|
||||
|
||||
if backend == "Engine":
|
||||
engine.init_weights_update_group(
|
||||
master_address="localhost",
|
||||
@@ -282,7 +279,6 @@ def init_process_sgl(
|
||||
# The last parameter is lm_head.weight, which is tied
|
||||
# with embed_tokens.weight. Actually, we only need
|
||||
# to update embed_tokens.weight once.
|
||||
|
||||
tie_word_embeddings = (
|
||||
True if model_name == DEFAULT_SMALL_MODEL_NAME_FOR_TEST else False
|
||||
)
|
||||
@@ -291,7 +287,6 @@ def init_process_sgl(
|
||||
update_parameters.remove("lm_head.weight")
|
||||
|
||||
# Get weights from the training engine and update the inference engine.
|
||||
|
||||
for parameter_name in update_parameters:
|
||||
if backend == "Engine":
|
||||
engine.update_weights_from_distributed(
|
||||
@@ -312,7 +307,6 @@ def init_process_sgl(
|
||||
time_end_update = time.time()
|
||||
|
||||
# Measure the latency of broadcast/weights update.
|
||||
|
||||
update_time = time_end_update - time_begin_update
|
||||
print(
|
||||
f"fully update model_name {model_name} rank {rank} parameter from distributed time: {update_time:.3f}s"
|
||||
@@ -320,7 +314,6 @@ def init_process_sgl(
|
||||
param_queue.put((f"update_sgl_dp_{rank}_time", update_time))
|
||||
|
||||
# Get the weights of post-training model after weights update for correctness check.
|
||||
|
||||
base_params = []
|
||||
for parameter_name in checking_parameters:
|
||||
if backend == "Engine":
|
||||
@@ -340,7 +333,6 @@ def init_process_sgl(
|
||||
param_queue.put((f"sgl_dp_{rank}_base_params", base_params))
|
||||
|
||||
# Shutdown the engine or terminate the server process.
|
||||
|
||||
if backend == "Engine":
|
||||
engine.shutdown()
|
||||
else:
|
||||
@@ -426,7 +418,6 @@ def test_update_weights_from_distributed(
|
||||
|
||||
# Check the correctness of weights update by verifying
|
||||
# the weights of instruct model and base model.
|
||||
|
||||
for i in range(len(params["hf_instruct"])):
|
||||
verify_params_close(
|
||||
params["hf_instruct"][i],
|
||||
@@ -463,7 +454,6 @@ def test_update_weights_from_distributed(
|
||||
), "hf_instruct_params and hf_base_params have different lengths"
|
||||
|
||||
# Check if the weights of lm_head are tied with embed_tokens.
|
||||
|
||||
params_to_check = [
|
||||
(
|
||||
params["hf_instruct"],
|
||||
@@ -509,7 +499,6 @@ def test_update_weights_from_distributed(
|
||||
|
||||
# Time limit for broadcast and update on CI is 3 / 6
|
||||
# On local H100, it's 1 / 2
|
||||
|
||||
time_limit = 3 if model_name == DEFAULT_SMALL_MODEL_NAME_FOR_TEST else 6
|
||||
|
||||
assert (
|
||||
@@ -526,7 +515,6 @@ def test_update_weights_from_distributed(
|
||||
), f"update_sgl_dp_two_time exceeds time limit {time_limit}s"
|
||||
|
||||
# Delete the context and close the parameter queue.
|
||||
|
||||
del context
|
||||
param_queue.close()
|
||||
param_queue.join_thread()
|
||||
|
||||
Reference in New Issue
Block a user