[Minor] Improve logging and rename the health check endpoint name (#1180)

This commit is contained in:
Lianmin Zheng
2024-08-21 19:24:36 -07:00
committed by GitHub
parent 83e23c69b3
commit 5623826f73
6 changed files with 21 additions and 30 deletions

View File

@@ -21,7 +21,6 @@ Each data parallel worker can manage multiple tensor parallel workers.
import dataclasses import dataclasses
import logging import logging
import multiprocessing import multiprocessing
import os
from enum import Enum, auto from enum import Enum, auto
import numpy as np import numpy as np

View File

@@ -17,7 +17,6 @@ limitations under the License.
import logging import logging
import multiprocessing import multiprocessing
import os
from typing import List from typing import List
import zmq import zmq

View File

@@ -39,6 +39,8 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@dataclasses.dataclass @dataclasses.dataclass
class DecodeStatus: class DecodeStatus:
"""Store the status of incremental decoding."""
vid: int vid: int
decoded_text: str decoded_text: str
decode_ids: List[int] decode_ids: List[int]
@@ -47,6 +49,8 @@ class DecodeStatus:
class DetokenizerManager: class DetokenizerManager:
"""DetokenizerManager is a process that detokenizes the token ids."""
def __init__( def __init__(
self, self,
server_args: ServerArgs, server_args: ServerArgs,

View File

@@ -62,12 +62,16 @@ logger = logging.getLogger(__name__)
@dataclasses.dataclass @dataclasses.dataclass
class ReqState: class ReqState:
"""Store the state a request."""
out_list: List out_list: List
finished: bool finished: bool
event: asyncio.Event event: asyncio.Event
class TokenizerManager: class TokenizerManager:
"""TokenizerManager is a process that tokenizes the text."""
def __init__( def __init__(
self, self,
server_args: ServerArgs, server_args: ServerArgs,
@@ -481,11 +485,7 @@ class TokenizerManager:
# Log requests # Log requests
if self.server_args.log_requests and state.finished: if self.server_args.log_requests and state.finished:
if obj.text is None: logger.info(f"in={obj}, out={out}")
in_obj = {"input_ids": obj.input_ids}
else:
in_obj = {"text": obj.text}
logger.info(f"in={in_obj}, out={out}")
state.out_list = [] state.out_list = []
if state.finished: if state.finished:

View File

@@ -92,11 +92,15 @@ app = FastAPI()
tokenizer_manager = None tokenizer_manager = None
@app.get("/v1/health") @app.get("/health")
async def health(request: Request) -> Response: async def health() -> Response:
""" """Check the health of the http server."""
Generate 1 token to verify the health of the inference service. return Response(status_code=200)
"""
@app.get("/health_generate")
async def health_generate(request: Request) -> Response:
"""Check the health of the inference server by generating one token."""
gri = GenerateReqInput( gri = GenerateReqInput(
text="s", sampling_params={"max_new_tokens": 1, "temperature": 0.7} text="s", sampling_params={"max_new_tokens": 1, "temperature": 0.7}
) )
@@ -109,12 +113,6 @@ async def health(request: Request) -> Response:
return Response(status_code=503) return Response(status_code=503)
@app.get("/health")
async def health() -> Response:
"""Health check."""
return Response(status_code=200)
@app.get("/get_model_info") @app.get("/get_model_info")
async def get_model_info(): async def get_model_info():
result = { result = {

View File

@@ -422,13 +422,13 @@ class ServerArgs:
parser.add_argument( parser.add_argument(
"--enable-mla", "--enable-mla",
action="store_true", action="store_true",
help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2", help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
) )
parser.add_argument( parser.add_argument(
"--attention-reduce-in-fp32", "--attention-reduce-in-fp32",
action="store_true", action="store_true",
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16." help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
"This only affects Triton attention kernels", "This only affects Triton attention kernels.",
) )
parser.add_argument( parser.add_argument(
"--efficient-weight-load", "--efficient-weight-load",
@@ -452,15 +452,6 @@ class ServerArgs:
def url(self): def url(self):
return f"http://{self.host}:{self.port}" return f"http://{self.host}:{self.port}"
def print_mode_args(self):
return (
f"disable_flashinfer={self.disable_flashinfer}, "
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, "
f"disable_radix_cache={self.disable_radix_cache}, "
f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
f"disable_disk_cache={self.disable_disk_cache}, "
)
def check_server_args(self): def check_server_args(self):
assert ( assert (
self.tp_size % self.nnodes == 0 self.tp_size % self.nnodes == 0
@@ -469,7 +460,7 @@ class ServerArgs:
self.dp_size > 1 and self.node_rank is not None self.dp_size > 1 and self.node_rank is not None
), "multi-node data parallel is not supported" ), "multi-node data parallel is not supported"
if "gemma-2" in self.model_path.lower(): if "gemma-2" in self.model_path.lower():
logger.info(f"When using sliding window in gemma-2, turn on flashinfer.") logger.info("When using sliding window in gemma-2, turn on flashinfer.")
self.disable_flashinfer = False self.disable_flashinfer = False