adapt to sglang v0.5.2rc1 on dcu

This commit is contained in:
maxiao
2025-09-04 15:56:33 +08:00
commit 909abb58f5
2320 changed files with 489411 additions and 0 deletions

166
python/pyproject.toml Normal file
View File

@@ -0,0 +1,166 @@
[build-system]
requires = ["setuptools>=61.0", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "sglang"
version = "0.5.2rc1"
description = "SGLang is yet another fast serving framework for large language models and vision language models."
readme = "README.md"
requires-python = ">=3.10"
license = { file = "LICENSE" }
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
]
dependencies = ["aiohttp", "requests", "tqdm", "numpy", "IPython", "setproctitle"]
[project.optional-dependencies]
runtime_common = [
"blobfile==3.0.0",
"build",
"compressed-tensors",
"datasets",
"einops",
"fastapi",
"hf_transfer",
"huggingface_hub",
"interegular",
"llguidance>=0.7.11,<0.8.0",
"modelscope",
"msgspec",
"ninja",
"openai==1.99.1",
"openai-harmony==0.0.4",
"orjson",
"outlines==0.1.11",
"packaging",
"partial_json_parser",
"pillow",
"prometheus-client>=0.20.0",
"psutil",
"pybase64",
"pydantic",
"pynvml",
"python-multipart",
"pyzmq>=25.1.2",
"sentencepiece",
"soundfile==0.13.1",
"scipy",
"timm==1.0.16",
"tiktoken",
"torchao==0.9.0",
"transformers==4.56.0",
"uvicorn",
"uvloop",
"xgrammar==0.1.23",
]
srt = [
"sglang[runtime_common]",
"sgl-kernel==0.3.8",
"torch==2.8.0",
"torchaudio==2.8.0",
"torchvision",
"cuda-python",
"flashinfer_python==0.3.0",
]
blackwell = [
"sglang[runtime_common]",
"sgl-kernel",
"torch==2.8.0",
"torchaudio==2.8.0",
"torchvision",
"cuda-python",
"flashinfer_python==0.3.0",
]
# HIP (Heterogeneous-computing Interface for Portability) for AMD
# => base docker rocm/vllm-dev:20250114, not from public vllm whl
srt_hip = [
"sglang[runtime_common]",
"torch",
"petit_kernel==0.0.2",
"wave-lang==1.0.1",
]
# https://docs.sglang.ai/platforms/cpu_server.html
srt_cpu = ["sglang[runtime_common]"]
# https://docs.sglang.ai/platforms/ascend_npu.html
srt_npu = ["sglang[runtime_common]"]
# xpu is not enabled in public vllm and torch whl,
# need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm
srt_xpu = ["sglang[runtime_common]"]
# For Intel Gaudi(device : hpu) follow the installation guide
# https://docs.vllm.ai/en/latest/getting_started/gaudi-installation.html
srt_hpu = ["sglang[runtime_common]"]
openai = ["openai==1.99.1", "tiktoken"]
anthropic = ["anthropic>=0.20.0"]
litellm = ["litellm>=1.0.0"]
torch_memory_saver = ["torch_memory_saver==0.0.8"]
decord = ["decord"]
test = [
"accelerate",
"expecttest",
"jsonlines",
"matplotlib",
"pandas",
"peft",
"sentence_transformers",
"pytest",
"tabulate",
]
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[torch_memory_saver]", "sglang[decord]"]
all_hip = ["sglang[srt_hip]", "sglang[openai]", "sglang[anthropic]", "sglang[decord]"]
all_xpu = ["sglang[srt_xpu]", "sglang[openai]", "sglang[anthropic]", "sglang[decord]"]
all_hpu = ["sglang[srt_hpu]", "sglang[openai]", "sglang[anthropic]", "sglang[decord]"]
all_cpu = ["sglang[srt_cpu]", "sglang[openai]", "sglang[anthropic]", "sglang[decord]"]
all_npu = ["sglang[srt_npu]", "sglang[openai]", "sglang[anthropic]", "sglang[decord]"]
dev = ["sglang[all]", "sglang[test]"]
dev_hip = ["sglang[all_hip]", "sglang[test]"]
dev_xpu = ["sglang[all_xpu]", "sglang[test]"]
dev_hpu = ["sglang[all_hpu]", "sglang[test]"]
dev_cpu = ["sglang[all_cpu]", "sglang[test]"]
[project.urls]
"Homepage" = "https://github.com/sgl-project/sglang"
"Bug Tracker" = "https://github.com/sgl-project/sglang/issues"
[tool.setuptools.package-data]
"sglang" = [
"srt/layers/moe/fused_moe_triton/configs/*/*.json",
"srt/layers/quantization/configs/*.json",
"srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp",
]
[tool.setuptools.packages.find]
exclude = [
"assets*",
"benchmark*",
"docs*",
"dist*",
"playground*",
"scripts*",
"tests*",
]
[tool.wheel]
exclude = [
"assets*",
"benchmark*",
"docs*",
"dist*",
"playground*",
"scripts*",
"tests*",
]
[tool.codespell]
ignore-words-list = "ans, als, hel, boostrap, childs, te, vas, hsa, ment"
skip = "*.json,*.jsonl,*.patch,*.txt"

View File

@@ -0,0 +1,150 @@
Metadata-Version: 2.4
Name: sglang
Version: 0.5.2rc1
Summary: SGLang is yet another fast serving framework for large language models and vision language models.
Project-URL: Homepage, https://github.com/sgl-project/sglang
Project-URL: Bug Tracker, https://github.com/sgl-project/sglang/issues
Classifier: Programming Language :: Python :: 3
Classifier: License :: OSI Approved :: Apache Software License
Requires-Python: >=3.10
Description-Content-Type: text/markdown
Requires-Dist: aiohttp
Requires-Dist: requests
Requires-Dist: tqdm
Requires-Dist: numpy
Requires-Dist: IPython
Requires-Dist: setproctitle
Provides-Extra: runtime-common
Requires-Dist: blobfile==3.0.0; extra == "runtime-common"
Requires-Dist: build; extra == "runtime-common"
Requires-Dist: compressed-tensors; extra == "runtime-common"
Requires-Dist: datasets; extra == "runtime-common"
Requires-Dist: einops; extra == "runtime-common"
Requires-Dist: fastapi; extra == "runtime-common"
Requires-Dist: hf_transfer; extra == "runtime-common"
Requires-Dist: huggingface_hub; extra == "runtime-common"
Requires-Dist: interegular; extra == "runtime-common"
Requires-Dist: llguidance<0.8.0,>=0.7.11; extra == "runtime-common"
Requires-Dist: modelscope; extra == "runtime-common"
Requires-Dist: msgspec; extra == "runtime-common"
Requires-Dist: ninja; extra == "runtime-common"
Requires-Dist: openai==1.99.1; extra == "runtime-common"
Requires-Dist: openai-harmony==0.0.4; extra == "runtime-common"
Requires-Dist: orjson; extra == "runtime-common"
Requires-Dist: outlines==0.1.11; extra == "runtime-common"
Requires-Dist: packaging; extra == "runtime-common"
Requires-Dist: partial_json_parser; extra == "runtime-common"
Requires-Dist: pillow; extra == "runtime-common"
Requires-Dist: prometheus-client>=0.20.0; extra == "runtime-common"
Requires-Dist: psutil; extra == "runtime-common"
Requires-Dist: pybase64; extra == "runtime-common"
Requires-Dist: pydantic; extra == "runtime-common"
Requires-Dist: pynvml; extra == "runtime-common"
Requires-Dist: python-multipart; extra == "runtime-common"
Requires-Dist: pyzmq>=25.1.2; extra == "runtime-common"
Requires-Dist: sentencepiece; extra == "runtime-common"
Requires-Dist: soundfile==0.13.1; extra == "runtime-common"
Requires-Dist: scipy; extra == "runtime-common"
Requires-Dist: timm==1.0.16; extra == "runtime-common"
Requires-Dist: tiktoken; extra == "runtime-common"
Requires-Dist: torchao==0.9.0; extra == "runtime-common"
Requires-Dist: transformers==4.56.0; extra == "runtime-common"
Requires-Dist: uvicorn; extra == "runtime-common"
Requires-Dist: uvloop; extra == "runtime-common"
Requires-Dist: xgrammar==0.1.23; extra == "runtime-common"
Provides-Extra: srt
Requires-Dist: sglang[runtime_common]; extra == "srt"
Requires-Dist: sgl-kernel==0.3.8; extra == "srt"
Requires-Dist: torch==2.8.0; extra == "srt"
Requires-Dist: torchaudio==2.8.0; extra == "srt"
Requires-Dist: torchvision; extra == "srt"
Requires-Dist: cuda-python; extra == "srt"
Requires-Dist: flashinfer_python==0.3.0; extra == "srt"
Provides-Extra: blackwell
Requires-Dist: sglang[runtime_common]; extra == "blackwell"
Requires-Dist: sgl-kernel; extra == "blackwell"
Requires-Dist: torch==2.8.0; extra == "blackwell"
Requires-Dist: torchaudio==2.8.0; extra == "blackwell"
Requires-Dist: torchvision; extra == "blackwell"
Requires-Dist: cuda-python; extra == "blackwell"
Requires-Dist: flashinfer_python==0.3.0; extra == "blackwell"
Provides-Extra: srt-hip
Requires-Dist: sglang[runtime_common]; extra == "srt-hip"
Requires-Dist: torch; extra == "srt-hip"
Requires-Dist: petit_kernel==0.0.2; extra == "srt-hip"
Requires-Dist: wave-lang==1.0.1; extra == "srt-hip"
Provides-Extra: srt-cpu
Requires-Dist: sglang[runtime_common]; extra == "srt-cpu"
Provides-Extra: srt-npu
Requires-Dist: sglang[runtime_common]; extra == "srt-npu"
Provides-Extra: srt-xpu
Requires-Dist: sglang[runtime_common]; extra == "srt-xpu"
Provides-Extra: srt-hpu
Requires-Dist: sglang[runtime_common]; extra == "srt-hpu"
Provides-Extra: openai
Requires-Dist: openai==1.99.1; extra == "openai"
Requires-Dist: tiktoken; extra == "openai"
Provides-Extra: anthropic
Requires-Dist: anthropic>=0.20.0; extra == "anthropic"
Provides-Extra: litellm
Requires-Dist: litellm>=1.0.0; extra == "litellm"
Provides-Extra: torch-memory-saver
Requires-Dist: torch_memory_saver==0.0.8; extra == "torch-memory-saver"
Provides-Extra: decord
Requires-Dist: decord; extra == "decord"
Provides-Extra: test
Requires-Dist: accelerate; extra == "test"
Requires-Dist: expecttest; extra == "test"
Requires-Dist: jsonlines; extra == "test"
Requires-Dist: matplotlib; extra == "test"
Requires-Dist: pandas; extra == "test"
Requires-Dist: peft; extra == "test"
Requires-Dist: sentence_transformers; extra == "test"
Requires-Dist: pytest; extra == "test"
Requires-Dist: tabulate; extra == "test"
Provides-Extra: all
Requires-Dist: sglang[srt]; extra == "all"
Requires-Dist: sglang[openai]; extra == "all"
Requires-Dist: sglang[anthropic]; extra == "all"
Requires-Dist: sglang[torch_memory_saver]; extra == "all"
Requires-Dist: sglang[decord]; extra == "all"
Provides-Extra: all-hip
Requires-Dist: sglang[srt_hip]; extra == "all-hip"
Requires-Dist: sglang[openai]; extra == "all-hip"
Requires-Dist: sglang[anthropic]; extra == "all-hip"
Requires-Dist: sglang[decord]; extra == "all-hip"
Provides-Extra: all-xpu
Requires-Dist: sglang[srt_xpu]; extra == "all-xpu"
Requires-Dist: sglang[openai]; extra == "all-xpu"
Requires-Dist: sglang[anthropic]; extra == "all-xpu"
Requires-Dist: sglang[decord]; extra == "all-xpu"
Provides-Extra: all-hpu
Requires-Dist: sglang[srt_hpu]; extra == "all-hpu"
Requires-Dist: sglang[openai]; extra == "all-hpu"
Requires-Dist: sglang[anthropic]; extra == "all-hpu"
Requires-Dist: sglang[decord]; extra == "all-hpu"
Provides-Extra: all-cpu
Requires-Dist: sglang[srt_cpu]; extra == "all-cpu"
Requires-Dist: sglang[openai]; extra == "all-cpu"
Requires-Dist: sglang[anthropic]; extra == "all-cpu"
Requires-Dist: sglang[decord]; extra == "all-cpu"
Provides-Extra: all-npu
Requires-Dist: sglang[srt_npu]; extra == "all-npu"
Requires-Dist: sglang[openai]; extra == "all-npu"
Requires-Dist: sglang[anthropic]; extra == "all-npu"
Requires-Dist: sglang[decord]; extra == "all-npu"
Provides-Extra: dev
Requires-Dist: sglang[all]; extra == "dev"
Requires-Dist: sglang[test]; extra == "dev"
Provides-Extra: dev-hip
Requires-Dist: sglang[all_hip]; extra == "dev-hip"
Requires-Dist: sglang[test]; extra == "dev-hip"
Provides-Extra: dev-xpu
Requires-Dist: sglang[all_xpu]; extra == "dev-xpu"
Requires-Dist: sglang[test]; extra == "dev-xpu"
Provides-Extra: dev-hpu
Requires-Dist: sglang[all_hpu]; extra == "dev-hpu"
Requires-Dist: sglang[test]; extra == "dev-hpu"
Provides-Extra: dev-cpu
Requires-Dist: sglang[all_cpu]; extra == "dev-cpu"
Requires-Dist: sglang[test]; extra == "dev-cpu"

View File

@@ -0,0 +1,883 @@
pyproject.toml
sglang/__init__.py
sglang/bench_offline_throughput.py
sglang/bench_one_batch.py
sglang/bench_one_batch_server.py
sglang/bench_serving.py
sglang/check_env.py
sglang/compile_deep_gemm.py
sglang/global_config.py
sglang/launch_server.py
sglang/profiler.py
sglang/utils.py
sglang/version.py
sglang.egg-info/PKG-INFO
sglang.egg-info/SOURCES.txt
sglang.egg-info/dependency_links.txt
sglang.egg-info/requires.txt
sglang.egg-info/top_level.txt
sglang/eval/llama3_eval.py
sglang/eval/loogle_eval.py
sglang/lang/api.py
sglang/lang/chat_template.py
sglang/lang/choices.py
sglang/lang/compiler.py
sglang/lang/interpreter.py
sglang/lang/ir.py
sglang/lang/tracer.py
sglang/lang/backend/anthropic.py
sglang/lang/backend/base_backend.py
sglang/lang/backend/litellm.py
sglang/lang/backend/openai.py
sglang/lang/backend/runtime_endpoint.py
sglang/lang/backend/vertexai.py
sglang/srt/_custom_ops.py
sglang/srt/aio_rwlock.py
sglang/srt/bench_utils.py
sglang/srt/constants.py
sglang/srt/custom_op.py
sglang/srt/hf_transformers_utils.py
sglang/srt/host_shared_memory.py
sglang/srt/offloader.py
sglang/srt/operations.py
sglang/srt/operations_strategy.py
sglang/srt/patch_torch.py
sglang/srt/poll_based_barrier.py
sglang/srt/server_args.py
sglang/srt/torch_memory_saver_adapter.py
sglang/srt/two_batch_overlap.py
sglang/srt/utils.py
sglang/srt/warmup.py
sglang/srt/configs/__init__.py
sglang/srt/configs/chatglm.py
sglang/srt/configs/dbrx.py
sglang/srt/configs/deepseekvl2.py
sglang/srt/configs/device_config.py
sglang/srt/configs/exaone.py
sglang/srt/configs/internvl.py
sglang/srt/configs/janus_pro.py
sglang/srt/configs/kimi_vl.py
sglang/srt/configs/kimi_vl_moonvit.py
sglang/srt/configs/load_config.py
sglang/srt/configs/longcat_flash.py
sglang/srt/configs/model_config.py
sglang/srt/configs/step3_vl.py
sglang/srt/configs/update_config.py
sglang/srt/configs/utils.py
sglang/srt/connector/__init__.py
sglang/srt/connector/base_connector.py
sglang/srt/connector/redis.py
sglang/srt/connector/s3.py
sglang/srt/connector/utils.py
sglang/srt/connector/serde/__init__.py
sglang/srt/connector/serde/safe_serde.py
sglang/srt/connector/serde/serde.py
sglang/srt/constrained/base_grammar_backend.py
sglang/srt/constrained/llguidance_backend.py
sglang/srt/constrained/outlines_backend.py
sglang/srt/constrained/outlines_jump_forward.py
sglang/srt/constrained/reasoner_grammar_backend.py
sglang/srt/constrained/xgrammar_backend.py
sglang/srt/constrained/triton_ops/bitmask_ops.py
sglang/srt/debug_utils/__init__.py
sglang/srt/debug_utils/dump_comparator.py
sglang/srt/debug_utils/dumper.py
sglang/srt/debug_utils/text_comparator.py
sglang/srt/disaggregation/decode.py
sglang/srt/disaggregation/decode_schedule_batch_mixin.py
sglang/srt/disaggregation/kv_events.py
sglang/srt/disaggregation/launch_lb.py
sglang/srt/disaggregation/mini_lb.py
sglang/srt/disaggregation/prefill.py
sglang/srt/disaggregation/utils.py
sglang/srt/disaggregation/ascend/__init__.py
sglang/srt/disaggregation/ascend/conn.py
sglang/srt/disaggregation/ascend/transfer_engine.py
sglang/srt/disaggregation/base/__init__.py
sglang/srt/disaggregation/base/conn.py
sglang/srt/disaggregation/common/__init__.py
sglang/srt/disaggregation/common/conn.py
sglang/srt/disaggregation/common/utils.py
sglang/srt/disaggregation/fake/__init__.py
sglang/srt/disaggregation/fake/conn.py
sglang/srt/disaggregation/mooncake/__init__.py
sglang/srt/disaggregation/mooncake/conn.py
sglang/srt/disaggregation/mooncake/transfer_engine.py
sglang/srt/disaggregation/nixl/__init__.py
sglang/srt/disaggregation/nixl/conn.py
sglang/srt/distributed/__init__.py
sglang/srt/distributed/communication_op.py
sglang/srt/distributed/naive_distributed.py
sglang/srt/distributed/parallel_state.py
sglang/srt/distributed/utils.py
sglang/srt/distributed/device_communicators/cuda_wrapper.py
sglang/srt/distributed/device_communicators/custom_all_reduce.py
sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py
sglang/srt/distributed/device_communicators/hpu_communicator.py
sglang/srt/distributed/device_communicators/npu_communicator.py
sglang/srt/distributed/device_communicators/pymscclpp.py
sglang/srt/distributed/device_communicators/pynccl.py
sglang/srt/distributed/device_communicators/pynccl_allocator.py
sglang/srt/distributed/device_communicators/pynccl_wrapper.py
sglang/srt/distributed/device_communicators/quick_all_reduce.py
sglang/srt/distributed/device_communicators/shm_broadcast.py
sglang/srt/distributed/device_communicators/xpu_communicator.py
sglang/srt/entrypoints/EngineBase.py
sglang/srt/entrypoints/context.py
sglang/srt/entrypoints/engine.py
sglang/srt/entrypoints/harmony_utils.py
sglang/srt/entrypoints/http_server.py
sglang/srt/entrypoints/http_server_engine.py
sglang/srt/entrypoints/tool.py
sglang/srt/entrypoints/openai/__init__.py
sglang/srt/entrypoints/openai/protocol.py
sglang/srt/entrypoints/openai/serving_base.py
sglang/srt/entrypoints/openai/serving_chat.py
sglang/srt/entrypoints/openai/serving_completions.py
sglang/srt/entrypoints/openai/serving_embedding.py
sglang/srt/entrypoints/openai/serving_rerank.py
sglang/srt/entrypoints/openai/serving_responses.py
sglang/srt/entrypoints/openai/serving_score.py
sglang/srt/entrypoints/openai/tool_server.py
sglang/srt/entrypoints/openai/usage_processor.py
sglang/srt/entrypoints/openai/utils.py
sglang/srt/eplb/__init__.py
sglang/srt/eplb/eplb_manager.py
sglang/srt/eplb/expert_distribution.py
sglang/srt/eplb/expert_location.py
sglang/srt/eplb/expert_location_dispatch.py
sglang/srt/eplb/expert_location_updater.py
sglang/srt/eplb/eplb_algorithms/__init__.py
sglang/srt/eplb/eplb_algorithms/deepseek.py
sglang/srt/eplb/eplb_algorithms/deepseek_vec.py
sglang/srt/eplb/eplb_simulator/__init__.py
sglang/srt/eplb/eplb_simulator/reader.py
sglang/srt/function_call/base_format_detector.py
sglang/srt/function_call/core_types.py
sglang/srt/function_call/deepseekv31_detector.py
sglang/srt/function_call/deepseekv3_detector.py
sglang/srt/function_call/ebnf_composer.py
sglang/srt/function_call/function_call_parser.py
sglang/srt/function_call/glm4_moe_detector.py
sglang/srt/function_call/gpt_oss_detector.py
sglang/srt/function_call/kimik2_detector.py
sglang/srt/function_call/llama32_detector.py
sglang/srt/function_call/mistral_detector.py
sglang/srt/function_call/pythonic_detector.py
sglang/srt/function_call/qwen25_detector.py
sglang/srt/function_call/qwen3_coder_detector.py
sglang/srt/function_call/step3_detector.py
sglang/srt/function_call/utils.py
sglang/srt/layers/activation.py
sglang/srt/layers/amx_utils.py
sglang/srt/layers/communicator.py
sglang/srt/layers/dp_attention.py
sglang/srt/layers/elementwise.py
sglang/srt/layers/flashinfer_comm_fusion.py
sglang/srt/layers/layernorm.py
sglang/srt/layers/linear.py
sglang/srt/layers/logits_processor.py
sglang/srt/layers/model_parallel.py
sglang/srt/layers/multimodal.py
sglang/srt/layers/parameter.py
sglang/srt/layers/pooler.py
sglang/srt/layers/radix_attention.py
sglang/srt/layers/rotary_embedding.py
sglang/srt/layers/sampler.py
sglang/srt/layers/torchao_utils.py
sglang/srt/layers/utils.py
sglang/srt/layers/vocab_parallel_embedding.py
sglang/srt/layers/attention/aiter_backend.py
sglang/srt/layers/attention/ascend_backend.py
sglang/srt/layers/attention/base_attn_backend.py
sglang/srt/layers/attention/cutlass_mla_backend.py
sglang/srt/layers/attention/double_sparsity_backend.py
sglang/srt/layers/attention/dual_chunk_flashattention_backend.py
sglang/srt/layers/attention/flashattention_backend.py
sglang/srt/layers/attention/flashinfer_backend.py
sglang/srt/layers/attention/flashinfer_mla_backend.py
sglang/srt/layers/attention/flashmla_backend.py
sglang/srt/layers/attention/hybrid_attn_backend.py
sglang/srt/layers/attention/intel_amx_backend.py
sglang/srt/layers/attention/merge_state.py
sglang/srt/layers/attention/tbo_backend.py
sglang/srt/layers/attention/torch_native_backend.py
sglang/srt/layers/attention/triton_backend.py
sglang/srt/layers/attention/trtllm_mha_backend.py
sglang/srt/layers/attention/trtllm_mla_backend.py
sglang/srt/layers/attention/utils.py
sglang/srt/layers/attention/vision.py
sglang/srt/layers/attention/vision_utils.py
sglang/srt/layers/attention/wave_backend.py
sglang/srt/layers/attention/triton_ops/decode_attention.py
sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py
sglang/srt/layers/attention/triton_ops/extend_attention.py
sglang/srt/layers/attention/triton_ops/merge_state.py
sglang/srt/layers/attention/triton_ops/prefill_attention.py
sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py
sglang/srt/layers/attention/wave_ops/decode_attention.py
sglang/srt/layers/attention/wave_ops/extend_attention.py
sglang/srt/layers/attention/wave_ops/prefill_attention.py
sglang/srt/layers/moe/__init__.py
sglang/srt/layers/moe/cutlass_moe.py
sglang/srt/layers/moe/cutlass_moe_params.py
sglang/srt/layers/moe/cutlass_w4a8_moe.py
sglang/srt/layers/moe/fused_moe_native.py
sglang/srt/layers/moe/rocm_moe_utils.py
sglang/srt/layers/moe/router.py
sglang/srt/layers/moe/topk.py
sglang/srt/layers/moe/utils.py
sglang/srt/layers/moe/ep_moe/__init__.py
sglang/srt/layers/moe/ep_moe/kernels.py
sglang/srt/layers/moe/ep_moe/layer.py
sglang/srt/layers/moe/fused_moe_triton/__init__.py
sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py
sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py
sglang/srt/layers/moe/fused_moe_triton/layer.py
sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py
sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=1024,device_name=NVIDIA_H200.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H200.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=2560,device_name=NVIDIA_H200.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=320,device_name=NVIDIA_H200.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=640,device_name=NVIDIA_H200.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=14336,device_name=AMD_Instinct_MI300X.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=14336,device_name=AMD_Instinct_MI325X.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=14336,device_name=AMD_Radeon_Graphics.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=14336,device_name=NVIDIA_H200.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=1792,device_name=AMD_Instinct_MI300X.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=1792,device_name=AMD_Instinct_MI325X.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=1792,device_name=AMD_Radeon_Graphics.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=1792,device_name=NVIDIA_H200.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H200.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=3584,device_name=AMD_Instinct_MI300X.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=3584,device_name=AMD_Instinct_MI325X.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=3584,device_name=AMD_Radeon_Graphics.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H200.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=3584,device_name=NVIDIA_L40S.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H200.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=7168,device_name=AMD_Instinct_MI300X.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=7168,device_name=AMD_Instinct_MI325X.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=7168,device_name=AMD_Radeon_Graphics.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H200.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=192,device_name=NVIDIA_H20.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=192,device_name=NVIDIA_H200.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=384,device_name=NVIDIA_H20.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=384,device_name=NVIDIA_H200.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=768,device_name=NVIDIA_H20.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=768,device_name=NVIDIA_H200.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=96,device_name=NVIDIA_H20.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/moe/moe_runner/__init__.py
sglang/srt/layers/moe/moe_runner/base.py
sglang/srt/layers/moe/token_dispatcher/__init__.py
sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py
sglang/srt/layers/moe/token_dispatcher/deepep.py
sglang/srt/layers/moe/token_dispatcher/standard.py
sglang/srt/layers/quantization/__init__.py
sglang/srt/layers/quantization/awq.py
sglang/srt/layers/quantization/awq_triton.py
sglang/srt/layers/quantization/base_config.py
sglang/srt/layers/quantization/blockwise_int8.py
sglang/srt/layers/quantization/fp8.py
sglang/srt/layers/quantization/fp8_kernel.py
sglang/srt/layers/quantization/fp8_utils.py
sglang/srt/layers/quantization/fpgemm_fp8.py
sglang/srt/layers/quantization/gptq.py
sglang/srt/layers/quantization/int8_kernel.py
sglang/srt/layers/quantization/int8_utils.py
sglang/srt/layers/quantization/kv_cache.py
sglang/srt/layers/quantization/marlin_utils.py
sglang/srt/layers/quantization/marlin_utils_fp8.py
sglang/srt/layers/quantization/modelopt_quant.py
sglang/srt/layers/quantization/moe_wna16.py
sglang/srt/layers/quantization/mxfp4.py
sglang/srt/layers/quantization/mxfp4_tensor.py
sglang/srt/layers/quantization/petit.py
sglang/srt/layers/quantization/petit_utils.py
sglang/srt/layers/quantization/qoq.py
sglang/srt/layers/quantization/unquant.py
sglang/srt/layers/quantization/utils.py
sglang/srt/layers/quantization/w4afp8.py
sglang/srt/layers/quantization/w8a8_fp8.py
sglang/srt/layers/quantization/w8a8_int8.py
sglang/srt/layers/quantization/compressed_tensors/__init__.py
sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
sglang/srt/layers/quantization/compressed_tensors/utils.py
sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py
sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py
sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py
sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py
sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py
sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py
sglang/srt/layers/quantization/quark/__init__.py
sglang/srt/layers/quantization/quark/quark.py
sglang/srt/layers/quantization/quark/quark_moe.py
sglang/srt/layers/quantization/quark/utils.py
sglang/srt/layers/quantization/quark/schemes/__init__.py
sglang/srt/layers/quantization/quark/schemes/quark_scheme.py
sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py
sglang/srt/lora/layers.py
sglang/srt/lora/lora.py
sglang/srt/lora/lora_config.py
sglang/srt/lora/lora_manager.py
sglang/srt/lora/lora_registry.py
sglang/srt/lora/mem_pool.py
sglang/srt/lora/utils.py
sglang/srt/lora/backend/base_backend.py
sglang/srt/lora/backend/triton_backend.py
sglang/srt/lora/triton_ops/__init__.py
sglang/srt/lora/triton_ops/gate_up_lora_b.py
sglang/srt/lora/triton_ops/qkv_lora_b.py
sglang/srt/lora/triton_ops/sgemm_lora_a.py
sglang/srt/lora/triton_ops/sgemm_lora_b.py
sglang/srt/managers/cache_controller.py
sglang/srt/managers/configure_logging.py
sglang/srt/managers/data_parallel_controller.py
sglang/srt/managers/detokenizer_manager.py
sglang/srt/managers/io_struct.py
sglang/srt/managers/mm_utils.py
sglang/srt/managers/multi_tokenizer_mixin.py
sglang/srt/managers/multimodal_processor.py
sglang/srt/managers/schedule_batch.py
sglang/srt/managers/schedule_policy.py
sglang/srt/managers/scheduler.py
sglang/srt/managers/scheduler_input_blocker.py
sglang/srt/managers/scheduler_metrics_mixin.py
sglang/srt/managers/scheduler_output_processor_mixin.py
sglang/srt/managers/scheduler_profiler_mixin.py
sglang/srt/managers/scheduler_recv_skipper.py
sglang/srt/managers/scheduler_update_weights_mixin.py
sglang/srt/managers/session_controller.py
sglang/srt/managers/template_manager.py
sglang/srt/managers/tokenizer_manager.py
sglang/srt/managers/tp_worker.py
sglang/srt/managers/tp_worker_overlap_thread.py
sglang/srt/managers/utils.py
sglang/srt/mem_cache/allocator.py
sglang/srt/mem_cache/allocator_ascend.py
sglang/srt/mem_cache/base_prefix_cache.py
sglang/srt/mem_cache/chunk_cache.py
sglang/srt/mem_cache/flush_cache.py
sglang/srt/mem_cache/hicache_storage.py
sglang/srt/mem_cache/hiradix_cache.py
sglang/srt/mem_cache/lora_radix_cache.py
sglang/srt/mem_cache/memory_pool.py
sglang/srt/mem_cache/memory_pool_host.py
sglang/srt/mem_cache/multimodal_cache.py
sglang/srt/mem_cache/radix_cache.py
sglang/srt/mem_cache/radix_cache_cpp.py
sglang/srt/mem_cache/swa_radix_cache.py
sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py
sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py
sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp
sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py
sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py
sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py
sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py
sglang/srt/mem_cache/storage/mooncake_store/unit_test.py
sglang/srt/mem_cache/storage/nixl/hicache_nixl.py
sglang/srt/mem_cache/storage/nixl/nixl_utils.py
sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py
sglang/srt/metrics/collector.py
sglang/srt/metrics/func_timer.py
sglang/srt/model_executor/cuda_graph_runner.py
sglang/srt/model_executor/forward_batch_info.py
sglang/srt/model_executor/model_runner.py
sglang/srt/model_executor/npu_graph_runner.py
sglang/srt/model_loader/__init__.py
sglang/srt/model_loader/loader.py
sglang/srt/model_loader/utils.py
sglang/srt/model_loader/weight_utils.py
sglang/srt/models/arcee.py
sglang/srt/models/baichuan.py
sglang/srt/models/bailing_moe.py
sglang/srt/models/bert.py
sglang/srt/models/chatglm.py
sglang/srt/models/clip.py
sglang/srt/models/commandr.py
sglang/srt/models/dbrx.py
sglang/srt/models/deepseek.py
sglang/srt/models/deepseek_janus_pro.py
sglang/srt/models/deepseek_nextn.py
sglang/srt/models/deepseek_v2.py
sglang/srt/models/deepseek_vl2.py
sglang/srt/models/ernie4.py
sglang/srt/models/ernie4_eagle.py
sglang/srt/models/exaone.py
sglang/srt/models/gemma.py
sglang/srt/models/gemma2.py
sglang/srt/models/gemma2_reward.py
sglang/srt/models/gemma3_causal.py
sglang/srt/models/gemma3_mm.py
sglang/srt/models/gemma3n_audio.py
sglang/srt/models/gemma3n_causal.py
sglang/srt/models/gemma3n_mm.py
sglang/srt/models/glm4.py
sglang/srt/models/glm4_moe.py
sglang/srt/models/glm4_moe_nextn.py
sglang/srt/models/glm4v.py
sglang/srt/models/glm4v_moe.py
sglang/srt/models/gpt2.py
sglang/srt/models/gpt_bigcode.py
sglang/srt/models/gpt_oss.py
sglang/srt/models/granite.py
sglang/srt/models/granitemoe.py
sglang/srt/models/grok.py
sglang/srt/models/hunyuan.py
sglang/srt/models/idefics2.py
sglang/srt/models/internlm2.py
sglang/srt/models/internlm2_reward.py
sglang/srt/models/interns1.py
sglang/srt/models/internvl.py
sglang/srt/models/kimi_vl.py
sglang/srt/models/kimi_vl_moonvit.py
sglang/srt/models/llama.py
sglang/srt/models/llama4.py
sglang/srt/models/llama_classification.py
sglang/srt/models/llama_eagle.py
sglang/srt/models/llama_eagle3.py
sglang/srt/models/llama_embedding.py
sglang/srt/models/llama_reward.py
sglang/srt/models/llava.py
sglang/srt/models/llavavid.py
sglang/srt/models/longcat_flash.py
sglang/srt/models/longcat_flash_nextn.py
sglang/srt/models/mimo.py
sglang/srt/models/mimo_mtp.py
sglang/srt/models/minicpm.py
sglang/srt/models/minicpm3.py
sglang/srt/models/minicpmo.py
sglang/srt/models/minicpmv.py
sglang/srt/models/mistral.py
sglang/srt/models/mixtral.py
sglang/srt/models/mixtral_quant.py
sglang/srt/models/mllama.py
sglang/srt/models/mllama4.py
sglang/srt/models/nemotron_nas.py
sglang/srt/models/olmo.py
sglang/srt/models/olmo2.py
sglang/srt/models/olmoe.py
sglang/srt/models/persimmon.py
sglang/srt/models/phi.py
sglang/srt/models/phi3_small.py
sglang/srt/models/phi4mm.py
sglang/srt/models/phi4mm_audio.py
sglang/srt/models/phi4mm_utils.py
sglang/srt/models/phimoe.py
sglang/srt/models/pixtral.py
sglang/srt/models/qwen.py
sglang/srt/models/qwen2.py
sglang/srt/models/qwen2_5_vl.py
sglang/srt/models/qwen2_audio.py
sglang/srt/models/qwen2_classification.py
sglang/srt/models/qwen2_eagle.py
sglang/srt/models/qwen2_moe.py
sglang/srt/models/qwen2_rm.py
sglang/srt/models/qwen2_vl.py
sglang/srt/models/qwen3.py
sglang/srt/models/qwen3_classification.py
sglang/srt/models/qwen3_moe.py
sglang/srt/models/registry.py
sglang/srt/models/roberta.py
sglang/srt/models/siglip.py
sglang/srt/models/stablelm.py
sglang/srt/models/step3_vl.py
sglang/srt/models/torch_native_llama.py
sglang/srt/models/transformers.py
sglang/srt/models/vila.py
sglang/srt/models/xverse.py
sglang/srt/models/xverse_moe.py
sglang/srt/models/yivl.py
sglang/srt/multimodal/mm_utils.py
sglang/srt/multimodal/processors/base_processor.py
sglang/srt/multimodal/processors/clip.py
sglang/srt/multimodal/processors/deepseek_vl_v2.py
sglang/srt/multimodal/processors/gemma3.py
sglang/srt/multimodal/processors/gemma3n.py
sglang/srt/multimodal/processors/glm4v.py
sglang/srt/multimodal/processors/internvl.py
sglang/srt/multimodal/processors/janus_pro.py
sglang/srt/multimodal/processors/kimi_vl.py
sglang/srt/multimodal/processors/llava.py
sglang/srt/multimodal/processors/minicpm.py
sglang/srt/multimodal/processors/mlama.py
sglang/srt/multimodal/processors/mllama4.py
sglang/srt/multimodal/processors/phi4mm.py
sglang/srt/multimodal/processors/pixtral.py
sglang/srt/multimodal/processors/qwen_audio.py
sglang/srt/multimodal/processors/qwen_vl.py
sglang/srt/multimodal/processors/step3_vl.py
sglang/srt/multimodal/processors/vila.py
sglang/srt/parser/code_completion_parser.py
sglang/srt/parser/conversation.py
sglang/srt/parser/harmony_parser.py
sglang/srt/parser/jinja_template_utils.py
sglang/srt/parser/reasoning_parser.py
sglang/srt/sampling/custom_logit_processor.py
sglang/srt/sampling/sampling_batch_info.py
sglang/srt/sampling/sampling_params.py
sglang/srt/sampling/penaltylib/__init__.py
sglang/srt/sampling/penaltylib/frequency_penalty.py
sglang/srt/sampling/penaltylib/min_new_tokens.py
sglang/srt/sampling/penaltylib/orchestrator.py
sglang/srt/sampling/penaltylib/presence_penalty.py
sglang/srt/speculative/build_eagle_tree.py
sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
sglang/srt/speculative/eagle_utils.py
sglang/srt/speculative/eagle_worker.py
sglang/srt/speculative/spec_info.py
sglang/srt/tokenizer/tiktoken_tokenizer.py
sglang/srt/weight_sync/tensor_bucket.py
sglang/srt/weight_sync/utils.py
sglang/test/__init__.py
sglang/test/doc_patch.py
sglang/test/few_shot_gsm8k.py
sglang/test/few_shot_gsm8k_engine.py
sglang/test/run_eval.py
sglang/test/runners.py
sglang/test/send_one.py
sglang/test/simple_eval_common.py
sglang/test/simple_eval_gpqa.py
sglang/test/simple_eval_humaneval.py
sglang/test/simple_eval_math.py
sglang/test/simple_eval_mgsm.py
sglang/test/simple_eval_mmlu.py
sglang/test/test_activation.py
sglang/test/test_block_fp8.py
sglang/test/test_block_fp8_deep_gemm_blackwell.py
sglang/test/test_block_fp8_ep.py
sglang/test/test_custom_ops.py
sglang/test/test_cutlass_moe.py
sglang/test/test_cutlass_w4a8_moe.py
sglang/test/test_deepep_utils.py
sglang/test/test_dynamic_grad_mode.py
sglang/test/test_fp4_moe.py
sglang/test/test_layernorm.py
sglang/test/test_marlin_moe.py
sglang/test/test_marlin_utils.py
sglang/test/test_programs.py
sglang/test/test_utils.py
sglang/test/attention/__init__.py
sglang/test/attention/test_flashattn_backend.py
sglang/test/attention/test_flashattn_mla_backend.py
sglang/test/attention/test_prefix_chunk_info.py
sglang/test/attention/test_trtllm_mla_backend.py

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,165 @@
aiohttp
requests
tqdm
numpy
IPython
setproctitle
[all]
sglang[srt]
sglang[openai]
sglang[anthropic]
sglang[torch_memory_saver]
sglang[decord]
[all_cpu]
sglang[srt_cpu]
sglang[openai]
sglang[anthropic]
sglang[decord]
[all_hip]
sglang[srt_hip]
sglang[openai]
sglang[anthropic]
sglang[decord]
[all_hpu]
sglang[srt_hpu]
sglang[openai]
sglang[anthropic]
sglang[decord]
[all_npu]
sglang[srt_npu]
sglang[openai]
sglang[anthropic]
sglang[decord]
[all_xpu]
sglang[srt_xpu]
sglang[openai]
sglang[anthropic]
sglang[decord]
[anthropic]
anthropic>=0.20.0
[blackwell]
sglang[runtime_common]
sgl-kernel
torch==2.8.0
torchaudio==2.8.0
torchvision
cuda-python
flashinfer_python==0.3.0
[decord]
decord
[dev]
sglang[all]
sglang[test]
[dev_cpu]
sglang[all_cpu]
sglang[test]
[dev_hip]
sglang[all_hip]
sglang[test]
[dev_hpu]
sglang[all_hpu]
sglang[test]
[dev_xpu]
sglang[all_xpu]
sglang[test]
[litellm]
litellm>=1.0.0
[openai]
openai==1.99.1
tiktoken
[runtime_common]
blobfile==3.0.0
build
compressed-tensors
datasets
einops
fastapi
hf_transfer
huggingface_hub
interegular
llguidance<0.8.0,>=0.7.11
modelscope
msgspec
ninja
openai==1.99.1
openai-harmony==0.0.4
orjson
outlines==0.1.11
packaging
partial_json_parser
pillow
prometheus-client>=0.20.0
psutil
pybase64
pydantic
pynvml
python-multipart
pyzmq>=25.1.2
sentencepiece
soundfile==0.13.1
scipy
timm==1.0.16
tiktoken
torchao==0.9.0
transformers==4.56.0
uvicorn
uvloop
xgrammar==0.1.23
[srt]
sglang[runtime_common]
sgl-kernel==0.3.8
torch==2.8.0
torchaudio==2.8.0
torchvision
cuda-python
flashinfer_python==0.3.0
[srt_cpu]
sglang[runtime_common]
[srt_hip]
sglang[runtime_common]
torch
petit_kernel==0.0.2
wave-lang==1.0.1
[srt_hpu]
sglang[runtime_common]
[srt_npu]
sglang[runtime_common]
[srt_xpu]
sglang[runtime_common]
[test]
accelerate
expecttest
jsonlines
matplotlib
pandas
peft
sentence_transformers
pytest
tabulate
[torch_memory_saver]
torch_memory_saver==0.0.8

