[Minor] Fix code style (#2311)
This commit is contained in:
@@ -25,7 +25,6 @@ import uuid
|
|||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
import torch
|
|
||||||
import uvloop
|
import uvloop
|
||||||
import zmq
|
import zmq
|
||||||
import zmq.asyncio
|
import zmq.asyncio
|
||||||
@@ -337,6 +336,12 @@ class TokenizerManager:
|
|||||||
rids.append(tmp_obj.rid)
|
rids.append(tmp_obj.rid)
|
||||||
else:
|
else:
|
||||||
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
|
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
|
||||||
|
if batch_size > 128:
|
||||||
|
logger.warning(
|
||||||
|
"Sending a single large batch with parallel sampling (n > 1) has not been well optimized. "
|
||||||
|
"The performance might be better if you just duplicate the requests n times or use "
|
||||||
|
"many threads to send them one by one with parallel sampling (n > 1)."
|
||||||
|
)
|
||||||
|
|
||||||
# Tokenize all requests
|
# Tokenize all requests
|
||||||
objs = [obj[i] for i in range(batch_size)]
|
objs = [obj[i] for i in range(batch_size)]
|
||||||
@@ -494,9 +499,7 @@ class TokenizerManager:
|
|||||||
result = await self.parameter_update_result
|
result = await self.parameter_update_result
|
||||||
return result.success, result.message
|
return result.success, result.message
|
||||||
else:
|
else:
|
||||||
logger.error(
|
logger.error("Another parameter update is in progress in tokenizer manager")
|
||||||
f"Another parameter update is in progress in tokenizer manager"
|
|
||||||
)
|
|
||||||
return (
|
return (
|
||||||
False,
|
False,
|
||||||
"Another parameter update is in progress. Please try again later.",
|
"Another parameter update is in progress. Please try again later.",
|
||||||
@@ -597,47 +600,7 @@ class TokenizerManager:
|
|||||||
InitWeightsUpdateGroupReqOutput,
|
InitWeightsUpdateGroupReqOutput,
|
||||||
] = await self.recv_from_detokenizer.recv_pyobj()
|
] = await self.recv_from_detokenizer.recv_pyobj()
|
||||||
|
|
||||||
if isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
|
if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)):
|
||||||
if self.server_args.dp_size == 1:
|
|
||||||
self.model_update_result.set_result(recv_obj)
|
|
||||||
else: # self.server_args.dp_size > 1
|
|
||||||
self.model_update_tmp.append(recv_obj)
|
|
||||||
# set future if the all results are recevied
|
|
||||||
if len(self.model_update_tmp) == self.server_args.dp_size:
|
|
||||||
self.model_update_result.set_result(self.model_update_tmp)
|
|
||||||
continue
|
|
||||||
elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
|
|
||||||
assert (
|
|
||||||
self.server_args.dp_size == 1
|
|
||||||
), "dp_size must be 1 for update weights from distributed"
|
|
||||||
self.parameter_update_result.set_result(recv_obj)
|
|
||||||
continue
|
|
||||||
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
|
|
||||||
if self.server_args.dp_size == 1:
|
|
||||||
self.get_weights_by_name_result.set_result(recv_obj)
|
|
||||||
else:
|
|
||||||
self.get_weights_by_name_tmp.append(recv_obj)
|
|
||||||
if len(self.get_weights_by_name_tmp) == self.server_args.dp_size:
|
|
||||||
self.get_weights_by_name_result.set_result(
|
|
||||||
self.get_weights_by_name_tmp
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput):
|
|
||||||
assert (
|
|
||||||
self.server_args.dp_size == 1
|
|
||||||
), "dp_size must be 1 for init parameter update group"
|
|
||||||
self.init_weights_update_group_result.set_result(recv_obj)
|
|
||||||
continue
|
|
||||||
elif isinstance(recv_obj, OpenSessionReqOutput):
|
|
||||||
self.session_futures[recv_obj.session_id].set_result(
|
|
||||||
recv_obj.session_id
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
assert isinstance(
|
|
||||||
recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
|
|
||||||
), f"Unexpected obj received: {type(recv_obj)}"
|
|
||||||
|
|
||||||
for i, rid in enumerate(recv_obj.rids):
|
for i, rid in enumerate(recv_obj.rids):
|
||||||
state = self.rid_to_state.get(rid, None)
|
state = self.rid_to_state.get(rid, None)
|
||||||
if state is None:
|
if state is None:
|
||||||
@@ -683,14 +646,50 @@ class TokenizerManager:
|
|||||||
self.metrics_collector.inc_prompt_tokens(
|
self.metrics_collector.inc_prompt_tokens(
|
||||||
recv_obj.meta_info[i]["prompt_tokens"]
|
recv_obj.meta_info[i]["prompt_tokens"]
|
||||||
)
|
)
|
||||||
self.metrics_collector.inc_generation_tokens(completion_tokens)
|
self.metrics_collector.inc_generation_tokens(
|
||||||
|
completion_tokens
|
||||||
|
)
|
||||||
self.metrics_collector.observe_e2e_request_latency(
|
self.metrics_collector.observe_e2e_request_latency(
|
||||||
time.time() - state.created_time
|
time.time() - state.created_time
|
||||||
)
|
)
|
||||||
if completion_tokens >= 1:
|
if completion_tokens >= 1:
|
||||||
self.metrics_collector.observe_time_per_output_token(
|
self.metrics_collector.observe_time_per_output_token(
|
||||||
(time.time() - state.created_time) / completion_tokens
|
(time.time() - state.created_time)
|
||||||
|
/ completion_tokens
|
||||||
)
|
)
|
||||||
|
elif isinstance(recv_obj, OpenSessionReqOutput):
|
||||||
|
self.session_futures[recv_obj.session_id].set_result(
|
||||||
|
recv_obj.session_id
|
||||||
|
)
|
||||||
|
elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
|
||||||
|
if self.server_args.dp_size == 1:
|
||||||
|
self.model_update_result.set_result(recv_obj)
|
||||||
|
else: # self.server_args.dp_size > 1
|
||||||
|
self.model_update_tmp.append(recv_obj)
|
||||||
|
# set future if the all results are recevied
|
||||||
|
if len(self.model_update_tmp) == self.server_args.dp_size:
|
||||||
|
self.model_update_result.set_result(self.model_update_tmp)
|
||||||
|
elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput):
|
||||||
|
assert (
|
||||||
|
self.server_args.dp_size == 1
|
||||||
|
), "dp_size must be 1 for init parameter update group"
|
||||||
|
self.init_weights_update_group_result.set_result(recv_obj)
|
||||||
|
elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
|
||||||
|
assert (
|
||||||
|
self.server_args.dp_size == 1
|
||||||
|
), "dp_size must be 1 for update weights from distributed"
|
||||||
|
self.parameter_update_result.set_result(recv_obj)
|
||||||
|
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
|
||||||
|
if self.server_args.dp_size == 1:
|
||||||
|
self.get_weights_by_name_result.set_result(recv_obj)
|
||||||
|
else:
|
||||||
|
self.get_weights_by_name_tmp.append(recv_obj)
|
||||||
|
if len(self.get_weights_by_name_tmp) == self.server_args.dp_size:
|
||||||
|
self.get_weights_by_name_result.set_result(
|
||||||
|
self.get_weights_by_name_tmp
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid object: {recv_obj=}")
|
||||||
|
|
||||||
def convert_logprob_style(
|
def convert_logprob_style(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -218,16 +218,6 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
self.tp_group = get_tp_group()
|
self.tp_group = get_tp_group()
|
||||||
|
|
||||||
# Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
|
|
||||||
# so we disable padding in cuda graph.
|
|
||||||
if self.device == "cuda" and not all(
|
|
||||||
in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)
|
|
||||||
):
|
|
||||||
self.server_args.disable_cuda_graph_padding = True
|
|
||||||
logger.info(
|
|
||||||
"Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check memory for tensor parallelism
|
# Check memory for tensor parallelism
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
||||||
|
|||||||
@@ -82,7 +82,6 @@ from sglang.srt.utils import (
|
|||||||
assert_pkg_version,
|
assert_pkg_version,
|
||||||
configure_logger,
|
configure_logger,
|
||||||
delete_directory,
|
delete_directory,
|
||||||
init_custom_process_group,
|
|
||||||
is_port_available,
|
is_port_available,
|
||||||
kill_process_tree,
|
kill_process_tree,
|
||||||
maybe_set_triton_cache_manager,
|
maybe_set_triton_cache_manager,
|
||||||
@@ -154,13 +153,11 @@ async def get_model_info():
|
|||||||
|
|
||||||
@app.get("/get_server_info")
|
@app.get("/get_server_info")
|
||||||
async def get_server_info():
|
async def get_server_info():
|
||||||
try:
|
return {
|
||||||
return await _get_server_info()
|
**dataclasses.asdict(tokenizer_manager.server_args), # server args
|
||||||
|
**scheduler_info,
|
||||||
except Exception as e:
|
"version": __version__,
|
||||||
return ORJSONResponse(
|
}
|
||||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/flush_cache")
|
@app.post("/flush_cache")
|
||||||
@@ -567,14 +564,6 @@ def launch_server(
|
|||||||
t.join()
|
t.join()
|
||||||
|
|
||||||
|
|
||||||
async def _get_server_info():
|
|
||||||
return {
|
|
||||||
**dataclasses.asdict(tokenizer_manager.server_args), # server args
|
|
||||||
**scheduler_info,
|
|
||||||
"version": __version__,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _set_envs_and_config(server_args: ServerArgs):
|
def _set_envs_and_config(server_args: ServerArgs):
|
||||||
# Set global environments
|
# Set global environments
|
||||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||||
@@ -687,11 +676,218 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
|||||||
delete_directory(server_args.model_path)
|
delete_directory(server_args.model_path)
|
||||||
|
|
||||||
|
|
||||||
|
STREAM_END_SYMBOL = b"data: [DONE]"
|
||||||
|
STREAM_CHUNK_START_SYMBOL = b"data:"
|
||||||
|
|
||||||
|
|
||||||
|
class Engine:
|
||||||
|
"""
|
||||||
|
SRT Engine without an HTTP server layer.
|
||||||
|
|
||||||
|
This class provides a direct inference engine without the need for an HTTP server. It is designed for use cases where
|
||||||
|
launching the HTTP server adds unnecessary complexity or overhead,
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, log_level: str = "error", *args, **kwargs):
|
||||||
|
"""See the arguments in server_args.py::ServerArgs"""
|
||||||
|
|
||||||
|
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
|
||||||
|
atexit.register(self.shutdown)
|
||||||
|
|
||||||
|
server_args = ServerArgs(*args, log_level=log_level, **kwargs)
|
||||||
|
launch_engine(server_args=server_args)
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||||
|
prompt: Optional[Union[List[str], str]] = None,
|
||||||
|
sampling_params: Optional[Union[List[Dict], Dict]] = None,
|
||||||
|
# The token ids for text; one can either specify text or input_ids.
|
||||||
|
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
||||||
|
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||||
|
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||||
|
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||||
|
lora_path: Optional[List[Optional[str]]] = None,
|
||||||
|
stream: bool = False,
|
||||||
|
):
|
||||||
|
obj = GenerateReqInput(
|
||||||
|
text=prompt,
|
||||||
|
input_ids=input_ids,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
return_logprob=return_logprob,
|
||||||
|
logprob_start_len=logprob_start_len,
|
||||||
|
top_logprobs_num=top_logprobs_num,
|
||||||
|
lora_path=lora_path,
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
|
||||||
|
# get the current event loop
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
ret = loop.run_until_complete(generate_request(obj, None))
|
||||||
|
|
||||||
|
if stream is True:
|
||||||
|
|
||||||
|
def generator_wrapper():
|
||||||
|
offset = 0
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
generator = ret.body_iterator
|
||||||
|
while True:
|
||||||
|
chunk = loop.run_until_complete(generator.__anext__())
|
||||||
|
|
||||||
|
if chunk.startswith(STREAM_END_SYMBOL):
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
|
||||||
|
data["text"] = data["text"][offset:]
|
||||||
|
offset += len(data["text"])
|
||||||
|
yield data
|
||||||
|
|
||||||
|
# we cannot yield in the scope of generate() because python does not allow yield + return in the same function
|
||||||
|
# however, it allows to wrap the generator as a subfunction and return
|
||||||
|
return generator_wrapper()
|
||||||
|
else:
|
||||||
|
return ret
|
||||||
|
|
||||||
|
async def async_generate(
|
||||||
|
self,
|
||||||
|
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||||
|
prompt: Optional[Union[List[str], str]] = None,
|
||||||
|
sampling_params: Optional[Dict] = None,
|
||||||
|
# The token ids for text; one can either specify text or input_ids.
|
||||||
|
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
||||||
|
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||||
|
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||||
|
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||||
|
lora_path: Optional[List[Optional[str]]] = None,
|
||||||
|
stream: bool = False,
|
||||||
|
):
|
||||||
|
obj = GenerateReqInput(
|
||||||
|
text=prompt,
|
||||||
|
input_ids=input_ids,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
return_logprob=return_logprob,
|
||||||
|
logprob_start_len=logprob_start_len,
|
||||||
|
top_logprobs_num=top_logprobs_num,
|
||||||
|
lora_path=lora_path,
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
|
||||||
|
ret = await generate_request(obj, None)
|
||||||
|
|
||||||
|
if stream is True:
|
||||||
|
generator = ret.body_iterator
|
||||||
|
|
||||||
|
async def generator_wrapper():
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
chunk = await generator.__anext__()
|
||||||
|
|
||||||
|
if chunk.startswith(STREAM_END_SYMBOL):
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
|
||||||
|
data["text"] = data["text"][offset:]
|
||||||
|
offset += len(data["text"])
|
||||||
|
yield data
|
||||||
|
|
||||||
|
return generator_wrapper()
|
||||||
|
else:
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
kill_process_tree(os.getpid(), include_parent=False)
|
||||||
|
|
||||||
|
def get_tokenizer(self):
|
||||||
|
global tokenizer_manager
|
||||||
|
|
||||||
|
if tokenizer_manager is None:
|
||||||
|
raise ReferenceError("Tokenizer Manager is not initialized.")
|
||||||
|
else:
|
||||||
|
return tokenizer_manager.tokenizer
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
||||||
|
):
|
||||||
|
obj = EmbeddingReqInput(text=prompt)
|
||||||
|
|
||||||
|
# get the current event loop
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return loop.run_until_complete(encode_request(obj, None))
|
||||||
|
|
||||||
|
def start_profile(self):
|
||||||
|
tokenizer_manager.start_profile()
|
||||||
|
|
||||||
|
def stop_profile(self):
|
||||||
|
tokenizer_manager.stop_profile()
|
||||||
|
|
||||||
|
def get_server_info(self):
|
||||||
|
return {
|
||||||
|
**dataclasses.asdict(tokenizer_manager.server_args), # server args
|
||||||
|
**scheduler_info,
|
||||||
|
"version": __version__,
|
||||||
|
}
|
||||||
|
|
||||||
|
def init_weights_update_group(
|
||||||
|
self,
|
||||||
|
master_address: str,
|
||||||
|
master_port: int,
|
||||||
|
rank_offset: int,
|
||||||
|
world_size: int,
|
||||||
|
group_name: str,
|
||||||
|
backend: str = "nccl",
|
||||||
|
):
|
||||||
|
"""Initialize parameter update group."""
|
||||||
|
obj = InitWeightsUpdateGroupReqInput(
|
||||||
|
master_address=master_address,
|
||||||
|
master_port=master_port,
|
||||||
|
rank_offset=rank_offset,
|
||||||
|
world_size=world_size,
|
||||||
|
group_name=group_name,
|
||||||
|
backend=backend,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _init_group():
|
||||||
|
return await tokenizer_manager.init_weights_update_group(obj, None)
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return loop.run_until_complete(_init_group())
|
||||||
|
|
||||||
|
def update_weights_from_distributed(self, name, dtype, shape):
|
||||||
|
"""Update weights from distributed source."""
|
||||||
|
obj = UpdateWeightsFromDistributedReqInput(
|
||||||
|
name=name,
|
||||||
|
dtype=dtype,
|
||||||
|
shape=shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _update_weights():
|
||||||
|
return await tokenizer_manager.update_weights_from_distributed(obj, None)
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return loop.run_until_complete(_update_weights())
|
||||||
|
|
||||||
|
def get_weights_by_name(self, name, truncate_size=100):
|
||||||
|
"""Get weights by parameter name."""
|
||||||
|
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
|
||||||
|
|
||||||
|
async def _get_weights():
|
||||||
|
return await tokenizer_manager.get_weights_by_name(obj, None)
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return loop.run_until_complete(_get_weights())
|
||||||
|
|
||||||
|
|
||||||
class Runtime:
|
class Runtime:
|
||||||
"""
|
"""
|
||||||
A wrapper for the server.
|
A wrapper for the HTTP server.
|
||||||
This is used for launching the server in a python program without
|
This is used for launching the server in a python program without
|
||||||
using the commond line interface.
|
using the commond line interface.
|
||||||
|
|
||||||
|
It is mainly used for the frontend language.
|
||||||
|
You should use the Engine class if you want to do normal offline processing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -839,201 +1035,3 @@ class Runtime:
|
|||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
self.shutdown()
|
self.shutdown()
|
||||||
|
|
||||||
|
|
||||||
STREAM_END_SYMBOL = b"data: [DONE]"
|
|
||||||
STREAM_CHUNK_START_SYMBOL = b"data:"
|
|
||||||
|
|
||||||
|
|
||||||
class Engine:
|
|
||||||
"""
|
|
||||||
SRT Engine without an HTTP server layer.
|
|
||||||
|
|
||||||
This class provides a direct inference engine without the need for an HTTP server. It is designed for use cases where
|
|
||||||
launching the HTTP server adds unnecessary complexity or overhead,
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, log_level: str = "error", *args, **kwargs):
|
|
||||||
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
|
|
||||||
atexit.register(self.shutdown)
|
|
||||||
|
|
||||||
server_args = ServerArgs(*args, log_level=log_level, **kwargs)
|
|
||||||
launch_engine(server_args=server_args)
|
|
||||||
|
|
||||||
def generate(
|
|
||||||
self,
|
|
||||||
# The input prompt. It can be a single prompt or a batch of prompts.
|
|
||||||
prompt: Optional[Union[List[str], str]] = None,
|
|
||||||
sampling_params: Optional[Union[List[Dict], Dict]] = None,
|
|
||||||
# The token ids for text; one can either specify text or input_ids.
|
|
||||||
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
|
||||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
|
||||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
|
||||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
|
||||||
lora_path: Optional[List[Optional[str]]] = None,
|
|
||||||
stream: bool = False,
|
|
||||||
):
|
|
||||||
obj = GenerateReqInput(
|
|
||||||
text=prompt,
|
|
||||||
input_ids=input_ids,
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
return_logprob=return_logprob,
|
|
||||||
logprob_start_len=logprob_start_len,
|
|
||||||
top_logprobs_num=top_logprobs_num,
|
|
||||||
lora_path=lora_path,
|
|
||||||
stream=stream,
|
|
||||||
)
|
|
||||||
|
|
||||||
# get the current event loop
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
ret = loop.run_until_complete(generate_request(obj, None))
|
|
||||||
|
|
||||||
if stream is True:
|
|
||||||
|
|
||||||
def generator_wrapper():
|
|
||||||
offset = 0
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
generator = ret.body_iterator
|
|
||||||
while True:
|
|
||||||
chunk = loop.run_until_complete(generator.__anext__())
|
|
||||||
|
|
||||||
if chunk.startswith(STREAM_END_SYMBOL):
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
|
|
||||||
data["text"] = data["text"][offset:]
|
|
||||||
offset += len(data["text"])
|
|
||||||
yield data
|
|
||||||
|
|
||||||
# we cannot yield in the scope of generate() because python does not allow yield + return in the same function
|
|
||||||
# however, it allows to wrap the generator as a subfunction and return
|
|
||||||
return generator_wrapper()
|
|
||||||
else:
|
|
||||||
return ret
|
|
||||||
|
|
||||||
async def async_generate(
|
|
||||||
self,
|
|
||||||
# The input prompt. It can be a single prompt or a batch of prompts.
|
|
||||||
prompt: Optional[Union[List[str], str]] = None,
|
|
||||||
sampling_params: Optional[Dict] = None,
|
|
||||||
# The token ids for text; one can either specify text or input_ids.
|
|
||||||
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
|
||||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
|
||||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
|
||||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
|
||||||
lora_path: Optional[List[Optional[str]]] = None,
|
|
||||||
stream: bool = False,
|
|
||||||
):
|
|
||||||
obj = GenerateReqInput(
|
|
||||||
text=prompt,
|
|
||||||
input_ids=input_ids,
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
return_logprob=return_logprob,
|
|
||||||
logprob_start_len=logprob_start_len,
|
|
||||||
top_logprobs_num=top_logprobs_num,
|
|
||||||
lora_path=lora_path,
|
|
||||||
stream=stream,
|
|
||||||
)
|
|
||||||
|
|
||||||
ret = await generate_request(obj, None)
|
|
||||||
|
|
||||||
if stream is True:
|
|
||||||
generator = ret.body_iterator
|
|
||||||
|
|
||||||
async def generator_wrapper():
|
|
||||||
|
|
||||||
offset = 0
|
|
||||||
|
|
||||||
while True:
|
|
||||||
chunk = await generator.__anext__()
|
|
||||||
|
|
||||||
if chunk.startswith(STREAM_END_SYMBOL):
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
|
|
||||||
data["text"] = data["text"][offset:]
|
|
||||||
offset += len(data["text"])
|
|
||||||
yield data
|
|
||||||
|
|
||||||
return generator_wrapper()
|
|
||||||
else:
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def shutdown(self):
|
|
||||||
kill_process_tree(os.getpid(), include_parent=False)
|
|
||||||
|
|
||||||
def get_tokenizer(self):
|
|
||||||
global tokenizer_manager
|
|
||||||
|
|
||||||
if tokenizer_manager is None:
|
|
||||||
raise ReferenceError("Tokenizer Manager is not initialized.")
|
|
||||||
else:
|
|
||||||
return tokenizer_manager.tokenizer
|
|
||||||
|
|
||||||
def encode(
|
|
||||||
self,
|
|
||||||
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
|
||||||
):
|
|
||||||
obj = EmbeddingReqInput(text=prompt)
|
|
||||||
|
|
||||||
# get the current event loop
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
return loop.run_until_complete(encode_request(obj, None))
|
|
||||||
|
|
||||||
def start_profile(self):
|
|
||||||
tokenizer_manager.start_profile()
|
|
||||||
|
|
||||||
def stop_profile(self):
|
|
||||||
tokenizer_manager.stop_profile()
|
|
||||||
|
|
||||||
async def get_server_info(self):
|
|
||||||
return await _get_server_info()
|
|
||||||
|
|
||||||
def init_weights_update_group(
|
|
||||||
self,
|
|
||||||
master_address: str,
|
|
||||||
master_port: int,
|
|
||||||
rank_offset: int,
|
|
||||||
world_size: int,
|
|
||||||
group_name: str,
|
|
||||||
backend: str = "nccl",
|
|
||||||
):
|
|
||||||
"""Initialize parameter update group."""
|
|
||||||
obj = InitWeightsUpdateGroupReqInput(
|
|
||||||
master_address=master_address,
|
|
||||||
master_port=master_port,
|
|
||||||
rank_offset=rank_offset,
|
|
||||||
world_size=world_size,
|
|
||||||
group_name=group_name,
|
|
||||||
backend=backend,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _init_group():
|
|
||||||
return await tokenizer_manager.init_weights_update_group(obj, None)
|
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
return loop.run_until_complete(_init_group())
|
|
||||||
|
|
||||||
def update_weights_from_distributed(self, name, dtype, shape):
|
|
||||||
"""Update weights from distributed source."""
|
|
||||||
obj = UpdateWeightsFromDistributedReqInput(
|
|
||||||
name=name,
|
|
||||||
dtype=dtype,
|
|
||||||
shape=shape,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _update_weights():
|
|
||||||
return await tokenizer_manager.update_weights_from_distributed(obj, None)
|
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
return loop.run_until_complete(_update_weights())
|
|
||||||
|
|
||||||
def get_weights_by_name(self, name, truncate_size=100):
|
|
||||||
"""Get weights by parameter name."""
|
|
||||||
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
|
|
||||||
|
|
||||||
async def _get_weights():
|
|
||||||
return await tokenizer_manager.get_weights_by_name(obj, None)
|
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
return loop.run_until_complete(_get_weights())
|
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ class TestGetWeightsByName(unittest.TestCase):
|
|||||||
terminate_process(self.process)
|
terminate_process(self.process)
|
||||||
|
|
||||||
def assert_tie_word_embeddings(self, truncate_size):
|
def assert_tie_word_embeddings(self, truncate_size):
|
||||||
print(f"assert_tie_word_embeddings")
|
print("assert_tie_word_embeddings")
|
||||||
if self.backend == "Engine":
|
if self.backend == "Engine":
|
||||||
backend_ret = _process_return(
|
backend_ret = _process_return(
|
||||||
self.engine.get_weights_by_name("lm_head.weight", truncate_size)
|
self.engine.get_weights_by_name("lm_head.weight", truncate_size)
|
||||||
@@ -79,7 +79,7 @@ class TestGetWeightsByName(unittest.TestCase):
|
|||||||
json={"name": "lm_head.weight", "truncate_size": truncate_size},
|
json={"name": "lm_head.weight", "truncate_size": truncate_size},
|
||||||
).json()
|
).json()
|
||||||
)
|
)
|
||||||
print(f"assert_tie_word_embeddings of hf and backend")
|
print("assert_tie_word_embeddings of hf and backend")
|
||||||
assert np.allclose(
|
assert np.allclose(
|
||||||
self.hf_model.get_parameter("model.embed_tokens.weight")
|
self.hf_model.get_parameter("model.embed_tokens.weight")
|
||||||
.cpu()
|
.cpu()
|
||||||
|
|||||||
@@ -127,7 +127,7 @@ def init_process_hf(
|
|||||||
hf_instruct_params = []
|
hf_instruct_params = []
|
||||||
hf_base_params = []
|
hf_base_params = []
|
||||||
|
|
||||||
print(f"get parameter in hf instruct model and base model")
|
print("get parameter in hf instruct model and base model")
|
||||||
for parameter_name in checking_parameters:
|
for parameter_name in checking_parameters:
|
||||||
hf_instruct_params.append(
|
hf_instruct_params.append(
|
||||||
hf_instruct_model.get_parameter(parameter_name)[:truncate_size]
|
hf_instruct_model.get_parameter(parameter_name)[:truncate_size]
|
||||||
@@ -186,7 +186,6 @@ def init_process_hf(
|
|||||||
param_queue.put(("broadcast_time", broadcast_time))
|
param_queue.put(("broadcast_time", broadcast_time))
|
||||||
|
|
||||||
# Delete the huggingface models to free up memory.
|
# Delete the huggingface models to free up memory.
|
||||||
|
|
||||||
del hf_instruct_model
|
del hf_instruct_model
|
||||||
del hf_base_model
|
del hf_base_model
|
||||||
gc.collect()
|
gc.collect()
|
||||||
@@ -238,7 +237,6 @@ def init_process_sgl(
|
|||||||
print(f"rank {rank} init server on url: {url}")
|
print(f"rank {rank} init server on url: {url}")
|
||||||
|
|
||||||
# Get weights of instruct model, i.e. pre-training weights.
|
# Get weights of instruct model, i.e. pre-training weights.
|
||||||
|
|
||||||
instruct_params = []
|
instruct_params = []
|
||||||
for parameter_name in checking_parameters:
|
for parameter_name in checking_parameters:
|
||||||
instruct_params.append(
|
instruct_params.append(
|
||||||
@@ -253,7 +251,6 @@ def init_process_sgl(
|
|||||||
param_queue.put((f"sgl_dp_{rank}_instruct_params", instruct_params))
|
param_queue.put((f"sgl_dp_{rank}_instruct_params", instruct_params))
|
||||||
|
|
||||||
# Init weight update group with the training engine.
|
# Init weight update group with the training engine.
|
||||||
|
|
||||||
if backend == "Engine":
|
if backend == "Engine":
|
||||||
engine.init_weights_update_group(
|
engine.init_weights_update_group(
|
||||||
master_address="localhost",
|
master_address="localhost",
|
||||||
@@ -282,7 +279,6 @@ def init_process_sgl(
|
|||||||
# The last parameter is lm_head.weight, which is tied
|
# The last parameter is lm_head.weight, which is tied
|
||||||
# with embed_tokens.weight. Actually, we only need
|
# with embed_tokens.weight. Actually, we only need
|
||||||
# to update embed_tokens.weight once.
|
# to update embed_tokens.weight once.
|
||||||
|
|
||||||
tie_word_embeddings = (
|
tie_word_embeddings = (
|
||||||
True if model_name == DEFAULT_SMALL_MODEL_NAME_FOR_TEST else False
|
True if model_name == DEFAULT_SMALL_MODEL_NAME_FOR_TEST else False
|
||||||
)
|
)
|
||||||
@@ -291,7 +287,6 @@ def init_process_sgl(
|
|||||||
update_parameters.remove("lm_head.weight")
|
update_parameters.remove("lm_head.weight")
|
||||||
|
|
||||||
# Get weights from the training engine and update the inference engine.
|
# Get weights from the training engine and update the inference engine.
|
||||||
|
|
||||||
for parameter_name in update_parameters:
|
for parameter_name in update_parameters:
|
||||||
if backend == "Engine":
|
if backend == "Engine":
|
||||||
engine.update_weights_from_distributed(
|
engine.update_weights_from_distributed(
|
||||||
@@ -312,7 +307,6 @@ def init_process_sgl(
|
|||||||
time_end_update = time.time()
|
time_end_update = time.time()
|
||||||
|
|
||||||
# Measure the latency of broadcast/weights update.
|
# Measure the latency of broadcast/weights update.
|
||||||
|
|
||||||
update_time = time_end_update - time_begin_update
|
update_time = time_end_update - time_begin_update
|
||||||
print(
|
print(
|
||||||
f"fully update model_name {model_name} rank {rank} parameter from distributed time: {update_time:.3f}s"
|
f"fully update model_name {model_name} rank {rank} parameter from distributed time: {update_time:.3f}s"
|
||||||
@@ -320,7 +314,6 @@ def init_process_sgl(
|
|||||||
param_queue.put((f"update_sgl_dp_{rank}_time", update_time))
|
param_queue.put((f"update_sgl_dp_{rank}_time", update_time))
|
||||||
|
|
||||||
# Get the weights of post-training model after weights update for correctness check.
|
# Get the weights of post-training model after weights update for correctness check.
|
||||||
|
|
||||||
base_params = []
|
base_params = []
|
||||||
for parameter_name in checking_parameters:
|
for parameter_name in checking_parameters:
|
||||||
if backend == "Engine":
|
if backend == "Engine":
|
||||||
@@ -340,7 +333,6 @@ def init_process_sgl(
|
|||||||
param_queue.put((f"sgl_dp_{rank}_base_params", base_params))
|
param_queue.put((f"sgl_dp_{rank}_base_params", base_params))
|
||||||
|
|
||||||
# Shutdown the engine or terminate the server process.
|
# Shutdown the engine or terminate the server process.
|
||||||
|
|
||||||
if backend == "Engine":
|
if backend == "Engine":
|
||||||
engine.shutdown()
|
engine.shutdown()
|
||||||
else:
|
else:
|
||||||
@@ -426,7 +418,6 @@ def test_update_weights_from_distributed(
|
|||||||
|
|
||||||
# Check the correctness of weights update by verifying
|
# Check the correctness of weights update by verifying
|
||||||
# the weights of instruct model and base model.
|
# the weights of instruct model and base model.
|
||||||
|
|
||||||
for i in range(len(params["hf_instruct"])):
|
for i in range(len(params["hf_instruct"])):
|
||||||
verify_params_close(
|
verify_params_close(
|
||||||
params["hf_instruct"][i],
|
params["hf_instruct"][i],
|
||||||
@@ -463,7 +454,6 @@ def test_update_weights_from_distributed(
|
|||||||
), "hf_instruct_params and hf_base_params have different lengths"
|
), "hf_instruct_params and hf_base_params have different lengths"
|
||||||
|
|
||||||
# Check if the weights of lm_head are tied with embed_tokens.
|
# Check if the weights of lm_head are tied with embed_tokens.
|
||||||
|
|
||||||
params_to_check = [
|
params_to_check = [
|
||||||
(
|
(
|
||||||
params["hf_instruct"],
|
params["hf_instruct"],
|
||||||
@@ -509,7 +499,6 @@ def test_update_weights_from_distributed(
|
|||||||
|
|
||||||
# Time limit for broadcast and update on CI is 3 / 6
|
# Time limit for broadcast and update on CI is 3 / 6
|
||||||
# On local H100, it's 1 / 2
|
# On local H100, it's 1 / 2
|
||||||
|
|
||||||
time_limit = 3 if model_name == DEFAULT_SMALL_MODEL_NAME_FOR_TEST else 6
|
time_limit = 3 if model_name == DEFAULT_SMALL_MODEL_NAME_FOR_TEST else 6
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
@@ -526,7 +515,6 @@ def test_update_weights_from_distributed(
|
|||||||
), f"update_sgl_dp_two_time exceeds time limit {time_limit}s"
|
), f"update_sgl_dp_two_time exceeds time limit {time_limit}s"
|
||||||
|
|
||||||
# Delete the context and close the parameter queue.
|
# Delete the context and close the parameter queue.
|
||||||
|
|
||||||
del context
|
del context
|
||||||
param_queue.close()
|
param_queue.close()
|
||||||
param_queue.join_thread()
|
param_queue.join_thread()
|
||||||
|
|||||||
Reference in New Issue
Block a user