Clean up server_args, triton cache manager (#8332)
This commit is contained in:
@@ -80,7 +80,7 @@ class ServerArgs:
|
||||
schedule_policy: str = "fcfs"
|
||||
schedule_conservativeness: float = 1.0
|
||||
cpu_offload_gb: int = 0
|
||||
page_size: int = 1
|
||||
page_size: Optional[int] = None
|
||||
hybrid_kvcache_ratio: Optional[float] = None
|
||||
swa_full_tokens_ratio: float = 0.8
|
||||
disable_hybrid_swa_memory: bool = False
|
||||
@@ -266,31 +266,20 @@ class ServerArgs:
|
||||
|
||||
def __post_init__(self):
|
||||
# Expert parallelism
|
||||
# We put it here first due to some internal ckpt conversation issues.
|
||||
if self.enable_ep_moe:
|
||||
self.ep_size = self.tp_size
|
||||
logger.warning(
|
||||
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
||||
)
|
||||
if self.enable_flashinfer_moe:
|
||||
assert (
|
||||
self.quantization == "modelopt_fp4"
|
||||
), "modelopt_fp4 quantization is required for Flashinfer MOE"
|
||||
os.environ["TRTLLM_ENABLE_PDL"] = "1"
|
||||
self.disable_shared_experts_fusion = True
|
||||
logger.warning(
|
||||
f"Flashinfer MoE is enabled. Shared expert fusion is disabled."
|
||||
)
|
||||
|
||||
# Set missing default values
|
||||
if self.tokenizer_path is None:
|
||||
self.tokenizer_path = self.model_path
|
||||
|
||||
if self.device is None:
|
||||
self.device = get_device()
|
||||
|
||||
if self.served_model_name is None:
|
||||
self.served_model_name = self.model_path
|
||||
|
||||
if self.device is None:
|
||||
self.device = get_device()
|
||||
if self.random_seed is None:
|
||||
self.random_seed = random.randint(0, 1 << 30)
|
||||
|
||||
@@ -359,7 +348,6 @@ class ServerArgs:
|
||||
self.chunked_prefill_size = 16384
|
||||
else:
|
||||
self.chunked_prefill_size = 4096
|
||||
assert self.chunked_prefill_size % self.page_size == 0
|
||||
|
||||
# Set cuda graph max batch size
|
||||
if self.cuda_graph_max_bs is None:
|
||||
@@ -410,6 +398,14 @@ class ServerArgs:
|
||||
)
|
||||
self.page_size = 128
|
||||
|
||||
# Set page size
|
||||
if self.page_size is None:
|
||||
self.page_size = 1
|
||||
|
||||
# AMD-specific Triton attention KV splits default number
|
||||
if is_hip():
|
||||
self.triton_attention_num_kv_splits = 16
|
||||
|
||||
# Choose grammar backend
|
||||
if self.grammar_backend is None:
|
||||
self.grammar_backend = "xgrammar"
|
||||
@@ -431,6 +427,17 @@ class ServerArgs:
|
||||
self.enable_dp_attention
|
||||
), "Please enable dp attention when setting enable_dp_lm_head. "
|
||||
|
||||
# MoE kernel
|
||||
if self.enable_flashinfer_moe:
|
||||
assert (
|
||||
self.quantization == "modelopt_fp4"
|
||||
), "modelopt_fp4 quantization is required for Flashinfer MOE"
|
||||
os.environ["TRTLLM_ENABLE_PDL"] = "1"
|
||||
self.disable_shared_experts_fusion = True
|
||||
logger.warning(
|
||||
f"Flashinfer MoE is enabled. Shared expert fusion is disabled."
|
||||
)
|
||||
|
||||
# DeepEP MoE
|
||||
if self.enable_deepep_moe:
|
||||
if self.deepep_mode == "normal":
|
||||
@@ -502,14 +509,6 @@ class ServerArgs:
|
||||
logger.warning(
|
||||
"DeepSeek MTP does not require setting speculative_draft_model_path."
|
||||
)
|
||||
elif "Llama4" in model_arch:
|
||||
# TODO: remove this after Llama4 supports in other backends
|
||||
if self.attention_backend != "fa3":
|
||||
self.attention_backend = "fa3"
|
||||
logger.warning(
|
||||
"Llama4 requires using fa3 attention backend. "
|
||||
"Attention backend is automatically set to fa3."
|
||||
)
|
||||
|
||||
# Auto choose parameters
|
||||
if self.speculative_num_steps is None:
|
||||
@@ -542,12 +541,11 @@ class ServerArgs:
|
||||
) and check_gguf_file(self.model_path):
|
||||
self.quantization = self.load_format = "gguf"
|
||||
|
||||
# Model loading
|
||||
if is_remote_url(self.model_path):
|
||||
self.load_format = "remote"
|
||||
|
||||
# AMD-specific Triton attention KV splits default number
|
||||
if is_hip():
|
||||
self.triton_attention_num_kv_splits = 16
|
||||
if self.custom_weight_loader is None:
|
||||
self.custom_weight_loader = []
|
||||
|
||||
# PD disaggregation
|
||||
if self.disaggregation_mode == "decode":
|
||||
@@ -572,6 +570,7 @@ class ServerArgs:
|
||||
self.disable_cuda_graph = True
|
||||
logger.warning("Cuda graph is disabled for prefill server")
|
||||
|
||||
# Propagate env vars
|
||||
os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
|
||||
"1" if self.enable_torch_compile else "0"
|
||||
)
|
||||
@@ -580,9 +579,6 @@ class ServerArgs:
|
||||
"1" if self.disable_outlines_disk_cache else "0"
|
||||
)
|
||||
|
||||
if self.custom_weight_loader is None:
|
||||
self.custom_weight_loader = []
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
# Model and tokenizer
|
||||
@@ -1227,6 +1223,13 @@ class ServerArgs:
|
||||
default=ServerArgs.grammar_backend,
|
||||
help="Choose the backend for grammar-guided decoding.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mm-attention-backend",
|
||||
type=str,
|
||||
choices=["sdpa", "fa3", "triton_attn"],
|
||||
default=ServerArgs.mm_attention_backend,
|
||||
help="Set multimodal attention backend.",
|
||||
)
|
||||
|
||||
# Speculative decoding
|
||||
parser.add_argument(
|
||||
@@ -1276,13 +1279,6 @@ class ServerArgs:
|
||||
help="The path of the draft model's small vocab table.",
|
||||
default=ServerArgs.speculative_token_map,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mm-attention-backend",
|
||||
type=str,
|
||||
choices=["sdpa", "fa3", "triton_attn"],
|
||||
default=ServerArgs.mm_attention_backend,
|
||||
help="Set multimodal attention backend.",
|
||||
)
|
||||
|
||||
# Expert parallelism
|
||||
parser.add_argument(
|
||||
@@ -1530,11 +1526,6 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Disable the overlap scheduler, which overlaps the CPU scheduler with GPU model worker.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-overlap-cg-plan",
|
||||
action="store_true",
|
||||
help="Disable the overlap optimization for cudagraph preparation in eagle verify.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-mixed-chunk",
|
||||
action="store_true",
|
||||
@@ -1792,11 +1783,11 @@ class ServerArgs:
|
||||
return hf_config
|
||||
|
||||
def check_server_args(self):
|
||||
# Check parallel size constraints
|
||||
assert (
|
||||
self.tp_size * self.pp_size
|
||||
) % self.nnodes == 0, "tp_size must be divisible by number of nodes"
|
||||
|
||||
# FIXME pp constraints
|
||||
if self.pp_size > 1:
|
||||
assert (
|
||||
self.disable_overlap_schedule
|
||||
@@ -1807,11 +1798,7 @@ class ServerArgs:
|
||||
assert not (
|
||||
self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention
|
||||
), "multi-node data parallel is not supported unless dp attention!"
|
||||
assert (
|
||||
self.max_loras_per_batch > 0
|
||||
# FIXME
|
||||
and (self.lora_paths is None or self.disable_radix_cache)
|
||||
), "compatibility of lora and radix attention is in progress"
|
||||
|
||||
assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
|
||||
assert self.gpu_id_step >= 1, "gpu_id_step must be positive"
|
||||
|
||||
@@ -1820,9 +1807,32 @@ class ServerArgs:
|
||||
None,
|
||||
}, "moe_dense_tp_size only support 1 and None currently"
|
||||
|
||||
# Check model architecture
|
||||
model_arch = self.get_hf_config().architectures[0]
|
||||
if "Llama4" in model_arch:
|
||||
assert self.attention_backend == "fa3", "fa3 is required for Llama4 model"
|
||||
|
||||
# Check LoRA
|
||||
self.check_lora_server_args()
|
||||
|
||||
# Check speculative decoding
|
||||
if self.speculative_algorithm is not None:
|
||||
assert (
|
||||
not self.enable_mixed_chunk
|
||||
), "enable_mixed_chunk is required for speculative decoding"
|
||||
|
||||
# Check chunked prefill
|
||||
assert (
|
||||
self.chunked_prefill_size % self.page_size == 0
|
||||
), "chunked_prefill_size must be divisible by page_size"
|
||||
|
||||
def check_lora_server_args(self):
|
||||
assert (
|
||||
self.max_loras_per_batch > 0
|
||||
# FIXME
|
||||
and (self.lora_paths is None or self.disable_radix_cache)
|
||||
), "compatibility of lora and radix attention is in progress"
|
||||
|
||||
# Enable LoRA if any LoRA paths are provided for backward compatibility.
|
||||
if self.lora_paths:
|
||||
if self.enable_lora is None:
|
||||
|
||||
Reference in New Issue
Block a user