Improve docs and warnings (#1164)

This commit is contained in:
Lianmin Zheng
2024-08-20 08:31:29 -07:00
committed by GitHub
parent d8476818ef
commit a8ae640328
7 changed files with 25 additions and 24 deletions

View File

@@ -147,13 +147,12 @@ def get_tokenizer(
and kwargs.get("use_fast", True)
and tokenizer_name != _FAST_LLAMA_TOKENIZER
):
pass
# warnings.warn(
# "For some LLaMA V1 models, initializing the fast tokenizer may "
# "take a long time. To reduce the initialization time, consider "
# f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
# "tokenizer."
# )
warnings.warn(
"For some LLaMA V1 models, initializing the fast tokenizer may "
"take a long time. To reduce the initialization time, consider "
f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
"tokenizer."
)
try:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,

View File

@@ -270,7 +270,7 @@ class Req:
if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
# TODO(lsyin): fix token fusion
warnings.warn(
logging.warning(
"Token fusion between input and output, try to avoid this by removing the space at the end of the input."
)
return False
@@ -791,7 +791,7 @@ class ScheduleBatch:
)
if not torch.all(success):
warnings.warn("Sampling failed, fallback to top_k=1 strategy")
logging.warning("Sampling failed, fallback to top_k=1 strategy")
probs = probs.masked_fill(torch.isnan(probs), 0.0)
argmax_ids = torch.argmax(probs, dim=-1)
batch_next_token_ids = torch.where(

View File

@@ -774,7 +774,7 @@ class ModelTpServer:
torch.cuda.empty_cache()
logger.info("Cache flushed successfully!")
else:
warnings.warn(
logging.warning(
f"Cache not flushed because there are pending requests. "
f"#queue-req: {len(self.waiting_queue)}, "
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"

View File

@@ -237,7 +237,7 @@ class ModelRunner:
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
if max_total_tokens is not None:
if max_total_tokens > self.max_total_num_tokens:
warnings.warn(
logging.warning(
f"max_total_tokens={max_total_tokens} is larger than the profiled value "
f"{self.max_total_num_tokens}. "
f"Use the profiled value instead."

View File

@@ -17,10 +17,10 @@ limitations under the License.
import asyncio
import json
import logging
import os
import time
import uuid
import warnings
from http import HTTPStatus
from typing import Dict, List, Optional
@@ -65,6 +65,8 @@ from sglang.srt.openai_api.protocol import (
UsageInfo,
)
logger = logging.getLogger(__name__)
chat_template_name = None
@@ -408,7 +410,7 @@ def v1_generate_request(all_requests: List[CompletionRequest]):
"Parallel sampling is not supported for completions from files"
)
if request.echo and request.logprobs:
warnings.warn(
logger.warning(
"Echo is not compatible with logprobs. "
"To compute logprobs of input prompt, please use SGLang /request API."
)