[Minor] Fix code style (#2311)

This commit is contained in:
Lianmin Zheng
2024-12-02 02:27:36 -08:00
committed by GitHub
parent c54bda300a
commit 18108abe5d
5 changed files with 292 additions and 317 deletions

View File

@@ -25,7 +25,6 @@ import uuid
from typing import Dict, List, Optional, Tuple, Union
import fastapi
import torch
import uvloop
import zmq
import zmq.asyncio
@@ -337,6 +336,12 @@ class TokenizerManager:
rids.append(tmp_obj.rid)
else:
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
if batch_size > 128:
logger.warning(
"Sending a single large batch with parallel sampling (n > 1) has not been well optimized. "
"The performance might be better if you just duplicate the requests n times or use "
"many threads to send them one by one with parallel sampling (n > 1)."
)
# Tokenize all requests
objs = [obj[i] for i in range(batch_size)]
@@ -494,9 +499,7 @@ class TokenizerManager:
result = await self.parameter_update_result
return result.success, result.message
else:
logger.error(
f"Another parameter update is in progress in tokenizer manager"
)
logger.error("Another parameter update is in progress in tokenizer manager")
return (
False,
"Another parameter update is in progress. Please try again later.",
@@ -597,7 +600,68 @@ class TokenizerManager:
InitWeightsUpdateGroupReqOutput,
] = await self.recv_from_detokenizer.recv_pyobj()
if isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)):
for i, rid in enumerate(recv_obj.rids):
state = self.rid_to_state.get(rid, None)
if state is None:
continue
recv_obj.meta_info[i]["id"] = rid
if isinstance(recv_obj, BatchStrOut):
out_dict = {
"text": recv_obj.output_strs[i],
"meta_info": recv_obj.meta_info[i],
}
elif isinstance(recv_obj, BatchTokenIDOut):
out_dict = {
"token_ids": recv_obj.output_ids[i],
"meta_info": recv_obj.meta_info[i],
}
else:
assert isinstance(recv_obj, BatchEmbeddingOut)
out_dict = {
"embedding": recv_obj.embeddings[i],
"meta_info": recv_obj.meta_info[i],
}
state.out_list.append(out_dict)
state.finished = recv_obj.finished_reason[i] is not None
state.event.set()
if self.enable_metrics:
completion_tokens = recv_obj.meta_info[i]["completion_tokens"]
if state.first_token_time is None:
state.first_token_time = time.time()
self.metrics_collector.observe_time_to_first_token(
state.first_token_time - state.created_time
)
else:
if completion_tokens >= 2:
self.metrics_collector.observe_time_per_output_token(
(time.time() - state.first_token_time)
/ (completion_tokens - 1)
)
if state.finished:
self.metrics_collector.inc_prompt_tokens(
recv_obj.meta_info[i]["prompt_tokens"]
)
self.metrics_collector.inc_generation_tokens(
completion_tokens
)
self.metrics_collector.observe_e2e_request_latency(
time.time() - state.created_time
)
if completion_tokens >= 1:
self.metrics_collector.observe_time_per_output_token(
(time.time() - state.created_time)
/ completion_tokens
)
elif isinstance(recv_obj, OpenSessionReqOutput):
self.session_futures[recv_obj.session_id].set_result(
recv_obj.session_id
)
elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
if self.server_args.dp_size == 1:
self.model_update_result.set_result(recv_obj)
else: # self.server_args.dp_size > 1
@@ -605,13 +669,16 @@ class TokenizerManager:
# set future if the all results are recevied
if len(self.model_update_tmp) == self.server_args.dp_size:
self.model_update_result.set_result(self.model_update_tmp)
continue
elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput):
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for init parameter update group"
self.init_weights_update_group_result.set_result(recv_obj)
elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for update weights from distributed"
self.parameter_update_result.set_result(recv_obj)
continue
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
if self.server_args.dp_size == 1:
self.get_weights_by_name_result.set_result(recv_obj)
@@ -621,76 +688,8 @@ class TokenizerManager:
self.get_weights_by_name_result.set_result(
self.get_weights_by_name_tmp
)
continue
elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput):
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for init parameter update group"
self.init_weights_update_group_result.set_result(recv_obj)
continue
elif isinstance(recv_obj, OpenSessionReqOutput):
self.session_futures[recv_obj.session_id].set_result(
recv_obj.session_id
)
continue
assert isinstance(
recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
), f"Unexpected obj received: {type(recv_obj)}"
for i, rid in enumerate(recv_obj.rids):
state = self.rid_to_state.get(rid, None)
if state is None:
continue
recv_obj.meta_info[i]["id"] = rid
if isinstance(recv_obj, BatchStrOut):
out_dict = {
"text": recv_obj.output_strs[i],
"meta_info": recv_obj.meta_info[i],
}
elif isinstance(recv_obj, BatchTokenIDOut):
out_dict = {
"token_ids": recv_obj.output_ids[i],
"meta_info": recv_obj.meta_info[i],
}
else:
assert isinstance(recv_obj, BatchEmbeddingOut)
out_dict = {
"embedding": recv_obj.embeddings[i],
"meta_info": recv_obj.meta_info[i],
}
state.out_list.append(out_dict)
state.finished = recv_obj.finished_reason[i] is not None
state.event.set()
if self.enable_metrics:
completion_tokens = recv_obj.meta_info[i]["completion_tokens"]
if state.first_token_time is None:
state.first_token_time = time.time()
self.metrics_collector.observe_time_to_first_token(
state.first_token_time - state.created_time
)
else:
if completion_tokens >= 2:
self.metrics_collector.observe_time_per_output_token(
(time.time() - state.first_token_time)
/ (completion_tokens - 1)
)
if state.finished:
self.metrics_collector.inc_prompt_tokens(
recv_obj.meta_info[i]["prompt_tokens"]
)
self.metrics_collector.inc_generation_tokens(completion_tokens)
self.metrics_collector.observe_e2e_request_latency(
time.time() - state.created_time
)
if completion_tokens >= 1:
self.metrics_collector.observe_time_per_output_token(
(time.time() - state.created_time) / completion_tokens
)
else:
raise ValueError(f"Invalid object: {recv_obj=}")
def convert_logprob_style(
self,

View File

@@ -218,16 +218,6 @@ class ModelRunner:
)
self.tp_group = get_tp_group()
# Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
# so we disable padding in cuda graph.
if self.device == "cuda" and not all(
in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)
):
self.server_args.disable_cuda_graph_padding = True
logger.info(
"Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
)
# Check memory for tensor parallelism
if self.tp_size > 1:
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)

