Clean up server_args, triton cache manager (#8332)
This commit is contained in:
@@ -71,7 +71,6 @@ from sglang.srt.utils import (
|
||||
is_cuda,
|
||||
kill_process_tree,
|
||||
launch_dummy_health_check_server,
|
||||
maybe_set_triton_cache_manager,
|
||||
prepare_model_and_tokenizer,
|
||||
set_prometheus_multiproc_dir,
|
||||
set_ulimit,
|
||||
@@ -637,11 +636,6 @@ def _set_envs_and_config(server_args: ServerArgs):
|
||||
# Set ulimit
|
||||
set_ulimit()
|
||||
|
||||
# Fix triton bugs
|
||||
if server_args.tp_size * server_args.dp_size > 1:
|
||||
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
|
||||
maybe_set_triton_cache_manager()
|
||||
|
||||
# Check flashinfer version
|
||||
if server_args.attention_backend == "flashinfer":
|
||||
assert_pkg_version(
|
||||
|
||||
@@ -107,6 +107,8 @@ from sglang.version import __version__
|
||||
logger = logging.getLogger(__name__)
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
|
||||
|
||||
|
||||
# Store global states
|
||||
@dataclasses.dataclass
|
||||
@@ -212,9 +214,6 @@ async def validate_json_request(raw_request: Request):
|
||||
)
|
||||
|
||||
|
||||
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
|
||||
|
||||
|
||||
##### Native API endpoints #####
|
||||
|
||||
|
||||
@@ -807,6 +806,24 @@ async def retrieve_model(model: str):
|
||||
)
|
||||
|
||||
|
||||
@app.post("/v1/score", dependencies=[Depends(validate_json_request)])
|
||||
async def v1_score_request(request: ScoringRequest, raw_request: Request):
|
||||
"""Endpoint for the decoder-only scoring API. See Engine.score() for detailed documentation."""
|
||||
return await raw_request.app.state.openai_serving_score.handle_request(
|
||||
request, raw_request
|
||||
)
|
||||
|
||||
|
||||
@app.api_route(
|
||||
"/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]
|
||||
)
|
||||
async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request):
|
||||
"""Endpoint for reranking documents based on query relevance."""
|
||||
return await raw_request.app.state.openai_serving_rerank.handle_request(
|
||||
request, raw_request
|
||||
)
|
||||
|
||||
|
||||
## SageMaker API
|
||||
@app.get("/ping")
|
||||
async def sagemaker_health() -> Response:
|
||||
@@ -852,24 +869,6 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque
|
||||
return ORJSONResponse({"predictions": ret})
|
||||
|
||||
|
||||
@app.post("/v1/score", dependencies=[Depends(validate_json_request)])
|
||||
async def v1_score_request(request: ScoringRequest, raw_request: Request):
|
||||
"""Endpoint for the decoder-only scoring API. See Engine.score() for detailed documentation."""
|
||||
return await raw_request.app.state.openai_serving_score.handle_request(
|
||||
request, raw_request
|
||||
)
|
||||
|
||||
|
||||
@app.api_route(
|
||||
"/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]
|
||||
)
|
||||
async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request):
|
||||
"""Endpoint for reranking documents based on query relevance."""
|
||||
return await raw_request.app.state.openai_serving_rerank.handle_request(
|
||||
request, raw_request
|
||||
)
|
||||
|
||||
|
||||
def _create_error_response(e):
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
@@ -916,15 +915,6 @@ def launch_server(
|
||||
add_prometheus_middleware(app)
|
||||
enable_func_timer()
|
||||
|
||||
image_token_text = None
|
||||
if (
|
||||
tokenizer_manager.image_token_id is not None
|
||||
and not server_args.skip_tokenizer_init
|
||||
):
|
||||
image_token_text = tokenizer_manager.tokenizer.decode(
|
||||
[tokenizer_manager.image_token_id]
|
||||
)
|
||||
|
||||
# Send a warmup request - we will create the thread launch it
|
||||
# in the lifespan after all other warmups have fired.
|
||||
warmup_thread = threading.Thread(
|
||||
@@ -932,7 +922,6 @@ def launch_server(
|
||||
args=(
|
||||
server_args,
|
||||
pipe_finish_writer,
|
||||
image_token_text,
|
||||
launch_callback,
|
||||
),
|
||||
)
|
||||
@@ -1066,7 +1055,6 @@ def _execute_server_warmup(
|
||||
def _wait_and_warmup(
|
||||
server_args: ServerArgs,
|
||||
pipe_finish_writer: Optional[multiprocessing.connection.Connection],
|
||||
image_token_text: str,
|
||||
launch_callback: Optional[Callable[[], None]] = None,
|
||||
):
|
||||
if not server_args.skip_server_warmup:
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Callable, NamedTuple, Optional
|
||||
from typing import Callable, NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -39,10 +39,10 @@ from sglang.srt.utils import (
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_hip = is_hip()
|
||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
_is_cpu = is_cpu()
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
_is_npu = is_npu()
|
||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import moe_fused_gate
|
||||
@@ -54,7 +54,6 @@ if _use_aiter:
|
||||
from aiter import biased_grouped_topk as aiter_biased_grouped_topk
|
||||
except ImportError:
|
||||
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
|
||||
|
||||
if _is_npu:
|
||||
import torch_npu
|
||||
|
||||
|
||||
@@ -653,6 +653,9 @@ class Scheduler(
|
||||
)
|
||||
)
|
||||
|
||||
embedding_cache_size = int(os.environ.get("SGLANG_VLM_CACHE_SIZE_MB", "100"))
|
||||
init_embedding_cache(embedding_cache_size * 1024 * 1024)
|
||||
|
||||
def init_profier(self):
|
||||
self.torch_profiler = None
|
||||
self.torch_profiler_output_dir: Optional[str] = None
|
||||
@@ -2895,9 +2898,9 @@ def run_scheduler_process(
|
||||
prefix += f" PP{pp_rank}"
|
||||
|
||||
# Config the process
|
||||
kill_itself_when_parent_died()
|
||||
setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
|
||||
faulthandler.enable()
|
||||
kill_itself_when_parent_died()
|
||||
parent_process = psutil.Process().parent()
|
||||
|
||||
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
|
||||
@@ -2912,10 +2915,6 @@ def run_scheduler_process(
|
||||
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
||||
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
|
||||
|
||||
embedding_cache_size = 100
|
||||
if "SGLANG_VLM_CACHE_SIZE_MB" in os.environ:
|
||||
embedding_cache_size = int(os.environ["SGLANG_VLM_CACHE_SIZE_MB"])
|
||||
init_embedding_cache(embedding_cache_size * 1024 * 1024)
|
||||
# Create a scheduler and run the event loop
|
||||
try:
|
||||
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
|
||||
@@ -2926,8 +2925,8 @@ def run_scheduler_process(
|
||||
"max_req_input_len": scheduler.max_req_input_len,
|
||||
}
|
||||
)
|
||||
disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
|
||||
|
||||
disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
|
||||
if disaggregation_mode == DisaggregationMode.NULL:
|
||||
if server_args.pp_size > 1:
|
||||
scheduler.event_loop_pp()
|
||||
|
||||
@@ -74,8 +74,6 @@ class ForwardMode(IntEnum):
|
||||
MIXED = auto()
|
||||
# No sequence to forward. For data parallel attention, some workers will be IDLE if no sequence are allocated.
|
||||
IDLE = auto()
|
||||
# Split Prefill for PD multiplexing
|
||||
SPLIT_PREFILL = auto()
|
||||
|
||||
# Used in speculative decoding: verify a batch in the target model.
|
||||
TARGET_VERIFY = auto()
|
||||
@@ -86,6 +84,9 @@ class ForwardMode(IntEnum):
|
||||
# It is now used for triggering the sampling_info_done event for the first prefill batch.
|
||||
DUMMY_FIRST = auto()
|
||||
|
||||
# Split Prefill for PD multiplexing
|
||||
SPLIT_PREFILL = auto()
|
||||
|
||||
def is_prefill(self):
|
||||
return self.is_extend()
|
||||
|
||||
@@ -103,12 +104,12 @@ class ForwardMode(IntEnum):
|
||||
def is_mixed(self):
|
||||
return self == ForwardMode.MIXED
|
||||
|
||||
def is_split_prefill(self):
|
||||
return self == ForwardMode.SPLIT_PREFILL
|
||||
|
||||
def is_idle(self):
|
||||
return self == ForwardMode.IDLE
|
||||
|
||||
def is_decode_or_idle(self):
|
||||
return self == ForwardMode.DECODE or self == ForwardMode.IDLE
|
||||
|
||||
def is_target_verify(self):
|
||||
return self == ForwardMode.TARGET_VERIFY
|
||||
|
||||
@@ -132,8 +133,8 @@ class ForwardMode(IntEnum):
|
||||
def is_dummy_first(self):
|
||||
return self == ForwardMode.DUMMY_FIRST
|
||||
|
||||
def is_decode_or_idle(self):
|
||||
return self == ForwardMode.DECODE or self == ForwardMode.IDLE
|
||||
def is_split_prefill(self):
|
||||
return self == ForwardMode.SPLIT_PREFILL
|
||||
|
||||
|
||||
@total_ordering
|
||||
|
||||
@@ -109,7 +109,6 @@ from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
get_cpu_ids_by_node,
|
||||
init_custom_process_group,
|
||||
is_cuda,
|
||||
is_fa3_default_architecture,
|
||||
is_flashinfer_available,
|
||||
is_hip,
|
||||
|
||||
@@ -80,7 +80,7 @@ class ServerArgs:
|
||||
schedule_policy: str = "fcfs"
|
||||
schedule_conservativeness: float = 1.0
|
||||
cpu_offload_gb: int = 0
|
||||
page_size: int = 1
|
||||
page_size: Optional[int] = None
|
||||
hybrid_kvcache_ratio: Optional[float] = None
|
||||
swa_full_tokens_ratio: float = 0.8
|
||||
disable_hybrid_swa_memory: bool = False
|
||||
@@ -266,31 +266,20 @@ class ServerArgs:
|
||||
|
||||
def __post_init__(self):
|
||||
# Expert parallelism
|
||||
# We put it here first due to some internal ckpt conversation issues.
|
||||
if self.enable_ep_moe:
|
||||
self.ep_size = self.tp_size
|
||||
logger.warning(
|
||||
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
||||
)
|
||||
if self.enable_flashinfer_moe:
|
||||
assert (
|
||||
self.quantization == "modelopt_fp4"
|
||||
), "modelopt_fp4 quantization is required for Flashinfer MOE"
|
||||
os.environ["TRTLLM_ENABLE_PDL"] = "1"
|
||||
self.disable_shared_experts_fusion = True
|
||||
logger.warning(
|
||||
f"Flashinfer MoE is enabled. Shared expert fusion is disabled."
|
||||
)
|
||||
|
||||
# Set missing default values
|
||||
if self.tokenizer_path is None:
|
||||
self.tokenizer_path = self.model_path
|
||||
|
||||
if self.device is None:
|
||||
self.device = get_device()
|
||||
|
||||
if self.served_model_name is None:
|
||||
self.served_model_name = self.model_path
|
||||
|
||||
if self.device is None:
|
||||
self.device = get_device()
|
||||
if self.random_seed is None:
|
||||
self.random_seed = random.randint(0, 1 << 30)
|
||||
|
||||
@@ -359,7 +348,6 @@ class ServerArgs:
|
||||
self.chunked_prefill_size = 16384
|
||||
else:
|
||||
self.chunked_prefill_size = 4096
|
||||
assert self.chunked_prefill_size % self.page_size == 0
|
||||
|
||||
# Set cuda graph max batch size
|
||||
if self.cuda_graph_max_bs is None:
|
||||
@@ -410,6 +398,14 @@ class ServerArgs:
|
||||
)
|
||||
self.page_size = 128
|
||||
|
||||
# Set page size
|
||||
if self.page_size is None:
|
||||
self.page_size = 1
|
||||
|
||||
# AMD-specific Triton attention KV splits default number
|
||||
if is_hip():
|
||||
self.triton_attention_num_kv_splits = 16
|
||||
|
||||
# Choose grammar backend
|
||||
if self.grammar_backend is None:
|
||||
self.grammar_backend = "xgrammar"
|
||||
@@ -431,6 +427,17 @@ class ServerArgs:
|
||||
self.enable_dp_attention
|
||||
), "Please enable dp attention when setting enable_dp_lm_head. "
|
||||
|
||||
# MoE kernel
|
||||
if self.enable_flashinfer_moe:
|
||||
assert (
|
||||
self.quantization == "modelopt_fp4"
|
||||
), "modelopt_fp4 quantization is required for Flashinfer MOE"
|
||||
os.environ["TRTLLM_ENABLE_PDL"] = "1"
|
||||
self.disable_shared_experts_fusion = True
|
||||
logger.warning(
|
||||
f"Flashinfer MoE is enabled. Shared expert fusion is disabled."
|
||||
)
|
||||
|
||||
# DeepEP MoE
|
||||
if self.enable_deepep_moe:
|
||||
if self.deepep_mode == "normal":
|
||||
@@ -502,14 +509,6 @@ class ServerArgs:
|
||||
logger.warning(
|
||||
"DeepSeek MTP does not require setting speculative_draft_model_path."
|
||||
)
|
||||
elif "Llama4" in model_arch:
|
||||
# TODO: remove this after Llama4 supports in other backends
|
||||
if self.attention_backend != "fa3":
|
||||
self.attention_backend = "fa3"
|
||||
logger.warning(
|
||||
"Llama4 requires using fa3 attention backend. "
|
||||
"Attention backend is automatically set to fa3."
|
||||
)
|
||||
|
||||
# Auto choose parameters
|
||||
if self.speculative_num_steps is None:
|
||||
@@ -542,12 +541,11 @@ class ServerArgs:
|
||||
) and check_gguf_file(self.model_path):
|
||||
self.quantization = self.load_format = "gguf"
|
||||
|
||||
# Model loading
|
||||
if is_remote_url(self.model_path):
|
||||
self.load_format = "remote"
|
||||
|
||||
# AMD-specific Triton attention KV splits default number
|
||||
if is_hip():
|
||||
self.triton_attention_num_kv_splits = 16
|
||||
if self.custom_weight_loader is None:
|
||||
self.custom_weight_loader = []
|
||||
|
||||
# PD disaggregation
|
||||
if self.disaggregation_mode == "decode":
|
||||
@@ -572,6 +570,7 @@ class ServerArgs:
|
||||
self.disable_cuda_graph = True
|
||||
logger.warning("Cuda graph is disabled for prefill server")
|
||||
|
||||
# Propagate env vars
|
||||
os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
|
||||
"1" if self.enable_torch_compile else "0"
|
||||
)
|
||||
@@ -580,9 +579,6 @@ class ServerArgs:
|
||||
"1" if self.disable_outlines_disk_cache else "0"
|
||||
)
|
||||
|
||||
if self.custom_weight_loader is None:
|
||||
self.custom_weight_loader = []
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
# Model and tokenizer
|
||||
@@ -1227,6 +1223,13 @@ class ServerArgs:
|
||||
default=ServerArgs.grammar_backend,
|
||||
help="Choose the backend for grammar-guided decoding.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mm-attention-backend",
|
||||
type=str,
|
||||
choices=["sdpa", "fa3", "triton_attn"],
|
||||
default=ServerArgs.mm_attention_backend,
|
||||
help="Set multimodal attention backend.",
|
||||
)
|
||||
|
||||
# Speculative decoding
|
||||
parser.add_argument(
|
||||
@@ -1276,13 +1279,6 @@ class ServerArgs:
|
||||
help="The path of the draft model's small vocab table.",
|
||||
default=ServerArgs.speculative_token_map,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mm-attention-backend",
|
||||
type=str,
|
||||
choices=["sdpa", "fa3", "triton_attn"],
|
||||
default=ServerArgs.mm_attention_backend,
|
||||
help="Set multimodal attention backend.",
|
||||
)
|
||||
|
||||
# Expert parallelism
|
||||
parser.add_argument(
|
||||
@@ -1530,11 +1526,6 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Disable the overlap scheduler, which overlaps the CPU scheduler with GPU model worker.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-overlap-cg-plan",
|
||||
action="store_true",
|
||||
help="Disable the overlap optimization for cudagraph preparation in eagle verify.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-mixed-chunk",
|
||||
action="store_true",
|
||||
@@ -1792,11 +1783,11 @@ class ServerArgs:
|
||||
return hf_config
|
||||
|
||||
def check_server_args(self):
|
||||
# Check parallel size constraints
|
||||
assert (
|
||||
self.tp_size * self.pp_size
|
||||
) % self.nnodes == 0, "tp_size must be divisible by number of nodes"
|
||||
|
||||
# FIXME pp constraints
|
||||
if self.pp_size > 1:
|
||||
assert (
|
||||
self.disable_overlap_schedule
|
||||
@@ -1807,11 +1798,7 @@ class ServerArgs:
|
||||
assert not (
|
||||
self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention
|
||||
), "multi-node data parallel is not supported unless dp attention!"
|
||||
assert (
|
||||
self.max_loras_per_batch > 0
|
||||
# FIXME
|
||||
and (self.lora_paths is None or self.disable_radix_cache)
|
||||
), "compatibility of lora and radix attention is in progress"
|
||||
|
||||
assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
|
||||
assert self.gpu_id_step >= 1, "gpu_id_step must be positive"
|
||||
|
||||
@@ -1820,9 +1807,32 @@ class ServerArgs:
|
||||
None,
|
||||
}, "moe_dense_tp_size only support 1 and None currently"
|
||||
|
||||
# Check model architecture
|
||||
model_arch = self.get_hf_config().architectures[0]
|
||||
if "Llama4" in model_arch:
|
||||
assert self.attention_backend == "fa3", "fa3 is required for Llama4 model"
|
||||
|
||||
# Check LoRA
|
||||
self.check_lora_server_args()
|
||||
|
||||
# Check speculative decoding
|
||||
if self.speculative_algorithm is not None:
|
||||
assert (
|
||||
not self.enable_mixed_chunk
|
||||
), "enable_mixed_chunk is required for speculative decoding"
|
||||
|
||||
# Check chunked prefill
|
||||
assert (
|
||||
self.chunked_prefill_size % self.page_size == 0
|
||||
), "chunked_prefill_size must be divisible by page_size"
|
||||
|
||||
def check_lora_server_args(self):
|
||||
assert (
|
||||
self.max_loras_per_batch > 0
|
||||
# FIXME
|
||||
and (self.lora_paths is None or self.disable_radix_cache)
|
||||
), "compatibility of lora and radix attention is in progress"
|
||||
|
||||
# Enable LoRA if any LoRA paths are provided for backward compatibility.
|
||||
if self.lora_paths:
|
||||
if self.enable_lora is None:
|
||||
|
||||
@@ -336,7 +336,6 @@ class EAGLEDraftCudaGraphRunner:
|
||||
forward_batch.req_pool_indices = self.req_pool_indices[:bs]
|
||||
forward_batch.positions = self.positions[:num_tokens]
|
||||
|
||||
# Special handle for seq_len_cpu used when flashinfer mla is used
|
||||
if forward_batch.seq_lens_cpu is not None:
|
||||
if bs != raw_bs:
|
||||
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
|
||||
|
||||
@@ -937,71 +937,6 @@ def monkey_patch_vllm_gguf_config():
|
||||
setattr(GGUFConfig, "get_quant_method", get_quant_method_with_embedding_replaced)
|
||||
|
||||
|
||||
def maybe_set_triton_cache_manager() -> None:
|
||||
"""Set environment variable to tell Triton to use a
|
||||
custom cache manager"""
|
||||
cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None)
|
||||
if cache_manger is None:
|
||||
manager = "sglang.srt.utils:CustomCacheManager"
|
||||
logger.debug("Setting Triton cache manager to: %s", manager)
|
||||
os.environ["TRITON_CACHE_MANAGER"] = manager
|
||||
|
||||
|
||||
class CustomCacheManager(FileCacheManager):
|
||||
# Adapted from: https://github.com/tdoublep/vllm/blob/3307522289fdfefe323b6c00d0db696651989a2f/vllm/triton_utils/custom_cache_manager.py
|
||||
def __init__(self, key, override=False, dump=False):
|
||||
from sglang.srt.distributed.parallel_state import get_tp_group
|
||||
|
||||
self.key = key
|
||||
self.lock_path = None
|
||||
|
||||
try:
|
||||
module_path = "triton.runtime.cache"
|
||||
cache_module = importlib.import_module(module_path)
|
||||
|
||||
default_cache_dir = getattr(cache_module, "default_cache_dir", None)
|
||||
default_dump_dir = getattr(cache_module, "default_dump_dir", None)
|
||||
default_override_dir = getattr(cache_module, "default_override_dir", None)
|
||||
except (ModuleNotFoundError, AttributeError) as e:
|
||||
default_cache_dir = None
|
||||
default_dump_dir = None
|
||||
default_override_dir = None
|
||||
|
||||
if dump:
|
||||
self.cache_dir = (
|
||||
default_dump_dir()
|
||||
if default_dump_dir is not None
|
||||
else os.path.join(Path.home(), ".triton", "dump")
|
||||
)
|
||||
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||
self.lock_path = os.path.join(self.cache_dir, "lock")
|
||||
os.makedirs(self.cache_dir, exist_ok=True)
|
||||
elif override:
|
||||
self.cache_dir = (
|
||||
default_override_dir()
|
||||
if default_override_dir is not None
|
||||
else os.path.join(Path.home(), ".triton", "override")
|
||||
)
|
||||
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||
else:
|
||||
# create cache directory if it doesn't exist
|
||||
self.cache_dir = os.getenv("TRITON_CACHE_DIR", "").strip() or (
|
||||
default_cache_dir()
|
||||
if default_cache_dir is not None
|
||||
else os.path.join(Path.home(), ".triton", "cache")
|
||||
)
|
||||
if self.cache_dir:
|
||||
try:
|
||||
self.cache_dir = f"{self.cache_dir}_{get_tp_group().local_rank}"
|
||||
except:
|
||||
self.cache_dir = f"{self.cache_dir}_{os.getpid()}"
|
||||
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||
self.lock_path = os.path.join(self.cache_dir, "lock")
|
||||
os.makedirs(self.cache_dir, exist_ok=True)
|
||||
else:
|
||||
raise RuntimeError("Could not create or locate cache dir")
|
||||
|
||||
|
||||
def set_ulimit(target_soft_limit=65535):
|
||||
# number of open files
|
||||
resource_type = resource.RLIMIT_NOFILE
|
||||
|
||||
Reference in New Issue
Block a user