View File

@@ -0,0 +1 @@
sglang

16
python/sglang/README.md Normal file
View File

@@ -0,0 +1,16 @@
# Code Structures
- `eval`: The evaluation utilities.
- `lang`: The frontend language.
- `srt`: The backend engine for running local models. (SRT = SGLang Runtime).
- `test`: The test utilities.
- `api.py`: The public APIs.
- `bench_offline_throughput.py`: Benchmark the performance in the offline mode.
- `bench_one_batch.py`: Benchmark the latency of running a single static batch without a server.
- `bench_one_batch_server.py`: Benchmark the latency of running a single batch with a server.
- `bench_serving.py`: Benchmark online serving with dynamic requests.
- `check_env.py`: Check the environment variables and dependencies.
- `global_config.py`: The global configs and constants.
- `launch_server.py`: The entry point for launching the local server.
- `utils.py`: Common utilities.
- `version.py`: Version info.

83
python/sglang/__init__.py Normal file
View File

@@ -0,0 +1,83 @@
# SGLang public APIs
# Frontend Language APIs
from sglang.global_config import global_config
from sglang.lang.api import (
Engine,
Runtime,
assistant,
assistant_begin,
assistant_end,
flush_cache,
function,
gen,
gen_int,
gen_string,
get_server_info,
image,
select,
separate_reasoning,
set_default_backend,
system,
system_begin,
system_end,
user,
user_begin,
user_end,
video,
)
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.lang.choices import (
greedy_token_selection,
token_length_normalized,
unconditional_likelihood_normalized,
)
# Lazy import some libraries
from sglang.utils import LazyImport
from sglang.version import __version__
Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI")
# Runtime Engine APIs
ServerArgs = LazyImport("sglang.srt.server_args", "ServerArgs")
Engine = LazyImport("sglang.srt.entrypoints.engine", "Engine")
__all__ = [
"Engine",
"Runtime",
"assistant",
"assistant_begin",
"assistant_end",
"flush_cache",
"function",
"gen",
"gen_int",
"gen_string",
"get_server_info",
"image",
"select",
"separate_reasoning",
"set_default_backend",
"system",
"system_begin",
"system_end",
"user",
"user_begin",
"user_end",
"video",
"RuntimeEndpoint",
"greedy_token_selection",
"token_length_normalized",
"unconditional_likelihood_normalized",
"ServerArgs",
"Anthropic",
"LiteLLM",
"OpenAI",
"VertexAI",
"global_config",
"__version__",
]

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,452 @@
"""
Benchmark the throughput in the offline mode.
It accepts server arguments (the same as launch_server.py) and benchmark arguments (the same as bench_serving.py).
# Usage
## Sharegpt dataset with default args
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --num-prompts 10
## Random dataset with default args
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024
"""
import argparse
import asyncio
import dataclasses
import inspect
import json
import logging
import os
import random
import time
from typing import Dict, List, Optional
import numpy as np
from sglang.bench_serving import (
DatasetRow,
get_dataset,
get_tokenizer,
sample_random_requests,
set_ulimit,
)
from sglang.lang.backend.runtime_endpoint import Runtime
from sglang.srt.entrypoints.engine import Engine
from sglang.srt.server_args import ServerArgs
@dataclasses.dataclass
class BenchArgs:
backend: str = "engine"
result_filename: str = ""
dataset_name: str = "sharegpt"
dataset_path: str = ""
num_prompts: int = 1000
sharegpt_output_len: Optional[int] = None
sharegpt_context_len: Optional[int] = None
random_input_len: int = 1024
random_output_len: int = 1024
random_range_ratio: float = 0.0
gsp_num_groups: int = 64
gsp_prompts_per_group: int = 16
gsp_system_prompt_len: int = 2048
gsp_question_len: int = 128
gsp_output_len: int = 256
seed: int = 1
disable_ignore_eos: bool = False
extra_request_body: Optional[str] = None
apply_chat_template: bool = False
profile: bool = False
skip_warmup: bool = False
do_not_exit: bool = False
prompt_suffix: str = ""
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument("--backend", type=str, default=BenchArgs.backend)
parser.add_argument(
"--result-filename", type=str, default=BenchArgs.result_filename
)
parser.add_argument(
"--dataset-name",
type=str,
default="sharegpt",
choices=["sharegpt", "random", "generated-shared-prefix"],
help="Name of the dataset to benchmark on.",
)
parser.add_argument(
"--dataset-path", type=str, default="", help="Path to the dataset."
)
parser.add_argument(
"--num-prompts",
type=int,
default=BenchArgs.num_prompts,
help="Number of prompts to process. Default is 1000.",
)
parser.add_argument(
"--sharegpt-output-len",
type=int,
default=BenchArgs.sharegpt_output_len,
help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
)
parser.add_argument(
"--sharegpt-context-len",
type=int,
default=BenchArgs.sharegpt_context_len,
help="The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.",
)
parser.add_argument(
"--random-input-len",
type=int,
default=BenchArgs.random_input_len,
help="Number of input tokens per request, used only for random dataset.",
)
parser.add_argument(
"--random-output-len",
type=int,
default=BenchArgs.random_output_len,
help="Number of output tokens per request, used only for random dataset.",
)
parser.add_argument(
"--random-range-ratio",
type=float,
default=BenchArgs.random_range_ratio,
help="Range of sampled ratio of input/output length, "
"used only for random dataset.",
)
parser.add_argument(
"--gsp-num-groups",
type=int,
default=BenchArgs.gsp_num_groups,
help="Number of groups with shared prefix, used"
"only for generate-shared-prefix",
)
parser.add_argument(
"--gsp-prompts-per-group",
type=int,
default=BenchArgs.gsp_prompts_per_group,
help="Number of prompts per group of shared prefix, used"
"only for generate-shared-prefix",
)
parser.add_argument(
"--gsp-system-prompt-len",
type=int,
default=BenchArgs.gsp_system_prompt_len,
help="System prompt length, used" "only for generate-shared-prefix",
)
parser.add_argument(
"--gsp-question-len",
type=int,
default=BenchArgs.gsp_question_len,
help="Question length, used" "only for generate-shared-prefix",
)
parser.add_argument(
"--gsp-output-len",
type=int,
default=BenchArgs.gsp_output_len,
help="Target length in tokens for outputs in generated-shared-prefix dataset",
)
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
parser.add_argument(
"--disable-ignore-eos",
action="store_true",
help="Disable ignore EOS token",
)
parser.add_argument(
"--extra-request-body",
metavar='{"key1": "value1", "key2": "value2"}',
type=str,
default=BenchArgs.extra_request_body,
help="Append given JSON object to the request payload. You can use this to specify"
"additional generate params like sampling params.",
)
parser.add_argument(
"--apply-chat-template",
action="store_true",
help="Apply chat template",
)
parser.add_argument(
"--profile",
action="store_true",
help="Use Torch Profiler. The endpoint must be launched with "
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
)
parser.add_argument(
"--skip-warmup",
action="store_true",
help="Skip the warmup batches.",
)
parser.add_argument(
"--do-not-exit",
action="store_true",
help="Do not exit the program. This is useful for nsys profile with --duration and --delay.",
)
parser.add_argument(
"--prompt-suffix",
type=str,
default="",
help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
attrs = [attr.name for attr in dataclasses.fields(cls)]
return cls(**{attr: getattr(args, attr) for attr in attrs})
def throughput_test_once(
backend_name: str,
backend,
reqs: List[DatasetRow],
ignore_eos: bool,
extra_request_body: Dict,
profile: bool,
):
measurement_results = {
"backend": backend_name,
"successful_requests": len(reqs),
"total_latency": -1,
"total_input_tokens": sum(r.prompt_len for r in reqs),
"total_output_tokens": -1,
"request_throughput": -1,
"input_throughput": -1,
"output_throughput": -1,
"total_throughput": -1,
}
prompt = [r.prompt for r in reqs]
sampling_params = [
{
"temperature": 0,
"max_new_tokens": r.output_len,
"ignore_eos": ignore_eos,
**extra_request_body,
}
for r in reqs
]
if profile:
assert (
"SGLANG_TORCH_PROFILER_DIR" in os.environ
), "Please set SGLANG_TORCH_PROFILER_DIR."
os.makedirs(os.environ["SGLANG_TORCH_PROFILER_DIR"], exist_ok=True)
backend.start_profile()
st = time.perf_counter()
gen_out = backend.generate(prompt=prompt, sampling_params=sampling_params)
latency = time.perf_counter() - st
if profile:
dir = os.getenv("SGLANG_TORCH_PROFILER_DIR")
known_files = set(os.listdir(dir))
backend.stop_profile()
monitor_trace_file(known_files, dir)
if backend_name == "runtime":
gen_out = json.loads(gen_out)
server_info = backend.get_server_info()
measurement_results["total_latency"] = latency
measurement_results["total_output_tokens"] = sum(
o["meta_info"]["completion_tokens"] for o in gen_out
)
measurement_results["request_throughput"] = (
measurement_results["successful_requests"] / latency
)
measurement_results["input_throughput"] = (
measurement_results["total_input_tokens"] / latency
)
measurement_results["output_throughput"] = (
measurement_results["total_output_tokens"] / latency
)
measurement_results["total_throughput"] = (
measurement_results["total_input_tokens"]
+ measurement_results["total_output_tokens"]
) / latency
if inspect.isawaitable(server_info):
server_info = asyncio.run(server_info)
measurement_results["last_gen_throughput"] = server_info["internal_states"][0][
"last_gen_throughput"
]
return measurement_results
def monitor_trace_file(known_files, directory, interval=1):
print(f"Monitoring {directory} for new trace files...")
while True:
flag = False
time.sleep(interval)
current_files = set(os.listdir(directory))
new_files = current_files - known_files
for new_file in new_files:
new_file_path = os.path.join(directory, new_file)
print(f"New file detected: {new_file}")
previous_size = 0
while True:
try:
current_size = os.path.getsize(new_file_path)
except FileNotFoundError:
print(f"File {new_file} is no longer accessible.")
break
if current_size > previous_size:
previous_size = current_size
else:
flag = True
break
time.sleep(interval)
if flag:
break
def throughput_test(
server_args: ServerArgs,
bench_args: BenchArgs,
):
if bench_args.backend == "engine":
backend = Engine(**dataclasses.asdict(server_args))
if not backend:
raise ValueError("Please provide valid engine arguments")
elif bench_args.backend == "runtime":
backend = Runtime(**dataclasses.asdict(server_args))
else:
raise ValueError('Please set backend to either "engine" or "runtime"')
tokenizer_id = server_args.tokenizer_path or server_args.model_path
tokenizer = get_tokenizer(tokenizer_id)
# Set global environments
set_ulimit()
random.seed(bench_args.seed)
np.random.seed(bench_args.seed)
# Parse args
extra_request_body = {}
if bench_args.extra_request_body:
extra_request_body = json.loads(args.extra_request_body)
# Read dataset
input_requests = get_dataset(bench_args, tokenizer)
warmup_requests = sample_random_requests(
input_len=256,
output_len=16,
num_prompts=min(bench_args.num_prompts, 16),
range_ratio=1.0,
tokenizer=tokenizer,
dataset_path=bench_args.dataset_path,
)
# Warm up
if not bench_args.skip_warmup:
logging.info("\nWarmup...")
throughput_test_once(
backend_name=bench_args.backend,
backend=backend,
reqs=warmup_requests,
ignore_eos=not bench_args.disable_ignore_eos,
extra_request_body=extra_request_body,
profile=False,
)
time.sleep(0.5)
logging.info("\nBenchmark...")
result = throughput_test_once(
backend_name=bench_args.backend,
backend=backend,
reqs=input_requests,
ignore_eos=not bench_args.disable_ignore_eos,
extra_request_body=extra_request_body,
profile=bench_args.profile,
)
backend.shutdown()
if bench_args.result_filename:
with open(bench_args.result_filename, "a") as fout:
fout.write(json.dumps(result) + "\n")
print(
"\n{s:{c}^{n}}".format(s=" Offline Throughput Benchmark Result ", n=50, c="=")
)
print("{:<40} {:<10}".format("Backend:", result["backend"]))
print("{:<40} {:<10}".format("Successful requests:", result["successful_requests"]))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", result["total_latency"]))
print("{:<40} {:<10}".format("Total input tokens:", result["total_input_tokens"]))
print(
"{:<40} {:<10}".format("Total generated tokens:", result["total_output_tokens"])
)
print(
"{:<40} {:<10.2f}".format(
"Last generation throughput (tok/s):", result["last_gen_throughput"]
)
)
print(
"{:<40} {:<10.2f}".format(
"Request throughput (req/s):", result["request_throughput"]
)
)
print(
"{:<40} {:<10.2f}".format(
"Input token throughput (tok/s):", result["input_throughput"]
)
)
print(
"{:<40} {:<10.2f}".format(
"Output token throughput (tok/s):", result["output_throughput"]
)
)
print(
"{:<40} {:<10.2f}".format(
"Total token throughput (tok/s):", result["total_throughput"]
)
)
print("=" * 50)
return result
if __name__ == "__main__":
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
BenchArgs.add_cli_args(parser)
args = parser.parse_args()
# handling ModelScope model downloads
if os.getenv("SGLANG_USE_MODELSCOPE", "false").lower() in ("true", "1"):
if os.path.exists(args.model_path):
print(f"Using local model path: {args.model_path}")
else:
try:
from modelscope import snapshot_download
print(f"Using ModelScope to download model: {args.model_path}")
# download the model and replace args.model_path
args.model_path = snapshot_download(
args.model_path,
)
print(f"Model downloaded to: {args.model_path}")
except Exception as e:
print(f"ModelScope download failed: {str(e)}")
raise e
server_args = ServerArgs.from_cli_args(args)
bench_args = BenchArgs.from_cli_args(args)
logging.basicConfig(
level=getattr(logging, server_args.log_level.upper()),
format="%(message)s",
)
throughput_test(server_args, bench_args)
while bench_args.do_not_exit:
pass

View File

@@ -0,0 +1,665 @@
"""
Benchmark the latency of running a single static batch without a server.
This script does not launch a server and uses the low-level APIs.
It accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths).
# Usage (latency test)
## with dummy weights:
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
## sweep through multiple data points and store (append) the results in a jsonl file:
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --run-name test_run
## run with profiling:
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --profile
# Usage (correctness test):
python -m sglang.bench_one_batch --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
## Reference output (of the correctness test above, can be gpu dependent):
input_ids=[[1, 450, 7483, 310, 3444, 338], [1, 450, 7483, 310, 278, 3303, 13187, 290, 338], [1, 20628, 338, 263, 6575, 1460, 2462, 322, 306, 763]]
prefill logits (first half): tensor([[-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633],
[-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633],
[ -9.1875, -10.2500, 2.7129, ..., -4.3359, -4.0664, -4.1328]],
device='cuda:0')
prefill logits (final): tensor([[-8.3125, -7.1172, 3.3457, ..., -4.9570, -4.1328, -3.4141],
[-8.9141, -9.0156, 4.1445, ..., -4.9922, -4.4961, -4.0781],
[-9.6328, -9.0547, 4.0195, ..., -5.3047, -4.7148, -4.4570]],
device='cuda:0')
========== Prompt 0 ==========
<s> The capital of France is Paris.
The capital of the United States is Washington, D.C.
========== Prompt 1 ==========
<s> The capital of the United Kindom is London.
The capital of the United Kingdom is London.
The capital of the
========== Prompt 2 ==========
<s> Today is a sunny day and I like to go for a walk in the park.
I'm going to the park
"""
import argparse
import copy
import dataclasses
import itertools
import json
import logging
import multiprocessing
import os
import time
from typing import Tuple
import numpy as np
import torch
import torch.distributed as dist
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.distributed.parallel_state import destroy_distributed_environment
from sglang.srt.entrypoints.engine import _set_envs_and_config
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.layers.moe import initialize_moe_config
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.managers.scheduler import Scheduler
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import (
configure_logger,
get_bool_env_var,
kill_process_tree,
require_mlp_sync,
require_mlp_tp_gather,
set_gpu_proc_affinity,
suppress_other_loggers,
)
@dataclasses.dataclass
class BenchArgs:
run_name: str = "default"
batch_size: Tuple[int] = (1,)
input_len: Tuple[int] = (1024,)
output_len: Tuple[int] = (16,)
prompt_filename: str = ""
result_filename: str = "result.jsonl"
correctness_test: bool = False
# This is only used for correctness test
cut_len: int = 4
log_decode_step: int = 0
profile: bool = False
profile_record_shapes: bool = False
profile_filename_prefix: str = "profile"
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument("--run-name", type=str, default=BenchArgs.run_name)
parser.add_argument(
"--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
)
parser.add_argument(
"--input-len", type=int, nargs="+", default=BenchArgs.input_len
)
parser.add_argument(
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
)
parser.add_argument(
"--prompt-filename", type=str, default=BenchArgs.prompt_filename
)
parser.add_argument(
"--result-filename", type=str, default=BenchArgs.result_filename
)
parser.add_argument("--correctness-test", action="store_true")
parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
parser.add_argument(
"--log-decode-step",
type=int,
default=BenchArgs.log_decode_step,
help="Log decode latency by step, default is set to zero to disable.",
)
parser.add_argument(
"--profile", action="store_true", help="Use Torch Profiler."
)
parser.add_argument(
"--profile-record-shapes",
action="store_true",
help="Record tensor shapes in profiling results.",
)
parser.add_argument(
"--profile-filename-prefix",
type=str,
default=BenchArgs.profile_filename_prefix,
help="Prefix of the profiling file names. The full profiling result file(s) be "
'"[profile_filename_prefix]_batch[batch_size]_input[input_len]_output[output_len].trace.json.gz"',
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
# use the default value's type to cast the args into correct types.
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
return cls(
**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
)
def load_model(server_args, port_args, tp_rank):
suppress_other_loggers()
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
model_config = ModelConfig.from_server_args(server_args)
model_runner = ModelRunner(
model_config=model_config,
mem_fraction_static=server_args.mem_fraction_static,
gpu_id=tp_rank,
tp_rank=tp_rank,
tp_size=server_args.tp_size,
moe_ep_rank=moe_ep_rank,
moe_ep_size=server_args.ep_size,
pp_rank=0,
pp_size=1,
nccl_port=port_args.nccl_port,
server_args=server_args,
)
rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
if server_args.tp_size > 1:
dist.barrier()
return model_runner, tokenizer
def prepare_inputs_for_correctness_test(bench_args, tokenizer, custom_prompts):
prompts = (
custom_prompts
if custom_prompts
else [
"The capital of France is",
"The capital of the United Kindom is",
"Today is a sunny day and I like",
]
)
input_ids = [tokenizer.encode(p) for p in prompts]
sampling_params = SamplingParams(
temperature=0,
max_new_tokens=BenchArgs.output_len,
)
reqs = []
for i in range(len(prompts)):
assert len(input_ids[i]) > bench_args.cut_len
tmp_input_ids = input_ids[i][: bench_args.cut_len]
req = Req(
rid=i,
origin_input_text=prompts[i],
origin_input_ids=tmp_input_ids,
sampling_params=sampling_params,
)
req.prefix_indices = []
req.fill_ids = req.origin_input_ids
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
req.logprob_start_len = len(req.origin_input_ids) - 1
reqs.append(req)
return input_ids, reqs
def prepare_extend_inputs_for_correctness_test(
bench_args, input_ids, reqs, model_runner
):
for i in range(len(reqs)):
req = reqs[i]
req.fill_ids += input_ids[i][bench_args.cut_len :]
req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
i, : bench_args.cut_len
]
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
req.logprob_start_len = len(req.origin_input_ids) - 1
return reqs
def prepare_synthetic_inputs_for_latency_test(
batch_size, input_len, custom_inputs=None
):
input_ids = (
custom_inputs
if custom_inputs
else np.random.randint(0, 10000, (batch_size, input_len), dtype=np.int32)
)
sampling_params = SamplingParams(
temperature=0,
max_new_tokens=BenchArgs.output_len,
)
reqs = []
for i in range(len(input_ids)):
req = Req(
rid=i,
origin_input_text="",
origin_input_ids=list(input_ids[i]),
sampling_params=sampling_params,
)
req.prefix_indices = []
req.fill_ids = req.origin_input_ids
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
req.logprob_start_len = len(req.origin_input_ids) - 1
reqs.append(req)
return reqs
@torch.no_grad
def extend(reqs, model_runner):
batch = ScheduleBatch.init_new(
reqs=reqs,
req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator,
tree_cache=None,
model_config=model_runner.model_config,
enable_overlap=False,
spec_algorithm=SpeculativeAlgorithm.NONE,
)
batch.prepare_for_extend()
_maybe_prepare_mlp_sync_batch(batch, model_runner)
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
logits_output, _ = model_runner.forward(forward_batch)
next_token_ids = model_runner.sample(logits_output, forward_batch)
return next_token_ids, logits_output.next_token_logits, batch
@torch.no_grad
def decode(input_token_ids, batch, model_runner):
batch.output_ids = input_token_ids
batch.prepare_for_decode()
_maybe_prepare_mlp_sync_batch(batch, model_runner)
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
logits_output, _ = model_runner.forward(forward_batch)
next_token_ids = model_runner.sample(logits_output, forward_batch)
return next_token_ids, logits_output.next_token_logits
def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
if require_mlp_sync(model_runner.server_args):
Scheduler.prepare_mlp_sync_batch_raw(
batch,
dp_size=model_runner.server_args.dp_size,
attn_tp_size=1,
tp_group=model_runner.tp_group,
get_idle_batch=None,
disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
spec_algorithm=SpeculativeAlgorithm.NONE,
speculative_num_draft_tokens=None,
require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
disable_overlap_schedule=model_runner.server_args.disable_overlap_schedule,
)
def _read_prompts_from_file(prompt_file, rank_print):
"""Read custom prompts from the file specified by `--prompt-filename`."""
if not prompt_file:
return []
if not os.path.exists(prompt_file):
rank_print(
f"Custom prompt file {prompt_file} not found. Using default inputs..."
)
return []
with open(prompt_file, "r") as pf:
return pf.readlines()
def _save_profile_trace_results(profiler, filename):
parent_dir = os.path.dirname(os.path.abspath(filename))
os.makedirs(parent_dir, exist_ok=True)
profiler.export_chrome_trace(filename)
print(
profiler.key_averages(group_by_input_shape=True).table(
sort_by="self_cpu_time_total"
)
)
def correctness_test(
server_args,
port_args,
bench_args,
tp_rank,
):
# Configure the logger
configure_logger(server_args, prefix=f" TP{tp_rank}")
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
# Load the model
model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
# Prepare inputs
custom_prompts = _read_prompts_from_file(bench_args.prompt_filename, rank_print)
input_ids, reqs = prepare_inputs_for_correctness_test(
bench_args, tokenizer, custom_prompts
)
rank_print(f"\n{input_ids=}\n")
if bench_args.cut_len > 0:
# Prefill
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
rank_print(f"prefill logits (first half): {next_token_logits} \n")
# Prepare extend inputs
reqs = prepare_extend_inputs_for_correctness_test(
bench_args, input_ids, reqs, model_runner
)
# Extend (prefill w/ KV cache)
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
rank_print(f"prefill logits (final): {next_token_logits} \n")
# Decode
output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
for _ in range(bench_args.output_len[0] - 1):
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
next_token_ids_list = next_token_ids.tolist()
for i in range(len(reqs)):
output_ids[i].append(next_token_ids_list[i])
# Print output texts
for i in range(len(reqs)):
rank_print(f"========== Prompt {i} ==========")
rank_print(tokenizer.decode(output_ids[i]), "\n")
def synchronize(device):
torch.get_device_module(device).synchronize()
def latency_test_run_once(
run_name,
model_runner,
rank_print,
reqs,
batch_size,
input_len,
output_len,
device,
log_decode_step,
profile,
profile_record_shapes,
profile_filename_prefix,
):
max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
if batch_size > max_batch_size:
rank_print(
f"skipping ({batch_size}, {input_len}, {output_len}) due to max batch size limit"
)
return
# Clear the pools.
model_runner.req_to_token_pool.clear()
model_runner.token_to_kv_pool_allocator.clear()
measurement_results = {
"run_name": run_name,
"batch_size": batch_size,
"input_len": input_len,
"output_len": output_len,
}
tot_latency = 0
profiler = None
if profile:
profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=True,
record_shapes=profile_record_shapes,
)
profiler.start()
# Prefill
synchronize(device)
tic = time.perf_counter()
next_token_ids, _, batch = extend(reqs, model_runner)
synchronize(device)
prefill_latency = time.perf_counter() - tic
tot_latency += prefill_latency
throughput = input_len * batch_size / prefill_latency
rank_print(
f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
)
measurement_results["prefill_latency"] = prefill_latency
measurement_results["prefill_throughput"] = throughput
if profile:
profiler.stop()
profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_prefill.trace.json.gz"
_save_profile_trace_results(profiler, profile_filename)
rank_print(
f"torch profiler chrome trace for prefill saved to {profile_filename}"
)
# Decode
decode_latencies = []
for i in range(output_len - 1):
synchronize(device)
if profile and i == output_len / 2:
profiler = None
profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=True,
record_shapes=profile_record_shapes,
)
profiler.start()
tic = time.perf_counter()
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
synchronize(device)
latency = time.perf_counter() - tic
tot_latency += latency
throughput = batch_size / latency
decode_latencies.append(latency)
if i < 5 or (log_decode_step > 0 and i % log_decode_step == 0):
rank_print(
f"Decode {i}. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
)
if profile and i == output_len / 2:
profiler.stop()
profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_decode.trace.json.gz"
_save_profile_trace_results(profiler, profile_filename)
rank_print(
f"torch profiler chrome trace for decoding 1 token saved to {profile_filename}"
)
# Record decode timing from 2nd output
if output_len > 1:
med_decode_latency = np.median(decode_latencies)
med_decode_throughput = batch_size / med_decode_latency
rank_print(
f"Decode. median latency: {med_decode_latency:6.5f} s, median throughput: {med_decode_throughput:9.2f} token/s"
)
measurement_results["median_decode_latency"] = med_decode_latency
measurement_results["median_decode_throughput"] = med_decode_throughput
throughput = (input_len + output_len) * batch_size / tot_latency
rank_print(
f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
)
measurement_results["total_latency"] = tot_latency
measurement_results["overall_throughput"] = throughput
return measurement_results
def latency_test(
server_args,
port_args,
bench_args,
tp_rank,
):
initialize_moe_config(server_args)
# Set CPU affinity
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, tp_rank)
# Configure the logger
configure_logger(server_args, prefix=f" TP{tp_rank}")
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
# Load the model
model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
# Prepare inputs for warm up
reqs = prepare_synthetic_inputs_for_latency_test(
bench_args.batch_size[0], bench_args.input_len[0]
)
# Warm up
rank_print("Warmup ...")
latency_test_run_once(
bench_args.run_name,
model_runner,
rank_print,
reqs,
bench_args.batch_size[0],
bench_args.input_len[0],
min(32, bench_args.output_len[0]), # shorter decoding to speed up the warmup
server_args.device,
log_decode_step=0,
profile=False,
profile_record_shapes=False,
profile_filename_prefix="", # not used
)
rank_print("Benchmark ...")
custom_inputs = _read_prompts_from_file(bench_args.prompt_filename, rank_print)
custom_inputs = [tokenizer.encode(p.strip()) for p in custom_inputs]
custom_input_len = len(custom_inputs)
# Run the sweep
result_list = []
for bs, il, ol in itertools.product(
bench_args.batch_size, bench_args.input_len, bench_args.output_len
):
bs_aligned_inputs = []
if custom_inputs:
if custom_input_len == bs:
bs_aligned_inputs = custom_inputs
elif custom_input_len > bs:
rank_print(
f"Custom input size ({custom_input_len}) is larger than batch_size ({bs}). "
f"Using the first {bs} prompts."
)
bs_aligned_inputs = copy.deepcopy(custom_inputs[:bs])
else:
rank_print(
f"Custom input size ({custom_input_len}) is smaller than batch_size ({bs}). "
f"Pad to the desired batch_size with the last prompt."
)
bs_aligned_inputs = copy.deepcopy(custom_inputs)
bs_aligned_inputs.extend(
[bs_aligned_inputs[-1]] * (bs - custom_input_len)
)
reqs = prepare_synthetic_inputs_for_latency_test(bs, il, bs_aligned_inputs)
ret = latency_test_run_once(
bench_args.run_name,
model_runner,
rank_print,
reqs,
bs,
il,
ol,
server_args.device,
bench_args.log_decode_step,
bench_args.profile if tp_rank == 0 else None,
bench_args.profile_record_shapes if tp_rank == 0 else None,
bench_args.profile_filename_prefix,
)
if ret is not None:
result_list.append(ret)
# Write results in jsonlines format on rank 0.
if tp_rank == 0 and bench_args.result_filename:
with open(bench_args.result_filename, "a") as fout:
for result in result_list:
fout.write(json.dumps(result) + "\n")
if server_args.tp_size > 1:
destroy_distributed_environment()
def main(server_args, bench_args):
server_args.cuda_graph_max_bs = max(bench_args.batch_size)
_set_envs_and_config(server_args)
if server_args.model_path:
if bench_args.correctness_test:
work_func = correctness_test
else:
work_func = latency_test
else:
raise ValueError(
"Provide --model-path for running the tests or "
"provide --result-filename for plotting the results"
)
port_args = PortArgs.init_new(server_args)
if server_args.tp_size == 1:
work_func(server_args, port_args, bench_args, 0)
else:
workers = []
for tp_rank in range(server_args.tp_size):
proc = multiprocessing.Process(
target=work_func,
args=(
server_args,
port_args,
bench_args,
tp_rank,
),
)
proc.start()
workers.append(proc)
for proc in workers:
proc.join()
proc.terminate()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
BenchArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
bench_args = BenchArgs.from_cli_args(args)
logging.basicConfig(
level=getattr(logging, server_args.log_level.upper()),
format="%(message)s",
)
try:
main(server_args, bench_args)
finally:
if server_args.tp_size != 1:
kill_process_tree(os.getpid(), include_parent=False)

View File

