Correctly abort the failed grammar requests & Improve the handling of abort (#6803)
This commit is contained in:
@@ -60,7 +60,7 @@ class BaseGrammarObject:
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def copy(self) -> "BaseGrammarObject":
|
def copy(self) -> "BaseGrammarObject":
|
||||||
raise NotImplementedError()
|
return self
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def finished(self):
|
def finished(self):
|
||||||
@@ -99,9 +99,12 @@ class BaseGrammarObject:
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
INVALID_GRAMMAR_OBJ = BaseGrammarObject()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CacheEntry:
|
class CacheEntry:
|
||||||
value: Optional[BaseGrammarObject]
|
value: BaseGrammarObject
|
||||||
event: Event
|
event: Event
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from llguidance.torch import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from sglang.srt.constrained.base_grammar_backend import (
|
from sglang.srt.constrained.base_grammar_backend import (
|
||||||
|
INVALID_GRAMMAR_OBJ,
|
||||||
BaseGrammarBackend,
|
BaseGrammarBackend,
|
||||||
BaseGrammarObject,
|
BaseGrammarObject,
|
||||||
)
|
)
|
||||||
@@ -126,8 +127,8 @@ class GuidanceBackend(BaseGrammarBackend):
|
|||||||
serialized_grammar=serialized_grammar,
|
serialized_grammar=serialized_grammar,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Skip invalid grammar: {serialized_grammar}, {e=}")
|
logger.error(f"Hit invalid grammar: {serialized_grammar=}, {e=}")
|
||||||
return None
|
return INVALID_GRAMMAR_OBJ
|
||||||
|
|
||||||
def dispatch_json(self, key_string: str) -> Optional[GuidanceGrammar]:
|
def dispatch_json(self, key_string: str) -> Optional[GuidanceGrammar]:
|
||||||
try:
|
try:
|
||||||
@@ -138,8 +139,8 @@ class GuidanceBackend(BaseGrammarBackend):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Skip invalid grammar: {key_string=}, {e=}")
|
logger.error(f"Hit invalid json_schema: {key_string=}, {e=}")
|
||||||
return None
|
return INVALID_GRAMMAR_OBJ
|
||||||
return self._from_serialized(serialized_grammar)
|
return self._from_serialized(serialized_grammar)
|
||||||
|
|
||||||
def dispatch_regex(self, key_string: str) -> Optional[GuidanceGrammar]:
|
def dispatch_regex(self, key_string: str) -> Optional[GuidanceGrammar]:
|
||||||
@@ -151,8 +152,8 @@ class GuidanceBackend(BaseGrammarBackend):
|
|||||||
serialized_grammar = grammar_from("ebnf", key_string)
|
serialized_grammar = grammar_from("ebnf", key_string)
|
||||||
return self._from_serialized(serialized_grammar)
|
return self._from_serialized(serialized_grammar)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.warning(f"Skip invalid ebnf: regex={key_string}, {e=}")
|
logger.error(f"Hit invalid ebnf: {key_string=}, {e=}")
|
||||||
return None
|
return INVALID_GRAMMAR_OBJ
|
||||||
|
|
||||||
def dispatch_structural_tag(self, key_string: str) -> Optional[GuidanceGrammar]:
|
def dispatch_structural_tag(self, key_string: str) -> Optional[GuidanceGrammar]:
|
||||||
try:
|
try:
|
||||||
@@ -169,5 +170,5 @@ class GuidanceBackend(BaseGrammarBackend):
|
|||||||
g = StructTag.to_grammar(tags)
|
g = StructTag.to_grammar(tags)
|
||||||
return self._from_serialized(g)
|
return self._from_serialized(g)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"Skip invalid structural_tag: {key_string}, {e=}")
|
logging.error(f"Hit invalid structural_tag: {key_string=}, {e=}")
|
||||||
return None
|
return INVALID_GRAMMAR_OBJ
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from outlines.models.transformers import TransformerTokenizer
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from sglang.srt.constrained.base_grammar_backend import (
|
from sglang.srt.constrained.base_grammar_backend import (
|
||||||
|
INVALID_GRAMMAR_OBJ,
|
||||||
BaseGrammarBackend,
|
BaseGrammarBackend,
|
||||||
BaseGrammarObject,
|
BaseGrammarObject,
|
||||||
)
|
)
|
||||||
@@ -151,8 +152,8 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
|
|||||||
# outlines <= 0.0.46
|
# outlines <= 0.0.46
|
||||||
guide = RegexGuide(regex, self.outlines_tokenizer)
|
guide = RegexGuide(regex, self.outlines_tokenizer)
|
||||||
except interegular.patterns.InvalidSyntax as e:
|
except interegular.patterns.InvalidSyntax as e:
|
||||||
logger.warning(f"skip invalid regex schema: {regex=}, {e=}")
|
logger.error(f"Hit invalid regex schema: {regex=}, {e=}")
|
||||||
return None
|
return INVALID_GRAMMAR_OBJ
|
||||||
|
|
||||||
jump_forward_map = None
|
jump_forward_map = None
|
||||||
return OutlinesGrammar(guide, jump_forward_map)
|
return OutlinesGrammar(guide, jump_forward_map)
|
||||||
@@ -170,8 +171,8 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
|
|||||||
whitespace_pattern=self.whitespace_pattern,
|
whitespace_pattern=self.whitespace_pattern,
|
||||||
)
|
)
|
||||||
except (NotImplementedError, json.decoder.JSONDecodeError, ValueError) as e:
|
except (NotImplementedError, json.decoder.JSONDecodeError, ValueError) as e:
|
||||||
logger.warning(f"Skip invalid json_schema: {key_string=}, {e=}")
|
logger.error(f"Hit invalid json_schema: {key_string=}, {e=}")
|
||||||
return None
|
return INVALID_GRAMMAR_OBJ
|
||||||
return self._compile_regex(regex)
|
return self._compile_regex(regex)
|
||||||
|
|
||||||
def dispatch_regex(self, key_string: str):
|
def dispatch_regex(self, key_string: str):
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from xgrammar import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from sglang.srt.constrained.base_grammar_backend import (
|
from sglang.srt.constrained.base_grammar_backend import (
|
||||||
|
INVALID_GRAMMAR_OBJ,
|
||||||
BaseGrammarBackend,
|
BaseGrammarBackend,
|
||||||
BaseGrammarObject,
|
BaseGrammarObject,
|
||||||
)
|
)
|
||||||
@@ -152,10 +153,11 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
tokenizer_info = TokenizerInfo.from_huggingface(
|
if True:
|
||||||
tokenizer, vocab_size=vocab_size
|
tokenizer_info = TokenizerInfo.from_huggingface(
|
||||||
)
|
tokenizer, vocab_size=vocab_size
|
||||||
override_stop_tokens = None
|
)
|
||||||
|
override_stop_tokens = None
|
||||||
|
|
||||||
self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
|
self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
@@ -178,25 +180,26 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|||||||
ctx = self.grammar_compiler.compile_builtin_json_grammar()
|
ctx = self.grammar_compiler.compile_builtin_json_grammar()
|
||||||
else:
|
else:
|
||||||
ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
|
ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
|
||||||
except RuntimeError as e:
|
|
||||||
logging.warning(f"Skip invalid json_schema: json_schema={key_string}, {e=}")
|
except (RuntimeError, json.decoder.JSONDecodeError) as e:
|
||||||
return None
|
logging.error(f"Hit invalid json_schema: {key_string=}, {e=}")
|
||||||
|
return INVALID_GRAMMAR_OBJ
|
||||||
return self._from_context(ctx, key_string)
|
return self._from_context(ctx, key_string)
|
||||||
|
|
||||||
def dispatch_ebnf(self, key_string: str) -> Optional[XGrammarGrammar]:
|
def dispatch_ebnf(self, key_string: str) -> Optional[XGrammarGrammar]:
|
||||||
try:
|
try:
|
||||||
ctx = self.grammar_compiler.compile_grammar(key_string)
|
ctx = self.grammar_compiler.compile_grammar(key_string)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}")
|
logging.error(f"Hit invalid ebnf: {key_string=}, {e=}")
|
||||||
return None
|
return INVALID_GRAMMAR_OBJ
|
||||||
return self._from_context(ctx, key_string)
|
return self._from_context(ctx, key_string)
|
||||||
|
|
||||||
def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]:
|
def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]:
|
||||||
try:
|
try:
|
||||||
ctx = self.grammar_compiler.compile_regex(key_string)
|
ctx = self.grammar_compiler.compile_regex(key_string)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
|
logging.error(f"Hit invalid regex: {key_string=}, {e=}")
|
||||||
return None
|
return INVALID_GRAMMAR_OBJ
|
||||||
return self._from_context(ctx, key_string)
|
return self._from_context(ctx, key_string)
|
||||||
|
|
||||||
def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
|
def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
|
||||||
@@ -213,13 +216,10 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|||||||
ctx = self.grammar_compiler.compile_structural_tag(
|
ctx = self.grammar_compiler.compile_structural_tag(
|
||||||
tags, structural_tag["triggers"]
|
tags, structural_tag["triggers"]
|
||||||
)
|
)
|
||||||
except RuntimeError as e:
|
except (RuntimeError, json.decoder.JSONDecodeError) as e:
|
||||||
logging.warning(
|
logging.error(f"Hit invalid structural_tag: {key_string=}, {e=}")
|
||||||
f"Skip invalid structural_tag: structural_tag={key_string}, {e=}"
|
return INVALID_GRAMMAR_OBJ
|
||||||
)
|
|
||||||
return None
|
|
||||||
return self._from_context(ctx, key_string)
|
return self._from_context(ctx, key_string)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
if self.grammar_compiler:
|
self.grammar_compiler.clear_cache()
|
||||||
self.grammar_compiler.clear_cache()
|
|
||||||
|
|||||||
@@ -256,7 +256,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
|
|||||||
) + b"\n\n"
|
) + b"\n\n"
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
out = {"error": {"message": str(e)}}
|
out = {"error": {"message": str(e)}}
|
||||||
logger.error(f"Error: {e}")
|
logger.error(f"[http_server] Error: {e}")
|
||||||
yield b"data: " + orjson.dumps(
|
yield b"data: " + orjson.dumps(
|
||||||
out, option=orjson.OPT_NON_STR_KEYS
|
out, option=orjson.OPT_NON_STR_KEYS
|
||||||
) + b"\n\n"
|
) + b"\n\n"
|
||||||
@@ -274,7 +274,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
|
|||||||
).__anext__()
|
).__anext__()
|
||||||
return ret
|
return ret
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(f"Error: {e}")
|
logger.error(f"[http_server] Error: {e}")
|
||||||
return _create_error_response(e)
|
return _create_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ import hashlib
|
|||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
|
from http import HTTPStatus
|
||||||
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
|
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -51,6 +52,7 @@ from sglang.srt.disaggregation.base import BaseKVSender
|
|||||||
from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
||||||
ScheduleBatchDisaggregationDecodeMixin,
|
ScheduleBatchDisaggregationDecodeMixin,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
|
||||||
from sglang.srt.layers.multimodal import gpu_tensor_hash
|
from sglang.srt.layers.multimodal import gpu_tensor_hash
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||||
@@ -60,7 +62,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw
|
|||||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import flatten_nested_list, get_compiler_backend, support_triton
|
from sglang.srt.utils import flatten_nested_list, support_triton
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||||
@@ -771,6 +773,16 @@ class Req:
|
|||||||
logger.info(f"{prefix}: {self.time_stats}")
|
logger.info(f"{prefix}: {self.time_stats}")
|
||||||
self.has_log_time_stats = True
|
self.has_log_time_stats = True
|
||||||
|
|
||||||
|
def set_finish_with_abort(self, error_msg: str):
|
||||||
|
if get_tensor_model_parallel_rank() == 0:
|
||||||
|
logger.error(f"{error_msg}, {self.rid=}")
|
||||||
|
self.multimodal_inputs = None
|
||||||
|
self.grammar = None
|
||||||
|
self.origin_input_ids = [0] # set it to one token to skip the long prefill
|
||||||
|
self.finished_reason = FINISH_ABORT(
|
||||||
|
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
||||||
|
)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return (
|
return (
|
||||||
f"Req(rid={self.rid}, "
|
f"Req(rid={self.rid}, "
|
||||||
|
|||||||
@@ -35,7 +35,10 @@ from torch.distributed import barrier
|
|||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
|
from sglang.srt.constrained.base_grammar_backend import (
|
||||||
|
INVALID_GRAMMAR_OBJ,
|
||||||
|
create_grammar_backend,
|
||||||
|
)
|
||||||
from sglang.srt.disaggregation.decode import (
|
from sglang.srt.disaggregation.decode import (
|
||||||
DecodePreallocQueue,
|
DecodePreallocQueue,
|
||||||
DecodeTransferQueue,
|
DecodeTransferQueue,
|
||||||
@@ -949,12 +952,12 @@ class Scheduler(
|
|||||||
if self.disaggregation_mode != DisaggregationMode.NULL:
|
if self.disaggregation_mode != DisaggregationMode.NULL:
|
||||||
# Invalid request for disaggregated mode
|
# Invalid request for disaggregated mode
|
||||||
if recv_req.bootstrap_room is None:
|
if recv_req.bootstrap_room is None:
|
||||||
error_message = (
|
error_msg = (
|
||||||
f"Invalid request: Disaggregated request received without "
|
f"Invalid request: Disaggregated request received without "
|
||||||
f"boostrap room id. {req.rid=}"
|
f"boostrap room id. {req.rid=}"
|
||||||
)
|
)
|
||||||
logger.error(error_message)
|
logger.error(error_msg)
|
||||||
prepare_abort(req, error_message)
|
prepare_abort(req, error_msg)
|
||||||
self.stream_output([req], req.return_logprob)
|
self.stream_output([req], req.return_logprob)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -985,29 +988,23 @@ class Scheduler(
|
|||||||
req.extend_image_inputs(image_inputs)
|
req.extend_image_inputs(image_inputs)
|
||||||
|
|
||||||
if len(req.origin_input_ids) >= self.max_req_input_len:
|
if len(req.origin_input_ids) >= self.max_req_input_len:
|
||||||
error_msg = (
|
req.set_finish_with_abort(
|
||||||
"Multimodal prompt is too long after expanding multimodal tokens. "
|
error_msg=(
|
||||||
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
|
"Multimodal prompt is too long after expanding multimodal tokens. "
|
||||||
)
|
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
|
||||||
logger.error(error_msg)
|
)
|
||||||
req.origin_input_ids = [0]
|
|
||||||
req.multimodal_inputs = None
|
|
||||||
req.sampling_params.max_new_tokens = 0
|
|
||||||
req.finished_reason = FINISH_ABORT(
|
|
||||||
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
|
||||||
)
|
)
|
||||||
self._add_request_to_queue(req)
|
self._add_request_to_queue(req)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Validate prompts length
|
# Validate prompt length
|
||||||
error_msg = validate_input_length(
|
error_msg = validate_input_length(
|
||||||
req,
|
req,
|
||||||
self.max_req_input_len,
|
self.max_req_input_len,
|
||||||
self.server_args.allow_auto_truncate,
|
self.server_args.allow_auto_truncate,
|
||||||
)
|
)
|
||||||
if error_msg:
|
if error_msg:
|
||||||
req.origin_input_ids = [0]
|
req.set_finish_with_abort(error_msg)
|
||||||
req.sampling_params.max_new_tokens = 0
|
|
||||||
self._add_request_to_queue(req)
|
self._add_request_to_queue(req)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -1019,12 +1016,9 @@ class Scheduler(
|
|||||||
req.logprob_start_len = recv_req.logprob_start_len
|
req.logprob_start_len = recv_req.logprob_start_len
|
||||||
|
|
||||||
if req.logprob_start_len >= len(req.origin_input_ids):
|
if req.logprob_start_len >= len(req.origin_input_ids):
|
||||||
req.finished_reason = FINISH_ABORT(
|
error_msg = f"{req.logprob_start_len=} is higher than the number of input tokens {len(req.origin_input_ids)=}. Please use a smaller logprob_start_len."
|
||||||
f"logprob_start_len, ({req.logprob_start_len}) is higher than the number of input tokens ({len(req.origin_input_ids)}). Request with a lower logprob_start_len.",
|
|
||||||
HTTPStatus.BAD_REQUEST,
|
|
||||||
"BadRequestError",
|
|
||||||
)
|
|
||||||
req.logprob_start_len = len(req.origin_input_ids) - 1
|
req.logprob_start_len = len(req.origin_input_ids) - 1
|
||||||
|
req.set_finish_with_abort(error_msg)
|
||||||
self._add_request_to_queue(req)
|
self._add_request_to_queue(req)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -1061,6 +1055,10 @@ class Scheduler(
|
|||||||
if not cache_hit:
|
if not cache_hit:
|
||||||
req.grammar_key = key
|
req.grammar_key = key
|
||||||
add_to_grammar_queue = True
|
add_to_grammar_queue = True
|
||||||
|
else:
|
||||||
|
if value is INVALID_GRAMMAR_OBJ: # We hit a cached invalid grammar.
|
||||||
|
error_msg = f"Invalid grammar request with cache hit: {key=}"
|
||||||
|
req.set_finish_with_abort(error_msg)
|
||||||
|
|
||||||
if add_to_grammar_queue:
|
if add_to_grammar_queue:
|
||||||
req.queue_time_start = time.perf_counter()
|
req.queue_time_start = time.perf_counter()
|
||||||
@@ -1108,19 +1106,13 @@ class Scheduler(
|
|||||||
req.extend_image_inputs(image_inputs)
|
req.extend_image_inputs(image_inputs)
|
||||||
|
|
||||||
if len(req.origin_input_ids) >= self.max_req_input_len:
|
if len(req.origin_input_ids) >= self.max_req_input_len:
|
||||||
error_msg = (
|
req.set_finish_with_abort(
|
||||||
"Multimodal prompt is too long after expanding multimodal tokens. "
|
error_msg=(
|
||||||
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
|
"Multimodal prompt is too long after expanding multimodal tokens. "
|
||||||
|
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
|
||||||
|
)
|
||||||
)
|
)
|
||||||
logger.error(error_msg)
|
self._add_request_to_queue(req)
|
||||||
req.origin_input_ids = [0]
|
|
||||||
req.multimodal_inputs = None
|
|
||||||
req.sampling_params.max_new_tokens = 0
|
|
||||||
req.finished_reason = FINISH_ABORT(
|
|
||||||
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
|
||||||
)
|
|
||||||
req.queue_time_start = time.perf_counter()
|
|
||||||
self.waiting_queue.append(req)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# Validate prompts length
|
# Validate prompts length
|
||||||
@@ -1785,17 +1777,25 @@ class Scheduler(
|
|||||||
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
|
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
|
||||||
|
|
||||||
num_ready_reqs = 0
|
num_ready_reqs = 0
|
||||||
num_abort_reqs = 0
|
num_timeout_reqs = 0
|
||||||
for req in self.grammar_queue:
|
for req in self.grammar_queue:
|
||||||
try:
|
try:
|
||||||
|
if req.finished(): # It is aborted by AbortReq
|
||||||
|
num_ready_reqs += 1
|
||||||
|
continue
|
||||||
req.grammar = req.grammar.result(timeout=0.03)
|
req.grammar = req.grammar.result(timeout=0.03)
|
||||||
if req.grammar:
|
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
|
||||||
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
|
if req.grammar is INVALID_GRAMMAR_OBJ:
|
||||||
|
req.set_finish_with_abort(
|
||||||
|
f"Invalid grammar request: {req.grammar_key=}"
|
||||||
|
)
|
||||||
num_ready_reqs += 1
|
num_ready_reqs += 1
|
||||||
except futures._base.TimeoutError:
|
except futures._base.TimeoutError:
|
||||||
req.grammar_wait_ct += 1
|
req.grammar_wait_ct += 1
|
||||||
|
# NOTE(lianmin): this timeout is the waiting time of the above line. It is
|
||||||
|
# not the waiting time from it enters the grammar queue.
|
||||||
if req.grammar_wait_ct > GRAMMAR_TIMEOUT / 0.03:
|
if req.grammar_wait_ct > GRAMMAR_TIMEOUT / 0.03:
|
||||||
num_abort_reqs = 1
|
num_timeout_reqs = 1
|
||||||
break
|
break
|
||||||
|
|
||||||
if self.server_args.enable_dp_attention:
|
if self.server_args.enable_dp_attention:
|
||||||
@@ -1807,28 +1807,33 @@ class Scheduler(
|
|||||||
|
|
||||||
if tp_size > 1:
|
if tp_size > 1:
|
||||||
# Sync across TP ranks to make sure they have the same number of ready requests
|
# Sync across TP ranks to make sure they have the same number of ready requests
|
||||||
tensor = torch.tensor([num_ready_reqs, num_abort_reqs], dtype=torch.int32)
|
tensor = torch.tensor([num_ready_reqs, num_timeout_reqs], dtype=torch.int32)
|
||||||
torch.distributed.all_reduce(
|
torch.distributed.all_reduce(
|
||||||
tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
|
tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
|
||||||
)
|
)
|
||||||
num_ready_reqs_max, num_abort_reqs_max = tensor.tolist()
|
num_ready_reqs_max, num_timeout_reqs_max = tensor.tolist()
|
||||||
|
|
||||||
for i in range(num_ready_reqs, num_ready_reqs_max):
|
for i in range(num_ready_reqs, num_ready_reqs_max):
|
||||||
req = self.grammar_queue[i]
|
req = self.grammar_queue[i]
|
||||||
|
if req.finished(): # It is aborted by AbortReq
|
||||||
|
continue
|
||||||
req.grammar = req.grammar.result()
|
req.grammar = req.grammar.result()
|
||||||
if req.grammar:
|
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
|
||||||
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
|
if req.grammar is INVALID_GRAMMAR_OBJ:
|
||||||
|
req.set_finish_with_abort(
|
||||||
|
f"Invalid grammar request: {req.grammar_key=}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
num_ready_reqs_max = num_ready_reqs
|
||||||
|
num_timeout_reqs_max = num_timeout_reqs
|
||||||
|
|
||||||
for i in range(num_ready_reqs, num_ready_reqs + num_abort_reqs_max):
|
for i in range(num_ready_reqs, num_ready_reqs + num_timeout_reqs_max):
|
||||||
req = self.grammar_queue[i]
|
req = self.grammar_queue[i]
|
||||||
req.grammar.cancel()
|
req.grammar.cancel()
|
||||||
req.grammar = None
|
error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
|
||||||
error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
|
req.set_finish_with_abort(error_msg)
|
||||||
logger.error(error_msg)
|
self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
|
||||||
req.finished_reason = FINISH_ABORT(
|
num_ready_reqs = num_ready_reqs_max + num_timeout_reqs_max
|
||||||
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
|
||||||
)
|
|
||||||
num_ready_reqs = num_ready_reqs_max + num_abort_reqs_max
|
|
||||||
|
|
||||||
self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
|
self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
|
||||||
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
||||||
@@ -2024,8 +2029,6 @@ class Scheduler(
|
|||||||
)
|
)
|
||||||
|
|
||||||
def abort_request(self, recv_req: AbortReq):
|
def abort_request(self, recv_req: AbortReq):
|
||||||
# TODO(lmzheng): abort the requests in the grammar queue.
|
|
||||||
|
|
||||||
# Delete requests in the waiting queue
|
# Delete requests in the waiting queue
|
||||||
to_del = []
|
to_del = []
|
||||||
for i, req in enumerate(self.waiting_queue):
|
for i, req in enumerate(self.waiting_queue):
|
||||||
@@ -2047,8 +2050,16 @@ class Scheduler(
|
|||||||
for req in reqs:
|
for req in reqs:
|
||||||
if req.rid.startswith(recv_req.rid) and not req.finished():
|
if req.rid.startswith(recv_req.rid) and not req.finished():
|
||||||
logger.debug(f"Abort running request. {req.rid=}")
|
logger.debug(f"Abort running request. {req.rid=}")
|
||||||
|
# We must use to_abort because it is in a running batch
|
||||||
req.to_abort = True
|
req.to_abort = True
|
||||||
|
|
||||||
|
# Delete the requests in the grammar queue
|
||||||
|
for req in self.grammar_queue:
|
||||||
|
if req.rid.startswith(recv_req.rid):
|
||||||
|
logger.debug(f"Abort grammar queue request. {req.rid=}")
|
||||||
|
req.grammar.cancel()
|
||||||
|
req.set_finish_with_abort("Aborted by AbortReq.")
|
||||||
|
|
||||||
def _pause_engine(self) -> Tuple[List[Req], int]:
|
def _pause_engine(self) -> Tuple[List[Req], int]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|||||||
@@ -221,7 +221,7 @@ class TokenizerManager:
|
|||||||
self.tokenizer = get_tokenizer_from_processor(self.processor)
|
self.tokenizer = get_tokenizer_from_processor(self.processor)
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
else:
|
else:
|
||||||
self.mm_processor = get_dummy_processor()
|
self.mm_processor = None
|
||||||
|
|
||||||
if server_args.skip_tokenizer_init:
|
if server_args.skip_tokenizer_init:
|
||||||
self.tokenizer = self.processor = None
|
self.tokenizer = self.processor = None
|
||||||
@@ -425,8 +425,8 @@ class TokenizerManager:
|
|||||||
is_single = obj.is_single
|
is_single = obj.is_single
|
||||||
if is_single:
|
if is_single:
|
||||||
tokenized_obj = await self._tokenize_one_request(obj)
|
tokenized_obj = await self._tokenize_one_request(obj)
|
||||||
self._send_one_request(obj, tokenized_obj, created_time)
|
state = self._send_one_request(obj, tokenized_obj, created_time)
|
||||||
async for response in self._wait_one_response(obj, request):
|
async for response in self._wait_one_response(obj, state, request):
|
||||||
yield response
|
yield response
|
||||||
else:
|
else:
|
||||||
async for response in self._handle_batch_request(
|
async for response in self._handle_batch_request(
|
||||||
@@ -462,8 +462,7 @@ class TokenizerManager:
|
|||||||
)
|
)
|
||||||
input_ids = self.tokenizer.encode(input_text)
|
input_ids = self.tokenizer.encode(input_text)
|
||||||
|
|
||||||
image_inputs: Optional[Dict] = None
|
if self.mm_processor and obj.contains_mm_input():
|
||||||
if obj.contains_mm_input():
|
|
||||||
image_inputs = await self.mm_processor.process_mm_data_async(
|
image_inputs = await self.mm_processor.process_mm_data_async(
|
||||||
image_data=obj.image_data,
|
image_data=obj.image_data,
|
||||||
input_text=input_text or input_ids,
|
input_text=input_text or input_ids,
|
||||||
@@ -472,6 +471,8 @@ class TokenizerManager:
|
|||||||
)
|
)
|
||||||
if image_inputs and "input_ids" in image_inputs:
|
if image_inputs and "input_ids" in image_inputs:
|
||||||
input_ids = image_inputs["input_ids"]
|
input_ids = image_inputs["input_ids"]
|
||||||
|
else:
|
||||||
|
image_inputs: Optional[Dict] = None
|
||||||
|
|
||||||
self._validate_token_len(obj, input_ids)
|
self._validate_token_len(obj, input_ids)
|
||||||
return self._create_tokenized_object(
|
return self._create_tokenized_object(
|
||||||
@@ -631,15 +632,15 @@ class TokenizerManager:
|
|||||||
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
||||||
state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
|
state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
|
||||||
self.rid_to_state[obj.rid] = state
|
self.rid_to_state[obj.rid] = state
|
||||||
|
return state
|
||||||
|
|
||||||
async def _wait_one_response(
|
async def _wait_one_response(
|
||||||
self,
|
self,
|
||||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||||
|
state: ReqState,
|
||||||
request: Optional[fastapi.Request] = None,
|
request: Optional[fastapi.Request] = None,
|
||||||
):
|
):
|
||||||
"""Wait for the response of one request."""
|
"""Wait for the response of one request."""
|
||||||
state = self.rid_to_state[obj.rid]
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(state.event.wait(), timeout=4)
|
await asyncio.wait_for(state.event.wait(), timeout=4)
|
||||||
@@ -709,16 +710,16 @@ class TokenizerManager:
|
|||||||
|
|
||||||
for i, tokenized_obj in enumerate(tokenized_objs):
|
for i, tokenized_obj in enumerate(tokenized_objs):
|
||||||
tmp_obj = obj[i]
|
tmp_obj = obj[i]
|
||||||
self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
||||||
generators.append(self._wait_one_response(tmp_obj, request))
|
generators.append(self._wait_one_response(tmp_obj, state, request))
|
||||||
rids.append(tmp_obj.rid)
|
rids.append(tmp_obj.rid)
|
||||||
else:
|
else:
|
||||||
# Sequential tokenization and processing
|
# Sequential tokenization and processing
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
tmp_obj = obj[i]
|
tmp_obj = obj[i]
|
||||||
tokenized_obj = await self._tokenize_one_request(tmp_obj)
|
tokenized_obj = await self._tokenize_one_request(tmp_obj)
|
||||||
self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
||||||
generators.append(self._wait_one_response(tmp_obj, request))
|
generators.append(self._wait_one_response(tmp_obj, state, request))
|
||||||
rids.append(tmp_obj.rid)
|
rids.append(tmp_obj.rid)
|
||||||
else:
|
else:
|
||||||
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
|
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
|
||||||
@@ -743,8 +744,8 @@ class TokenizerManager:
|
|||||||
tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params)
|
tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params)
|
||||||
tokenized_obj.sampling_params.max_new_tokens = 0
|
tokenized_obj.sampling_params.max_new_tokens = 0
|
||||||
tokenized_obj.stream = False
|
tokenized_obj.stream = False
|
||||||
self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
||||||
await self._wait_one_response(tmp_obj, request).__anext__()
|
await self._wait_one_response(tmp_obj, state, request).__anext__()
|
||||||
|
|
||||||
# Expand requests, assign new rids for them, and send them
|
# Expand requests, assign new rids for them, and send them
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
@@ -752,8 +753,8 @@ class TokenizerManager:
|
|||||||
tmp_obj = copy.copy(objs[i])
|
tmp_obj = copy.copy(objs[i])
|
||||||
tokenized_obj = copy.copy(tokenized_objs[i])
|
tokenized_obj = copy.copy(tokenized_objs[i])
|
||||||
tokenized_obj.rid = tmp_obj.regenerate_rid()
|
tokenized_obj.rid = tmp_obj.regenerate_rid()
|
||||||
self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
||||||
generators.append(self._wait_one_response(tmp_obj, request))
|
generators.append(self._wait_one_response(tmp_obj, state, request))
|
||||||
rids.append(tmp_obj.rid)
|
rids.append(tmp_obj.rid)
|
||||||
|
|
||||||
# Wait for all requests
|
# Wait for all requests
|
||||||
@@ -789,6 +790,9 @@ class TokenizerManager:
|
|||||||
req = AbortReq(rid)
|
req = AbortReq(rid)
|
||||||
self.send_to_scheduler.send_pyobj(req)
|
self.send_to_scheduler.send_pyobj(req)
|
||||||
|
|
||||||
|
if self.enable_metrics:
|
||||||
|
self.metrics_collector.observe_one_aborted_request()
|
||||||
|
|
||||||
async def start_profile(
|
async def start_profile(
|
||||||
self,
|
self,
|
||||||
output_dir: Optional[str] = None,
|
output_dir: Optional[str] = None,
|
||||||
|
|||||||
@@ -35,10 +35,6 @@ def validate_input_length(
|
|||||||
f"the maximum allowed length ({max_req_input_len} tokens). "
|
f"the maximum allowed length ({max_req_input_len} tokens). "
|
||||||
f"Use a shorter input or enable --allow-auto-truncate."
|
f"Use a shorter input or enable --allow-auto-truncate."
|
||||||
)
|
)
|
||||||
logger.error(error_msg)
|
|
||||||
req.finished_reason = FINISH_ABORT(
|
|
||||||
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
|
||||||
)
|
|
||||||
return error_msg
|
return error_msg
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -402,6 +402,12 @@ class TokenizerMetricsCollector:
|
|||||||
labelnames=labels.keys(),
|
labelnames=labels.keys(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.num_aborted_requests_total = Counter(
|
||||||
|
name="sglang:num_aborted_requests",
|
||||||
|
documentation="Number of requests aborted.",
|
||||||
|
labelnames=labels.keys(),
|
||||||
|
)
|
||||||
|
|
||||||
if bucket_time_to_first_token is None:
|
if bucket_time_to_first_token is None:
|
||||||
bucket_time_to_first_token = [
|
bucket_time_to_first_token = [
|
||||||
0.1,
|
0.1,
|
||||||
@@ -533,3 +539,6 @@ class TokenizerMetricsCollector:
|
|||||||
if adjusted_interval <= bound:
|
if adjusted_interval <= bound:
|
||||||
his._buckets[i].inc(num_new_tokens)
|
his._buckets[i].inc(num_new_tokens)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
def observe_one_aborted_request(self):
|
||||||
|
self.num_aborted_requests_total.labels(**self.labels).inc(1)
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ from typing import TYPE_CHECKING, Callable, Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
from sglang.srt import two_batch_overlap
|
|
||||||
from sglang.srt.custom_op import CustomOp
|
from sglang.srt.custom_op import CustomOp
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||||
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
|
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
|
||||||
@@ -133,28 +132,27 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|||||||
if capture_bs is None:
|
if capture_bs is None:
|
||||||
if server_args.speculative_algorithm is None:
|
if server_args.speculative_algorithm is None:
|
||||||
if server_args.disable_cuda_graph_padding:
|
if server_args.disable_cuda_graph_padding:
|
||||||
capture_bs = list(range(1, 33)) + list(range(40, 161, 16))
|
capture_bs = list(range(1, 33)) + list(range(48, 161, 16))
|
||||||
else:
|
else:
|
||||||
capture_bs = [1, 2, 4, 8] + list(range(16, 161, 8))
|
capture_bs = [1, 2, 4, 8] + list(range(16, 161, 8))
|
||||||
else:
|
else:
|
||||||
# Since speculative decoding requires more cuda graph memory, we
|
# Since speculative decoding requires more cuda graph memory, we
|
||||||
# capture less.
|
# capture less.
|
||||||
capture_bs = (
|
capture_bs = (
|
||||||
list(range(1, 9)) + list(range(10, 33, 2)) + list(range(40, 161, 16))
|
list(range(1, 9))
|
||||||
|
+ list(range(10, 33, 2))
|
||||||
|
+ list(range(40, 64, 8))
|
||||||
|
+ list(range(80, 161, 16))
|
||||||
)
|
)
|
||||||
|
|
||||||
gpu_mem = get_device_memory_capacity()
|
gpu_mem = get_device_memory_capacity()
|
||||||
if gpu_mem is not None and gpu_mem > 96 * 1024:
|
if gpu_mem is not None and gpu_mem > 96 * 1024:
|
||||||
capture_bs += list(range(160, 257, 8))
|
capture_bs += list(range(160, 257, 8))
|
||||||
if gpu_mem is not None and gpu_mem > 180 * 1000:
|
|
||||||
capture_bs += list(range(256, 528, 16))
|
|
||||||
|
|
||||||
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
||||||
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
# In some cases (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
||||||
# is very small. We add more values here to make sure we capture the maximum bs.
|
# is very small. We add more values here to make sure we capture the maximum bs.
|
||||||
capture_bs += [model_runner.req_to_token_pool.size - 1] + [
|
capture_bs += [model_runner.req_to_token_pool.size]
|
||||||
model_runner.req_to_token_pool.size
|
|
||||||
]
|
|
||||||
|
|
||||||
if server_args.enable_two_batch_overlap:
|
if server_args.enable_two_batch_overlap:
|
||||||
capture_bs = [bs for bs in capture_bs if bs >= 2]
|
capture_bs = [bs for bs in capture_bs if bs >= 2]
|
||||||
@@ -167,7 +165,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|||||||
)
|
)
|
||||||
capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size]
|
capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size]
|
||||||
capture_bs = list(sorted(set(capture_bs)))
|
capture_bs = list(sorted(set(capture_bs)))
|
||||||
assert len(capture_bs) > 0 and capture_bs[0] > 0
|
assert len(capture_bs) > 0 and capture_bs[0] > 0, f"{capture_bs=}"
|
||||||
compile_bs = (
|
compile_bs = (
|
||||||
[bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
|
[bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
|
||||||
if server_args.enable_torch_compile
|
if server_args.enable_torch_compile
|
||||||
|
|||||||
@@ -918,7 +918,7 @@ class ModelRunner:
|
|||||||
|
|
||||||
if self.req_to_token_pool is None:
|
if self.req_to_token_pool is None:
|
||||||
self.req_to_token_pool = ReqToTokenPool(
|
self.req_to_token_pool = ReqToTokenPool(
|
||||||
size=max_num_reqs + 1,
|
size=max_num_reqs,
|
||||||
max_context_len=self.model_config.context_len + 4,
|
max_context_len=self.model_config.context_len + 4,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||||
|
|||||||
@@ -2055,6 +2055,12 @@ is_ampere_with_cuda_12_3 = lambda: _check(8)
|
|||||||
is_hopper_with_cuda_12_3 = lambda: _check(9)
|
is_hopper_with_cuda_12_3 = lambda: _check(9)
|
||||||
|
|
||||||
|
|
||||||
|
def is_blackwell():
|
||||||
|
if not is_cuda():
|
||||||
|
return False
|
||||||
|
return torch.cuda.get_device_capability()[0] == 10
|
||||||
|
|
||||||
|
|
||||||
def get_free_port():
|
def get_free_port():
|
||||||
# try ipv4
|
# try ipv4
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -127,6 +127,10 @@ def send_one_prompt(args):
|
|||||||
if args.batch_size > 1:
|
if args.batch_size > 1:
|
||||||
ret = ret[0]
|
ret = ret[0]
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
print(ret)
|
||||||
|
return 0, 0
|
||||||
|
|
||||||
latency = ret["meta_info"]["e2e_latency"]
|
latency = ret["meta_info"]["e2e_latency"]
|
||||||
|
|
||||||
if "spec_verify_ct" in ret["meta_info"]:
|
if "spec_verify_ct" in ret["meta_info"]:
|
||||||
|
|||||||
@@ -881,20 +881,24 @@ def calculate_rouge_l(output_strs_list1, output_strs_list2):
|
|||||||
return rouge_l_scores
|
return rouge_l_scores
|
||||||
|
|
||||||
|
|
||||||
STDERR_FILENAME = "stderr.txt"
|
STDERR_FILENAME = "/tmp/stderr.txt"
|
||||||
STDOUT_FILENAME = "stdout.txt"
|
STDOUT_FILENAME = "/tmp/stdout.txt"
|
||||||
|
|
||||||
|
|
||||||
def read_output(output_lines: List[str], filename: str = STDERR_FILENAME):
|
def read_output(output_lines: List[str], filename: str = STDERR_FILENAME):
|
||||||
"""Print the output in real time with another thread."""
|
"""Print the output in real time with another thread."""
|
||||||
while not os.path.exists(filename):
|
while not os.path.exists(filename):
|
||||||
time.sleep(1)
|
time.sleep(0.01)
|
||||||
|
|
||||||
pt = 0
|
pt = 0
|
||||||
while pt >= 0:
|
while pt >= 0:
|
||||||
if pt > 0 and not os.path.exists(filename):
|
if pt > 0 and not os.path.exists(filename):
|
||||||
break
|
break
|
||||||
lines = open(filename).readlines()
|
try:
|
||||||
|
lines = open(filename).readlines()
|
||||||
|
except FileNotFoundError:
|
||||||
|
print(f"{pt=}, {os.path.exists(filename)=}")
|
||||||
|
raise
|
||||||
for line in lines[pt:]:
|
for line in lines[pt:]:
|
||||||
print(line, end="", flush=True)
|
print(line, end="", flush=True)
|
||||||
output_lines.append(line)
|
output_lines.append(line)
|
||||||
|
|||||||
@@ -1,25 +1,33 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
# Show current GPU status
|
if [ "$1" = "rocm" ]; then
|
||||||
nvidia-smi
|
echo "Running in ROCm mode"
|
||||||
|
|
||||||
# Clean SGLang processes
|
# Clean SGLang processes
|
||||||
pgrep -f 'sglang::|sglang\.launch_server|sglang\.bench|sglang\.data_parallel|sglang\.srt' | xargs -r kill -9
|
pgrep -f 'sglang::|sglang\.launch_server|sglang\.bench|sglang\.data_parallel|sglang\.srt' | xargs -r kill -9
|
||||||
|
|
||||||
# Clean all GPU processes if any argument is provided
|
else
|
||||||
if [ $# -gt 0 ]; then
|
# Show current GPU status
|
||||||
# Check if sudo is available
|
nvidia-smi
|
||||||
if command -v sudo >/dev/null 2>&1; then
|
|
||||||
sudo apt-get update
|
# Clean SGLang processes
|
||||||
sudo apt-get install -y lsof
|
pgrep -f 'sglang::|sglang\.launch_server|sglang\.bench|sglang\.data_parallel|sglang\.srt' | xargs -r kill -9
|
||||||
else
|
|
||||||
apt-get update
|
# Clean all GPU processes if any argument is provided
|
||||||
apt-get install -y lsof
|
if [ $# -gt 0 ]; then
|
||||||
|
# Check if sudo is available
|
||||||
|
if command -v sudo >/dev/null 2>&1; then
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y lsof
|
||||||
|
else
|
||||||
|
apt-get update
|
||||||
|
apt-get install -y lsof
|
||||||
|
fi
|
||||||
|
kill -9 $(nvidia-smi | sed -n '/Processes:/,$p' | grep " [0-9]" | awk '{print $5}') 2>/dev/null
|
||||||
|
lsof /dev/nvidia* | awk '{print $2}' | xargs kill -9 2>/dev/null
|
||||||
fi
|
fi
|
||||||
kill -9 $(nvidia-smi | sed -n '/Processes:/,$p' | grep " [0-9]" | awk '{print $5}') 2>/dev/null
|
|
||||||
lsof /dev/nvidia* | awk '{print $2}' | xargs kill -9 2>/dev/null
|
|
||||||
|
# Show GPU status after clean up
|
||||||
|
nvidia-smi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
# Show GPU status after clean up
|
|
||||||
nvidia-smi
|
|
||||||
|
|||||||
Reference in New Issue
Block a user