[Minor] Improve logging and rename the health check endpoint name (#1180)
This commit is contained in:
@@ -21,7 +21,6 @@ Each data parallel worker can manage multiple tensor parallel workers.
|
||||
import dataclasses
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
from enum import Enum, auto
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -17,7 +17,6 @@ limitations under the License.
|
||||
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import zmq
|
||||
|
||||
@@ -39,6 +39,8 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DecodeStatus:
|
||||
"""Store the status of incremental decoding."""
|
||||
|
||||
vid: int
|
||||
decoded_text: str
|
||||
decode_ids: List[int]
|
||||
@@ -47,6 +49,8 @@ class DecodeStatus:
|
||||
|
||||
|
||||
class DetokenizerManager:
|
||||
"""DetokenizerManager is a process that detokenizes the token ids."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
|
||||
@@ -62,12 +62,16 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ReqState:
|
||||
"""Store the state a request."""
|
||||
|
||||
out_list: List
|
||||
finished: bool
|
||||
event: asyncio.Event
|
||||
|
||||
|
||||
class TokenizerManager:
|
||||
"""TokenizerManager is a process that tokenizes the text."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
@@ -481,11 +485,7 @@ class TokenizerManager:
|
||||
|
||||
# Log requests
|
||||
if self.server_args.log_requests and state.finished:
|
||||
if obj.text is None:
|
||||
in_obj = {"input_ids": obj.input_ids}
|
||||
else:
|
||||
in_obj = {"text": obj.text}
|
||||
logger.info(f"in={in_obj}, out={out}")
|
||||
logger.info(f"in={obj}, out={out}")
|
||||
|
||||
state.out_list = []
|
||||
if state.finished:
|
||||
|
||||
@@ -92,11 +92,15 @@ app = FastAPI()
|
||||
tokenizer_manager = None
|
||||
|
||||
|
||||
@app.get("/v1/health")
|
||||
async def health(request: Request) -> Response:
|
||||
"""
|
||||
Generate 1 token to verify the health of the inference service.
|
||||
"""
|
||||
@app.get("/health")
|
||||
async def health() -> Response:
|
||||
"""Check the health of the http server."""
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.get("/health_generate")
|
||||
async def health_generate(request: Request) -> Response:
|
||||
"""Check the health of the inference server by generating one token."""
|
||||
gri = GenerateReqInput(
|
||||
text="s", sampling_params={"max_new_tokens": 1, "temperature": 0.7}
|
||||
)
|
||||
@@ -109,12 +113,6 @@ async def health(request: Request) -> Response:
|
||||
return Response(status_code=503)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> Response:
|
||||
"""Health check."""
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.get("/get_model_info")
|
||||
async def get_model_info():
|
||||
result = {
|
||||
|
||||
@@ -422,13 +422,13 @@ class ServerArgs:
|
||||
parser.add_argument(
|
||||
"--enable-mla",
|
||||
action="store_true",
|
||||
help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2",
|
||||
help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--attention-reduce-in-fp32",
|
||||
action="store_true",
|
||||
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
|
||||
"This only affects Triton attention kernels",
|
||||
"This only affects Triton attention kernels.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--efficient-weight-load",
|
||||
@@ -452,15 +452,6 @@ class ServerArgs:
|
||||
def url(self):
|
||||
return f"http://{self.host}:{self.port}"
|
||||
|
||||
def print_mode_args(self):
|
||||
return (
|
||||
f"disable_flashinfer={self.disable_flashinfer}, "
|
||||
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, "
|
||||
f"disable_radix_cache={self.disable_radix_cache}, "
|
||||
f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
|
||||
f"disable_disk_cache={self.disable_disk_cache}, "
|
||||
)
|
||||
|
||||
def check_server_args(self):
|
||||
assert (
|
||||
self.tp_size % self.nnodes == 0
|
||||
@@ -469,7 +460,7 @@ class ServerArgs:
|
||||
self.dp_size > 1 and self.node_rank is not None
|
||||
), "multi-node data parallel is not supported"
|
||||
if "gemma-2" in self.model_path.lower():
|
||||
logger.info(f"When using sliding window in gemma-2, turn on flashinfer.")
|
||||
logger.info("When using sliding window in gemma-2, turn on flashinfer.")
|
||||
self.disable_flashinfer = False
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user