@@ -0,0 +1,430 @@
"""
Benchmark the latency of running a single batch with a server.
This script launches a server and uses the HTTP interface.
It accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths).
Usage:
python3 -m sglang.bench_one_batch_server --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --show-report --profile --profile-by-stage
"""
import argparse
import dataclasses
import itertools
import json
import multiprocessing
import os
import time
from typing import List, Tuple
import requests
from sglang.bench_serving import get_tokenizer, sample_random_requests
from sglang.profiler import run_profile
from sglang.srt.entrypoints.http_server import launch_server
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import is_blackwell, kill_process_tree
from sglang.test.test_utils import is_in_ci, write_github_step_summary
@dataclasses.dataclass
class BenchArgs:
run_name: str = "default"
batch_size: Tuple[int] = (1,)
input_len: Tuple[int] = (1024,)
output_len: Tuple[int] = (16,)
temperature: float = 0.0
return_logprob: bool = False
client_stream_interval: int = 1
input_len_step_percentage: float = 0.0
result_filename: str = "result.jsonl"
base_url: str = ""
skip_warmup: bool = False
show_report: bool = False
profile: bool = False
profile_steps: int = 3
profile_by_stage: bool = False
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument("--run-name", type=str, default=BenchArgs.run_name)
parser.add_argument(
"--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
)
parser.add_argument(
"--input-len", type=int, nargs="+", default=BenchArgs.input_len
)
parser.add_argument(
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
)
parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
parser.add_argument("--return-logprob", action="store_true")
parser.add_argument(
"--client-stream-interval",
type=int,
default=BenchArgs.client_stream_interval,
)
parser.add_argument(
"--input-len-step-percentage",
type=float,
default=BenchArgs.input_len_step_percentage,
)
parser.add_argument(
"--result-filename", type=str, default=BenchArgs.result_filename
)
parser.add_argument("--base-url", type=str, default=BenchArgs.base_url)
parser.add_argument("--skip-warmup", action="store_true")
parser.add_argument("--show-report", action="store_true")
parser.add_argument("--profile", action="store_true")
parser.add_argument(
"--profile-steps", type=int, default=BenchArgs.profile_steps
)
parser.add_argument("--profile-by-stage", action="store_true")
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
# use the default value's type to cast the args into correct types.
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
return cls(
**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
)
def launch_server_internal(server_args):
try:
launch_server(server_args)
except Exception as e:
raise e
finally:
kill_process_tree(os.getpid(), include_parent=False)
def launch_server_process(server_args: ServerArgs):
proc = multiprocessing.Process(target=launch_server_internal, args=(server_args,))
proc.start()
base_url = f"http://{server_args.host}:{server_args.port}"
timeout = 600
start_time = time.time()
while time.time() - start_time < timeout:
try:
headers = {
"Content-Type": "application/json; charset=utf-8",
}
response = requests.get(f"{base_url}/v1/models", headers=headers)
if response.status_code == 200:
return proc, base_url
except requests.RequestException:
pass
time.sleep(10)
raise TimeoutError("Server failed to start within the timeout period.")
def run_one_case(
url: str,
batch_size: int,
input_len: int,
output_len: int,
temperature: float,
return_logprob: bool,
stream_interval: int,
input_len_step_percentage: float,
run_name: str,
result_filename: str,
tokenizer,
profile: bool = False,
profile_steps: int = 3,
profile_by_stage: bool = False,
):
requests.post(url + "/flush_cache")
input_requests = sample_random_requests(
input_len=input_len,
output_len=output_len,
num_prompts=batch_size,
range_ratio=1.0,
tokenizer=tokenizer,
dataset_path="",
random_sample=True,
return_text=False,
)
use_structured_outputs = False
if use_structured_outputs:
texts = []
for _ in range(batch_size):
texts.append(
"Human: What is the capital city of france? can you give as many trivial information as possible about that city? answer in json.\n"
* 50
+ "Assistant:"
)
json_schema = "$$ANY$$"
else:
json_schema = None
profile_link = None
if profile:
profile_link: str = run_profile(
url, profile_steps, ["CPU", "GPU"], None, None, profile_by_stage
)
tic = time.perf_counter()
response = requests.post(
url + "/generate",
json={
"input_ids": [req.prompt for req in input_requests],
"sampling_params": {
"temperature": temperature,
"max_new_tokens": output_len,
"ignore_eos": True,
"json_schema": json_schema,
"stream_interval": stream_interval,
},
"return_logprob": return_logprob,
"stream": True,
},
stream=True,
)
# The TTFT of the last request in the batch
ttft = 0.0
for chunk in response.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
if "error" in data:
raise RuntimeError(f"Request has failed. {data}.")
assert (
data["meta_info"]["finish_reason"] is None
or data["meta_info"]["finish_reason"]["type"] == "length"
)
if data["meta_info"]["completion_tokens"] == 1:
ttft = time.perf_counter() - tic
latency = time.perf_counter() - tic
input_throughput = batch_size * input_len / ttft
output_throughput = batch_size * output_len / (latency - ttft)
overall_throughput = batch_size * (input_len + output_len) / latency
server_info = requests.get(url + "/get_server_info").json()
acc_length = server_info["internal_states"][0].get("avg_spec_accept_length", None)
last_gen_throughput = server_info["internal_states"][0]["last_gen_throughput"]
print(f"batch size: {batch_size}")
print(f"input_len: {input_len}")
print(f"output_len: {output_len}")
print(f"latency: {latency:.2f} s")
print(f"ttft: {ttft:.2f} s")
print(f"last generation throughput: {last_gen_throughput:.2f} tok/s")
print(f"input throughput: {input_throughput:.2f} tok/s")
if output_len != 1:
print(f"output throughput: {output_throughput:.2f} tok/s")
if result_filename:
with open(result_filename, "a") as fout:
res = {
"run_name": run_name,
"batch_size": batch_size,
"input_len": input_len,
"output_len": output_len,
"latency": round(latency, 4),
"output_throughput": round(output_throughput, 2),
"overall_throughput": round(overall_throughput, 2),
"last_gen_throughput": round(last_gen_throughput, 2),
}
fout.write(json.dumps(res) + "\n")
return (
batch_size,
latency,
ttft,
input_throughput,
output_throughput,
overall_throughput,
last_gen_throughput,
acc_length,
profile_link if profile else None,
)
def get_report_summary(
result: List[Tuple], server_args: ServerArgs, bench_args: BenchArgs
):
import tabulate
summary = (
f"\nInput lens: {bench_args.input_len}. Output lens: {bench_args.output_len}.\n"
)
headers = [
"batch size",
"latency (s)",
"input throughput (tok/s)",
"output throughput (tok/s)",
"acc length",
"ITL (ms)",
"input cost ($/1M)",
"output cost ($/1M)",
]
if bench_args.profile:
headers.append("profile")
rows = []
for (
batch_size,
latency,
ttft,
input_throughput,
output_throughput,
_,
_,
acc_length,
trace_link,
) in result:
if is_blackwell():
hourly_cost_per_gpu = 4 # $4/hour for one B200
else:
hourly_cost_per_gpu = 2 # $2/hour for one H100
hourly_cost = hourly_cost_per_gpu * server_args.tp_size
input_util = 0.7
accept_length = round(acc_length, 2) if acc_length is not None else "n/a"
itl = 1 / (output_throughput / batch_size) * 1000
input_cost = 1e6 / (input_throughput * input_util) / 3600 * hourly_cost
output_cost = 1e6 / output_throughput / 3600 * hourly_cost
row = [
batch_size,
latency,
input_throughput,
output_throughput,
accept_length,
itl,
input_cost,
output_cost,
]
if trace_link:
row.append(f"[Profile]({trace_link})")
rows.append(row)
summary += tabulate.tabulate(
rows, headers=headers, tablefmt="github", floatfmt=".2f"
)
return summary
def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
if bench_args.base_url:
proc, base_url = None, bench_args.base_url
else:
proc, base_url = launch_server_process(server_args)
server_info = requests.get(base_url + "/get_server_info").json()
if "tokenizer_path" in server_info:
tokenizer_path = server_info["tokenizer_path"]
elif "prefill" in server_info:
tokenizer_path = server_info["prefill"][0]["tokenizer_path"]
tokenizer = get_tokenizer(tokenizer_path)
# warmup
if not bench_args.skip_warmup:
print("=" * 8 + " Warmup Begin " + "=" * 8)
run_one_case(
base_url,
batch_size=16,
input_len=1024,
output_len=16,
temperature=bench_args.temperature,
return_logprob=bench_args.return_logprob,
stream_interval=bench_args.client_stream_interval,
input_len_step_percentage=bench_args.input_len_step_percentage,
run_name="",
result_filename="",
tokenizer=tokenizer,
)
print("=" * 8 + " Warmup End " + "=" * 8 + "\n")
# benchmark
result = []
bench_result = []
try:
for bs, il, ol in itertools.product(
bench_args.batch_size, bench_args.input_len, bench_args.output_len
):
result.append(
run_one_case(
base_url,
bs,
il,
ol,
temperature=bench_args.temperature,
return_logprob=bench_args.return_logprob,
stream_interval=bench_args.client_stream_interval,
input_len_step_percentage=bench_args.input_len_step_percentage,
run_name=bench_args.run_name,
result_filename=bench_args.result_filename,
tokenizer=tokenizer,
)
)
if bench_args.profile:
try:
for bs, il, ol in itertools.product(
bench_args.batch_size, bench_args.input_len, bench_args.output_len
):
bench_result.append(
(
run_one_case(
base_url,
bs,
il,
ol,
temperature=bench_args.temperature,
return_logprob=bench_args.return_logprob,
stream_interval=bench_args.client_stream_interval,
input_len_step_percentage=bench_args.input_len_step_percentage,
run_name=bench_args.run_name,
result_filename=bench_args.result_filename,
tokenizer=tokenizer,
profile=bench_args.profile,
profile_steps=bench_args.profile_steps,
profile_by_stage=bench_args.profile_by_stage,
)[-1],
)
)
result = [t1[:-1] + t2 for t1, t2 in zip(result, bench_result)]
except Exception as e:
print(f"Error profiling, there will be no profile trace dump: {e}")
finally:
if proc:
kill_process_tree(proc.pid)
print(f"\nResults are saved to {bench_args.result_filename}")
if not bench_args.show_report:
return
summary = get_report_summary(result, server_args, bench_args)
print(summary)
if is_in_ci():
write_github_step_summary(summary)
def main():
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
BenchArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
bench_args = BenchArgs.from_cli_args(args)
run_benchmark(server_args, bench_args)
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

305
python/sglang/check_env.py Normal file
View File

@@ -0,0 +1,305 @@
"""Check environment configurations and dependency versions."""
import importlib.metadata
import os
import resource
import subprocess
import sys
from collections import OrderedDict, defaultdict
import torch
from sglang.srt.utils import is_hip
def is_cuda_v2():
return torch.version.cuda is not None
# List of packages to check versions
PACKAGE_LIST = [
"sglang",
"sgl_kernel",
"flashinfer_python",
"triton",
"transformers",
"torchao",
"numpy",
"aiohttp",
"fastapi",
"hf_transfer",
"huggingface_hub",
"interegular",
"modelscope",
"orjson",
"outlines",
"packaging",
"psutil",
"pydantic",
"python-multipart",
"pyzmq",
"torchao",
"uvicorn",
"uvloop",
"vllm",
"xgrammar",
"openai",
"tiktoken",
"anthropic",
"litellm",
"decord",
]
def get_package_versions(packages):
"""
Get versions of specified packages.
"""
versions = {}
for package in packages:
package_name = package.split("==")[0].split(">=")[0].split("<=")[0]
try:
version = importlib.metadata.version(package_name)
versions[package_name] = version
except ModuleNotFoundError:
versions[package_name] = "Module Not Found"
return versions
def get_cuda_info():
"""
Get CUDA-related information if available.
"""
if is_cuda_v2():
cuda_info = {"CUDA available": torch.cuda.is_available()}
if cuda_info["CUDA available"]:
cuda_info.update(_get_gpu_info())
cuda_info.update(_get_cuda_version_info())
return cuda_info
elif is_hip():
cuda_info = {"ROCM available": torch.cuda.is_available()}
if cuda_info["ROCM available"]:
cuda_info.update(_get_gpu_info())
cuda_info.update(_get_cuda_version_info())
return cuda_info
def _get_gpu_info():
"""
Get information about available GPUs.
"""
devices = defaultdict(list)
capabilities = defaultdict(list)
for k in range(torch.cuda.device_count()):
devices[torch.cuda.get_device_name(k)].append(str(k))
capability = torch.cuda.get_device_capability(k)
capabilities[f"{capability[0]}.{capability[1]}"].append(str(k))
gpu_info = {}
for name, device_ids in devices.items():
gpu_info[f"GPU {','.join(device_ids)}"] = name
if len(capabilities) == 1:
# All GPUs have the same compute capability
cap, gpu_ids = list(capabilities.items())[0]
gpu_info[f"GPU {','.join(gpu_ids)} Compute Capability"] = cap
else:
# GPUs have different compute capabilities
for cap, gpu_ids in capabilities.items():
gpu_info[f"GPU {','.join(gpu_ids)} Compute Capability"] = cap
return gpu_info
def _get_cuda_version_info():
"""
Get CUDA version information.
"""
if is_cuda_v2():
from torch.utils.cpp_extension import CUDA_HOME
cuda_info = {"CUDA_HOME": CUDA_HOME}
if CUDA_HOME and os.path.isdir(CUDA_HOME):
cuda_info.update(_get_nvcc_info())
cuda_info.update(_get_cuda_driver_version())
return cuda_info
elif is_hip():
from torch.utils.cpp_extension import ROCM_HOME as ROCM_HOME
cuda_info = {"ROCM_HOME": ROCM_HOME}
if ROCM_HOME and os.path.isdir(ROCM_HOME):
cuda_info.update(_get_nvcc_info())
cuda_info.update(_get_cuda_driver_version())
return cuda_info
else:
cuda_info = {"CUDA_HOME": ""}
return cuda_info
def _get_nvcc_info():
"""
Get NVCC version information.
"""
if is_cuda_v2():
from torch.utils.cpp_extension import CUDA_HOME
try:
nvcc = os.path.join(CUDA_HOME, "bin/nvcc")
nvcc_output = (
subprocess.check_output(f'"{nvcc}" -V', shell=True)
.decode("utf-8")
.strip()
)
return {
"NVCC": nvcc_output[
nvcc_output.rfind("Cuda compilation tools") : nvcc_output.rfind(
"Build"
)
].strip()
}
except subprocess.SubprocessError:
return {"NVCC": "Not Available"}
elif is_hip():
from torch.utils.cpp_extension import ROCM_HOME
try:
hipcc = os.path.join(ROCM_HOME, "bin/hipcc")
hipcc_output = (
subprocess.check_output(f'"{hipcc}" --version', shell=True)
.decode("utf-8")
.strip()
)
return {
"HIPCC": hipcc_output[
hipcc_output.rfind("HIP version") : hipcc_output.rfind("AMD clang")
].strip()
}
except subprocess.SubprocessError:
return {"HIPCC": "Not Available"}
else:
return {"NVCC": "Not Available"}
def _get_cuda_driver_version():
"""
Get CUDA driver version.
"""
versions = set()
if is_cuda_v2():
try:
output = subprocess.check_output(
[
"nvidia-smi",
"--query-gpu=driver_version",
"--format=csv,noheader,nounits",
]
)
versions = set(output.decode().strip().split("\n"))
if len(versions) == 1:
return {"CUDA Driver Version": versions.pop()}
else:
return {"CUDA Driver Versions": ", ".join(sorted(versions))}
except subprocess.SubprocessError:
return {"CUDA Driver Version": "Not Available"}
elif is_hip():
try:
output = subprocess.check_output(
[
"rocm-smi",
"--showdriverversion",
"--csv",
]
)
versions = set(output.decode().strip().split("\n"))
versions.discard("name, value")
ver = versions.pop()
ver = ver.replace('"Driver version", ', "").replace('"', "")
return {"ROCM Driver Version": ver}
except subprocess.SubprocessError:
return {"ROCM Driver Version": "Not Available"}
else:
return {"CUDA Driver Version": "Not Available"}
def get_gpu_topology():
"""
Get GPU topology information.
"""
if is_cuda_v2():
try:
result = subprocess.run(
["nvidia-smi", "topo", "-m"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
check=True,
)
return "\n" + result.stdout if result.returncode == 0 else None
except subprocess.SubprocessError:
return None
elif is_hip():
try:
result = subprocess.run(
["rocm-smi", "--showtopotype"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
check=True,
)
return "\n" + result.stdout if result.returncode == 0 else None
except subprocess.SubprocessError:
return None
else:
return None
def get_hypervisor_vendor():
try:
output = subprocess.check_output(["lscpu"], text=True)
for line in output.split("\n"):
if "Hypervisor vendor:" in line:
return line.split(":")[1].strip()
return None
except:
return None
def check_env():
"""
Check and print environment information.
"""
env_info = OrderedDict()
env_info["Python"] = sys.version.replace("\n", "")
env_info.update(get_cuda_info())
env_info["PyTorch"] = torch.__version__
env_info.update(get_package_versions(PACKAGE_LIST))
gpu_topo = get_gpu_topology()
if gpu_topo:
if is_cuda_v2():
env_info["NVIDIA Topology"] = gpu_topo
elif is_hip():
env_info["AMD Topology"] = gpu_topo
hypervisor_vendor = get_hypervisor_vendor()
if hypervisor_vendor:
env_info["Hypervisor vendor"] = hypervisor_vendor
ulimit_soft, _ = resource.getrlimit(resource.RLIMIT_NOFILE)
env_info["ulimit soft"] = ulimit_soft
for k, v in env_info.items():
print(f"{k}: {v}")
if __name__ == "__main__":
check_env()

View File

@@ -0,0 +1,184 @@
"""
Compile DeepGEMM Kernels for a model with specify server arguments
This script launches a server for capturing DeepGEMM calls and then compiles the kernels.
It accepts server arguments (the same as launch_server.py).
Usage:
python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code
"""
import argparse
import dataclasses
import multiprocessing
import os
import time
import requests
from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST
from sglang.srt.entrypoints.http_server import launch_server
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import kill_process_tree
from sglang.srt.warmup import warmup
multiprocessing.set_start_method("spawn", force=True)
# Reduce warning
os.environ["SGL_IN_DEEPGEMM_PRECOMPILE_STAGE"] = "1"
# Force enable deep gemm
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "1"
# Force enable mha chunked kv for DeepSeek V3 to avoid missing kv_b_proj DeepGEMM case
os.environ["SGL_CHUNKED_PREFIX_CACHE_THRESHOLD"] = "0"
@dataclasses.dataclass
class CompileArgs:
timeout: int = 3600
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument("--timeout", type=int, default=CompileArgs.timeout)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
# use the default value's type to cast the args into correct types.
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
return cls(
**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
)
@warmup("compile-deep-gemm")
async def warm_up_compile(
disaggregation_mode: str, tokenizer_manager: TokenizerManager
):
print("\nGenerate warm up request for compiling DeepGEMM...\n")
generate_req_input = GenerateReqInput(
input_ids=[0, 1, 2, 3],
sampling_params={
"temperature": 0.0,
"max_new_tokens": 8,
"ignore_eos": True,
},
)
if disaggregation_mode != "null":
generate_req_input.bootstrap_room = 0
generate_req_input.bootstrap_host = FAKE_BOOTSTRAP_HOST
await tokenizer_manager.generate_request(generate_req_input, None).__anext__()
def launch_server_internal(server_args):
try:
launch_server(server_args)
except Exception as e:
raise e
finally:
kill_process_tree(os.getpid(), include_parent=False)
def launch_server_process_and_send_one_request(
server_args: ServerArgs, compile_args: CompileArgs
):
proc = multiprocessing.Process(target=launch_server_internal, args=(server_args,))
proc.start()
base_url = f"http://{server_args.host}:{server_args.port}"
timeout = compile_args.timeout
start_time = time.perf_counter()
while time.perf_counter() - start_time < timeout:
try:
headers = {
"Content-Type": "application/json; charset=utf-8",
}
if server_args.node_rank == 0:
response = requests.get(f"{base_url}/v1/models", headers=headers)
else:
# This http api is created by launch_dummy_health_check_server for none-rank0 node.
response = requests.get(f"{base_url}/health", headers=headers)
if response.status_code == 200:
# Rank-0 node send a request to sync with other node and then return.
if server_args.node_rank == 0:
response = requests.post(
f"{base_url}/generate",
json={
"input_ids": [0, 1, 2, 3],
"sampling_params": {
"max_new_tokens": 8,
"temperature": 0,
},
},
timeout=600,
)
if response.status_code != 200:
error = response.json()
raise RuntimeError(f"Sync request failed: {error}")
# Other nodes should wait for the exit signal from Rank-0 node.
else:
start_time_waiting = time.perf_counter()
while proc.is_alive():
if time.perf_counter() - start_time_waiting < timeout:
time.sleep(10)
else:
raise TimeoutError("Waiting for main node timeout!")
return proc
except requests.RequestException:
pass
time.sleep(10)
raise TimeoutError(
"DeepGEMM Kernels compilation timeout."
"\n\nFeel free and please restart the command."
)
def refine_server_args(server_args: ServerArgs, compile_args: CompileArgs):
# Disable cuda graph and torch compile to save time
server_args.disable_cuda_graph = True
server_args.enable_torch_compile = False
print(f"Disable CUDA Graph and Torch Compile to save time...")
# Set watchdog timeout to compile_args.timeout because compilation will take a long time
server_args.watchdog_timeout = compile_args.timeout
server_args.warmups = "compile-deep-gemm"
def run_compile(server_args: ServerArgs, compile_args: CompileArgs):
print(
"Begin DeepGEMM Kernels compilation...\n"
"It may take a long time and timeout maybe raised "
"while the compilation is still in progress.\n"
"Just feel free to restart the command "
"until the compilation is fully finished.\n"
)
proc = launch_server_process_and_send_one_request(server_args, compile_args)
print("\nDeepGEMM Kernels compilation finished successfully.")
# Sleep for safety
time.sleep(10)
if proc.is_alive():
# This is the rank0 node.
kill_process_tree(proc.pid)
else:
try:
kill_process_tree(proc.pid)
except Exception:
pass
if __name__ == "__main__":
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
CompileArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
compile_args = CompileArgs.from_cli_args(args)
refine_server_args(server_args, compile_args)
run_compile(server_args, compile_args)

View File

@@ -0,0 +1,315 @@
# Adapt from https://github.com/fw-ai/llm_eval_meta
import argparse
import asyncio
import os
import pickle
import re
import shutil
from collections import defaultdict
from dataclasses import dataclass
import httpx
import numpy as np
import openai
from datasets import load_dataset
from openai import AsyncOpenAI
from tqdm import tqdm
# Mapping providers to their clients and models
provider_to_models = {
"b10": {
"8b": "meta-llama/Llama-3.1-8B-Instruct",
"70b": "meta-llama/Llama-3.1-70B-Instruct",
"405b": "meta-llama/Llama-3.1-405B-Instruct",
},
"oai": {
"8b": "meta-llama/Llama-3.1-8B-Instruct",
"70b": "meta-llama/Llama-3.1-70B-Instruct",
"405b": "meta-llama/Llama-3.1-405B-Instruct",
},
"sgl": {
"8b": "meta-llama/Llama-3.1-8B-Instruct",
"70b": "meta-llama/Llama-3.1-70B-Instruct",
"405b": "meta-llama/Llama-3.1-405B-Instruct",
},
}
async def fetch_responses(
client, prompt, semaphore, index, provider, model_size, output_dir, max_tokens
):
output_file = os.path.join(output_dir, f"response_{index}.pkl")
if os.path.exists(output_file):
print(f"File {output_file} already exists, skipping.")
return
async with semaphore:
response = await client.completions.create(
model=provider_to_models[provider][model_size],
prompt=prompt,
temperature=0.0,
max_tokens=max_tokens,
)
if isinstance(response, openai.BadRequestError):
with open(output_file, "wb") as f:
pickle.dump("bad_response", f)
assert isinstance(response, openai.types.completion.Completion)
# Save response to a file
with open(output_file, "wb") as f:
pickle.dump(response, f)
TASK_TO_MAX_TOKENS = {
"evals__mmlu__details": 1,
"evals__mmlu__0_shot__cot__details": 1024,
# Official meta uses 1024, but a small % (.05) of questions are answered correctly after relaxing
"evals__mmlu_pro__details": 2048,
"evals__gsm8k__details": 1024,
}
TASK_TO_EVAL_SET = {
"mmlu": "evals__mmlu__details",
"mmlu_cot": "evals__mmlu__0_shot__cot__details",
"mmlu_pro": "evals__mmlu_pro__details",
"gsm8k": "evals__gsm8k__details",
}
class CustomAsyncHTTPXClient(httpx.AsyncClient):
async def send(self, request: httpx.Request, *args, **kwargs) -> httpx.Response:
request.url = httpx.URL(
f"https://model-{os.getenv('MODEL_ID')}.api.baseten.co/development/predict"
)
return await super().send(request, *args, **kwargs)
def get_client(provider):
if provider not in "b10":
if os.getenv("OPENAI_API_KEY") == None:
os.environ["OPENAI_API_KEY"] = "EMPTY"
return {
"oai": AsyncOpenAI(base_url="http://127.0.0.1:8000/v1/"),
"b10": AsyncOpenAI(
api_key=f"Api-Key {os.getenv('OPENAI_API_KEY')}",
base_url=f"https://model-{os.getenv('MODEL_ID')}.api.baseten.co/development/predict",
http_client=CustomAsyncHTTPXClient(),
),
"sgl": AsyncOpenAI(base_url="http://127.0.0.1:30000/v1/"),
}[provider]
# Define the benchmark function
async def benchmark(args):
ds = load_dataset(
"meta-llama/Llama-3.1-405B-Instruct-evals",
f"Llama-3.1-405B-Instruct-{TASK_TO_EVAL_SET[args.task]}",
)
semaphore = asyncio.Semaphore(args.concurrency) # Limit to 16 concurrent tasks
if args.num_examples is None:
args.num_examples = len(ds["latest"]["input_final_prompts"])
prompts = ds["latest"]["input_final_prompts"][: args.num_examples]
# Create the output directory if it does not exist
os.makedirs(args.output_dir, exist_ok=True)
tasks = []
# Create the tasks with tqdm progress bar
max_tokens = TASK_TO_MAX_TOKENS[TASK_TO_EVAL_SET[args.task]]
client = get_client(args.provider)
for idx, prompt in enumerate(tqdm(prompts, desc="Creating tasks")):
tasks.append(
asyncio.create_task(
fetch_responses(
client,
f"<|begin_of_text|>{prompt[0]}",
semaphore,
idx,
args.provider,
args.model_size,
args.output_dir,
max_tokens=max_tokens,
)
)
)
# Run the tasks with tqdm progress bar
for future in tqdm(
asyncio.as_completed(tasks), total=len(tasks), desc="Processing tasks"
):
await future
def get_mmlu_answer(response):
if response is not None:
return response.choices[0].text.lstrip().rstrip().upper().replace(".", "")
return None
def get_mmlu_cot_answer(response):
pattern = r"The best answer is (.+)\.?"
match = re.search(pattern, response.choices[0].text)
if match:
return match.group(1).replace(".", "").replace("*", "")
pattern = r"the best answer is (.+)\.?"
match = re.search(pattern, response.choices[0].text)
if match:
return match.group(1).replace(".", "")
pattern = r"The correct answer is (.+)\.?"
match = re.search(pattern, response.choices[0].text)
if match:
return match.group(1).replace(".", "")
pattern = r"the correct answer is (.+)\.?"
match = re.search(pattern, response.choices[0].text)
if match:
return match.group(1).replace(".", "")
def get_answer_gsm8k(response):
pattern = r"The final answer is (.+)\.?"
match = re.search(pattern, response.choices[0].text)
if match:
s = match.group(1)
for ok_symbol in ["%", "$"]:
s = s.replace(ok_symbol, "")
return s
TASK_TO_ANSWER_EXTRACTOR = {
"evals__mmlu__details": get_mmlu_answer,
"evals__mmlu__0_shot__cot__details": get_mmlu_cot_answer,
"evals__gsm8k__details": get_answer_gsm8k,
"evals__mmlu_pro__details": get_mmlu_cot_answer,
}
def get_dataset_from_task(task, response_path, model_size):
ds_405b = load_dataset(
f"meta-llama/Llama-3.1-405B-Instruct-evals",
f"Llama-3.1-405B-Instruct-{task}",
)
ds_405b_hash_order = [x[0] for x in ds_405b["latest"]["input_final_prompts_hash"]]
if "70b" in model_size or "8b" in model_size:
if "70" in model_size:
ref_model_ds = load_dataset(
f"meta-llama/Llama-3.1-70B-Instruct-evals",
f"Llama-3.1-70B-Instruct-{task}",
)
else:
ref_model_ds = load_dataset(
f"meta-llama/Llama-3.1-8B-Instruct-evals",
f"Llama-3.1-8B-Instruct-{task}",
)
hash_to_row = {}
for row in ref_model_ds["latest"]:
hash_to_row[row["input_final_prompts_hash"][0]] = row
reordered_rows = []
for prompt_hash in ds_405b_hash_order:
reordered_rows.append(hash_to_row[prompt_hash])
ref_model_ds["latest"] = reordered_rows
return ref_model_ds
return ds_405b
def analyze(task, response_path, model_size):
ds = get_dataset_from_task(task, response_path, model_size)
responses = []
total = len(ds["latest"])
for i in range(0, total):
response = pickle.load(
open(os.path.join(response_path, f"response_{i}.pkl"), "rb")
)
responses.append(response)
@dataclass
class Stats:
correct: int = 0
total: int = 0
meta_correct: int = 0
average: float = None
subtask_name_to_stats = defaultdict(lambda: Stats())
for response, ds_row in zip(responses, ds["latest"]):
model_answer = TASK_TO_ANSWER_EXTRACTOR[task](response)
subtask = ds_row["subtask_name"]
is_eval_correct = model_answer in ds_row["input_correct_responses"]
if is_eval_correct:
subtask_name_to_stats[subtask].correct += 1
if ds_row["is_correct"]:
subtask_name_to_stats[subtask].meta_correct += 1
subtask_name_to_stats[subtask].total += 1
micro_stats = Stats()
for subtask, stats in subtask_name_to_stats.items():
stats.average = stats.correct / stats.total
stats.meta_average = stats.meta_correct / stats.total
micro_stats.correct += stats.correct
micro_stats.total += stats.total
micro_stats.meta_correct += stats.meta_correct
micro_stats.average = micro_stats.correct / micro_stats.total
micro_stats.meta_average = micro_stats.meta_correct / micro_stats.total
print("Macro average", np.mean([x.average for x in subtask_name_to_stats.values()]))
print(
"Meta Macro average",
np.mean([x.meta_average for x in subtask_name_to_stats.values()]),
)
print("Micro average", micro_stats.average)
print("Meta Micro average", micro_stats.meta_average)
# Entry point for the script
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Script to run model with specified parameters."
)
parser.add_argument(
"--model-size",
type=str,
default="8b",
help="Size of the model (e.g., 8b or 70b)",
)
parser.add_argument(
"--provider",
type=str,
default="sgl",
help="Provider name (e.g., sgl, oai, b10)",
)
parser.add_argument(
"--task",
type=str,
required=True,
help="Task (e.g., mmlu, mmlu_cot, mmlu_pro, gsm8k)",
)
parser.add_argument(
"--num-examples", type=int, default=None, help="Number of examples to process"
)
parser.add_argument("--concurrency", type=int, default=16)
parser.add_argument(
"--output-dir",
type=str,
default="tmp-output-dir",
help="Directory to save responses",
)
args = parser.parse_args()
asyncio.run(benchmark(args))
analyze(TASK_TO_EVAL_SET[args.task], args.output_dir, args.model_size)
shutil.rmtree("tmp-output-dir", ignore_errors=True)

View File

@@ -0,0 +1,164 @@
import argparse
import asyncio
import os
import pickle
from pathlib import Path
from typing import List
import openai
import torch
from bert_score import BERTScorer
from datasets import load_dataset
from tqdm import tqdm
def get_client(api_url: str) -> openai.AsyncOpenAI:
if os.getenv("OPENAI_API_KEY") is None:
os.environ["OPENAI_API_KEY"] = "EMPTY"
return openai.AsyncOpenAI(base_url=api_url)
def get_dataset():
return load_dataset("bigai-nlco/LooGLE", "longdep_qa", split="test")
async def fetch_response(
client: openai.AsyncOpenAI,
context: str,
question: str,
semaphore: asyncio.Semaphore,
index: int,
model: str,
output_dir: Path,
):
output_file = output_dir / f"response_{index}.pkl"
if output_file.exists():
return
prompt = (
"Please answer the question based on the long texts below.\n"
f"{context}\n"
f"Question: {question}\n"
"Answer:"
)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
]
async with semaphore:
try:
response = await client.chat.completions.create(
model=model,
messages=messages,
temperature=0.0,
max_tokens=512,
)
except openai.BadRequestError as e:
with open(output_file, "wb") as f:
pickle.dump({"error": str(e)}, f)
return
with open(output_file, "wb") as f:
pickle.dump(response, f)
async def benchmark(args):
dataset = get_dataset()
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
client = get_client(args.api_url)
semaphore = asyncio.Semaphore(args.max_concurrency)
tasks: List[asyncio.Task] = []
for idx, ex in enumerate(dataset):
if idx >= args.num_prompts:
break
tasks.append(
asyncio.create_task(
fetch_response(
client,
ex["context"],
ex["question"],
semaphore,
idx,
args.model,
output_dir,
)
)
)
for _ in tqdm(
asyncio.as_completed(tasks), total=len(tasks), desc="Running benchmark"
):
await _
def analyse(args):
dataset = get_dataset()
output_dir = Path(args.output_dir)
device = "cuda" if torch.cuda.is_available() else "cpu"
scorer = BERTScorer(lang="en", device=device)
hyps: List[str] = []
refs: List[str] = []
for idx, ex in enumerate(tqdm(dataset, desc="Loading responses")):
if idx >= args.num_prompts:
break
pkl_file = output_dir / f"response_{idx}.pkl"
if not pkl_file.exists():
raise FileNotFoundError(pkl_file)
response = pickle.load(open(pkl_file, "rb"))
if isinstance(response, dict) and "error" in response:
continue
hyps.append(response.choices[0].message.content.strip())
refs.append(ex["answer"])
if not hyps:
print("No valid responses to score!")
return
batch_size = 64
all_f1: List[float] = []
for i in tqdm(range(0, len(hyps), batch_size), desc="Scoring batches"):
h_batch = hyps[i : i + batch_size]
r_batch = refs[i : i + batch_size]
_, _, f1_scores = scorer.score(h_batch, r_batch, verbose=False)
all_f1.extend([float(x) for x in f1_scores])
avg = sum(all_f1) / len(all_f1)
print(f"Average BERTScore (F1): {avg:.2%}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run benchmark and evaluation in one go."
)
parser.add_argument(
"--api-url",
default="http://127.0.0.1:30000/v1",
help="OpenAIcompatible API base URL",
)
parser.add_argument(
"--model",
default="meta-llama/Llama-4-Maverick-17B-128E-Instruct",
help="Model name or ID, only used for model name",
)
parser.add_argument(
"--max-concurrency", type=int, default=144, help="Maximum concurrent requests"
)
parser.add_argument(
"--output-dir", default="tmp-output-dir", help="Directory for cached responses"
)
parser.add_argument(
"--num-prompts", type=int, default=10000, help="Number of prompts to run"
)
args = parser.parse_args()
asyncio.run(benchmark(args))
analyse(args)

View File

@@ -0,0 +1,53 @@
"""Global configurations"""
import os
class GlobalConfig:
"""
Store some global constants.
See also python/sglang/srt/managers/schedule_batch.py::global_server_args_dict, which stores
many global runtime arguments as well.
"""
def __init__(self):
# Verbosity level
# 0: do not output anything
# 2: output final text after every run
self.verbosity = 0
# Default backend of the language
self.default_backend = None
# Runtime constants: New generation token ratio estimation
self.default_init_new_token_ratio = float(
os.environ.get("SGLANG_INIT_NEW_TOKEN_RATIO", 0.7)
)
self.default_min_new_token_ratio_factor = float(
os.environ.get("SGLANG_MIN_NEW_TOKEN_RATIO_FACTOR", 0.14)
)
self.default_new_token_ratio_decay_steps = float(
os.environ.get("SGLANG_NEW_TOKEN_RATIO_DECAY_STEPS", 600)
)
self.torch_empty_cache_interval = float(
os.environ.get(
"SGLANG_EMPTY_CACHE_INTERVAL", -1
) # in seconds. Set if you observe high memory accumulation over a long serving period.
)
# Runtime constants: others
self.retract_decode_steps = 20
self.flashinfer_workspace_size = os.environ.get(
"FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024
)
# Output tokenization configs
self.skip_special_tokens_in_output = True
self.spaces_between_special_tokens_in_out = True
# Language frontend interpreter optimization configs
self.enable_precache_with_tracing = True
self.enable_parallel_encoding = True
global_config = GlobalConfig()

Binary file not shown.

Binary file not shown.

Binary file not shown.

286
python/sglang/lang/api.py Normal file
View File

@@ -0,0 +1,286 @@
"""Public APIs of the language."""
import re
from typing import Callable, List, Optional, Union
from sglang.global_config import global_config
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.choices import ChoicesSamplingMethod, token_length_normalized
from sglang.lang.ir import (
SglExpr,
SglExprList,
SglFunction,
SglGen,
SglImage,
SglRoleBegin,
SglRoleEnd,
SglSelect,
SglSeparateReasoning,
SglVideo,
)
def function(
func: Optional[Callable] = None, num_api_spec_tokens: Optional[int] = None
):
if func:
return SglFunction(func, num_api_spec_tokens=num_api_spec_tokens)
def decorator(func):
return SglFunction(func, num_api_spec_tokens=num_api_spec_tokens)
return decorator
def Runtime(*args, **kwargs):
# Avoid importing unnecessary dependency
from sglang.lang.backend.runtime_endpoint import Runtime
return Runtime(*args, **kwargs)
def Engine(*args, **kwargs):
# Avoid importing unnecessary dependency
from sglang.srt.entrypoints.engine import Engine
return Engine(*args, **kwargs)
def set_default_backend(backend: BaseBackend):
global_config.default_backend = backend
def flush_cache(backend: Optional[BaseBackend] = None):
backend = backend or global_config.default_backend
if backend is None:
return False
# If backend is Runtime
if hasattr(backend, "endpoint"):
backend = backend.endpoint
return backend.flush_cache()
def get_server_info(backend: Optional[BaseBackend] = None):
backend = backend or global_config.default_backend
if backend is None:
return None
# If backend is Runtime
if hasattr(backend, "endpoint"):
backend = backend.endpoint
return backend.get_server_info()
def gen(
name: Optional[str] = None,
max_tokens: Optional[int] = None,
min_tokens: Optional[int] = None,
n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
min_p: Optional[float] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
ignore_eos: Optional[bool] = None,
return_logprob: Optional[bool] = None,
logprob_start_len: Optional[int] = None,
top_logprobs_num: Optional[int] = None,
return_text_in_logprobs: Optional[bool] = None,
dtype: Optional[Union[type, str]] = None,
choices: Optional[List[str]] = None,
choices_method: Optional[ChoicesSamplingMethod] = None,
regex: Optional[str] = None,
json_schema: Optional[str] = None,
):
"""Call the model to generate. See the meaning of the arguments in docs/backend/sampling_params.md"""
if choices:
return SglSelect(
name,
choices,
0.0 if temperature is None else temperature,
token_length_normalized if choices_method is None else choices_method,
)
# check regex is valid
if regex is not None:
try:
re.compile(regex)
except re.error as e:
raise e
return SglGen(
name,
max_tokens,
min_tokens,
n,
stop,
stop_token_ids,
temperature,
top_p,
top_k,
min_p,
frequency_penalty,
presence_penalty,
ignore_eos,
return_logprob,
logprob_start_len,
top_logprobs_num,
return_text_in_logprobs,
dtype,
regex,
json_schema,
)
def gen_int(
name: Optional[str] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
min_p: Optional[float] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
ignore_eos: Optional[bool] = None,
return_logprob: Optional[bool] = None,
logprob_start_len: Optional[int] = None,
top_logprobs_num: Optional[int] = None,
return_text_in_logprobs: Optional[bool] = None,
):
return SglGen(
name,
max_tokens,
None,
n,
stop,
stop_token_ids,
temperature,
top_p,
top_k,
min_p,
frequency_penalty,
presence_penalty,
ignore_eos,
return_logprob,
logprob_start_len,
top_logprobs_num,
return_text_in_logprobs,
int,
None,
)
def gen_string(
name: Optional[str] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
min_p: Optional[float] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
ignore_eos: Optional[bool] = None,
return_logprob: Optional[bool] = None,
logprob_start_len: Optional[int] = None,
top_logprobs_num: Optional[int] = None,
return_text_in_logprobs: Optional[bool] = None,
):
return SglGen(
name,
max_tokens,
None,
n,
stop,
stop_token_ids,
temperature,
top_p,
top_k,
min_p,
frequency_penalty,
presence_penalty,
ignore_eos,
return_logprob,
logprob_start_len,
top_logprobs_num,
return_text_in_logprobs,
str,
None,
)
def image(expr: SglExpr):
return SglImage(expr)
def video(path: str, num_frames: int):
return SglVideo(path, num_frames)
def select(
name: Optional[str] = None,
choices: Optional[List[str]] = None,
temperature: float = 0.0,
choices_method: ChoicesSamplingMethod = token_length_normalized,
):
assert choices is not None
return SglSelect(name, choices, temperature, choices_method)
def _role_common(name: str, expr: Optional[SglExpr] = None):
if expr is None:
return SglExprList([SglRoleBegin(name), SglRoleEnd(name)])
else:
return SglExprList([SglRoleBegin(name), expr, SglRoleEnd(name)])
def system(expr: Optional[SglExpr] = None):
return _role_common("system", expr)
def user(expr: Optional[SglExpr] = None):
return _role_common("user", expr)
def assistant(expr: Optional[SglExpr] = None):
return _role_common("assistant", expr)
def system_begin():
return SglRoleBegin("system")
def system_end():
return SglRoleEnd("system")
def user_begin():
return SglRoleBegin("user")
def user_end():
return SglRoleEnd("user")
def assistant_begin():
return SglRoleBegin("assistant")
def assistant_end():
return SglRoleEnd("assistant")
def separate_reasoning(
expr: Optional[SglExpr] = None, model_type: Optional[str] = None
):
return SglExprList([expr, SglSeparateReasoning(model_type, expr=expr)])

View File

