feat: support truss endpoint for benchmark serving (#1906)
This commit is contained in:
@@ -222,6 +222,85 @@ async def async_request_openai_completions(
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
async def async_request_truss(
|
||||||
|
request_func_input: RequestFuncInput,
|
||||||
|
pbar: Optional[tqdm] = None,
|
||||||
|
) -> RequestFuncOutput:
|
||||||
|
api_url = request_func_input.api_url
|
||||||
|
|
||||||
|
prompt = request_func_input.prompt
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||||
|
payload = {
|
||||||
|
"model": request_func_input.model,
|
||||||
|
"prompt": prompt,
|
||||||
|
"temperature": 0.0,
|
||||||
|
"best_of": 1,
|
||||||
|
"max_tokens": request_func_input.output_len,
|
||||||
|
"stream": not args.disable_stream,
|
||||||
|
"ignore_eos": not args.disable_ignore_eos,
|
||||||
|
**request_func_input.extra_request_body,
|
||||||
|
}
|
||||||
|
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
||||||
|
|
||||||
|
output = RequestFuncOutput()
|
||||||
|
output.prompt_len = request_func_input.prompt_len
|
||||||
|
|
||||||
|
generated_text = ""
|
||||||
|
ttft = 0.0
|
||||||
|
st = time.perf_counter()
|
||||||
|
most_recent_timestamp = st
|
||||||
|
try:
|
||||||
|
async with session.post(
|
||||||
|
url=api_url, json=payload, headers=headers
|
||||||
|
) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
async for chunk_bytes in response.content:
|
||||||
|
chunk_bytes = chunk_bytes.strip()
|
||||||
|
if not chunk_bytes:
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
|
||||||
|
latency = time.perf_counter() - st
|
||||||
|
if chunk == "[DONE]":
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
data = json.loads(chunk)
|
||||||
|
|
||||||
|
# NOTE: Some completion API might have a last
|
||||||
|
# usage summary response without a token so we
|
||||||
|
# want to check a token was generated
|
||||||
|
if data["choices"][0]["delta"]["content"]:
|
||||||
|
timestamp = time.perf_counter()
|
||||||
|
# First token
|
||||||
|
if ttft == 0.0:
|
||||||
|
ttft = time.perf_counter() - st
|
||||||
|
output.ttft = ttft
|
||||||
|
|
||||||
|
# Decoding phase
|
||||||
|
else:
|
||||||
|
output.itl.append(timestamp - most_recent_timestamp)
|
||||||
|
|
||||||
|
most_recent_timestamp = timestamp
|
||||||
|
generated_text += data["choices"][0]["delta"]["content"]
|
||||||
|
|
||||||
|
output.generated_text = generated_text
|
||||||
|
output.success = True
|
||||||
|
output.latency = latency
|
||||||
|
output.output_len = request_func_input.output_len
|
||||||
|
else:
|
||||||
|
output.error = response.reason or ""
|
||||||
|
output.success = False
|
||||||
|
except Exception:
|
||||||
|
output.success = False
|
||||||
|
exc_info = sys.exc_info()
|
||||||
|
output.error = "".join(traceback.format_exception(*exc_info))
|
||||||
|
|
||||||
|
if pbar:
|
||||||
|
pbar.update(1)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
async def async_request_sglang_generate(
|
async def async_request_sglang_generate(
|
||||||
request_func_input: RequestFuncInput,
|
request_func_input: RequestFuncInput,
|
||||||
pbar: Optional[tqdm] = None,
|
pbar: Optional[tqdm] = None,
|
||||||
@@ -350,6 +429,7 @@ ASYNC_REQUEST_FUNCS = {
|
|||||||
"lmdeploy": async_request_openai_completions,
|
"lmdeploy": async_request_openai_completions,
|
||||||
"trt": async_request_trt_llm,
|
"trt": async_request_trt_llm,
|
||||||
"gserver": async_request_gserver,
|
"gserver": async_request_gserver,
|
||||||
|
"truss": async_request_truss,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -873,6 +953,7 @@ def run_benchmark(args_: argparse.Namespace):
|
|||||||
"vllm": 8000,
|
"vllm": 8000,
|
||||||
"trt": 8000,
|
"trt": 8000,
|
||||||
"gserver": 9988,
|
"gserver": 9988,
|
||||||
|
"truss": 8080,
|
||||||
}.get(args.backend, 30000)
|
}.get(args.backend, 30000)
|
||||||
|
|
||||||
model_url = (
|
model_url = (
|
||||||
@@ -905,9 +986,20 @@ def run_benchmark(args_: argparse.Namespace):
|
|||||||
elif args.backend == "gserver":
|
elif args.backend == "gserver":
|
||||||
api_url = args.base_url if args.base_url else f"{args.host}:{args.port}"
|
api_url = args.base_url if args.base_url else f"{args.host}:{args.port}"
|
||||||
args.model = args.model or "default"
|
args.model = args.model or "default"
|
||||||
|
elif args.backend == "truss":
|
||||||
|
api_url = (
|
||||||
|
f"{args.base_url}/v1/models/model:predict"
|
||||||
|
if args.base_url
|
||||||
|
else f"http://{args.host}:{args.port}/v1/models/model:predict"
|
||||||
|
)
|
||||||
|
|
||||||
# Get model name
|
# Get model name
|
||||||
if args.model is None:
|
if args.model is None:
|
||||||
|
if args.backend == "truss":
|
||||||
|
print(
|
||||||
|
"Please provide a model with `--model` when using truss backend. e.g. --model meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
try:
|
try:
|
||||||
response = requests.get(model_url)
|
response = requests.get(model_url)
|
||||||
model_list = response.json().get("data", [])
|
model_list = response.json().get("data", [])
|
||||||
|
|||||||
Reference in New Issue
Block a user