|
|
|
|
@@ -20,6 +20,7 @@ import logging
|
|
|
|
|
import os
|
|
|
|
|
import random
|
|
|
|
|
import tempfile
|
|
|
|
|
from token import OP
|
|
|
|
|
from typing import List, Literal, Optional, Union
|
|
|
|
|
|
|
|
|
|
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
|
|
|
|
|
@@ -46,31 +47,28 @@ class ServerArgs:
|
|
|
|
|
tokenizer_path: Optional[str] = None
|
|
|
|
|
tokenizer_mode: str = "auto"
|
|
|
|
|
skip_tokenizer_init: bool = False
|
|
|
|
|
skip_server_warmup: bool = False
|
|
|
|
|
load_format: str = "auto"
|
|
|
|
|
model_loader_extra_config: str = "{}"
|
|
|
|
|
trust_remote_code: bool = False
|
|
|
|
|
dtype: str = "auto"
|
|
|
|
|
kv_cache_dtype: str = "auto"
|
|
|
|
|
quantization: Optional[str] = None
|
|
|
|
|
quantization_param_path: Optional[str] = None
|
|
|
|
|
context_length: Optional[int] = None
|
|
|
|
|
device: Optional[str] = None
|
|
|
|
|
served_model_name: Optional[str] = None
|
|
|
|
|
chat_template: Optional[str] = None
|
|
|
|
|
completion_template: Optional[str] = None
|
|
|
|
|
is_embedding: bool = False
|
|
|
|
|
enable_multimodal: Optional[bool] = None
|
|
|
|
|
revision: Optional[str] = None
|
|
|
|
|
hybrid_kvcache_ratio: Optional[float] = None
|
|
|
|
|
swa_full_tokens_ratio: float = 0.8
|
|
|
|
|
impl: str = "auto"
|
|
|
|
|
model_impl: str = "auto"
|
|
|
|
|
|
|
|
|
|
# Port for the HTTP server
|
|
|
|
|
# HTTP server
|
|
|
|
|
host: str = "127.0.0.1"
|
|
|
|
|
port: int = 30000
|
|
|
|
|
skip_server_warmup: bool = False
|
|
|
|
|
warmups: Optional[str] = None
|
|
|
|
|
nccl_port: Optional[int] = None
|
|
|
|
|
|
|
|
|
|
# Quantization and data type
|
|
|
|
|
dtype: str = "auto"
|
|
|
|
|
quantization: Optional[str] = None
|
|
|
|
|
quantization_param_path: Optional[str] = None
|
|
|
|
|
kv_cache_dtype: str = "auto"
|
|
|
|
|
|
|
|
|
|
# Memory and scheduling
|
|
|
|
|
mem_fraction_static: Optional[float] = None
|
|
|
|
|
max_running_requests: Optional[int] = None
|
|
|
|
|
@@ -81,8 +79,12 @@ class ServerArgs:
|
|
|
|
|
schedule_conservativeness: float = 1.0
|
|
|
|
|
cpu_offload_gb: int = 0
|
|
|
|
|
page_size: int = 1
|
|
|
|
|
hybrid_kvcache_ratio: Optional[float] = None
|
|
|
|
|
swa_full_tokens_ratio: float = 0.8
|
|
|
|
|
disable_hybrid_swa_memory: bool = False
|
|
|
|
|
|
|
|
|
|
# Other runtime options
|
|
|
|
|
# Runtime options
|
|
|
|
|
device: Optional[str] = None
|
|
|
|
|
tp_size: int = 1
|
|
|
|
|
pp_size: int = 1
|
|
|
|
|
max_micro_batch_size: Optional[int] = None
|
|
|
|
|
@@ -107,8 +109,8 @@ class ServerArgs:
|
|
|
|
|
enable_metrics: bool = False
|
|
|
|
|
enable_metrics_for_all_schedulers: bool = False
|
|
|
|
|
bucket_time_to_first_token: Optional[List[float]] = None
|
|
|
|
|
bucket_e2e_request_latency: Optional[List[float]] = None
|
|
|
|
|
bucket_inter_token_latency: Optional[List[float]] = None
|
|
|
|
|
bucket_e2e_request_latency: Optional[List[float]] = None
|
|
|
|
|
collect_tokens_histogram: bool = False
|
|
|
|
|
decode_log_interval: int = 40
|
|
|
|
|
enable_request_time_stats_logging: bool = False
|
|
|
|
|
@@ -116,6 +118,9 @@ class ServerArgs:
|
|
|
|
|
|
|
|
|
|
# API related
|
|
|
|
|
api_key: Optional[str] = None
|
|
|
|
|
served_model_name: Optional[str] = None
|
|
|
|
|
chat_template: Optional[str] = None
|
|
|
|
|
completion_template: Optional[str] = None
|
|
|
|
|
file_storage_path: str = "sglang_storage"
|
|
|
|
|
enable_cache_report: bool = False
|
|
|
|
|
reasoning_parser: Optional[str] = None
|
|
|
|
|
@@ -179,6 +184,14 @@ class ServerArgs:
|
|
|
|
|
deepep_config: Optional[str] = None
|
|
|
|
|
moe_dense_tp_size: Optional[int] = None
|
|
|
|
|
|
|
|
|
|
# Hierarchical cache
|
|
|
|
|
enable_hierarchical_cache: bool = False
|
|
|
|
|
hicache_ratio: float = 2.0
|
|
|
|
|
hicache_size: int = 0
|
|
|
|
|
hicache_write_policy: str = "write_through_selective"
|
|
|
|
|
hicache_io_backend: str = ""
|
|
|
|
|
hicache_storage_backend: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
# Double Sparsity
|
|
|
|
|
enable_double_sparsity: bool = False
|
|
|
|
|
ds_channel_config_path: Optional[str] = None
|
|
|
|
|
@@ -200,7 +213,6 @@ class ServerArgs:
|
|
|
|
|
disable_custom_all_reduce: bool = False
|
|
|
|
|
enable_mscclpp: bool = False
|
|
|
|
|
disable_overlap_schedule: bool = False
|
|
|
|
|
disable_overlap_cg_plan: bool = False
|
|
|
|
|
enable_mixed_chunk: bool = False
|
|
|
|
|
enable_dp_attention: bool = False
|
|
|
|
|
enable_dp_lm_head: bool = False
|
|
|
|
|
@@ -217,20 +229,12 @@ class ServerArgs:
|
|
|
|
|
enable_memory_saver: bool = False
|
|
|
|
|
allow_auto_truncate: bool = False
|
|
|
|
|
enable_custom_logit_processor: bool = False
|
|
|
|
|
enable_hierarchical_cache: bool = False
|
|
|
|
|
hicache_ratio: float = 2.0
|
|
|
|
|
hicache_size: int = 0
|
|
|
|
|
hicache_write_policy: str = "write_through_selective"
|
|
|
|
|
hicache_io_backend: str = ""
|
|
|
|
|
hicache_storage_backend: Optional[str] = None
|
|
|
|
|
flashinfer_mla_disable_ragged: bool = False
|
|
|
|
|
disable_shared_experts_fusion: bool = False
|
|
|
|
|
disable_chunked_prefix_cache: bool = False
|
|
|
|
|
disable_fast_image_processor: bool = False
|
|
|
|
|
enable_return_hidden_states: bool = False
|
|
|
|
|
enable_triton_kernel_moe: bool = False
|
|
|
|
|
warmups: Optional[str] = None
|
|
|
|
|
disable_hybrid_swa_memory: bool = False
|
|
|
|
|
|
|
|
|
|
# Debug tensor dumps
|
|
|
|
|
debug_tensor_dump_output_folder: Optional[str] = None
|
|
|
|
|
@@ -238,7 +242,7 @@ class ServerArgs:
|
|
|
|
|
debug_tensor_dump_inject: bool = False
|
|
|
|
|
debug_tensor_dump_prefill_only: bool = False
|
|
|
|
|
|
|
|
|
|
# For PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
|
|
|
|
|
# PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
|
|
|
|
|
disaggregation_mode: str = "null"
|
|
|
|
|
disaggregation_transfer_backend: str = "mooncake"
|
|
|
|
|
disaggregation_bootstrap_port: int = 8998
|
|
|
|
|
@@ -273,6 +277,7 @@ class ServerArgs:
|
|
|
|
|
logger.warning(
|
|
|
|
|
f"Flashinfer MoE is enabled. Shared expert fusion is disabled."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Set missing default values
|
|
|
|
|
if self.tokenizer_path is None:
|
|
|
|
|
self.tokenizer_path = self.model_path
|
|
|
|
|
@@ -333,56 +338,12 @@ class ServerArgs:
|
|
|
|
|
self.mem_fraction_static = 0.88
|
|
|
|
|
|
|
|
|
|
# Lazy init to avoid circular import
|
|
|
|
|
# Multimodal models need more memory for the image processor
|
|
|
|
|
from sglang.srt.configs.model_config import ModelConfig
|
|
|
|
|
|
|
|
|
|
# Multimodal models need more memory for the image processor
|
|
|
|
|
model_config = ModelConfig.from_server_args(self)
|
|
|
|
|
|
|
|
|
|
vision_config = getattr(model_config.hf_config, "vision_config", None)
|
|
|
|
|
|
|
|
|
|
if model_config.is_multimodal and vision_config:
|
|
|
|
|
# roughly reduce the mem_fraction_static base on params of Vit
|
|
|
|
|
original_server_arg_mem_fraction = self.mem_fraction_static
|
|
|
|
|
# a base mem_fraction_static factor for regular Vit
|
|
|
|
|
base_mem_fraction_reduction_ratio = 0.95
|
|
|
|
|
|
|
|
|
|
vit_num_layers = getattr(vision_config, "num_hidden_layers", 24)
|
|
|
|
|
vit_hidden_size = getattr(vision_config, "hidden_size", 1024)
|
|
|
|
|
|
|
|
|
|
# baseline ViT params (ViT-L/14)
|
|
|
|
|
baseline_vit_layers = 24
|
|
|
|
|
baseline_vit_hidden_size = 1024
|
|
|
|
|
|
|
|
|
|
# weight params count
|
|
|
|
|
current_complexity_score = vit_num_layers * (vit_hidden_size**2)
|
|
|
|
|
baseline_complexity_score = baseline_vit_layers * (
|
|
|
|
|
baseline_vit_hidden_size**2
|
|
|
|
|
)
|
|
|
|
|
complexity_ratio = (
|
|
|
|
|
current_complexity_score / baseline_complexity_score
|
|
|
|
|
if baseline_complexity_score > 0
|
|
|
|
|
else 1.0
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# every time the complexity grows 100%, adjust final factor for 10%
|
|
|
|
|
sensitivity_scale = 0.1
|
|
|
|
|
dynamic_adjustment_factor = 1.0 - sensitivity_scale * (
|
|
|
|
|
complexity_ratio - 1.0
|
|
|
|
|
)
|
|
|
|
|
dynamic_adjustment_factor = max(
|
|
|
|
|
0.8, min(1.05, dynamic_adjustment_factor)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
final_overall_factor = (
|
|
|
|
|
base_mem_fraction_reduction_ratio * dynamic_adjustment_factor
|
|
|
|
|
)
|
|
|
|
|
self.mem_fraction_static = (
|
|
|
|
|
original_server_arg_mem_fraction * final_overall_factor
|
|
|
|
|
)
|
|
|
|
|
logger.warning(
|
|
|
|
|
f"Multimodal model: Dynamically adjusted --mem-fraction-static "
|
|
|
|
|
f"from: {original_server_arg_mem_fraction:.3f} to: {self.mem_fraction_static:.3f}."
|
|
|
|
|
)
|
|
|
|
|
if model_config.is_multimodal:
|
|
|
|
|
self.adjust_mem_fraction_for_vlm(model_config)
|
|
|
|
|
|
|
|
|
|
# Set chunked prefill size, which depends on the gpu memory capacity
|
|
|
|
|
if self.chunked_prefill_size is None:
|
|
|
|
|
@@ -406,23 +367,6 @@ class ServerArgs:
|
|
|
|
|
else:
|
|
|
|
|
self.cuda_graph_max_bs = 80
|
|
|
|
|
|
|
|
|
|
assert self.moe_dense_tp_size in {
|
|
|
|
|
1,
|
|
|
|
|
None,
|
|
|
|
|
}, "moe_dense_tp_size only support 1 and None currently"
|
|
|
|
|
|
|
|
|
|
if self.attention_backend == "flashmla":
|
|
|
|
|
logger.warning(
|
|
|
|
|
"FlashMLA only supports a page_size of 64, change page_size to 64."
|
|
|
|
|
)
|
|
|
|
|
self.page_size = 64
|
|
|
|
|
|
|
|
|
|
if self.attention_backend == "cutlass_mla":
|
|
|
|
|
logger.warning(
|
|
|
|
|
"Cutlass MLA only supports a page_size of 128, change page_size to 128."
|
|
|
|
|
)
|
|
|
|
|
self.page_size = 128
|
|
|
|
|
|
|
|
|
|
# Set kernel backends for hpu device
|
|
|
|
|
if self.device == "hpu":
|
|
|
|
|
self.attention_backend = "torch_native"
|
|
|
|
|
@@ -451,6 +395,18 @@ class ServerArgs:
|
|
|
|
|
)
|
|
|
|
|
self.page_size = 128
|
|
|
|
|
|
|
|
|
|
if self.attention_backend == "flashmla":
|
|
|
|
|
logger.warning(
|
|
|
|
|
"FlashMLA only supports a page_size of 64, change page_size to 64."
|
|
|
|
|
)
|
|
|
|
|
self.page_size = 64
|
|
|
|
|
|
|
|
|
|
if self.attention_backend == "cutlass_mla":
|
|
|
|
|
logger.warning(
|
|
|
|
|
"Cutlass MLA only supports a page_size of 128, change page_size to 128."
|
|
|
|
|
)
|
|
|
|
|
self.page_size = 128
|
|
|
|
|
|
|
|
|
|
# Choose grammar backend
|
|
|
|
|
if self.grammar_backend is None:
|
|
|
|
|
self.grammar_backend = "xgrammar"
|
|
|
|
|
@@ -482,12 +438,6 @@ class ServerArgs:
|
|
|
|
|
f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if self.pp_size > 1:
|
|
|
|
|
self.disable_overlap_schedule = True
|
|
|
|
|
logger.warning(
|
|
|
|
|
"Pipeline parallelism is incompatible with overlap schedule."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if self.enable_eplb and (self.expert_distribution_recorder_mode is None):
|
|
|
|
|
self.expert_distribution_recorder_mode = "stat"
|
|
|
|
|
logger.info(
|
|
|
|
|
@@ -513,6 +463,13 @@ class ServerArgs:
|
|
|
|
|
elif self.expert_distribution_recorder_mode is not None:
|
|
|
|
|
self.expert_distribution_recorder_buffer_size = 1000
|
|
|
|
|
|
|
|
|
|
# Pipeline parallelism
|
|
|
|
|
if self.pp_size > 1:
|
|
|
|
|
self.disable_overlap_schedule = True
|
|
|
|
|
logger.warning(
|
|
|
|
|
"Pipeline parallelism is incompatible with overlap schedule."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Speculative Decoding
|
|
|
|
|
if self.speculative_algorithm == "NEXTN":
|
|
|
|
|
# NEXTN shares the same implementation of EAGLE
|
|
|
|
|
@@ -533,8 +490,7 @@ class ServerArgs:
|
|
|
|
|
"eagle speculative decoding."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
model_arch = get_model_arch(self)
|
|
|
|
|
|
|
|
|
|
model_arch = self.get_hf_config().architectures[0]
|
|
|
|
|
if model_arch == "DeepseekV3ForCausalLM":
|
|
|
|
|
# Auto set draft_model_path DeepSeek-V3/R1
|
|
|
|
|
if self.speculative_draft_model_path is None:
|
|
|
|
|
@@ -624,17 +580,9 @@ class ServerArgs:
|
|
|
|
|
if self.custom_weight_loader is None:
|
|
|
|
|
self.custom_weight_loader = []
|
|
|
|
|
|
|
|
|
|
def validate_disagg_tp_size(self, prefill_tp: int, decode_tp: int):
|
|
|
|
|
larger_tp = max(decode_tp, prefill_tp)
|
|
|
|
|
smaller_tp = min(decode_tp, prefill_tp)
|
|
|
|
|
assert larger_tp % smaller_tp == 0, (
|
|
|
|
|
"Different tp size is supported only when one tp is multiple of the other. "
|
|
|
|
|
f"decode_tp={decode_tp}, prefill_tp={prefill_tp}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def add_cli_args(parser: argparse.ArgumentParser):
|
|
|
|
|
# Model and port args
|
|
|
|
|
# Model and tokenizer
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--model-path",
|
|
|
|
|
"--model",
|
|
|
|
|
@@ -648,24 +596,6 @@ class ServerArgs:
|
|
|
|
|
default=ServerArgs.tokenizer_path,
|
|
|
|
|
help="The path of the tokenizer.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--host",
|
|
|
|
|
type=str,
|
|
|
|
|
default=ServerArgs.host,
|
|
|
|
|
help="The host of the HTTP server.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--port",
|
|
|
|
|
type=int,
|
|
|
|
|
default=ServerArgs.port,
|
|
|
|
|
help="The port of the HTTP server.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--nccl-port",
|
|
|
|
|
type=int,
|
|
|
|
|
default=ServerArgs.nccl_port,
|
|
|
|
|
help="The port for NCCL distributed environment setup. Defaults to a random port.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--tokenizer-mode",
|
|
|
|
|
type=str,
|
|
|
|
|
@@ -680,11 +610,6 @@ class ServerArgs:
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="If set, skip init tokenizer and pass input_ids in generate request.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--skip-server-warmup",
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="If set, skip warmup.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--load-format",
|
|
|
|
|
type=str,
|
|
|
|
|
@@ -730,6 +655,77 @@ class ServerArgs:
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--context-length",
|
|
|
|
|
type=int,
|
|
|
|
|
default=ServerArgs.context_length,
|
|
|
|
|
help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--is-embedding",
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="Whether to use a CausalLM as an embedding model.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--enable-multimodal",
|
|
|
|
|
default=ServerArgs.enable_multimodal,
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="Enable the multimodal functionality for the served model. If the model being served is not multimodal, nothing will happen",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--revision",
|
|
|
|
|
type=str,
|
|
|
|
|
default=None,
|
|
|
|
|
help="The specific model version to use. It can be a branch "
|
|
|
|
|
"name, a tag name, or a commit id. If unspecified, will use "
|
|
|
|
|
"the default version.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--model-impl",
|
|
|
|
|
type=str,
|
|
|
|
|
default=ServerArgs.model_impl,
|
|
|
|
|
help="Which implementation of the model to use.\n\n"
|
|
|
|
|
'* "auto" will try to use the SGLang implementation if it exists '
|
|
|
|
|
"and fall back to the Transformers implementation if no SGLang "
|
|
|
|
|
"implementation is available.\n"
|
|
|
|
|
'* "sglang" will use the SGLang model implementation.\n'
|
|
|
|
|
'* "transformers" will use the Transformers model '
|
|
|
|
|
"implementation.\n",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# HTTP server
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--host",
|
|
|
|
|
type=str,
|
|
|
|
|
default=ServerArgs.host,
|
|
|
|
|
help="The host of the HTTP server.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--port",
|
|
|
|
|
type=int,
|
|
|
|
|
default=ServerArgs.port,
|
|
|
|
|
help="The port of the HTTP server.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--skip-server-warmup",
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="If set, skip warmup.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--warmups",
|
|
|
|
|
type=str,
|
|
|
|
|
required=False,
|
|
|
|
|
help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 "
|
|
|
|
|
"will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--nccl-port",
|
|
|
|
|
type=int,
|
|
|
|
|
default=ServerArgs.nccl_port,
|
|
|
|
|
help="The port for NCCL distributed environment setup. Defaults to a random port.",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Quantization and data type
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--dtype",
|
|
|
|
|
type=str,
|
|
|
|
|
@@ -744,13 +740,6 @@ class ServerArgs:
|
|
|
|
|
'* "float" is shorthand for FP32 precision.\n'
|
|
|
|
|
'* "float32" for FP32 precision.',
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--kv-cache-dtype",
|
|
|
|
|
type=str,
|
|
|
|
|
default=ServerArgs.kv_cache_dtype,
|
|
|
|
|
choices=["auto", "fp8_e5m2", "fp8_e4m3"],
|
|
|
|
|
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--quantization",
|
|
|
|
|
type=str,
|
|
|
|
|
@@ -785,65 +774,11 @@ class ServerArgs:
|
|
|
|
|
"default to 1.0, which may cause accuracy issues. ",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--context-length",
|
|
|
|
|
type=int,
|
|
|
|
|
default=ServerArgs.context_length,
|
|
|
|
|
help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--device",
|
|
|
|
|
"--kv-cache-dtype",
|
|
|
|
|
type=str,
|
|
|
|
|
default=ServerArgs.device,
|
|
|
|
|
help="The device to use ('cuda', 'xpu', 'hpu', 'npu', 'cpu'). Defaults to auto-detection if not specified.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--served-model-name",
|
|
|
|
|
type=str,
|
|
|
|
|
default=ServerArgs.served_model_name,
|
|
|
|
|
help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--chat-template",
|
|
|
|
|
type=str,
|
|
|
|
|
default=ServerArgs.chat_template,
|
|
|
|
|
help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--completion-template",
|
|
|
|
|
type=str,
|
|
|
|
|
default=ServerArgs.completion_template,
|
|
|
|
|
help="The buliltin completion template name or the path of the completion template file. This is only used for OpenAI-compatible API server. only for code completion currently.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--is-embedding",
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="Whether to use a CausalLM as an embedding model.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--enable-multimodal",
|
|
|
|
|
default=ServerArgs.enable_multimodal,
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="Enable the multimodal functionality for the served model. If the model being served is not multimodal, nothing will happen",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--revision",
|
|
|
|
|
type=str,
|
|
|
|
|
default=None,
|
|
|
|
|
help="The specific model version to use. It can be a branch "
|
|
|
|
|
"name, a tag name, or a commit id. If unspecified, will use "
|
|
|
|
|
"the default version.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--impl",
|
|
|
|
|
type=str,
|
|
|
|
|
default=ServerArgs.impl,
|
|
|
|
|
help="Which implementation of the model to use.\n\n"
|
|
|
|
|
'* "auto" will try to use the SGLang implementation if it exists '
|
|
|
|
|
"and fall back to the Transformers implementation if no SGLang "
|
|
|
|
|
"implementation is available.\n"
|
|
|
|
|
'* "sglang" will use the SGLang model implementation.\n'
|
|
|
|
|
'* "transformers" will use the Transformers model '
|
|
|
|
|
"implementation.\n",
|
|
|
|
|
default=ServerArgs.kv_cache_dtype,
|
|
|
|
|
choices=["auto", "fp8_e5m2", "fp8_e4m3"],
|
|
|
|
|
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Memory and scheduling
|
|
|
|
|
@@ -928,7 +863,13 @@ class ServerArgs:
|
|
|
|
|
help="Disable the hybrid SWA memory.",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Other runtime options
|
|
|
|
|
# Runtime options
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--device",
|
|
|
|
|
type=str,
|
|
|
|
|
default=ServerArgs.device,
|
|
|
|
|
help="The device to use ('cuda', 'xpu', 'hpu', 'npu', 'cpu'). Defaults to auto-detection if not specified.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--tensor-parallel-size",
|
|
|
|
|
"--tp-size",
|
|
|
|
|
@@ -970,7 +911,7 @@ class ServerArgs:
|
|
|
|
|
"--constrained-json-whitespace-pattern",
|
|
|
|
|
type=str,
|
|
|
|
|
default=ServerArgs.constrained_json_whitespace_pattern,
|
|
|
|
|
help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
|
|
|
|
|
help="(outlines backend only) Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--watchdog-timeout",
|
|
|
|
|
@@ -1083,12 +1024,6 @@ class ServerArgs:
|
|
|
|
|
default=ServerArgs.collect_tokens_histogram,
|
|
|
|
|
help="Collect prompt/generation tokens histogram.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--kv-events-config",
|
|
|
|
|
type=str,
|
|
|
|
|
default=None,
|
|
|
|
|
help="Config in json format for NVIDIA dynamo KV event publishing. Publishing will be enabled if this flag is used.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--decode-log-interval",
|
|
|
|
|
type=int,
|
|
|
|
|
@@ -1101,6 +1036,12 @@ class ServerArgs:
|
|
|
|
|
default=ServerArgs.enable_request_time_stats_logging,
|
|
|
|
|
help="Enable per request time stats logging",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--kv-events-config",
|
|
|
|
|
type=str,
|
|
|
|
|
default=None,
|
|
|
|
|
help="Config in json format for NVIDIA dynamo KV event publishing. Publishing will be enabled if this flag is used.",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# API related
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
@@ -1109,6 +1050,24 @@ class ServerArgs:
|
|
|
|
|
default=ServerArgs.api_key,
|
|
|
|
|
help="Set API key of the server. It is also used in the OpenAI API compatible server.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--served-model-name",
|
|
|
|
|
type=str,
|
|
|
|
|
default=ServerArgs.served_model_name,
|
|
|
|
|
help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--chat-template",
|
|
|
|
|
type=str,
|
|
|
|
|
default=ServerArgs.chat_template,
|
|
|
|
|
help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--completion-template",
|
|
|
|
|
type=str,
|
|
|
|
|
default=ServerArgs.completion_template,
|
|
|
|
|
help="The buliltin completion template name or the path of the completion template file. This is only used for OpenAI-compatible API server. only for code completion currently.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--file-storage-path",
|
|
|
|
|
type=str,
|
|
|
|
|
@@ -1427,6 +1386,46 @@ class ServerArgs:
|
|
|
|
|
help="TP size for MoE dense MLP layers. This flag is useful when, with large TP size, there are errors caused by weights in MLP layers having dimension smaller than the min dimension GEMM supports.",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Hierarchical cache
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--enable-hierarchical-cache",
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="Enable hierarchical cache",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--hicache-ratio",
|
|
|
|
|
type=float,
|
|
|
|
|
default=ServerArgs.hicache_ratio,
|
|
|
|
|
help="The ratio of the size of host KV cache memory pool to the size of device pool.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--hicache-size",
|
|
|
|
|
type=int,
|
|
|
|
|
default=ServerArgs.hicache_size,
|
|
|
|
|
help="The size of host KV cache memory pool in gigabytes, which will override the hicache_ratio if set.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--hicache-write-policy",
|
|
|
|
|
type=str,
|
|
|
|
|
choices=["write_back", "write_through", "write_through_selective"],
|
|
|
|
|
default=ServerArgs.hicache_write_policy,
|
|
|
|
|
help="The write policy of hierarchical cache.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--hicache-io-backend",
|
|
|
|
|
type=str,
|
|
|
|
|
choices=["direct", "kernel"],
|
|
|
|
|
default=ServerArgs.hicache_io_backend,
|
|
|
|
|
help="The IO backend for KV cache transfer between CPU and GPU",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--hicache-storage-backend",
|
|
|
|
|
type=str,
|
|
|
|
|
choices=["file"], # todo, mooncake
|
|
|
|
|
default=ServerArgs.hicache_storage_backend,
|
|
|
|
|
help="The storage backend for hierarchical KV cache.",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Double Sparsity
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--enable-double-sparsity",
|
|
|
|
|
@@ -1619,44 +1618,6 @@ class ServerArgs:
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="Enable users to pass custom logit processors to the server (disabled by default for security)",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--enable-hierarchical-cache",
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="Enable hierarchical cache",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--hicache-ratio",
|
|
|
|
|
type=float,
|
|
|
|
|
default=ServerArgs.hicache_ratio,
|
|
|
|
|
help="The ratio of the size of host KV cache memory pool to the size of device pool.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--hicache-size",
|
|
|
|
|
type=int,
|
|
|
|
|
default=ServerArgs.hicache_size,
|
|
|
|
|
help="The size of host KV cache memory pool in gigabytes, which will override the hicache_ratio if set.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--hicache-write-policy",
|
|
|
|
|
type=str,
|
|
|
|
|
choices=["write_back", "write_through", "write_through_selective"],
|
|
|
|
|
default=ServerArgs.hicache_write_policy,
|
|
|
|
|
help="The write policy of hierarchical cache.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--hicache-io-backend",
|
|
|
|
|
type=str,
|
|
|
|
|
choices=["direct", "kernel"],
|
|
|
|
|
default=ServerArgs.hicache_io_backend,
|
|
|
|
|
help="The IO backend for KV cache transfer between CPU and GPU",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--hicache-storage-backend",
|
|
|
|
|
type=str,
|
|
|
|
|
choices=["file"], # todo, mooncacke
|
|
|
|
|
default=ServerArgs.hicache_storage_backend,
|
|
|
|
|
help="The storage backend for hierarchical KV cache.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--flashinfer-mla-disable-ragged",
|
|
|
|
|
action="store_true",
|
|
|
|
|
@@ -1687,13 +1648,6 @@ class ServerArgs:
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="Use triton moe grouped gemm kernel.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--warmups",
|
|
|
|
|
type=str,
|
|
|
|
|
required=False,
|
|
|
|
|
help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 "
|
|
|
|
|
"will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Debug tensor dumps
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
@@ -1720,7 +1674,7 @@ class ServerArgs:
|
|
|
|
|
help="Only dump the tensors for prefill requests (i.e. batch size > 1).",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Disaggregation
|
|
|
|
|
# PD disaggregation
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--disaggregation-mode",
|
|
|
|
|
type=str,
|
|
|
|
|
@@ -1779,6 +1733,8 @@ class ServerArgs:
|
|
|
|
|
default=None,
|
|
|
|
|
help="The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer.",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Custom weight loader
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--custom-weight-loader",
|
|
|
|
|
type=str,
|
|
|
|
|
@@ -1791,6 +1747,8 @@ class ServerArgs:
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="Enable PD-Multiplexing, PD running on greenctx stream.",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# For PD-Multiplexing
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--sm-group-num",
|
|
|
|
|
type=int,
|
|
|
|
|
@@ -1818,6 +1776,17 @@ class ServerArgs:
|
|
|
|
|
else:
|
|
|
|
|
return f"http://{self.host}:{self.port}"
|
|
|
|
|
|
|
|
|
|
def get_hf_config(self):
|
|
|
|
|
kwargs = {}
|
|
|
|
|
hf_config = get_config(
|
|
|
|
|
self.model_path,
|
|
|
|
|
trust_remote_code=self.trust_remote_code,
|
|
|
|
|
revision=self.revision,
|
|
|
|
|
model_override_args=json.loads(self.json_model_override_args),
|
|
|
|
|
**kwargs,
|
|
|
|
|
)
|
|
|
|
|
return hf_config
|
|
|
|
|
|
|
|
|
|
def check_server_args(self):
|
|
|
|
|
assert (
|
|
|
|
|
self.tp_size * self.pp_size
|
|
|
|
|
@@ -1842,6 +1811,11 @@ class ServerArgs:
|
|
|
|
|
assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
|
|
|
|
|
assert self.gpu_id_step >= 1, "gpu_id_step must be positive"
|
|
|
|
|
|
|
|
|
|
assert self.moe_dense_tp_size in {
|
|
|
|
|
1,
|
|
|
|
|
None,
|
|
|
|
|
}, "moe_dense_tp_size only support 1 and None currently"
|
|
|
|
|
|
|
|
|
|
if isinstance(self.lora_paths, list):
|
|
|
|
|
lora_paths = self.lora_paths
|
|
|
|
|
self.lora_paths = {}
|
|
|
|
|
@@ -1852,6 +1826,56 @@ class ServerArgs:
|
|
|
|
|
else:
|
|
|
|
|
self.lora_paths[lora_path] = lora_path
|
|
|
|
|
|
|
|
|
|
def validate_disagg_tp_size(self, prefill_tp: int, decode_tp: int):
|
|
|
|
|
larger_tp = max(decode_tp, prefill_tp)
|
|
|
|
|
smaller_tp = min(decode_tp, prefill_tp)
|
|
|
|
|
assert larger_tp % smaller_tp == 0, (
|
|
|
|
|
"Different tp size is supported only when one tp is multiple of the other. "
|
|
|
|
|
f"decode_tp={decode_tp}, prefill_tp={prefill_tp}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def adjust_mem_fraction_for_vlm(self, model_config):
|
|
|
|
|
vision_config = getattr(model_config.hf_config, "vision_config", None)
|
|
|
|
|
if vision_config is None:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# roughly reduce the mem_fraction_static base on params of Vit
|
|
|
|
|
original_server_arg_mem_fraction = self.mem_fraction_static
|
|
|
|
|
# a base mem_fraction_static factor for regular Vit
|
|
|
|
|
base_mem_fraction_reduction_ratio = 0.95
|
|
|
|
|
|
|
|
|
|
vit_num_layers = getattr(vision_config, "num_hidden_layers", 24)
|
|
|
|
|
vit_hidden_size = getattr(vision_config, "hidden_size", 1024)
|
|
|
|
|
|
|
|
|
|
# baseline ViT params (ViT-L/14)
|
|
|
|
|
baseline_vit_layers = 24
|
|
|
|
|
baseline_vit_hidden_size = 1024
|
|
|
|
|
|
|
|
|
|
# weight params count
|
|
|
|
|
current_complexity_score = vit_num_layers * (vit_hidden_size**2)
|
|
|
|
|
baseline_complexity_score = baseline_vit_layers * (baseline_vit_hidden_size**2)
|
|
|
|
|
complexity_ratio = (
|
|
|
|
|
current_complexity_score / baseline_complexity_score
|
|
|
|
|
if baseline_complexity_score > 0
|
|
|
|
|
else 1.0
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# every time the complexity grows 100%, adjust final factor for 10%
|
|
|
|
|
sensitivity_scale = 0.1
|
|
|
|
|
dynamic_adjustment_factor = 1.0 - sensitivity_scale * (complexity_ratio - 1.0)
|
|
|
|
|
dynamic_adjustment_factor = max(0.8, min(1.05, dynamic_adjustment_factor))
|
|
|
|
|
|
|
|
|
|
final_overall_factor = (
|
|
|
|
|
base_mem_fraction_reduction_ratio * dynamic_adjustment_factor
|
|
|
|
|
)
|
|
|
|
|
self.mem_fraction_static = (
|
|
|
|
|
original_server_arg_mem_fraction * final_overall_factor
|
|
|
|
|
)
|
|
|
|
|
logger.warning(
|
|
|
|
|
f"Multimodal model: Dynamically adjusted --mem-fraction-static "
|
|
|
|
|
f"from: {original_server_arg_mem_fraction:.3f} to: {self.mem_fraction_static:.3f}."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def prepare_server_args(argv: List[str]) -> ServerArgs:
|
|
|
|
|
"""
|
|
|
|
|
@@ -1895,16 +1919,16 @@ class PortArgs:
|
|
|
|
|
@staticmethod
|
|
|
|
|
def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
|
|
|
|
|
if server_args.nccl_port is None:
|
|
|
|
|
port = server_args.port + random.randint(100, 1000)
|
|
|
|
|
nccl_port = server_args.port + random.randint(100, 1000)
|
|
|
|
|
while True:
|
|
|
|
|
if is_port_available(port):
|
|
|
|
|
if is_port_available(nccl_port):
|
|
|
|
|
break
|
|
|
|
|
if port < 60000:
|
|
|
|
|
port += 42
|
|
|
|
|
if nccl_port < 60000:
|
|
|
|
|
nccl_port += 42
|
|
|
|
|
else:
|
|
|
|
|
port -= 43
|
|
|
|
|
nccl_port -= 43
|
|
|
|
|
else:
|
|
|
|
|
port = server_args.nccl_port
|
|
|
|
|
nccl_port = server_args.nccl_port
|
|
|
|
|
|
|
|
|
|
if not server_args.enable_dp_attention:
|
|
|
|
|
# Normal case, use IPC within a single node
|
|
|
|
|
@@ -1912,7 +1936,7 @@ class PortArgs:
|
|
|
|
|
tokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
|
|
|
|
scheduler_input_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
|
|
|
|
detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
|
|
|
|
nccl_port=port,
|
|
|
|
|
nccl_port=nccl_port,
|
|
|
|
|
rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
|
|
|
|
metrics_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
|
|
|
|
)
|
|
|
|
|
@@ -1942,7 +1966,7 @@ class PortArgs:
|
|
|
|
|
tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
|
|
|
|
|
scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
|
|
|
|
|
detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}",
|
|
|
|
|
nccl_port=port,
|
|
|
|
|
nccl_port=nccl_port,
|
|
|
|
|
rpc_ipc_name=f"tcp://{dist_init_host}:{port_base + 2}",
|
|
|
|
|
metrics_ipc_name=f"tcp://{dist_init_host}:{port_base + 3}",
|
|
|
|
|
)
|
|
|
|
|
@@ -1969,31 +1993,13 @@ class DeprecatedAction(argparse.Action):
|
|
|
|
|
raise ValueError(self.help)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_model_arch(args: ServerArgs):
|
|
|
|
|
hf_config = get_config(
|
|
|
|
|
args.model_path,
|
|
|
|
|
trust_remote_code=args.trust_remote_code,
|
|
|
|
|
revision=args.revision,
|
|
|
|
|
model_override_args=json.loads(args.json_model_override_args),
|
|
|
|
|
)
|
|
|
|
|
return hf_config.architectures[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def auto_choose_speculative_params(self: ServerArgs):
|
|
|
|
|
"""
|
|
|
|
|
Automatically choose the parameters for speculative decoding.
|
|
|
|
|
|
|
|
|
|
You can tune them on your own models and prompts with scripts/playground/bench_speculative.py
|
|
|
|
|
"""
|
|
|
|
|
kwargs = {}
|
|
|
|
|
|
|
|
|
|
hf_config = get_config(
|
|
|
|
|
self.model_path,
|
|
|
|
|
trust_remote_code=self.trust_remote_code,
|
|
|
|
|
revision=self.revision,
|
|
|
|
|
model_override_args=json.loads(self.json_model_override_args),
|
|
|
|
|
**kwargs,
|
|
|
|
|
)
|
|
|
|
|
hf_config = self.get_hf_config()
|
|
|
|
|
arch = hf_config.architectures[0]
|
|
|
|
|
|
|
|
|
|
if arch in ["LlamaForCausalLM"]:
|
|
|
|
|
|