@@ -0,0 +1,73 @@
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams
try:
import anthropic
except ImportError as e:
anthropic = e
class Anthropic(BaseBackend):
def __init__(self, model_name, *args, **kwargs):
super().__init__()
if isinstance(anthropic, Exception):
raise anthropic
self.model_name = model_name
self.chat_template = get_chat_template("claude")
self.client = anthropic.Anthropic(*args, **kwargs)
def get_chat_template(self):
return self.chat_template
def generate(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
if s.messages_:
messages = s.messages_
else:
messages = [{"role": "user", "content": s.text_}]
if messages and messages[0]["role"] == "system":
system = messages.pop(0)["content"]
else:
system = ""
ret = self.client.messages.create(
model=self.model_name,
system=system,
messages=messages,
**sampling_params.to_anthropic_kwargs(),
)
comp = ret.content[0].text
return comp, {}
def generate_stream(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
if s.messages_:
messages = s.messages_
else:
messages = [{"role": "user", "content": s.text_}]
if messages and messages[0]["role"] == "system":
system = messages.pop(0)["content"]
else:
system = ""
with self.client.messages.stream(
model=self.model_name,
system=system,
messages=messages,
**sampling_params.to_anthropic_kwargs(),
) as stream:
for text in stream.text_stream:
yield text, {}

View File

@@ -0,0 +1,82 @@
from typing import List, Optional, Union
from sglang.lang.chat_template import get_chat_template
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams
class BaseBackend:
def __init__(self) -> None:
self.support_concate_and_append = False
self.chat_template = get_chat_template("default")
def get_model_name(self):
raise NotImplementedError()
def get_chat_template(self):
return self.chat_template
def cache_prefix(self, prefix_str: str):
pass
def uncache_prefix(self, rid: str):
pass
def end_request(self, rid: Union[str, List[str]]):
pass
def begin_program(self, s: StreamExecutor):
pass
def end_program(self, s: Union[StreamExecutor, List[StreamExecutor]]):
pass
def commit_lazy_operations(self, s: StreamExecutor):
pass
def fork_program(
self,
src: StreamExecutor,
dst: List[StreamExecutor],
position_ids_offset: Optional[List[int]] = None,
):
pass
def fill_image(self, s: StreamExecutor):
pass
def generate(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
raise NotImplementedError()
def generate_stream(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
raise NotImplementedError()
def select(
self,
s: StreamExecutor,
choices: List[str],
temperature: float,
choices_method: Optional[ChoicesSamplingMethod] = None,
) -> ChoicesDecision:
raise NotImplementedError()
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
raise NotImplementedError()
def shutdown(self):
pass
def flush_cache(self):
pass
def get_server_info(self):
pass

View File

@@ -0,0 +1,90 @@
from typing import Mapping, Optional
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template_by_model_path
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams
try:
import litellm
except ImportError as e:
litellm = e
litellm.num_retries = 1
class LiteLLM(BaseBackend):
def __init__(
self,
model_name,
chat_template=None,
api_key=None,
organization: Optional[str] = None,
base_url: Optional[str] = None,
timeout: Optional[float] = 600,
max_retries: Optional[int] = litellm.num_retries,
default_headers: Optional[Mapping[str, str]] = None,
):
super().__init__()
if isinstance(litellm, Exception):
raise litellm
self.model_name = model_name
self.chat_template = chat_template or get_chat_template_by_model_path(
model_name
)
self.client_params = {
"api_key": api_key,
"organization": organization,
"base_url": base_url,
"timeout": timeout,
"max_retries": max_retries,
"default_headers": default_headers,
}
def get_chat_template(self):
return self.chat_template
def generate(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
if s.messages_:
messages = s.messages_
else:
messages = [{"role": "user", "content": s.text_}]
ret = litellm.completion(
model=self.model_name,
messages=messages,
**self.client_params,
**sampling_params.to_litellm_kwargs(),
)
comp = ret.choices[0].message.content
return comp, {}
def generate_stream(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
if s.messages_:
messages = s.messages_
else:
messages = [{"role": "user", "content": s.text_}]
ret = litellm.completion(
model=self.model_name,
messages=messages,
stream=True,
**self.client_params,
**sampling_params.to_litellm_kwargs(),
)
for chunk in ret:
text = chunk.choices[0].delta.content
if text is not None:
yield text, {}

View File

@@ -0,0 +1,475 @@
import dataclasses
import logging
import time
import warnings
from typing import List, Optional, Union
import numpy as np
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import ChatTemplate, get_chat_template_by_model_path
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams
try:
import openai
import tiktoken
except ImportError as e:
openai = tiktoken = e
logger = logging.getLogger(__name__)
def create_logit_bias_int(tokenizer):
"""Get logit bias for integer numbers."""
int_token_ids = []
tokens = tokenizer._mergeable_ranks
for token, token_id in tokens.items():
s = tokenizer.decode([token_id])
if all([c.isdigit() for c in s]) or s in [" "]:
int_token_ids.append(token_id)
if len(int_token_ids) >= 300: # OpenAI API limit
break
special_tokens = tokenizer._special_tokens
mask = {t: 100 for t in int_token_ids[:299]}
mask[special_tokens["<|endoftext|>"]] = 100
return mask
INSTRUCT_MODEL_NAMES = [
"gpt-3.5-turbo-instruct",
]
@dataclasses.dataclass
class TokenUsage:
prompt_tokens: int
completion_tokens: int
def reset(self):
self.prompt_tokens = self.completion_tokens = 0
class OpenAI(BaseBackend):
def __init__(
self,
model_name: str,
is_chat_model: Optional[bool] = None,
chat_template: Optional[ChatTemplate] = None,
is_azure: bool = False,
*args,
**kwargs,
):
super().__init__()
if isinstance(openai, Exception):
raise openai
if is_azure:
self.client = openai.AzureOpenAI(*args, **kwargs)
else:
self.client = openai.OpenAI(*args, **kwargs)
self.model_name = model_name
try:
self.tokenizer = tiktoken.encoding_for_model(model_name)
except KeyError:
self.tokenizer = tiktoken.get_encoding("cl100k_base")
self.logit_bias_int = create_logit_bias_int(self.tokenizer)
self.chat_template = chat_template or get_chat_template_by_model_path(
model_name
)
if is_chat_model is not None:
self.is_chat_model = is_chat_model
else:
if model_name in INSTRUCT_MODEL_NAMES:
self.is_chat_model = False
else:
self.is_chat_model = True
self.chat_prefix = self.chat_template.role_prefix_and_suffix["assistant"][0]
# Usage
self.token_usage = TokenUsage(0, 0)
# API speculative execution
# TODO(ying): This does not support multi-threading (run_batch)
self.spec_kwargs = {}
self.spec_format = []
self.spec_max_num_tries = 3
def get_chat_template(self):
return self.chat_template
def _prepare_spec_execution(
self,
sampling_params: SglSamplingParams,
num_api_spec_tokens: int,
spec_var_name: str,
):
if "max_tokens" not in self.spec_kwargs:
self.spec_kwargs["max_tokens"] = num_api_spec_tokens
else:
assert self.spec_kwargs["max_tokens"] == num_api_spec_tokens
params = sampling_params.to_openai_kwargs()
for key, value in params.items():
if key in ["stop"]:
continue
if key in ["max_tokens"]:
warnings.warn(
"The parameter max_tokens will be overwritten by speculated number of tokens."
)
continue
if key not in self.spec_kwargs:
self.spec_kwargs[key] = value
else:
assert (
value == self.spec_kwargs[key]
), "sampling parameters should be consistent if turn on api speculative execution."
self.spec_format.append(
{"text": "", "stop": params["stop"], "name": spec_var_name}
)
return "", {}
def generate(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
spec_var_name: str = None,
):
if sampling_params.dtype is None:
if self.is_chat_model:
if s.num_api_spec_tokens is None:
if not s.text_.endswith(self.chat_prefix):
raise RuntimeError(
"This use case is not supported if api speculative execution is off. "
"For OpenAI chat models, sgl.gen must be right after sgl.assistant. "
"Example of adding api speculative execution: @function(num_api_spec_tokens=128)."
)
prompt = s.messages_
else:
return self._prepare_spec_execution(
sampling_params, s.num_api_spec_tokens, spec_var_name
)
else:
prompt = s.text_
kwargs = sampling_params.to_openai_kwargs()
if (
self.model_name.startswith("o1")
or self.model_name.startswith("o3")
or "o1" in self.model_name
):
kwargs.pop("max_tokens", None)
else:
kwargs.pop("max_completion_tokens", None)
comp = openai_completion(
client=self.client,
token_usage=self.token_usage,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=prompt,
**kwargs,
)
# Keep the returned list (or string) as is.
elif sampling_params.dtype in [str, "str", "string"]:
assert (
not self.is_chat_model
), "constrained type not supported on chat model"
kwargs = sampling_params.to_openai_kwargs()
kwargs.pop("stop")
comp = openai_completion(
client=self.client,
token_usage=self.token_usage,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=s.text_ + '"',
stop='"',
**kwargs,
)
# Wrap each element in quotes if we have a list.
if isinstance(comp, list):
comp = ['"' + x + '"' for x in comp]
else:
comp = '"' + comp + '"'
elif sampling_params.dtype in [int, "int"]:
assert (
not self.is_chat_model
), "constrained type not supported on chat model"
kwargs = sampling_params.to_openai_kwargs()
kwargs.pop("stop")
comp = openai_completion(
client=self.client,
token_usage=self.token_usage,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=s.text_,
logit_bias=self.logit_bias_int,
stop=[" "],
**kwargs,
)
# Leave as a list if that's what is returned.
else:
raise ValueError(f"Unknown dtype: {sampling_params.dtype}")
return comp, {}
def spec_fill(self, value: str):
assert self.is_chat_model
self.spec_format.append({"text": value, "stop": None, "name": None})
def spec_pattern_match(self, comp):
for i, term in enumerate(self.spec_format):
text = term["text"]
if text != "":
if comp.startswith(text):
comp = comp[len(text) :]
else:
return False
else:
pos = comp.find(term["stop"])
if pos != -1:
term["text"] = comp[:pos]
comp = comp[pos:]
else:
if i == len(self.spec_format) - 1:
term["text"] = comp
else:
return False
return True
def role_end_generate(
self,
s: StreamExecutor,
):
if s.num_api_spec_tokens is None or not s.text_.endswith(self.chat_prefix):
return
comp = ""
if not all(x["name"] is None for x in self.spec_format):
# TODO(ying): throw errors or warnings
for i in range(self.spec_max_num_tries):
comp = openai_completion(
client=self.client,
token_usage=self.token_usage,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=s.messages_,
**self.spec_kwargs,
)
# Use a string for pattern matching.
comp_for_match = comp[0] if isinstance(comp, list) else comp
if self.spec_pattern_match(comp_for_match):
break
for term in self.spec_format:
s.text_ += term["text"]
name = term["name"]
if name is not None:
s.variables[name] = term["text"]
s.meta_info[name] = {}
s.variable_event[name].set()
self.spec_kwargs = {}
self.spec_format = []
def generate_stream(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
if sampling_params.dtype is None:
if self.is_chat_model:
if not s.text_.endswith(self.chat_prefix):
raise RuntimeError(
"This use case is not supported. "
"For OpenAI chat models, sgl.gen must be right after sgl.assistant"
)
prompt = s.messages_
else:
prompt = s.text_
kwargs = sampling_params.to_openai_kwargs()
generator = openai_completion_stream(
client=self.client,
token_usage=self.token_usage,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=prompt,
**kwargs,
)
return generator
else:
raise ValueError(f"Unknown dtype: {sampling_params.dtype}")
def select(
self,
s: StreamExecutor,
choices: List[str],
temperature: float,
choices_method: ChoicesSamplingMethod,
) -> ChoicesDecision:
"""Note: `choices_method` is not used by the OpenAI backend."""
if self.is_chat_model:
raise NotImplementedError(
"select/choices is not supported for chat models. "
"Please try to use a non-chat model such as gpt-3.5-turbo-instruct"
)
n_choices = len(choices)
token_ids = [self.tokenizer.encode(x) for x in choices]
scores = [0] * n_choices
valid = [len(x) > 0 for x in token_ids]
prompt_tokens = self.tokenizer.encode(s.text_)
max_len = max([len(x) for x in token_ids])
for step in range(max_len):
# Build logit bias
logit_bias = {}
for i in range(n_choices):
if valid[i]:
logit_bias[token_ids[i][step]] = 100
# Call API
ret = self.client.completions.create(
model=self.model_name,
prompt=prompt_tokens,
logit_bias=logit_bias,
max_tokens=1,
temperature=temperature,
)
ret_str = ret.choices[0].text
ret_token = self.tokenizer.encode(ret_str)[0]
self.token_usage.prompt_tokens += ret.usage.prompt_tokens
self.token_usage.completion_tokens = ret.usage.completion_tokens
# TODO:
# 1. return logits as the scores
# 2. compute logits of the full choice
# 3. consider chunk-based decoding
# Update valid
hit = False
for i in range(n_choices):
if valid[i]:
if step == len(token_ids[i]) - 1:
valid[i] = False
if ret_token == token_ids[i][step]:
scores[i] += 1
hit = True
else:
valid[i] = False
assert hit
if np.sum(valid) <= 1:
break
prompt_tokens.append(ret_token)
return ChoicesDecision(
decision=choices[np.argmax(scores)],
meta_info={"scores": scores},
)
def openai_completion(
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
) -> Union[str, List[str]]:
# if "ebnf" is in kwargs, warn and remove
if "ebnf" in kwargs:
warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.")
del kwargs["ebnf"]
for attempt in range(retries):
try:
if is_chat:
if "stop" in kwargs and kwargs["stop"] is None:
kwargs.pop("stop")
ret = client.chat.completions.create(messages=prompt, **kwargs)
if len(ret.choices) == 1:
comp = ret.choices[0].message.content
else:
comp = [c.message.content for c in ret.choices]
else:
ret = client.completions.create(prompt=prompt, **kwargs)
if isinstance(prompt, (list, tuple)):
comp = [c.text for c in ret.choices]
else:
comp = ret.choices[0].text
if len(ret.choices) > 1:
comp = [c.text for c in ret.choices]
token_usage.prompt_tokens += ret.usage.prompt_tokens
token_usage.completion_tokens += ret.usage.completion_tokens
break
except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")
time.sleep(5)
if attempt == retries - 1:
raise e
except Exception as e:
logger.error(f"RuntimeError {e}.")
raise e
return comp
def openai_completion_stream(
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
):
# if "ebnf" is in kwargs, warn and remove
if "ebnf" in kwargs:
warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.")
del kwargs["ebnf"]
for attempt in range(retries):
try:
if is_chat:
if "stop" in kwargs and kwargs["stop"] is None:
kwargs.pop("stop")
generator = client.chat.completions.create(
messages=prompt,
stream=True,
stream_options={"include_usage": True},
**kwargs,
)
for ret in generator:
if len(ret.choices) == 0:
continue
try:
content = ret.choices[0].delta.content
except IndexError:
content = None
yield content or "", {}
else:
generator = client.completions.create(
prompt=prompt,
stream=True,
stream_options={"include_usage": True},
**kwargs,
)
for ret in generator:
if len(ret.choices) == 0:
continue
content = ret.choices[0].text
yield content or "", {}
token_usage.prompt_tokens += ret.usage.prompt_tokens
token_usage.completion_tokens += ret.usage.completion_tokens
break
except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")
time.sleep(5)
if attempt == retries - 1:
raise e
except Exception as e:
logger.error(f"RuntimeError {e}.")
raise e

View File

@@ -0,0 +1,527 @@
import atexit
import json
import multiprocessing
import warnings
from typing import Dict, List, Optional, Union
import aiohttp
import requests
from sglang.global_config import global_config
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import (
REGEX_BOOL,
REGEX_FLOAT,
REGEX_INT,
REGEX_STR,
SglSamplingParams,
)
from sglang.utils import http_request
class RuntimeEndpoint(BaseBackend):
def __init__(
self,
base_url: str,
api_key: Optional[str] = None,
verify: Optional[str] = None,
chat_template_name: Optional[str] = None,
):
super().__init__()
self.support_concate_and_append = True
self.base_url = base_url
self.api_key = api_key
self.verify = verify
res = http_request(
self.base_url + "/get_model_info",
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
self.model_info = res.json()
if chat_template_name:
self.chat_template = get_chat_template(chat_template_name)
else:
self.chat_template = get_chat_template_by_model_path(
self.model_info["model_path"]
)
def get_model_name(self):
return self.model_info["model_path"]
def flush_cache(self):
res = http_request(
self.base_url + "/flush_cache",
api_key=self.api_key,
verify=self.verify,
method="POST",
)
self._assert_success(res)
def get_server_info(self):
res = http_request(
self.base_url + "/get_server_info",
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
return res.json()
def get_chat_template(self):
return self.chat_template
def cache_prefix(self, prefix_str: str):
res = http_request(
self.base_url + "/generate",
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
def start_profile(self):
res = http_request(
self.base_url + "/start_profile",
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
def stop_profile(self):
res = http_request(
self.base_url + "/stop_profile",
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
def commit_lazy_operations(self, s: StreamExecutor):
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
self._add_images(s, data)
res = http_request(
self.base_url + "/generate",
json=data,
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
def fill_image(self, s: StreamExecutor):
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
self._add_images(s, data)
res = http_request(
self.base_url + "/generate",
json=data,
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
def _handle_dtype_to_regex(self, sampling_params: SglSamplingParams):
if sampling_params.dtype is None:
return
if sampling_params.stop == ():
sampling_params.stop = []
dtype_regex = None
if sampling_params.dtype in ["int", int]:
dtype_regex = REGEX_INT
sampling_params.stop.extend([" ", "\n"])
elif sampling_params.dtype in ["float", float]:
dtype_regex = REGEX_FLOAT
sampling_params.stop.extend([" ", "\n"])
elif sampling_params.dtype in ["str", str]:
dtype_regex = REGEX_STR
elif sampling_params.dtype in ["bool", bool]:
dtype_regex = REGEX_BOOL
else:
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
if dtype_regex is not None and sampling_params.regex is not None:
warnings.warn(
f"Both dtype and regex are set. Only dtype will be used. dtype: {sampling_params.dtype}, regex: {sampling_params.regex}"
)
sampling_params.regex = dtype_regex
def generate(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
self._handle_dtype_to_regex(sampling_params)
data = {
"text": s.text_,
"sampling_params": {
"skip_special_tokens": global_config.skip_special_tokens_in_output,
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
**sampling_params.to_srt_kwargs(),
},
}
for item in [
"return_logprob",
"logprob_start_len",
"top_logprobs_num",
"return_text_in_logprobs",
]:
value = getattr(sampling_params, item, None)
if value is not None:
data[item] = value
self._add_images(s, data)
res = http_request(
self.base_url + "/generate",
json=data,
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
obj = res.json()
comp = obj["text"]
return comp, obj["meta_info"]
def generate_stream(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
self._handle_dtype_to_regex(sampling_params)
data = {
"text": s.text_,
"sampling_params": {
"skip_special_tokens": global_config.skip_special_tokens_in_output,
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
**sampling_params.to_srt_kwargs(),
},
}
for item in [
"return_logprob",
"logprob_start_len",
"top_logprobs_num",
"return_text_in_logprobs",
]:
value = getattr(sampling_params, item, None)
if value is not None:
data[item] = value
data["stream"] = True
self._add_images(s, data)
res = http_request(
self.base_url + "/generate",
json=data,
stream=True,
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
pos = 0
for chunk in res.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
chunk_text = data["text"][pos:]
meta_info = data["meta_info"]
pos += len(chunk_text)
yield chunk_text, meta_info
def select(
self,
s: StreamExecutor,
choices: List[str],
temperature: float,
choices_method: ChoicesSamplingMethod,
) -> ChoicesDecision:
assert temperature <= 1e-5
# Cache common prefix
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
obj = self._generate_http_request(s, data)
prompt_len = obj["meta_info"]["prompt_tokens"]
logprob_start_len = max(prompt_len - 2, 0) # For token healing
# Compute logprob
data = {
"text": [s.text_ + c for c in choices],
"sampling_params": {
"max_new_tokens": 0,
"temperature": 0,
},
"return_logprob": True,
"return_text_in_logprobs": True,
"logprob_start_len": logprob_start_len,
}
obj = self._generate_http_request(s, data)
input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
normalized_prompt_logprobs = [
compute_normalized_prompt_logprobs(r["meta_info"]["input_token_logprobs"])
for r in obj
]
# Remove extra token if no token healing occurred
for i in range(len(input_token_logprobs)):
healed_token_str = input_token_logprobs[i][0][-1]
if s.text_.endswith(healed_token_str):
healed_token_logprob = input_token_logprobs[i][0][0]
normalized_prompt_logprobs[i] = (
normalized_prompt_logprobs[i] * len(input_token_logprobs[i])
- healed_token_logprob
) / (len(input_token_logprobs[i]) - 1)
input_token_logprobs[i] = input_token_logprobs[i][1:]
# Compute unconditional logprobs if required
if choices_method.requires_unconditional_logprobs:
input_ids = [[el[1] for el in subl] for subl in input_token_logprobs]
data = {
"input_ids": input_ids,
"sampling_params": {"max_new_tokens": 0},
"return_logprob": True,
}
obj = self._generate_http_request(s, data)
unconditional_token_logprobs = [
r["meta_info"]["input_token_logprobs"] for r in obj
]
else:
unconditional_token_logprobs = None
return choices_method(
choices=choices,
normalized_prompt_logprobs=normalized_prompt_logprobs,
input_token_logprobs=input_token_logprobs,
output_token_logprobs=output_token_logprobs,
unconditional_token_logprobs=unconditional_token_logprobs,
)
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
res = http_request(
self.base_url + "/concate_and_append_request",
json={"src_rids": src_rids, "dst_rid": dst_rid},
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
def _generate_http_request(self, s: StreamExecutor, data):
self._add_images(s, data)
res = http_request(
self.base_url + "/generate",
json=data,
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
return res.json()
def _add_images(self, s: StreamExecutor, data):
if s.images_:
assert len(s.images_) == 1, "Only support one image."
data["image_data"] = s.images_[0][1]
def _assert_success(self, res):
if res.status_code != 200:
try:
content = res.json()
except json.JSONDecodeError:
content = res.text
raise RuntimeError(content)
def compute_normalized_prompt_logprobs(input_logprobs):
values = [x[0] for x in input_logprobs if x[0]]
return sum(values) / len(values)
class Runtime:
"""
A wrapper for the HTTP server.
This is used for launching the server in a python program without
using the command line interface.
It is mainly used for the frontend language.
You should use the Engine class if you want to do normal offline processing without the frontend language.
"""
def __init__(
self,
log_level: str = "error",
*args,
**kwargs,
):
"""See the arguments in server_args.py::ServerArgs"""
# We delay the import of any `sglang.srt` components in `sglang.lang`, so users can run
# client code without installing SRT server and its dependency if they want.
from sglang.srt.entrypoints.http_server import launch_server
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import is_port_available
self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
# Pre-allocate ports
for port in range(self.server_args.port, 40000):
if is_port_available(port):
break
self.server_args.port = port
self.url = self.server_args.url()
self.generate_url = self.url + "/generate"
# NOTE: We store pid instead of proc to fix some issues during __delete__
self.pid = None
pipe_reader, pipe_writer = multiprocessing.Pipe(duplex=False)
ctx = multiprocessing.get_context("spawn")
proc = ctx.Process(
target=launch_server,
args=(self.server_args, pipe_writer),
)
proc.start()
pipe_writer.close()
self.pid = proc.pid
# Before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
atexit.register(self.shutdown)
# TODO: remove this pipe_writer mechanism and use `/health_generate` instead.
try:
init_state = pipe_reader.recv()
except EOFError:
init_state = ""
if init_state != "ready":
self.shutdown()
raise RuntimeError(
"Initialization failed. Please see the error messages above."
)
self.endpoint = RuntimeEndpoint(self.url)
def shutdown(self):
from sglang.srt.utils import kill_process_tree
if self.pid is not None:
kill_process_tree(self.pid)
self.pid = None
def start_profile(self):
self.endpoint.start_profile()
def stop_profile(self):
self.endpoint.stop_profile()
def cache_prefix(self, prefix: str):
self.endpoint.cache_prefix(prefix)
def get_tokenizer(self):
from sglang.srt.hf_transformers_utils import get_tokenizer
return get_tokenizer(
self.server_args.tokenizer_path,
tokenizer_mode=self.server_args.tokenizer_mode,
trust_remote_code=self.server_args.trust_remote_code,
revision=self.server_args.revision,
)
async def async_generate(
self,
prompt: str,
sampling_params: Optional[Dict] = None,
):
if self.server_args.skip_tokenizer_init:
json_data = {
"input_ids": prompt,
"sampling_params": sampling_params,
"stream": True,
}
else:
json_data = {
"text": prompt,
"sampling_params": sampling_params,
"stream": True,
}
pos = 0
timeout = aiohttp.ClientTimeout(total=3 * 3600)
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.post(self.generate_url, json=json_data) as response:
async for chunk, _ in response.content.iter_chunks():
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]\n\n":
break
data = json.loads(chunk[5:].strip("\n"))
if "text" in data:
cur = data["text"][pos:]
if cur:
yield cur
pos += len(cur)
else:
yield data
add_request = async_generate
def generate(
self,
prompt: Union[str, List[str]],
sampling_params: Optional[Dict] = None,
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
lora_path: Optional[List[Optional[str]]] = None,
):
json_data = {
"text": prompt,
"sampling_params": sampling_params,
"return_logprob": return_logprob,
"logprob_start_len": logprob_start_len,
"top_logprobs_num": top_logprobs_num,
"lora_path": lora_path,
}
assert not isinstance(lora_path, list) or len(lora_path) == len(prompt)
response = requests.post(
self.url + "/generate",
json=json_data,
)
return json.dumps(response.json())
def encode(
self,
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
):
json_data = {"text": prompt}
response = requests.post(self.url + "/encode", json=json_data)
return json.dumps(response.json())
async def get_server_info(self):
async with aiohttp.ClientSession() as session:
async with session.get(f"{self.url}/get_server_info") as response:
if response.status == 200:
return await response.json()
else:
error_data = await response.json()
raise RuntimeError(
f"Failed to get server info. {error_data['error']['message']}"
)
def __del__(self):
self.shutdown()

View File

@@ -0,0 +1,148 @@
import os
import warnings
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams
try:
import vertexai
from vertexai.preview.generative_models import (
GenerationConfig,
GenerativeModel,
Image,
)
except ImportError as e:
GenerativeModel = e
class VertexAI(BaseBackend):
def __init__(self, model_name, safety_settings=None):
super().__init__()
if isinstance(GenerativeModel, Exception):
raise GenerativeModel
project_id = os.environ["GCP_PROJECT_ID"]
location = os.environ.get("GCP_LOCATION")
vertexai.init(project=project_id, location=location)
self.model_name = model_name
self.chat_template = get_chat_template("default")
self.safety_settings = safety_settings
def get_chat_template(self):
return self.chat_template
def generate(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
if s.messages_:
prompt = self.messages_to_vertexai_input(s.messages_)
else:
# single-turn
prompt = (
self.text_to_vertexai_input(s.text_, s.cur_images)
if s.cur_images
else s.text_
)
ret = GenerativeModel(self.model_name).generate_content(
prompt,
generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
safety_settings=self.safety_settings,
)
comp = ret.text
return comp, {}
def generate_stream(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
if s.messages_:
prompt = self.messages_to_vertexai_input(s.messages_)
else:
# single-turn
prompt = (
self.text_to_vertexai_input(s.text_, s.cur_images)
if s.cur_images
else s.text_
)
generator = GenerativeModel(self.model_name).generate_content(
prompt,
stream=True,
generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
safety_settings=self.safety_settings,
)
for ret in generator:
yield ret.text, {}
def text_to_vertexai_input(self, text, images):
input = []
# split with image token
text_segs = text.split(self.chat_template.image_token)
for image_path, image_base64_data in images:
text_seg = text_segs.pop(0)
if text_seg != "":
input.append(text_seg)
input.append(Image.from_bytes(image_base64_data))
text_seg = text_segs.pop(0)
if text_seg != "":
input.append(text_seg)
return input
def messages_to_vertexai_input(self, messages):
vertexai_message = []
# from openai message format to vertexai message format
for msg in messages:
if isinstance(msg["content"], str):
text = msg["content"]
else:
text = msg["content"][0]["text"]
if msg["role"] == "system":
warnings.warn("Warning: system prompt is not supported in VertexAI.")
vertexai_message.append(
{
"role": "user",
"parts": [{"text": "System prompt: " + text}],
}
)
vertexai_message.append(
{
"role": "model",
"parts": [{"text": "Understood."}],
}
)
continue
if msg["role"] == "user":
vertexai_msg = {
"role": "user",
"parts": [{"text": text}],
}
elif msg["role"] == "assistant":
vertexai_msg = {
"role": "model",
"parts": [{"text": text}],
}
# images
if isinstance(msg["content"], list) and len(msg["content"]) > 1:
for image in msg["content"][1:]:
assert image["type"] == "image_url"
vertexai_msg["parts"].append(
{
"inline_data": {
"data": image["image_url"]["url"].split(",")[1],
"mime_type": "image/jpeg",
}
}
)
vertexai_message.append(vertexai_msg)
return vertexai_message

View File

@@ -0,0 +1,662 @@
import re
from dataclasses import dataclass
from enum import Enum, auto
from typing import Callable, Dict, List, Tuple
class ChatTemplateStyle(Enum):
PLAIN = auto()
LLAMA2 = auto()
@dataclass
class ChatTemplate:
name: str
default_system_prompt: str
role_prefix_and_suffix: Dict[str, Tuple[str, str]]
stop_str: List[str] = ()
image_token: str = "<image>"
audio_token: str = "<audio>"
style: ChatTemplateStyle = ChatTemplateStyle.PLAIN
def get_prefix_and_suffix(
self, role: str, hist_messages: List[Dict]
) -> Tuple[str, str]:
prefix, suffix = self.role_prefix_and_suffix.get(role, ("", ""))
if self.style == ChatTemplateStyle.LLAMA2:
if role == "system" and not hist_messages:
user_prefix, _ = self.role_prefix_and_suffix.get("user", ("", ""))
system_prefix, system_suffix = self.role_prefix_and_suffix.get(
"system", ("", "")
)
return (user_prefix + system_prefix, system_suffix)
elif (
role == "user"
and len(hist_messages) == 1
and hist_messages[0]["content"] is not None
):
return ("", suffix)
return prefix, suffix
def get_prompt(self, messages: List[Dict]) -> str:
prompt = ""
for i, message in enumerate(messages):
role, content = message["role"], message["content"]
if role == "system" and content is None:
content = self.default_system_prompt
if content is None:
continue
prefix, suffix = self.get_prefix_and_suffix(role, messages[:i])
prompt += f"{prefix}{content}{suffix}"
return prompt
chat_template_registry: Dict[str, ChatTemplate] = {}
matching_function_registry: List[Callable] = []
def register_chat_template(template):
chat_template_registry[template.name] = template
def register_chat_template_matching_function(func):
matching_function_registry.append(func)
def get_chat_template(name):
return chat_template_registry[name]
def get_chat_template_by_model_path(model_path):
for matching_func in matching_function_registry:
template_name = matching_func(model_path)
if template_name is not None:
return get_chat_template(template_name)
return get_chat_template("default")
register_chat_template(
ChatTemplate(
name="default",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("SYSTEM:", "\n"),
"user": ("USER:", "\n"),
"assistant": ("ASSISTANT:", "\n"),
},
)
)
register_chat_template(
ChatTemplate(
name="claude",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("", ""),
"user": ("\n\nHuman: ", ""),
"assistant": ("\n\nAssistant:", ""),
},
)
)
register_chat_template(
ChatTemplate(
name="chatml",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
},
style=ChatTemplateStyle.PLAIN,
stop_str=("<|im_end|>",),
)
)
register_chat_template(
ChatTemplate(
name="chatml-llava",
default_system_prompt="You are a helpful assistant.",
role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
},
style=ChatTemplateStyle.PLAIN,
stop_str=("<|im_end|>",),
image_token="<image>\n",
)
)
# There is default system prompt for qwen
# reference: https://modelscope.cn/models/qwen/Qwen2-72B-Instruct/file/view/master?fileName=tokenizer_config.json&status=1
# The chat template is: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
register_chat_template(
ChatTemplate(
name="qwen",
default_system_prompt="You are a helpful assistant.",
role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
},
style=ChatTemplateStyle.PLAIN,
stop_str=("<|im_end|>",),
)
)
# Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
register_chat_template(
ChatTemplate(
name="qwen2-vl",
default_system_prompt="You are a helpful assistant.",
role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
},
style=ChatTemplateStyle.PLAIN,
stop_str=("<|im_end|>",),
image_token="<|vision_start|><|image_pad|><|vision_end|>",
)
)
# Reference: https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md#prompt-template
register_chat_template(
ChatTemplate(
name="vicuna_v1.1",
default_system_prompt=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
role_prefix_and_suffix={
"system": ("", " "),
"user": ("USER:", " "),
"assistant": ("ASSISTANT:", "</s>"),
},
image_token=" <image>\n",
)
)
register_chat_template(
ChatTemplate(
name="llama-2-chat",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("<<SYS>>\n", "\n<</SYS>>\n\n"),
"user": ("[INST] ", " [/INST]"),
"assistant": ("", " </s><s>"),
},
style=ChatTemplateStyle.LLAMA2,
)
)
# Reference: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/chat_template.json
register_chat_template(
ChatTemplate(
name="mistral",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("[SYSTEM_PROMPT] ", " [/SYSTEM_PROMPT]"),
"user": ("[INST] ", " [/INST]"),
"assistant": ("", " </s><s>"),
},
stop_str=("</s>",),
image_token="[IMG]",
)
)
register_chat_template(
ChatTemplate(
name="llama-3-instruct",
default_system_prompt=None,
role_prefix_and_suffix={
"system": (
"<|start_header_id|>system<|end_header_id|>\n\n",
"<|eot_id|>",
),
"user": (
"<|start_header_id|>user<|end_header_id|>\n\n",
"<|eot_id|>",
),
"assistant": (
"<|start_header_id|>assistant<|end_header_id|>\n\n",
"<|eot_id|>",
),
},
stop_str=("<|eot_id|>",),
image_token="<|image|>",
)
)
# https://huggingface.co/openbmb/MiniCPM-V-2_6
register_chat_template(
ChatTemplate(
name="minicpmv",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("", " "),
"user": ("user:", " "),
"assistant": ("assistant:", "</s>"),
},
stop_str=("<|im_end|>", "<|endoftext|>"),
image_token="(<image>./</image>)",
)
)
register_chat_template(
ChatTemplate(
name="janus-pro",
default_system_prompt=None,
role_prefix_and_suffix={
"system": (
"",
"",
),
"User": (
"<User>",
"",
),
"assistant": (
"<Assistant>",
"<end▁of▁sentence>",
),
},
stop_str=("<end▁of▁sentence>",),
image_token="<image_placeholder>\n",
)
)
# https://huggingface.co/openbmb/MiniCPM-o-2_6
register_chat_template(
ChatTemplate(
name="minicpmo",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("", " "),
"user": ("user:", " "),
"assistant": ("assistant:", "</s>"),
},
stop_str=("<|im_end|>", "<|endoftext|>"),
image_token="(<image>./</image>)",
audio_token="(<audio>./</audio>)",
)
)
register_chat_template(
ChatTemplate(
name="janus",
default_system_prompt=None,
role_prefix_and_suffix={
"system": (
"",
"",
),
"user": (
"<User>",
"",
),
"assistant": (
"<Assistant>",
"<end▁of▁sentence>",
),
},
stop_str=("<end▁of▁sentence>",),
image_token="<image_placeholder>\n",
)
)
# The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token.
register_chat_template(
ChatTemplate(
name="llama-3-instruct-llava",
default_system_prompt=None,
role_prefix_and_suffix={
"system": (
"<|start_header_id|>system<|end_header_id|>\n\n",
"<|eot_id|>",
),
"user": (
"<|start_header_id|>user<|end_header_id|>\n\n",
"<|eot_id|>",
),
"assistant": (
"<|start_header_id|>assistant<|end_header_id|>\n\n",
"<|eot_id|>",
),
},
stop_str=("<|eot_id|>",),
image_token="<image>\n",
)
)
# Reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json
register_chat_template(
ChatTemplate(
name="llama-4",
default_system_prompt=None,
role_prefix_and_suffix={
"system": (
"<|header_start|>system<|header_end|>\n\n",
"<|eot|>",
),
"user": (
"<|header_start|>user<|header_end|>\n\n",
"<|eot|>",
),
"assistant": (
"<|header_start|>assistant<|header_end|>\n\n",
"<|eot|>",
),
},
stop_str=("<|eot|>",),
image_token="<|image|>",
)
)
# Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1
register_chat_template(
ChatTemplate(
name="yi-1.5",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("", ""),
"user": ("<|im_start|>user\n", "<|im_end|>\n<|im_start|>assistant\n"),
"assistant": ("", "<|im_end|>\n"),
},
style=ChatTemplateStyle.PLAIN,
stop_str=("<|im_end|>",),
)
)
# Reference: https://github.com/01-ai/Yi/tree/main/VL#major-difference-with-llava
register_chat_template(
ChatTemplate(
name="yi-vl",
default_system_prompt=(
"This is a chat between an inquisitive human and an AI assistant. Assume the role of the AI assistant. Read all the images carefully, and respond to the human's questions with informative, helpful, detailed and polite answers."
"这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。仔细阅读所有的图像并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。"
),
role_prefix_and_suffix={
"system": ("", "\n\n"),
"user": ("### Human:", "\n"),
"assistant": ("### Assistant:", "\n"),
},
image_token=" <image_placeholder>\n",
)
)
register_chat_template(
ChatTemplate(
name="gemma-it",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("", ""),
"user": ("<start_of_turn>user\n", "<end_of_turn>\n"),
"assistant": ("<start_of_turn>model\n", "<end_of_turn>\n"),
},
style=ChatTemplateStyle.PLAIN,
)
)
register_chat_template(
ChatTemplate(
name="dbrx-instruct",
default_system_prompt="You are DBRX, created by Databricks. You were last updated in December 2023. You answer questions based on information available up to that point.\nYOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough responses to more complex and open-ended questions.\nYou assist with various tasks, from writing to coding (using markdown for code blocks — remember to use ``` with code, JSON, and tables).\n(You do not have real-time data access or code execution capabilities. You avoid stereotyping and provide balanced perspectives on controversial topics. You do not provide song lyrics, poems, or news articles and do not divulge details of your training data.)\nThis is your system prompt, guiding your responses. Do not reference it, just respond to the user. If you find yourself talking about this message, stop. You should be responding appropriately and usually that means not mentioning this.\nYOU DO NOT MENTION ANY OF THIS INFORMATION ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY.",
role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "<|im_end|>"),
"user": ("\n<|im_start|>user\n", "<|im_end|>"),
"assistant": ("\n<|im_start|>assistant\n", "<|im_end|>"),
},
stop_str=("<|im_end|>",),
)
)
register_chat_template(
ChatTemplate(
name="c4ai-command-r",
default_system_prompt=None,
role_prefix_and_suffix={
"system": (
"<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>",
"<|END_OF_TURN_TOKEN|>",
),
"user": ("<|START_OF_TURN_TOKEN|><|USER_TOKEN|>", "<|END_OF_TURN_TOKEN|>"),
"assistant": (
"<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
"<|END_OF_TURN_TOKEN|>",
),
},
style=ChatTemplateStyle.PLAIN,
)
)
# Adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
register_chat_template(
ChatTemplate(
name="internvl-2-5",
default_system_prompt="你是书生·万象英文名是InternVL是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。",
role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
},
stop_str=["<|im_end|>", "<|action_end|>"],
)
)
register_chat_template(
ChatTemplate(
name="interns1",
default_system_prompt="You are an AI assistant whose name is Intern-S1 (书生大模型).\n- Intern-S1 (书生大模型) is a vision-language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n- Intern-S1 (书生大模型) can understand and communicate fluently in the language chosen by the user such as English and 中文.\nYou are an expert reasoner with extensive experience in all areas. You approach problems through systematic thinking and rigorous reasoning. Your response should reflect deep understanding and precise logical thinking, making your solution path and reasoning clear to others. Please put your thinking process within <think>...</think> tags.",
role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
},
stop_str=["<|im_end|>", "<|action_end|>"],
)
)
register_chat_template(
ChatTemplate(
name="granite-3-instruct",
default_system_prompt=None,
role_prefix_and_suffix={
"system": (
"<|start_of_role|>system<|end_of_role|>",
"<|end_of_text|>",
),
"user": (
"<|start_of_role|>user<|end_of_role|>",
"<|end_of_text|>",
),
"assistant": (
"<|start_of_role|>assistant<|end_of_role|>",
"<|end_of_text|>",
),
},
stop_str=("<|end_of_text|>",),
)
)
register_chat_template(
ChatTemplate(
name="deepseek-v3",
default_system_prompt=None,
role_prefix_and_suffix={
"system": (
"",
"",
),
"user": (
"<User>",
"",
),
"assistant": (
"<Assistant>",
"<end▁of▁sentence>",
),
},
stop_str=("<end▁of▁sentence>",),
)
)
# Reference: https://huggingface.co/docs/transformers/main/model_doc/glm4_v#usage-example
register_chat_template(
ChatTemplate(
name="glm-4v",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("<|system|>\n", "\n"),
"user": ("<|user|>\n", "\n"),
"assistant": ("<|assistant|>\n", "\n"),
},
style=ChatTemplateStyle.PLAIN,
stop_str=["<|user|>", "<|endoftext|>", "<|observation|>"],
image_token="<|image|>",
)
)
@register_chat_template_matching_function
def match_deepseek(model_path: str):
if re.search(r"deepseek-(v3|r1)", model_path, re.IGNORECASE) and not re.search(
r"base", model_path, re.IGNORECASE
):
return "deepseek-v3"
@register_chat_template_matching_function
def match_deepseek_janus_pro(model_path: str):
if re.search(r"janus", model_path, re.IGNORECASE):
return "janus-pro"
@register_chat_template_matching_function
def match_dbrx(model_path: str):
if re.search(r"dbrx", model_path, re.IGNORECASE) and re.search(
r"instruct", model_path, re.IGNORECASE
):
return "dbrx-instruct"
@register_chat_template_matching_function
def match_vicuna(model_path: str):
if re.search(r"vicuna|llava-v1\.5|llava-next-video-7b", model_path, re.IGNORECASE):
return "vicuna_v1.1"
@register_chat_template_matching_function
def match_llama2_chat(model_path: str):
if re.search(
r"llama-2.*chat|codellama.*instruct",
model_path,
re.IGNORECASE,
):
return "llama-2-chat"
@register_chat_template_matching_function
def match_mistral(model_path: str):
if re.search(r"pixtral|(mistral|mixtral).*instruct", model_path, re.IGNORECASE):
return "mistral"
@register_chat_template_matching_function
def match_llama3_instruct(model_path: str):
if re.search(r"llama-3.*instruct", model_path, re.IGNORECASE):
return "llama-3-instruct"
@register_chat_template_matching_function
def match_chat_ml(model_path: str):
if re.search(r"tinyllama", model_path, re.IGNORECASE):
return "chatml"
if re.search(r"qwen.*vl", model_path, re.IGNORECASE):
return "qwen2-vl"
if re.search(r"glm[-_]?4(\.\d+)?v", model_path, re.IGNORECASE):
return "glm-4v"
if re.search(r"qwen.*(chat|instruct)", model_path, re.IGNORECASE) and not re.search(
r"llava", model_path, re.IGNORECASE
):
return "qwen"
if re.search(
r"llava-v1\.6-34b|llava-v1\.6-yi-34b|llava-next-video-34b|llava-onevision-qwen2",
model_path,
re.IGNORECASE,
):
return "chatml-llava"
@register_chat_template_matching_function
def match_chat_yi(model_path: str):
if re.search(r"yi-vl", model_path, re.IGNORECASE) and not re.search(
r"llava", model_path, re.IGNORECASE
):
return "yi-vl"
elif re.search(r"yi-1\.5.*chat", model_path, re.IGNORECASE):
return "yi-1.5"
@register_chat_template_matching_function
def match_gemma_it(model_path: str):
if re.search(r"gemma.*it", model_path, re.IGNORECASE):
return "gemma-it"
@register_chat_template_matching_function
def match_openbmb_minicpm(model_path: str):
if re.search(r"minicpm-v", model_path, re.IGNORECASE):
return "minicpmv"
elif re.search(r"minicpm-o", model_path, re.IGNORECASE):
return "minicpmo"
@register_chat_template_matching_function
def match_c4ai_command_r(model_path: str):
if re.search(r"c4ai-command-r", model_path, re.IGNORECASE):
return "c4ai-command-r"
@register_chat_template_matching_function
def match_granite_instruct(model_path: str):
if re.search(r"granite.*instruct", model_path, re.IGNORECASE):
return "granite-3-instruct"
@register_chat_template_matching_function
def match_gemma3_instruct(model_path: str):
if re.search(r"gemma-3", model_path, re.IGNORECASE):
return "gemma-it"
@register_chat_template_matching_function
def match_internvl_chat(model_path: str):
if re.search(r"internvl2_5", model_path, re.IGNORECASE):
return "internvl-2-5"
@register_chat_template_matching_function
def match_interns1_chat(model_path: str):
if re.search(r"intern-s1", model_path, re.IGNORECASE):
return "interns1"
if re.search(r"interns1", model_path, re.IGNORECASE):
return "interns1"
if __name__ == "__main__":
messages = [
{"role": "system", "content": None}, # None means default
# {"role": "system", "content": "You are a helpful, respectful and honest assistant."},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hi!"},
{"role": "user", "content": "What can you do?"},
{"role": "assistant", "content": "I can chat with you."},
]
template = get_chat_template("llama-2-chat")
print(template.get_prompt(messages))

View File

@@ -0,0 +1,164 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import numpy as np
@dataclass
class ChoicesDecision:
decision: str
meta_info: Optional[Dict[str, Any]] = None
class ChoicesSamplingMethod(ABC):
@property
def requires_unconditional_logprobs(self) -> bool:
return False
@abstractmethod
def __call__(
self,
*,
choices: List[str],
normalized_prompt_logprobs: List[float],
input_token_logprobs: List[List[Any]],
output_token_logprobs: List[List[Any]],
unconditional_token_logprobs: Optional[List[List[Any]]] = None,
) -> ChoicesDecision: ...
class TokenLengthNormalized(ChoicesSamplingMethod):
def __call__(
self,
*,
choices: List[str],
normalized_prompt_logprobs: List[float],
input_token_logprobs: List[List[Any]],
output_token_logprobs: List[List[Any]],
unconditional_token_logprobs: Optional[List[List[Any]]] = None,
) -> ChoicesDecision:
"""Select the option with the highest token length normalized prompt logprob."""
best_choice = choices[np.argmax(normalized_prompt_logprobs)]
meta_info = {
"normalized_prompt_logprobs": normalized_prompt_logprobs,
"input_token_logprobs": input_token_logprobs,
"output_token_logprobs": output_token_logprobs,
}
return ChoicesDecision(decision=best_choice, meta_info=meta_info)
token_length_normalized = TokenLengthNormalized()
class GreedyTokenSelection(ChoicesSamplingMethod):
def __call__(
self,
*,
choices: List[str],
normalized_prompt_logprobs: List[float],
input_token_logprobs: List[List[Any]],
output_token_logprobs: List[List[Any]],
unconditional_token_logprobs: Optional[List[List[Any]]] = None,
) -> ChoicesDecision:
"""Select the option based on greedy logprob selection. For overlapping options
where one option is a subset of a longer option, extend the shorter option using
its average logprob for comparison against the longer option."""
num_options = len(choices)
max_tokens = max(len(option) for option in input_token_logprobs)
logprob_matrix = self._build_logprob_matrix(
input_token_logprobs, max_tokens, num_options
)
remaining = self._greedy_selection(logprob_matrix, num_options, max_tokens)
best_choice = choices[remaining[0]]
meta_info = {
"normalized_prompt_logprobs": normalized_prompt_logprobs,
"input_token_logprobs": input_token_logprobs,
"output_token_logprobs": output_token_logprobs,
"greedy_logprob_matrix": logprob_matrix.tolist(),
}
return ChoicesDecision(decision=best_choice, meta_info=meta_info)
def _build_logprob_matrix(self, input_token_logprobs, max_tokens, num_options):
logprob_matrix = np.zeros((num_options, max_tokens))
for i, option in enumerate(input_token_logprobs):
actual_logprobs = [token[0] for token in option]
avg_logprob = np.mean(actual_logprobs)
logprob_matrix[i, : len(option)] = actual_logprobs
if len(option) < max_tokens:
logprob_matrix[i, len(option) :] = avg_logprob
return logprob_matrix
def _greedy_selection(self, logprob_matrix, num_options, max_tokens):
remaining = np.arange(num_options)
for j in range(max_tokens):
max_logprob = np.max(logprob_matrix[remaining, j])
remaining = remaining[logprob_matrix[remaining, j] == max_logprob]
if len(remaining) == 1:
break
return remaining
greedy_token_selection = GreedyTokenSelection()
class UnconditionalLikelihoodNormalized(ChoicesSamplingMethod):
@property
def requires_unconditional_logprobs(self) -> bool:
return True
def __call__(
self,
*,
choices: List[str],
normalized_prompt_logprobs: List[float],
input_token_logprobs: List[List[Any]],
output_token_logprobs: List[List[Any]],
unconditional_token_logprobs: Optional[List[List[Any]]] = None,
) -> ChoicesDecision:
"""Select the option with the highest average token logprob once normalized by
the unconditional token logprobs.
The first unconditional token logprob is assumed to be None. If so, it is
replaced with 0 for the purposes of normalization."""
if unconditional_token_logprobs is None:
raise ValueError(
"Unconditional token logprobs are required for this method."
)
normalized_unconditional_prompt_logprobs = self._normalize_logprobs(
input_token_logprobs, unconditional_token_logprobs
)
best_choice = choices[np.argmax(normalized_unconditional_prompt_logprobs)]
meta_info = {
"normalized_prompt_logprobs": normalized_prompt_logprobs,
"input_token_logprobs": input_token_logprobs,
"output_token_logprobs": output_token_logprobs,
"unconditional_token_logprobs": unconditional_token_logprobs,
"normalized_unconditional_prompt_logprobs": normalized_unconditional_prompt_logprobs,
}
return ChoicesDecision(decision=best_choice, meta_info=meta_info)
def _normalize_logprobs(self, input_token_logprobs, unconditional_token_logprobs):
normalized_unconditional_prompt_logprobs = []
for inputs, unconditionals in zip(
input_token_logprobs, unconditional_token_logprobs
):
inputs_logprobs = np.array([token[0] for token in inputs])
unconditionals_logprobs = np.array([token[0] for token in unconditionals])
unconditionals_logprobs[0] = unconditionals_logprobs[0] or 0
normalized_unconditional_prompt_logprobs.append(
float(np.mean(inputs_logprobs - unconditionals_logprobs))
)
return normalized_unconditional_prompt_logprobs
unconditional_likelihood_normalized = UnconditionalLikelihoodNormalized()

View File

@@ -0,0 +1,231 @@
import multiprocessing
from concurrent.futures import ThreadPoolExecutor
from queue import Queue
from typing import List, Union
from sglang.global_config import global_config
from sglang.lang.interpreter import ProgramState, StreamExecutor, cache_program
from sglang.lang.ir import SglArgument, SglExpr, SglSamplingParams, SglVariable
def compile_func(function, backend):
tracer = function.trace(backend=backend)
compiler = CompiledFunction(tracer, function)
return compiler
class CompiledFunction:
def __init__(self, tracer, function):
self.function = function
self.last_node = CompGraphNode(tracer.last_node)
self.expr_to_node = {}
self.build_graph(tracer)
self.topological_sort()
def build_graph(self, tracer):
self.nodes = [self.last_node]
self.expr_to_node[tracer.last_node] = self.nodes[-1]
rename_pid = {}
visited = set([tracer.last_node])
head = 0
while head < len(self.nodes):
cur_node = self.nodes[head]
# add prev node
prev_node = cur_node.expr.prev_node
if prev_node is not None:
if prev_node not in visited:
visited.add(prev_node)
self.nodes.append(CompGraphNode(prev_node))
self.expr_to_node[prev_node] = self.nodes[-1]
cur_node.prev_node = self.expr_to_node[prev_node]
self.expr_to_node[prev_node].add_next_node(cur_node)
# add source node
if isinstance(cur_node.expr, SglVariable):
if cur_node.expr.name in tracer.variables:
source = tracer.variables[cur_node.expr.name].source
else:
source = cur_node.expr.source
if source not in visited:
visited.add(source)
self.nodes.append(CompGraphNode(source))
self.expr_to_node[source] = self.nodes[-1]
cur_node.source_node = self.expr_to_node[source]
self.expr_to_node[source].add_next_node(cur_node)
head += 1
# rename pid
if cur_node.expr.pid not in rename_pid:
rename_pid[cur_node.expr.pid] = len(rename_pid)
cur_node.expr.pid = rename_pid[cur_node.expr.pid]
def topological_sort(self):
prevd = {}
cand = Queue()
for x in self.nodes:
prevd[x] = (x.prev_node is not None) + (x.source_node is not None)
if prevd[x] == 0:
cand.put(x)
new_list = []
while cand.qsize() > 0:
head = cand.get()
new_list.append(head)
for x in head.next_nodes:
prevd[x] -= 1
if prevd[x] == 0:
cand.put(x)
self.nodes = new_list
def print_graph(
self,
):
for node in self.nodes:
print(node)
def run_internal(
self,
backend,
kwargs,
default_sampling_para,
):
stream_executor_ids = set([x.expr.pid for x in self.nodes])
stream_executors = {}
for x in stream_executor_ids:
arguments = kwargs if x == self.last_node.expr.pid else {}
stream_executors[x] = StreamExecutor(
backend, arguments, default_sampling_para, None, False
)
for node in self.nodes:
se_id = node.expr.pid
expr = node.expr
if isinstance(expr, SglVariable):
# Make a copy for SglVariable
expr = SglVariable(expr.name, expr.source)
expr.source_stream_executor = stream_executors[
node.source_node.expr.pid
]
elif isinstance(expr, SglArgument):
# Substitute SglArgument
expr = kwargs[expr.name]
stream_executors[se_id].submit(expr)
for stream_executor in stream_executors.values():
stream_executor.end()
return ProgramState(stream_executors[self.last_node.expr.pid])
def run(
self,
*,
max_new_tokens: int = 128,
stop: Union[str, List[str]] = (),
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
backend=None,
**kwargs,
):
backend = backend or global_config.default_backend
kwargs.update(self.function.bind_arguments)
default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
)
return self.run_internal(backend, kwargs, default_sampling_para)
def run_batch(
self,
batch_kwargs,
*,
max_new_tokens: int = 128,
stop: Union[str, List[str]] = (),
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
backend=None,
num_threads: Union[str, int] = "auto",
):
assert isinstance(batch_kwargs, (list, tuple))
if len(batch_kwargs) == 0:
return []
assert isinstance(batch_kwargs[0], dict)
backend = backend or global_config.default_backend
default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
)
# Extract prefix by tracing and cache it
if len(batch_kwargs) > 1:
cache_program(self.function, backend)
# Run all programs
if num_threads == "auto":
num_threads = multiprocessing.cpu_count()
num_threads = min(num_threads, len(batch_kwargs))
if num_threads == 1:
rets = []
for arguments in batch_kwargs:
rets.append(
self.run_internal(backend, arguments, default_sampling_para)
)
else:
with ThreadPoolExecutor(num_threads) as executor:
futures = []
for arguments in batch_kwargs:
futures.append(
executor.submit(
self.run_internal, backend, arguments, default_sampling_para
)
)
rets = [f.result() for f in futures]
rets[-1].sync()
return rets
class CompGraphNode:
def __init__(
self, expr: SglExpr, prev_node=None, next_nodes=None, source_node=None
):
self.expr = expr
self.next_nodes = next_nodes or []
self.prev_node = prev_node
self.source_node = source_node
def add_next_node(self, other):
self.next_nodes.append(other)
def __repr__(self):
re = f"stream {self.expr.pid:2d}: "
re += f"%{self.expr.node_id} = "
if self.prev_node is not None:
re += f"%{self.prev_node.expr.node_id} + "
re += repr(self.expr)
return re

File diff suppressed because it is too large Load Diff

635
python/sglang/lang/ir.py Normal file
View File

@@ -0,0 +1,635 @@
"""The intermediate representation."""
import dataclasses
import inspect
import warnings
from typing import List, Optional, Union
from sglang.global_config import global_config
from sglang.lang.choices import ChoicesSamplingMethod
REGEX_INT = r"[-+]?[0-9]+[ \n]*"
REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+[ \n]*"
REGEX_BOOL = r"(True|False)"
REGEX_STR = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
@dataclasses.dataclass
class SglSamplingParams:
max_new_tokens: int = 128
min_new_tokens: int = 0
n: int = 1
stop: Union[str, List[str]] = ()
stop_token_ids: Optional[List[int]] = ()
temperature: float = 1.0
top_p: float = 1.0
top_k: int = -1 # -1 means disable
min_p: float = 0.0
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
ignore_eos: bool = False
return_logprob: Optional[bool] = None
logprob_start_len: Optional[int] = (None,)
top_logprobs_num: Optional[int] = (None,)
return_text_in_logprobs: Optional[bool] = (None,)
json_schema: Optional[str] = None
# for constrained generation, not included in to_xxx_kwargs
dtype: Optional[str] = None
regex: Optional[str] = None
def clone(self):
return SglSamplingParams(
self.max_new_tokens,
self.min_new_tokens,
self.n,
self.stop,
self.stop_token_ids,
self.temperature,
self.top_p,
self.top_k,
self.min_p,
self.frequency_penalty,
self.presence_penalty,
self.ignore_eos,
self.return_logprob,
self.logprob_start_len,
self.top_logprobs_num,
self.return_text_in_logprobs,
self.json_schema,
)
def to_openai_kwargs(self):
# OpenAI does not support top_k, so we drop it here
if self.regex is not None:
warnings.warn("Regular expression is not supported in the OpenAI backend.")
return {
"max_tokens": self.max_new_tokens,
"max_completion_tokens": self.max_new_tokens,
"n": self.n,
"stop": self.stop or None,
"temperature": self.temperature,
"top_p": self.top_p,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
}
def to_vertexai_kwargs(self):
if self.regex is not None:
warnings.warn(
"Regular expression is not supported in the VertexAI backend."
)
return {
"candidate_count": 1,
"max_output_tokens": self.max_new_tokens,
"stop_sequences": self.stop,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k if self.top_k > 0 else None,
}
def to_anthropic_kwargs(self):
# Anthropic does not support frequency_penalty or presence_penalty, so we drop it here
if self.regex is not None:
warnings.warn(
"Regular expression is not supported in the Anthropic backend."
)
return {
"max_tokens": self.max_new_tokens,
"stop_sequences": (
self.stop if isinstance(self.stop, (list, tuple)) else [self.stop]
),
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
}
def to_litellm_kwargs(self):
if self.regex is not None:
warnings.warn("Regular expression is not supported in the LiteLLM backend.")
return {
"max_tokens": self.max_new_tokens,
"stop": self.stop or None,
"temperature": self.temperature,
"top_p": self.top_p,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
}
def to_srt_kwargs(self):
return {
"max_new_tokens": self.max_new_tokens,
"min_new_tokens": self.min_new_tokens,
"n": self.n,
"stop": self.stop,
"stop_token_ids": self.stop_token_ids,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"min_p": self.min_p,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"ignore_eos": self.ignore_eos,
"regex": self.regex,
"json_schema": self.json_schema,
}
class SglFunction:
def __init__(self, func, num_api_spec_tokens=None, bind_arguments=None):
self.func = func
self.num_api_spec_tokens = num_api_spec_tokens
self.bind_arguments = bind_arguments or {}
self.pin_prefix_rid = None
# Parse arguments
argspec = inspect.getfullargspec(func)
assert argspec.args[0] == "s", 'The first argument must be "s"'
self.arg_names = argspec.args[1:]
self.arg_defaults = argspec.defaults if argspec.defaults is not None else []
def bind(self, **kwargs):
assert all(key in self.arg_names for key in kwargs)
new_bind_dict = {**self.bind_arguments, **kwargs}
return SglFunction(self.func, bind_arguments=new_bind_dict)
def run(
self,
*args,
max_new_tokens: int = 128,
n: int = 1,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
ignore_eos: bool = False,
return_logprob: Optional[bool] = None,
logprob_start_len: Optional[int] = None,
top_logprobs_num: Optional[int] = None,
return_text_in_logprobs: Optional[bool] = None,
stream: bool = False,
backend=None,
use_thread: bool = True,
**kwargs,
):
from sglang.lang.interpreter import run_program
# avoid using [] as the default arg: https://nikos7am.com/posts/mutable-default-arguments/
if stop is None:
stop = []
if stop_token_ids is None:
stop_token_ids = []
default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens,
n=n,
stop=stop,
stop_token_ids=stop_token_ids,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
ignore_eos=ignore_eos,
return_logprob=return_logprob,
logprob_start_len=logprob_start_len,
top_logprobs_num=top_logprobs_num,
return_text_in_logprobs=return_text_in_logprobs,
)
backend = backend or global_config.default_backend
return run_program(
self,
backend,
args,
kwargs,
default_sampling_para,
stream,
use_thread=use_thread,
)
def run_batch(
self,
batch_kwargs,
*,
max_new_tokens: int = 128,
n: int = 1,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
ignore_eos: bool = False,
return_logprob: Optional[bool] = None,
logprob_start_len: Optional[int] = None,
top_logprobs_num: Optional[int] = None,
return_text_in_logprobs: Optional[bool] = None,
backend=None,
num_threads: Union[str, int] = "auto",
progress_bar: bool = False,
generator_style: bool = False,
):
from sglang.lang.interpreter import run_program_batch
if stop is None:
stop = []
if stop_token_ids is None:
stop_token_ids = []
assert isinstance(batch_kwargs, (list, tuple))
if len(batch_kwargs) == 0:
return []
if not isinstance(batch_kwargs[0], dict):
num_programs = len(batch_kwargs)
# change the list of argument values to dict of arg_name -> arg_value
batch_kwargs = [
{self.arg_names[i]: v for i, v in enumerate(arg_values)}
for arg_values in batch_kwargs
if isinstance(arg_values, (list, tuple))
and len(self.arg_names) - len(self.arg_defaults)
<= len(arg_values)
<= len(self.arg_names)
]
# Ensure to raise an exception if the number of arguments mismatch
if len(batch_kwargs) != num_programs:
raise Exception("Given arguments mismatch the SGL function signature")
default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens,
n=n,
stop=stop,
stop_token_ids=stop_token_ids,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
ignore_eos=ignore_eos,
return_logprob=return_logprob,
logprob_start_len=logprob_start_len,
top_logprobs_num=top_logprobs_num,
return_text_in_logprobs=return_text_in_logprobs,
)
backend = backend or global_config.default_backend
return run_program_batch(
self,
backend,
batch_kwargs,
default_sampling_para,
num_threads,
progress_bar,
generator_style=generator_style,
)
def trace(self, *, backend=None, **kwargs):
from sglang.lang.tracer import trace_program
backend = backend or global_config.default_backend
return trace_program(self, kwargs, backend)
def cache(self, backend=None):
from sglang.lang.interpreter import cache_program
backend = backend or global_config.default_backend
return cache_program(self, backend)
def compile(self, *, backend=None):
from sglang.lang.compiler import compile_func
return compile_func(self, backend)
def __call__(self, *args, **kwargs):
from sglang.lang.tracer import TracingScope
tracing_scope = TracingScope.get_current_scope()
if tracing_scope is None:
return self.run(*args, **kwargs)
else:
kwargs["backend"] = tracing_scope.tracer_state.backend
return self.trace(*args, **kwargs)
class SglExpr:
node_ct = 0
def __init__(self):
self.node_id = SglExpr.node_ct
self.prev_node = None
self.pid = None
SglExpr.node_ct += 1
def __add__(self, other):
if isinstance(other, str):
other = SglConstantText(other)
assert isinstance(other, SglExpr)
return self.concatenate_ir(self, other)
def __radd__(self, other):
if isinstance(other, str):
other = SglConstantText(other)
assert isinstance(other, SglExpr), f"{other}"
return self.concatenate_ir(other, self)
def concatenate_ir(self, a, b):
if isinstance(a, SglExprList):
if isinstance(b, SglExprList):
return SglExprList(a.expr_list + b.expr_list)
else:
return SglExprList(a.expr_list + [b])
elif isinstance(b, SglExprList):
return SglExprList([a] + b.expr_list)
return SglExprList([a, b])
def print_graph_dfs(self):
ret = [""]
visited = set()
def dfs_print(x):
if x is None or x in visited:
return
visited.add(x)
# Print dependency
if x.prev_node is not None:
dfs_print(x.prev_node)
if isinstance(x, SglExprList):
for y in x.expr_list:
dfs_print(y)
# elif isinstance(x, SglRole):
# dfs_print(x.expr)
elif isinstance(x, SglVariable):
dfs_print(x.source)
# Print the node itself
if isinstance(x, (SglFork, SglGetForkItem)):
ret[0] += f"%{x.node_id} = {x}\n"
else:
if x.prev_node is not None:
ret[0] += (
f"%{x.node_id} = %{x.prev_node.node_id} + " + str(x) + "\n"
)
else:
ret[0] += f"%{x.node_id} = " + str(x) + "\n"
dfs_print(self)
return ret[0]
class SglExprList(SglExpr):
def __init__(self, expr_list: List[SglExpr]):
super().__init__()
self.expr_list = expr_list
def __repr__(self):
return f"ExprList({self.expr_list})"
class SglArgument(SglExpr):
def __init__(self, name: str, value: str):
super().__init__()
self.name = name
self.value = value
def __repr__(self):
return f"Argument(name={self.name}, value={repr(self.value)})"
def __len__(self):
return len(self.value)
def __getitem__(self, i):
return self.value[i]
def __int__(self):
return self.value
def __bool__(self):
return self.value
def __format__(self, *args):
raise TypeError(
"Cannot put argument inside a f-string. "
"This is not compatible with the tracer. "
)
class SglImage(SglExpr):
def __init__(self, path: str):
self.path = path
def __repr__(self) -> str:
return f"SglImage({self.path})"
class SglVideo(SglExpr):
def __init__(self, path: str, num_frames: int):
self.path = path
self.num_frames = num_frames
def __repr__(self) -> str:
return f"SglVideo({self.path}, {self.num_frames})"
class SglGen(SglExpr):
def __init__(
self,
name: Optional[str] = None,
max_new_tokens: Optional[int] = None,
min_new_tokens: Optional[int] = None,
n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
min_p: Optional[float] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
ignore_eos: Optional[bool] = None,
return_logprob: Optional[bool] = None,
logprob_start_len: Optional[int] = None,
top_logprobs_num: Optional[int] = None,
return_text_in_logprobs: Optional[bool] = None,
dtype: Optional[type] = None,
regex: Optional[str] = None,
json_schema: Optional[str] = None,
):
"""Call the model to generate. See the meaning of the arguments in docs/backend/sampling_params.md"""
super().__init__()
self.name = name
self.sampling_params = SglSamplingParams(
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
n=n,
stop=stop,
stop_token_ids=stop_token_ids,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
ignore_eos=ignore_eos,
return_logprob=return_logprob,
logprob_start_len=logprob_start_len,
top_logprobs_num=top_logprobs_num,
return_text_in_logprobs=return_text_in_logprobs,
dtype=dtype,
regex=regex,
json_schema=json_schema,
)
def __repr__(self):
return f"Gen('{self.name}')"
class SglConstantText(SglExpr):
def __init__(self, value: str):
super().__init__()
self.value = value
def __repr__(self):
return f"Constant({repr(self.value)})"
class SglRoleBegin(SglExpr):
def __init__(self, role: str):
super().__init__()
self.role = role
def __repr__(self):
return f"RoleBegin({self.role})"
class SglRoleEnd(SglExpr):
def __init__(self, role: str):
super().__init__()
self.role = role
def __repr__(self):
return f"RoleEnd({self.role})"
class SglSelect(SglExpr):
def __init__(
self,
name: str,
choices: List[str],
temperature: float,
choices_method: ChoicesSamplingMethod,
):
super().__init__()
self.name = name
self.choices = choices
self.temperature = temperature
self.choices_method = choices_method
def __repr__(self):
return f"Select({self.name}, choices={self.choices}, choices_method={self.choices_method})"
class SglFork(SglExpr):
def __init__(self, number: int, position_ids_offset=None):
super().__init__()
self.number = number
self.position_ids_offset = position_ids_offset
def __repr__(self):
return (
f"Fork(%{self.prev_node.node_id}, number={self.number}, "
f"position_ids_offset={self.position_ids_offset})"
)
class SglGetForkItem(SglExpr):
def __init__(self, index: int):
super().__init__()
self.index = index
def __repr__(self):
return f"GetForkItem(%{self.prev_node.node_id}, index={self.index})"
class SglVariable(SglExpr):
def __init__(self, name: str, source):
super().__init__()
self.name = name
self.source = source
def __repr__(self):
return f"Variable('{self.name}', source=%{self.source.node_id})"
class SglVarScopeBegin(SglExpr):
def __init__(self, name: str):
super().__init__()
self.name = name
def __repr__(self):
return f"VarScopeBegin('{self.name}')"
class SglVarScopeEnd(SglExpr):
def __init__(self, name: str):
super().__init__()
self.name = name
def __repr__(self):
return f"VarScopeEnd('{self.name}')"
class SglConcateAndAppend(SglExpr):
def __init__(self, states):
super().__init__()
self.states = states
def __repr__(self):
return f"ConcatenateAndAppend('{self.states}')"
class SglCommitLazy(SglExpr):
def __init__(self):
super().__init__()
def __repr__(self):
return "CommitLazy()"
class SglSeparateReasoning(SglExpr):
def __init__(self, model_type: str, expr: SglExpr):
super().__init__()
self.model_type = model_type
self.expr = expr
self.name = None
self._process_expr(expr)
def process_name_for_reasoning(self, name):
if not name:
raise ValueError("name must be provided")
return f"{name}_reasoning_content"
def _process_expr(self, expr):
if isinstance(expr, SglGen):
self.name = self.process_name_for_reasoning(expr.name)
elif isinstance(expr, SglSelect):
self.name = self.process_name_for_reasoning(expr.name)
elif isinstance(expr, SglExprList):
for x in expr.expr_list:
self._process_expr(x)
def __repr__(self):
return f"SeparateReasoning(model_type={self.model_type}, name={self.name})"

