[Minor] Improve logging and rename the health check endpoint name (#1180)
This commit is contained in:
@@ -21,7 +21,6 @@ Each data parallel worker can manage multiple tensor parallel workers.
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ limitations under the License.
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import zmq
|
import zmq
|
||||||
|
|||||||
@@ -39,6 +39,8 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
|||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class DecodeStatus:
|
class DecodeStatus:
|
||||||
|
"""Store the status of incremental decoding."""
|
||||||
|
|
||||||
vid: int
|
vid: int
|
||||||
decoded_text: str
|
decoded_text: str
|
||||||
decode_ids: List[int]
|
decode_ids: List[int]
|
||||||
@@ -47,6 +49,8 @@ class DecodeStatus:
|
|||||||
|
|
||||||
|
|
||||||
class DetokenizerManager:
|
class DetokenizerManager:
|
||||||
|
"""DetokenizerManager is a process that detokenizes the token ids."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
|
|||||||
@@ -62,12 +62,16 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class ReqState:
|
class ReqState:
|
||||||
|
"""Store the state a request."""
|
||||||
|
|
||||||
out_list: List
|
out_list: List
|
||||||
finished: bool
|
finished: bool
|
||||||
event: asyncio.Event
|
event: asyncio.Event
|
||||||
|
|
||||||
|
|
||||||
class TokenizerManager:
|
class TokenizerManager:
|
||||||
|
"""TokenizerManager is a process that tokenizes the text."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
@@ -481,11 +485,7 @@ class TokenizerManager:
|
|||||||
|
|
||||||
# Log requests
|
# Log requests
|
||||||
if self.server_args.log_requests and state.finished:
|
if self.server_args.log_requests and state.finished:
|
||||||
if obj.text is None:
|
logger.info(f"in={obj}, out={out}")
|
||||||
in_obj = {"input_ids": obj.input_ids}
|
|
||||||
else:
|
|
||||||
in_obj = {"text": obj.text}
|
|
||||||
logger.info(f"in={in_obj}, out={out}")
|
|
||||||
|
|
||||||
state.out_list = []
|
state.out_list = []
|
||||||
if state.finished:
|
if state.finished:
|
||||||
|
|||||||
@@ -92,11 +92,15 @@ app = FastAPI()
|
|||||||
tokenizer_manager = None
|
tokenizer_manager = None
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/health")
|
@app.get("/health")
|
||||||
async def health(request: Request) -> Response:
|
async def health() -> Response:
|
||||||
"""
|
"""Check the health of the http server."""
|
||||||
Generate 1 token to verify the health of the inference service.
|
return Response(status_code=200)
|
||||||
"""
|
|
||||||
|
|
||||||
|
@app.get("/health_generate")
|
||||||
|
async def health_generate(request: Request) -> Response:
|
||||||
|
"""Check the health of the inference server by generating one token."""
|
||||||
gri = GenerateReqInput(
|
gri = GenerateReqInput(
|
||||||
text="s", sampling_params={"max_new_tokens": 1, "temperature": 0.7}
|
text="s", sampling_params={"max_new_tokens": 1, "temperature": 0.7}
|
||||||
)
|
)
|
||||||
@@ -109,12 +113,6 @@ async def health(request: Request) -> Response:
|
|||||||
return Response(status_code=503)
|
return Response(status_code=503)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
|
||||||
async def health() -> Response:
|
|
||||||
"""Health check."""
|
|
||||||
return Response(status_code=200)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/get_model_info")
|
@app.get("/get_model_info")
|
||||||
async def get_model_info():
|
async def get_model_info():
|
||||||
result = {
|
result = {
|
||||||
|
|||||||
@@ -422,13 +422,13 @@ class ServerArgs:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable-mla",
|
"--enable-mla",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2",
|
help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--attention-reduce-in-fp32",
|
"--attention-reduce-in-fp32",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
|
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
|
||||||
"This only affects Triton attention kernels",
|
"This only affects Triton attention kernels.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--efficient-weight-load",
|
"--efficient-weight-load",
|
||||||
@@ -452,15 +452,6 @@ class ServerArgs:
|
|||||||
def url(self):
|
def url(self):
|
||||||
return f"http://{self.host}:{self.port}"
|
return f"http://{self.host}:{self.port}"
|
||||||
|
|
||||||
def print_mode_args(self):
|
|
||||||
return (
|
|
||||||
f"disable_flashinfer={self.disable_flashinfer}, "
|
|
||||||
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, "
|
|
||||||
f"disable_radix_cache={self.disable_radix_cache}, "
|
|
||||||
f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
|
|
||||||
f"disable_disk_cache={self.disable_disk_cache}, "
|
|
||||||
)
|
|
||||||
|
|
||||||
def check_server_args(self):
|
def check_server_args(self):
|
||||||
assert (
|
assert (
|
||||||
self.tp_size % self.nnodes == 0
|
self.tp_size % self.nnodes == 0
|
||||||
@@ -469,7 +460,7 @@ class ServerArgs:
|
|||||||
self.dp_size > 1 and self.node_rank is not None
|
self.dp_size > 1 and self.node_rank is not None
|
||||||
), "multi-node data parallel is not supported"
|
), "multi-node data parallel is not supported"
|
||||||
if "gemma-2" in self.model_path.lower():
|
if "gemma-2" in self.model_path.lower():
|
||||||
logger.info(f"When using sliding window in gemma-2, turn on flashinfer.")
|
logger.info("When using sliding window in gemma-2, turn on flashinfer.")
|
||||||
self.disable_flashinfer = False
|
self.disable_flashinfer = False
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user