View File

@@ -82,7 +82,6 @@ from sglang.srt.utils import (
assert_pkg_version,
configure_logger,
delete_directory,
init_custom_process_group,
is_port_available,
kill_process_tree,
maybe_set_triton_cache_manager,
@@ -154,13 +153,11 @@ async def get_model_info():
@app.get("/get_server_info")
async def get_server_info():
try:
return await _get_server_info()
except Exception as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
return {
**dataclasses.asdict(tokenizer_manager.server_args), # server args
**scheduler_info,
"version": __version__,
}
@app.post("/flush_cache")
@@ -567,14 +564,6 @@ def launch_server(
t.join()
async def _get_server_info():
return {
**dataclasses.asdict(tokenizer_manager.server_args), # server args
**scheduler_info,
"version": __version__,
}
def _set_envs_and_config(server_args: ServerArgs):
# Set global environments
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
@@ -687,11 +676,218 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
delete_directory(server_args.model_path)
STREAM_END_SYMBOL = b"data: [DONE]"
STREAM_CHUNK_START_SYMBOL = b"data:"
class Engine:
"""
SRT Engine without an HTTP server layer.
This class provides a direct inference engine without the need for an HTTP server. It is designed for use cases where
launching the HTTP server adds unnecessary complexity or overhead,
"""
def __init__(self, log_level: str = "error", *args, **kwargs):
"""See the arguments in server_args.py::ServerArgs"""
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
atexit.register(self.shutdown)
server_args = ServerArgs(*args, log_level=log_level, **kwargs)
launch_engine(server_args=server_args)
def generate(
self,
# The input prompt. It can be a single prompt or a batch of prompts.
prompt: Optional[Union[List[str], str]] = None,
sampling_params: Optional[Union[List[Dict], Dict]] = None,
# The token ids for text; one can either specify text or input_ids.
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
lora_path: Optional[List[Optional[str]]] = None,
stream: bool = False,
):
obj = GenerateReqInput(
text=prompt,
input_ids=input_ids,
sampling_params=sampling_params,
return_logprob=return_logprob,
logprob_start_len=logprob_start_len,
top_logprobs_num=top_logprobs_num,
lora_path=lora_path,
stream=stream,
)
# get the current event loop
loop = asyncio.get_event_loop()
ret = loop.run_until_complete(generate_request(obj, None))
if stream is True:
def generator_wrapper():
offset = 0
loop = asyncio.get_event_loop()
generator = ret.body_iterator
while True:
chunk = loop.run_until_complete(generator.__anext__())
if chunk.startswith(STREAM_END_SYMBOL):
break
else:
data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
data["text"] = data["text"][offset:]
offset += len(data["text"])
yield data
# we cannot yield in the scope of generate() because python does not allow yield + return in the same function
# however, it allows to wrap the generator as a subfunction and return
return generator_wrapper()
else:
return ret
async def async_generate(
self,
# The input prompt. It can be a single prompt or a batch of prompts.
prompt: Optional[Union[List[str], str]] = None,
sampling_params: Optional[Dict] = None,
# The token ids for text; one can either specify text or input_ids.
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
lora_path: Optional[List[Optional[str]]] = None,
stream: bool = False,
):
obj = GenerateReqInput(
text=prompt,
input_ids=input_ids,
sampling_params=sampling_params,
return_logprob=return_logprob,
logprob_start_len=logprob_start_len,
top_logprobs_num=top_logprobs_num,
lora_path=lora_path,
stream=stream,
)
ret = await generate_request(obj, None)
if stream is True:
generator = ret.body_iterator
async def generator_wrapper():
offset = 0
while True:
chunk = await generator.__anext__()
if chunk.startswith(STREAM_END_SYMBOL):
break
else:
data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
data["text"] = data["text"][offset:]
offset += len(data["text"])
yield data
return generator_wrapper()
else:
return ret
def shutdown(self):
kill_process_tree(os.getpid(), include_parent=False)
def get_tokenizer(self):
global tokenizer_manager
if tokenizer_manager is None:
raise ReferenceError("Tokenizer Manager is not initialized.")
else:
return tokenizer_manager.tokenizer
def encode(
self,
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
):
obj = EmbeddingReqInput(text=prompt)
# get the current event loop
loop = asyncio.get_event_loop()
return loop.run_until_complete(encode_request(obj, None))
def start_profile(self):
tokenizer_manager.start_profile()
def stop_profile(self):
tokenizer_manager.stop_profile()
def get_server_info(self):
return {
**dataclasses.asdict(tokenizer_manager.server_args), # server args
**scheduler_info,
"version": __version__,
}
def init_weights_update_group(
self,
master_address: str,
master_port: int,
rank_offset: int,
world_size: int,
group_name: str,
backend: str = "nccl",
):
"""Initialize parameter update group."""
obj = InitWeightsUpdateGroupReqInput(
master_address=master_address,
master_port=master_port,
rank_offset=rank_offset,
world_size=world_size,
group_name=group_name,
backend=backend,
)
async def _init_group():
return await tokenizer_manager.init_weights_update_group(obj, None)
loop = asyncio.get_event_loop()
return loop.run_until_complete(_init_group())
def update_weights_from_distributed(self, name, dtype, shape):
"""Update weights from distributed source."""
obj = UpdateWeightsFromDistributedReqInput(
name=name,
dtype=dtype,
shape=shape,
)
async def _update_weights():
return await tokenizer_manager.update_weights_from_distributed(obj, None)
loop = asyncio.get_event_loop()
return loop.run_until_complete(_update_weights())
def get_weights_by_name(self, name, truncate_size=100):
"""Get weights by parameter name."""
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
async def _get_weights():
return await tokenizer_manager.get_weights_by_name(obj, None)
loop = asyncio.get_event_loop()
return loop.run_until_complete(_get_weights())
class Runtime:
"""
A wrapper for the server.
A wrapper for the HTTP server.
This is used for launching the server in a python program without
using the commond line interface.
It is mainly used for the frontend language.
You should use the Engine class if you want to do normal offline processing.
"""
def __init__(
@@ -839,201 +1035,3 @@ class Runtime:
def __del__(self):
self.shutdown()
STREAM_END_SYMBOL = b"data: [DONE]"
STREAM_CHUNK_START_SYMBOL = b"data:"
class Engine:
"""
SRT Engine without an HTTP server layer.
This class provides a direct inference engine without the need for an HTTP server. It is designed for use cases where
launching the HTTP server adds unnecessary complexity or overhead,
"""
def __init__(self, log_level: str = "error", *args, **kwargs):
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
atexit.register(self.shutdown)
server_args = ServerArgs(*args, log_level=log_level, **kwargs)
launch_engine(server_args=server_args)
def generate(
self,
# The input prompt. It can be a single prompt or a batch of prompts.
prompt: Optional[Union[List[str], str]] = None,
sampling_params: Optional[Union[List[Dict], Dict]] = None,
# The token ids for text; one can either specify text or input_ids.
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
lora_path: Optional[List[Optional[str]]] = None,
stream: bool = False,
):
obj = GenerateReqInput(
text=prompt,
input_ids=input_ids,
sampling_params=sampling_params,
return_logprob=return_logprob,
logprob_start_len=logprob_start_len,
top_logprobs_num=top_logprobs_num,
lora_path=lora_path,
stream=stream,
)
# get the current event loop
loop = asyncio.get_event_loop()
ret = loop.run_until_complete(generate_request(obj, None))
if stream is True:
def generator_wrapper():
offset = 0
loop = asyncio.get_event_loop()
generator = ret.body_iterator
while True:
chunk = loop.run_until_complete(generator.__anext__())
if chunk.startswith(STREAM_END_SYMBOL):
break
else:
data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
data["text"] = data["text"][offset:]
offset += len(data["text"])
yield data
# we cannot yield in the scope of generate() because python does not allow yield + return in the same function
# however, it allows to wrap the generator as a subfunction and return
return generator_wrapper()
else:
return ret
async def async_generate(
self,
# The input prompt. It can be a single prompt or a batch of prompts.
prompt: Optional[Union[List[str], str]] = None,
sampling_params: Optional[Dict] = None,
# The token ids for text; one can either specify text or input_ids.
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
lora_path: Optional[List[Optional[str]]] = None,
stream: bool = False,
):
obj = GenerateReqInput(
text=prompt,
input_ids=input_ids,
sampling_params=sampling_params,
return_logprob=return_logprob,
logprob_start_len=logprob_start_len,
top_logprobs_num=top_logprobs_num,
lora_path=lora_path,
stream=stream,
)
ret = await generate_request(obj, None)
if stream is True:
generator = ret.body_iterator
async def generator_wrapper():
offset = 0
while True:
chunk = await generator.__anext__()
if chunk.startswith(STREAM_END_SYMBOL):
break
else:
data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
data["text"] = data["text"][offset:]
offset += len(data["text"])
yield data
return generator_wrapper()
else:
return ret
def shutdown(self):
kill_process_tree(os.getpid(), include_parent=False)
def get_tokenizer(self):
global tokenizer_manager
if tokenizer_manager is None:
raise ReferenceError("Tokenizer Manager is not initialized.")
else:
return tokenizer_manager.tokenizer
def encode(
self,
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
):
obj = EmbeddingReqInput(text=prompt)
# get the current event loop
loop = asyncio.get_event_loop()
return loop.run_until_complete(encode_request(obj, None))
def start_profile(self):
tokenizer_manager.start_profile()
def stop_profile(self):
tokenizer_manager.stop_profile()
async def get_server_info(self):
return await _get_server_info()
def init_weights_update_group(
self,
master_address: str,
master_port: int,
rank_offset: int,
world_size: int,
group_name: str,
backend: str = "nccl",
):
"""Initialize parameter update group."""
obj = InitWeightsUpdateGroupReqInput(
master_address=master_address,
master_port=master_port,
rank_offset=rank_offset,
world_size=world_size,
group_name=group_name,
backend=backend,
)
async def _init_group():
return await tokenizer_manager.init_weights_update_group(obj, None)
loop = asyncio.get_event_loop()
return loop.run_until_complete(_init_group())
def update_weights_from_distributed(self, name, dtype, shape):
"""Update weights from distributed source."""
obj = UpdateWeightsFromDistributedReqInput(
name=name,
dtype=dtype,
shape=shape,
)
async def _update_weights():
return await tokenizer_manager.update_weights_from_distributed(obj, None)
loop = asyncio.get_event_loop()
return loop.run_until_complete(_update_weights())
def get_weights_by_name(self, name, truncate_size=100):
"""Get weights by parameter name."""
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
async def _get_weights():
return await tokenizer_manager.get_weights_by_name(obj, None)
loop = asyncio.get_event_loop()
return loop.run_until_complete(_get_weights())