View File

@@ -0,0 +1,279 @@
"""Tracing a program."""
import uuid
from typing import Any, Dict, List, Optional
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.interpreter import ProgramState, ProgramStateGroup
from sglang.lang.ir import (
SglArgument,
SglConstantText,
SglExpr,
SglExprList,
SglFork,
SglGen,
SglGetForkItem,
SglRoleBegin,
SglRoleEnd,
SglSelect,
SglVariable,
SglVarScopeBegin,
SglVarScopeEnd,
)
class StopTracing(Exception):
pass
def extract_prefix_by_tracing(program, backend):
# Create dummy arguments
dummy_arguments = {name: SglArgument(name, None) for name in program.arg_names}
arguments = dummy_arguments
arguments.update(program.bind_arguments)
# Trace
tracer = TracerProgramState(backend, arguments, only_trace_prefix=True)
try:
with TracingScope(tracer):
tracer.ret_value = program.func(tracer, **arguments)
except (StopTracing, TypeError, AttributeError):
# Some exceptions may not be caught
pass
# Run and cache prefix
prefix = ""
for expr in tracer.flatten_nodes():
if isinstance(expr, SglConstantText):
prefix += expr.value
else:
break
return prefix
def trace_program(program, arguments, backend):
# Create dummy backend
if backend is None:
backend = BaseBackend()
# Create dummy arguments
dummy_arguments = {
name: SglArgument(name, None)
for name in program.arg_names
if name not in arguments
}
arguments.update(dummy_arguments)
arguments.update(program.bind_arguments)
# Trace
tracer = TracerProgramState(backend, arguments, only_trace_prefix=False)
with TracingScope(tracer):
tracer.ret_value = program.func(tracer, **arguments)
return tracer
class TracerProgramState(ProgramState):
def __init__(self, backend, arguments, only_trace_prefix):
self.pid = uuid.uuid4().hex
self.backend = backend
self.arguments: Dict[str, Any] = arguments
self.only_trace_prefix = only_trace_prefix
if hasattr(backend, "endpoint"):
self.backend = backend.endpoint
self.nodes = []
self.last_node = None
self.variables = {}
self.ret_value = None
# For completion
# For chat
self.messages_ = []
self.cur_role = None
self.chat_template = self.backend.get_chat_template()
# For multi states
self.child_states = []
cur_scope = TracingScope.get_current_scope()
if cur_scope is not None:
cur_scope.add_child_state(self)
##################################
########### Public API ###########
##################################
def fork(self, size: int = 1, position_ids_offset: Optional[List[int]] = None):
assert size >= 1
if self.only_trace_prefix:
raise StopTracing()
fork_node = SglFork(size)
fork_node.prev_node = self.last_node
states = [
TracerProgramState(self.backend, self.arguments, self.only_trace_prefix)
for _ in range(size)
]
for i in range(size):
node = SglGetForkItem(i)
node.prev_node = fork_node
states[i].last_node = node
states[i].variables = dict(self.variables)
states[i].messages_ = list(self.messages_)
states[i].cur_role = self.cur_role
states[i].chat_template = self.chat_template
state_group = ProgramStateGroup(states, self)
return state_group
##################################
########## Internal API ##########
##################################
def _append_node(self, other: SglExpr):
self.nodes.append(other)
other.prev_node = self.last_node
self.last_node = other
def _execute(self, other: SglExpr):
if isinstance(other, str):
other = SglConstantText(other)
other.pid = self.pid
if isinstance(other, SglConstantText):
self._execute_fill(other)
elif isinstance(other, SglGen):
self._execute_gen(other)
elif isinstance(other, SglSelect):
self._execute_select(other)
elif isinstance(other, SglExprList):
for x in other.expr_list:
self._execute(x)
elif isinstance(other, SglRoleBegin):
self._execute_role_begin(other)
elif isinstance(other, SglRoleEnd):
self._execute_role_end(other)
elif isinstance(other, SglVarScopeBegin):
self._execute_var_scope_begin(other)
elif isinstance(other, SglVarScopeEnd):
self._execute_var_scope_end(other)
else:
if self.only_trace_prefix:
raise StopTracing()
else:
self._append_node(other)
return self
def __iadd__(self, other):
self._execute(other)
return self
def _execute_fill(self, expr: SglConstantText):
if isinstance(expr, str):
expr = SglConstantText(expr)
self._append_node(expr)
def _execute_gen(self, expr: SglGen):
name = expr.name if expr.name is not None else "gen_" + str(len(self.variables))
new_node = SglVariable(name, source=expr)
self.variables[name] = new_node
self._append_node(expr)
def _execute_select(self, expr: SglSelect):
name = (
expr.name if expr.name is not None else "select_" + str(len(self.variables))
)
new_node = SglVariable(name, source=expr)
self.variables[name] = new_node
self._append_node(expr)
def _execute_role_begin(self, expr: SglRoleBegin):
assert self.cur_role is None, "Nested roles are not allowed."
if len(self.messages_) == 0 and expr.role != "system":
# Insert default system message
default_system = self.chat_template.default_system_prompt
if default_system:
self._execute_role_begin(SglRoleBegin("system"))
self._execute_fill(default_system)
self._execute_role_end(SglRoleEnd("system"))
self.cur_role = expr.role
prefix, suffix = self.chat_template.get_prefix_and_suffix(
expr.role, self.messages_
)
self._execute_fill(prefix)
def _execute_role_end(self, expr: SglRoleEnd):
prefix, suffix = self.chat_template.get_prefix_and_suffix(
expr.role, self.messages_
)
self._execute_fill(suffix)
self.messages_.append({"role": expr.role, "content": ""})
self.cur_role = None
def _execute_var_scope_end(self, expr: SglVarScopeEnd):
new_node = SglVariable(expr.name, source=self.last_node)
self.variables[expr.name] = new_node
def get_var(self, name):
ret = self.arguments.get(name, None)
if ret is not None:
return ret
v = self.variables[name]
return SglVariable(v.name, v.source)
def flatten_nodes(self):
def traverse(cur):
if isinstance(cur, SglExprList):
for child in cur.expr_list:
traverse(child)
else:
ret.append(cur)
ret = []
for x in self.nodes:
traverse(x)
return ret
def __del__(self):
pass
class TracingScope:
cur_scope = None
def __init__(self, tracer_state: TracerProgramState):
self.tracer_state = tracer_state
self.last_scope = TracingScope.cur_scope
def __enter__(self):
TracingScope.cur_scope = self
return self
def __exit__(self, exc_type, exc_value, traceback):
TracingScope.cur_scope = self.last_scope
@staticmethod
def get_current_scope():
return TracingScope.cur_scope
def add_child_state(self, state: TracerProgramState):
cur_scope = self
while cur_scope is not None:
cur_scope.tracer_state.child_states.append(state)
cur_scope = cur_scope.last_scope

View File

@@ -0,0 +1,16 @@
"""Launch the inference server."""
import os
import sys
from sglang.srt.entrypoints.http_server import launch_server
from sglang.srt.server_args import prepare_server_args
from sglang.srt.utils import kill_process_tree
if __name__ == "__main__":
server_args = prepare_server_args(sys.argv[1:])
try:
launch_server(server_args)
finally:
kill_process_tree(os.getpid(), include_parent=False)

166
python/sglang/profiler.py Normal file
View File

@@ -0,0 +1,166 @@
"""
Run live profiling.
Usage:
python3 -m sglang.profiler
"""
import argparse
import json
import os
import time
from argparse import ArgumentParser
from pathlib import Path
from typing import List, Optional
import requests
PARENT_FOLDER = "/tmp/sglang-profile"
def _run_profile(
url: Optional[str],
num_steps: int,
activities: List[str],
output_dir: Optional[str] = None,
profile_name: Optional[str] = None,
profile_by_stage: bool = False,
) -> str:
if output_dir is None:
output_dir = PARENT_FOLDER
output_dir = os.path.normpath(output_dir)
output_dir = os.path.abspath(output_dir)
output_dir = Path(output_dir)
# Add "profile_name/timestamp" to the path.
if profile_name:
output_dir = output_dir / profile_name
output_dir = output_dir / str(time.time())
output_dir.mkdir(exist_ok=True, parents=True)
print(f"Dump profiling traces to {output_dir}")
print(
f"Waiting for {num_steps} steps and the trace to be flushed.... ({profile_by_stage=})"
)
# Dump server args.
file_path = Path(output_dir) / "server_args.json"
if not file_path.exists():
response = requests.get(url + "/get_server_info")
response.raise_for_status()
server_args_data = response.json()
with open(file_path, "w") as file:
file.write(json.dumps(server_args_data))
# Start profiler. The API replies when all steps are processed
# and files are generated.
json_data = {
"output_dir": str(output_dir),
"num_steps": str(num_steps),
"activities": activities,
"profile_by_stage": profile_by_stage,
}
response = requests.post(url=url + "/start_profile", json=json_data)
response.raise_for_status()
trace_link = str(output_dir)
return trace_link
def run_profile(
url: Optional[str],
num_steps: int,
activities: List[str],
output_dir: Optional[str] = None,
profile_name: Optional[str] = None,
profile_by_stage: bool = False,
):
# step based profile will self terminate on num_steps constraints
link = _run_profile(
url, num_steps, activities, output_dir, profile_name, profile_by_stage
)
return link
if __name__ == "__main__":
parser = ArgumentParser(description="Benchmark the online serving throughput.")
parser.add_argument(
"--url",
type=str,
default="http://localhost:30000",
help="Server or API base url if not using http host and port.",
)
parser.add_argument(
"--output-dir",
type=str,
default=None,
help="Profile directory to dump profile traces.",
)
parser.add_argument(
"--profile-name",
type=str,
default=None,
help="The name of this profile run.",
)
parser.add_argument(
"--num-steps",
type=int,
default=5,
help="The number of forward steps to profile.",
)
parser.add_argument(
"--profile-by-stage",
action=argparse.BooleanOptionalAction,
type=bool,
default=False,
help="The number of forward steps to profile.",
)
parser.add_argument(
"--cpu",
action=argparse.BooleanOptionalAction,
type=bool,
default=True,
help="Whether to profile CPU activity",
)
parser.add_argument(
"--gpu",
action=argparse.BooleanOptionalAction,
type=bool,
default=True,
help="Whether to profile GPU activity",
)
parser.add_argument(
"--mem",
action=argparse.BooleanOptionalAction,
type=bool,
default=False,
help="Whether to memory usage (https://pytorch.org/memory_viz)",
)
parser.add_argument(
"--rpd",
action=argparse.BooleanOptionalAction,
type=bool,
default=False,
help="Whether to use rpd profiler (https://github.com/ROCm/rocmProfileData)",
)
args = parser.parse_args()
activities = []
if args.cpu:
activities.append("CPU")
if args.gpu:
activities.append("GPU")
if args.mem:
activities.append("MEM")
if args.rpd:
activities.append("RPD")
run_profile(
args.url,
args.num_steps,
activities,
args.output_dir,
args.profile_name,
args.profile_by_stage,
)

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,177 @@
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py
import logging
from typing import List, Optional, Tuple
import torch
from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu, is_npu
logger = logging.getLogger(__name__)
use_vllm_custom_allreduce = get_bool_env_var(
"USE_VLLM_CUSTOM_ALLREDUCE", default="false"
)
if not is_hpu():
# ROCm does not use vllm custom allreduce
if use_vllm_custom_allreduce and not is_hip():
try:
import vllm._C
except ImportError as e:
logger.warning("Failed to import from vllm._C with %r", e)
else:
try:
import sgl_kernel
except ImportError as e:
logger.warning("Failed to import from custom_ar with %r", e)
if not is_hip() and not is_npu():
if use_vllm_custom_allreduce:
custom_op = torch.ops._C_custom_ar
else:
custom_op = sgl_kernel.allreduce
# custom allreduce
def init_custom_ar(
ipc_tensors: List[torch.Tensor],
rank_data: torch.Tensor,
rank: int,
full_nvlink: bool,
) -> int:
return custom_op.init_custom_ar(ipc_tensors, rank_data, rank, full_nvlink)
def all_reduce(
fa: int,
inp: torch.Tensor,
out: torch.Tensor,
reg_buffer: int,
reg_buffer_sz_bytes: int,
) -> None:
custom_op.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes)
def dispose(fa: int) -> None:
custom_op.dispose(fa)
def meta_size() -> int:
return custom_op.meta_size()
def register_buffer(fa: int, ipc_tensors: List[int]) -> None:
return custom_op.register_buffer(fa, ipc_tensors)
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
return custom_op.get_graph_buffer_ipc_meta(fa)
def register_graph_buffers(
fa: int, handles: List[List[int]], offsets: List[List[int]]
) -> None:
custom_op.register_graph_buffers(fa, handles, offsets)
else:
# ROCM custom allreduce
def init_custom_ar(
meta: torch.Tensor,
rank_data: torch.Tensor,
handles: List[str],
offsets: List[int],
rank: int,
full_nvlink: bool,
) -> int:
return sgl_kernel.allreduce.init_custom_ar(
meta, rank_data, handles, offsets, rank, full_nvlink
)
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
sgl_kernel.allreduce.all_reduce_reg(fa, inp, out)
def all_reduce_unreg(
fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor
) -> None:
sgl_kernel.allreduce.all_reduce_unreg(fa, inp, reg_buffer, out)
def dispose(fa: int) -> None:
sgl_kernel.allreduce.dispose(fa)
def meta_size() -> int:
return sgl_kernel.allreduce.meta_size()
def register_buffer(
fa: int, t: torch.Tensor, handles: List[str], offsets: List[int]
) -> None:
return sgl_kernel.allreduce.register_buffer(fa, t, handles, offsets)
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]:
return sgl_kernel.allreduce.get_graph_buffer_ipc_meta(fa)
def register_graph_buffers(
fa: int, handles: List[str], offsets: List[List[int]]
) -> None:
sgl_kernel.allreduce.register_graph_buffers(fa, handles, offsets)
def allocate_meta_buffer(size: int) -> torch.Tensor:
return sgl_kernel.allreduce.allocate_meta_buffer(size)
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
return sgl_kernel.allreduce.get_meta_buffer_ipc_handle(inp)
# ROCM custom quick allreduce
def init_custom_qr(
rank: int, world_size: int, qr_max_size: Optional[int] = None
) -> int:
return sgl_kernel.allreduce.init_custom_qr(world_size, rank, qr_max_size)
def qr_get_handle(fa: int) -> torch.Tensor:
return sgl_kernel.allreduce.qr_get_handle(fa)
def qr_open_handles(fa: int, handles: list[torch.Tensor]) -> None:
sgl_kernel.allreduce.qr_open_handles(fa, handles)
def qr_all_reduce(
fa: int,
inp: torch.Tensor,
out: torch.Tensor,
quant_level: int,
cast_bf2half: bool,
) -> None:
sgl_kernel.allreduce.qr_all_reduce(fa, inp, out, quant_level, cast_bf2half)
def qr_destroy(fa: int) -> None:
sgl_kernel.allreduce.qr_destroy(fa)
def qr_max_size() -> int:
return sgl_kernel.allreduce.qr_max_size()
def mscclpp_generate_unique_id() -> bytes:
return sgl_kernel.allreduce.mscclpp_generate_unique_id()
def mscclpp_init_context(
unique_id: bytes,
rank: int,
world_size: int,
scratch: torch.Tensor,
put_buffer: torch.Tensor,
nranks_per_node: int,
rank_to_node: List[int],
rank_to_ib: List[int],
context_selection: int,
) -> int:
return sgl_kernel.allreduce.mscclpp_init_context(
unique_id,
rank,
world_size,
scratch,
put_buffer,
nranks_per_node,
rank_to_node,
rank_to_ib,
context_selection,
)
def mscclpp_allreduce(
context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int
) -> None:
return sgl_kernel.allreduce.mscclpp_allreduce(context, inp, out, nthreads, nblocks)

View File

@@ -0,0 +1,100 @@
import asyncio
class RWLock:
def __init__(self):
# Protects internal state
self._lock = asyncio.Lock()
# Condition variable used to wait for state changes
self._cond = asyncio.Condition(self._lock)
# Number of readers currently holding the lock
self._readers = 0
# Whether a writer is currently holding the lock
self._writer_active = False
# How many writers are queued waiting for a turn
self._waiting_writers = 0
@property
def reader_lock(self):
"""
A context manager for acquiring a shared (reader) lock.
Example:
async with rwlock.reader_lock:
# read-only access
"""
return _ReaderLock(self)
@property
def writer_lock(self):
"""
A context manager for acquiring an exclusive (writer) lock.
Example:
async with rwlock.writer_lock:
# exclusive access
"""
return _WriterLock(self)
async def acquire_reader(self):
async with self._lock:
# Wait until there is no active writer or waiting writer
# to ensure fairness.
while self._writer_active or self._waiting_writers > 0:
await self._cond.wait()
self._readers += 1
async def release_reader(self):
async with self._lock:
self._readers -= 1
# If this was the last reader, wake up anyone waiting
# (potentially a writer or new readers).
if self._readers == 0:
self._cond.notify_all()
async def acquire_writer(self):
async with self._lock:
# Increment the count of writers waiting
self._waiting_writers += 1
try:
# Wait while either a writer is active or readers are present
while self._writer_active or self._readers > 0:
await self._cond.wait()
self._writer_active = True
finally:
# Decrement waiting writers only after we've acquired the writer lock
self._waiting_writers -= 1
async def release_writer(self):
async with self._lock:
self._writer_active = False
# Wake up anyone waiting (readers or writers)
self._cond.notify_all()
class _ReaderLock:
def __init__(self, rwlock: RWLock):
self._rwlock = rwlock
async def __aenter__(self):
await self._rwlock.acquire_reader()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self._rwlock.release_reader()
class _WriterLock:
def __init__(self, rwlock: RWLock):
self._rwlock = rwlock
async def __aenter__(self):
await self._rwlock.acquire_writer()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self._rwlock.release_writer()

View File

@@ -0,0 +1,137 @@
import os
import sys
from contextlib import nullcontext
import torch
# NOTE copied and modified from DeepGEMM
class suppress_stdout_stderr:
def __enter__(self):
self.outnull_file = open(os.devnull, "w")
self.errnull_file = open(os.devnull, "w")
self.old_stdout_fileno_undup = sys.stdout.fileno()
self.old_stderr_fileno_undup = sys.stderr.fileno()
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
self.old_stdout = sys.stdout
self.old_stderr = sys.stderr
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
sys.stdout = self.outnull_file
sys.stderr = self.errnull_file
return self
def __exit__(self, *_):
sys.stdout = self.old_stdout
sys.stderr = self.old_stderr
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
os.close(self.old_stdout_fileno)
os.close(self.old_stderr_fileno)
self.outnull_file.close()
self.errnull_file.close()
# NOTE copied and modified from DeepGEMM
def bench_kineto(
fn,
kernel_names,
num_tests: int = 30,
suppress_kineto_output: bool = False,
trace_path: str = None,
flush_l2: bool = True,
with_multiple_kernels: bool = False,
):
# Conflict with Nsight Systems
using_nsys = int(os.environ.get("SGLANG_NSYS_PROFILING", 0))
# By default, flush L2 with an excessive 8GB memset to give the GPU some (literal) chill time without full idle
flush_l2_size = int(8e9 // 4)
# For some auto-tuning kernels with prints
fn()
# Profile
suppress = (
suppress_stdout_stderr
if suppress_kineto_output and not using_nsys
else nullcontext
)
with suppress():
schedule = (
torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1)
if not using_nsys
else None
)
profiler = (
torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule
)
if not using_nsys
else nullcontext()
)
with profiler:
for i in range(2):
for _ in range(num_tests):
if flush_l2:
torch.empty(
flush_l2_size, dtype=torch.int, device="cuda"
).zero_()
fn()
if not using_nsys:
profiler.step()
# Return 1 if using Nsight Systems
if using_nsys:
return 1
# Parse the profiling table
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
is_tuple = isinstance(kernel_names, tuple)
prof_lines = (
profiler.key_averages()
.table(sort_by="cuda_time_total", max_name_column_width=100)
.split("\n")
)
kernel_names = (kernel_names,) if isinstance(kernel_names, str) else kernel_names
assert all([isinstance(name, str) for name in kernel_names])
if not with_multiple_kernels:
for name in kernel_names:
assert (
sum([name in line for line in prof_lines]) == 1
), f"Errors of the kernel {name} in the profiling table (table: {prof_lines})"
# Save chrome traces
if trace_path is not None:
profiler.export_chrome_trace(trace_path)
# Return average kernel times
units = {"ms": 1e3, "us": 1e6}
kernel_times = []
for name in kernel_names:
total_time = 0
total_num = 0
for line in prof_lines:
if name in line:
time_str = line.split()[-2]
num_str = line.split()[-1]
for unit, scale in units.items():
if unit in time_str:
total_time += (
float(time_str.replace(unit, "")) / scale * int(num_str)
)
total_num += int(num_str)
break
kernel_times.append(total_time / total_num)
return tuple(kernel_times) if is_tuple else kernel_times[0]

View File

@@ -0,0 +1,27 @@
from sglang.srt.configs.chatglm import ChatGLMConfig
from sglang.srt.configs.dbrx import DbrxConfig
from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config
from sglang.srt.configs.exaone import ExaoneConfig
from sglang.srt.configs.janus_pro import MultiModalityConfig
from sglang.srt.configs.kimi_vl import KimiVLConfig
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
from sglang.srt.configs.longcat_flash import LongcatFlashConfig
from sglang.srt.configs.step3_vl import (
Step3TextConfig,
Step3VisionEncoderConfig,
Step3VLConfig,
)
__all__ = [
"ExaoneConfig",
"ChatGLMConfig",
"DbrxConfig",
"DeepseekVL2Config",
"LongcatFlashConfig",
"MultiModalityConfig",
"KimiVLConfig",
"MoonViTConfig",
"Step3VLConfig",
"Step3TextConfig",
"Step3VisionEncoderConfig",
]

View File

@@ -0,0 +1,78 @@
# Adapted from
# https://github.com/THUDM/ChatGLM2-6B
# https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/configs/chatglm.py
# ChatGLM2 and ChatGLM3 share the same config.
# ChatGLM4 is officially supported by Huggingface
# transformers >= 4.46.0 is required
# https://huggingface.co/docs/transformers/en/model_doc/glm
from transformers import PretrainedConfig
class ChatGLMConfig(PretrainedConfig):
model_type = "chatglm"
attribute_map = {
"num_hidden_layers": "num_layers",
"n_head_kv": "multi_query_group_num",
}
def __init__(
self,
num_layers=28,
padded_vocab_size=65024,
hidden_size=4096,
ffn_hidden_size=13696,
kv_channels=128,
num_attention_heads=32,
seq_length=2048,
hidden_dropout=0.0,
attention_dropout=0.0,
layernorm_epsilon=1e-5,
rmsnorm=True,
apply_residual_connection_post_layernorm=False,
post_layer_norm=True,
add_bias_linear=False,
add_qkv_bias=False,
interleaved_qkv=False,
bias_dropout_fusion=True,
multi_query_attention=False,
multi_query_group_num=1,
apply_query_key_layer_scaling=True,
attention_softmax_in_fp32=True,
fp32_residual_connection=False,
quantization_bit=0,
pre_seq_len=None,
prefix_projection=False,
**kwargs
):
self.num_layers = num_layers
self.vocab_size = padded_vocab_size
self.padded_vocab_size = padded_vocab_size
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
self.kv_channels = kv_channels
self.num_attention_heads = num_attention_heads
self.seq_length = seq_length
# It is to be compatible with long lora.
self.max_position_embeddings = seq_length
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.layernorm_epsilon = layernorm_epsilon
self.rmsnorm = rmsnorm
self.apply_residual_connection_post_layernorm = (
apply_residual_connection_post_layernorm
)
self.post_layer_norm = post_layer_norm
self.add_bias_linear = add_bias_linear
self.add_qkv_bias = add_qkv_bias
self.bias_dropout_fusion = bias_dropout_fusion
self.multi_query_attention = multi_query_attention
self.multi_query_group_num = multi_query_group_num
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
self.fp32_residual_connection = fp32_residual_connection
self.quantization_bit = quantization_bit
self.pre_seq_len = pre_seq_len
self.prefix_projection = prefix_projection
self.interleaved_qkv = interleaved_qkv
super().__init__(**kwargs)

View File

@@ -0,0 +1,279 @@
# Adapted from
# https://huggingface.co/databricks/dbrx-base/blob/main/configuration_dbrx.py
# https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/configs/dbrx.py
"""Dbrx configuration."""
from typing import Any, Optional
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {} # type: ignore
class DbrxAttentionConfig(PretrainedConfig):
"""Configuration class for Dbrx Attention.
[`DbrxAttention`] class. It is used to instantiate attention layers
according to the specified arguments, defining the layers architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
attn_pdrop (`float`, *optional*, defaults to 0.0):
The dropout probability for the attention layers.
clip_qkv (`float`, *optional*, defaults to None):
If not `None`, clip the queries, keys, and values in the attention layer to this value.
kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads.
rope_theta (float): The base frequency for rope.
"""
def __init__(
self,
attn_pdrop: float = 0,
clip_qkv: Optional[float] = None,
kv_n_heads: int = 1,
rope_theta: float = 10000.0,
**kwargs: Any,
):
super().__init__(**kwargs)
self.attn_pdrop = attn_pdrop
self.clip_qkv = clip_qkv
self.kv_n_heads = kv_n_heads
self.rope_theta = rope_theta
for k in ["model_type"]:
if k in kwargs:
kwargs.pop(k)
if len(kwargs) != 0:
raise ValueError(f"Found unknown {kwargs=}")
@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path: str, **kwargs: Any
) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(
pretrained_model_name_or_path, **kwargs
)
if config_dict.get("model_type") == "dbrx":
config_dict = config_dict["attn_config"]
if (
"model_type" in config_dict
and hasattr(cls, "model_type")
and config_dict["model_type"] != cls.model_type
):
logger.warning(
"You are using a model of type %s to instantiate a model of "
"type %s. This is not supported for all configurations of "
"models and can yield errors.",
config_dict["model_type"],
cls.model_type,
)
return cls.from_dict(config_dict, **kwargs)
class DbrxFFNConfig(PretrainedConfig):
"""Configuration class for Dbrx FFN.
[`DbrxFFN`] class. It is used to instantiate feedforward layers according to
the specified arguments, defining the layers architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
ffn_act_fn (dict, optional): A dict specifying activation function for the FFN.
The dict should have a key 'name' with the value being the name of
the activation function along with any additional keyword arguments.
ffn_hidden_size (int, optional): The hidden size of the feedforward network.
moe_num_experts (int, optional): The number of experts in the mixture of experts layer.
moe_top_k (int, optional): The number of experts to use in the mixture of experts layer.
moe_jitter_eps (float, optional): The jitter epsilon for the mixture of experts layer.
moe_loss_weight (float, optional): The loss weight for the mixture of experts layer.
moe_normalize_expert_weights (float, optional): The normalization factor for the expert weights.
uniform_expert_assignment (bool, optional): Whether to use uniform expert assignment.
This should only be used for benchmarking purposes.
"""
def __init__(
self,
ffn_act_fn: Optional[dict] = None,
ffn_hidden_size: int = 3584,
moe_num_experts: int = 4,
moe_top_k: int = 1,
moe_jitter_eps: Optional[float] = None,
moe_loss_weight: float = 0.01,
moe_normalize_expert_weights: Optional[float] = 1,
uniform_expert_assignment: bool = False,
**kwargs: Any,
):
super().__init__()
if ffn_act_fn is None:
ffn_act_fn = {"name": "silu"}
self.ffn_act_fn = ffn_act_fn
self.ffn_hidden_size = ffn_hidden_size
self.moe_num_experts = moe_num_experts
self.moe_top_k = moe_top_k
self.moe_jitter_eps = moe_jitter_eps
self.moe_loss_weight = moe_loss_weight
self.moe_normalize_expert_weights = moe_normalize_expert_weights
self.uniform_expert_assignment = uniform_expert_assignment
for k in ["model_type"]:
if k in kwargs:
kwargs.pop(k)
if len(kwargs) != 0:
raise ValueError(f"Found unknown {kwargs=}")
@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path: str, **kwargs: Any
) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(
pretrained_model_name_or_path, **kwargs
)
if config_dict.get("model_type") == "dbrx":
config_dict = config_dict["ffn_config"]
if (
"model_type" in config_dict
and hasattr(cls, "model_type")
and config_dict["model_type"] != cls.model_type
):
logger.warning(
"You are using a model of type %s to instantiate a model of "
"type %s. This is not supported for all "
"configurations of models and can yield errors.",
config_dict["model_type"],
cls.model_type,
)
return cls.from_dict(config_dict, **kwargs)
class DbrxConfig(PretrainedConfig):
"""Configuration class for Dbrx.
[`DbrxModel`]. It is used to instantiate a Dbrx model according to the
specified arguments, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
d_model (`int`, *optional*, defaults to 6144):
Dimensionality of the embeddings and hidden states.
n_heads (`int`, *optional*, defaults to 48):
Number of attention heads for each attention layer in the Transformer encoder.
n_layers (`int`, *optional*, defaults to 40):
Number of hidden layers in the Transformer encoder.
max_seq_len (`int`, *optional*, defaults to 32768):
The maximum sequence length of the model.
vocab_size (`int`, *optional*, defaults to 100352):
Vocabulary size of the Dbrx model. Defines the maximum number of different tokens that can be represented by
the `inputs_ids` passed when calling [`DbrxModel`].
resid_pdrop (`float`, *optional*, defaults to 0.0):
The dropout probability applied to the attention output before combining with residual.
emb_pdrop (`float`, *optional*, defaults to 0.0):
The dropout probability for the embedding layer.
attn_config (`dict`, *optional*):
A dictionary used to configure the model's attention module.
ffn_config (`dict`, *optional*):
A dictionary used to configure the model's FFN module.
use_cache (`bool`, *optional*, defaults to `False`):
Whether or not the model should return the last key/values attentions (not used by all models).
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
output_router_logits (`bool`, *optional*, defaults to `False`):
Whether or not the router logits should be returned by the model. Enabling this will also
allow the model to output the auxiliary loss. See [here]() for more details
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
The aux loss factor for the total loss.
Example:
```python
>>> from transformers import DbrxConfig, DbrxModel
>>> # Initializing a Dbrx configuration
>>> configuration = DbrxConfig()
>>> # Initializing a model (with random weights) from the configuration
>>> model = DbrxModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "dbrx"
attribute_map = {
"num_attention_heads": "n_heads",
"hidden_size": "d_model",
"num_hidden_layers": "n_layers",
"max_position_embeddings": "max_seq_len",
}
def __init__(
self,
d_model: int = 2048,
n_heads: int = 16,
n_layers: int = 24,
max_seq_len: int = 2048,
vocab_size: int = 32000,
resid_pdrop: float = 0.0,
emb_pdrop: float = 0.0,
attn_config: Optional[DbrxAttentionConfig] = None,
ffn_config: Optional[DbrxFFNConfig] = None,
use_cache: bool = True,
initializer_range: float = 0.02,
output_router_logits: bool = False,
router_aux_loss_coef: float = 0.05,
**kwargs: Any,
):
if attn_config is None:
self.attn_config = DbrxAttentionConfig()
elif isinstance(attn_config, dict):
self.attn_config = DbrxAttentionConfig(**attn_config)
else:
self.attn_config = attn_config
if ffn_config is None:
self.ffn_config = DbrxFFNConfig()
elif isinstance(ffn_config, dict):
self.ffn_config = DbrxFFNConfig(**ffn_config)
else:
self.ffn_config = ffn_config
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
self.max_seq_len = max_seq_len
self.vocab_size = vocab_size
self.resid_pdrop = resid_pdrop
self.emb_pdrop = emb_pdrop
self.use_cache = use_cache
self.initializer_range = initializer_range
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
if tie_word_embeddings:
raise ValueError("tie_word_embeddings is not supported for Dbrx models.")
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

View File

@@ -0,0 +1,688 @@
import math
import os
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import torch
from PIL import Image, ImageOps
from transformers import (
AutoProcessor,
LlamaTokenizerFast,
PretrainedConfig,
ProcessorMixin,
)
def select_best_resolution(image_size, candidate_resolutions):
# used for cropping
original_width, original_height = image_size
best_fit = None
max_effective_resolution = 0
min_wasted_resolution = float("inf")
for width, height in candidate_resolutions:
scale = min(width / original_width, height / original_height)
downscaled_width, downscaled_height = int(original_width * scale), int(
original_height * scale
)
effective_resolution = min(
downscaled_width * downscaled_height, original_width * original_height
)
wasted_resolution = (width * height) - effective_resolution
if effective_resolution > max_effective_resolution or (
effective_resolution == max_effective_resolution
and wasted_resolution < min_wasted_resolution
):
max_effective_resolution = effective_resolution
min_wasted_resolution = wasted_resolution
best_fit = (width, height)
return best_fit
class DictOutput(object):
def items(self):
return self.__dict__.items()
def keys(self):
return self.__dict__.keys()
def __getitem__(self, item):
return self.__dict__[item]
def __contains__(self, key):
return key in self.__dict__
def __setitem__(self, key, value):
self.__dict__[key] = value
@dataclass
class VLChatProcessorOutput(DictOutput):
input_ids: torch.LongTensor
target_ids: torch.LongTensor
pixel_values: (
torch.Tensor
) # rename from "images" to "pixel_values" for compatibility
images_seq_mask: torch.BoolTensor
images_spatial_crop: torch.LongTensor
def __len__(self):
return len(self.input_ids)
class ImageTransform(object):
def __init__(
self,
mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
normalize: bool = True,
):
self.mean = mean
self.std = std
self.normalize = normalize
# only load torchvision.transforms when needed
try:
import torchvision.transforms as T
# FIXME: add version check for gguf
except ImportError as err:
raise ImportError(
"Please install torchvision via `pip install torchvision` to use Deepseek-VL2."
) from err
transform_pipelines = [T.ToTensor()]
if normalize:
transform_pipelines.append(T.Normalize(mean, std))
self.transform = T.Compose(transform_pipelines)
def __call__(self, pil_img: Image.Image):
x = self.transform(pil_img)
return x
class DeepseekVLV2Processor(ProcessorMixin):
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
attributes = ["tokenizer"]
def __init__(
self,
tokenizer: LlamaTokenizerFast,
candidate_resolutions: Tuple[Tuple[int, int]],
patch_size: int,
downsample_ratio: int,
image_mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
image_std: Tuple[float, float, float] = (0.5, 0.5, 0.5),
normalize: bool = True,
image_token: str = "<image>",
pad_token: str = "<▁pad▁>",
add_special_token: bool = False,
sft_format: str = "deepseek",
mask_prompt: bool = True,
ignore_id: int = -100,
**kwargs,
):
self.candidate_resolutions = candidate_resolutions
self.image_size = candidate_resolutions[0][0]
self.patch_size = patch_size
self.image_mean = image_mean
self.image_std = image_std
self.normalize = normalize
self.downsample_ratio = downsample_ratio
self.image_transform = ImageTransform(
mean=image_mean, std=image_std, normalize=normalize
)
self.tokenizer = tokenizer
# must set thispadding side with make a difference in batch inference
self.tokenizer.padding_side = "left"
# add the pad_token as special token to use 'tokenizer.pad_token' and 'tokenizer.pad_token_id'
if tokenizer.pad_token is None:
self.tokenizer.add_special_tokens({"pad_token": pad_token})
# add image token
image_token_id = self.tokenizer.vocab.get(image_token)
if image_token_id is None:
special_tokens = [image_token]
special_tokens_dict = {"additional_special_tokens": special_tokens}
self.tokenizer.add_special_tokens(special_tokens_dict)
self.image_token_id = self.tokenizer.vocab.get(image_token)
# add five special tokens for grounding-related tasks
# <|ref|>, <|/ref|>, <|det|>, <|/det|>, <|grounding|>
special_tokens = ["<|ref|>", "<|/ref|>", "<|det|>", "<|/det|>", "<|grounding|>"]
special_tokens_dict = {"additional_special_tokens": special_tokens}
self.tokenizer.add_special_tokens(special_tokens_dict)
# add special tokens for SFT data
special_tokens = ["<|User|>", "<|Assistant|>"]
special_tokens_dict = {"additional_special_tokens": special_tokens}
self.tokenizer.add_special_tokens(special_tokens_dict)
self.image_token = image_token
self.pad_token = pad_token
self.add_special_token = add_special_token
self.sft_format = sft_format
self.mask_prompt = mask_prompt
self.ignore_id = ignore_id
super().__init__(
tokenizer,
**kwargs,
)
def format_messages_v2(self, messages, pil_images, max_req_input_len=-1):
"""play the role of format_messages_v2 and get_images_info in the last version"""
tokenized_data = []
masked_tokenized_data = [] # labels
images_list = []
images_seq_mask = []
images_spatial_crop = []
image_index = 0
image_token_cnt = messages.count(self.image_token)
tokenized_str, images, seq_mask, spatial_crop = self.tokenize_with_images(
messages,
pil_images[image_index : image_index + image_token_cnt],
bos=True,
eos=True,
cropping=len(pil_images) <= 2,
max_req_input_len=max_req_input_len,
)
image_index = image_token_cnt
tokenized_data += tokenized_str
if self.mask_prompt:
masked_tokenized_data += [self.ignore_id] * len(tokenized_str)
else:
masked_tokenized_data += tokenized_str
images_list += images
images_seq_mask += seq_mask
images_spatial_crop += spatial_crop
assert len(tokenized_data) == len(
images_seq_mask
), f"format_messages_v2: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}"
return (
tokenized_data,
masked_tokenized_data,
images_list,
images_seq_mask,
images_spatial_crop,
)
@property
def bos_id(self):
return self.tokenizer.bos_token_id
@property
def eos_id(self):
return self.tokenizer.eos_token_id
@property
def pad_id(self):
return self.tokenizer.pad_token_id
def encode(self, text: str, bos: bool = True, eos: bool = False):
t = self.tokenizer.encode(text, add_special_tokens=False)
if bos:
t = [self.bos_id] + t
if eos:
t = t + [self.eos_id]
return t
def decode(self, t: List[int], **kwargs) -> str:
return self.tokenizer.decode(t, **kwargs)
def process_one(
self,
prompt: str = None,
conversations: List[Dict[str, str]] = None,
images: List[Image.Image] = None,
apply_sft_format: bool = False,
inference_mode: bool = True,
system_prompt: str = "",
max_req_input_len: int = -1,
**kwargs,
):
"""
Args:
prompt (str): the formatted prompt;
conversations (List[Dict]): conversations with a list of messages;
images (List[ImageType]): the list of images;
apply_sft_format (bool): if prompt is not None, then apply the SFT format to prompt;
if conversations is not None, then it will always apply the SFT format to conversations;
inference_mode (bool): if True, then remove the last eos token;
system_prompt (str): the system prompt;
**kwargs:
Returns:
outputs (BaseProcessorOutput): the output of the processor,
- input_ids (torch.LongTensor): [N + image tokens]
- target_ids (torch.LongTensor): [N + image tokens]
- images (torch.FloatTensor): [n_images, 3, H, W]
- image_id (int): the id of the image token
- num_image_tokens (List[int]): the number of image tokens
"""
assert (
prompt is None or conversations is None
), "prompt and conversations cannot be used at the same time."
(
tokenized_str,
masked_tokenized_str,
images_list,
images_seq_mask,
images_spatial_crop,
) = self.format_messages_v2(conversations, images, max_req_input_len)
assert (
len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str)
), (
f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, "
f"imags_seq_mask's length {len(images_seq_mask)}, are not equal"
)
input_ids = torch.LongTensor(tokenized_str)
target_ids = torch.LongTensor(masked_tokenized_str)
images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
# set input_ids < 0 | input_ids == self.image_token_id as ignore_id
target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = (
self.ignore_id
)
input_ids[input_ids < 0] = self.pad_id
if inference_mode:
assert input_ids[-1] == self.eos_id
input_ids = input_ids[:-1]
target_ids = target_ids[:-1]
images_seq_mask = images_seq_mask[:-1]
if len(images_list) == 0:
images = torch.zeros((1, 3, self.image_size, self.image_size))
images_spatial_crop = torch.zeros((1, 2), dtype=torch.long)
else:
images = torch.stack(images_list, dim=0)
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
images_spatial_crop = torch.stack(
[images_spatial_crop], dim=0
) # stack the tensor to make it a batch of 1
prepare = VLChatProcessorOutput(
input_ids=input_ids,
target_ids=target_ids,
pixel_values=images,
images_seq_mask=images_seq_mask,
images_spatial_crop=images_spatial_crop,
)
return prepare
def __call__(
self,
*,
prompt: str = None,
conversations: List[Dict[str, str]] = None,
images: List[Image.Image] = None,
apply_sft_format: bool = False,
inference_mode: bool = True,
system_prompt: str = "",
max_req_input_len: int = -1,
**kwargs,
):
prepare = self.process_one(
prompt=prompt,
conversations=conversations,
images=images,
apply_sft_format=apply_sft_format,
inference_mode=inference_mode,
system_prompt=system_prompt,
max_req_input_len=max_req_input_len,
)
return prepare
def find_all_indices(self, messages, target_value):
indices = []
for index, item in enumerate(messages):
if item == target_value:
indices.append(index)
return indices
def tokenize_with_images(
self,
conversation: str,
images: List[Image.Image],
bos: bool = True,
eos: bool = True,
cropping: bool = True,
max_req_input_len: int = -1,
):
"""Tokenize text with <image> tags."""
images_list, images_seq_mask, images_spatial_crop = [], [], []
text_splits = conversation.split(self.image_token)
tokenized_str = []
for text_sep, image in zip(text_splits, images):
"""encode text_sep"""
tokenized_sep = self.encode(text_sep, bos=False, eos=False)
tokenized_str += tokenized_sep
images_seq_mask += [False] * len(tokenized_sep)
"""select best resolution for anyres"""
if cropping:
best_width, best_height = select_best_resolution(
image.size, self.candidate_resolutions
)
else:
best_width, best_height = self.image_size, self.image_size
# print(image.size, (best_width, best_height)) # check the select_best_resolutions func
"""process the global view"""
global_view = ImageOps.pad(
image,
(self.image_size, self.image_size),
color=tuple(int(x * 255) for x in self.image_transform.mean),
)
images_list.append(self.image_transform(global_view))
"""process the local views"""
local_view = ImageOps.pad(
image,
(best_width, best_height),
color=tuple(int(x * 255) for x in self.image_transform.mean),
)
for i in range(0, best_height, self.image_size):
for j in range(0, best_width, self.image_size):
images_list.append(
self.image_transform(
local_view.crop(
(j, i, j + self.image_size, i + self.image_size)
)
)
)
"""record height / width crop num"""
num_width_tiles, num_height_tiles = (
best_width // self.image_size,
best_height // self.image_size,
)
images_spatial_crop.append([num_width_tiles, num_height_tiles])
"""add image tokens"""
h = w = math.ceil(
(self.image_size // self.patch_size) / self.downsample_ratio
)
# global views tokens h * (w + 1), 1 is for line separator
tokenized_image = [self.image_token_id] * h * (w + 1)
# add a separator between global and local views
tokenized_image += [self.image_token_id]
# local views tokens, (num_height_tiles * h) * (num_width_tiles * w + 1)
tokenized_image += (
[self.image_token_id]
* (num_height_tiles * h)
* (num_width_tiles * w + 1)
)
tokenized_str += tokenized_image
images_seq_mask += [True] * len(tokenized_image)
# print(width_crop_num, height_crop_num, len(tokenized_image)) # test the correctness of the number of image-related tokens
"""process the last text split"""
tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False)
# deal with video, limit with request len
if max_req_input_len > -1:
if max_req_input_len < len(tokenized_sep) + len(tokenized_str) - 1:
rest = max_req_input_len - len(tokenized_sep) - 1 - 1024
tokenized_str = tokenized_str[:rest]
images_seq_mask = images_seq_mask[:rest]
tokenized_str += tokenized_sep
images_seq_mask += [False] * len(tokenized_sep)
"""add the bos and eos tokens"""
if bos:
tokenized_str = [self.bos_id] + tokenized_str
images_seq_mask = [False] + images_seq_mask
if eos:
tokenized_str = tokenized_str + [self.eos_id]
images_seq_mask = images_seq_mask + [False]
assert len(tokenized_str) == len(
images_seq_mask
), f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}"
return tokenized_str, images_list, images_seq_mask, images_spatial_crop
class DeepseekVL2VisionEncoderConfig(PretrainedConfig):
model_type: str = "vision"
model_name: str = "siglip_large_patch16_384"
image_size: int = 384
patch_size: int = 16
width: int = 1024
layers: int = 24
heads: int = 16
mlp_ratio: int = 4
global_pool: str = "map"
ignore_head: bool = True
class_token: bool = False
num_classes: int = 0
use_checkpoint: bool = False
weight_init: str = "skip"
deterministic: bool = False
num_recomputing_layers: int = 0
def __init__(
self,
model_name: str = "siglip_large_patch16_384",
image_size: int = 384,
patch_size: int = 16,
width: int = 1024,
layers: int = 24,
heads: int = 16,
mlp_ratio: int = 4,
global_pool: str = "map",
ignore_head: bool = True,
class_token: bool = False,
num_classes: int = 0,
use_checkpoint: bool = False,
**kwargs,
):
self.model_name = model_name
self.image_size = image_size
self.patch_size = patch_size
self.width = width
self.layers = layers
self.heads = heads
self.mlp_ratio = mlp_ratio
self.global_pool = global_pool
self.ignore_head = ignore_head
self.class_token = class_token
self.num_classes = num_classes
self.use_checkpoint = use_checkpoint
super().__init__(**kwargs)
class DeepseekVL2MlpProjectorConfig(PretrainedConfig):
model_type = "mlp_projector"
projector_type: str = "downsample_mlp_gelu"
input_dim: int = 1152
n_embed: int = 2048
depth: int = 2
mlp_ratio: int = 1
downsample_ratio: int = 2
token_pooling: bool = False
def __init__(
self,
projector_type: str = "downsample_mlp_gelu",
input_dim: int = 1152,
n_embed: int = 2048,
depth: int = 2,
mlp_ratio: int = 1,
downsample_ratio: int = 2,
**kwargs,
):
self.projector_type = projector_type
self.input_dim = input_dim
self.n_embed = n_embed
self.depth = depth
self.mlp_ratio = mlp_ratio
self.downsample_ratio = downsample_ratio
super().__init__(**kwargs)
class DeepseekV2Config(PretrainedConfig):
model_type = "deepseek_v2"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=102400,
hidden_size=4096,
intermediate_size=11008,
moe_intermediate_size=1407,
num_hidden_layers=30,
num_attention_heads=32,
num_key_value_heads=32,
n_shared_experts=None,
n_routed_experts=None,
ep_size=1,
routed_scaling_factor=1.0,
kv_lora_rank=512,
q_lora_rank=1536,
qk_rope_head_dim=64,
v_head_dim=128,
qk_nope_head_dim=128,
topk_method="gready",
n_group=None,
topk_group=None,
num_experts_per_tok=None,
moe_layer_freq=1,
first_k_dense_replace=0,
norm_topk_prob=False,
scoring_func="softmax",
aux_loss_alpha=0.001,
seq_aux=True,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=100000,
eos_token_id=100001,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
use_mla=True,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.moe_intermediate_size = moe_intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.n_shared_experts = n_shared_experts
self.n_routed_experts = n_routed_experts
self.ep_size = ep_size
self.routed_scaling_factor = routed_scaling_factor
self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.topk_method = topk_method
self.n_group = n_group
self.topk_group = topk_group
self.num_experts_per_tok = num_experts_per_tok
self.moe_layer_freq = moe_layer_freq
self.first_k_dense_replace = first_k_dense_replace
self.norm_topk_prob = norm_topk_prob
self.scoring_func = scoring_func
self.aux_loss_alpha = aux_loss_alpha
self.seq_aux = seq_aux
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = float(rms_norm_eps)
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.use_mla = use_mla
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
class DeepseekVL2Config(PretrainedConfig):
model_type = "deepseek_vl_v2"
vision_config: DeepseekVL2VisionEncoderConfig
projector_config: DeepseekVL2MlpProjectorConfig
language_config: DeepseekV2Config
tile_tag: str = "2D"
global_view_pos: str = "head"
candidate_resolutions: Tuple[Tuple[int, int]] = ((384, 384),)
def __init__(
self,
tile_tag: str = "tile_tag",
global_view_pos: str = "head",
candidate_resolutions: Tuple[Tuple[int, int]] = ((384, 384),),
**kwargs,
):
super().__init__(**kwargs)
vision_config = kwargs.get("vision_config", {})
self.vision_config = DeepseekVL2VisionEncoderConfig(**vision_config)
projector_config = kwargs.get("projector_config", {})
self.projector_config = DeepseekVL2MlpProjectorConfig(**projector_config)
language_config = kwargs.get("language_config", {})
if isinstance(language_config, DeepseekV2Config):
self.language_config = language_config
else:
self.language_config = DeepseekV2Config(**language_config)
self.tile_tag = tile_tag
self.global_view_pos = global_view_pos
self.candidate_resolutions = candidate_resolutions
self.architectures = ["DeepseekVL2ForCausalLM"]
AutoProcessor.register(DeepseekVL2Config, DeepseekVLV2Processor)

View File

@@ -0,0 +1,17 @@
import logging
from typing import Optional
import torch
logger = logging.getLogger(__name__)
class DeviceConfig:
device: Optional[torch.device]
def __init__(self, device: str = "cuda") -> None:
if device in ["cuda", "xpu", "hpu", "cpu", "npu"]:
self.device_type = device
else:
raise RuntimeError(f"Not supported device type: {device}")
self.device = torch.device(self.device_type)

View File

@@ -0,0 +1,195 @@
# coding=utf-8
# Copyright 2024 The LG AI Research EXAONE Lab. All rights reserved.
# Copyright 2024 The LG CNS AI Engineering Team.
# Copyright 2023-2024 SGLang Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" EXAONE model configuration """
from typing import Any, Dict
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
EXAONE_PRETRAINED_CONFIG_ARCHIVE_MAP: Dict[str, Any] = {}
# ruff: noqa: E501
class ExaoneConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a :class:`~transformers.ExaoneModel`. It is used to
instantiate a EXAONE model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the Exaone
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
Args:
vocab_size (:obj:`int`, `optional`, defaults to 102400):
Vocabulary size of the EXAONE model. Defines the number of different tokens that can be represented by the
:obj:`inputs_ids` passed when calling :class:`~transformers.ExaoneModel`. Vocabulary size of the model.
Defines the different tokens that can be represented by the `inputs_ids` passed to the forward method of
:class:`~transformers.EXAONEModel`.
max_position_embeddings (:obj:`int`, `optional`, defaults to 2048):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
hidden_size (:obj:`int`, `optional`, defaults to 2048):
Dimensionality of the encoder layers and the pooler layer.
num_layers (:obj:`int`, `optional`, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (:obj:`int`, `optional`, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (:obj:`int`, `optional`):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
intermediate_size (:obj:`int`, `optional`, defaults to `hidden_size * 4`):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"silu"`):
The non-linear activation function (function or string) in the decoder.
rope_theta (:obj:`float`, `optional`, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (:obj:`Dict`, `optional`):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (:obj:`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (:obj:`float`, `optional`):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (:obj:`int`, `optional`):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (:obj:`float`, `optional`):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (:obj:`float`, `optional`):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (:obj:`float`, `optional`):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (:obj:`List[float]`, `optional`):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (:obj:`List[float]`, `optional`):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (:obj:`float`, `optional`):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (:obj:`float`, `optional`):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
embed_dropout (:obj:`float`, `optional`, defaults to 0.0):
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (:obj:`float`, `optional`, defaults to 0.0):
The dropout ratio for the attention probabilities.
layer_norm_epsilon (:obj:`float`, `optional`, defaults to 1e-5):
The epsilon used by the layer normalization layers.
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if ``configs.is_decoder=True``.
bos_token_id (:obj:`int`, `optional`, defaults to 0):
Beginning of stream token id.
eos_token_id (:obj:`int`, `optional`, defaults to 2):
End of stream token id.
tie_word_embeddings (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to tie weight embeddings
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
Example::
>>> from transformers import EXAONEModel, ExaoneConfig
>>> # Initializing a EXAONE configuration
>>> configuration = ExaoneConfig()
>>> # Initializing a model from configuration
>>> model = EXAONEModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.configs
"""
model_type = "exaone"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"num_hidden_layers": "num_layers"}
def __init__(
self,
vocab_size=102400,
max_position_embeddings=2048,
hidden_size=2048,
num_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
intermediate_size=None,
activation_function="silu",
rope_theta=10000.0,
rope_scaling=None,
embed_dropout=0.0,
attention_dropout=0.0,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
use_cache=True,
bos_token_id=0,
eos_token_id=2,
tie_word_embeddings=True,
**kwargs
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.num_layers = num_layers
self.num_attention_heads = num_attention_heads
self.num_hidden_layers = num_layers
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
if intermediate_size:
self.intermediate_size = intermediate_size
else:
self.intermediate_size = hidden_size * 4
self.activation_function = activation_function
self.embed_dropout = embed_dropout
self.attention_dropout = attention_dropout
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
super().__init__(
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs
)

View File

@@ -0,0 +1,706 @@
import copy
import os
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union
import sentencepiece as spm
from transformers import (
TOKENIZER_MAPPING,
GptOssConfig,
LlamaConfig,
PretrainedConfig,
PreTrainedTokenizer,
Qwen2Config,
Qwen3Config,
Qwen3MoeConfig,
)
from sglang.utils import logger
# Copied from: https://github.com/OpenGVLab/InternVL/blob/34a81000402bf8f716bab8c9b57aff1f6b436bd0/internvl_chat/internvl/model/internvl_chat/configuration_internvl_chat.py#L21
VOCAB_FILES_NAMES = {"vocab_file": "./tokenizer.model"}
PRETRAINED_VOCAB_FILES_MAP = {}
# Modified from transformers.model.llama.configuration_llama.LlamaConfig
class InternLM2Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate
an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the InternLM2-7B.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`InternLM2Model`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 11008):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
Example:
"""
model_type = "internlm2"
_auto_class = "AutoConfig"
def __init__( # pylint: disable=W0102
self,
vocab_size=103168,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
bias=True,
rope_theta=10000,
rope_scaling=None,
attn_implementation="eager",
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.bias = bias
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self._rope_scaling_validation()
self.attn_implementation = attn_implementation
if self.attn_implementation is None:
self.attn_implementation = "eager"
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if (
rope_scaling_factor is None
or not isinstance(rope_scaling_factor, (float, int))
or rope_scaling_factor < 1.0
):
raise ValueError(
f"`rope_scaling`'s factor field must be a float|int >= 1, got {rope_scaling_factor=}, {type(rope_scaling_factor)=}"
)
if isinstance(rope_scaling_factor, int):
rope_scaling_factor = float(rope_scaling_factor)
class InternVisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to
instantiate a vision encoder according to the specified arguments, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
num_channels (`int`, *optional*, defaults to 3):
Number of color channels in the input images (e.g., 3 for RGB).
patch_size (`int`, *optional*, defaults to 14):
The size (resolution) of each patch.
image_size (`int`, *optional*, defaults to 224):
The size (resolution) of each image.
qkv_bias (`bool`, *optional*, defaults to `False`):
Whether to add a bias to the queries and values in the self-attention layers.
hidden_size (`int`, *optional*, defaults to 3200):
Dimensionality of the encoder layers and the pooler layer.
num_attention_heads (`int`, *optional*, defaults to 25):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (`int`, *optional*, defaults to 12800):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
qk_normalization (`bool`, *optional*, defaults to `True`):
Whether to normalize the queries and keys in the self-attention layers.
num_hidden_layers (`int`, *optional*, defaults to 48):
Number of hidden layers in the Transformer encoder.
use_flash_attn (`bool`, *optional*, defaults to `True`):
Whether to use flash attention mechanism.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported.
layer_norm_eps (`float`, *optional*, defaults to 1e-6):
The epsilon used by the layer normalization layers.
dropout (`float`, *optional*, defaults to 0.0):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
drop_path_rate (`float`, *optional*, defaults to 0.0):
Dropout rate for stochastic depth.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
initializer_factor (`float`, *optional*, defaults to 0.1):
A factor for layer scale.
"""
model_type = "intern_vit_6b"
def __init__(
self,
num_channels=3,
patch_size=14,
image_size=224,
qkv_bias=False,
hidden_size=3200,
num_attention_heads=25,
intermediate_size=12800,
qk_normalization=True,
num_hidden_layers=48,
use_flash_attn=True,
hidden_act="gelu",
layer_norm_eps=1e-6,
dropout=0.0,
drop_path_rate=0.0,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=0.1,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.dropout = dropout
self.drop_path_rate = drop_path_rate
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.initializer_range = initializer_range
self.initializer_factor = initializer_factor
self.attention_dropout = attention_dropout
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
self.qkv_bias = qkv_bias
self.qk_normalization = qk_normalization
self.use_flash_attn = use_flash_attn
@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> "PretrainedConfig":
config_dict, kwargs = cls.get_config_dict(
pretrained_model_name_or_path, **kwargs
)
if "vision_config" in config_dict:
config_dict = config_dict["vision_config"]
if (
"model_type" in config_dict
and hasattr(cls, "model_type")
and config_dict["model_type"] != cls.model_type
):
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
class InternVLChatConfig(PretrainedConfig):
model_type = "internvl_chat"
is_composition = True
def __init__(
self,
vision_config=None,
llm_config=None,
use_backbone_lora=0,
use_llm_lora=0,
pad2square=False,
select_layer=-1,
force_image_size=None,
downsample_ratio=0.5,
template=None,
dynamic_image_size=False,
use_thumbnail=False,
ps_version="v1",
min_dynamic_patch=1,
max_dynamic_patch=6,
**kwargs,
):
super().__init__(**kwargs)
if vision_config is None:
vision_config = {"architectures": ["InternVisionModel"]}
logger.info(
"vision_config is None. Initializing the InternVisionConfig with default values."
)
if llm_config is None:
llm_config = {"architectures": ["InternLM2ForCausalLM"]}
logger.info(
"llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`)."
)
self.vision_config = InternVisionConfig(**vision_config)
if llm_config.get("architectures")[0] == "LlamaForCausalLM":
self.llm_config = LlamaConfig(**llm_config)
elif llm_config.get("architectures")[0] == "InternLM2ForCausalLM":
self.llm_config = InternLM2Config(**llm_config)
elif llm_config.get("architectures")[0] == "Qwen2ForCausalLM":
self.llm_config = Qwen2Config(**llm_config)
elif llm_config.get("architectures")[0] == "Qwen3MoeForCausalLM":
self.llm_config = Qwen3MoeConfig(**llm_config)
elif llm_config.get("architectures")[0] == "Qwen3ForCausalLM":
self.llm_config = Qwen3Config(**llm_config)
elif llm_config.get("architectures")[0] == "GptOssForCausalLM":
self.llm_config = GptOssConfig(**llm_config)
else:
raise ValueError(
"Unsupported architecture: {}".format(
llm_config.get("architectures")[0]
)
)
self.use_backbone_lora = use_backbone_lora
self.use_llm_lora = use_llm_lora
self.pad2square = pad2square
self.select_layer = select_layer
self.force_image_size = force_image_size
self.downsample_ratio = downsample_ratio
self.template = template
self.dynamic_image_size = dynamic_image_size
self.use_thumbnail = use_thumbnail
self.ps_version = ps_version # pixel shuffle version
self.min_dynamic_patch = min_dynamic_patch
self.max_dynamic_patch = max_dynamic_patch
self.hidden_size = self.llm_config.hidden_size
# By default, we use tie_word_embeddings=False for models of all sizes.
self.tie_word_embeddings = False
self.llm_config.tie_word_embeddings = self.tie_word_embeddings
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["vision_config"] = self.vision_config.to_dict()
output["llm_config"] = self.llm_config.to_dict()
output["model_type"] = self.__class__.model_type
output["use_backbone_lora"] = self.use_backbone_lora
output["use_llm_lora"] = self.use_llm_lora
output["select_layer"] = self.select_layer
output["force_image_size"] = self.force_image_size
output["downsample_ratio"] = self.downsample_ratio
output["template"] = self.template
output["dynamic_image_size"] = self.dynamic_image_size
output["use_thumbnail"] = self.use_thumbnail
output["ps_version"] = self.ps_version
output["min_dynamic_patch"] = self.min_dynamic_patch
output["max_dynamic_patch"] = self.max_dynamic_patch
return output
# # Modified from transformers.model.llama.tokenization_llama_fast.LlamaTokenizerFast -> InternLM2TokenizerFast
# class InternLM2TokenizerFast(PreTrainedTokenizerFast):
# vocab_files_names = VOCAB_FILES_NAMES
# slow_tokenizer_class = InternLM2Tokenizer
# padding_side = 'left'
# model_input_names = ['input_ids', 'attention_mask']
# _auto_class = 'AutoTokenizer'
#
# def __init__(
# self,
# vocab_file,
# unk_token='<unk>',
# bos_token='<s>',
# eos_token='</s>',
# pad_token='</s>',
# sp_model_kwargs: Optional[Dict[str, Any]] = None,
# add_bos_token=True,
# add_eos_token=False,
# decode_with_prefix_space=False,
# clean_up_tokenization_spaces=False,
# **kwargs,
# ):
# super().__init__(
# vocab_file=vocab_file,
# unk_token=unk_token,
# bos_token=bos_token,
# eos_token=eos_token,
# pad_token=pad_token,
# sp_model_kwargs=sp_model_kwargs,
# add_bos_token=add_bos_token,
# add_eos_token=add_eos_token,
# decode_with_prefix_space=decode_with_prefix_space,
# clean_up_tokenization_spaces=clean_up_tokenization_spaces,
# **kwargs,
# )
# self._add_bos_token = add_bos_token
# self._add_eos_token = add_eos_token
# self.update_post_processor()
# self.vocab_file = vocab_file
#
# @property
# def can_save_slow_tokenizer(self) -> bool:
# return os.path.isfile(self.vocab_file) if self.vocab_file else False
#
# def update_post_processor(self):
# """
# Updates the underlying post processor with the current `bos_token` and `eos_token`.
# """
# bos = self.bos_token
# bos_token_id = self.bos_token_id
# if bos is None and self.add_bos_token:
# raise ValueError('add_bos_token = True but bos_token = None')
#
# eos = self.eos_token
# eos_token_id = self.eos_token_id
# if eos is None and self.add_eos_token:
# raise ValueError('add_eos_token = True but eos_token = None')
#
# single = f"{(bos + ':0 ') if self.add_bos_token else ''}$A:0{(' ' + eos + ':0') if self.add_eos_token else ''}"
# pair = f"{single}{(' ' + bos + ':1') if self.add_bos_token else ''} $B:1{(' ' + eos + ':1') if self.add_eos_token else ''}"
#
# special_tokens = []
# if self.add_bos_token:
# special_tokens.append((bos, bos_token_id))
# if self.add_eos_token:
# special_tokens.append((eos, eos_token_id))
# self._tokenizer.post_processor = processors.TemplateProcessing(
# single=single, pair=pair, special_tokens=special_tokens
# )
#
# @property
# def add_eos_token(self):
# return self._add_eos_token
#
# @property
# def add_bos_token(self):
# return self._add_bos_token
#
# @add_eos_token.setter
# def add_eos_token(self, value):
# self._add_eos_token = value
# self.update_post_processor()
#
# @add_bos_token.setter
# def add_bos_token(self, value):
# self._add_bos_token = value
# self.update_post_processor()
#
# def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
# if not self.can_save_slow_tokenizer:
# raise ValueError(
# 'Your fast tokenizer does not have the necessary information to save the vocabulary for a slow '
# 'tokenizer.'
# )
#
# if not os.path.isdir(save_directory):
# logger.error(f'Vocabulary path ({save_directory}) should be a directory')
# return
# out_vocab_file = os.path.join(
# save_directory, (filename_prefix + '-' if filename_prefix else '') + VOCAB_FILES_NAMES['vocab_file']
# )
#
# if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
# copyfile(self.vocab_file, out_vocab_file)
#
# return (out_vocab_file,)
# Modified from transformers.model.llama.tokenization_llama.LlamaTokenizer
class InternLM2Tokenizer(PreTrainedTokenizer):
"""
Construct a InternLM2 tokenizer. Based on byte-level Byte-Pair-Encoding.
Args:
vocab_file (`str`):
Path to the vocabulary file.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
model_input_names = ["input_ids", "attention_mask"]
_auto_class = "AutoTokenizer"
def __init__(
self,
vocab_file,
unk_token="<unk>",
bos_token="<s>",
eos_token="</s>",
pad_token="</s>",
sp_model_kwargs: Optional[Dict[str, Any]] = None,
add_bos_token=True,
add_eos_token=False,
decode_with_prefix_space=False,
clean_up_tokenization_spaces=False,
**kwargs,
):
print("register succeed")
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
self.vocab_file = vocab_file
self.add_bos_token = add_bos_token
self.add_eos_token = add_eos_token
self.decode_with_prefix_space = decode_with_prefix_space
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(vocab_file)
self._no_prefix_space_tokens = None
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)
@property
def no_prefix_space_tokens(self):
if self._no_prefix_space_tokens is None:
vocab = self.convert_ids_to_tokens(list(range(self.vocab_size)))
self._no_prefix_space_tokens = {
i for i, tok in enumerate(vocab) if not tok.startswith("")
}
return self._no_prefix_space_tokens
@property
def vocab_size(self):
"""Returns vocab size"""
return self.sp_model.get_piece_size()
@property
def bos_token_id(self) -> Optional[int]:
return self.sp_model.bos_id()
@property
def eos_token_id(self) -> Optional[int]:
return self.sp_model.eos_id()
def get_vocab(self):
"""Returns vocab as a dict"""
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def _tokenize(self, text):
"""Returns a tokenized string."""
return self.sp_model.encode(text, out_type=str)
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
return self.sp_model.piece_to_id(token)
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
token = self.sp_model.IdToPiece(index)
return token
def _maybe_add_prefix_space(self, tokens, decoded):
if tokens and tokens[0] not in self.no_prefix_space_tokens:
return " " + decoded
else:
return decoded
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
current_sub_tokens = []
out_string = ""
prev_is_special = False
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
if not prev_is_special:
out_string += " "
out_string += self.sp_model.decode(current_sub_tokens) + token
prev_is_special = True
current_sub_tokens = []
else:
current_sub_tokens.append(token)
prev_is_special = False
out_string += self.sp_model.decode(current_sub_tokens)
out_string = self.clean_up_tokenization(out_string)
out_string = self._maybe_add_prefix_space(tokens=tokens, decoded=out_string)
return out_string[1:]
def save_vocabulary(
self, save_directory, filename_prefix: Optional[str] = None
) -> Tuple[str]:
"""
Save the vocabulary and special tokens file to a directory.
Args:
save_directory (`str`):
The directory in which to save the vocabulary.
Returns:
`Tuple(str)`: Paths to the files saved.
"""
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
out_vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "")
+ VOCAB_FILES_NAMES["vocab_file"],
)
if os.path.abspath(self.vocab_file) != os.path.abspath(
out_vocab_file
) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,)
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
if self.add_bos_token:
bos_token_ids = [self.bos_token_id]
else:
bos_token_ids = []
output = bos_token_ids + token_ids_0
if token_ids_1 is not None:
output = output + token_ids_1
if self.add_eos_token:
output = output + [self.eos_token_id]
return output
def get_special_tokens_mask(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: bool = False,
) -> List[int]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` method.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0,
token_ids_1=token_ids_1,
already_has_special_tokens=True,
)
if token_ids_1 is None:
return [1] + ([0] * len(token_ids_0)) + [1]
return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
use of token type ids, therefore a list of zeros is returned.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of zeros.
"""
eos = [self.eos_token_id]
if token_ids_1 is None:
return len(token_ids_0 + eos) * [0]
return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
TOKENIZER_MAPPING.register(
InternVLChatConfig, (InternLM2Tokenizer, None), exist_ok=True
)

View File

@@ -0,0 +1,634 @@
# Adapted from:
# https://github.com/deepseek-ai/Janus/tree/main/janus/models
from dataclasses import dataclass
from typing import Dict, List, Tuple, Union
import numpy as np
import PIL
import torch
from PIL.Image import Image
from transformers import (
BaseImageProcessor,
BatchFeature,
LlamaConfig,
LlamaTokenizerFast,
PretrainedConfig,
ProcessorMixin,
)
from transformers.image_utils import to_numpy_array
from sglang.srt.configs.utils import register_image_processor, register_processor
from sglang.srt.multimodal.mm_utils import expand2square
class DictToObject(dict):
def __init__(self, dictionary):
super(self).__init__(dictionary)
for key, value in dictionary.items():
if isinstance(value, dict):
value = DictToObject(value)
setattr(self, key, value)
class VisionConfig(PretrainedConfig):
model_type = "vision"
cls: str = ""
params = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__
self.params = kwargs.get("params", {})
class GenAlignerConfig(PretrainedConfig):
model_type = "gen_aligner"
cls: str = ""
params = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__
self.params = kwargs.get("params", {})
class GenHeadConfig(PretrainedConfig):
model_type = "gen_head"
cls: str = ""
params = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__
self.params = kwargs.get("params", {})
class AlignerConfig(PretrainedConfig):
model_type = "aligner"
cls: str = ""
params = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__
self.params = kwargs.get("params", {})
class GenVisionConfig(PretrainedConfig):
model_type = "gen_vision"
cls: str = ""
params = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__
self.params = kwargs.get("params", {})
@dataclass
class SigLIPVisionCfg:
width: int = 1152
layers: Union[Tuple[int, int, int, int], int] = 27
heads: int = 16
patch_size: int = 14
image_size: Union[Tuple[int, int], int] = 336
global_pool: str = "map"
mlp_ratio: float = 3.7362
class_token: bool = False
num_classes: int = 0
use_checkpoint: bool = False
class MultiModalityConfig(PretrainedConfig):
model_type = "multi_modality"
vision_config: VisionConfig
aligner_config: AlignerConfig
gen_vision_config: GenVisionConfig
gen_aligner_config: GenAlignerConfig
gen_head_config: GenHeadConfig
language_config: LlamaConfig
def __init__(self, **kwargs):
super().__init__(**kwargs)
vision_config = kwargs.get("vision_config", {})
self.vision_config = VisionConfig(**vision_config)
aligner_config = kwargs.get("aligner_config", {})
self.aligner_config = AlignerConfig(**aligner_config)
gen_vision_config = kwargs.get("gen_vision_config", {})
self.gen_vision_config = GenVisionConfig(**gen_vision_config)
gen_aligner_config = kwargs.get("gen_aligner_config", {})
self.gen_aligner_config = GenAlignerConfig(**gen_aligner_config)
gen_head_config = kwargs.get("gen_head_config", {})
self.gen_head_config = GenHeadConfig(**gen_head_config)
language_config = kwargs.get("language_config", {})
if isinstance(language_config, LlamaConfig):
self.language_config = language_config
else:
self.language_config = LlamaConfig(**language_config)
class VLMImageProcessor(BaseImageProcessor):
model_input_names = ["pixel_values"]
def __init__(
self,
image_size: int,
min_size: int = 14,
image_mean: Union[Tuple[float, float, float], List[float]] = (
0.48145466,
0.4578275,
0.40821073,
),
image_std: Union[Tuple[float, float, float], List[float]] = (
0.26862954,
0.26130258,
0.27577711,
),
rescale_factor: float = 1.0 / 255.0,
do_normalize: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self.image_size = image_size
self.rescale_factor = rescale_factor
self.image_mean = image_mean
self.image_std = image_std
self.min_size = min_size
self.do_normalize = do_normalize
if image_mean is None:
self.background_color = (127, 127, 127)
else:
self.background_color = tuple([int(x * 255) for x in image_mean])
def resize(self, pil_img: Image) -> np.ndarray:
"""
Args:
pil_img (PIL.Image): [H, W, 3] in PIL.Image in RGB
Returns:
x (np.ndarray): [3, self.image_size, self.image_size]
"""
width, height = pil_img.size
max_size = max(width, height)
size = [
max(int(height / max_size * self.image_size), self.min_size),
max(int(width / max_size * self.image_size), self.min_size),
]
if width <= 0 or height <= 0 or size[0] <= 0 or size[1] <= 0:
# print(f"orig size = {pil_img.size}, new size = {size}")
raise ValueError("Invalid size!")
def resize(
pil_img, size, interpolation=PIL.Image.Resampling.BICUBIC, antialias=True
):
if isinstance(size, int):
w, h = pil_img.size
if (w <= h and w == size) or (h <= w and h == size):
return pil_img
if w < h:
ow = size
oh = int(size * h / w)
else:
oh = size
ow = int(size * w / h)
size = (ow, oh)
else:
size = (size[1], size[0])
return pil_img.resize(
size, resample=interpolation, reducing_gap=None if antialias else 3.0
)
pil_img = resize(
pil_img, size, interpolation=PIL.Image.Resampling.BICUBIC, antialias=True
)
pil_img = expand2square(pil_img, self.background_color)
x = to_numpy_array(pil_img)
# [H, W, 3] -> [3, H, W]
x = np.transpose(x, (2, 0, 1))
return x
def preprocess(self, images, return_tensors: str = "pt", **kwargs) -> BatchFeature:
# resize and pad to [self.image_size, self.image_size]
# then convert from [H, W, 3] to [3, H, W]
if not isinstance(images, list):
images = [images]
images: List[np.ndarray] = [self.resize(image) for image in images]
images = [image[:3, ...] for image in images]
# rescale from [0, 255] -> [0, 1]
images = [
self.rescale(
image=image,
scale=self.rescale_factor,
input_data_format="channels_first",
)
for image in images
]
# normalize
if self.do_normalize:
images = [
self.normalize(
image=image,
mean=self.image_mean,
std=self.image_std,
input_data_format="channels_first",
)
for image in images
]
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)
@property
def default_shape(self):
return [3, self.image_size, self.image_size]
class DictOutput(object):
def items(self):
return self.__dict__.items()
def keys(self):
return self.__dict__.keys()
def __getitem__(self, item):
return self.__dict__[item]
def __contains__(self, key):
return key in self.__dict__
def __setitem__(self, key, value):
self.__dict__[key] = value
@dataclass
class VLChatProcessorOutput(DictOutput):
sft_format: str
input_ids: torch.Tensor
pixel_values: torch.Tensor
num_image_tokens: torch.IntTensor
def __len__(self):
return len(self.input_ids)
@dataclass
class BatchedVLChatProcessorOutput(DictOutput):
sft_format: List[str]
input_ids: torch.Tensor
pixel_values: torch.Tensor
attention_mask: torch.Tensor
images_seq_mask: torch.BoolTensor
images_emb_mask: torch.BoolTensor
# FIXME: had to place Official Processor here, since image_processor module would not be imported in all threads,
# hence AutoProcessor registration would not be affective in some cases
class VLChatProcessor(ProcessorMixin):
image_processor_class = "AutoImageProcessor"
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
attributes = ["image_processor", "tokenizer"]
def __init__(
self,
image_processor: VLMImageProcessor,
tokenizer: LlamaTokenizerFast,
image_tag: str = "<image_placeholder>",
image_start_tag: str = "<begin_of_image>",
image_end_tag: str = "<end_of_image>",
pad_tag: str = "<▁pad▁>",
num_image_tokens: int = 576,
add_special_token: bool = False,
sft_format: str = "deepseek",
mask_prompt: bool = True,
ignore_id: int = -100,
**kwargs,
):
self.image_processor = image_processor
self.tokenizer = tokenizer
image_id = self.tokenizer.vocab.get(image_tag)
if image_id is None:
special_tokens = [image_tag]
special_tokens_dict = {"additional_special_tokens": special_tokens}
self.tokenizer.add_special_tokens(special_tokens_dict)
# print(f"Add image tag = {image_tag} to the tokenizer")
self.image_tag = image_tag
self.image_start_tag = image_start_tag
self.image_end_tag = image_end_tag
self.pad_tag = pad_tag
self.num_image_tokens = num_image_tokens
self.add_special_token = add_special_token
self.sft_format = sft_format
self.ignore_id = ignore_id
super().__init__(
image_processor,
tokenizer,
**kwargs,
)
@property
def image_token(self):
return self.image_tag
@property
def image_id(self) -> int:
image_id = self.tokenizer.vocab.get(self.image_tag)
return image_id
@property
def image_start_id(self):
image_start_id = self.tokenizer.vocab.get(self.image_start_tag)
return image_start_id
@property
def image_end_id(self):
image_end_id = self.tokenizer.vocab.get(self.image_end_tag)
return image_end_id
@property
def image_start_token(self):
return self.image_start_tag
@property
def image_end_token(self):
return self.image_end_tag
@property
def pad_id(self):
pad_id = self.tokenizer.vocab.get(self.pad_tag)
return pad_id
def add_image_token(
self,
image_indices: List[int],
input_ids: torch.LongTensor,
):
"""
Args:
image_indices (List[int]): [index_0, index_1, ..., index_j]
input_ids (torch.LongTensor): [N]
Returns:
input_ids (torch.LongTensor): [N + image tokens]
num_image_tokens (torch.IntTensor): [n_images]
"""
input_slices = []
start = 0
for index in image_indices:
if self.add_special_token:
end = index + 1
else:
end = index
# original text tokens
input_slices.append(input_ids[start:end])
# add boi, image tokens, eoi and set the mask as False
input_slices.append(self.image_start_id * torch.ones((1), dtype=torch.long))
input_slices.append(
self.image_id * torch.ones((self.num_image_tokens,), dtype=torch.long)
)
input_slices.append(self.image_end_id * torch.ones((1), dtype=torch.long))
start = index + 1
# the left part
input_slices.append(input_ids[start:])
# concat all slices
input_ids = torch.cat(input_slices, dim=0)
num_image_tokens = torch.IntTensor([self.num_image_tokens] * len(image_indices))
return input_ids, num_image_tokens
def process_one(
self,
prompt: str = None,
images: List[Image] = None,
**kwargs,
):
"""
Args:
prompt (str): the formatted prompt;
images (List[ImageType]): the list of images;
**kwargs:
Returns:
outputs (BaseProcessorOutput): the output of the processor,
- input_ids (torch.LongTensor): [N + image tokens]
- target_ids (torch.LongTensor): [N + image tokens]
- images (torch.FloatTensor): [n_images, 3, H, W]
- image_id (int): the id of the image token
- num_image_tokens (List[int]): the number of image tokens
"""
sft_format = prompt
# tokenize
input_ids = self.tokenizer.encode(sft_format)
input_ids = torch.LongTensor(input_ids)
# add image tokens to the input_ids
image_token_mask: torch.Tensor = (input_ids == self.image_id).to(torch.bool)
image_indices = image_token_mask.nonzero()
input_ids, num_image_tokens = self.add_image_token(
image_indices=image_indices,
input_ids=input_ids,
)
# load images
images_outputs = self.image_processor(images, return_tensors="pt")
prepare = VLChatProcessorOutput(
sft_format=sft_format,
input_ids=input_ids,
pixel_values=images_outputs.pixel_values,
num_image_tokens=num_image_tokens,
)
return prepare
def __call__(
self,
*,
prompt: str = None,
conversations: List[Dict[str, str]] = None,
images: List[Image] = None,
force_batchify: bool = True,
**kwargs,
):
"""
Args:
prompt (str): the formatted prompt;
conversations (List[Dict]): conversations with a list of messages;
images (List[ImageType]): the list of images;
force_batchify (bool): force batchify the inputs;
**kwargs:
Returns:
outputs (BaseProcessorOutput): the output of the processor,
- input_ids (torch.LongTensor): [N + image tokens]
- images (torch.FloatTensor): [n_images, 3, H, W]
- image_id (int): the id of the image token
- num_image_tokens (List[int]): the number of image tokens
"""
prepare = self.process_one(
prompt=prompt, conversations=conversations, images=images
)
if force_batchify:
prepare = self.batchify([prepare])
return prepare
def batchify(
self, prepare_list: List[VLChatProcessorOutput]
) -> BatchedVLChatProcessorOutput:
"""
Preprocesses the inputs for multimodal inference.
Args:
prepare_list (List[VLChatProcessorOutput]): A list of VLChatProcessorOutput.
Returns:
BatchedVLChatProcessorOutput: A dictionary of the inputs to use for multimodal inference.
"""
batch_size = len(prepare_list)
sft_format = []
n_images = []
seq_lens = []
for prepare in prepare_list:
n_images.append(len(prepare.num_image_tokens))
seq_lens.append(len(prepare))
input_token_max_len = max(seq_lens)
max_n_images = max(1, max(n_images))
batched_input_ids = torch.full(
(batch_size, input_token_max_len), self.pad_id
).long() # FIXME
batched_attention_mask = torch.zeros((batch_size, input_token_max_len)).long()
batched_pixel_values = torch.zeros(
(batch_size, max_n_images, *self.image_processor.default_shape)
).float()
batched_images_seq_mask = torch.zeros((batch_size, input_token_max_len)).bool()
batched_images_emb_mask = torch.zeros(
(batch_size, max_n_images, self.num_image_tokens)
).bool()
for i, prepare in enumerate(prepare_list):
input_ids = prepare.input_ids
seq_len = len(prepare)
n_image = len(prepare.num_image_tokens)
# left-padding
batched_attention_mask[i, -seq_len:] = 1
batched_input_ids[i, -seq_len:] = torch.LongTensor(input_ids)
batched_images_seq_mask[i, -seq_len:] = input_ids == self.image_id
if n_image > 0:
batched_pixel_values[i, :n_image] = prepare.pixel_values
for j, n_image_tokens in enumerate(prepare.num_image_tokens):
batched_images_emb_mask[i, j, :n_image_tokens] = True
sft_format.append(prepare.sft_format)
batched_prepares = BatchedVLChatProcessorOutput(
input_ids=batched_input_ids,
attention_mask=batched_attention_mask,
pixel_values=batched_pixel_values,
images_seq_mask=batched_images_seq_mask,
images_emb_mask=batched_images_emb_mask,
sft_format=sft_format,
)
return batched_prepares
class VLMImageProcessorConfig(PretrainedConfig):
model_type = "deepseek_vlm"
image_size: int
min_size: int
image_mean: Union[Tuple[float, float, float], List[float]]
image_std: Union[Tuple[float, float, float], List[float]]
rescale_factor: float
do_normalize: bool
def __init__(
self,
image_size: int,
min_size: int = 14,
image_mean: Union[Tuple[float, float, float], List[float]] = (
0.48145466,
0.4578275,
0.40821073,
),
image_std: Union[Tuple[float, float, float], List[float]] = (
0.26862954,
0.26130258,
0.27577711,
),
rescale_factor: float = 1.0 / 255.0,
do_normalize: bool = True,
**kwargs,
):
self.image_size = image_size
self.min_size = min_size
self.image_mean = image_mean
self.image_std = image_std
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
super().__init__(**kwargs)
register_processor(MultiModalityConfig, VLChatProcessor)
register_image_processor(MultiModalityConfig, VLMImageProcessor)

View File

@@ -0,0 +1,38 @@
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py
from typing import Optional, Union
from transformers.configuration_utils import PretrainedConfig
from sglang.srt.configs.deepseekvl2 import DeepseekV2Config
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
class KimiVLConfig(PretrainedConfig):
model_type = "kimi_vl"
def __init__(
self,
vision_config: Optional[Union[dict, MoonViTConfig]] = None,
text_config: Optional[Union[dict, DeepseekV2Config]] = None,
ignore_index: int = -100,
media_placeholder_token_id: int = 163605,
pad_token_id: int = 0,
**kwargs
):
if vision_config is None:
vision_config = MoonViTConfig()
elif isinstance(vision_config, dict):
vision_config = MoonViTConfig(**vision_config)
self.vision_config = vision_config
if text_config is None:
text_config = DeepseekV2Config()
elif isinstance(text_config, dict):
text_config = DeepseekV2Config(**text_config)
self.text_config = text_config
self.ignore_index = ignore_index
self.media_placeholder_token_id = media_placeholder_token_id
super().__init__(pad_token_id=pad_token_id, **kwargs)

View File

@@ -0,0 +1,32 @@
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py
from transformers.configuration_utils import PretrainedConfig
class MoonViTConfig(PretrainedConfig):
model_type = "moonvit"
def __init__(
self,
patch_size: int = 14,
init_pos_emb_height: int = 64,
init_pos_emb_width: int = 64,
num_attention_heads: int = 16,
num_hidden_layers: int = 27,
hidden_size: int = 1152,
intermediate_size: int = 4304,
merge_kernel_size: tuple[int, int] = (2, 2),
**kwargs,
):
super().__init__(**kwargs)
self.patch_size = patch_size
# Positional embedding config
self.init_pos_emb_height = init_pos_emb_height
self.init_pos_emb_width = init_pos_emb_width
# Transformer config
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
# Patch merger config
self.merge_kernel_size = merge_kernel_size

View File

@@ -0,0 +1,89 @@
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
import enum
import json
import logging
from dataclasses import dataclass, field
from typing import List, Optional, Union
from sglang.srt.utils import is_hip
logger = logging.getLogger(__name__)
class LoadFormat(str, enum.Enum):
AUTO = "auto"
PT = "pt"
SAFETENSORS = "safetensors"
NPCACHE = "npcache"
DUMMY = "dummy"
SHARDED_STATE = "sharded_state"
GGUF = "gguf"
BITSANDBYTES = "bitsandbytes"
MISTRAL = "mistral"
LAYERED = "layered"
JAX = "jax"
REMOTE = "remote"
@dataclass
class LoadConfig:
"""
download_dir: Directory to download and load the weights, default to the
default cache directory of huggingface.
load_format: The format of the model weights to load:
"auto" will try to load the weights in the safetensors format and
fall back to the pytorch bin format if safetensors format is
not available.
"pt" will load the weights in the pytorch bin format.
"safetensors" will load the weights in the safetensors format.
"npcache" will load the weights in pytorch format and store
a numpy cache to speed up the loading.
"dummy" will initialize the weights with random values, which is
mainly for profiling.
"bitsandbytes" will load nf4 type weights.
ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's
checkpoints.
decryption_key_file: If set, decrypts the output files with a password read
from this file (after PBKDF2).
"""
load_format: Union[str, LoadFormat] = LoadFormat.AUTO
download_dir: Optional[str] = None
model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)
ignore_patterns: Optional[Union[List[str], str]] = None
decryption_key_file: Optional[str] = None
def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {}
if isinstance(model_loader_extra_config, str):
self.model_loader_extra_config = json.loads(model_loader_extra_config)
self._verify_load_format()
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
logger.info(
"Ignoring the following patterns when downloading weights: %s",
self.ignore_patterns,
)
else:
self.ignore_patterns = ["original/**/*"]
def _verify_load_format(self) -> None:
if not isinstance(self.load_format, str):
return
load_format = self.load_format.lower()
self.load_format = LoadFormat(load_format)
rocm_not_supported_load_format: List[str] = []
if is_hip() and load_format in rocm_not_supported_load_format:
rocm_supported_load_format = [
f
for f in LoadFormat.__members__
if (f not in rocm_not_supported_load_format)
]
raise ValueError(
f"load format '{load_format}' is not supported in ROCm. "
f"Supported load formats are "
f"{rocm_supported_load_format}"
)

View File

@@ -0,0 +1,104 @@
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
FLASH_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
class LongcatFlashConfig(PretrainedConfig):
model_type = "longcat_flash"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=131072,
hidden_size=6144,
intermediate_size=None,
ffn_hidden_size=12288,
expert_ffn_hidden_size=2048,
num_layers=28,
num_hidden_layers=None,
num_attention_heads=64,
ep_size=1,
kv_lora_rank=512,
q_lora_rank=1536,
qk_rope_head_dim=128,
qk_nope_head_dim=128,
v_head_dim=128,
n_routed_experts=512,
moe_topk=12,
norm_topk_prob=False,
max_position_embeddings=131072,
rms_norm_eps=1e-05,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
mla_scale_q_lora=True,
mla_scale_kv_lora=True,
torch_dtype="bfloat16",
params_dtype="bfloat16",
rounter_params_dtype="float32",
router_bias=False,
topk_method=None,
routed_scaling_factor=6.0,
zero_expert_num=256,
zero_expert_type="identity",
nextn_use_scmoe=False,
num_nextn_predict_layers=1,
**kwargs,
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
torch_dtype=torch_dtype,
params_dtype=params_dtype,
rounter_params_dtype=rounter_params_dtype,
topk_method=topk_method,
router_bias=router_bias,
nextn_use_scmoe=nextn_use_scmoe,
num_nextn_predict_layers=num_nextn_predict_layers,
**kwargs,
)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.num_hidden_layers = (
num_hidden_layers if num_hidden_layers is not None else num_layers
)
self.intermediate_size = (
intermediate_size if intermediate_size is not None else ffn_hidden_size
)
self.moe_intermediate_size = expert_ffn_hidden_size
self.num_attention_heads = num_attention_heads
self.ep_size = ep_size
self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.n_routed_experts = n_routed_experts
self.moe_topk = moe_topk
self.norm_topk_prob = norm_topk_prob
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.mla_scale_q_lora = mla_scale_q_lora
self.mla_scale_kv_lora = mla_scale_kv_lora
self.zero_expert_num = zero_expert_num
self.zero_expert_type = zero_expert_type
self.routed_scaling_factor = routed_scaling_factor
self.hidden_act = "silu"

View File

@@ -0,0 +1,794 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import json
import logging
import math
import os
from enum import Enum, IntEnum, auto
from typing import List, Optional, Set, Union
import torch
from transformers import PretrainedConfig
from sglang.srt.hf_transformers_utils import (
get_config,
get_context_length,
get_generation_config,
get_hf_text_config,
get_sparse_attention_config,
)
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_bool_env_var, is_hip
from sglang.utils import is_in_ci
logger = logging.getLogger(__name__)
class AttentionArch(IntEnum):
MLA = auto()
MHA = auto()
class ModelImpl(str, Enum):
AUTO = "auto"
SGLANG = "sglang"
TRANSFORMERS = "transformers"
class ModelConfig:
def __init__(
self,
model_path: str,
trust_remote_code: bool = True,
revision: Optional[str] = None,
context_length: Optional[int] = None,
model_override_args: str = "{}",
is_embedding: Optional[bool] = None,
enable_multimodal: Optional[bool] = None,
dtype: str = "auto",
quantization: Optional[str] = None,
override_config_file: Optional[str] = None,
is_draft_model: bool = False,
hybrid_kvcache_ratio: Optional[float] = None,
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
) -> None:
# Parse args
self.model_path = model_path
self.revision = revision
self.quantization = quantization
self.model_impl = model_impl
self.maybe_pull_model_tokenizer_from_remote()
self.model_override_args = json.loads(model_override_args)
kwargs = {}
if override_config_file and override_config_file.strip():
kwargs["_configuration_file"] = override_config_file.strip()
self.hf_config = get_config(
self.model_path,
trust_remote_code=trust_remote_code,
revision=revision,
model_override_args=self.model_override_args,
**kwargs,
)
self.hf_generation_config = get_generation_config(
self.model_path,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs,
)
self.hf_text_config = get_hf_text_config(self.hf_config)
self.attention_chunk_size = getattr(
self.hf_text_config, "attention_chunk_size", None
)
self.is_hybrid = is_hybrid_model(
self.hf_config.architectures,
hybrid_kvcache_ratio=hybrid_kvcache_ratio,
context_length=context_length,
attention_chunk_size=self.attention_chunk_size,
)
if self.is_hybrid is not None:
self.swa_attention_layer_ids, self.full_attention_layer_ids = (
get_hybrid_layer_ids(
self.hf_config.architectures, self.hf_text_config.num_hidden_layers
)
)
if enable_multimodal is None:
mm_disabled_models = [
"Gemma3ForConditionalGeneration",
"Llama4ForConditionalGeneration",
"Step3VLForConditionalGeneration",
]
if self.hf_config.architectures[0] in mm_disabled_models:
enable_multimodal = False
logger.info(
f"Multimodal is disabled for {self.hf_config.model_type}. To enable it, set --enable-multimodal."
)
else:
enable_multimodal = True
if (
is_draft_model
and self.hf_config.architectures[0] == "DeepseekV3ForCausalLM"
):
self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN"
if is_draft_model and self.hf_config.architectures[0] == "Glm4MoeForCausalLM":
self.hf_config.architectures[0] = "Glm4MoeForCausalLMNextN"
if (
is_draft_model
and self.hf_config.architectures[0] == "LongcatFlashForCausalLM"
):
self.hf_config.architectures[0] = "LongcatFlashForCausalLMNextN"
self.hf_config.num_hidden_layers = self.hf_config.num_nextn_predict_layers
if is_draft_model and self.hf_config.architectures[0] == "MiMoForCausalLM":
self.hf_config.architectures[0] = "MiMoMTP"
if (
is_draft_model
and self.hf_config.architectures[0] == "Ernie4_5_MoeForCausalLM"
):
self.hf_config.architectures[0] = "Ernie4_5_MoeForCausalLMMTP"
# Check model type
self.is_generation = is_generation_model(
self.hf_config.architectures, is_embedding
)
self.is_multimodal = enable_multimodal and is_multimodal_model(
self.hf_config.architectures
)
self.is_multimodal_gen = enable_multimodal and is_multimodal_gen_model(
self.hf_config.architectures
)
self.is_image_gen = enable_multimodal and is_image_gen_model(
self.hf_config.architectures
)
self.is_audio_model = enable_multimodal and is_audio_model(
self.hf_config.architectures
)
self.is_multimodal_chunked_prefill_supported = (
enable_multimodal
and is_multimodal_chunked_prefill_supported(self.hf_config.architectures)
)
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
# Derive context length
derived_context_len = get_context_length(self.hf_text_config)
if context_length is not None:
if context_length > derived_context_len:
reason = "Target model's" if is_draft_model else "User-specified"
msg = (
f"Warning: {reason} context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config."
)
if (
get_bool_env_var("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN")
or is_in_ci() # FIXME: fix this special case
):
logger.warning(msg)
self.context_len = context_length
else:
raise ValueError(
f"{msg} To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1"
)
else:
self.context_len = context_length
else:
self.context_len = derived_context_len
# Unify the config keys for hf_text_config
self.head_dim = getattr(
self.hf_text_config,
"head_dim",
self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads,
)
# FIXME: temporary special judge for MLA architecture
if (
"DeepseekV2ForCausalLM" in self.hf_config.architectures
or "DeepseekV3ForCausalLM" in self.hf_config.architectures
or "DeepseekV3ForCausalLMNextN" in self.hf_config.architectures
or "LongcatFlashForCausalLM" in self.hf_config.architectures
or "LongcatFlashForCausalLMNextN" in self.hf_config.architectures
):
self.head_dim = 256
self.attention_arch = AttentionArch.MLA
self.kv_lora_rank = self.hf_config.kv_lora_rank
self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
self.v_head_dim = self.hf_config.v_head_dim
# Handle rope scaling with yarn
self.scaling = 1 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)
if self.hf_config.rope_scaling:
mscale_all_dim = self.hf_config.rope_scaling.get(
"mscale_all_dim", False
)
scaling_factor = self.hf_config.rope_scaling["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale
elif "MiniCPM3ForCausalLM" in self.hf_config.architectures:
self.head_dim = 128
self.attention_arch = AttentionArch.MLA
self.kv_lora_rank = self.hf_config.kv_lora_rank
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
elif "DeepseekVL2ForCausalLM" in self.hf_config.architectures and getattr(
self.hf_text_config, "use_mla", True
):
self.head_dim = 256
self.attention_arch = AttentionArch.MLA
self.kv_lora_rank = self.hf_text_config.kv_lora_rank
self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
elif "KimiVLForConditionalGeneration" in self.hf_config.architectures:
self.head_dim = 256
self.attention_arch = AttentionArch.MLA
self.kv_lora_rank = self.hf_text_config.kv_lora_rank
self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
self.v_head_dim = self.hf_text_config.v_head_dim
self.qk_nope_head_dim = self.hf_text_config.qk_nope_head_dim
else:
if (
"MistralModel" in self.hf_config.architectures
or "MixtralForCausalLM" in self.hf_config.architectures
or "MistralForCausalLM" in self.hf_config.architectures
):
if getattr(self, "head_dim", None) is None:
self.head_dim = (
self.hf_config.hidden_size // self.hf_config.num_attention_heads
)
# In transformers==4.52.3, the head_dim is null in MistralConfig
if (
not hasattr(self.hf_text_config, "head_dim")
or self.hf_text_config.head_dim is None
):
setattr(self.hf_text_config, "head_dim", self.head_dim)
self.attention_arch = AttentionArch.MHA
self.num_attention_heads = self.hf_text_config.num_attention_heads
self.num_key_value_heads = getattr(
self.hf_text_config, "num_key_value_heads", None
)
# for Dbrx and MPT models
if self.hf_config.model_type in ["dbrx", "mpt"]:
self.num_key_value_heads = getattr(
self.hf_config.attn_config, "kv_n_heads", None
)
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
self.hidden_size = self.hf_text_config.hidden_size
self.num_hidden_layers = self.hf_text_config.num_hidden_layers
self.num_attention_layers = self.num_hidden_layers
if "LongcatFlashForCausalLM" in self.hf_config.architectures:
self.num_attention_layers = self.num_hidden_layers * 2
self.num_nextn_predict_layers = getattr(
self.hf_text_config, "num_nextn_predict_layers", None
)
self.vocab_size = self.hf_text_config.vocab_size
# Verify quantization
self._verify_quantization()
# Verify dual-chunk attention config
self._verify_dual_chunk_attention_config()
# Cache attributes
self.hf_eos_token_id = self.get_hf_eos_token_id()
# multimodal
self.image_token_id = getattr(
self.hf_config, "image_token_id", None
) or getattr(self.hf_config, "image_token_index", None)
@staticmethod
def from_server_args(server_args: ServerArgs, model_path: str = None, **kwargs):
return ModelConfig(
model_path=model_path or server_args.model_path,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding,
enable_multimodal=server_args.enable_multimodal,
dtype=server_args.dtype,
quantization=server_args.quantization,
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
model_impl=server_args.model_impl,
**kwargs,
)
def get_total_num_attention_heads(self) -> int:
return self.num_attention_heads
def get_num_attention_heads(self, tensor_parallel_size) -> int:
total_num_attention_heads = self.num_attention_heads
return max(1, total_num_attention_heads // tensor_parallel_size)
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
def get_total_num_kv_heads(self) -> int:
"""Returns the total number of KV heads."""
# For GPTBigCode & Falcon:
# NOTE: for falcon, when new_decoder_architecture is True, the
# multi_query flag is ignored and we use n_head_kv for the number of
# KV heads.
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
new_decoder_arch_falcon = (
self.hf_config.model_type in falcon_model_types
and getattr(self.hf_config, "new_decoder_architecture", False)
)
if not new_decoder_arch_falcon and getattr(
self.hf_text_config, "multi_query", False
):
# Multi-query attention, only one KV head.
# Currently, tensor parallelism is not supported in this case.
return 1
# For DBRX and MPT
if self.hf_config.model_type in ["mpt"]:
if "kv_n_heads" in self.hf_config.attn_config:
return self.hf_config.attn_config["kv_n_heads"]
return self.hf_config.num_attention_heads
if self.hf_config.model_type in ["dbrx"]:
return getattr(
self.hf_config.attn_config,
"kv_n_heads",
self.hf_config.num_attention_heads,
)
if self.hf_config.model_type in ["nemotron-nas"]:
nkvh = {
self.hf_config.num_attention_heads // block.attention.n_heads_in_group
for block in self.hf_config.block_configs
if not block.attention.no_op
}
if len(nkvh) == 0:
raise RuntimeError("Couldn't determine number of kv heads")
if len(nkvh) > 1:
raise ValueError(
"Variable GQA (VGQA) is not yet supported for nemotron-nas in sglang"
)
return next(iter(nkvh))
attributes = [
# For Falcon:
"n_head_kv",
"num_kv_heads",
# For LLaMA-2:
"num_key_value_heads",
# For ChatGLM:
"multi_query_group_num",
# For Step3
"num_attention_groups",
]
for attr in attributes:
num_kv_heads = getattr(self.hf_text_config, attr, None)
if num_kv_heads is not None:
return num_kv_heads
# For non-grouped-query attention models, the number of KV heads is
# equal to the number of attention heads.
return self.hf_text_config.num_attention_heads
def get_num_kv_heads(self, tensor_parallel_size) -> int:
"""Returns the number of KV heads per GPU."""
total_num_kv_heads = self.get_total_num_kv_heads()
# If tensor parallelism is used, we divide the number of KV heads by
# the tensor parallel size. We will replicate the KV heads in the
# case where the number of KV heads is smaller than the tensor
# parallel size so each GPU has at least one KV head.
return max(1, total_num_kv_heads // tensor_parallel_size)
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _parse_quant_hf_config(self):
quant_cfg = getattr(self.hf_config, "quantization_config", None)
if quant_cfg is None:
# compressed-tensors uses a "compression_config" key
quant_cfg = getattr(self.hf_config, "compression_config", None)
if quant_cfg is None:
# check if is modelopt or mixed-precision model -- Both of them don't have corresponding field
# in hf `config.json` but has a standalone `hf_quant_config.json` in the root directory
# example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main
# example: https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/tree/main
is_local = os.path.exists(self.model_path)
modelopt_quant_config = {"quant_method": "modelopt"}
if not is_local:
from huggingface_hub import HfApi
hf_api = HfApi()
if hf_api.file_exists(self.model_path, "hf_quant_config.json"):
quant_cfg = modelopt_quant_config
elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")):
quant_config_file = os.path.join(
self.model_path, "hf_quant_config.json"
)
with open(quant_config_file) as f:
quant_config_dict = json.load(f)
json_quant_configs = quant_config_dict["quantization"]
quant_algo = json_quant_configs.get("quant_algo", None)
if quant_algo == "MIXED_PRECISION":
quant_cfg = {"quant_method": "w4afp8"}
else:
quant_cfg = modelopt_quant_config
return quant_cfg
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _verify_quantization(self) -> None:
supported_quantization = [*QUANTIZATION_METHODS]
rocm_supported_quantization = [
"awq",
"gptq",
"fp8",
"compressed_tensors",
"compressed-tensors",
"fbgemm_fp8",
"w8a8_fp8",
"petit_nvfp4",
"quark",
"mxfp4",
]
optimized_quantization_methods = [
"fp8",
"marlin",
"modelopt",
"gptq_marlin_24",
"gptq_marlin",
"awq_marlin",
"fbgemm_fp8",
"compressed_tensors",
"compressed-tensors",
"experts_int8",
"w8a8_int8",
"w8a8_fp8",
"moe_wna16",
"qoq",
"w4afp8",
"petit_nvfp4",
]
compatible_quantization_methods = {
"modelopt_fp4": ["modelopt"],
"petit_nvfp4": ["modelopt"],
"w8a8_int8": ["compressed-tensors", "compressed_tensors"],
"w8a8_fp8": ["compressed-tensors", "compressed_tensors"],
}
if self.quantization is not None:
self.quantization = self.quantization.lower()
# Parse quantization method from the HF model config, if available.
quant_cfg = self._parse_quant_hf_config()
if quant_cfg is not None:
quant_method = quant_cfg.get(
"quant_method", "" if not self.quantization else self.quantization
).lower()
# Detect which checkpoint is it
for _, method in QUANTIZATION_METHODS.items():
quantization_override = method.override_quantization_method(
quant_cfg, self.quantization
)
if quantization_override:
quant_method = quantization_override
self.quantization = quantization_override
break
# Verify quantization configurations.
if self.quantization is None:
self.quantization = quant_method
elif self.quantization != quant_method:
if (
self.quantization not in compatible_quantization_methods
or quant_method
not in compatible_quantization_methods[self.quantization]
):
raise ValueError(
"Quantization method specified in the model config "
f"({quant_method}) does not match the quantization "
f"method specified in the `quantization` argument "
f"({self.quantization})."
)
if self.quantization is not None:
if self.quantization not in supported_quantization:
raise ValueError(
f"Unknown quantization method: {self.quantization}. Must "
f"be one of {supported_quantization}."
)
if is_hip() and self.quantization not in rocm_supported_quantization:
raise ValueError(
f"{self.quantization} quantization is currently not "
f"supported in ROCm."
)
if self.quantization not in optimized_quantization_methods:
logger.warning(
"%s quantization is not fully "
"optimized yet. The speed can be slower than "
"non-quantized models.",
self.quantization,
)
def _verify_dual_chunk_attention_config(self) -> None:
if hasattr(self.hf_config, "dual_chunk_attention_config"):
# Try loading the sparse attention config
sparse_attn_config = get_sparse_attention_config(self.model_path)
if not sparse_attn_config:
return
self.hf_config.dual_chunk_attention_config["sparse_attention_config"] = (
sparse_attn_config
)
if (
"sparse_attention_enabled"
not in self.hf_config.dual_chunk_attention_config
):
self.hf_config.dual_chunk_attention_config[
"sparse_attention_enabled"
] = True
def get_hf_eos_token_id(self) -> Optional[Set[int]]:
eos_ids = getattr(self.hf_config, "eos_token_id", None)
if eos_ids is not None:
# it can be either int or list of int
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
if eos_ids is None:
eos_ids = set()
if self.hf_generation_config:
generation_eos_ids = getattr(
self.hf_generation_config, "eos_token_id", None
)
if generation_eos_ids:
generation_eos_ids = (
{generation_eos_ids}
if isinstance(generation_eos_ids, int)
else set(generation_eos_ids)
)
eos_ids = eos_ids | generation_eos_ids
return eos_ids
def maybe_pull_model_tokenizer_from_remote(self) -> None:
"""
Pull the model config files to a temporary
directory in case of remote.
Args:
model: The model name or path.
"""
from sglang.srt.connector import create_remote_connector
from sglang.srt.utils import is_remote_url
if is_remote_url(self.model_path):
logger.info("Pulling model configs from remote...")
# BaseConnector implements __del__() to clean up the local dir.
# Since config files need to exist all the time, so we DO NOT use
# with statement to avoid closing the client.
client = create_remote_connector(self.model_path)
if is_remote_url(self.model_path):
client.pull_files(allow_pattern=["*config.json"])
self.model_weights = self.model_path
self.model_path = client.get_local_dir()
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.float16,
"float16": torch.float16,
"float": torch.float32,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _get_and_verify_dtype(
config: PretrainedConfig,
dtype: Union[str, torch.dtype],
) -> torch.dtype:
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None.
config_dtype = getattr(config, "torch_dtype", None)
if isinstance(config_dtype, str):
config_dtype = _STR_DTYPE_TO_TORCH_DTYPE.get(config_dtype, None)
if config_dtype is None:
config_dtype = torch.float32
if isinstance(dtype, str):
dtype = dtype.lower()
if dtype == "auto":
if config_dtype == torch.float32:
if config.model_type.startswith("gemma"):
if config.model_type == "gemma":
gemma_version = ""
else:
gemma_version = config.model_type[5]
logger.info(
f"For Gemma {gemma_version}, we downcast float32 to bfloat16 instead "
"of float16 by default. Please specify `dtype` if you "
"want to use float16."
)
torch_dtype = torch.bfloat16
else:
# Following the common practice, we use float16 for float32
# models.
torch_dtype = torch.float16
else:
torch_dtype = config_dtype
else:
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
raise ValueError(f"Unknown dtype: {dtype}")
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
elif isinstance(dtype, torch.dtype):
torch_dtype = dtype
else:
raise ValueError(f"Unknown dtype: {dtype}")
# Verify the dtype.
if torch_dtype != config_dtype:
if torch_dtype == torch.float32:
# Upcasting to float32 is allowed.
logger.info("Upcasting %s to %s.", config_dtype, torch_dtype)
pass
elif config_dtype == torch.float32:
# Downcasting from float32 to float16 or bfloat16 is allowed.
logger.info("Downcasting %s to %s.", config_dtype, torch_dtype)
pass
else:
# Casting between float16 and bfloat16 is allowed with a warning.
logger.warning("Casting %s to %s.", config_dtype, torch_dtype)
return torch_dtype
def is_generation_model(model_architectures: List[str], is_embedding: bool = False):
# We have two ways to determine whether a model is a generative model.
# 1. Check the model architecture
# 2. check the `is_embedding` server args
if (
"LlamaEmbeddingModel" in model_architectures
or "MistralModel" in model_architectures
or "LlamaForSequenceClassification" in model_architectures
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
or "InternLM2ForRewardModel" in model_architectures
or "Qwen2ForRewardModel" in model_architectures
or "Qwen2ForSequenceClassification" in model_architectures
or "Qwen3ForSequenceClassification" in model_architectures
or "CLIPModel" in model_architectures
or "BertModel" in model_architectures
or "Contriever" in model_architectures
or "BertForSequenceClassification" in model_architectures
or "XLMRobertaModel" in model_architectures
or "XLMRobertaForSequenceClassification" in model_architectures
):
return False
else:
return not is_embedding
multimodal_model_archs = [
"CLIPModel",
"DeepseekVL2ForCausalLM",
"Gemma3ForConditionalGeneration",
"Gemma3nForConditionalGeneration",
"Glm4vForConditionalGeneration",
"Glm4vMoeForConditionalGeneration",
"Grok1VForCausalLM",
"Grok1AForCausalLM",
"LlavaLlamaForCausalLM",
"Llama4ForConditionalGeneration",
"LlavaMistralForCausalLM",
"LlavaQwenForCausalLM",
"LlavaForConditionalGeneration",
"LlavaVidForCausalLM",
"MiniCPMO",
"MiniCPMV",
"Mistral3ForConditionalGeneration",
"MultiModalityCausalLM",
"MllamaForConditionalGeneration",
"Qwen2AudioForConditionalGeneration",
"Qwen2VLForConditionalGeneration",
"Qwen2_5_VLForConditionalGeneration",
"KimiVLForConditionalGeneration",
"InternVLChatModel",
"InternS1ForConditionalGeneration",
"Phi4MMForCausalLM",
"VILAForConditionalGeneration",
"Step3VLForConditionalGeneration",
]
def is_multimodal_model(model_architectures: List[str]):
if any(
multi_model_arch in model_architectures
for multi_model_arch in multimodal_model_archs
):
return True
else:
return False
def is_multimodal_gen_model(model_architectures: List[str]):
return False
def is_image_gen_model(model_architectures: List[str]):
return False
def is_audio_model(model_architectures: List[str]):
return False
def is_encoder_decoder_model(model_architectures: List[str]):
return "MllamaForConditionalGeneration" in model_architectures
def is_multimodal_chunked_prefill_supported(model_architectures: List[str]):
"""Check if chunked prefill is supported for a MultiModal model."""
unsupported = [
"Grok1VForCausalLM",
"Grok1AForCausalLM",
"LlavaLlamaForCausalLM",
"MllamaForConditionalGeneration",
"CLIPModel",
]
if any(multi_model_arch in unsupported for multi_model_arch in model_architectures):
return False
else:
return True
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
def is_hybrid_model(
model_architectures: List[str],
hybrid_kvcache_ratio: Optional[float],
context_length: Optional[int],
attention_chunk_size: Optional[int],
):
if hybrid_kvcache_ratio is None:
return None
elif (
hybrid_kvcache_ratio > 0
and model_architectures[0] == "Llama4ForConditionalGeneration"
and context_length > attention_chunk_size
):
return hybrid_kvcache_ratio
else:
return None
def get_hybrid_layer_ids(model_architectures: List[str], num_hidden_layers: int):
if "Llama4ForConditionalGeneration" in model_architectures:
swa_attention_layer_ids = [
i for i in range(num_hidden_layers) if (i + 1) % 4 != 0
]
full_attention_layer_ids = [
i for i in range(num_hidden_layers) if (i + 1) % 4 == 0
]
else:
swa_attention_layer_ids = None
full_attention_layer_ids = None
return swa_attention_layer_ids, full_attention_layer_ids

View File

@@ -0,0 +1,172 @@
from typing import Any, Optional, Union
from transformers.configuration_utils import PretrainedConfig
class Step3VisionEncoderConfig(PretrainedConfig):
model_type = "step3_vision_encoder"
def __init__(
self,
hidden_size=1792,
intermediate_size=3072,
output_hidden_size=4096,
num_hidden_layers=63,
num_attention_heads=16,
num_channels=3,
image_size=728,
patch_size=14,
hidden_act="quick_gelu",
layer_norm_eps=1e-5,
**kwargs,
):
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.output_hidden_size = output_hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
super().__init__(**kwargs)
class Step3TextConfig(PretrainedConfig):
model_type = "step3_text"
architectures = ["Step3TextForCausalLM"]
def __init__(
self,
hidden_size: int = 7168,
intermediate_size: int = 18432,
num_attention_heads: int = 64,
num_attention_groups: int = 1,
num_hidden_layers: int = 61,
max_seq_len: int = 65536,
vocab_size: int = 128815,
rms_norm_eps: float = 1e-5,
moe_intermediate_size: int = 5120,
moe_num_experts: int = 48,
moe_top_k: int = 3,
rope_theta: float = 500000,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embedding: int = 65536,
share_expert_dim: int = 5120,
share_q_dim: int = 2048,
head_dim: int = 256,
norm_expert_weight: bool = False,
moe_layers_enum: tuple[int] = (
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30,
31,
32,
33,
34,
35,
36,
37,
38,
39,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
50,
51,
52,
53,
54,
55,
56,
57,
58,
59,
),
**kwargs,
) -> None:
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.num_attention_groups = num_attention_groups
self.num_hidden_layers = num_hidden_layers
self.max_seq_len = max_seq_len
self.vocab_size = vocab_size
self.rms_norm_eps = rms_norm_eps
self.moe_intermediate_size = moe_intermediate_size
self.moe_num_experts = moe_num_experts
self.moe_top_k = moe_top_k
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.max_position_embedding = max_position_embedding
self.share_expert_dim = share_expert_dim
self.share_q_dim = share_q_dim
self.head_dim = head_dim
self.norm_expert_weight = norm_expert_weight
self.moe_layers_enum = moe_layers_enum
super().__init__(**kwargs)
class Step3VLConfig(PretrainedConfig):
model_type = "step3_vl"
def __init__(
self,
vision_config: Optional[Union[dict, Step3VisionEncoderConfig]] = None,
text_config: Optional[Union[dict, Step3TextConfig]] = None,
understand_projector_stride: int = 1,
projector_bias: bool = True,
image_token_id: int = 128001,
**kwargs,
) -> None:
if vision_config is None:
vision_config = Step3VisionEncoderConfig()
elif isinstance(vision_config, dict):
vision_config = Step3VisionEncoderConfig(**vision_config)
self.vision_config = vision_config
if text_config is None:
text_config = Step3TextConfig()
elif isinstance(text_config, dict):
text_config = Step3TextConfig(**text_config)
self.text_config = text_config
self.understand_projector_stride = understand_projector_stride
self.projector_bias = projector_bias
self.hidden_size = text_config.hidden_size
self.image_token_id = image_token_id
super().__init__(**kwargs)

View File

@@ -0,0 +1,156 @@
from __future__ import annotations
from typing import TYPE_CHECKING
DEFAULT_MOE_PADDING_SIZE = 32
if TYPE_CHECKING:
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import ModelConfig
def may_get_weight_block_size(model_config, load_config):
from sglang.srt.model_loader.loader import _get_quantization_config
from sglang.srt.model_loader.utils import get_model_architecture
model_class, _ = get_model_architecture(model_config)
packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {})
quant_config = _get_quantization_config(
model_config, load_config, packed_modules_mapping
)
if quant_config is not None and hasattr(quant_config, "weight_block_size"):
return getattr(quant_config, "weight_block_size")
return None
def get_moe_padding_size(weight_block_size):
if weight_block_size is not None:
# See NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
assert (
len(weight_block_size) == 2
), "Only len(weight_block_size) == 2 is supported"
assert (
weight_block_size[0] == weight_block_size[1]
), "Only weight_block_size[0] == weight_block_size[1] is supported"
return weight_block_size[0]
return DEFAULT_MOE_PADDING_SIZE
def get_num_heads_padding_size(tp_size, weight_block_size):
pad_size = (
tp_size * 2 if tp_size % 2 == 1 and weight_block_size is not None else tp_size
)
return pad_size
def update_intermediate_size(model_config, attr_name, intermediate_padding_size):
attr_value = intermediate_padding_size
if hasattr(model_config, "hf_config") and hasattr(
model_config.hf_config, attr_name
):
attr_value = getattr(model_config.hf_config, attr_name)
elif hasattr(model_config, attr_name):
attr_value = getattr(model_config, attr_name)
if attr_value % intermediate_padding_size != 0:
from sglang.srt.layers.vocab_parallel_embedding import pad_vocab_size
attr_value = pad_vocab_size(attr_value, intermediate_padding_size)
if hasattr(model_config, "hf_config"):
setattr(model_config.hf_config, attr_name, attr_value)
if hasattr(model_config, "hf_text_config"):
setattr(model_config.hf_text_config, attr_name, attr_value)
else:
setattr(model_config, attr_name, attr_value)
return model_config
def adjust_config_with_unaligned_cpu_tp(
model_config: ModelConfig, load_config: LoadConfig, tp_size: int
) -> ModelConfig:
# Support the case where the num_attention_heads is not divisible by the TP size.
weight_block_size = may_get_weight_block_size(model_config, load_config)
model_config.hf_config.original_num_attention_heads = (
model_config.num_attention_heads
)
model_config.hf_text_config.original_num_attention_heads = (
model_config.num_attention_heads
)
model_config.hf_config.original_total_num_kv_heads = (
model_config.get_total_num_kv_heads()
)
model_config.hf_text_config.original_total_num_kv_heads = (
model_config.get_total_num_kv_heads()
)
if (
model_config.num_attention_heads % tp_size != 0
or model_config.get_total_num_kv_heads() % tp_size != 0
):
# Compute the head_dim using the model_config.num_attention_heads before padding
if not hasattr(model_config.hf_config, "head_dim"):
model_config.hf_config.head_dim = (
model_config.hidden_size // model_config.num_attention_heads
)
query_heads_per_kv = (
model_config.num_attention_heads // model_config.get_total_num_kv_heads()
)
total_kv_heads = model_config.get_total_num_kv_heads()
from sglang.srt.layers.vocab_parallel_embedding import pad_vocab_size
pad_size = get_num_heads_padding_size(tp_size, weight_block_size)
num_key_value_heads = pad_vocab_size(total_kv_heads, pad_size)
model_config.num_key_value_heads = num_key_value_heads
model_config.hf_config.num_key_value_heads = num_key_value_heads
model_config.hf_text_config.num_key_value_heads = num_key_value_heads
num_attention_heads = num_key_value_heads * query_heads_per_kv
model_config.num_attention_heads = num_attention_heads
model_config.hf_config.num_attention_heads = num_attention_heads
model_config.hf_text_config.num_attention_heads = num_attention_heads
intermediate_padding_size = tp_size * get_moe_padding_size(weight_block_size)
model_config = update_intermediate_size(
model_config, "moe_intermediate_size", intermediate_padding_size
)
model_config = update_intermediate_size(
model_config, "intermediate_size", intermediate_padding_size
)
model_config = update_intermediate_size(
model_config, "intermediate_size_mlp", intermediate_padding_size
)
if (
hasattr(model_config.hf_config, "vision_config")
and model_config.hf_config.vision_config.model_type == "siglip_vision_model"
):
model_config.hf_config.vision_config.original_num_attention_heads = (
model_config.num_attention_heads
)
if model_config.hf_config.vision_config.num_attention_heads % tp_size != 0:
model_config.hf_config.vision_config.head_dim = (
model_config.hf_config.vision_config.hidden_size
// model_config.hf_config.vision_config.num_attention_heads
)
from sglang.srt.layers.vocab_parallel_embedding import pad_vocab_size
pad_size = get_num_heads_padding_size(tp_size, weight_block_size)
model_config.hf_config.vision_config.num_attention_heads = pad_vocab_size(
model_config.hf_config.vision_config.num_attention_heads, pad_size
)
model_config.hf_config.vision_config = update_intermediate_size(
model_config.hf_config.vision_config,
"intermediate_size",
intermediate_padding_size,
)
return model_config

View File

@@ -0,0 +1,25 @@
from typing import Type
from transformers import (
AutoImageProcessor,
AutoProcessor,
BaseImageProcessor,
PretrainedConfig,
ProcessorMixin,
)
def register_image_processor(
config: Type[PretrainedConfig], image_processor: Type[BaseImageProcessor]
):
"""
register customized hf image processor while removing hf impl
"""
AutoImageProcessor.register(config, None, image_processor, None, exist_ok=True)
def register_processor(config: Type[PretrainedConfig], processor: Type[ProcessorMixin]):
"""
register customized hf processor while removing hf impl
"""
AutoProcessor.register(config, processor, exist_ok=True)

View File

@@ -0,0 +1,51 @@
# SPDX-License-Identifier: Apache-2.0
import enum
import logging
from sglang.srt.connector.base_connector import (
BaseConnector,
BaseFileConnector,
BaseKVConnector,
)
from sglang.srt.connector.redis import RedisConnector
from sglang.srt.connector.s3 import S3Connector
from sglang.srt.utils import parse_connector_type
logger = logging.getLogger(__name__)
class ConnectorType(str, enum.Enum):
FS = "filesystem"
KV = "KV"
def create_remote_connector(url, **kwargs) -> BaseConnector:
connector_type = parse_connector_type(url)
if connector_type == "redis":
return RedisConnector(url)
elif connector_type == "s3":
return S3Connector(url)
else:
raise ValueError(f"Invalid connector type: {url}")
def get_connector_type(client: BaseConnector) -> ConnectorType:
if isinstance(client, BaseKVConnector):
return ConnectorType.KV
if isinstance(client, BaseFileConnector):
return ConnectorType.FS
raise ValueError(f"Invalid connector type: {client}")
__all__ = [
"BaseConnector",
"BaseFileConnector",
"BaseKVConnector",
"RedisConnector",
"S3Connector",
"ConnectorType",
"create_remote_connector",
"get_connector_type",
]

Some files were not shown because too many files have changed in this diff Show More