adapt to sglang v0.5.2rc1 on dcu
This commit is contained in:
166
python/pyproject.toml
Normal file
166
python/pyproject.toml
Normal file
@@ -0,0 +1,166 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "sglang"
|
||||
version = "0.5.2rc1"
|
||||
description = "SGLang is yet another fast serving framework for large language models and vision language models."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
license = { file = "LICENSE" }
|
||||
classifiers = [
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
]
|
||||
dependencies = ["aiohttp", "requests", "tqdm", "numpy", "IPython", "setproctitle"]
|
||||
|
||||
[project.optional-dependencies]
|
||||
runtime_common = [
|
||||
"blobfile==3.0.0",
|
||||
"build",
|
||||
"compressed-tensors",
|
||||
"datasets",
|
||||
"einops",
|
||||
"fastapi",
|
||||
"hf_transfer",
|
||||
"huggingface_hub",
|
||||
"interegular",
|
||||
"llguidance>=0.7.11,<0.8.0",
|
||||
"modelscope",
|
||||
"msgspec",
|
||||
"ninja",
|
||||
"openai==1.99.1",
|
||||
"openai-harmony==0.0.4",
|
||||
"orjson",
|
||||
"outlines==0.1.11",
|
||||
"packaging",
|
||||
"partial_json_parser",
|
||||
"pillow",
|
||||
"prometheus-client>=0.20.0",
|
||||
"psutil",
|
||||
"pybase64",
|
||||
"pydantic",
|
||||
"pynvml",
|
||||
"python-multipart",
|
||||
"pyzmq>=25.1.2",
|
||||
"sentencepiece",
|
||||
"soundfile==0.13.1",
|
||||
"scipy",
|
||||
"timm==1.0.16",
|
||||
"tiktoken",
|
||||
"torchao==0.9.0",
|
||||
"transformers==4.56.0",
|
||||
"uvicorn",
|
||||
"uvloop",
|
||||
"xgrammar==0.1.23",
|
||||
]
|
||||
|
||||
srt = [
|
||||
"sglang[runtime_common]",
|
||||
"sgl-kernel==0.3.8",
|
||||
"torch==2.8.0",
|
||||
"torchaudio==2.8.0",
|
||||
"torchvision",
|
||||
"cuda-python",
|
||||
"flashinfer_python==0.3.0",
|
||||
]
|
||||
|
||||
blackwell = [
|
||||
"sglang[runtime_common]",
|
||||
"sgl-kernel",
|
||||
"torch==2.8.0",
|
||||
"torchaudio==2.8.0",
|
||||
"torchvision",
|
||||
"cuda-python",
|
||||
"flashinfer_python==0.3.0",
|
||||
]
|
||||
|
||||
# HIP (Heterogeneous-computing Interface for Portability) for AMD
|
||||
# => base docker rocm/vllm-dev:20250114, not from public vllm whl
|
||||
srt_hip = [
|
||||
"sglang[runtime_common]",
|
||||
"torch",
|
||||
"petit_kernel==0.0.2",
|
||||
"wave-lang==1.0.1",
|
||||
]
|
||||
|
||||
# https://docs.sglang.ai/platforms/cpu_server.html
|
||||
srt_cpu = ["sglang[runtime_common]"]
|
||||
|
||||
# https://docs.sglang.ai/platforms/ascend_npu.html
|
||||
srt_npu = ["sglang[runtime_common]"]
|
||||
|
||||
# xpu is not enabled in public vllm and torch whl,
|
||||
# need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm
|
||||
srt_xpu = ["sglang[runtime_common]"]
|
||||
|
||||
# For Intel Gaudi(device : hpu) follow the installation guide
|
||||
# https://docs.vllm.ai/en/latest/getting_started/gaudi-installation.html
|
||||
srt_hpu = ["sglang[runtime_common]"]
|
||||
|
||||
openai = ["openai==1.99.1", "tiktoken"]
|
||||
anthropic = ["anthropic>=0.20.0"]
|
||||
litellm = ["litellm>=1.0.0"]
|
||||
torch_memory_saver = ["torch_memory_saver==0.0.8"]
|
||||
decord = ["decord"]
|
||||
test = [
|
||||
"accelerate",
|
||||
"expecttest",
|
||||
"jsonlines",
|
||||
"matplotlib",
|
||||
"pandas",
|
||||
"peft",
|
||||
"sentence_transformers",
|
||||
"pytest",
|
||||
"tabulate",
|
||||
]
|
||||
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[torch_memory_saver]", "sglang[decord]"]
|
||||
all_hip = ["sglang[srt_hip]", "sglang[openai]", "sglang[anthropic]", "sglang[decord]"]
|
||||
all_xpu = ["sglang[srt_xpu]", "sglang[openai]", "sglang[anthropic]", "sglang[decord]"]
|
||||
all_hpu = ["sglang[srt_hpu]", "sglang[openai]", "sglang[anthropic]", "sglang[decord]"]
|
||||
all_cpu = ["sglang[srt_cpu]", "sglang[openai]", "sglang[anthropic]", "sglang[decord]"]
|
||||
all_npu = ["sglang[srt_npu]", "sglang[openai]", "sglang[anthropic]", "sglang[decord]"]
|
||||
|
||||
dev = ["sglang[all]", "sglang[test]"]
|
||||
dev_hip = ["sglang[all_hip]", "sglang[test]"]
|
||||
dev_xpu = ["sglang[all_xpu]", "sglang[test]"]
|
||||
dev_hpu = ["sglang[all_hpu]", "sglang[test]"]
|
||||
dev_cpu = ["sglang[all_cpu]", "sglang[test]"]
|
||||
|
||||
[project.urls]
|
||||
"Homepage" = "https://github.com/sgl-project/sglang"
|
||||
"Bug Tracker" = "https://github.com/sgl-project/sglang/issues"
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
"sglang" = [
|
||||
"srt/layers/moe/fused_moe_triton/configs/*/*.json",
|
||||
"srt/layers/quantization/configs/*.json",
|
||||
"srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp",
|
||||
]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
exclude = [
|
||||
"assets*",
|
||||
"benchmark*",
|
||||
"docs*",
|
||||
"dist*",
|
||||
"playground*",
|
||||
"scripts*",
|
||||
"tests*",
|
||||
]
|
||||
|
||||
[tool.wheel]
|
||||
exclude = [
|
||||
"assets*",
|
||||
"benchmark*",
|
||||
"docs*",
|
||||
"dist*",
|
||||
"playground*",
|
||||
"scripts*",
|
||||
"tests*",
|
||||
]
|
||||
|
||||
[tool.codespell]
|
||||
ignore-words-list = "ans, als, hel, boostrap, childs, te, vas, hsa, ment"
|
||||
skip = "*.json,*.jsonl,*.patch,*.txt"
|
||||
150
python/sglang.egg-info/PKG-INFO
Normal file
150
python/sglang.egg-info/PKG-INFO
Normal file
@@ -0,0 +1,150 @@
|
||||
Metadata-Version: 2.4
|
||||
Name: sglang
|
||||
Version: 0.5.2rc1
|
||||
Summary: SGLang is yet another fast serving framework for large language models and vision language models.
|
||||
Project-URL: Homepage, https://github.com/sgl-project/sglang
|
||||
Project-URL: Bug Tracker, https://github.com/sgl-project/sglang/issues
|
||||
Classifier: Programming Language :: Python :: 3
|
||||
Classifier: License :: OSI Approved :: Apache Software License
|
||||
Requires-Python: >=3.10
|
||||
Description-Content-Type: text/markdown
|
||||
Requires-Dist: aiohttp
|
||||
Requires-Dist: requests
|
||||
Requires-Dist: tqdm
|
||||
Requires-Dist: numpy
|
||||
Requires-Dist: IPython
|
||||
Requires-Dist: setproctitle
|
||||
Provides-Extra: runtime-common
|
||||
Requires-Dist: blobfile==3.0.0; extra == "runtime-common"
|
||||
Requires-Dist: build; extra == "runtime-common"
|
||||
Requires-Dist: compressed-tensors; extra == "runtime-common"
|
||||
Requires-Dist: datasets; extra == "runtime-common"
|
||||
Requires-Dist: einops; extra == "runtime-common"
|
||||
Requires-Dist: fastapi; extra == "runtime-common"
|
||||
Requires-Dist: hf_transfer; extra == "runtime-common"
|
||||
Requires-Dist: huggingface_hub; extra == "runtime-common"
|
||||
Requires-Dist: interegular; extra == "runtime-common"
|
||||
Requires-Dist: llguidance<0.8.0,>=0.7.11; extra == "runtime-common"
|
||||
Requires-Dist: modelscope; extra == "runtime-common"
|
||||
Requires-Dist: msgspec; extra == "runtime-common"
|
||||
Requires-Dist: ninja; extra == "runtime-common"
|
||||
Requires-Dist: openai==1.99.1; extra == "runtime-common"
|
||||
Requires-Dist: openai-harmony==0.0.4; extra == "runtime-common"
|
||||
Requires-Dist: orjson; extra == "runtime-common"
|
||||
Requires-Dist: outlines==0.1.11; extra == "runtime-common"
|
||||
Requires-Dist: packaging; extra == "runtime-common"
|
||||
Requires-Dist: partial_json_parser; extra == "runtime-common"
|
||||
Requires-Dist: pillow; extra == "runtime-common"
|
||||
Requires-Dist: prometheus-client>=0.20.0; extra == "runtime-common"
|
||||
Requires-Dist: psutil; extra == "runtime-common"
|
||||
Requires-Dist: pybase64; extra == "runtime-common"
|
||||
Requires-Dist: pydantic; extra == "runtime-common"
|
||||
Requires-Dist: pynvml; extra == "runtime-common"
|
||||
Requires-Dist: python-multipart; extra == "runtime-common"
|
||||
Requires-Dist: pyzmq>=25.1.2; extra == "runtime-common"
|
||||
Requires-Dist: sentencepiece; extra == "runtime-common"
|
||||
Requires-Dist: soundfile==0.13.1; extra == "runtime-common"
|
||||
Requires-Dist: scipy; extra == "runtime-common"
|
||||
Requires-Dist: timm==1.0.16; extra == "runtime-common"
|
||||
Requires-Dist: tiktoken; extra == "runtime-common"
|
||||
Requires-Dist: torchao==0.9.0; extra == "runtime-common"
|
||||
Requires-Dist: transformers==4.56.0; extra == "runtime-common"
|
||||
Requires-Dist: uvicorn; extra == "runtime-common"
|
||||
Requires-Dist: uvloop; extra == "runtime-common"
|
||||
Requires-Dist: xgrammar==0.1.23; extra == "runtime-common"
|
||||
Provides-Extra: srt
|
||||
Requires-Dist: sglang[runtime_common]; extra == "srt"
|
||||
Requires-Dist: sgl-kernel==0.3.8; extra == "srt"
|
||||
Requires-Dist: torch==2.8.0; extra == "srt"
|
||||
Requires-Dist: torchaudio==2.8.0; extra == "srt"
|
||||
Requires-Dist: torchvision; extra == "srt"
|
||||
Requires-Dist: cuda-python; extra == "srt"
|
||||
Requires-Dist: flashinfer_python==0.3.0; extra == "srt"
|
||||
Provides-Extra: blackwell
|
||||
Requires-Dist: sglang[runtime_common]; extra == "blackwell"
|
||||
Requires-Dist: sgl-kernel; extra == "blackwell"
|
||||
Requires-Dist: torch==2.8.0; extra == "blackwell"
|
||||
Requires-Dist: torchaudio==2.8.0; extra == "blackwell"
|
||||
Requires-Dist: torchvision; extra == "blackwell"
|
||||
Requires-Dist: cuda-python; extra == "blackwell"
|
||||
Requires-Dist: flashinfer_python==0.3.0; extra == "blackwell"
|
||||
Provides-Extra: srt-hip
|
||||
Requires-Dist: sglang[runtime_common]; extra == "srt-hip"
|
||||
Requires-Dist: torch; extra == "srt-hip"
|
||||
Requires-Dist: petit_kernel==0.0.2; extra == "srt-hip"
|
||||
Requires-Dist: wave-lang==1.0.1; extra == "srt-hip"
|
||||
Provides-Extra: srt-cpu
|
||||
Requires-Dist: sglang[runtime_common]; extra == "srt-cpu"
|
||||
Provides-Extra: srt-npu
|
||||
Requires-Dist: sglang[runtime_common]; extra == "srt-npu"
|
||||
Provides-Extra: srt-xpu
|
||||
Requires-Dist: sglang[runtime_common]; extra == "srt-xpu"
|
||||
Provides-Extra: srt-hpu
|
||||
Requires-Dist: sglang[runtime_common]; extra == "srt-hpu"
|
||||
Provides-Extra: openai
|
||||
Requires-Dist: openai==1.99.1; extra == "openai"
|
||||
Requires-Dist: tiktoken; extra == "openai"
|
||||
Provides-Extra: anthropic
|
||||
Requires-Dist: anthropic>=0.20.0; extra == "anthropic"
|
||||
Provides-Extra: litellm
|
||||
Requires-Dist: litellm>=1.0.0; extra == "litellm"
|
||||
Provides-Extra: torch-memory-saver
|
||||
Requires-Dist: torch_memory_saver==0.0.8; extra == "torch-memory-saver"
|
||||
Provides-Extra: decord
|
||||
Requires-Dist: decord; extra == "decord"
|
||||
Provides-Extra: test
|
||||
Requires-Dist: accelerate; extra == "test"
|
||||
Requires-Dist: expecttest; extra == "test"
|
||||
Requires-Dist: jsonlines; extra == "test"
|
||||
Requires-Dist: matplotlib; extra == "test"
|
||||
Requires-Dist: pandas; extra == "test"
|
||||
Requires-Dist: peft; extra == "test"
|
||||
Requires-Dist: sentence_transformers; extra == "test"
|
||||
Requires-Dist: pytest; extra == "test"
|
||||
Requires-Dist: tabulate; extra == "test"
|
||||
Provides-Extra: all
|
||||
Requires-Dist: sglang[srt]; extra == "all"
|
||||
Requires-Dist: sglang[openai]; extra == "all"
|
||||
Requires-Dist: sglang[anthropic]; extra == "all"
|
||||
Requires-Dist: sglang[torch_memory_saver]; extra == "all"
|
||||
Requires-Dist: sglang[decord]; extra == "all"
|
||||
Provides-Extra: all-hip
|
||||
Requires-Dist: sglang[srt_hip]; extra == "all-hip"
|
||||
Requires-Dist: sglang[openai]; extra == "all-hip"
|
||||
Requires-Dist: sglang[anthropic]; extra == "all-hip"
|
||||
Requires-Dist: sglang[decord]; extra == "all-hip"
|
||||
Provides-Extra: all-xpu
|
||||
Requires-Dist: sglang[srt_xpu]; extra == "all-xpu"
|
||||
Requires-Dist: sglang[openai]; extra == "all-xpu"
|
||||
Requires-Dist: sglang[anthropic]; extra == "all-xpu"
|
||||
Requires-Dist: sglang[decord]; extra == "all-xpu"
|
||||
Provides-Extra: all-hpu
|
||||
Requires-Dist: sglang[srt_hpu]; extra == "all-hpu"
|
||||
Requires-Dist: sglang[openai]; extra == "all-hpu"
|
||||
Requires-Dist: sglang[anthropic]; extra == "all-hpu"
|
||||
Requires-Dist: sglang[decord]; extra == "all-hpu"
|
||||
Provides-Extra: all-cpu
|
||||
Requires-Dist: sglang[srt_cpu]; extra == "all-cpu"
|
||||
Requires-Dist: sglang[openai]; extra == "all-cpu"
|
||||
Requires-Dist: sglang[anthropic]; extra == "all-cpu"
|
||||
Requires-Dist: sglang[decord]; extra == "all-cpu"
|
||||
Provides-Extra: all-npu
|
||||
Requires-Dist: sglang[srt_npu]; extra == "all-npu"
|
||||
Requires-Dist: sglang[openai]; extra == "all-npu"
|
||||
Requires-Dist: sglang[anthropic]; extra == "all-npu"
|
||||
Requires-Dist: sglang[decord]; extra == "all-npu"
|
||||
Provides-Extra: dev
|
||||
Requires-Dist: sglang[all]; extra == "dev"
|
||||
Requires-Dist: sglang[test]; extra == "dev"
|
||||
Provides-Extra: dev-hip
|
||||
Requires-Dist: sglang[all_hip]; extra == "dev-hip"
|
||||
Requires-Dist: sglang[test]; extra == "dev-hip"
|
||||
Provides-Extra: dev-xpu
|
||||
Requires-Dist: sglang[all_xpu]; extra == "dev-xpu"
|
||||
Requires-Dist: sglang[test]; extra == "dev-xpu"
|
||||
Provides-Extra: dev-hpu
|
||||
Requires-Dist: sglang[all_hpu]; extra == "dev-hpu"
|
||||
Requires-Dist: sglang[test]; extra == "dev-hpu"
|
||||
Provides-Extra: dev-cpu
|
||||
Requires-Dist: sglang[all_cpu]; extra == "dev-cpu"
|
||||
Requires-Dist: sglang[test]; extra == "dev-cpu"
|
||||
883
python/sglang.egg-info/SOURCES.txt
Normal file
883
python/sglang.egg-info/SOURCES.txt
Normal file
@@ -0,0 +1,883 @@
|
||||
pyproject.toml
|
||||
sglang/__init__.py
|
||||
sglang/bench_offline_throughput.py
|
||||
sglang/bench_one_batch.py
|
||||
sglang/bench_one_batch_server.py
|
||||
sglang/bench_serving.py
|
||||
sglang/check_env.py
|
||||
sglang/compile_deep_gemm.py
|
||||
sglang/global_config.py
|
||||
sglang/launch_server.py
|
||||
sglang/profiler.py
|
||||
sglang/utils.py
|
||||
sglang/version.py
|
||||
sglang.egg-info/PKG-INFO
|
||||
sglang.egg-info/SOURCES.txt
|
||||
sglang.egg-info/dependency_links.txt
|
||||
sglang.egg-info/requires.txt
|
||||
sglang.egg-info/top_level.txt
|
||||
sglang/eval/llama3_eval.py
|
||||
sglang/eval/loogle_eval.py
|
||||
sglang/lang/api.py
|
||||
sglang/lang/chat_template.py
|
||||
sglang/lang/choices.py
|
||||
sglang/lang/compiler.py
|
||||
sglang/lang/interpreter.py
|
||||
sglang/lang/ir.py
|
||||
sglang/lang/tracer.py
|
||||
sglang/lang/backend/anthropic.py
|
||||
sglang/lang/backend/base_backend.py
|
||||
sglang/lang/backend/litellm.py
|
||||
sglang/lang/backend/openai.py
|
||||
sglang/lang/backend/runtime_endpoint.py
|
||||
sglang/lang/backend/vertexai.py
|
||||
sglang/srt/_custom_ops.py
|
||||
sglang/srt/aio_rwlock.py
|
||||
sglang/srt/bench_utils.py
|
||||
sglang/srt/constants.py
|
||||
sglang/srt/custom_op.py
|
||||
sglang/srt/hf_transformers_utils.py
|
||||
sglang/srt/host_shared_memory.py
|
||||
sglang/srt/offloader.py
|
||||
sglang/srt/operations.py
|
||||
sglang/srt/operations_strategy.py
|
||||
sglang/srt/patch_torch.py
|
||||
sglang/srt/poll_based_barrier.py
|
||||
sglang/srt/server_args.py
|
||||
sglang/srt/torch_memory_saver_adapter.py
|
||||
sglang/srt/two_batch_overlap.py
|
||||
sglang/srt/utils.py
|
||||
sglang/srt/warmup.py
|
||||
sglang/srt/configs/__init__.py
|
||||
sglang/srt/configs/chatglm.py
|
||||
sglang/srt/configs/dbrx.py
|
||||
sglang/srt/configs/deepseekvl2.py
|
||||
sglang/srt/configs/device_config.py
|
||||
sglang/srt/configs/exaone.py
|
||||
sglang/srt/configs/internvl.py
|
||||
sglang/srt/configs/janus_pro.py
|
||||
sglang/srt/configs/kimi_vl.py
|
||||
sglang/srt/configs/kimi_vl_moonvit.py
|
||||
sglang/srt/configs/load_config.py
|
||||
sglang/srt/configs/longcat_flash.py
|
||||
sglang/srt/configs/model_config.py
|
||||
sglang/srt/configs/step3_vl.py
|
||||
sglang/srt/configs/update_config.py
|
||||
sglang/srt/configs/utils.py
|
||||
sglang/srt/connector/__init__.py
|
||||
sglang/srt/connector/base_connector.py
|
||||
sglang/srt/connector/redis.py
|
||||
sglang/srt/connector/s3.py
|
||||
sglang/srt/connector/utils.py
|
||||
sglang/srt/connector/serde/__init__.py
|
||||
sglang/srt/connector/serde/safe_serde.py
|
||||
sglang/srt/connector/serde/serde.py
|
||||
sglang/srt/constrained/base_grammar_backend.py
|
||||
sglang/srt/constrained/llguidance_backend.py
|
||||
sglang/srt/constrained/outlines_backend.py
|
||||
sglang/srt/constrained/outlines_jump_forward.py
|
||||
sglang/srt/constrained/reasoner_grammar_backend.py
|
||||
sglang/srt/constrained/xgrammar_backend.py
|
||||
sglang/srt/constrained/triton_ops/bitmask_ops.py
|
||||
sglang/srt/debug_utils/__init__.py
|
||||
sglang/srt/debug_utils/dump_comparator.py
|
||||
sglang/srt/debug_utils/dumper.py
|
||||
sglang/srt/debug_utils/text_comparator.py
|
||||
sglang/srt/disaggregation/decode.py
|
||||
sglang/srt/disaggregation/decode_schedule_batch_mixin.py
|
||||
sglang/srt/disaggregation/kv_events.py
|
||||
sglang/srt/disaggregation/launch_lb.py
|
||||
sglang/srt/disaggregation/mini_lb.py
|
||||
sglang/srt/disaggregation/prefill.py
|
||||
sglang/srt/disaggregation/utils.py
|
||||
sglang/srt/disaggregation/ascend/__init__.py
|
||||
sglang/srt/disaggregation/ascend/conn.py
|
||||
sglang/srt/disaggregation/ascend/transfer_engine.py
|
||||
sglang/srt/disaggregation/base/__init__.py
|
||||
sglang/srt/disaggregation/base/conn.py
|
||||
sglang/srt/disaggregation/common/__init__.py
|
||||
sglang/srt/disaggregation/common/conn.py
|
||||
sglang/srt/disaggregation/common/utils.py
|
||||
sglang/srt/disaggregation/fake/__init__.py
|
||||
sglang/srt/disaggregation/fake/conn.py
|
||||
sglang/srt/disaggregation/mooncake/__init__.py
|
||||
sglang/srt/disaggregation/mooncake/conn.py
|
||||
sglang/srt/disaggregation/mooncake/transfer_engine.py
|
||||
sglang/srt/disaggregation/nixl/__init__.py
|
||||
sglang/srt/disaggregation/nixl/conn.py
|
||||
sglang/srt/distributed/__init__.py
|
||||
sglang/srt/distributed/communication_op.py
|
||||
sglang/srt/distributed/naive_distributed.py
|
||||
sglang/srt/distributed/parallel_state.py
|
||||
sglang/srt/distributed/utils.py
|
||||
sglang/srt/distributed/device_communicators/cuda_wrapper.py
|
||||
sglang/srt/distributed/device_communicators/custom_all_reduce.py
|
||||
sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py
|
||||
sglang/srt/distributed/device_communicators/hpu_communicator.py
|
||||
sglang/srt/distributed/device_communicators/npu_communicator.py
|
||||
sglang/srt/distributed/device_communicators/pymscclpp.py
|
||||
sglang/srt/distributed/device_communicators/pynccl.py
|
||||
sglang/srt/distributed/device_communicators/pynccl_allocator.py
|
||||
sglang/srt/distributed/device_communicators/pynccl_wrapper.py
|
||||
sglang/srt/distributed/device_communicators/quick_all_reduce.py
|
||||
sglang/srt/distributed/device_communicators/shm_broadcast.py
|
||||
sglang/srt/distributed/device_communicators/xpu_communicator.py
|
||||
sglang/srt/entrypoints/EngineBase.py
|
||||
sglang/srt/entrypoints/context.py
|
||||
sglang/srt/entrypoints/engine.py
|
||||
sglang/srt/entrypoints/harmony_utils.py
|
||||
sglang/srt/entrypoints/http_server.py
|
||||
sglang/srt/entrypoints/http_server_engine.py
|
||||
sglang/srt/entrypoints/tool.py
|
||||
sglang/srt/entrypoints/openai/__init__.py
|
||||
sglang/srt/entrypoints/openai/protocol.py
|
||||
sglang/srt/entrypoints/openai/serving_base.py
|
||||
sglang/srt/entrypoints/openai/serving_chat.py
|
||||
sglang/srt/entrypoints/openai/serving_completions.py
|
||||
sglang/srt/entrypoints/openai/serving_embedding.py
|
||||
sglang/srt/entrypoints/openai/serving_rerank.py
|
||||
sglang/srt/entrypoints/openai/serving_responses.py
|
||||
sglang/srt/entrypoints/openai/serving_score.py
|
||||
sglang/srt/entrypoints/openai/tool_server.py
|
||||
sglang/srt/entrypoints/openai/usage_processor.py
|
||||
sglang/srt/entrypoints/openai/utils.py
|
||||
sglang/srt/eplb/__init__.py
|
||||
sglang/srt/eplb/eplb_manager.py
|
||||
sglang/srt/eplb/expert_distribution.py
|
||||
sglang/srt/eplb/expert_location.py
|
||||
sglang/srt/eplb/expert_location_dispatch.py
|
||||
sglang/srt/eplb/expert_location_updater.py
|
||||
sglang/srt/eplb/eplb_algorithms/__init__.py
|
||||
sglang/srt/eplb/eplb_algorithms/deepseek.py
|
||||
sglang/srt/eplb/eplb_algorithms/deepseek_vec.py
|
||||
sglang/srt/eplb/eplb_simulator/__init__.py
|
||||
sglang/srt/eplb/eplb_simulator/reader.py
|
||||
sglang/srt/function_call/base_format_detector.py
|
||||
sglang/srt/function_call/core_types.py
|
||||
sglang/srt/function_call/deepseekv31_detector.py
|
||||
sglang/srt/function_call/deepseekv3_detector.py
|
||||
sglang/srt/function_call/ebnf_composer.py
|
||||
sglang/srt/function_call/function_call_parser.py
|
||||
sglang/srt/function_call/glm4_moe_detector.py
|
||||
sglang/srt/function_call/gpt_oss_detector.py
|
||||
sglang/srt/function_call/kimik2_detector.py
|
||||
sglang/srt/function_call/llama32_detector.py
|
||||
sglang/srt/function_call/mistral_detector.py
|
||||
sglang/srt/function_call/pythonic_detector.py
|
||||
sglang/srt/function_call/qwen25_detector.py
|
||||
sglang/srt/function_call/qwen3_coder_detector.py
|
||||
sglang/srt/function_call/step3_detector.py
|
||||
sglang/srt/function_call/utils.py
|
||||
sglang/srt/layers/activation.py
|
||||
sglang/srt/layers/amx_utils.py
|
||||
sglang/srt/layers/communicator.py
|
||||
sglang/srt/layers/dp_attention.py
|
||||
sglang/srt/layers/elementwise.py
|
||||
sglang/srt/layers/flashinfer_comm_fusion.py
|
||||
sglang/srt/layers/layernorm.py
|
||||
sglang/srt/layers/linear.py
|
||||
sglang/srt/layers/logits_processor.py
|
||||
sglang/srt/layers/model_parallel.py
|
||||
sglang/srt/layers/multimodal.py
|
||||
sglang/srt/layers/parameter.py
|
||||
sglang/srt/layers/pooler.py
|
||||
sglang/srt/layers/radix_attention.py
|
||||
sglang/srt/layers/rotary_embedding.py
|
||||
sglang/srt/layers/sampler.py
|
||||
sglang/srt/layers/torchao_utils.py
|
||||
sglang/srt/layers/utils.py
|
||||
sglang/srt/layers/vocab_parallel_embedding.py
|
||||
sglang/srt/layers/attention/aiter_backend.py
|
||||
sglang/srt/layers/attention/ascend_backend.py
|
||||
sglang/srt/layers/attention/base_attn_backend.py
|
||||
sglang/srt/layers/attention/cutlass_mla_backend.py
|
||||
sglang/srt/layers/attention/double_sparsity_backend.py
|
||||
sglang/srt/layers/attention/dual_chunk_flashattention_backend.py
|
||||
sglang/srt/layers/attention/flashattention_backend.py
|
||||
sglang/srt/layers/attention/flashinfer_backend.py
|
||||
sglang/srt/layers/attention/flashinfer_mla_backend.py
|
||||
sglang/srt/layers/attention/flashmla_backend.py
|
||||
sglang/srt/layers/attention/hybrid_attn_backend.py
|
||||
sglang/srt/layers/attention/intel_amx_backend.py
|
||||
sglang/srt/layers/attention/merge_state.py
|
||||
sglang/srt/layers/attention/tbo_backend.py
|
||||
sglang/srt/layers/attention/torch_native_backend.py
|
||||
sglang/srt/layers/attention/triton_backend.py
|
||||
sglang/srt/layers/attention/trtllm_mha_backend.py
|
||||
sglang/srt/layers/attention/trtllm_mla_backend.py
|
||||
sglang/srt/layers/attention/utils.py
|
||||
sglang/srt/layers/attention/vision.py
|
||||
sglang/srt/layers/attention/vision_utils.py
|
||||
sglang/srt/layers/attention/wave_backend.py
|
||||
sglang/srt/layers/attention/triton_ops/decode_attention.py
|
||||
sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py
|
||||
sglang/srt/layers/attention/triton_ops/extend_attention.py
|
||||
sglang/srt/layers/attention/triton_ops/merge_state.py
|
||||
sglang/srt/layers/attention/triton_ops/prefill_attention.py
|
||||
sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py
|
||||
sglang/srt/layers/attention/wave_ops/decode_attention.py
|
||||
sglang/srt/layers/attention/wave_ops/extend_attention.py
|
||||
sglang/srt/layers/attention/wave_ops/prefill_attention.py
|
||||
sglang/srt/layers/moe/__init__.py
|
||||
sglang/srt/layers/moe/cutlass_moe.py
|
||||
sglang/srt/layers/moe/cutlass_moe_params.py
|
||||
sglang/srt/layers/moe/cutlass_w4a8_moe.py
|
||||
sglang/srt/layers/moe/fused_moe_native.py
|
||||
sglang/srt/layers/moe/rocm_moe_utils.py
|
||||
sglang/srt/layers/moe/router.py
|
||||
sglang/srt/layers/moe/topk.py
|
||||
sglang/srt/layers/moe/utils.py
|
||||
sglang/srt/layers/moe/ep_moe/__init__.py
|
||||
sglang/srt/layers/moe/ep_moe/kernels.py
|
||||
sglang/srt/layers/moe/ep_moe/layer.py
|
||||
sglang/srt/layers/moe/fused_moe_triton/__init__.py
|
||||
sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
|
||||
sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py
|
||||
sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py
|
||||
sglang/srt/layers/moe/fused_moe_triton/layer.py
|
||||
sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py
|
||||
sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=1024,device_name=NVIDIA_H200.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H200.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=2560,device_name=NVIDIA_H200.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=320,device_name=NVIDIA_H200.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=64,N=640,device_name=NVIDIA_H200.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=14336,device_name=AMD_Instinct_MI300X.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=14336,device_name=AMD_Instinct_MI325X.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=14336,device_name=AMD_Radeon_Graphics.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=14336,device_name=NVIDIA_H200.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=1792,device_name=AMD_Instinct_MI300X.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=1792,device_name=AMD_Instinct_MI325X.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=1792,device_name=AMD_Radeon_Graphics.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=1792,device_name=NVIDIA_H200.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H200.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=3584,device_name=AMD_Instinct_MI300X.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=3584,device_name=AMD_Instinct_MI325X.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=3584,device_name=AMD_Radeon_Graphics.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H200.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=3584,device_name=NVIDIA_L40S.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H200.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=7168,device_name=AMD_Instinct_MI300X.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=7168,device_name=AMD_Instinct_MI325X.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=7168,device_name=AMD_Radeon_Graphics.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H200.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=192,device_name=NVIDIA_H20.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=192,device_name=NVIDIA_H200.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=384,device_name=NVIDIA_H20.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=384,device_name=NVIDIA_H200.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=768,device_name=NVIDIA_H20.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=768,device_name=NVIDIA_H200.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=96,device_name=NVIDIA_H20.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json
|
||||
sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/moe/moe_runner/__init__.py
|
||||
sglang/srt/layers/moe/moe_runner/base.py
|
||||
sglang/srt/layers/moe/token_dispatcher/__init__.py
|
||||
sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py
|
||||
sglang/srt/layers/moe/token_dispatcher/deepep.py
|
||||
sglang/srt/layers/moe/token_dispatcher/standard.py
|
||||
sglang/srt/layers/quantization/__init__.py
|
||||
sglang/srt/layers/quantization/awq.py
|
||||
sglang/srt/layers/quantization/awq_triton.py
|
||||
sglang/srt/layers/quantization/base_config.py
|
||||
sglang/srt/layers/quantization/blockwise_int8.py
|
||||
sglang/srt/layers/quantization/fp8.py
|
||||
sglang/srt/layers/quantization/fp8_kernel.py
|
||||
sglang/srt/layers/quantization/fp8_utils.py
|
||||
sglang/srt/layers/quantization/fpgemm_fp8.py
|
||||
sglang/srt/layers/quantization/gptq.py
|
||||
sglang/srt/layers/quantization/int8_kernel.py
|
||||
sglang/srt/layers/quantization/int8_utils.py
|
||||
sglang/srt/layers/quantization/kv_cache.py
|
||||
sglang/srt/layers/quantization/marlin_utils.py
|
||||
sglang/srt/layers/quantization/marlin_utils_fp8.py
|
||||
sglang/srt/layers/quantization/modelopt_quant.py
|
||||
sglang/srt/layers/quantization/moe_wna16.py
|
||||
sglang/srt/layers/quantization/mxfp4.py
|
||||
sglang/srt/layers/quantization/mxfp4_tensor.py
|
||||
sglang/srt/layers/quantization/petit.py
|
||||
sglang/srt/layers/quantization/petit_utils.py
|
||||
sglang/srt/layers/quantization/qoq.py
|
||||
sglang/srt/layers/quantization/unquant.py
|
||||
sglang/srt/layers/quantization/utils.py
|
||||
sglang/srt/layers/quantization/w4afp8.py
|
||||
sglang/srt/layers/quantization/w8a8_fp8.py
|
||||
sglang/srt/layers/quantization/w8a8_int8.py
|
||||
sglang/srt/layers/quantization/compressed_tensors/__init__.py
|
||||
sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
|
||||
sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
|
||||
sglang/srt/layers/quantization/compressed_tensors/utils.py
|
||||
sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py
|
||||
sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py
|
||||
sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
|
||||
sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
|
||||
sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json
|
||||
sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py
|
||||
sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py
|
||||
sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py
|
||||
sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py
|
||||
sglang/srt/layers/quantization/quark/__init__.py
|
||||
sglang/srt/layers/quantization/quark/quark.py
|
||||
sglang/srt/layers/quantization/quark/quark_moe.py
|
||||
sglang/srt/layers/quantization/quark/utils.py
|
||||
sglang/srt/layers/quantization/quark/schemes/__init__.py
|
||||
sglang/srt/layers/quantization/quark/schemes/quark_scheme.py
|
||||
sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py
|
||||
sglang/srt/lora/layers.py
|
||||
sglang/srt/lora/lora.py
|
||||
sglang/srt/lora/lora_config.py
|
||||
sglang/srt/lora/lora_manager.py
|
||||
sglang/srt/lora/lora_registry.py
|
||||
sglang/srt/lora/mem_pool.py
|
||||
sglang/srt/lora/utils.py
|
||||
sglang/srt/lora/backend/base_backend.py
|
||||
sglang/srt/lora/backend/triton_backend.py
|
||||
sglang/srt/lora/triton_ops/__init__.py
|
||||
sglang/srt/lora/triton_ops/gate_up_lora_b.py
|
||||
sglang/srt/lora/triton_ops/qkv_lora_b.py
|
||||
sglang/srt/lora/triton_ops/sgemm_lora_a.py
|
||||
sglang/srt/lora/triton_ops/sgemm_lora_b.py
|
||||
sglang/srt/managers/cache_controller.py
|
||||
sglang/srt/managers/configure_logging.py
|
||||
sglang/srt/managers/data_parallel_controller.py
|
||||
sglang/srt/managers/detokenizer_manager.py
|
||||
sglang/srt/managers/io_struct.py
|
||||
sglang/srt/managers/mm_utils.py
|
||||
sglang/srt/managers/multi_tokenizer_mixin.py
|
||||
sglang/srt/managers/multimodal_processor.py
|
||||
sglang/srt/managers/schedule_batch.py
|
||||
sglang/srt/managers/schedule_policy.py
|
||||
sglang/srt/managers/scheduler.py
|
||||
sglang/srt/managers/scheduler_input_blocker.py
|
||||
sglang/srt/managers/scheduler_metrics_mixin.py
|
||||
sglang/srt/managers/scheduler_output_processor_mixin.py
|
||||
sglang/srt/managers/scheduler_profiler_mixin.py
|
||||
sglang/srt/managers/scheduler_recv_skipper.py
|
||||
sglang/srt/managers/scheduler_update_weights_mixin.py
|
||||
sglang/srt/managers/session_controller.py
|
||||
sglang/srt/managers/template_manager.py
|
||||
sglang/srt/managers/tokenizer_manager.py
|
||||
sglang/srt/managers/tp_worker.py
|
||||
sglang/srt/managers/tp_worker_overlap_thread.py
|
||||
sglang/srt/managers/utils.py
|
||||
sglang/srt/mem_cache/allocator.py
|
||||
sglang/srt/mem_cache/allocator_ascend.py
|
||||
sglang/srt/mem_cache/base_prefix_cache.py
|
||||
sglang/srt/mem_cache/chunk_cache.py
|
||||
sglang/srt/mem_cache/flush_cache.py
|
||||
sglang/srt/mem_cache/hicache_storage.py
|
||||
sglang/srt/mem_cache/hiradix_cache.py
|
||||
sglang/srt/mem_cache/lora_radix_cache.py
|
||||
sglang/srt/mem_cache/memory_pool.py
|
||||
sglang/srt/mem_cache/memory_pool_host.py
|
||||
sglang/srt/mem_cache/multimodal_cache.py
|
||||
sglang/srt/mem_cache/radix_cache.py
|
||||
sglang/srt/mem_cache/radix_cache_cpp.py
|
||||
sglang/srt/mem_cache/swa_radix_cache.py
|
||||
sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py
|
||||
sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py
|
||||
sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp
|
||||
sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py
|
||||
sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py
|
||||
sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py
|
||||
sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py
|
||||
sglang/srt/mem_cache/storage/mooncake_store/unit_test.py
|
||||
sglang/srt/mem_cache/storage/nixl/hicache_nixl.py
|
||||
sglang/srt/mem_cache/storage/nixl/nixl_utils.py
|
||||
sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py
|
||||
sglang/srt/metrics/collector.py
|
||||
sglang/srt/metrics/func_timer.py
|
||||
sglang/srt/model_executor/cuda_graph_runner.py
|
||||
sglang/srt/model_executor/forward_batch_info.py
|
||||
sglang/srt/model_executor/model_runner.py
|
||||
sglang/srt/model_executor/npu_graph_runner.py
|
||||
sglang/srt/model_loader/__init__.py
|
||||
sglang/srt/model_loader/loader.py
|
||||
sglang/srt/model_loader/utils.py
|
||||
sglang/srt/model_loader/weight_utils.py
|
||||
sglang/srt/models/arcee.py
|
||||
sglang/srt/models/baichuan.py
|
||||
sglang/srt/models/bailing_moe.py
|
||||
sglang/srt/models/bert.py
|
||||
sglang/srt/models/chatglm.py
|
||||
sglang/srt/models/clip.py
|
||||
sglang/srt/models/commandr.py
|
||||
sglang/srt/models/dbrx.py
|
||||
sglang/srt/models/deepseek.py
|
||||
sglang/srt/models/deepseek_janus_pro.py
|
||||
sglang/srt/models/deepseek_nextn.py
|
||||
sglang/srt/models/deepseek_v2.py
|
||||
sglang/srt/models/deepseek_vl2.py
|
||||
sglang/srt/models/ernie4.py
|
||||
sglang/srt/models/ernie4_eagle.py
|
||||
sglang/srt/models/exaone.py
|
||||
sglang/srt/models/gemma.py
|
||||
sglang/srt/models/gemma2.py
|
||||
sglang/srt/models/gemma2_reward.py
|
||||
sglang/srt/models/gemma3_causal.py
|
||||
sglang/srt/models/gemma3_mm.py
|
||||
sglang/srt/models/gemma3n_audio.py
|
||||
sglang/srt/models/gemma3n_causal.py
|
||||
sglang/srt/models/gemma3n_mm.py
|
||||
sglang/srt/models/glm4.py
|
||||
sglang/srt/models/glm4_moe.py
|
||||
sglang/srt/models/glm4_moe_nextn.py
|
||||
sglang/srt/models/glm4v.py
|
||||
sglang/srt/models/glm4v_moe.py
|
||||
sglang/srt/models/gpt2.py
|
||||
sglang/srt/models/gpt_bigcode.py
|
||||
sglang/srt/models/gpt_oss.py
|
||||
sglang/srt/models/granite.py
|
||||
sglang/srt/models/granitemoe.py
|
||||
sglang/srt/models/grok.py
|
||||
sglang/srt/models/hunyuan.py
|
||||
sglang/srt/models/idefics2.py
|
||||
sglang/srt/models/internlm2.py
|
||||
sglang/srt/models/internlm2_reward.py
|
||||
sglang/srt/models/interns1.py
|
||||
sglang/srt/models/internvl.py
|
||||
sglang/srt/models/kimi_vl.py
|
||||
sglang/srt/models/kimi_vl_moonvit.py
|
||||
sglang/srt/models/llama.py
|
||||
sglang/srt/models/llama4.py
|
||||
sglang/srt/models/llama_classification.py
|
||||
sglang/srt/models/llama_eagle.py
|
||||
sglang/srt/models/llama_eagle3.py
|
||||
sglang/srt/models/llama_embedding.py
|
||||
sglang/srt/models/llama_reward.py
|
||||
sglang/srt/models/llava.py
|
||||
sglang/srt/models/llavavid.py
|
||||
sglang/srt/models/longcat_flash.py
|
||||
sglang/srt/models/longcat_flash_nextn.py
|
||||
sglang/srt/models/mimo.py
|
||||
sglang/srt/models/mimo_mtp.py
|
||||
sglang/srt/models/minicpm.py
|
||||
sglang/srt/models/minicpm3.py
|
||||
sglang/srt/models/minicpmo.py
|
||||
sglang/srt/models/minicpmv.py
|
||||
sglang/srt/models/mistral.py
|
||||
sglang/srt/models/mixtral.py
|
||||
sglang/srt/models/mixtral_quant.py
|
||||
sglang/srt/models/mllama.py
|
||||
sglang/srt/models/mllama4.py
|
||||
sglang/srt/models/nemotron_nas.py
|
||||
sglang/srt/models/olmo.py
|
||||
sglang/srt/models/olmo2.py
|
||||
sglang/srt/models/olmoe.py
|
||||
sglang/srt/models/persimmon.py
|
||||
sglang/srt/models/phi.py
|
||||
sglang/srt/models/phi3_small.py
|
||||
sglang/srt/models/phi4mm.py
|
||||
sglang/srt/models/phi4mm_audio.py
|
||||
sglang/srt/models/phi4mm_utils.py
|
||||
sglang/srt/models/phimoe.py
|
||||
sglang/srt/models/pixtral.py
|
||||
sglang/srt/models/qwen.py
|
||||
sglang/srt/models/qwen2.py
|
||||
sglang/srt/models/qwen2_5_vl.py
|
||||
sglang/srt/models/qwen2_audio.py
|
||||
sglang/srt/models/qwen2_classification.py
|
||||
sglang/srt/models/qwen2_eagle.py
|
||||
sglang/srt/models/qwen2_moe.py
|
||||
sglang/srt/models/qwen2_rm.py
|
||||
sglang/srt/models/qwen2_vl.py
|
||||
sglang/srt/models/qwen3.py
|
||||
sglang/srt/models/qwen3_classification.py
|
||||
sglang/srt/models/qwen3_moe.py
|
||||
sglang/srt/models/registry.py
|
||||
sglang/srt/models/roberta.py
|
||||
sglang/srt/models/siglip.py
|
||||
sglang/srt/models/stablelm.py
|
||||
sglang/srt/models/step3_vl.py
|
||||
sglang/srt/models/torch_native_llama.py
|
||||
sglang/srt/models/transformers.py
|
||||
sglang/srt/models/vila.py
|
||||
sglang/srt/models/xverse.py
|
||||
sglang/srt/models/xverse_moe.py
|
||||
sglang/srt/models/yivl.py
|
||||
sglang/srt/multimodal/mm_utils.py
|
||||
sglang/srt/multimodal/processors/base_processor.py
|
||||
sglang/srt/multimodal/processors/clip.py
|
||||
sglang/srt/multimodal/processors/deepseek_vl_v2.py
|
||||
sglang/srt/multimodal/processors/gemma3.py
|
||||
sglang/srt/multimodal/processors/gemma3n.py
|
||||
sglang/srt/multimodal/processors/glm4v.py
|
||||
sglang/srt/multimodal/processors/internvl.py
|
||||
sglang/srt/multimodal/processors/janus_pro.py
|
||||
sglang/srt/multimodal/processors/kimi_vl.py
|
||||
sglang/srt/multimodal/processors/llava.py
|
||||
sglang/srt/multimodal/processors/minicpm.py
|
||||
sglang/srt/multimodal/processors/mlama.py
|
||||
sglang/srt/multimodal/processors/mllama4.py
|
||||
sglang/srt/multimodal/processors/phi4mm.py
|
||||
sglang/srt/multimodal/processors/pixtral.py
|
||||
sglang/srt/multimodal/processors/qwen_audio.py
|
||||
sglang/srt/multimodal/processors/qwen_vl.py
|
||||
sglang/srt/multimodal/processors/step3_vl.py
|
||||
sglang/srt/multimodal/processors/vila.py
|
||||
sglang/srt/parser/code_completion_parser.py
|
||||
sglang/srt/parser/conversation.py
|
||||
sglang/srt/parser/harmony_parser.py
|
||||
sglang/srt/parser/jinja_template_utils.py
|
||||
sglang/srt/parser/reasoning_parser.py
|
||||
sglang/srt/sampling/custom_logit_processor.py
|
||||
sglang/srt/sampling/sampling_batch_info.py
|
||||
sglang/srt/sampling/sampling_params.py
|
||||
sglang/srt/sampling/penaltylib/__init__.py
|
||||
sglang/srt/sampling/penaltylib/frequency_penalty.py
|
||||
sglang/srt/sampling/penaltylib/min_new_tokens.py
|
||||
sglang/srt/sampling/penaltylib/orchestrator.py
|
||||
sglang/srt/sampling/penaltylib/presence_penalty.py
|
||||
sglang/srt/speculative/build_eagle_tree.py
|
||||
sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
|
||||
sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
|
||||
sglang/srt/speculative/eagle_utils.py
|
||||
sglang/srt/speculative/eagle_worker.py
|
||||
sglang/srt/speculative/spec_info.py
|
||||
sglang/srt/tokenizer/tiktoken_tokenizer.py
|
||||
sglang/srt/weight_sync/tensor_bucket.py
|
||||
sglang/srt/weight_sync/utils.py
|
||||
sglang/test/__init__.py
|
||||
sglang/test/doc_patch.py
|
||||
sglang/test/few_shot_gsm8k.py
|
||||
sglang/test/few_shot_gsm8k_engine.py
|
||||
sglang/test/run_eval.py
|
||||
sglang/test/runners.py
|
||||
sglang/test/send_one.py
|
||||
sglang/test/simple_eval_common.py
|
||||
sglang/test/simple_eval_gpqa.py
|
||||
sglang/test/simple_eval_humaneval.py
|
||||
sglang/test/simple_eval_math.py
|
||||
sglang/test/simple_eval_mgsm.py
|
||||
sglang/test/simple_eval_mmlu.py
|
||||
sglang/test/test_activation.py
|
||||
sglang/test/test_block_fp8.py
|
||||
sglang/test/test_block_fp8_deep_gemm_blackwell.py
|
||||
sglang/test/test_block_fp8_ep.py
|
||||
sglang/test/test_custom_ops.py
|
||||
sglang/test/test_cutlass_moe.py
|
||||
sglang/test/test_cutlass_w4a8_moe.py
|
||||
sglang/test/test_deepep_utils.py
|
||||
sglang/test/test_dynamic_grad_mode.py
|
||||
sglang/test/test_fp4_moe.py
|
||||
sglang/test/test_layernorm.py
|
||||
sglang/test/test_marlin_moe.py
|
||||
sglang/test/test_marlin_utils.py
|
||||
sglang/test/test_programs.py
|
||||
sglang/test/test_utils.py
|
||||
sglang/test/attention/__init__.py
|
||||
sglang/test/attention/test_flashattn_backend.py
|
||||
sglang/test/attention/test_flashattn_mla_backend.py
|
||||
sglang/test/attention/test_prefix_chunk_info.py
|
||||
sglang/test/attention/test_trtllm_mla_backend.py
|
||||
1
python/sglang.egg-info/dependency_links.txt
Normal file
1
python/sglang.egg-info/dependency_links.txt
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
165
python/sglang.egg-info/requires.txt
Normal file
165
python/sglang.egg-info/requires.txt
Normal file
@@ -0,0 +1,165 @@
|
||||
aiohttp
|
||||
requests
|
||||
tqdm
|
||||
numpy
|
||||
IPython
|
||||
setproctitle
|
||||
|
||||
[all]
|
||||
sglang[srt]
|
||||
sglang[openai]
|
||||
sglang[anthropic]
|
||||
sglang[torch_memory_saver]
|
||||
sglang[decord]
|
||||
|
||||
[all_cpu]
|
||||
sglang[srt_cpu]
|
||||
sglang[openai]
|
||||
sglang[anthropic]
|
||||
sglang[decord]
|
||||
|
||||
[all_hip]
|
||||
sglang[srt_hip]
|
||||
sglang[openai]
|
||||
sglang[anthropic]
|
||||
sglang[decord]
|
||||
|
||||
[all_hpu]
|
||||
sglang[srt_hpu]
|
||||
sglang[openai]
|
||||
sglang[anthropic]
|
||||
sglang[decord]
|
||||
|
||||
[all_npu]
|
||||
sglang[srt_npu]
|
||||
sglang[openai]
|
||||
sglang[anthropic]
|
||||
sglang[decord]
|
||||
|
||||
[all_xpu]
|
||||
sglang[srt_xpu]
|
||||
sglang[openai]
|
||||
sglang[anthropic]
|
||||
sglang[decord]
|
||||
|
||||
[anthropic]
|
||||
anthropic>=0.20.0
|
||||
|
||||
[blackwell]
|
||||
sglang[runtime_common]
|
||||
sgl-kernel
|
||||
torch==2.8.0
|
||||
torchaudio==2.8.0
|
||||
torchvision
|
||||
cuda-python
|
||||
flashinfer_python==0.3.0
|
||||
|
||||
[decord]
|
||||
decord
|
||||
|
||||
[dev]
|
||||
sglang[all]
|
||||
sglang[test]
|
||||
|
||||
[dev_cpu]
|
||||
sglang[all_cpu]
|
||||
sglang[test]
|
||||
|
||||
[dev_hip]
|
||||
sglang[all_hip]
|
||||
sglang[test]
|
||||
|
||||
[dev_hpu]
|
||||
sglang[all_hpu]
|
||||
sglang[test]
|
||||
|
||||
[dev_xpu]
|
||||
sglang[all_xpu]
|
||||
sglang[test]
|
||||
|
||||
[litellm]
|
||||
litellm>=1.0.0
|
||||
|
||||
[openai]
|
||||
openai==1.99.1
|
||||
tiktoken
|
||||
|
||||
[runtime_common]
|
||||
blobfile==3.0.0
|
||||
build
|
||||
compressed-tensors
|
||||
datasets
|
||||
einops
|
||||
fastapi
|
||||
hf_transfer
|
||||
huggingface_hub
|
||||
interegular
|
||||
llguidance<0.8.0,>=0.7.11
|
||||
modelscope
|
||||
msgspec
|
||||
ninja
|
||||
openai==1.99.1
|
||||
openai-harmony==0.0.4
|
||||
orjson
|
||||
outlines==0.1.11
|
||||
packaging
|
||||
partial_json_parser
|
||||
pillow
|
||||
prometheus-client>=0.20.0
|
||||
psutil
|
||||
pybase64
|
||||
pydantic
|
||||
pynvml
|
||||
python-multipart
|
||||
pyzmq>=25.1.2
|
||||
sentencepiece
|
||||
soundfile==0.13.1
|
||||
scipy
|
||||
timm==1.0.16
|
||||
tiktoken
|
||||
torchao==0.9.0
|
||||
transformers==4.56.0
|
||||
uvicorn
|
||||
uvloop
|
||||
xgrammar==0.1.23
|
||||
|
||||
[srt]
|
||||
sglang[runtime_common]
|
||||
sgl-kernel==0.3.8
|
||||
torch==2.8.0
|
||||
torchaudio==2.8.0
|
||||
torchvision
|
||||
cuda-python
|
||||
flashinfer_python==0.3.0
|
||||
|
||||
[srt_cpu]
|
||||
sglang[runtime_common]
|
||||
|
||||
[srt_hip]
|
||||
sglang[runtime_common]
|
||||
torch
|
||||
petit_kernel==0.0.2
|
||||
wave-lang==1.0.1
|
||||
|
||||
[srt_hpu]
|
||||
sglang[runtime_common]
|
||||
|
||||
[srt_npu]
|
||||
sglang[runtime_common]
|
||||
|
||||
[srt_xpu]
|
||||
sglang[runtime_common]
|
||||
|
||||
[test]
|
||||
accelerate
|
||||
expecttest
|
||||
jsonlines
|
||||
matplotlib
|
||||
pandas
|
||||
peft
|
||||
sentence_transformers
|
||||
pytest
|
||||
tabulate
|
||||
|
||||
[torch_memory_saver]
|
||||
torch_memory_saver==0.0.8
|
||||
1
python/sglang.egg-info/top_level.txt
Normal file
1
python/sglang.egg-info/top_level.txt
Normal file
@@ -0,0 +1 @@
|
||||
sglang
|
||||
16
python/sglang/README.md
Normal file
16
python/sglang/README.md
Normal file
@@ -0,0 +1,16 @@
|
||||
# Code Structures
|
||||
|
||||
- `eval`: The evaluation utilities.
|
||||
- `lang`: The frontend language.
|
||||
- `srt`: The backend engine for running local models. (SRT = SGLang Runtime).
|
||||
- `test`: The test utilities.
|
||||
- `api.py`: The public APIs.
|
||||
- `bench_offline_throughput.py`: Benchmark the performance in the offline mode.
|
||||
- `bench_one_batch.py`: Benchmark the latency of running a single static batch without a server.
|
||||
- `bench_one_batch_server.py`: Benchmark the latency of running a single batch with a server.
|
||||
- `bench_serving.py`: Benchmark online serving with dynamic requests.
|
||||
- `check_env.py`: Check the environment variables and dependencies.
|
||||
- `global_config.py`: The global configs and constants.
|
||||
- `launch_server.py`: The entry point for launching the local server.
|
||||
- `utils.py`: Common utilities.
|
||||
- `version.py`: Version info.
|
||||
83
python/sglang/__init__.py
Normal file
83
python/sglang/__init__.py
Normal file
@@ -0,0 +1,83 @@
|
||||
# SGLang public APIs
|
||||
|
||||
# Frontend Language APIs
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.api import (
|
||||
Engine,
|
||||
Runtime,
|
||||
assistant,
|
||||
assistant_begin,
|
||||
assistant_end,
|
||||
flush_cache,
|
||||
function,
|
||||
gen,
|
||||
gen_int,
|
||||
gen_string,
|
||||
get_server_info,
|
||||
image,
|
||||
select,
|
||||
separate_reasoning,
|
||||
set_default_backend,
|
||||
system,
|
||||
system_begin,
|
||||
system_end,
|
||||
user,
|
||||
user_begin,
|
||||
user_end,
|
||||
video,
|
||||
)
|
||||
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
||||
from sglang.lang.choices import (
|
||||
greedy_token_selection,
|
||||
token_length_normalized,
|
||||
unconditional_likelihood_normalized,
|
||||
)
|
||||
|
||||
# Lazy import some libraries
|
||||
from sglang.utils import LazyImport
|
||||
from sglang.version import __version__
|
||||
|
||||
Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
|
||||
LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
|
||||
OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
|
||||
VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI")
|
||||
|
||||
# Runtime Engine APIs
|
||||
ServerArgs = LazyImport("sglang.srt.server_args", "ServerArgs")
|
||||
Engine = LazyImport("sglang.srt.entrypoints.engine", "Engine")
|
||||
|
||||
__all__ = [
|
||||
"Engine",
|
||||
"Runtime",
|
||||
"assistant",
|
||||
"assistant_begin",
|
||||
"assistant_end",
|
||||
"flush_cache",
|
||||
"function",
|
||||
"gen",
|
||||
"gen_int",
|
||||
"gen_string",
|
||||
"get_server_info",
|
||||
"image",
|
||||
"select",
|
||||
"separate_reasoning",
|
||||
"set_default_backend",
|
||||
"system",
|
||||
"system_begin",
|
||||
"system_end",
|
||||
"user",
|
||||
"user_begin",
|
||||
"user_end",
|
||||
"video",
|
||||
"RuntimeEndpoint",
|
||||
"greedy_token_selection",
|
||||
"token_length_normalized",
|
||||
"unconditional_likelihood_normalized",
|
||||
"ServerArgs",
|
||||
"Anthropic",
|
||||
"LiteLLM",
|
||||
"OpenAI",
|
||||
"VertexAI",
|
||||
"global_config",
|
||||
"__version__",
|
||||
]
|
||||
BIN
python/sglang/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
python/sglang/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
python/sglang/__pycache__/global_config.cpython-310.pyc
Normal file
BIN
python/sglang/__pycache__/global_config.cpython-310.pyc
Normal file
Binary file not shown.
BIN
python/sglang/__pycache__/launch_server.cpython-310.pyc
Normal file
BIN
python/sglang/__pycache__/launch_server.cpython-310.pyc
Normal file
Binary file not shown.
BIN
python/sglang/__pycache__/utils.cpython-310.pyc
Normal file
BIN
python/sglang/__pycache__/utils.cpython-310.pyc
Normal file
Binary file not shown.
BIN
python/sglang/__pycache__/version.cpython-310.pyc
Normal file
BIN
python/sglang/__pycache__/version.cpython-310.pyc
Normal file
Binary file not shown.
452
python/sglang/bench_offline_throughput.py
Normal file
452
python/sglang/bench_offline_throughput.py
Normal file
@@ -0,0 +1,452 @@
|
||||
"""
|
||||
Benchmark the throughput in the offline mode.
|
||||
It accepts server arguments (the same as launch_server.py) and benchmark arguments (the same as bench_serving.py).
|
||||
|
||||
# Usage
|
||||
## Sharegpt dataset with default args
|
||||
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --num-prompts 10
|
||||
|
||||
## Random dataset with default args
|
||||
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from sglang.bench_serving import (
|
||||
DatasetRow,
|
||||
get_dataset,
|
||||
get_tokenizer,
|
||||
sample_random_requests,
|
||||
set_ulimit,
|
||||
)
|
||||
from sglang.lang.backend.runtime_endpoint import Runtime
|
||||
from sglang.srt.entrypoints.engine import Engine
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BenchArgs:
|
||||
backend: str = "engine"
|
||||
result_filename: str = ""
|
||||
dataset_name: str = "sharegpt"
|
||||
dataset_path: str = ""
|
||||
num_prompts: int = 1000
|
||||
sharegpt_output_len: Optional[int] = None
|
||||
sharegpt_context_len: Optional[int] = None
|
||||
random_input_len: int = 1024
|
||||
random_output_len: int = 1024
|
||||
random_range_ratio: float = 0.0
|
||||
gsp_num_groups: int = 64
|
||||
gsp_prompts_per_group: int = 16
|
||||
gsp_system_prompt_len: int = 2048
|
||||
gsp_question_len: int = 128
|
||||
gsp_output_len: int = 256
|
||||
seed: int = 1
|
||||
disable_ignore_eos: bool = False
|
||||
extra_request_body: Optional[str] = None
|
||||
apply_chat_template: bool = False
|
||||
profile: bool = False
|
||||
skip_warmup: bool = False
|
||||
do_not_exit: bool = False
|
||||
prompt_suffix: str = ""
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--backend", type=str, default=BenchArgs.backend)
|
||||
parser.add_argument(
|
||||
"--result-filename", type=str, default=BenchArgs.result_filename
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
default="sharegpt",
|
||||
choices=["sharegpt", "random", "generated-shared-prefix"],
|
||||
help="Name of the dataset to benchmark on.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-path", type=str, default="", help="Path to the dataset."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-prompts",
|
||||
type=int,
|
||||
default=BenchArgs.num_prompts,
|
||||
help="Number of prompts to process. Default is 1000.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sharegpt-output-len",
|
||||
type=int,
|
||||
default=BenchArgs.sharegpt_output_len,
|
||||
help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sharegpt-context-len",
|
||||
type=int,
|
||||
default=BenchArgs.sharegpt_context_len,
|
||||
help="The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random-input-len",
|
||||
type=int,
|
||||
default=BenchArgs.random_input_len,
|
||||
help="Number of input tokens per request, used only for random dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random-output-len",
|
||||
type=int,
|
||||
default=BenchArgs.random_output_len,
|
||||
help="Number of output tokens per request, used only for random dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random-range-ratio",
|
||||
type=float,
|
||||
default=BenchArgs.random_range_ratio,
|
||||
help="Range of sampled ratio of input/output length, "
|
||||
"used only for random dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gsp-num-groups",
|
||||
type=int,
|
||||
default=BenchArgs.gsp_num_groups,
|
||||
help="Number of groups with shared prefix, used"
|
||||
"only for generate-shared-prefix",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gsp-prompts-per-group",
|
||||
type=int,
|
||||
default=BenchArgs.gsp_prompts_per_group,
|
||||
help="Number of prompts per group of shared prefix, used"
|
||||
"only for generate-shared-prefix",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gsp-system-prompt-len",
|
||||
type=int,
|
||||
default=BenchArgs.gsp_system_prompt_len,
|
||||
help="System prompt length, used" "only for generate-shared-prefix",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gsp-question-len",
|
||||
type=int,
|
||||
default=BenchArgs.gsp_question_len,
|
||||
help="Question length, used" "only for generate-shared-prefix",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gsp-output-len",
|
||||
type=int,
|
||||
default=BenchArgs.gsp_output_len,
|
||||
help="Target length in tokens for outputs in generated-shared-prefix dataset",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
|
||||
parser.add_argument(
|
||||
"--disable-ignore-eos",
|
||||
action="store_true",
|
||||
help="Disable ignore EOS token",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--extra-request-body",
|
||||
metavar='{"key1": "value1", "key2": "value2"}',
|
||||
type=str,
|
||||
default=BenchArgs.extra_request_body,
|
||||
help="Append given JSON object to the request payload. You can use this to specify"
|
||||
"additional generate params like sampling params.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--apply-chat-template",
|
||||
action="store_true",
|
||||
help="Apply chat template",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
action="store_true",
|
||||
help="Use Torch Profiler. The endpoint must be launched with "
|
||||
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-warmup",
|
||||
action="store_true",
|
||||
help="Skip the warmup batches.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do-not-exit",
|
||||
action="store_true",
|
||||
help="Do not exit the program. This is useful for nsys profile with --duration and --delay.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt-suffix",
|
||||
type=str,
|
||||
default="",
|
||||
help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
||||
|
||||
|
||||
def throughput_test_once(
|
||||
backend_name: str,
|
||||
backend,
|
||||
reqs: List[DatasetRow],
|
||||
ignore_eos: bool,
|
||||
extra_request_body: Dict,
|
||||
profile: bool,
|
||||
):
|
||||
measurement_results = {
|
||||
"backend": backend_name,
|
||||
"successful_requests": len(reqs),
|
||||
"total_latency": -1,
|
||||
"total_input_tokens": sum(r.prompt_len for r in reqs),
|
||||
"total_output_tokens": -1,
|
||||
"request_throughput": -1,
|
||||
"input_throughput": -1,
|
||||
"output_throughput": -1,
|
||||
"total_throughput": -1,
|
||||
}
|
||||
|
||||
prompt = [r.prompt for r in reqs]
|
||||
sampling_params = [
|
||||
{
|
||||
"temperature": 0,
|
||||
"max_new_tokens": r.output_len,
|
||||
"ignore_eos": ignore_eos,
|
||||
**extra_request_body,
|
||||
}
|
||||
for r in reqs
|
||||
]
|
||||
|
||||
if profile:
|
||||
assert (
|
||||
"SGLANG_TORCH_PROFILER_DIR" in os.environ
|
||||
), "Please set SGLANG_TORCH_PROFILER_DIR."
|
||||
os.makedirs(os.environ["SGLANG_TORCH_PROFILER_DIR"], exist_ok=True)
|
||||
backend.start_profile()
|
||||
|
||||
st = time.perf_counter()
|
||||
gen_out = backend.generate(prompt=prompt, sampling_params=sampling_params)
|
||||
latency = time.perf_counter() - st
|
||||
|
||||
if profile:
|
||||
dir = os.getenv("SGLANG_TORCH_PROFILER_DIR")
|
||||
known_files = set(os.listdir(dir))
|
||||
backend.stop_profile()
|
||||
monitor_trace_file(known_files, dir)
|
||||
|
||||
if backend_name == "runtime":
|
||||
gen_out = json.loads(gen_out)
|
||||
|
||||
server_info = backend.get_server_info()
|
||||
|
||||
measurement_results["total_latency"] = latency
|
||||
measurement_results["total_output_tokens"] = sum(
|
||||
o["meta_info"]["completion_tokens"] for o in gen_out
|
||||
)
|
||||
measurement_results["request_throughput"] = (
|
||||
measurement_results["successful_requests"] / latency
|
||||
)
|
||||
measurement_results["input_throughput"] = (
|
||||
measurement_results["total_input_tokens"] / latency
|
||||
)
|
||||
measurement_results["output_throughput"] = (
|
||||
measurement_results["total_output_tokens"] / latency
|
||||
)
|
||||
measurement_results["total_throughput"] = (
|
||||
measurement_results["total_input_tokens"]
|
||||
+ measurement_results["total_output_tokens"]
|
||||
) / latency
|
||||
|
||||
if inspect.isawaitable(server_info):
|
||||
server_info = asyncio.run(server_info)
|
||||
|
||||
measurement_results["last_gen_throughput"] = server_info["internal_states"][0][
|
||||
"last_gen_throughput"
|
||||
]
|
||||
|
||||
return measurement_results
|
||||
|
||||
|
||||
def monitor_trace_file(known_files, directory, interval=1):
|
||||
print(f"Monitoring {directory} for new trace files...")
|
||||
|
||||
while True:
|
||||
flag = False
|
||||
time.sleep(interval)
|
||||
current_files = set(os.listdir(directory))
|
||||
|
||||
new_files = current_files - known_files
|
||||
for new_file in new_files:
|
||||
new_file_path = os.path.join(directory, new_file)
|
||||
print(f"New file detected: {new_file}")
|
||||
|
||||
previous_size = 0
|
||||
while True:
|
||||
try:
|
||||
current_size = os.path.getsize(new_file_path)
|
||||
except FileNotFoundError:
|
||||
print(f"File {new_file} is no longer accessible.")
|
||||
break
|
||||
|
||||
if current_size > previous_size:
|
||||
previous_size = current_size
|
||||
else:
|
||||
flag = True
|
||||
break
|
||||
|
||||
time.sleep(interval)
|
||||
if flag:
|
||||
break
|
||||
|
||||
|
||||
def throughput_test(
|
||||
server_args: ServerArgs,
|
||||
bench_args: BenchArgs,
|
||||
):
|
||||
if bench_args.backend == "engine":
|
||||
backend = Engine(**dataclasses.asdict(server_args))
|
||||
if not backend:
|
||||
raise ValueError("Please provide valid engine arguments")
|
||||
elif bench_args.backend == "runtime":
|
||||
backend = Runtime(**dataclasses.asdict(server_args))
|
||||
else:
|
||||
raise ValueError('Please set backend to either "engine" or "runtime"')
|
||||
|
||||
tokenizer_id = server_args.tokenizer_path or server_args.model_path
|
||||
tokenizer = get_tokenizer(tokenizer_id)
|
||||
|
||||
# Set global environments
|
||||
set_ulimit()
|
||||
random.seed(bench_args.seed)
|
||||
np.random.seed(bench_args.seed)
|
||||
|
||||
# Parse args
|
||||
extra_request_body = {}
|
||||
if bench_args.extra_request_body:
|
||||
extra_request_body = json.loads(args.extra_request_body)
|
||||
|
||||
# Read dataset
|
||||
input_requests = get_dataset(bench_args, tokenizer)
|
||||
|
||||
warmup_requests = sample_random_requests(
|
||||
input_len=256,
|
||||
output_len=16,
|
||||
num_prompts=min(bench_args.num_prompts, 16),
|
||||
range_ratio=1.0,
|
||||
tokenizer=tokenizer,
|
||||
dataset_path=bench_args.dataset_path,
|
||||
)
|
||||
|
||||
# Warm up
|
||||
if not bench_args.skip_warmup:
|
||||
logging.info("\nWarmup...")
|
||||
throughput_test_once(
|
||||
backend_name=bench_args.backend,
|
||||
backend=backend,
|
||||
reqs=warmup_requests,
|
||||
ignore_eos=not bench_args.disable_ignore_eos,
|
||||
extra_request_body=extra_request_body,
|
||||
profile=False,
|
||||
)
|
||||
time.sleep(0.5)
|
||||
|
||||
logging.info("\nBenchmark...")
|
||||
result = throughput_test_once(
|
||||
backend_name=bench_args.backend,
|
||||
backend=backend,
|
||||
reqs=input_requests,
|
||||
ignore_eos=not bench_args.disable_ignore_eos,
|
||||
extra_request_body=extra_request_body,
|
||||
profile=bench_args.profile,
|
||||
)
|
||||
backend.shutdown()
|
||||
|
||||
if bench_args.result_filename:
|
||||
with open(bench_args.result_filename, "a") as fout:
|
||||
fout.write(json.dumps(result) + "\n")
|
||||
|
||||
print(
|
||||
"\n{s:{c}^{n}}".format(s=" Offline Throughput Benchmark Result ", n=50, c="=")
|
||||
)
|
||||
print("{:<40} {:<10}".format("Backend:", result["backend"]))
|
||||
print("{:<40} {:<10}".format("Successful requests:", result["successful_requests"]))
|
||||
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", result["total_latency"]))
|
||||
print("{:<40} {:<10}".format("Total input tokens:", result["total_input_tokens"]))
|
||||
print(
|
||||
"{:<40} {:<10}".format("Total generated tokens:", result["total_output_tokens"])
|
||||
)
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
"Last generation throughput (tok/s):", result["last_gen_throughput"]
|
||||
)
|
||||
)
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
"Request throughput (req/s):", result["request_throughput"]
|
||||
)
|
||||
)
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
"Input token throughput (tok/s):", result["input_throughput"]
|
||||
)
|
||||
)
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
"Output token throughput (tok/s):", result["output_throughput"]
|
||||
)
|
||||
)
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
"Total token throughput (tok/s):", result["total_throughput"]
|
||||
)
|
||||
)
|
||||
print("=" * 50)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
BenchArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
# handling ModelScope model downloads
|
||||
if os.getenv("SGLANG_USE_MODELSCOPE", "false").lower() in ("true", "1"):
|
||||
if os.path.exists(args.model_path):
|
||||
print(f"Using local model path: {args.model_path}")
|
||||
else:
|
||||
try:
|
||||
from modelscope import snapshot_download
|
||||
|
||||
print(f"Using ModelScope to download model: {args.model_path}")
|
||||
|
||||
# download the model and replace args.model_path
|
||||
args.model_path = snapshot_download(
|
||||
args.model_path,
|
||||
)
|
||||
print(f"Model downloaded to: {args.model_path}")
|
||||
except Exception as e:
|
||||
print(f"ModelScope download failed: {str(e)}")
|
||||
raise e
|
||||
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
bench_args = BenchArgs.from_cli_args(args)
|
||||
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, server_args.log_level.upper()),
|
||||
format="%(message)s",
|
||||
)
|
||||
|
||||
throughput_test(server_args, bench_args)
|
||||
|
||||
while bench_args.do_not_exit:
|
||||
pass
|
||||
665
python/sglang/bench_one_batch.py
Normal file
665
python/sglang/bench_one_batch.py
Normal file
@@ -0,0 +1,665 @@
|
||||
"""
|
||||
Benchmark the latency of running a single static batch without a server.
|
||||
|
||||
This script does not launch a server and uses the low-level APIs.
|
||||
It accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths).
|
||||
|
||||
# Usage (latency test)
|
||||
## with dummy weights:
|
||||
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
|
||||
## sweep through multiple data points and store (append) the results in a jsonl file:
|
||||
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --run-name test_run
|
||||
## run with profiling:
|
||||
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --profile
|
||||
# Usage (correctness test):
|
||||
python -m sglang.bench_one_batch --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
|
||||
|
||||
## Reference output (of the correctness test above, can be gpu dependent):
|
||||
input_ids=[[1, 450, 7483, 310, 3444, 338], [1, 450, 7483, 310, 278, 3303, 13187, 290, 338], [1, 20628, 338, 263, 6575, 1460, 2462, 322, 306, 763]]
|
||||
|
||||
prefill logits (first half): tensor([[-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633],
|
||||
[-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633],
|
||||
[ -9.1875, -10.2500, 2.7129, ..., -4.3359, -4.0664, -4.1328]],
|
||||
device='cuda:0')
|
||||
|
||||
prefill logits (final): tensor([[-8.3125, -7.1172, 3.3457, ..., -4.9570, -4.1328, -3.4141],
|
||||
[-8.9141, -9.0156, 4.1445, ..., -4.9922, -4.4961, -4.0781],
|
||||
[-9.6328, -9.0547, 4.0195, ..., -5.3047, -4.7148, -4.4570]],
|
||||
device='cuda:0')
|
||||
|
||||
========== Prompt 0 ==========
|
||||
<s> The capital of France is Paris.
|
||||
The capital of the United States is Washington, D.C.
|
||||
|
||||
|
||||
========== Prompt 1 ==========
|
||||
<s> The capital of the United Kindom is London.
|
||||
The capital of the United Kingdom is London.
|
||||
The capital of the
|
||||
|
||||
========== Prompt 2 ==========
|
||||
<s> Today is a sunny day and I like to go for a walk in the park.
|
||||
I'm going to the park
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import dataclasses
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.distributed.parallel_state import destroy_distributed_environment
|
||||
from sglang.srt.entrypoints.engine import _set_envs_and_config
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.layers.moe import initialize_moe_config
|
||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||
from sglang.srt.managers.scheduler import Scheduler
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.utils import (
|
||||
configure_logger,
|
||||
get_bool_env_var,
|
||||
kill_process_tree,
|
||||
require_mlp_sync,
|
||||
require_mlp_tp_gather,
|
||||
set_gpu_proc_affinity,
|
||||
suppress_other_loggers,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BenchArgs:
|
||||
run_name: str = "default"
|
||||
batch_size: Tuple[int] = (1,)
|
||||
input_len: Tuple[int] = (1024,)
|
||||
output_len: Tuple[int] = (16,)
|
||||
prompt_filename: str = ""
|
||||
result_filename: str = "result.jsonl"
|
||||
correctness_test: bool = False
|
||||
# This is only used for correctness test
|
||||
cut_len: int = 4
|
||||
log_decode_step: int = 0
|
||||
profile: bool = False
|
||||
profile_record_shapes: bool = False
|
||||
profile_filename_prefix: str = "profile"
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--run-name", type=str, default=BenchArgs.run_name)
|
||||
parser.add_argument(
|
||||
"--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-len", type=int, nargs="+", default=BenchArgs.input_len
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt-filename", type=str, default=BenchArgs.prompt_filename
|
||||
)
|
||||
parser.add_argument(
|
||||
"--result-filename", type=str, default=BenchArgs.result_filename
|
||||
)
|
||||
parser.add_argument("--correctness-test", action="store_true")
|
||||
parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
|
||||
parser.add_argument(
|
||||
"--log-decode-step",
|
||||
type=int,
|
||||
default=BenchArgs.log_decode_step,
|
||||
help="Log decode latency by step, default is set to zero to disable.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile", action="store_true", help="Use Torch Profiler."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile-record-shapes",
|
||||
action="store_true",
|
||||
help="Record tensor shapes in profiling results.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile-filename-prefix",
|
||||
type=str,
|
||||
default=BenchArgs.profile_filename_prefix,
|
||||
help="Prefix of the profiling file names. The full profiling result file(s) be "
|
||||
'"[profile_filename_prefix]_batch[batch_size]_input[input_len]_output[output_len].trace.json.gz"',
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
# use the default value's type to cast the args into correct types.
|
||||
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
|
||||
return cls(
|
||||
**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
|
||||
)
|
||||
|
||||
|
||||
def load_model(server_args, port_args, tp_rank):
|
||||
suppress_other_loggers()
|
||||
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
||||
moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
|
||||
|
||||
model_config = ModelConfig.from_server_args(server_args)
|
||||
model_runner = ModelRunner(
|
||||
model_config=model_config,
|
||||
mem_fraction_static=server_args.mem_fraction_static,
|
||||
gpu_id=tp_rank,
|
||||
tp_rank=tp_rank,
|
||||
tp_size=server_args.tp_size,
|
||||
moe_ep_rank=moe_ep_rank,
|
||||
moe_ep_size=server_args.ep_size,
|
||||
pp_rank=0,
|
||||
pp_size=1,
|
||||
nccl_port=port_args.nccl_port,
|
||||
server_args=server_args,
|
||||
)
|
||||
rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
|
||||
tokenizer = get_tokenizer(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
if server_args.tp_size > 1:
|
||||
dist.barrier()
|
||||
return model_runner, tokenizer
|
||||
|
||||
|
||||
def prepare_inputs_for_correctness_test(bench_args, tokenizer, custom_prompts):
|
||||
prompts = (
|
||||
custom_prompts
|
||||
if custom_prompts
|
||||
else [
|
||||
"The capital of France is",
|
||||
"The capital of the United Kindom is",
|
||||
"Today is a sunny day and I like",
|
||||
]
|
||||
)
|
||||
input_ids = [tokenizer.encode(p) for p in prompts]
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
max_new_tokens=BenchArgs.output_len,
|
||||
)
|
||||
|
||||
reqs = []
|
||||
for i in range(len(prompts)):
|
||||
assert len(input_ids[i]) > bench_args.cut_len
|
||||
|
||||
tmp_input_ids = input_ids[i][: bench_args.cut_len]
|
||||
req = Req(
|
||||
rid=i,
|
||||
origin_input_text=prompts[i],
|
||||
origin_input_ids=tmp_input_ids,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
req.prefix_indices = []
|
||||
req.fill_ids = req.origin_input_ids
|
||||
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
||||
req.logprob_start_len = len(req.origin_input_ids) - 1
|
||||
reqs.append(req)
|
||||
|
||||
return input_ids, reqs
|
||||
|
||||
|
||||
def prepare_extend_inputs_for_correctness_test(
|
||||
bench_args, input_ids, reqs, model_runner
|
||||
):
|
||||
for i in range(len(reqs)):
|
||||
req = reqs[i]
|
||||
req.fill_ids += input_ids[i][bench_args.cut_len :]
|
||||
req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
|
||||
i, : bench_args.cut_len
|
||||
]
|
||||
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
||||
req.logprob_start_len = len(req.origin_input_ids) - 1
|
||||
return reqs
|
||||
|
||||
|
||||
def prepare_synthetic_inputs_for_latency_test(
|
||||
batch_size, input_len, custom_inputs=None
|
||||
):
|
||||
input_ids = (
|
||||
custom_inputs
|
||||
if custom_inputs
|
||||
else np.random.randint(0, 10000, (batch_size, input_len), dtype=np.int32)
|
||||
)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
max_new_tokens=BenchArgs.output_len,
|
||||
)
|
||||
|
||||
reqs = []
|
||||
for i in range(len(input_ids)):
|
||||
req = Req(
|
||||
rid=i,
|
||||
origin_input_text="",
|
||||
origin_input_ids=list(input_ids[i]),
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
req.prefix_indices = []
|
||||
req.fill_ids = req.origin_input_ids
|
||||
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
||||
req.logprob_start_len = len(req.origin_input_ids) - 1
|
||||
reqs.append(req)
|
||||
|
||||
return reqs
|
||||
|
||||
|
||||
@torch.no_grad
|
||||
def extend(reqs, model_runner):
|
||||
batch = ScheduleBatch.init_new(
|
||||
reqs=reqs,
|
||||
req_to_token_pool=model_runner.req_to_token_pool,
|
||||
token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator,
|
||||
tree_cache=None,
|
||||
model_config=model_runner.model_config,
|
||||
enable_overlap=False,
|
||||
spec_algorithm=SpeculativeAlgorithm.NONE,
|
||||
)
|
||||
batch.prepare_for_extend()
|
||||
_maybe_prepare_mlp_sync_batch(batch, model_runner)
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
||||
logits_output, _ = model_runner.forward(forward_batch)
|
||||
next_token_ids = model_runner.sample(logits_output, forward_batch)
|
||||
return next_token_ids, logits_output.next_token_logits, batch
|
||||
|
||||
|
||||
@torch.no_grad
|
||||
def decode(input_token_ids, batch, model_runner):
|
||||
batch.output_ids = input_token_ids
|
||||
batch.prepare_for_decode()
|
||||
_maybe_prepare_mlp_sync_batch(batch, model_runner)
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
||||
logits_output, _ = model_runner.forward(forward_batch)
|
||||
next_token_ids = model_runner.sample(logits_output, forward_batch)
|
||||
return next_token_ids, logits_output.next_token_logits
|
||||
|
||||
|
||||
def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
|
||||
if require_mlp_sync(model_runner.server_args):
|
||||
Scheduler.prepare_mlp_sync_batch_raw(
|
||||
batch,
|
||||
dp_size=model_runner.server_args.dp_size,
|
||||
attn_tp_size=1,
|
||||
tp_group=model_runner.tp_group,
|
||||
get_idle_batch=None,
|
||||
disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
|
||||
spec_algorithm=SpeculativeAlgorithm.NONE,
|
||||
speculative_num_draft_tokens=None,
|
||||
require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
|
||||
disable_overlap_schedule=model_runner.server_args.disable_overlap_schedule,
|
||||
)
|
||||
|
||||
|
||||
def _read_prompts_from_file(prompt_file, rank_print):
|
||||
"""Read custom prompts from the file specified by `--prompt-filename`."""
|
||||
if not prompt_file:
|
||||
return []
|
||||
if not os.path.exists(prompt_file):
|
||||
rank_print(
|
||||
f"Custom prompt file {prompt_file} not found. Using default inputs..."
|
||||
)
|
||||
return []
|
||||
with open(prompt_file, "r") as pf:
|
||||
return pf.readlines()
|
||||
|
||||
|
||||
def _save_profile_trace_results(profiler, filename):
|
||||
parent_dir = os.path.dirname(os.path.abspath(filename))
|
||||
os.makedirs(parent_dir, exist_ok=True)
|
||||
profiler.export_chrome_trace(filename)
|
||||
print(
|
||||
profiler.key_averages(group_by_input_shape=True).table(
|
||||
sort_by="self_cpu_time_total"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def correctness_test(
|
||||
server_args,
|
||||
port_args,
|
||||
bench_args,
|
||||
tp_rank,
|
||||
):
|
||||
# Configure the logger
|
||||
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
||||
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
||||
|
||||
# Load the model
|
||||
model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
|
||||
|
||||
# Prepare inputs
|
||||
custom_prompts = _read_prompts_from_file(bench_args.prompt_filename, rank_print)
|
||||
input_ids, reqs = prepare_inputs_for_correctness_test(
|
||||
bench_args, tokenizer, custom_prompts
|
||||
)
|
||||
rank_print(f"\n{input_ids=}\n")
|
||||
|
||||
if bench_args.cut_len > 0:
|
||||
# Prefill
|
||||
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
|
||||
rank_print(f"prefill logits (first half): {next_token_logits} \n")
|
||||
|
||||
# Prepare extend inputs
|
||||
reqs = prepare_extend_inputs_for_correctness_test(
|
||||
bench_args, input_ids, reqs, model_runner
|
||||
)
|
||||
|
||||
# Extend (prefill w/ KV cache)
|
||||
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
|
||||
rank_print(f"prefill logits (final): {next_token_logits} \n")
|
||||
|
||||
# Decode
|
||||
output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
|
||||
for _ in range(bench_args.output_len[0] - 1):
|
||||
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
||||
next_token_ids_list = next_token_ids.tolist()
|
||||
for i in range(len(reqs)):
|
||||
output_ids[i].append(next_token_ids_list[i])
|
||||
|
||||
# Print output texts
|
||||
for i in range(len(reqs)):
|
||||
rank_print(f"========== Prompt {i} ==========")
|
||||
rank_print(tokenizer.decode(output_ids[i]), "\n")
|
||||
|
||||
|
||||
def synchronize(device):
|
||||
torch.get_device_module(device).synchronize()
|
||||
|
||||
|
||||
def latency_test_run_once(
|
||||
run_name,
|
||||
model_runner,
|
||||
rank_print,
|
||||
reqs,
|
||||
batch_size,
|
||||
input_len,
|
||||
output_len,
|
||||
device,
|
||||
log_decode_step,
|
||||
profile,
|
||||
profile_record_shapes,
|
||||
profile_filename_prefix,
|
||||
):
|
||||
max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
|
||||
if batch_size > max_batch_size:
|
||||
rank_print(
|
||||
f"skipping ({batch_size}, {input_len}, {output_len}) due to max batch size limit"
|
||||
)
|
||||
return
|
||||
|
||||
# Clear the pools.
|
||||
model_runner.req_to_token_pool.clear()
|
||||
model_runner.token_to_kv_pool_allocator.clear()
|
||||
|
||||
measurement_results = {
|
||||
"run_name": run_name,
|
||||
"batch_size": batch_size,
|
||||
"input_len": input_len,
|
||||
"output_len": output_len,
|
||||
}
|
||||
|
||||
tot_latency = 0
|
||||
|
||||
profiler = None
|
||||
if profile:
|
||||
profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
with_stack=True,
|
||||
record_shapes=profile_record_shapes,
|
||||
)
|
||||
profiler.start()
|
||||
|
||||
# Prefill
|
||||
synchronize(device)
|
||||
tic = time.perf_counter()
|
||||
next_token_ids, _, batch = extend(reqs, model_runner)
|
||||
synchronize(device)
|
||||
prefill_latency = time.perf_counter() - tic
|
||||
tot_latency += prefill_latency
|
||||
throughput = input_len * batch_size / prefill_latency
|
||||
rank_print(
|
||||
f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
||||
)
|
||||
measurement_results["prefill_latency"] = prefill_latency
|
||||
measurement_results["prefill_throughput"] = throughput
|
||||
|
||||
if profile:
|
||||
profiler.stop()
|
||||
profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_prefill.trace.json.gz"
|
||||
_save_profile_trace_results(profiler, profile_filename)
|
||||
rank_print(
|
||||
f"torch profiler chrome trace for prefill saved to {profile_filename}"
|
||||
)
|
||||
|
||||
# Decode
|
||||
decode_latencies = []
|
||||
for i in range(output_len - 1):
|
||||
synchronize(device)
|
||||
if profile and i == output_len / 2:
|
||||
profiler = None
|
||||
profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
with_stack=True,
|
||||
record_shapes=profile_record_shapes,
|
||||
)
|
||||
profiler.start()
|
||||
|
||||
tic = time.perf_counter()
|
||||
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
||||
synchronize(device)
|
||||
latency = time.perf_counter() - tic
|
||||
tot_latency += latency
|
||||
throughput = batch_size / latency
|
||||
decode_latencies.append(latency)
|
||||
if i < 5 or (log_decode_step > 0 and i % log_decode_step == 0):
|
||||
rank_print(
|
||||
f"Decode {i}. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
||||
)
|
||||
|
||||
if profile and i == output_len / 2:
|
||||
profiler.stop()
|
||||
profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_decode.trace.json.gz"
|
||||
_save_profile_trace_results(profiler, profile_filename)
|
||||
rank_print(
|
||||
f"torch profiler chrome trace for decoding 1 token saved to {profile_filename}"
|
||||
)
|
||||
|
||||
# Record decode timing from 2nd output
|
||||
if output_len > 1:
|
||||
med_decode_latency = np.median(decode_latencies)
|
||||
med_decode_throughput = batch_size / med_decode_latency
|
||||
rank_print(
|
||||
f"Decode. median latency: {med_decode_latency:6.5f} s, median throughput: {med_decode_throughput:9.2f} token/s"
|
||||
)
|
||||
measurement_results["median_decode_latency"] = med_decode_latency
|
||||
measurement_results["median_decode_throughput"] = med_decode_throughput
|
||||
|
||||
throughput = (input_len + output_len) * batch_size / tot_latency
|
||||
rank_print(
|
||||
f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
|
||||
)
|
||||
measurement_results["total_latency"] = tot_latency
|
||||
measurement_results["overall_throughput"] = throughput
|
||||
return measurement_results
|
||||
|
||||
|
||||
def latency_test(
|
||||
server_args,
|
||||
port_args,
|
||||
bench_args,
|
||||
tp_rank,
|
||||
):
|
||||
initialize_moe_config(server_args)
|
||||
|
||||
# Set CPU affinity
|
||||
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
||||
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, tp_rank)
|
||||
|
||||
# Configure the logger
|
||||
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
||||
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
||||
|
||||
# Load the model
|
||||
model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
|
||||
|
||||
# Prepare inputs for warm up
|
||||
reqs = prepare_synthetic_inputs_for_latency_test(
|
||||
bench_args.batch_size[0], bench_args.input_len[0]
|
||||
)
|
||||
|
||||
# Warm up
|
||||
rank_print("Warmup ...")
|
||||
latency_test_run_once(
|
||||
bench_args.run_name,
|
||||
model_runner,
|
||||
rank_print,
|
||||
reqs,
|
||||
bench_args.batch_size[0],
|
||||
bench_args.input_len[0],
|
||||
min(32, bench_args.output_len[0]), # shorter decoding to speed up the warmup
|
||||
server_args.device,
|
||||
log_decode_step=0,
|
||||
profile=False,
|
||||
profile_record_shapes=False,
|
||||
profile_filename_prefix="", # not used
|
||||
)
|
||||
|
||||
rank_print("Benchmark ...")
|
||||
|
||||
custom_inputs = _read_prompts_from_file(bench_args.prompt_filename, rank_print)
|
||||
custom_inputs = [tokenizer.encode(p.strip()) for p in custom_inputs]
|
||||
custom_input_len = len(custom_inputs)
|
||||
|
||||
# Run the sweep
|
||||
result_list = []
|
||||
for bs, il, ol in itertools.product(
|
||||
bench_args.batch_size, bench_args.input_len, bench_args.output_len
|
||||
):
|
||||
bs_aligned_inputs = []
|
||||
if custom_inputs:
|
||||
if custom_input_len == bs:
|
||||
bs_aligned_inputs = custom_inputs
|
||||
elif custom_input_len > bs:
|
||||
rank_print(
|
||||
f"Custom input size ({custom_input_len}) is larger than batch_size ({bs}). "
|
||||
f"Using the first {bs} prompts."
|
||||
)
|
||||
bs_aligned_inputs = copy.deepcopy(custom_inputs[:bs])
|
||||
else:
|
||||
rank_print(
|
||||
f"Custom input size ({custom_input_len}) is smaller than batch_size ({bs}). "
|
||||
f"Pad to the desired batch_size with the last prompt."
|
||||
)
|
||||
bs_aligned_inputs = copy.deepcopy(custom_inputs)
|
||||
bs_aligned_inputs.extend(
|
||||
[bs_aligned_inputs[-1]] * (bs - custom_input_len)
|
||||
)
|
||||
|
||||
reqs = prepare_synthetic_inputs_for_latency_test(bs, il, bs_aligned_inputs)
|
||||
ret = latency_test_run_once(
|
||||
bench_args.run_name,
|
||||
model_runner,
|
||||
rank_print,
|
||||
reqs,
|
||||
bs,
|
||||
il,
|
||||
ol,
|
||||
server_args.device,
|
||||
bench_args.log_decode_step,
|
||||
bench_args.profile if tp_rank == 0 else None,
|
||||
bench_args.profile_record_shapes if tp_rank == 0 else None,
|
||||
bench_args.profile_filename_prefix,
|
||||
)
|
||||
if ret is not None:
|
||||
result_list.append(ret)
|
||||
|
||||
# Write results in jsonlines format on rank 0.
|
||||
if tp_rank == 0 and bench_args.result_filename:
|
||||
with open(bench_args.result_filename, "a") as fout:
|
||||
for result in result_list:
|
||||
fout.write(json.dumps(result) + "\n")
|
||||
|
||||
if server_args.tp_size > 1:
|
||||
destroy_distributed_environment()
|
||||
|
||||
|
||||
def main(server_args, bench_args):
|
||||
server_args.cuda_graph_max_bs = max(bench_args.batch_size)
|
||||
|
||||
_set_envs_and_config(server_args)
|
||||
|
||||
if server_args.model_path:
|
||||
if bench_args.correctness_test:
|
||||
work_func = correctness_test
|
||||
else:
|
||||
work_func = latency_test
|
||||
else:
|
||||
raise ValueError(
|
||||
"Provide --model-path for running the tests or "
|
||||
"provide --result-filename for plotting the results"
|
||||
)
|
||||
|
||||
port_args = PortArgs.init_new(server_args)
|
||||
|
||||
if server_args.tp_size == 1:
|
||||
work_func(server_args, port_args, bench_args, 0)
|
||||
else:
|
||||
workers = []
|
||||
for tp_rank in range(server_args.tp_size):
|
||||
proc = multiprocessing.Process(
|
||||
target=work_func,
|
||||
args=(
|
||||
server_args,
|
||||
port_args,
|
||||
bench_args,
|
||||
tp_rank,
|
||||
),
|
||||
)
|
||||
proc.start()
|
||||
workers.append(proc)
|
||||
|
||||
for proc in workers:
|
||||
proc.join()
|
||||
|
||||
proc.terminate()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
BenchArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
bench_args = BenchArgs.from_cli_args(args)
|
||||
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, server_args.log_level.upper()),
|
||||
format="%(message)s",
|
||||
)
|
||||
|
||||
try:
|
||||
main(server_args, bench_args)
|
||||
finally:
|
||||
if server_args.tp_size != 1:
|
||||
kill_process_tree(os.getpid(), include_parent=False)
|
||||
430
python/sglang/bench_one_batch_server.py
Normal file
430
python/sglang/bench_one_batch_server.py
Normal file
@@ -0,0 +1,430 @@
|
||||
"""
|
||||
Benchmark the latency of running a single batch with a server.
|
||||
|
||||
This script launches a server and uses the HTTP interface.
|
||||
It accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths).
|
||||
|
||||
Usage:
|
||||
python3 -m sglang.bench_one_batch_server --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8
|
||||
|
||||
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
|
||||
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --show-report --profile --profile-by-stage
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import itertools
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
from typing import List, Tuple
|
||||
|
||||
import requests
|
||||
|
||||
from sglang.bench_serving import get_tokenizer, sample_random_requests
|
||||
from sglang.profiler import run_profile
|
||||
from sglang.srt.entrypoints.http_server import launch_server
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import is_blackwell, kill_process_tree
|
||||
from sglang.test.test_utils import is_in_ci, write_github_step_summary
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BenchArgs:
|
||||
run_name: str = "default"
|
||||
batch_size: Tuple[int] = (1,)
|
||||
input_len: Tuple[int] = (1024,)
|
||||
output_len: Tuple[int] = (16,)
|
||||
temperature: float = 0.0
|
||||
return_logprob: bool = False
|
||||
client_stream_interval: int = 1
|
||||
input_len_step_percentage: float = 0.0
|
||||
result_filename: str = "result.jsonl"
|
||||
base_url: str = ""
|
||||
skip_warmup: bool = False
|
||||
show_report: bool = False
|
||||
profile: bool = False
|
||||
profile_steps: int = 3
|
||||
profile_by_stage: bool = False
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--run-name", type=str, default=BenchArgs.run_name)
|
||||
parser.add_argument(
|
||||
"--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-len", type=int, nargs="+", default=BenchArgs.input_len
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
|
||||
)
|
||||
parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
|
||||
parser.add_argument("--return-logprob", action="store_true")
|
||||
parser.add_argument(
|
||||
"--client-stream-interval",
|
||||
type=int,
|
||||
default=BenchArgs.client_stream_interval,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-len-step-percentage",
|
||||
type=float,
|
||||
default=BenchArgs.input_len_step_percentage,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--result-filename", type=str, default=BenchArgs.result_filename
|
||||
)
|
||||
parser.add_argument("--base-url", type=str, default=BenchArgs.base_url)
|
||||
parser.add_argument("--skip-warmup", action="store_true")
|
||||
parser.add_argument("--show-report", action="store_true")
|
||||
parser.add_argument("--profile", action="store_true")
|
||||
parser.add_argument(
|
||||
"--profile-steps", type=int, default=BenchArgs.profile_steps
|
||||
)
|
||||
parser.add_argument("--profile-by-stage", action="store_true")
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
# use the default value's type to cast the args into correct types.
|
||||
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
|
||||
return cls(
|
||||
**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
|
||||
)
|
||||
|
||||
|
||||
def launch_server_internal(server_args):
|
||||
try:
|
||||
launch_server(server_args)
|
||||
except Exception as e:
|
||||
raise e
|
||||
finally:
|
||||
kill_process_tree(os.getpid(), include_parent=False)
|
||||
|
||||
|
||||
def launch_server_process(server_args: ServerArgs):
|
||||
proc = multiprocessing.Process(target=launch_server_internal, args=(server_args,))
|
||||
proc.start()
|
||||
base_url = f"http://{server_args.host}:{server_args.port}"
|
||||
timeout = 600
|
||||
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
headers = {
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
}
|
||||
response = requests.get(f"{base_url}/v1/models", headers=headers)
|
||||
if response.status_code == 200:
|
||||
return proc, base_url
|
||||
except requests.RequestException:
|
||||
pass
|
||||
time.sleep(10)
|
||||
raise TimeoutError("Server failed to start within the timeout period.")
|
||||
|
||||
|
||||
def run_one_case(
|
||||
url: str,
|
||||
batch_size: int,
|
||||
input_len: int,
|
||||
output_len: int,
|
||||
temperature: float,
|
||||
return_logprob: bool,
|
||||
stream_interval: int,
|
||||
input_len_step_percentage: float,
|
||||
run_name: str,
|
||||
result_filename: str,
|
||||
tokenizer,
|
||||
profile: bool = False,
|
||||
profile_steps: int = 3,
|
||||
profile_by_stage: bool = False,
|
||||
):
|
||||
requests.post(url + "/flush_cache")
|
||||
input_requests = sample_random_requests(
|
||||
input_len=input_len,
|
||||
output_len=output_len,
|
||||
num_prompts=batch_size,
|
||||
range_ratio=1.0,
|
||||
tokenizer=tokenizer,
|
||||
dataset_path="",
|
||||
random_sample=True,
|
||||
return_text=False,
|
||||
)
|
||||
|
||||
use_structured_outputs = False
|
||||
if use_structured_outputs:
|
||||
texts = []
|
||||
for _ in range(batch_size):
|
||||
texts.append(
|
||||
"Human: What is the capital city of france? can you give as many trivial information as possible about that city? answer in json.\n"
|
||||
* 50
|
||||
+ "Assistant:"
|
||||
)
|
||||
json_schema = "$$ANY$$"
|
||||
else:
|
||||
json_schema = None
|
||||
|
||||
profile_link = None
|
||||
if profile:
|
||||
profile_link: str = run_profile(
|
||||
url, profile_steps, ["CPU", "GPU"], None, None, profile_by_stage
|
||||
)
|
||||
|
||||
tic = time.perf_counter()
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"input_ids": [req.prompt for req in input_requests],
|
||||
"sampling_params": {
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": output_len,
|
||||
"ignore_eos": True,
|
||||
"json_schema": json_schema,
|
||||
"stream_interval": stream_interval,
|
||||
},
|
||||
"return_logprob": return_logprob,
|
||||
"stream": True,
|
||||
},
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# The TTFT of the last request in the batch
|
||||
ttft = 0.0
|
||||
for chunk in response.iter_lines(decode_unicode=False):
|
||||
chunk = chunk.decode("utf-8")
|
||||
if chunk and chunk.startswith("data:"):
|
||||
if chunk == "data: [DONE]":
|
||||
break
|
||||
data = json.loads(chunk[5:].strip("\n"))
|
||||
if "error" in data:
|
||||
raise RuntimeError(f"Request has failed. {data}.")
|
||||
|
||||
assert (
|
||||
data["meta_info"]["finish_reason"] is None
|
||||
or data["meta_info"]["finish_reason"]["type"] == "length"
|
||||
)
|
||||
if data["meta_info"]["completion_tokens"] == 1:
|
||||
ttft = time.perf_counter() - tic
|
||||
|
||||
latency = time.perf_counter() - tic
|
||||
input_throughput = batch_size * input_len / ttft
|
||||
output_throughput = batch_size * output_len / (latency - ttft)
|
||||
overall_throughput = batch_size * (input_len + output_len) / latency
|
||||
|
||||
server_info = requests.get(url + "/get_server_info").json()
|
||||
acc_length = server_info["internal_states"][0].get("avg_spec_accept_length", None)
|
||||
last_gen_throughput = server_info["internal_states"][0]["last_gen_throughput"]
|
||||
|
||||
print(f"batch size: {batch_size}")
|
||||
print(f"input_len: {input_len}")
|
||||
print(f"output_len: {output_len}")
|
||||
print(f"latency: {latency:.2f} s")
|
||||
print(f"ttft: {ttft:.2f} s")
|
||||
print(f"last generation throughput: {last_gen_throughput:.2f} tok/s")
|
||||
print(f"input throughput: {input_throughput:.2f} tok/s")
|
||||
if output_len != 1:
|
||||
print(f"output throughput: {output_throughput:.2f} tok/s")
|
||||
|
||||
if result_filename:
|
||||
with open(result_filename, "a") as fout:
|
||||
res = {
|
||||
"run_name": run_name,
|
||||
"batch_size": batch_size,
|
||||
"input_len": input_len,
|
||||
"output_len": output_len,
|
||||
"latency": round(latency, 4),
|
||||
"output_throughput": round(output_throughput, 2),
|
||||
"overall_throughput": round(overall_throughput, 2),
|
||||
"last_gen_throughput": round(last_gen_throughput, 2),
|
||||
}
|
||||
fout.write(json.dumps(res) + "\n")
|
||||
|
||||
return (
|
||||
batch_size,
|
||||
latency,
|
||||
ttft,
|
||||
input_throughput,
|
||||
output_throughput,
|
||||
overall_throughput,
|
||||
last_gen_throughput,
|
||||
acc_length,
|
||||
profile_link if profile else None,
|
||||
)
|
||||
|
||||
|
||||
def get_report_summary(
|
||||
result: List[Tuple], server_args: ServerArgs, bench_args: BenchArgs
|
||||
):
|
||||
import tabulate
|
||||
|
||||
summary = (
|
||||
f"\nInput lens: {bench_args.input_len}. Output lens: {bench_args.output_len}.\n"
|
||||
)
|
||||
|
||||
headers = [
|
||||
"batch size",
|
||||
"latency (s)",
|
||||
"input throughput (tok/s)",
|
||||
"output throughput (tok/s)",
|
||||
"acc length",
|
||||
"ITL (ms)",
|
||||
"input cost ($/1M)",
|
||||
"output cost ($/1M)",
|
||||
]
|
||||
if bench_args.profile:
|
||||
headers.append("profile")
|
||||
rows = []
|
||||
|
||||
for (
|
||||
batch_size,
|
||||
latency,
|
||||
ttft,
|
||||
input_throughput,
|
||||
output_throughput,
|
||||
_,
|
||||
_,
|
||||
acc_length,
|
||||
trace_link,
|
||||
) in result:
|
||||
if is_blackwell():
|
||||
hourly_cost_per_gpu = 4 # $4/hour for one B200
|
||||
else:
|
||||
hourly_cost_per_gpu = 2 # $2/hour for one H100
|
||||
|
||||
hourly_cost = hourly_cost_per_gpu * server_args.tp_size
|
||||
input_util = 0.7
|
||||
accept_length = round(acc_length, 2) if acc_length is not None else "n/a"
|
||||
itl = 1 / (output_throughput / batch_size) * 1000
|
||||
input_cost = 1e6 / (input_throughput * input_util) / 3600 * hourly_cost
|
||||
output_cost = 1e6 / output_throughput / 3600 * hourly_cost
|
||||
row = [
|
||||
batch_size,
|
||||
latency,
|
||||
input_throughput,
|
||||
output_throughput,
|
||||
accept_length,
|
||||
itl,
|
||||
input_cost,
|
||||
output_cost,
|
||||
]
|
||||
if trace_link:
|
||||
row.append(f"[Profile]({trace_link})")
|
||||
rows.append(row)
|
||||
|
||||
summary += tabulate.tabulate(
|
||||
rows, headers=headers, tablefmt="github", floatfmt=".2f"
|
||||
)
|
||||
return summary
|
||||
|
||||
|
||||
def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
||||
if bench_args.base_url:
|
||||
proc, base_url = None, bench_args.base_url
|
||||
else:
|
||||
proc, base_url = launch_server_process(server_args)
|
||||
|
||||
server_info = requests.get(base_url + "/get_server_info").json()
|
||||
if "tokenizer_path" in server_info:
|
||||
tokenizer_path = server_info["tokenizer_path"]
|
||||
elif "prefill" in server_info:
|
||||
tokenizer_path = server_info["prefill"][0]["tokenizer_path"]
|
||||
tokenizer = get_tokenizer(tokenizer_path)
|
||||
|
||||
# warmup
|
||||
if not bench_args.skip_warmup:
|
||||
print("=" * 8 + " Warmup Begin " + "=" * 8)
|
||||
run_one_case(
|
||||
base_url,
|
||||
batch_size=16,
|
||||
input_len=1024,
|
||||
output_len=16,
|
||||
temperature=bench_args.temperature,
|
||||
return_logprob=bench_args.return_logprob,
|
||||
stream_interval=bench_args.client_stream_interval,
|
||||
input_len_step_percentage=bench_args.input_len_step_percentage,
|
||||
run_name="",
|
||||
result_filename="",
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
print("=" * 8 + " Warmup End " + "=" * 8 + "\n")
|
||||
|
||||
# benchmark
|
||||
result = []
|
||||
bench_result = []
|
||||
try:
|
||||
for bs, il, ol in itertools.product(
|
||||
bench_args.batch_size, bench_args.input_len, bench_args.output_len
|
||||
):
|
||||
result.append(
|
||||
run_one_case(
|
||||
base_url,
|
||||
bs,
|
||||
il,
|
||||
ol,
|
||||
temperature=bench_args.temperature,
|
||||
return_logprob=bench_args.return_logprob,
|
||||
stream_interval=bench_args.client_stream_interval,
|
||||
input_len_step_percentage=bench_args.input_len_step_percentage,
|
||||
run_name=bench_args.run_name,
|
||||
result_filename=bench_args.result_filename,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
)
|
||||
|
||||
if bench_args.profile:
|
||||
try:
|
||||
for bs, il, ol in itertools.product(
|
||||
bench_args.batch_size, bench_args.input_len, bench_args.output_len
|
||||
):
|
||||
bench_result.append(
|
||||
(
|
||||
run_one_case(
|
||||
base_url,
|
||||
bs,
|
||||
il,
|
||||
ol,
|
||||
temperature=bench_args.temperature,
|
||||
return_logprob=bench_args.return_logprob,
|
||||
stream_interval=bench_args.client_stream_interval,
|
||||
input_len_step_percentage=bench_args.input_len_step_percentage,
|
||||
run_name=bench_args.run_name,
|
||||
result_filename=bench_args.result_filename,
|
||||
tokenizer=tokenizer,
|
||||
profile=bench_args.profile,
|
||||
profile_steps=bench_args.profile_steps,
|
||||
profile_by_stage=bench_args.profile_by_stage,
|
||||
)[-1],
|
||||
)
|
||||
)
|
||||
result = [t1[:-1] + t2 for t1, t2 in zip(result, bench_result)]
|
||||
except Exception as e:
|
||||
print(f"Error profiling, there will be no profile trace dump: {e}")
|
||||
finally:
|
||||
if proc:
|
||||
kill_process_tree(proc.pid)
|
||||
|
||||
print(f"\nResults are saved to {bench_args.result_filename}")
|
||||
|
||||
if not bench_args.show_report:
|
||||
return
|
||||
|
||||
summary = get_report_summary(result, server_args, bench_args)
|
||||
print(summary)
|
||||
|
||||
if is_in_ci():
|
||||
write_github_step_summary(summary)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
BenchArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
bench_args = BenchArgs.from_cli_args(args)
|
||||
|
||||
run_benchmark(server_args, bench_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
2178
python/sglang/bench_serving.py
Normal file
2178
python/sglang/bench_serving.py
Normal file
File diff suppressed because it is too large
Load Diff
305
python/sglang/check_env.py
Normal file
305
python/sglang/check_env.py
Normal file
@@ -0,0 +1,305 @@
|
||||
"""Check environment configurations and dependency versions."""
|
||||
|
||||
import importlib.metadata
|
||||
import os
|
||||
import resource
|
||||
import subprocess
|
||||
import sys
|
||||
from collections import OrderedDict, defaultdict
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
|
||||
def is_cuda_v2():
|
||||
return torch.version.cuda is not None
|
||||
|
||||
|
||||
# List of packages to check versions
|
||||
PACKAGE_LIST = [
|
||||
"sglang",
|
||||
"sgl_kernel",
|
||||
"flashinfer_python",
|
||||
"triton",
|
||||
"transformers",
|
||||
"torchao",
|
||||
"numpy",
|
||||
"aiohttp",
|
||||
"fastapi",
|
||||
"hf_transfer",
|
||||
"huggingface_hub",
|
||||
"interegular",
|
||||
"modelscope",
|
||||
"orjson",
|
||||
"outlines",
|
||||
"packaging",
|
||||
"psutil",
|
||||
"pydantic",
|
||||
"python-multipart",
|
||||
"pyzmq",
|
||||
"torchao",
|
||||
"uvicorn",
|
||||
"uvloop",
|
||||
"vllm",
|
||||
"xgrammar",
|
||||
"openai",
|
||||
"tiktoken",
|
||||
"anthropic",
|
||||
"litellm",
|
||||
"decord",
|
||||
]
|
||||
|
||||
|
||||
def get_package_versions(packages):
|
||||
"""
|
||||
Get versions of specified packages.
|
||||
"""
|
||||
versions = {}
|
||||
for package in packages:
|
||||
package_name = package.split("==")[0].split(">=")[0].split("<=")[0]
|
||||
try:
|
||||
version = importlib.metadata.version(package_name)
|
||||
versions[package_name] = version
|
||||
except ModuleNotFoundError:
|
||||
versions[package_name] = "Module Not Found"
|
||||
return versions
|
||||
|
||||
|
||||
def get_cuda_info():
|
||||
"""
|
||||
Get CUDA-related information if available.
|
||||
"""
|
||||
if is_cuda_v2():
|
||||
cuda_info = {"CUDA available": torch.cuda.is_available()}
|
||||
|
||||
if cuda_info["CUDA available"]:
|
||||
cuda_info.update(_get_gpu_info())
|
||||
cuda_info.update(_get_cuda_version_info())
|
||||
|
||||
return cuda_info
|
||||
elif is_hip():
|
||||
cuda_info = {"ROCM available": torch.cuda.is_available()}
|
||||
|
||||
if cuda_info["ROCM available"]:
|
||||
cuda_info.update(_get_gpu_info())
|
||||
cuda_info.update(_get_cuda_version_info())
|
||||
|
||||
return cuda_info
|
||||
|
||||
|
||||
def _get_gpu_info():
|
||||
"""
|
||||
Get information about available GPUs.
|
||||
"""
|
||||
devices = defaultdict(list)
|
||||
capabilities = defaultdict(list)
|
||||
for k in range(torch.cuda.device_count()):
|
||||
devices[torch.cuda.get_device_name(k)].append(str(k))
|
||||
capability = torch.cuda.get_device_capability(k)
|
||||
capabilities[f"{capability[0]}.{capability[1]}"].append(str(k))
|
||||
|
||||
gpu_info = {}
|
||||
for name, device_ids in devices.items():
|
||||
gpu_info[f"GPU {','.join(device_ids)}"] = name
|
||||
|
||||
if len(capabilities) == 1:
|
||||
# All GPUs have the same compute capability
|
||||
cap, gpu_ids = list(capabilities.items())[0]
|
||||
gpu_info[f"GPU {','.join(gpu_ids)} Compute Capability"] = cap
|
||||
else:
|
||||
# GPUs have different compute capabilities
|
||||
for cap, gpu_ids in capabilities.items():
|
||||
gpu_info[f"GPU {','.join(gpu_ids)} Compute Capability"] = cap
|
||||
|
||||
return gpu_info
|
||||
|
||||
|
||||
def _get_cuda_version_info():
|
||||
"""
|
||||
Get CUDA version information.
|
||||
"""
|
||||
if is_cuda_v2():
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
|
||||
cuda_info = {"CUDA_HOME": CUDA_HOME}
|
||||
|
||||
if CUDA_HOME and os.path.isdir(CUDA_HOME):
|
||||
cuda_info.update(_get_nvcc_info())
|
||||
cuda_info.update(_get_cuda_driver_version())
|
||||
|
||||
return cuda_info
|
||||
elif is_hip():
|
||||
from torch.utils.cpp_extension import ROCM_HOME as ROCM_HOME
|
||||
|
||||
cuda_info = {"ROCM_HOME": ROCM_HOME}
|
||||
|
||||
if ROCM_HOME and os.path.isdir(ROCM_HOME):
|
||||
cuda_info.update(_get_nvcc_info())
|
||||
cuda_info.update(_get_cuda_driver_version())
|
||||
|
||||
return cuda_info
|
||||
else:
|
||||
cuda_info = {"CUDA_HOME": ""}
|
||||
return cuda_info
|
||||
|
||||
|
||||
def _get_nvcc_info():
|
||||
"""
|
||||
Get NVCC version information.
|
||||
"""
|
||||
if is_cuda_v2():
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
|
||||
try:
|
||||
nvcc = os.path.join(CUDA_HOME, "bin/nvcc")
|
||||
nvcc_output = (
|
||||
subprocess.check_output(f'"{nvcc}" -V', shell=True)
|
||||
.decode("utf-8")
|
||||
.strip()
|
||||
)
|
||||
return {
|
||||
"NVCC": nvcc_output[
|
||||
nvcc_output.rfind("Cuda compilation tools") : nvcc_output.rfind(
|
||||
"Build"
|
||||
)
|
||||
].strip()
|
||||
}
|
||||
except subprocess.SubprocessError:
|
||||
return {"NVCC": "Not Available"}
|
||||
elif is_hip():
|
||||
from torch.utils.cpp_extension import ROCM_HOME
|
||||
|
||||
try:
|
||||
hipcc = os.path.join(ROCM_HOME, "bin/hipcc")
|
||||
hipcc_output = (
|
||||
subprocess.check_output(f'"{hipcc}" --version', shell=True)
|
||||
.decode("utf-8")
|
||||
.strip()
|
||||
)
|
||||
return {
|
||||
"HIPCC": hipcc_output[
|
||||
hipcc_output.rfind("HIP version") : hipcc_output.rfind("AMD clang")
|
||||
].strip()
|
||||
}
|
||||
except subprocess.SubprocessError:
|
||||
return {"HIPCC": "Not Available"}
|
||||
else:
|
||||
return {"NVCC": "Not Available"}
|
||||
|
||||
|
||||
def _get_cuda_driver_version():
|
||||
"""
|
||||
Get CUDA driver version.
|
||||
"""
|
||||
versions = set()
|
||||
if is_cuda_v2():
|
||||
try:
|
||||
output = subprocess.check_output(
|
||||
[
|
||||
"nvidia-smi",
|
||||
"--query-gpu=driver_version",
|
||||
"--format=csv,noheader,nounits",
|
||||
]
|
||||
)
|
||||
versions = set(output.decode().strip().split("\n"))
|
||||
if len(versions) == 1:
|
||||
return {"CUDA Driver Version": versions.pop()}
|
||||
else:
|
||||
return {"CUDA Driver Versions": ", ".join(sorted(versions))}
|
||||
except subprocess.SubprocessError:
|
||||
return {"CUDA Driver Version": "Not Available"}
|
||||
elif is_hip():
|
||||
try:
|
||||
output = subprocess.check_output(
|
||||
[
|
||||
"rocm-smi",
|
||||
"--showdriverversion",
|
||||
"--csv",
|
||||
]
|
||||
)
|
||||
versions = set(output.decode().strip().split("\n"))
|
||||
versions.discard("name, value")
|
||||
ver = versions.pop()
|
||||
ver = ver.replace('"Driver version", ', "").replace('"', "")
|
||||
|
||||
return {"ROCM Driver Version": ver}
|
||||
except subprocess.SubprocessError:
|
||||
return {"ROCM Driver Version": "Not Available"}
|
||||
else:
|
||||
return {"CUDA Driver Version": "Not Available"}
|
||||
|
||||
|
||||
def get_gpu_topology():
|
||||
"""
|
||||
Get GPU topology information.
|
||||
"""
|
||||
if is_cuda_v2():
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["nvidia-smi", "topo", "-m"],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
return "\n" + result.stdout if result.returncode == 0 else None
|
||||
except subprocess.SubprocessError:
|
||||
return None
|
||||
elif is_hip():
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["rocm-smi", "--showtopotype"],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
return "\n" + result.stdout if result.returncode == 0 else None
|
||||
except subprocess.SubprocessError:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_hypervisor_vendor():
|
||||
try:
|
||||
output = subprocess.check_output(["lscpu"], text=True)
|
||||
for line in output.split("\n"):
|
||||
if "Hypervisor vendor:" in line:
|
||||
return line.split(":")[1].strip()
|
||||
return None
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
def check_env():
|
||||
"""
|
||||
Check and print environment information.
|
||||
"""
|
||||
env_info = OrderedDict()
|
||||
env_info["Python"] = sys.version.replace("\n", "")
|
||||
env_info.update(get_cuda_info())
|
||||
env_info["PyTorch"] = torch.__version__
|
||||
env_info.update(get_package_versions(PACKAGE_LIST))
|
||||
|
||||
gpu_topo = get_gpu_topology()
|
||||
if gpu_topo:
|
||||
if is_cuda_v2():
|
||||
env_info["NVIDIA Topology"] = gpu_topo
|
||||
elif is_hip():
|
||||
env_info["AMD Topology"] = gpu_topo
|
||||
|
||||
hypervisor_vendor = get_hypervisor_vendor()
|
||||
if hypervisor_vendor:
|
||||
env_info["Hypervisor vendor"] = hypervisor_vendor
|
||||
|
||||
ulimit_soft, _ = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
env_info["ulimit soft"] = ulimit_soft
|
||||
|
||||
for k, v in env_info.items():
|
||||
print(f"{k}: {v}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
check_env()
|
||||
184
python/sglang/compile_deep_gemm.py
Normal file
184
python/sglang/compile_deep_gemm.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""
|
||||
Compile DeepGEMM Kernels for a model with specify server arguments
|
||||
|
||||
This script launches a server for capturing DeepGEMM calls and then compiles the kernels.
|
||||
It accepts server arguments (the same as launch_server.py).
|
||||
|
||||
Usage:
|
||||
python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST
|
||||
from sglang.srt.entrypoints.http_server import launch_server
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.srt.warmup import warmup
|
||||
|
||||
multiprocessing.set_start_method("spawn", force=True)
|
||||
|
||||
# Reduce warning
|
||||
os.environ["SGL_IN_DEEPGEMM_PRECOMPILE_STAGE"] = "1"
|
||||
# Force enable deep gemm
|
||||
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "1"
|
||||
# Force enable mha chunked kv for DeepSeek V3 to avoid missing kv_b_proj DeepGEMM case
|
||||
os.environ["SGL_CHUNKED_PREFIX_CACHE_THRESHOLD"] = "0"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CompileArgs:
|
||||
timeout: int = 3600
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--timeout", type=int, default=CompileArgs.timeout)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
# use the default value's type to cast the args into correct types.
|
||||
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
|
||||
return cls(
|
||||
**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
|
||||
)
|
||||
|
||||
|
||||
@warmup("compile-deep-gemm")
|
||||
async def warm_up_compile(
|
||||
disaggregation_mode: str, tokenizer_manager: TokenizerManager
|
||||
):
|
||||
print("\nGenerate warm up request for compiling DeepGEMM...\n")
|
||||
generate_req_input = GenerateReqInput(
|
||||
input_ids=[0, 1, 2, 3],
|
||||
sampling_params={
|
||||
"temperature": 0.0,
|
||||
"max_new_tokens": 8,
|
||||
"ignore_eos": True,
|
||||
},
|
||||
)
|
||||
if disaggregation_mode != "null":
|
||||
generate_req_input.bootstrap_room = 0
|
||||
generate_req_input.bootstrap_host = FAKE_BOOTSTRAP_HOST
|
||||
|
||||
await tokenizer_manager.generate_request(generate_req_input, None).__anext__()
|
||||
|
||||
|
||||
def launch_server_internal(server_args):
|
||||
try:
|
||||
launch_server(server_args)
|
||||
except Exception as e:
|
||||
raise e
|
||||
finally:
|
||||
kill_process_tree(os.getpid(), include_parent=False)
|
||||
|
||||
|
||||
def launch_server_process_and_send_one_request(
|
||||
server_args: ServerArgs, compile_args: CompileArgs
|
||||
):
|
||||
proc = multiprocessing.Process(target=launch_server_internal, args=(server_args,))
|
||||
proc.start()
|
||||
base_url = f"http://{server_args.host}:{server_args.port}"
|
||||
timeout = compile_args.timeout
|
||||
|
||||
start_time = time.perf_counter()
|
||||
while time.perf_counter() - start_time < timeout:
|
||||
try:
|
||||
headers = {
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
}
|
||||
if server_args.node_rank == 0:
|
||||
response = requests.get(f"{base_url}/v1/models", headers=headers)
|
||||
else:
|
||||
# This http api is created by launch_dummy_health_check_server for none-rank0 node.
|
||||
response = requests.get(f"{base_url}/health", headers=headers)
|
||||
if response.status_code == 200:
|
||||
# Rank-0 node send a request to sync with other node and then return.
|
||||
if server_args.node_rank == 0:
|
||||
response = requests.post(
|
||||
f"{base_url}/generate",
|
||||
json={
|
||||
"input_ids": [0, 1, 2, 3],
|
||||
"sampling_params": {
|
||||
"max_new_tokens": 8,
|
||||
"temperature": 0,
|
||||
},
|
||||
},
|
||||
timeout=600,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
error = response.json()
|
||||
raise RuntimeError(f"Sync request failed: {error}")
|
||||
# Other nodes should wait for the exit signal from Rank-0 node.
|
||||
else:
|
||||
start_time_waiting = time.perf_counter()
|
||||
while proc.is_alive():
|
||||
if time.perf_counter() - start_time_waiting < timeout:
|
||||
time.sleep(10)
|
||||
else:
|
||||
raise TimeoutError("Waiting for main node timeout!")
|
||||
return proc
|
||||
except requests.RequestException:
|
||||
pass
|
||||
time.sleep(10)
|
||||
raise TimeoutError(
|
||||
"DeepGEMM Kernels compilation timeout."
|
||||
"\n\nFeel free and please restart the command."
|
||||
)
|
||||
|
||||
|
||||
def refine_server_args(server_args: ServerArgs, compile_args: CompileArgs):
|
||||
# Disable cuda graph and torch compile to save time
|
||||
server_args.disable_cuda_graph = True
|
||||
server_args.enable_torch_compile = False
|
||||
print(f"Disable CUDA Graph and Torch Compile to save time...")
|
||||
|
||||
# Set watchdog timeout to compile_args.timeout because compilation will take a long time
|
||||
server_args.watchdog_timeout = compile_args.timeout
|
||||
server_args.warmups = "compile-deep-gemm"
|
||||
|
||||
|
||||
def run_compile(server_args: ServerArgs, compile_args: CompileArgs):
|
||||
print(
|
||||
"Begin DeepGEMM Kernels compilation...\n"
|
||||
"It may take a long time and timeout maybe raised "
|
||||
"while the compilation is still in progress.\n"
|
||||
"Just feel free to restart the command "
|
||||
"until the compilation is fully finished.\n"
|
||||
)
|
||||
|
||||
proc = launch_server_process_and_send_one_request(server_args, compile_args)
|
||||
|
||||
print("\nDeepGEMM Kernels compilation finished successfully.")
|
||||
|
||||
# Sleep for safety
|
||||
time.sleep(10)
|
||||
if proc.is_alive():
|
||||
# This is the rank0 node.
|
||||
kill_process_tree(proc.pid)
|
||||
else:
|
||||
try:
|
||||
kill_process_tree(proc.pid)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
CompileArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
compile_args = CompileArgs.from_cli_args(args)
|
||||
|
||||
refine_server_args(server_args, compile_args)
|
||||
|
||||
run_compile(server_args, compile_args)
|
||||
315
python/sglang/eval/llama3_eval.py
Normal file
315
python/sglang/eval/llama3_eval.py
Normal file
@@ -0,0 +1,315 @@
|
||||
# Adapt from https://github.com/fw-ai/llm_eval_meta
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
import shutil
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
|
||||
import httpx
|
||||
import numpy as np
|
||||
import openai
|
||||
from datasets import load_dataset
|
||||
from openai import AsyncOpenAI
|
||||
from tqdm import tqdm
|
||||
|
||||
# Mapping providers to their clients and models
|
||||
provider_to_models = {
|
||||
"b10": {
|
||||
"8b": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"70b": "meta-llama/Llama-3.1-70B-Instruct",
|
||||
"405b": "meta-llama/Llama-3.1-405B-Instruct",
|
||||
},
|
||||
"oai": {
|
||||
"8b": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"70b": "meta-llama/Llama-3.1-70B-Instruct",
|
||||
"405b": "meta-llama/Llama-3.1-405B-Instruct",
|
||||
},
|
||||
"sgl": {
|
||||
"8b": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"70b": "meta-llama/Llama-3.1-70B-Instruct",
|
||||
"405b": "meta-llama/Llama-3.1-405B-Instruct",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
async def fetch_responses(
|
||||
client, prompt, semaphore, index, provider, model_size, output_dir, max_tokens
|
||||
):
|
||||
output_file = os.path.join(output_dir, f"response_{index}.pkl")
|
||||
if os.path.exists(output_file):
|
||||
print(f"File {output_file} already exists, skipping.")
|
||||
return
|
||||
|
||||
async with semaphore:
|
||||
response = await client.completions.create(
|
||||
model=provider_to_models[provider][model_size],
|
||||
prompt=prompt,
|
||||
temperature=0.0,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
if isinstance(response, openai.BadRequestError):
|
||||
with open(output_file, "wb") as f:
|
||||
pickle.dump("bad_response", f)
|
||||
assert isinstance(response, openai.types.completion.Completion)
|
||||
# Save response to a file
|
||||
with open(output_file, "wb") as f:
|
||||
pickle.dump(response, f)
|
||||
|
||||
|
||||
TASK_TO_MAX_TOKENS = {
|
||||
"evals__mmlu__details": 1,
|
||||
"evals__mmlu__0_shot__cot__details": 1024,
|
||||
# Official meta uses 1024, but a small % (.05) of questions are answered correctly after relaxing
|
||||
"evals__mmlu_pro__details": 2048,
|
||||
"evals__gsm8k__details": 1024,
|
||||
}
|
||||
|
||||
TASK_TO_EVAL_SET = {
|
||||
"mmlu": "evals__mmlu__details",
|
||||
"mmlu_cot": "evals__mmlu__0_shot__cot__details",
|
||||
"mmlu_pro": "evals__mmlu_pro__details",
|
||||
"gsm8k": "evals__gsm8k__details",
|
||||
}
|
||||
|
||||
|
||||
class CustomAsyncHTTPXClient(httpx.AsyncClient):
|
||||
async def send(self, request: httpx.Request, *args, **kwargs) -> httpx.Response:
|
||||
request.url = httpx.URL(
|
||||
f"https://model-{os.getenv('MODEL_ID')}.api.baseten.co/development/predict"
|
||||
)
|
||||
return await super().send(request, *args, **kwargs)
|
||||
|
||||
|
||||
def get_client(provider):
|
||||
if provider not in "b10":
|
||||
if os.getenv("OPENAI_API_KEY") == None:
|
||||
os.environ["OPENAI_API_KEY"] = "EMPTY"
|
||||
return {
|
||||
"oai": AsyncOpenAI(base_url="http://127.0.0.1:8000/v1/"),
|
||||
"b10": AsyncOpenAI(
|
||||
api_key=f"Api-Key {os.getenv('OPENAI_API_KEY')}",
|
||||
base_url=f"https://model-{os.getenv('MODEL_ID')}.api.baseten.co/development/predict",
|
||||
http_client=CustomAsyncHTTPXClient(),
|
||||
),
|
||||
"sgl": AsyncOpenAI(base_url="http://127.0.0.1:30000/v1/"),
|
||||
}[provider]
|
||||
|
||||
|
||||
# Define the benchmark function
|
||||
async def benchmark(args):
|
||||
ds = load_dataset(
|
||||
"meta-llama/Llama-3.1-405B-Instruct-evals",
|
||||
f"Llama-3.1-405B-Instruct-{TASK_TO_EVAL_SET[args.task]}",
|
||||
)
|
||||
semaphore = asyncio.Semaphore(args.concurrency) # Limit to 16 concurrent tasks
|
||||
|
||||
if args.num_examples is None:
|
||||
args.num_examples = len(ds["latest"]["input_final_prompts"])
|
||||
prompts = ds["latest"]["input_final_prompts"][: args.num_examples]
|
||||
|
||||
# Create the output directory if it does not exist
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
tasks = []
|
||||
# Create the tasks with tqdm progress bar
|
||||
max_tokens = TASK_TO_MAX_TOKENS[TASK_TO_EVAL_SET[args.task]]
|
||||
client = get_client(args.provider)
|
||||
for idx, prompt in enumerate(tqdm(prompts, desc="Creating tasks")):
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
fetch_responses(
|
||||
client,
|
||||
f"<|begin_of_text|>{prompt[0]}",
|
||||
semaphore,
|
||||
idx,
|
||||
args.provider,
|
||||
args.model_size,
|
||||
args.output_dir,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Run the tasks with tqdm progress bar
|
||||
for future in tqdm(
|
||||
asyncio.as_completed(tasks), total=len(tasks), desc="Processing tasks"
|
||||
):
|
||||
await future
|
||||
|
||||
|
||||
def get_mmlu_answer(response):
|
||||
if response is not None:
|
||||
return response.choices[0].text.lstrip().rstrip().upper().replace(".", "")
|
||||
return None
|
||||
|
||||
|
||||
def get_mmlu_cot_answer(response):
|
||||
pattern = r"The best answer is (.+)\.?"
|
||||
match = re.search(pattern, response.choices[0].text)
|
||||
if match:
|
||||
return match.group(1).replace(".", "").replace("*", "")
|
||||
|
||||
pattern = r"the best answer is (.+)\.?"
|
||||
match = re.search(pattern, response.choices[0].text)
|
||||
if match:
|
||||
return match.group(1).replace(".", "")
|
||||
|
||||
pattern = r"The correct answer is (.+)\.?"
|
||||
match = re.search(pattern, response.choices[0].text)
|
||||
if match:
|
||||
return match.group(1).replace(".", "")
|
||||
|
||||
pattern = r"the correct answer is (.+)\.?"
|
||||
match = re.search(pattern, response.choices[0].text)
|
||||
if match:
|
||||
return match.group(1).replace(".", "")
|
||||
|
||||
|
||||
def get_answer_gsm8k(response):
|
||||
pattern = r"The final answer is (.+)\.?"
|
||||
match = re.search(pattern, response.choices[0].text)
|
||||
if match:
|
||||
s = match.group(1)
|
||||
for ok_symbol in ["%", "$"]:
|
||||
s = s.replace(ok_symbol, "")
|
||||
return s
|
||||
|
||||
|
||||
TASK_TO_ANSWER_EXTRACTOR = {
|
||||
"evals__mmlu__details": get_mmlu_answer,
|
||||
"evals__mmlu__0_shot__cot__details": get_mmlu_cot_answer,
|
||||
"evals__gsm8k__details": get_answer_gsm8k,
|
||||
"evals__mmlu_pro__details": get_mmlu_cot_answer,
|
||||
}
|
||||
|
||||
|
||||
def get_dataset_from_task(task, response_path, model_size):
|
||||
ds_405b = load_dataset(
|
||||
f"meta-llama/Llama-3.1-405B-Instruct-evals",
|
||||
f"Llama-3.1-405B-Instruct-{task}",
|
||||
)
|
||||
ds_405b_hash_order = [x[0] for x in ds_405b["latest"]["input_final_prompts_hash"]]
|
||||
|
||||
if "70b" in model_size or "8b" in model_size:
|
||||
if "70" in model_size:
|
||||
ref_model_ds = load_dataset(
|
||||
f"meta-llama/Llama-3.1-70B-Instruct-evals",
|
||||
f"Llama-3.1-70B-Instruct-{task}",
|
||||
)
|
||||
else:
|
||||
ref_model_ds = load_dataset(
|
||||
f"meta-llama/Llama-3.1-8B-Instruct-evals",
|
||||
f"Llama-3.1-8B-Instruct-{task}",
|
||||
)
|
||||
|
||||
hash_to_row = {}
|
||||
for row in ref_model_ds["latest"]:
|
||||
hash_to_row[row["input_final_prompts_hash"][0]] = row
|
||||
reordered_rows = []
|
||||
for prompt_hash in ds_405b_hash_order:
|
||||
reordered_rows.append(hash_to_row[prompt_hash])
|
||||
ref_model_ds["latest"] = reordered_rows
|
||||
return ref_model_ds
|
||||
|
||||
return ds_405b
|
||||
|
||||
|
||||
def analyze(task, response_path, model_size):
|
||||
ds = get_dataset_from_task(task, response_path, model_size)
|
||||
|
||||
responses = []
|
||||
total = len(ds["latest"])
|
||||
|
||||
for i in range(0, total):
|
||||
response = pickle.load(
|
||||
open(os.path.join(response_path, f"response_{i}.pkl"), "rb")
|
||||
)
|
||||
responses.append(response)
|
||||
|
||||
@dataclass
|
||||
class Stats:
|
||||
correct: int = 0
|
||||
total: int = 0
|
||||
meta_correct: int = 0
|
||||
|
||||
average: float = None
|
||||
|
||||
subtask_name_to_stats = defaultdict(lambda: Stats())
|
||||
|
||||
for response, ds_row in zip(responses, ds["latest"]):
|
||||
model_answer = TASK_TO_ANSWER_EXTRACTOR[task](response)
|
||||
|
||||
subtask = ds_row["subtask_name"]
|
||||
|
||||
is_eval_correct = model_answer in ds_row["input_correct_responses"]
|
||||
if is_eval_correct:
|
||||
subtask_name_to_stats[subtask].correct += 1
|
||||
|
||||
if ds_row["is_correct"]:
|
||||
subtask_name_to_stats[subtask].meta_correct += 1
|
||||
|
||||
subtask_name_to_stats[subtask].total += 1
|
||||
|
||||
micro_stats = Stats()
|
||||
for subtask, stats in subtask_name_to_stats.items():
|
||||
stats.average = stats.correct / stats.total
|
||||
stats.meta_average = stats.meta_correct / stats.total
|
||||
|
||||
micro_stats.correct += stats.correct
|
||||
micro_stats.total += stats.total
|
||||
micro_stats.meta_correct += stats.meta_correct
|
||||
|
||||
micro_stats.average = micro_stats.correct / micro_stats.total
|
||||
micro_stats.meta_average = micro_stats.meta_correct / micro_stats.total
|
||||
|
||||
print("Macro average", np.mean([x.average for x in subtask_name_to_stats.values()]))
|
||||
print(
|
||||
"Meta Macro average",
|
||||
np.mean([x.meta_average for x in subtask_name_to_stats.values()]),
|
||||
)
|
||||
print("Micro average", micro_stats.average)
|
||||
print("Meta Micro average", micro_stats.meta_average)
|
||||
|
||||
|
||||
# Entry point for the script
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Script to run model with specified parameters."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-size",
|
||||
type=str,
|
||||
default="8b",
|
||||
help="Size of the model (e.g., 8b or 70b)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--provider",
|
||||
type=str,
|
||||
default="sgl",
|
||||
help="Provider name (e.g., sgl, oai, b10)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Task (e.g., mmlu, mmlu_cot, mmlu_pro, gsm8k)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-examples", type=int, default=None, help="Number of examples to process"
|
||||
)
|
||||
parser.add_argument("--concurrency", type=int, default=16)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="tmp-output-dir",
|
||||
help="Directory to save responses",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
asyncio.run(benchmark(args))
|
||||
analyze(TASK_TO_EVAL_SET[args.task], args.output_dir, args.model_size)
|
||||
shutil.rmtree("tmp-output-dir", ignore_errors=True)
|
||||
164
python/sglang/eval/loogle_eval.py
Normal file
164
python/sglang/eval/loogle_eval.py
Normal file
@@ -0,0 +1,164 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import openai
|
||||
import torch
|
||||
from bert_score import BERTScorer
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def get_client(api_url: str) -> openai.AsyncOpenAI:
|
||||
if os.getenv("OPENAI_API_KEY") is None:
|
||||
os.environ["OPENAI_API_KEY"] = "EMPTY"
|
||||
return openai.AsyncOpenAI(base_url=api_url)
|
||||
|
||||
|
||||
def get_dataset():
|
||||
return load_dataset("bigai-nlco/LooGLE", "longdep_qa", split="test")
|
||||
|
||||
|
||||
async def fetch_response(
|
||||
client: openai.AsyncOpenAI,
|
||||
context: str,
|
||||
question: str,
|
||||
semaphore: asyncio.Semaphore,
|
||||
index: int,
|
||||
model: str,
|
||||
output_dir: Path,
|
||||
):
|
||||
output_file = output_dir / f"response_{index}.pkl"
|
||||
if output_file.exists():
|
||||
return
|
||||
|
||||
prompt = (
|
||||
"Please answer the question based on the long texts below.\n"
|
||||
f"{context}\n"
|
||||
f"Question: {question}\n"
|
||||
"Answer:"
|
||||
)
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
async with semaphore:
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=0.0,
|
||||
max_tokens=512,
|
||||
)
|
||||
except openai.BadRequestError as e:
|
||||
with open(output_file, "wb") as f:
|
||||
pickle.dump({"error": str(e)}, f)
|
||||
return
|
||||
|
||||
with open(output_file, "wb") as f:
|
||||
pickle.dump(response, f)
|
||||
|
||||
|
||||
async def benchmark(args):
|
||||
dataset = get_dataset()
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
client = get_client(args.api_url)
|
||||
semaphore = asyncio.Semaphore(args.max_concurrency)
|
||||
|
||||
tasks: List[asyncio.Task] = []
|
||||
for idx, ex in enumerate(dataset):
|
||||
if idx >= args.num_prompts:
|
||||
break
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
fetch_response(
|
||||
client,
|
||||
ex["context"],
|
||||
ex["question"],
|
||||
semaphore,
|
||||
idx,
|
||||
args.model,
|
||||
output_dir,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
for _ in tqdm(
|
||||
asyncio.as_completed(tasks), total=len(tasks), desc="Running benchmark"
|
||||
):
|
||||
await _
|
||||
|
||||
|
||||
def analyse(args):
|
||||
dataset = get_dataset()
|
||||
output_dir = Path(args.output_dir)
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
scorer = BERTScorer(lang="en", device=device)
|
||||
|
||||
hyps: List[str] = []
|
||||
refs: List[str] = []
|
||||
for idx, ex in enumerate(tqdm(dataset, desc="Loading responses")):
|
||||
if idx >= args.num_prompts:
|
||||
break
|
||||
pkl_file = output_dir / f"response_{idx}.pkl"
|
||||
if not pkl_file.exists():
|
||||
raise FileNotFoundError(pkl_file)
|
||||
|
||||
response = pickle.load(open(pkl_file, "rb"))
|
||||
if isinstance(response, dict) and "error" in response:
|
||||
continue
|
||||
|
||||
hyps.append(response.choices[0].message.content.strip())
|
||||
refs.append(ex["answer"])
|
||||
|
||||
if not hyps:
|
||||
print("No valid responses to score!")
|
||||
return
|
||||
|
||||
batch_size = 64
|
||||
all_f1: List[float] = []
|
||||
for i in tqdm(range(0, len(hyps), batch_size), desc="Scoring batches"):
|
||||
h_batch = hyps[i : i + batch_size]
|
||||
r_batch = refs[i : i + batch_size]
|
||||
_, _, f1_scores = scorer.score(h_batch, r_batch, verbose=False)
|
||||
all_f1.extend([float(x) for x in f1_scores])
|
||||
|
||||
avg = sum(all_f1) / len(all_f1)
|
||||
print(f"Average BERTScore (F1): {avg:.2%}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run benchmark and evaluation in one go."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--api-url",
|
||||
default="http://127.0.0.1:30000/v1",
|
||||
help="OpenAI‑compatible API base URL",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="meta-llama/Llama-4-Maverick-17B-128E-Instruct",
|
||||
help="Model name or ID, only used for model name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-concurrency", type=int, default=144, help="Maximum concurrent requests"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir", default="tmp-output-dir", help="Directory for cached responses"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-prompts", type=int, default=10000, help="Number of prompts to run"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(benchmark(args))
|
||||
|
||||
analyse(args)
|
||||
53
python/sglang/global_config.py
Normal file
53
python/sglang/global_config.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""Global configurations"""
|
||||
|
||||
import os
|
||||
|
||||
|
||||
class GlobalConfig:
|
||||
"""
|
||||
Store some global constants.
|
||||
|
||||
See also python/sglang/srt/managers/schedule_batch.py::global_server_args_dict, which stores
|
||||
many global runtime arguments as well.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Verbosity level
|
||||
# 0: do not output anything
|
||||
# 2: output final text after every run
|
||||
self.verbosity = 0
|
||||
|
||||
# Default backend of the language
|
||||
self.default_backend = None
|
||||
|
||||
# Runtime constants: New generation token ratio estimation
|
||||
self.default_init_new_token_ratio = float(
|
||||
os.environ.get("SGLANG_INIT_NEW_TOKEN_RATIO", 0.7)
|
||||
)
|
||||
self.default_min_new_token_ratio_factor = float(
|
||||
os.environ.get("SGLANG_MIN_NEW_TOKEN_RATIO_FACTOR", 0.14)
|
||||
)
|
||||
self.default_new_token_ratio_decay_steps = float(
|
||||
os.environ.get("SGLANG_NEW_TOKEN_RATIO_DECAY_STEPS", 600)
|
||||
)
|
||||
self.torch_empty_cache_interval = float(
|
||||
os.environ.get(
|
||||
"SGLANG_EMPTY_CACHE_INTERVAL", -1
|
||||
) # in seconds. Set if you observe high memory accumulation over a long serving period.
|
||||
)
|
||||
# Runtime constants: others
|
||||
self.retract_decode_steps = 20
|
||||
self.flashinfer_workspace_size = os.environ.get(
|
||||
"FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024
|
||||
)
|
||||
|
||||
# Output tokenization configs
|
||||
self.skip_special_tokens_in_output = True
|
||||
self.spaces_between_special_tokens_in_out = True
|
||||
|
||||
# Language frontend interpreter optimization configs
|
||||
self.enable_precache_with_tracing = True
|
||||
self.enable_parallel_encoding = True
|
||||
|
||||
|
||||
global_config = GlobalConfig()
|
||||
BIN
python/sglang/lang/__pycache__/api.cpython-310.pyc
Normal file
BIN
python/sglang/lang/__pycache__/api.cpython-310.pyc
Normal file
Binary file not shown.
BIN
python/sglang/lang/__pycache__/chat_template.cpython-310.pyc
Normal file
BIN
python/sglang/lang/__pycache__/chat_template.cpython-310.pyc
Normal file
Binary file not shown.
BIN
python/sglang/lang/__pycache__/choices.cpython-310.pyc
Normal file
BIN
python/sglang/lang/__pycache__/choices.cpython-310.pyc
Normal file
Binary file not shown.
BIN
python/sglang/lang/__pycache__/interpreter.cpython-310.pyc
Normal file
BIN
python/sglang/lang/__pycache__/interpreter.cpython-310.pyc
Normal file
Binary file not shown.
BIN
python/sglang/lang/__pycache__/ir.cpython-310.pyc
Normal file
BIN
python/sglang/lang/__pycache__/ir.cpython-310.pyc
Normal file
Binary file not shown.
286
python/sglang/lang/api.py
Normal file
286
python/sglang/lang/api.py
Normal file
@@ -0,0 +1,286 @@
|
||||
"""Public APIs of the language."""
|
||||
|
||||
import re
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.choices import ChoicesSamplingMethod, token_length_normalized
|
||||
from sglang.lang.ir import (
|
||||
SglExpr,
|
||||
SglExprList,
|
||||
SglFunction,
|
||||
SglGen,
|
||||
SglImage,
|
||||
SglRoleBegin,
|
||||
SglRoleEnd,
|
||||
SglSelect,
|
||||
SglSeparateReasoning,
|
||||
SglVideo,
|
||||
)
|
||||
|
||||
|
||||
def function(
|
||||
func: Optional[Callable] = None, num_api_spec_tokens: Optional[int] = None
|
||||
):
|
||||
if func:
|
||||
return SglFunction(func, num_api_spec_tokens=num_api_spec_tokens)
|
||||
|
||||
def decorator(func):
|
||||
return SglFunction(func, num_api_spec_tokens=num_api_spec_tokens)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def Runtime(*args, **kwargs):
|
||||
# Avoid importing unnecessary dependency
|
||||
from sglang.lang.backend.runtime_endpoint import Runtime
|
||||
|
||||
return Runtime(*args, **kwargs)
|
||||
|
||||
|
||||
def Engine(*args, **kwargs):
|
||||
# Avoid importing unnecessary dependency
|
||||
from sglang.srt.entrypoints.engine import Engine
|
||||
|
||||
return Engine(*args, **kwargs)
|
||||
|
||||
|
||||
def set_default_backend(backend: BaseBackend):
|
||||
global_config.default_backend = backend
|
||||
|
||||
|
||||
def flush_cache(backend: Optional[BaseBackend] = None):
|
||||
backend = backend or global_config.default_backend
|
||||
if backend is None:
|
||||
return False
|
||||
|
||||
# If backend is Runtime
|
||||
if hasattr(backend, "endpoint"):
|
||||
backend = backend.endpoint
|
||||
return backend.flush_cache()
|
||||
|
||||
|
||||
def get_server_info(backend: Optional[BaseBackend] = None):
|
||||
backend = backend or global_config.default_backend
|
||||
if backend is None:
|
||||
return None
|
||||
|
||||
# If backend is Runtime
|
||||
if hasattr(backend, "endpoint"):
|
||||
backend = backend.endpoint
|
||||
return backend.get_server_info()
|
||||
|
||||
|
||||
def gen(
|
||||
name: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
min_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
min_p: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
ignore_eos: Optional[bool] = None,
|
||||
return_logprob: Optional[bool] = None,
|
||||
logprob_start_len: Optional[int] = None,
|
||||
top_logprobs_num: Optional[int] = None,
|
||||
return_text_in_logprobs: Optional[bool] = None,
|
||||
dtype: Optional[Union[type, str]] = None,
|
||||
choices: Optional[List[str]] = None,
|
||||
choices_method: Optional[ChoicesSamplingMethod] = None,
|
||||
regex: Optional[str] = None,
|
||||
json_schema: Optional[str] = None,
|
||||
):
|
||||
"""Call the model to generate. See the meaning of the arguments in docs/backend/sampling_params.md"""
|
||||
|
||||
if choices:
|
||||
return SglSelect(
|
||||
name,
|
||||
choices,
|
||||
0.0 if temperature is None else temperature,
|
||||
token_length_normalized if choices_method is None else choices_method,
|
||||
)
|
||||
|
||||
# check regex is valid
|
||||
if regex is not None:
|
||||
try:
|
||||
re.compile(regex)
|
||||
except re.error as e:
|
||||
raise e
|
||||
|
||||
return SglGen(
|
||||
name,
|
||||
max_tokens,
|
||||
min_tokens,
|
||||
n,
|
||||
stop,
|
||||
stop_token_ids,
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
min_p,
|
||||
frequency_penalty,
|
||||
presence_penalty,
|
||||
ignore_eos,
|
||||
return_logprob,
|
||||
logprob_start_len,
|
||||
top_logprobs_num,
|
||||
return_text_in_logprobs,
|
||||
dtype,
|
||||
regex,
|
||||
json_schema,
|
||||
)
|
||||
|
||||
|
||||
def gen_int(
|
||||
name: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
min_p: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
ignore_eos: Optional[bool] = None,
|
||||
return_logprob: Optional[bool] = None,
|
||||
logprob_start_len: Optional[int] = None,
|
||||
top_logprobs_num: Optional[int] = None,
|
||||
return_text_in_logprobs: Optional[bool] = None,
|
||||
):
|
||||
return SglGen(
|
||||
name,
|
||||
max_tokens,
|
||||
None,
|
||||
n,
|
||||
stop,
|
||||
stop_token_ids,
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
min_p,
|
||||
frequency_penalty,
|
||||
presence_penalty,
|
||||
ignore_eos,
|
||||
return_logprob,
|
||||
logprob_start_len,
|
||||
top_logprobs_num,
|
||||
return_text_in_logprobs,
|
||||
int,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def gen_string(
|
||||
name: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
min_p: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
ignore_eos: Optional[bool] = None,
|
||||
return_logprob: Optional[bool] = None,
|
||||
logprob_start_len: Optional[int] = None,
|
||||
top_logprobs_num: Optional[int] = None,
|
||||
return_text_in_logprobs: Optional[bool] = None,
|
||||
):
|
||||
return SglGen(
|
||||
name,
|
||||
max_tokens,
|
||||
None,
|
||||
n,
|
||||
stop,
|
||||
stop_token_ids,
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
min_p,
|
||||
frequency_penalty,
|
||||
presence_penalty,
|
||||
ignore_eos,
|
||||
return_logprob,
|
||||
logprob_start_len,
|
||||
top_logprobs_num,
|
||||
return_text_in_logprobs,
|
||||
str,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def image(expr: SglExpr):
|
||||
return SglImage(expr)
|
||||
|
||||
|
||||
def video(path: str, num_frames: int):
|
||||
return SglVideo(path, num_frames)
|
||||
|
||||
|
||||
def select(
|
||||
name: Optional[str] = None,
|
||||
choices: Optional[List[str]] = None,
|
||||
temperature: float = 0.0,
|
||||
choices_method: ChoicesSamplingMethod = token_length_normalized,
|
||||
):
|
||||
assert choices is not None
|
||||
return SglSelect(name, choices, temperature, choices_method)
|
||||
|
||||
|
||||
def _role_common(name: str, expr: Optional[SglExpr] = None):
|
||||
if expr is None:
|
||||
return SglExprList([SglRoleBegin(name), SglRoleEnd(name)])
|
||||
else:
|
||||
return SglExprList([SglRoleBegin(name), expr, SglRoleEnd(name)])
|
||||
|
||||
|
||||
def system(expr: Optional[SglExpr] = None):
|
||||
return _role_common("system", expr)
|
||||
|
||||
|
||||
def user(expr: Optional[SglExpr] = None):
|
||||
return _role_common("user", expr)
|
||||
|
||||
|
||||
def assistant(expr: Optional[SglExpr] = None):
|
||||
return _role_common("assistant", expr)
|
||||
|
||||
|
||||
def system_begin():
|
||||
return SglRoleBegin("system")
|
||||
|
||||
|
||||
def system_end():
|
||||
return SglRoleEnd("system")
|
||||
|
||||
|
||||
def user_begin():
|
||||
return SglRoleBegin("user")
|
||||
|
||||
|
||||
def user_end():
|
||||
return SglRoleEnd("user")
|
||||
|
||||
|
||||
def assistant_begin():
|
||||
return SglRoleBegin("assistant")
|
||||
|
||||
|
||||
def assistant_end():
|
||||
return SglRoleEnd("assistant")
|
||||
|
||||
|
||||
def separate_reasoning(
|
||||
expr: Optional[SglExpr] = None, model_type: Optional[str] = None
|
||||
):
|
||||
return SglExprList([expr, SglSeparateReasoning(model_type, expr=expr)])
|
||||
Binary file not shown.
Binary file not shown.
73
python/sglang/lang/backend/anthropic.py
Normal file
73
python/sglang/lang/backend/anthropic.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from sglang.lang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import get_chat_template
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SglSamplingParams
|
||||
|
||||
try:
|
||||
import anthropic
|
||||
except ImportError as e:
|
||||
anthropic = e
|
||||
|
||||
|
||||
class Anthropic(BaseBackend):
|
||||
def __init__(self, model_name, *args, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(anthropic, Exception):
|
||||
raise anthropic
|
||||
|
||||
self.model_name = model_name
|
||||
self.chat_template = get_chat_template("claude")
|
||||
self.client = anthropic.Anthropic(*args, **kwargs)
|
||||
|
||||
def get_chat_template(self):
|
||||
return self.chat_template
|
||||
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
if s.messages_:
|
||||
messages = s.messages_
|
||||
else:
|
||||
messages = [{"role": "user", "content": s.text_}]
|
||||
|
||||
if messages and messages[0]["role"] == "system":
|
||||
system = messages.pop(0)["content"]
|
||||
else:
|
||||
system = ""
|
||||
|
||||
ret = self.client.messages.create(
|
||||
model=self.model_name,
|
||||
system=system,
|
||||
messages=messages,
|
||||
**sampling_params.to_anthropic_kwargs(),
|
||||
)
|
||||
comp = ret.content[0].text
|
||||
|
||||
return comp, {}
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
if s.messages_:
|
||||
messages = s.messages_
|
||||
else:
|
||||
messages = [{"role": "user", "content": s.text_}]
|
||||
|
||||
if messages and messages[0]["role"] == "system":
|
||||
system = messages.pop(0)["content"]
|
||||
else:
|
||||
system = ""
|
||||
|
||||
with self.client.messages.stream(
|
||||
model=self.model_name,
|
||||
system=system,
|
||||
messages=messages,
|
||||
**sampling_params.to_anthropic_kwargs(),
|
||||
) as stream:
|
||||
for text in stream.text_stream:
|
||||
yield text, {}
|
||||
82
python/sglang/lang/backend/base_backend.py
Normal file
82
python/sglang/lang/backend/base_backend.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from sglang.lang.chat_template import get_chat_template
|
||||
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SglSamplingParams
|
||||
|
||||
|
||||
class BaseBackend:
|
||||
def __init__(self) -> None:
|
||||
self.support_concate_and_append = False
|
||||
self.chat_template = get_chat_template("default")
|
||||
|
||||
def get_model_name(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_chat_template(self):
|
||||
return self.chat_template
|
||||
|
||||
def cache_prefix(self, prefix_str: str):
|
||||
pass
|
||||
|
||||
def uncache_prefix(self, rid: str):
|
||||
pass
|
||||
|
||||
def end_request(self, rid: Union[str, List[str]]):
|
||||
pass
|
||||
|
||||
def begin_program(self, s: StreamExecutor):
|
||||
pass
|
||||
|
||||
def end_program(self, s: Union[StreamExecutor, List[StreamExecutor]]):
|
||||
pass
|
||||
|
||||
def commit_lazy_operations(self, s: StreamExecutor):
|
||||
pass
|
||||
|
||||
def fork_program(
|
||||
self,
|
||||
src: StreamExecutor,
|
||||
dst: List[StreamExecutor],
|
||||
position_ids_offset: Optional[List[int]] = None,
|
||||
):
|
||||
pass
|
||||
|
||||
def fill_image(self, s: StreamExecutor):
|
||||
pass
|
||||
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
raise NotImplementedError()
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
raise NotImplementedError()
|
||||
|
||||
def select(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
choices: List[str],
|
||||
temperature: float,
|
||||
choices_method: Optional[ChoicesSamplingMethod] = None,
|
||||
) -> ChoicesDecision:
|
||||
raise NotImplementedError()
|
||||
|
||||
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
|
||||
raise NotImplementedError()
|
||||
|
||||
def shutdown(self):
|
||||
pass
|
||||
|
||||
def flush_cache(self):
|
||||
pass
|
||||
|
||||
def get_server_info(self):
|
||||
pass
|
||||
90
python/sglang/lang/backend/litellm.py
Normal file
90
python/sglang/lang/backend/litellm.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from typing import Mapping, Optional
|
||||
|
||||
from sglang.lang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import get_chat_template_by_model_path
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SglSamplingParams
|
||||
|
||||
try:
|
||||
import litellm
|
||||
except ImportError as e:
|
||||
litellm = e
|
||||
litellm.num_retries = 1
|
||||
|
||||
|
||||
class LiteLLM(BaseBackend):
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
chat_template=None,
|
||||
api_key=None,
|
||||
organization: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
timeout: Optional[float] = 600,
|
||||
max_retries: Optional[int] = litellm.num_retries,
|
||||
default_headers: Optional[Mapping[str, str]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(litellm, Exception):
|
||||
raise litellm
|
||||
|
||||
self.model_name = model_name
|
||||
|
||||
self.chat_template = chat_template or get_chat_template_by_model_path(
|
||||
model_name
|
||||
)
|
||||
|
||||
self.client_params = {
|
||||
"api_key": api_key,
|
||||
"organization": organization,
|
||||
"base_url": base_url,
|
||||
"timeout": timeout,
|
||||
"max_retries": max_retries,
|
||||
"default_headers": default_headers,
|
||||
}
|
||||
|
||||
def get_chat_template(self):
|
||||
return self.chat_template
|
||||
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
if s.messages_:
|
||||
messages = s.messages_
|
||||
else:
|
||||
messages = [{"role": "user", "content": s.text_}]
|
||||
|
||||
ret = litellm.completion(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
**self.client_params,
|
||||
**sampling_params.to_litellm_kwargs(),
|
||||
)
|
||||
comp = ret.choices[0].message.content
|
||||
|
||||
return comp, {}
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
if s.messages_:
|
||||
messages = s.messages_
|
||||
else:
|
||||
messages = [{"role": "user", "content": s.text_}]
|
||||
|
||||
ret = litellm.completion(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
**self.client_params,
|
||||
**sampling_params.to_litellm_kwargs(),
|
||||
)
|
||||
for chunk in ret:
|
||||
text = chunk.choices[0].delta.content
|
||||
if text is not None:
|
||||
yield text, {}
|
||||
475
python/sglang/lang/backend/openai.py
Normal file
475
python/sglang/lang/backend/openai.py
Normal file
@@ -0,0 +1,475 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
import time
|
||||
import warnings
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from sglang.lang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import ChatTemplate, get_chat_template_by_model_path
|
||||
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SglSamplingParams
|
||||
|
||||
try:
|
||||
import openai
|
||||
import tiktoken
|
||||
except ImportError as e:
|
||||
openai = tiktoken = e
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_logit_bias_int(tokenizer):
|
||||
"""Get logit bias for integer numbers."""
|
||||
int_token_ids = []
|
||||
|
||||
tokens = tokenizer._mergeable_ranks
|
||||
for token, token_id in tokens.items():
|
||||
s = tokenizer.decode([token_id])
|
||||
if all([c.isdigit() for c in s]) or s in [" "]:
|
||||
int_token_ids.append(token_id)
|
||||
if len(int_token_ids) >= 300: # OpenAI API limit
|
||||
break
|
||||
special_tokens = tokenizer._special_tokens
|
||||
mask = {t: 100 for t in int_token_ids[:299]}
|
||||
mask[special_tokens["<|endoftext|>"]] = 100
|
||||
return mask
|
||||
|
||||
|
||||
INSTRUCT_MODEL_NAMES = [
|
||||
"gpt-3.5-turbo-instruct",
|
||||
]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TokenUsage:
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
|
||||
def reset(self):
|
||||
self.prompt_tokens = self.completion_tokens = 0
|
||||
|
||||
|
||||
class OpenAI(BaseBackend):
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
is_chat_model: Optional[bool] = None,
|
||||
chat_template: Optional[ChatTemplate] = None,
|
||||
is_azure: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(openai, Exception):
|
||||
raise openai
|
||||
|
||||
if is_azure:
|
||||
self.client = openai.AzureOpenAI(*args, **kwargs)
|
||||
else:
|
||||
self.client = openai.OpenAI(*args, **kwargs)
|
||||
|
||||
self.model_name = model_name
|
||||
try:
|
||||
self.tokenizer = tiktoken.encoding_for_model(model_name)
|
||||
except KeyError:
|
||||
self.tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
self.logit_bias_int = create_logit_bias_int(self.tokenizer)
|
||||
|
||||
self.chat_template = chat_template or get_chat_template_by_model_path(
|
||||
model_name
|
||||
)
|
||||
|
||||
if is_chat_model is not None:
|
||||
self.is_chat_model = is_chat_model
|
||||
else:
|
||||
if model_name in INSTRUCT_MODEL_NAMES:
|
||||
self.is_chat_model = False
|
||||
else:
|
||||
self.is_chat_model = True
|
||||
|
||||
self.chat_prefix = self.chat_template.role_prefix_and_suffix["assistant"][0]
|
||||
|
||||
# Usage
|
||||
self.token_usage = TokenUsage(0, 0)
|
||||
|
||||
# API speculative execution
|
||||
# TODO(ying): This does not support multi-threading (run_batch)
|
||||
self.spec_kwargs = {}
|
||||
self.spec_format = []
|
||||
self.spec_max_num_tries = 3
|
||||
|
||||
def get_chat_template(self):
|
||||
return self.chat_template
|
||||
|
||||
def _prepare_spec_execution(
|
||||
self,
|
||||
sampling_params: SglSamplingParams,
|
||||
num_api_spec_tokens: int,
|
||||
spec_var_name: str,
|
||||
):
|
||||
if "max_tokens" not in self.spec_kwargs:
|
||||
self.spec_kwargs["max_tokens"] = num_api_spec_tokens
|
||||
else:
|
||||
assert self.spec_kwargs["max_tokens"] == num_api_spec_tokens
|
||||
|
||||
params = sampling_params.to_openai_kwargs()
|
||||
for key, value in params.items():
|
||||
if key in ["stop"]:
|
||||
continue
|
||||
if key in ["max_tokens"]:
|
||||
warnings.warn(
|
||||
"The parameter max_tokens will be overwritten by speculated number of tokens."
|
||||
)
|
||||
continue
|
||||
if key not in self.spec_kwargs:
|
||||
self.spec_kwargs[key] = value
|
||||
else:
|
||||
assert (
|
||||
value == self.spec_kwargs[key]
|
||||
), "sampling parameters should be consistent if turn on api speculative execution."
|
||||
self.spec_format.append(
|
||||
{"text": "", "stop": params["stop"], "name": spec_var_name}
|
||||
)
|
||||
return "", {}
|
||||
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
spec_var_name: str = None,
|
||||
):
|
||||
if sampling_params.dtype is None:
|
||||
if self.is_chat_model:
|
||||
if s.num_api_spec_tokens is None:
|
||||
if not s.text_.endswith(self.chat_prefix):
|
||||
raise RuntimeError(
|
||||
"This use case is not supported if api speculative execution is off. "
|
||||
"For OpenAI chat models, sgl.gen must be right after sgl.assistant. "
|
||||
"Example of adding api speculative execution: @function(num_api_spec_tokens=128)."
|
||||
)
|
||||
prompt = s.messages_
|
||||
else:
|
||||
return self._prepare_spec_execution(
|
||||
sampling_params, s.num_api_spec_tokens, spec_var_name
|
||||
)
|
||||
else:
|
||||
prompt = s.text_
|
||||
|
||||
kwargs = sampling_params.to_openai_kwargs()
|
||||
if (
|
||||
self.model_name.startswith("o1")
|
||||
or self.model_name.startswith("o3")
|
||||
or "o1" in self.model_name
|
||||
):
|
||||
kwargs.pop("max_tokens", None)
|
||||
else:
|
||||
kwargs.pop("max_completion_tokens", None)
|
||||
|
||||
comp = openai_completion(
|
||||
client=self.client,
|
||||
token_usage=self.token_usage,
|
||||
is_chat=self.is_chat_model,
|
||||
model=self.model_name,
|
||||
prompt=prompt,
|
||||
**kwargs,
|
||||
)
|
||||
# Keep the returned list (or string) as is.
|
||||
elif sampling_params.dtype in [str, "str", "string"]:
|
||||
assert (
|
||||
not self.is_chat_model
|
||||
), "constrained type not supported on chat model"
|
||||
kwargs = sampling_params.to_openai_kwargs()
|
||||
kwargs.pop("stop")
|
||||
comp = openai_completion(
|
||||
client=self.client,
|
||||
token_usage=self.token_usage,
|
||||
is_chat=self.is_chat_model,
|
||||
model=self.model_name,
|
||||
prompt=s.text_ + '"',
|
||||
stop='"',
|
||||
**kwargs,
|
||||
)
|
||||
# Wrap each element in quotes if we have a list.
|
||||
if isinstance(comp, list):
|
||||
comp = ['"' + x + '"' for x in comp]
|
||||
else:
|
||||
comp = '"' + comp + '"'
|
||||
elif sampling_params.dtype in [int, "int"]:
|
||||
assert (
|
||||
not self.is_chat_model
|
||||
), "constrained type not supported on chat model"
|
||||
kwargs = sampling_params.to_openai_kwargs()
|
||||
kwargs.pop("stop")
|
||||
comp = openai_completion(
|
||||
client=self.client,
|
||||
token_usage=self.token_usage,
|
||||
is_chat=self.is_chat_model,
|
||||
model=self.model_name,
|
||||
prompt=s.text_,
|
||||
logit_bias=self.logit_bias_int,
|
||||
stop=[" "],
|
||||
**kwargs,
|
||||
)
|
||||
# Leave as a list if that's what is returned.
|
||||
else:
|
||||
raise ValueError(f"Unknown dtype: {sampling_params.dtype}")
|
||||
|
||||
return comp, {}
|
||||
|
||||
def spec_fill(self, value: str):
|
||||
assert self.is_chat_model
|
||||
self.spec_format.append({"text": value, "stop": None, "name": None})
|
||||
|
||||
def spec_pattern_match(self, comp):
|
||||
for i, term in enumerate(self.spec_format):
|
||||
text = term["text"]
|
||||
if text != "":
|
||||
if comp.startswith(text):
|
||||
comp = comp[len(text) :]
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
pos = comp.find(term["stop"])
|
||||
if pos != -1:
|
||||
term["text"] = comp[:pos]
|
||||
comp = comp[pos:]
|
||||
else:
|
||||
if i == len(self.spec_format) - 1:
|
||||
term["text"] = comp
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def role_end_generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
):
|
||||
if s.num_api_spec_tokens is None or not s.text_.endswith(self.chat_prefix):
|
||||
return
|
||||
|
||||
comp = ""
|
||||
if not all(x["name"] is None for x in self.spec_format):
|
||||
# TODO(ying): throw errors or warnings
|
||||
for i in range(self.spec_max_num_tries):
|
||||
comp = openai_completion(
|
||||
client=self.client,
|
||||
token_usage=self.token_usage,
|
||||
is_chat=self.is_chat_model,
|
||||
model=self.model_name,
|
||||
prompt=s.messages_,
|
||||
**self.spec_kwargs,
|
||||
)
|
||||
# Use a string for pattern matching.
|
||||
comp_for_match = comp[0] if isinstance(comp, list) else comp
|
||||
if self.spec_pattern_match(comp_for_match):
|
||||
break
|
||||
|
||||
for term in self.spec_format:
|
||||
s.text_ += term["text"]
|
||||
name = term["name"]
|
||||
if name is not None:
|
||||
s.variables[name] = term["text"]
|
||||
s.meta_info[name] = {}
|
||||
s.variable_event[name].set()
|
||||
|
||||
self.spec_kwargs = {}
|
||||
self.spec_format = []
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
if sampling_params.dtype is None:
|
||||
if self.is_chat_model:
|
||||
if not s.text_.endswith(self.chat_prefix):
|
||||
raise RuntimeError(
|
||||
"This use case is not supported. "
|
||||
"For OpenAI chat models, sgl.gen must be right after sgl.assistant"
|
||||
)
|
||||
prompt = s.messages_
|
||||
else:
|
||||
prompt = s.text_
|
||||
|
||||
kwargs = sampling_params.to_openai_kwargs()
|
||||
generator = openai_completion_stream(
|
||||
client=self.client,
|
||||
token_usage=self.token_usage,
|
||||
is_chat=self.is_chat_model,
|
||||
model=self.model_name,
|
||||
prompt=prompt,
|
||||
**kwargs,
|
||||
)
|
||||
return generator
|
||||
else:
|
||||
raise ValueError(f"Unknown dtype: {sampling_params.dtype}")
|
||||
|
||||
def select(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
choices: List[str],
|
||||
temperature: float,
|
||||
choices_method: ChoicesSamplingMethod,
|
||||
) -> ChoicesDecision:
|
||||
"""Note: `choices_method` is not used by the OpenAI backend."""
|
||||
if self.is_chat_model:
|
||||
raise NotImplementedError(
|
||||
"select/choices is not supported for chat models. "
|
||||
"Please try to use a non-chat model such as gpt-3.5-turbo-instruct"
|
||||
)
|
||||
|
||||
n_choices = len(choices)
|
||||
token_ids = [self.tokenizer.encode(x) for x in choices]
|
||||
scores = [0] * n_choices
|
||||
valid = [len(x) > 0 for x in token_ids]
|
||||
prompt_tokens = self.tokenizer.encode(s.text_)
|
||||
|
||||
max_len = max([len(x) for x in token_ids])
|
||||
for step in range(max_len):
|
||||
# Build logit bias
|
||||
logit_bias = {}
|
||||
for i in range(n_choices):
|
||||
if valid[i]:
|
||||
logit_bias[token_ids[i][step]] = 100
|
||||
|
||||
# Call API
|
||||
ret = self.client.completions.create(
|
||||
model=self.model_name,
|
||||
prompt=prompt_tokens,
|
||||
logit_bias=logit_bias,
|
||||
max_tokens=1,
|
||||
temperature=temperature,
|
||||
)
|
||||
ret_str = ret.choices[0].text
|
||||
ret_token = self.tokenizer.encode(ret_str)[0]
|
||||
self.token_usage.prompt_tokens += ret.usage.prompt_tokens
|
||||
self.token_usage.completion_tokens = ret.usage.completion_tokens
|
||||
|
||||
# TODO:
|
||||
# 1. return logits as the scores
|
||||
# 2. compute logits of the full choice
|
||||
# 3. consider chunk-based decoding
|
||||
|
||||
# Update valid
|
||||
hit = False
|
||||
for i in range(n_choices):
|
||||
if valid[i]:
|
||||
if step == len(token_ids[i]) - 1:
|
||||
valid[i] = False
|
||||
|
||||
if ret_token == token_ids[i][step]:
|
||||
scores[i] += 1
|
||||
hit = True
|
||||
else:
|
||||
valid[i] = False
|
||||
assert hit
|
||||
|
||||
if np.sum(valid) <= 1:
|
||||
break
|
||||
|
||||
prompt_tokens.append(ret_token)
|
||||
|
||||
return ChoicesDecision(
|
||||
decision=choices[np.argmax(scores)],
|
||||
meta_info={"scores": scores},
|
||||
)
|
||||
|
||||
|
||||
def openai_completion(
|
||||
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
|
||||
) -> Union[str, List[str]]:
|
||||
# if "ebnf" is in kwargs, warn and remove
|
||||
if "ebnf" in kwargs:
|
||||
warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.")
|
||||
del kwargs["ebnf"]
|
||||
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
if is_chat:
|
||||
if "stop" in kwargs and kwargs["stop"] is None:
|
||||
kwargs.pop("stop")
|
||||
ret = client.chat.completions.create(messages=prompt, **kwargs)
|
||||
if len(ret.choices) == 1:
|
||||
comp = ret.choices[0].message.content
|
||||
else:
|
||||
comp = [c.message.content for c in ret.choices]
|
||||
else:
|
||||
ret = client.completions.create(prompt=prompt, **kwargs)
|
||||
if isinstance(prompt, (list, tuple)):
|
||||
comp = [c.text for c in ret.choices]
|
||||
else:
|
||||
comp = ret.choices[0].text
|
||||
if len(ret.choices) > 1:
|
||||
comp = [c.text for c in ret.choices]
|
||||
|
||||
token_usage.prompt_tokens += ret.usage.prompt_tokens
|
||||
token_usage.completion_tokens += ret.usage.completion_tokens
|
||||
break
|
||||
except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
|
||||
logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")
|
||||
time.sleep(5)
|
||||
if attempt == retries - 1:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f"RuntimeError {e}.")
|
||||
raise e
|
||||
|
||||
return comp
|
||||
|
||||
|
||||
def openai_completion_stream(
|
||||
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
|
||||
):
|
||||
# if "ebnf" is in kwargs, warn and remove
|
||||
if "ebnf" in kwargs:
|
||||
warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.")
|
||||
del kwargs["ebnf"]
|
||||
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
if is_chat:
|
||||
if "stop" in kwargs and kwargs["stop"] is None:
|
||||
kwargs.pop("stop")
|
||||
generator = client.chat.completions.create(
|
||||
messages=prompt,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
**kwargs,
|
||||
)
|
||||
for ret in generator:
|
||||
if len(ret.choices) == 0:
|
||||
continue
|
||||
try:
|
||||
content = ret.choices[0].delta.content
|
||||
except IndexError:
|
||||
content = None
|
||||
yield content or "", {}
|
||||
else:
|
||||
generator = client.completions.create(
|
||||
prompt=prompt,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
**kwargs,
|
||||
)
|
||||
for ret in generator:
|
||||
if len(ret.choices) == 0:
|
||||
continue
|
||||
content = ret.choices[0].text
|
||||
yield content or "", {}
|
||||
|
||||
token_usage.prompt_tokens += ret.usage.prompt_tokens
|
||||
token_usage.completion_tokens += ret.usage.completion_tokens
|
||||
break
|
||||
except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
|
||||
logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")
|
||||
time.sleep(5)
|
||||
if attempt == retries - 1:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f"RuntimeError {e}.")
|
||||
raise e
|
||||
527
python/sglang/lang/backend/runtime_endpoint.py
Normal file
527
python/sglang/lang/backend/runtime_endpoint.py
Normal file
@@ -0,0 +1,527 @@
|
||||
import atexit
|
||||
import json
|
||||
import multiprocessing
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path
|
||||
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import (
|
||||
REGEX_BOOL,
|
||||
REGEX_FLOAT,
|
||||
REGEX_INT,
|
||||
REGEX_STR,
|
||||
SglSamplingParams,
|
||||
)
|
||||
from sglang.utils import http_request
|
||||
|
||||
|
||||
class RuntimeEndpoint(BaseBackend):
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
api_key: Optional[str] = None,
|
||||
verify: Optional[str] = None,
|
||||
chat_template_name: Optional[str] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.support_concate_and_append = True
|
||||
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
self.verify = verify
|
||||
|
||||
res = http_request(
|
||||
self.base_url + "/get_model_info",
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
self.model_info = res.json()
|
||||
|
||||
if chat_template_name:
|
||||
self.chat_template = get_chat_template(chat_template_name)
|
||||
else:
|
||||
self.chat_template = get_chat_template_by_model_path(
|
||||
self.model_info["model_path"]
|
||||
)
|
||||
|
||||
def get_model_name(self):
|
||||
return self.model_info["model_path"]
|
||||
|
||||
def flush_cache(self):
|
||||
res = http_request(
|
||||
self.base_url + "/flush_cache",
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
method="POST",
|
||||
)
|
||||
self._assert_success(res)
|
||||
|
||||
def get_server_info(self):
|
||||
res = http_request(
|
||||
self.base_url + "/get_server_info",
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
return res.json()
|
||||
|
||||
def get_chat_template(self):
|
||||
return self.chat_template
|
||||
|
||||
def cache_prefix(self, prefix_str: str):
|
||||
res = http_request(
|
||||
self.base_url + "/generate",
|
||||
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
|
||||
def start_profile(self):
|
||||
res = http_request(
|
||||
self.base_url + "/start_profile",
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
|
||||
def stop_profile(self):
|
||||
res = http_request(
|
||||
self.base_url + "/stop_profile",
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
|
||||
def commit_lazy_operations(self, s: StreamExecutor):
|
||||
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
||||
self._add_images(s, data)
|
||||
res = http_request(
|
||||
self.base_url + "/generate",
|
||||
json=data,
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
|
||||
def fill_image(self, s: StreamExecutor):
|
||||
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
||||
self._add_images(s, data)
|
||||
res = http_request(
|
||||
self.base_url + "/generate",
|
||||
json=data,
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
|
||||
def _handle_dtype_to_regex(self, sampling_params: SglSamplingParams):
|
||||
if sampling_params.dtype is None:
|
||||
return
|
||||
|
||||
if sampling_params.stop == ():
|
||||
sampling_params.stop = []
|
||||
|
||||
dtype_regex = None
|
||||
if sampling_params.dtype in ["int", int]:
|
||||
|
||||
dtype_regex = REGEX_INT
|
||||
sampling_params.stop.extend([" ", "\n"])
|
||||
elif sampling_params.dtype in ["float", float]:
|
||||
|
||||
dtype_regex = REGEX_FLOAT
|
||||
sampling_params.stop.extend([" ", "\n"])
|
||||
elif sampling_params.dtype in ["str", str]:
|
||||
|
||||
dtype_regex = REGEX_STR
|
||||
elif sampling_params.dtype in ["bool", bool]:
|
||||
|
||||
dtype_regex = REGEX_BOOL
|
||||
else:
|
||||
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
||||
|
||||
if dtype_regex is not None and sampling_params.regex is not None:
|
||||
warnings.warn(
|
||||
f"Both dtype and regex are set. Only dtype will be used. dtype: {sampling_params.dtype}, regex: {sampling_params.regex}"
|
||||
)
|
||||
|
||||
sampling_params.regex = dtype_regex
|
||||
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
self._handle_dtype_to_regex(sampling_params)
|
||||
data = {
|
||||
"text": s.text_,
|
||||
"sampling_params": {
|
||||
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
||||
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
||||
**sampling_params.to_srt_kwargs(),
|
||||
},
|
||||
}
|
||||
|
||||
for item in [
|
||||
"return_logprob",
|
||||
"logprob_start_len",
|
||||
"top_logprobs_num",
|
||||
"return_text_in_logprobs",
|
||||
]:
|
||||
value = getattr(sampling_params, item, None)
|
||||
if value is not None:
|
||||
data[item] = value
|
||||
|
||||
self._add_images(s, data)
|
||||
|
||||
res = http_request(
|
||||
self.base_url + "/generate",
|
||||
json=data,
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
|
||||
obj = res.json()
|
||||
comp = obj["text"]
|
||||
return comp, obj["meta_info"]
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
self._handle_dtype_to_regex(sampling_params)
|
||||
|
||||
data = {
|
||||
"text": s.text_,
|
||||
"sampling_params": {
|
||||
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
||||
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
||||
**sampling_params.to_srt_kwargs(),
|
||||
},
|
||||
}
|
||||
|
||||
for item in [
|
||||
"return_logprob",
|
||||
"logprob_start_len",
|
||||
"top_logprobs_num",
|
||||
"return_text_in_logprobs",
|
||||
]:
|
||||
value = getattr(sampling_params, item, None)
|
||||
if value is not None:
|
||||
data[item] = value
|
||||
|
||||
data["stream"] = True
|
||||
self._add_images(s, data)
|
||||
|
||||
res = http_request(
|
||||
self.base_url + "/generate",
|
||||
json=data,
|
||||
stream=True,
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
pos = 0
|
||||
|
||||
for chunk in res.iter_lines(decode_unicode=False):
|
||||
chunk = chunk.decode("utf-8")
|
||||
if chunk and chunk.startswith("data:"):
|
||||
if chunk == "data: [DONE]":
|
||||
break
|
||||
data = json.loads(chunk[5:].strip("\n"))
|
||||
chunk_text = data["text"][pos:]
|
||||
meta_info = data["meta_info"]
|
||||
pos += len(chunk_text)
|
||||
yield chunk_text, meta_info
|
||||
|
||||
def select(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
choices: List[str],
|
||||
temperature: float,
|
||||
choices_method: ChoicesSamplingMethod,
|
||||
) -> ChoicesDecision:
|
||||
assert temperature <= 1e-5
|
||||
|
||||
# Cache common prefix
|
||||
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
||||
obj = self._generate_http_request(s, data)
|
||||
prompt_len = obj["meta_info"]["prompt_tokens"]
|
||||
logprob_start_len = max(prompt_len - 2, 0) # For token healing
|
||||
|
||||
# Compute logprob
|
||||
data = {
|
||||
"text": [s.text_ + c for c in choices],
|
||||
"sampling_params": {
|
||||
"max_new_tokens": 0,
|
||||
"temperature": 0,
|
||||
},
|
||||
"return_logprob": True,
|
||||
"return_text_in_logprobs": True,
|
||||
"logprob_start_len": logprob_start_len,
|
||||
}
|
||||
obj = self._generate_http_request(s, data)
|
||||
|
||||
input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
|
||||
output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
|
||||
normalized_prompt_logprobs = [
|
||||
compute_normalized_prompt_logprobs(r["meta_info"]["input_token_logprobs"])
|
||||
for r in obj
|
||||
]
|
||||
|
||||
# Remove extra token if no token healing occurred
|
||||
for i in range(len(input_token_logprobs)):
|
||||
healed_token_str = input_token_logprobs[i][0][-1]
|
||||
if s.text_.endswith(healed_token_str):
|
||||
healed_token_logprob = input_token_logprobs[i][0][0]
|
||||
normalized_prompt_logprobs[i] = (
|
||||
normalized_prompt_logprobs[i] * len(input_token_logprobs[i])
|
||||
- healed_token_logprob
|
||||
) / (len(input_token_logprobs[i]) - 1)
|
||||
input_token_logprobs[i] = input_token_logprobs[i][1:]
|
||||
|
||||
# Compute unconditional logprobs if required
|
||||
if choices_method.requires_unconditional_logprobs:
|
||||
input_ids = [[el[1] for el in subl] for subl in input_token_logprobs]
|
||||
data = {
|
||||
"input_ids": input_ids,
|
||||
"sampling_params": {"max_new_tokens": 0},
|
||||
"return_logprob": True,
|
||||
}
|
||||
obj = self._generate_http_request(s, data)
|
||||
unconditional_token_logprobs = [
|
||||
r["meta_info"]["input_token_logprobs"] for r in obj
|
||||
]
|
||||
else:
|
||||
unconditional_token_logprobs = None
|
||||
|
||||
return choices_method(
|
||||
choices=choices,
|
||||
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
||||
input_token_logprobs=input_token_logprobs,
|
||||
output_token_logprobs=output_token_logprobs,
|
||||
unconditional_token_logprobs=unconditional_token_logprobs,
|
||||
)
|
||||
|
||||
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
|
||||
res = http_request(
|
||||
self.base_url + "/concate_and_append_request",
|
||||
json={"src_rids": src_rids, "dst_rid": dst_rid},
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
|
||||
def _generate_http_request(self, s: StreamExecutor, data):
|
||||
self._add_images(s, data)
|
||||
res = http_request(
|
||||
self.base_url + "/generate",
|
||||
json=data,
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
return res.json()
|
||||
|
||||
def _add_images(self, s: StreamExecutor, data):
|
||||
if s.images_:
|
||||
assert len(s.images_) == 1, "Only support one image."
|
||||
data["image_data"] = s.images_[0][1]
|
||||
|
||||
def _assert_success(self, res):
|
||||
if res.status_code != 200:
|
||||
try:
|
||||
content = res.json()
|
||||
except json.JSONDecodeError:
|
||||
content = res.text
|
||||
raise RuntimeError(content)
|
||||
|
||||
|
||||
def compute_normalized_prompt_logprobs(input_logprobs):
|
||||
values = [x[0] for x in input_logprobs if x[0]]
|
||||
return sum(values) / len(values)
|
||||
|
||||
|
||||
class Runtime:
|
||||
"""
|
||||
A wrapper for the HTTP server.
|
||||
This is used for launching the server in a python program without
|
||||
using the command line interface.
|
||||
|
||||
It is mainly used for the frontend language.
|
||||
You should use the Engine class if you want to do normal offline processing without the frontend language.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
log_level: str = "error",
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""See the arguments in server_args.py::ServerArgs"""
|
||||
# We delay the import of any `sglang.srt` components in `sglang.lang`, so users can run
|
||||
# client code without installing SRT server and its dependency if they want.
|
||||
from sglang.srt.entrypoints.http_server import launch_server
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import is_port_available
|
||||
|
||||
self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
|
||||
|
||||
# Pre-allocate ports
|
||||
for port in range(self.server_args.port, 40000):
|
||||
if is_port_available(port):
|
||||
break
|
||||
self.server_args.port = port
|
||||
|
||||
self.url = self.server_args.url()
|
||||
self.generate_url = self.url + "/generate"
|
||||
|
||||
# NOTE: We store pid instead of proc to fix some issues during __delete__
|
||||
self.pid = None
|
||||
pipe_reader, pipe_writer = multiprocessing.Pipe(duplex=False)
|
||||
|
||||
ctx = multiprocessing.get_context("spawn")
|
||||
proc = ctx.Process(
|
||||
target=launch_server,
|
||||
args=(self.server_args, pipe_writer),
|
||||
)
|
||||
proc.start()
|
||||
pipe_writer.close()
|
||||
self.pid = proc.pid
|
||||
|
||||
# Before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
|
||||
atexit.register(self.shutdown)
|
||||
|
||||
# TODO: remove this pipe_writer mechanism and use `/health_generate` instead.
|
||||
try:
|
||||
init_state = pipe_reader.recv()
|
||||
except EOFError:
|
||||
init_state = ""
|
||||
|
||||
if init_state != "ready":
|
||||
self.shutdown()
|
||||
raise RuntimeError(
|
||||
"Initialization failed. Please see the error messages above."
|
||||
)
|
||||
|
||||
self.endpoint = RuntimeEndpoint(self.url)
|
||||
|
||||
def shutdown(self):
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
|
||||
if self.pid is not None:
|
||||
kill_process_tree(self.pid)
|
||||
self.pid = None
|
||||
|
||||
def start_profile(self):
|
||||
self.endpoint.start_profile()
|
||||
|
||||
def stop_profile(self):
|
||||
self.endpoint.stop_profile()
|
||||
|
||||
def cache_prefix(self, prefix: str):
|
||||
self.endpoint.cache_prefix(prefix)
|
||||
|
||||
def get_tokenizer(self):
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
|
||||
return get_tokenizer(
|
||||
self.server_args.tokenizer_path,
|
||||
tokenizer_mode=self.server_args.tokenizer_mode,
|
||||
trust_remote_code=self.server_args.trust_remote_code,
|
||||
revision=self.server_args.revision,
|
||||
)
|
||||
|
||||
async def async_generate(
|
||||
self,
|
||||
prompt: str,
|
||||
sampling_params: Optional[Dict] = None,
|
||||
):
|
||||
if self.server_args.skip_tokenizer_init:
|
||||
json_data = {
|
||||
"input_ids": prompt,
|
||||
"sampling_params": sampling_params,
|
||||
"stream": True,
|
||||
}
|
||||
else:
|
||||
json_data = {
|
||||
"text": prompt,
|
||||
"sampling_params": sampling_params,
|
||||
"stream": True,
|
||||
}
|
||||
pos = 0
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=3 * 3600)
|
||||
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
||||
async with session.post(self.generate_url, json=json_data) as response:
|
||||
async for chunk, _ in response.content.iter_chunks():
|
||||
chunk = chunk.decode("utf-8")
|
||||
if chunk and chunk.startswith("data:"):
|
||||
if chunk == "data: [DONE]\n\n":
|
||||
break
|
||||
data = json.loads(chunk[5:].strip("\n"))
|
||||
if "text" in data:
|
||||
cur = data["text"][pos:]
|
||||
if cur:
|
||||
yield cur
|
||||
pos += len(cur)
|
||||
else:
|
||||
yield data
|
||||
|
||||
add_request = async_generate
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
sampling_params: Optional[Dict] = None,
|
||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||
lora_path: Optional[List[Optional[str]]] = None,
|
||||
):
|
||||
json_data = {
|
||||
"text": prompt,
|
||||
"sampling_params": sampling_params,
|
||||
"return_logprob": return_logprob,
|
||||
"logprob_start_len": logprob_start_len,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
"lora_path": lora_path,
|
||||
}
|
||||
assert not isinstance(lora_path, list) or len(lora_path) == len(prompt)
|
||||
response = requests.post(
|
||||
self.url + "/generate",
|
||||
json=json_data,
|
||||
)
|
||||
return json.dumps(response.json())
|
||||
|
||||
def encode(
|
||||
self,
|
||||
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
||||
):
|
||||
json_data = {"text": prompt}
|
||||
response = requests.post(self.url + "/encode", json=json_data)
|
||||
return json.dumps(response.json())
|
||||
|
||||
async def get_server_info(self):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(f"{self.url}/get_server_info") as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
else:
|
||||
error_data = await response.json()
|
||||
raise RuntimeError(
|
||||
f"Failed to get server info. {error_data['error']['message']}"
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
148
python/sglang/lang/backend/vertexai.py
Normal file
148
python/sglang/lang/backend/vertexai.py
Normal file
@@ -0,0 +1,148 @@
|
||||
import os
|
||||
import warnings
|
||||
|
||||
from sglang.lang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import get_chat_template
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SglSamplingParams
|
||||
|
||||
try:
|
||||
import vertexai
|
||||
from vertexai.preview.generative_models import (
|
||||
GenerationConfig,
|
||||
GenerativeModel,
|
||||
Image,
|
||||
)
|
||||
except ImportError as e:
|
||||
GenerativeModel = e
|
||||
|
||||
|
||||
class VertexAI(BaseBackend):
|
||||
def __init__(self, model_name, safety_settings=None):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(GenerativeModel, Exception):
|
||||
raise GenerativeModel
|
||||
|
||||
project_id = os.environ["GCP_PROJECT_ID"]
|
||||
location = os.environ.get("GCP_LOCATION")
|
||||
vertexai.init(project=project_id, location=location)
|
||||
|
||||
self.model_name = model_name
|
||||
self.chat_template = get_chat_template("default")
|
||||
self.safety_settings = safety_settings
|
||||
|
||||
def get_chat_template(self):
|
||||
return self.chat_template
|
||||
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
if s.messages_:
|
||||
prompt = self.messages_to_vertexai_input(s.messages_)
|
||||
else:
|
||||
# single-turn
|
||||
prompt = (
|
||||
self.text_to_vertexai_input(s.text_, s.cur_images)
|
||||
if s.cur_images
|
||||
else s.text_
|
||||
)
|
||||
ret = GenerativeModel(self.model_name).generate_content(
|
||||
prompt,
|
||||
generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
|
||||
safety_settings=self.safety_settings,
|
||||
)
|
||||
|
||||
comp = ret.text
|
||||
|
||||
return comp, {}
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
if s.messages_:
|
||||
prompt = self.messages_to_vertexai_input(s.messages_)
|
||||
else:
|
||||
# single-turn
|
||||
prompt = (
|
||||
self.text_to_vertexai_input(s.text_, s.cur_images)
|
||||
if s.cur_images
|
||||
else s.text_
|
||||
)
|
||||
generator = GenerativeModel(self.model_name).generate_content(
|
||||
prompt,
|
||||
stream=True,
|
||||
generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
|
||||
safety_settings=self.safety_settings,
|
||||
)
|
||||
for ret in generator:
|
||||
yield ret.text, {}
|
||||
|
||||
def text_to_vertexai_input(self, text, images):
|
||||
input = []
|
||||
# split with image token
|
||||
text_segs = text.split(self.chat_template.image_token)
|
||||
for image_path, image_base64_data in images:
|
||||
text_seg = text_segs.pop(0)
|
||||
if text_seg != "":
|
||||
input.append(text_seg)
|
||||
input.append(Image.from_bytes(image_base64_data))
|
||||
text_seg = text_segs.pop(0)
|
||||
if text_seg != "":
|
||||
input.append(text_seg)
|
||||
return input
|
||||
|
||||
def messages_to_vertexai_input(self, messages):
|
||||
vertexai_message = []
|
||||
# from openai message format to vertexai message format
|
||||
for msg in messages:
|
||||
if isinstance(msg["content"], str):
|
||||
text = msg["content"]
|
||||
else:
|
||||
text = msg["content"][0]["text"]
|
||||
|
||||
if msg["role"] == "system":
|
||||
warnings.warn("Warning: system prompt is not supported in VertexAI.")
|
||||
vertexai_message.append(
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [{"text": "System prompt: " + text}],
|
||||
}
|
||||
)
|
||||
vertexai_message.append(
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [{"text": "Understood."}],
|
||||
}
|
||||
)
|
||||
continue
|
||||
if msg["role"] == "user":
|
||||
vertexai_msg = {
|
||||
"role": "user",
|
||||
"parts": [{"text": text}],
|
||||
}
|
||||
elif msg["role"] == "assistant":
|
||||
vertexai_msg = {
|
||||
"role": "model",
|
||||
"parts": [{"text": text}],
|
||||
}
|
||||
|
||||
# images
|
||||
if isinstance(msg["content"], list) and len(msg["content"]) > 1:
|
||||
for image in msg["content"][1:]:
|
||||
assert image["type"] == "image_url"
|
||||
vertexai_msg["parts"].append(
|
||||
{
|
||||
"inline_data": {
|
||||
"data": image["image_url"]["url"].split(",")[1],
|
||||
"mime_type": "image/jpeg",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
vertexai_message.append(vertexai_msg)
|
||||
return vertexai_message
|
||||
662
python/sglang/lang/chat_template.py
Normal file
662
python/sglang/lang/chat_template.py
Normal file
@@ -0,0 +1,662 @@
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import Callable, Dict, List, Tuple
|
||||
|
||||
|
||||
class ChatTemplateStyle(Enum):
|
||||
PLAIN = auto()
|
||||
LLAMA2 = auto()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatTemplate:
|
||||
name: str
|
||||
default_system_prompt: str
|
||||
role_prefix_and_suffix: Dict[str, Tuple[str, str]]
|
||||
stop_str: List[str] = ()
|
||||
image_token: str = "<image>"
|
||||
audio_token: str = "<audio>"
|
||||
style: ChatTemplateStyle = ChatTemplateStyle.PLAIN
|
||||
|
||||
def get_prefix_and_suffix(
|
||||
self, role: str, hist_messages: List[Dict]
|
||||
) -> Tuple[str, str]:
|
||||
prefix, suffix = self.role_prefix_and_suffix.get(role, ("", ""))
|
||||
|
||||
if self.style == ChatTemplateStyle.LLAMA2:
|
||||
if role == "system" and not hist_messages:
|
||||
user_prefix, _ = self.role_prefix_and_suffix.get("user", ("", ""))
|
||||
system_prefix, system_suffix = self.role_prefix_and_suffix.get(
|
||||
"system", ("", "")
|
||||
)
|
||||
return (user_prefix + system_prefix, system_suffix)
|
||||
elif (
|
||||
role == "user"
|
||||
and len(hist_messages) == 1
|
||||
and hist_messages[0]["content"] is not None
|
||||
):
|
||||
return ("", suffix)
|
||||
|
||||
return prefix, suffix
|
||||
|
||||
def get_prompt(self, messages: List[Dict]) -> str:
|
||||
prompt = ""
|
||||
for i, message in enumerate(messages):
|
||||
role, content = message["role"], message["content"]
|
||||
if role == "system" and content is None:
|
||||
content = self.default_system_prompt
|
||||
if content is None:
|
||||
continue
|
||||
|
||||
prefix, suffix = self.get_prefix_and_suffix(role, messages[:i])
|
||||
prompt += f"{prefix}{content}{suffix}"
|
||||
return prompt
|
||||
|
||||
|
||||
chat_template_registry: Dict[str, ChatTemplate] = {}
|
||||
matching_function_registry: List[Callable] = []
|
||||
|
||||
|
||||
def register_chat_template(template):
|
||||
chat_template_registry[template.name] = template
|
||||
|
||||
|
||||
def register_chat_template_matching_function(func):
|
||||
matching_function_registry.append(func)
|
||||
|
||||
|
||||
def get_chat_template(name):
|
||||
return chat_template_registry[name]
|
||||
|
||||
|
||||
def get_chat_template_by_model_path(model_path):
|
||||
for matching_func in matching_function_registry:
|
||||
template_name = matching_func(model_path)
|
||||
if template_name is not None:
|
||||
return get_chat_template(template_name)
|
||||
return get_chat_template("default")
|
||||
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="default",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": ("SYSTEM:", "\n"),
|
||||
"user": ("USER:", "\n"),
|
||||
"assistant": ("ASSISTANT:", "\n"),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="claude",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": ("", ""),
|
||||
"user": ("\n\nHuman: ", ""),
|
||||
"assistant": ("\n\nAssistant:", ""),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="chatml",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
|
||||
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
|
||||
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
|
||||
},
|
||||
style=ChatTemplateStyle.PLAIN,
|
||||
stop_str=("<|im_end|>",),
|
||||
)
|
||||
)
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="chatml-llava",
|
||||
default_system_prompt="You are a helpful assistant.",
|
||||
role_prefix_and_suffix={
|
||||
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
|
||||
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
|
||||
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
|
||||
},
|
||||
style=ChatTemplateStyle.PLAIN,
|
||||
stop_str=("<|im_end|>",),
|
||||
image_token="<image>\n",
|
||||
)
|
||||
)
|
||||
|
||||
# There is default system prompt for qwen
|
||||
# reference: https://modelscope.cn/models/qwen/Qwen2-72B-Instruct/file/view/master?fileName=tokenizer_config.json&status=1
|
||||
# The chat template is: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="qwen",
|
||||
default_system_prompt="You are a helpful assistant.",
|
||||
role_prefix_and_suffix={
|
||||
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
|
||||
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
|
||||
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
|
||||
},
|
||||
style=ChatTemplateStyle.PLAIN,
|
||||
stop_str=("<|im_end|>",),
|
||||
)
|
||||
)
|
||||
|
||||
# Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="qwen2-vl",
|
||||
default_system_prompt="You are a helpful assistant.",
|
||||
role_prefix_and_suffix={
|
||||
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
|
||||
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
|
||||
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
|
||||
},
|
||||
style=ChatTemplateStyle.PLAIN,
|
||||
stop_str=("<|im_end|>",),
|
||||
image_token="<|vision_start|><|image_pad|><|vision_end|>",
|
||||
)
|
||||
)
|
||||
|
||||
# Reference: https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md#prompt-template
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="vicuna_v1.1",
|
||||
default_system_prompt=(
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||
),
|
||||
role_prefix_and_suffix={
|
||||
"system": ("", " "),
|
||||
"user": ("USER:", " "),
|
||||
"assistant": ("ASSISTANT:", "</s>"),
|
||||
},
|
||||
image_token=" <image>\n",
|
||||
)
|
||||
)
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="llama-2-chat",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": ("<<SYS>>\n", "\n<</SYS>>\n\n"),
|
||||
"user": ("[INST] ", " [/INST]"),
|
||||
"assistant": ("", " </s><s>"),
|
||||
},
|
||||
style=ChatTemplateStyle.LLAMA2,
|
||||
)
|
||||
)
|
||||
|
||||
# Reference: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/chat_template.json
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="mistral",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": ("[SYSTEM_PROMPT] ", " [/SYSTEM_PROMPT]"),
|
||||
"user": ("[INST] ", " [/INST]"),
|
||||
"assistant": ("", " </s><s>"),
|
||||
},
|
||||
stop_str=("</s>",),
|
||||
image_token="[IMG]",
|
||||
)
|
||||
)
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="llama-3-instruct",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": (
|
||||
"<|start_header_id|>system<|end_header_id|>\n\n",
|
||||
"<|eot_id|>",
|
||||
),
|
||||
"user": (
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n",
|
||||
"<|eot_id|>",
|
||||
),
|
||||
"assistant": (
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||
"<|eot_id|>",
|
||||
),
|
||||
},
|
||||
stop_str=("<|eot_id|>",),
|
||||
image_token="<|image|>",
|
||||
)
|
||||
)
|
||||
|
||||
# https://huggingface.co/openbmb/MiniCPM-V-2_6
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="minicpmv",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": ("", " "),
|
||||
"user": ("user:", " "),
|
||||
"assistant": ("assistant:", "</s>"),
|
||||
},
|
||||
stop_str=("<|im_end|>", "<|endoftext|>"),
|
||||
image_token="(<image>./</image>)",
|
||||
)
|
||||
)
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="janus-pro",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": (
|
||||
"",
|
||||
"",
|
||||
),
|
||||
"User": (
|
||||
"<|User|>",
|
||||
"",
|
||||
),
|
||||
"assistant": (
|
||||
"<|Assistant|>",
|
||||
"<|end▁of▁sentence|>",
|
||||
),
|
||||
},
|
||||
stop_str=("<|end▁of▁sentence|>",),
|
||||
image_token="<image_placeholder>\n",
|
||||
)
|
||||
)
|
||||
|
||||
# https://huggingface.co/openbmb/MiniCPM-o-2_6
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="minicpmo",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": ("", " "),
|
||||
"user": ("user:", " "),
|
||||
"assistant": ("assistant:", "</s>"),
|
||||
},
|
||||
stop_str=("<|im_end|>", "<|endoftext|>"),
|
||||
image_token="(<image>./</image>)",
|
||||
audio_token="(<audio>./</audio>)",
|
||||
)
|
||||
)
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="janus",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": (
|
||||
"",
|
||||
"",
|
||||
),
|
||||
"user": (
|
||||
"<|User|>",
|
||||
"",
|
||||
),
|
||||
"assistant": (
|
||||
"<|Assistant|>",
|
||||
"<|end▁of▁sentence|>",
|
||||
),
|
||||
},
|
||||
stop_str=("<|end▁of▁sentence|>",),
|
||||
image_token="<image_placeholder>\n",
|
||||
)
|
||||
)
|
||||
|
||||
# The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token.
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="llama-3-instruct-llava",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": (
|
||||
"<|start_header_id|>system<|end_header_id|>\n\n",
|
||||
"<|eot_id|>",
|
||||
),
|
||||
"user": (
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n",
|
||||
"<|eot_id|>",
|
||||
),
|
||||
"assistant": (
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||
"<|eot_id|>",
|
||||
),
|
||||
},
|
||||
stop_str=("<|eot_id|>",),
|
||||
image_token="<image>\n",
|
||||
)
|
||||
)
|
||||
|
||||
# Reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="llama-4",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": (
|
||||
"<|header_start|>system<|header_end|>\n\n",
|
||||
"<|eot|>",
|
||||
),
|
||||
"user": (
|
||||
"<|header_start|>user<|header_end|>\n\n",
|
||||
"<|eot|>",
|
||||
),
|
||||
"assistant": (
|
||||
"<|header_start|>assistant<|header_end|>\n\n",
|
||||
"<|eot|>",
|
||||
),
|
||||
},
|
||||
stop_str=("<|eot|>",),
|
||||
image_token="<|image|>",
|
||||
)
|
||||
)
|
||||
|
||||
# Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="yi-1.5",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": ("", ""),
|
||||
"user": ("<|im_start|>user\n", "<|im_end|>\n<|im_start|>assistant\n"),
|
||||
"assistant": ("", "<|im_end|>\n"),
|
||||
},
|
||||
style=ChatTemplateStyle.PLAIN,
|
||||
stop_str=("<|im_end|>",),
|
||||
)
|
||||
)
|
||||
|
||||
# Reference: https://github.com/01-ai/Yi/tree/main/VL#major-difference-with-llava
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="yi-vl",
|
||||
default_system_prompt=(
|
||||
"This is a chat between an inquisitive human and an AI assistant. Assume the role of the AI assistant. Read all the images carefully, and respond to the human's questions with informative, helpful, detailed and polite answers."
|
||||
"这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。"
|
||||
),
|
||||
role_prefix_and_suffix={
|
||||
"system": ("", "\n\n"),
|
||||
"user": ("### Human:", "\n"),
|
||||
"assistant": ("### Assistant:", "\n"),
|
||||
},
|
||||
image_token=" <image_placeholder>\n",
|
||||
)
|
||||
)
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="gemma-it",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": ("", ""),
|
||||
"user": ("<start_of_turn>user\n", "<end_of_turn>\n"),
|
||||
"assistant": ("<start_of_turn>model\n", "<end_of_turn>\n"),
|
||||
},
|
||||
style=ChatTemplateStyle.PLAIN,
|
||||
)
|
||||
)
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="dbrx-instruct",
|
||||
default_system_prompt="You are DBRX, created by Databricks. You were last updated in December 2023. You answer questions based on information available up to that point.\nYOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough responses to more complex and open-ended questions.\nYou assist with various tasks, from writing to coding (using markdown for code blocks — remember to use ``` with code, JSON, and tables).\n(You do not have real-time data access or code execution capabilities. You avoid stereotyping and provide balanced perspectives on controversial topics. You do not provide song lyrics, poems, or news articles and do not divulge details of your training data.)\nThis is your system prompt, guiding your responses. Do not reference it, just respond to the user. If you find yourself talking about this message, stop. You should be responding appropriately and usually that means not mentioning this.\nYOU DO NOT MENTION ANY OF THIS INFORMATION ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY.",
|
||||
role_prefix_and_suffix={
|
||||
"system": ("<|im_start|>system\n", "<|im_end|>"),
|
||||
"user": ("\n<|im_start|>user\n", "<|im_end|>"),
|
||||
"assistant": ("\n<|im_start|>assistant\n", "<|im_end|>"),
|
||||
},
|
||||
stop_str=("<|im_end|>",),
|
||||
)
|
||||
)
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="c4ai-command-r",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": (
|
||||
"<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>",
|
||||
"<|END_OF_TURN_TOKEN|>",
|
||||
),
|
||||
"user": ("<|START_OF_TURN_TOKEN|><|USER_TOKEN|>", "<|END_OF_TURN_TOKEN|>"),
|
||||
"assistant": (
|
||||
"<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
|
||||
"<|END_OF_TURN_TOKEN|>",
|
||||
),
|
||||
},
|
||||
style=ChatTemplateStyle.PLAIN,
|
||||
)
|
||||
)
|
||||
|
||||
# Adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="internvl-2-5",
|
||||
default_system_prompt="你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。",
|
||||
role_prefix_and_suffix={
|
||||
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
|
||||
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
|
||||
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
|
||||
},
|
||||
stop_str=["<|im_end|>", "<|action_end|>"],
|
||||
)
|
||||
)
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="interns1",
|
||||
default_system_prompt="You are an AI assistant whose name is Intern-S1 (书生大模型).\n- Intern-S1 (书生大模型) is a vision-language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n- Intern-S1 (书生大模型) can understand and communicate fluently in the language chosen by the user such as English and 中文.\nYou are an expert reasoner with extensive experience in all areas. You approach problems through systematic thinking and rigorous reasoning. Your response should reflect deep understanding and precise logical thinking, making your solution path and reasoning clear to others. Please put your thinking process within <think>...</think> tags.",
|
||||
role_prefix_and_suffix={
|
||||
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
|
||||
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
|
||||
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
|
||||
},
|
||||
stop_str=["<|im_end|>", "<|action_end|>"],
|
||||
)
|
||||
)
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="granite-3-instruct",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": (
|
||||
"<|start_of_role|>system<|end_of_role|>",
|
||||
"<|end_of_text|>",
|
||||
),
|
||||
"user": (
|
||||
"<|start_of_role|>user<|end_of_role|>",
|
||||
"<|end_of_text|>",
|
||||
),
|
||||
"assistant": (
|
||||
"<|start_of_role|>assistant<|end_of_role|>",
|
||||
"<|end_of_text|>",
|
||||
),
|
||||
},
|
||||
stop_str=("<|end_of_text|>",),
|
||||
)
|
||||
)
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="deepseek-v3",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": (
|
||||
"",
|
||||
"",
|
||||
),
|
||||
"user": (
|
||||
"<|User|>",
|
||||
"",
|
||||
),
|
||||
"assistant": (
|
||||
"<|Assistant|>",
|
||||
"<|end▁of▁sentence|>",
|
||||
),
|
||||
},
|
||||
stop_str=("<|end▁of▁sentence|>",),
|
||||
)
|
||||
)
|
||||
|
||||
# Reference: https://huggingface.co/docs/transformers/main/model_doc/glm4_v#usage-example
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="glm-4v",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": ("<|system|>\n", "\n"),
|
||||
"user": ("<|user|>\n", "\n"),
|
||||
"assistant": ("<|assistant|>\n", "\n"),
|
||||
},
|
||||
style=ChatTemplateStyle.PLAIN,
|
||||
stop_str=["<|user|>", "<|endoftext|>", "<|observation|>"],
|
||||
image_token="<|image|>",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_deepseek(model_path: str):
|
||||
if re.search(r"deepseek-(v3|r1)", model_path, re.IGNORECASE) and not re.search(
|
||||
r"base", model_path, re.IGNORECASE
|
||||
):
|
||||
return "deepseek-v3"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_deepseek_janus_pro(model_path: str):
|
||||
if re.search(r"janus", model_path, re.IGNORECASE):
|
||||
return "janus-pro"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_dbrx(model_path: str):
|
||||
if re.search(r"dbrx", model_path, re.IGNORECASE) and re.search(
|
||||
r"instruct", model_path, re.IGNORECASE
|
||||
):
|
||||
return "dbrx-instruct"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_vicuna(model_path: str):
|
||||
if re.search(r"vicuna|llava-v1\.5|llava-next-video-7b", model_path, re.IGNORECASE):
|
||||
return "vicuna_v1.1"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_llama2_chat(model_path: str):
|
||||
if re.search(
|
||||
r"llama-2.*chat|codellama.*instruct",
|
||||
model_path,
|
||||
re.IGNORECASE,
|
||||
):
|
||||
return "llama-2-chat"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_mistral(model_path: str):
|
||||
if re.search(r"pixtral|(mistral|mixtral).*instruct", model_path, re.IGNORECASE):
|
||||
return "mistral"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_llama3_instruct(model_path: str):
|
||||
if re.search(r"llama-3.*instruct", model_path, re.IGNORECASE):
|
||||
return "llama-3-instruct"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_chat_ml(model_path: str):
|
||||
if re.search(r"tinyllama", model_path, re.IGNORECASE):
|
||||
return "chatml"
|
||||
if re.search(r"qwen.*vl", model_path, re.IGNORECASE):
|
||||
return "qwen2-vl"
|
||||
if re.search(r"glm[-_]?4(\.\d+)?v", model_path, re.IGNORECASE):
|
||||
return "glm-4v"
|
||||
if re.search(r"qwen.*(chat|instruct)", model_path, re.IGNORECASE) and not re.search(
|
||||
r"llava", model_path, re.IGNORECASE
|
||||
):
|
||||
return "qwen"
|
||||
if re.search(
|
||||
r"llava-v1\.6-34b|llava-v1\.6-yi-34b|llava-next-video-34b|llava-onevision-qwen2",
|
||||
model_path,
|
||||
re.IGNORECASE,
|
||||
):
|
||||
return "chatml-llava"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_chat_yi(model_path: str):
|
||||
if re.search(r"yi-vl", model_path, re.IGNORECASE) and not re.search(
|
||||
r"llava", model_path, re.IGNORECASE
|
||||
):
|
||||
return "yi-vl"
|
||||
elif re.search(r"yi-1\.5.*chat", model_path, re.IGNORECASE):
|
||||
return "yi-1.5"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_gemma_it(model_path: str):
|
||||
if re.search(r"gemma.*it", model_path, re.IGNORECASE):
|
||||
return "gemma-it"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_openbmb_minicpm(model_path: str):
|
||||
if re.search(r"minicpm-v", model_path, re.IGNORECASE):
|
||||
return "minicpmv"
|
||||
elif re.search(r"minicpm-o", model_path, re.IGNORECASE):
|
||||
return "minicpmo"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_c4ai_command_r(model_path: str):
|
||||
if re.search(r"c4ai-command-r", model_path, re.IGNORECASE):
|
||||
return "c4ai-command-r"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_granite_instruct(model_path: str):
|
||||
if re.search(r"granite.*instruct", model_path, re.IGNORECASE):
|
||||
return "granite-3-instruct"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_gemma3_instruct(model_path: str):
|
||||
if re.search(r"gemma-3", model_path, re.IGNORECASE):
|
||||
return "gemma-it"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_internvl_chat(model_path: str):
|
||||
if re.search(r"internvl2_5", model_path, re.IGNORECASE):
|
||||
return "internvl-2-5"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_interns1_chat(model_path: str):
|
||||
if re.search(r"intern-s1", model_path, re.IGNORECASE):
|
||||
return "interns1"
|
||||
if re.search(r"interns1", model_path, re.IGNORECASE):
|
||||
return "interns1"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
messages = [
|
||||
{"role": "system", "content": None}, # None means default
|
||||
# {"role": "system", "content": "You are a helpful, respectful and honest assistant."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Hi!"},
|
||||
{"role": "user", "content": "What can you do?"},
|
||||
{"role": "assistant", "content": "I can chat with you."},
|
||||
]
|
||||
|
||||
template = get_chat_template("llama-2-chat")
|
||||
print(template.get_prompt(messages))
|
||||
164
python/sglang/lang/choices.py
Normal file
164
python/sglang/lang/choices.py
Normal file
@@ -0,0 +1,164 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChoicesDecision:
|
||||
decision: str
|
||||
meta_info: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ChoicesSamplingMethod(ABC):
|
||||
|
||||
@property
|
||||
def requires_unconditional_logprobs(self) -> bool:
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def __call__(
|
||||
self,
|
||||
*,
|
||||
choices: List[str],
|
||||
normalized_prompt_logprobs: List[float],
|
||||
input_token_logprobs: List[List[Any]],
|
||||
output_token_logprobs: List[List[Any]],
|
||||
unconditional_token_logprobs: Optional[List[List[Any]]] = None,
|
||||
) -> ChoicesDecision: ...
|
||||
|
||||
|
||||
class TokenLengthNormalized(ChoicesSamplingMethod):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
*,
|
||||
choices: List[str],
|
||||
normalized_prompt_logprobs: List[float],
|
||||
input_token_logprobs: List[List[Any]],
|
||||
output_token_logprobs: List[List[Any]],
|
||||
unconditional_token_logprobs: Optional[List[List[Any]]] = None,
|
||||
) -> ChoicesDecision:
|
||||
"""Select the option with the highest token length normalized prompt logprob."""
|
||||
best_choice = choices[np.argmax(normalized_prompt_logprobs)]
|
||||
meta_info = {
|
||||
"normalized_prompt_logprobs": normalized_prompt_logprobs,
|
||||
"input_token_logprobs": input_token_logprobs,
|
||||
"output_token_logprobs": output_token_logprobs,
|
||||
}
|
||||
return ChoicesDecision(decision=best_choice, meta_info=meta_info)
|
||||
|
||||
|
||||
token_length_normalized = TokenLengthNormalized()
|
||||
|
||||
|
||||
class GreedyTokenSelection(ChoicesSamplingMethod):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
*,
|
||||
choices: List[str],
|
||||
normalized_prompt_logprobs: List[float],
|
||||
input_token_logprobs: List[List[Any]],
|
||||
output_token_logprobs: List[List[Any]],
|
||||
unconditional_token_logprobs: Optional[List[List[Any]]] = None,
|
||||
) -> ChoicesDecision:
|
||||
"""Select the option based on greedy logprob selection. For overlapping options
|
||||
where one option is a subset of a longer option, extend the shorter option using
|
||||
its average logprob for comparison against the longer option."""
|
||||
|
||||
num_options = len(choices)
|
||||
max_tokens = max(len(option) for option in input_token_logprobs)
|
||||
logprob_matrix = self._build_logprob_matrix(
|
||||
input_token_logprobs, max_tokens, num_options
|
||||
)
|
||||
remaining = self._greedy_selection(logprob_matrix, num_options, max_tokens)
|
||||
|
||||
best_choice = choices[remaining[0]]
|
||||
meta_info = {
|
||||
"normalized_prompt_logprobs": normalized_prompt_logprobs,
|
||||
"input_token_logprobs": input_token_logprobs,
|
||||
"output_token_logprobs": output_token_logprobs,
|
||||
"greedy_logprob_matrix": logprob_matrix.tolist(),
|
||||
}
|
||||
return ChoicesDecision(decision=best_choice, meta_info=meta_info)
|
||||
|
||||
def _build_logprob_matrix(self, input_token_logprobs, max_tokens, num_options):
|
||||
logprob_matrix = np.zeros((num_options, max_tokens))
|
||||
for i, option in enumerate(input_token_logprobs):
|
||||
actual_logprobs = [token[0] for token in option]
|
||||
avg_logprob = np.mean(actual_logprobs)
|
||||
logprob_matrix[i, : len(option)] = actual_logprobs
|
||||
if len(option) < max_tokens:
|
||||
logprob_matrix[i, len(option) :] = avg_logprob
|
||||
return logprob_matrix
|
||||
|
||||
def _greedy_selection(self, logprob_matrix, num_options, max_tokens):
|
||||
remaining = np.arange(num_options)
|
||||
for j in range(max_tokens):
|
||||
max_logprob = np.max(logprob_matrix[remaining, j])
|
||||
remaining = remaining[logprob_matrix[remaining, j] == max_logprob]
|
||||
if len(remaining) == 1:
|
||||
break
|
||||
return remaining
|
||||
|
||||
|
||||
greedy_token_selection = GreedyTokenSelection()
|
||||
|
||||
|
||||
class UnconditionalLikelihoodNormalized(ChoicesSamplingMethod):
|
||||
|
||||
@property
|
||||
def requires_unconditional_logprobs(self) -> bool:
|
||||
return True
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
*,
|
||||
choices: List[str],
|
||||
normalized_prompt_logprobs: List[float],
|
||||
input_token_logprobs: List[List[Any]],
|
||||
output_token_logprobs: List[List[Any]],
|
||||
unconditional_token_logprobs: Optional[List[List[Any]]] = None,
|
||||
) -> ChoicesDecision:
|
||||
"""Select the option with the highest average token logprob once normalized by
|
||||
the unconditional token logprobs.
|
||||
|
||||
The first unconditional token logprob is assumed to be None. If so, it is
|
||||
replaced with 0 for the purposes of normalization."""
|
||||
|
||||
if unconditional_token_logprobs is None:
|
||||
raise ValueError(
|
||||
"Unconditional token logprobs are required for this method."
|
||||
)
|
||||
|
||||
normalized_unconditional_prompt_logprobs = self._normalize_logprobs(
|
||||
input_token_logprobs, unconditional_token_logprobs
|
||||
)
|
||||
|
||||
best_choice = choices[np.argmax(normalized_unconditional_prompt_logprobs)]
|
||||
meta_info = {
|
||||
"normalized_prompt_logprobs": normalized_prompt_logprobs,
|
||||
"input_token_logprobs": input_token_logprobs,
|
||||
"output_token_logprobs": output_token_logprobs,
|
||||
"unconditional_token_logprobs": unconditional_token_logprobs,
|
||||
"normalized_unconditional_prompt_logprobs": normalized_unconditional_prompt_logprobs,
|
||||
}
|
||||
return ChoicesDecision(decision=best_choice, meta_info=meta_info)
|
||||
|
||||
def _normalize_logprobs(self, input_token_logprobs, unconditional_token_logprobs):
|
||||
normalized_unconditional_prompt_logprobs = []
|
||||
for inputs, unconditionals in zip(
|
||||
input_token_logprobs, unconditional_token_logprobs
|
||||
):
|
||||
inputs_logprobs = np.array([token[0] for token in inputs])
|
||||
unconditionals_logprobs = np.array([token[0] for token in unconditionals])
|
||||
unconditionals_logprobs[0] = unconditionals_logprobs[0] or 0
|
||||
normalized_unconditional_prompt_logprobs.append(
|
||||
float(np.mean(inputs_logprobs - unconditionals_logprobs))
|
||||
)
|
||||
return normalized_unconditional_prompt_logprobs
|
||||
|
||||
|
||||
unconditional_likelihood_normalized = UnconditionalLikelihoodNormalized()
|
||||
231
python/sglang/lang/compiler.py
Normal file
231
python/sglang/lang/compiler.py
Normal file
@@ -0,0 +1,231 @@
|
||||
import multiprocessing
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from queue import Queue
|
||||
from typing import List, Union
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.interpreter import ProgramState, StreamExecutor, cache_program
|
||||
from sglang.lang.ir import SglArgument, SglExpr, SglSamplingParams, SglVariable
|
||||
|
||||
|
||||
def compile_func(function, backend):
|
||||
tracer = function.trace(backend=backend)
|
||||
compiler = CompiledFunction(tracer, function)
|
||||
return compiler
|
||||
|
||||
|
||||
class CompiledFunction:
|
||||
def __init__(self, tracer, function):
|
||||
self.function = function
|
||||
|
||||
self.last_node = CompGraphNode(tracer.last_node)
|
||||
self.expr_to_node = {}
|
||||
self.build_graph(tracer)
|
||||
self.topological_sort()
|
||||
|
||||
def build_graph(self, tracer):
|
||||
self.nodes = [self.last_node]
|
||||
self.expr_to_node[tracer.last_node] = self.nodes[-1]
|
||||
|
||||
rename_pid = {}
|
||||
|
||||
visited = set([tracer.last_node])
|
||||
head = 0
|
||||
while head < len(self.nodes):
|
||||
cur_node = self.nodes[head]
|
||||
|
||||
# add prev node
|
||||
prev_node = cur_node.expr.prev_node
|
||||
if prev_node is not None:
|
||||
if prev_node not in visited:
|
||||
visited.add(prev_node)
|
||||
self.nodes.append(CompGraphNode(prev_node))
|
||||
self.expr_to_node[prev_node] = self.nodes[-1]
|
||||
cur_node.prev_node = self.expr_to_node[prev_node]
|
||||
self.expr_to_node[prev_node].add_next_node(cur_node)
|
||||
|
||||
# add source node
|
||||
if isinstance(cur_node.expr, SglVariable):
|
||||
if cur_node.expr.name in tracer.variables:
|
||||
source = tracer.variables[cur_node.expr.name].source
|
||||
else:
|
||||
source = cur_node.expr.source
|
||||
if source not in visited:
|
||||
visited.add(source)
|
||||
self.nodes.append(CompGraphNode(source))
|
||||
self.expr_to_node[source] = self.nodes[-1]
|
||||
cur_node.source_node = self.expr_to_node[source]
|
||||
self.expr_to_node[source].add_next_node(cur_node)
|
||||
head += 1
|
||||
|
||||
# rename pid
|
||||
if cur_node.expr.pid not in rename_pid:
|
||||
rename_pid[cur_node.expr.pid] = len(rename_pid)
|
||||
cur_node.expr.pid = rename_pid[cur_node.expr.pid]
|
||||
|
||||
def topological_sort(self):
|
||||
prevd = {}
|
||||
cand = Queue()
|
||||
for x in self.nodes:
|
||||
prevd[x] = (x.prev_node is not None) + (x.source_node is not None)
|
||||
if prevd[x] == 0:
|
||||
cand.put(x)
|
||||
new_list = []
|
||||
while cand.qsize() > 0:
|
||||
head = cand.get()
|
||||
new_list.append(head)
|
||||
for x in head.next_nodes:
|
||||
prevd[x] -= 1
|
||||
if prevd[x] == 0:
|
||||
cand.put(x)
|
||||
self.nodes = new_list
|
||||
|
||||
def print_graph(
|
||||
self,
|
||||
):
|
||||
for node in self.nodes:
|
||||
print(node)
|
||||
|
||||
def run_internal(
|
||||
self,
|
||||
backend,
|
||||
kwargs,
|
||||
default_sampling_para,
|
||||
):
|
||||
stream_executor_ids = set([x.expr.pid for x in self.nodes])
|
||||
stream_executors = {}
|
||||
for x in stream_executor_ids:
|
||||
arguments = kwargs if x == self.last_node.expr.pid else {}
|
||||
stream_executors[x] = StreamExecutor(
|
||||
backend, arguments, default_sampling_para, None, False
|
||||
)
|
||||
for node in self.nodes:
|
||||
se_id = node.expr.pid
|
||||
expr = node.expr
|
||||
if isinstance(expr, SglVariable):
|
||||
# Make a copy for SglVariable
|
||||
expr = SglVariable(expr.name, expr.source)
|
||||
expr.source_stream_executor = stream_executors[
|
||||
node.source_node.expr.pid
|
||||
]
|
||||
elif isinstance(expr, SglArgument):
|
||||
# Substitute SglArgument
|
||||
expr = kwargs[expr.name]
|
||||
stream_executors[se_id].submit(expr)
|
||||
for stream_executor in stream_executors.values():
|
||||
stream_executor.end()
|
||||
return ProgramState(stream_executors[self.last_node.expr.pid])
|
||||
|
||||
def run(
|
||||
self,
|
||||
*,
|
||||
max_new_tokens: int = 128,
|
||||
stop: Union[str, List[str]] = (),
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
min_p: float = 0.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
backend=None,
|
||||
**kwargs,
|
||||
):
|
||||
backend = backend or global_config.default_backend
|
||||
|
||||
kwargs.update(self.function.bind_arguments)
|
||||
|
||||
default_sampling_para = SglSamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop=stop,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
)
|
||||
|
||||
return self.run_internal(backend, kwargs, default_sampling_para)
|
||||
|
||||
def run_batch(
|
||||
self,
|
||||
batch_kwargs,
|
||||
*,
|
||||
max_new_tokens: int = 128,
|
||||
stop: Union[str, List[str]] = (),
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
min_p: float = 0.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
backend=None,
|
||||
num_threads: Union[str, int] = "auto",
|
||||
):
|
||||
assert isinstance(batch_kwargs, (list, tuple))
|
||||
if len(batch_kwargs) == 0:
|
||||
return []
|
||||
assert isinstance(batch_kwargs[0], dict)
|
||||
|
||||
backend = backend or global_config.default_backend
|
||||
|
||||
default_sampling_para = SglSamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop=stop,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
)
|
||||
|
||||
# Extract prefix by tracing and cache it
|
||||
if len(batch_kwargs) > 1:
|
||||
cache_program(self.function, backend)
|
||||
|
||||
# Run all programs
|
||||
if num_threads == "auto":
|
||||
num_threads = multiprocessing.cpu_count()
|
||||
num_threads = min(num_threads, len(batch_kwargs))
|
||||
|
||||
if num_threads == 1:
|
||||
rets = []
|
||||
for arguments in batch_kwargs:
|
||||
rets.append(
|
||||
self.run_internal(backend, arguments, default_sampling_para)
|
||||
)
|
||||
else:
|
||||
with ThreadPoolExecutor(num_threads) as executor:
|
||||
futures = []
|
||||
for arguments in batch_kwargs:
|
||||
futures.append(
|
||||
executor.submit(
|
||||
self.run_internal, backend, arguments, default_sampling_para
|
||||
)
|
||||
)
|
||||
rets = [f.result() for f in futures]
|
||||
rets[-1].sync()
|
||||
|
||||
return rets
|
||||
|
||||
|
||||
class CompGraphNode:
|
||||
def __init__(
|
||||
self, expr: SglExpr, prev_node=None, next_nodes=None, source_node=None
|
||||
):
|
||||
self.expr = expr
|
||||
self.next_nodes = next_nodes or []
|
||||
self.prev_node = prev_node
|
||||
self.source_node = source_node
|
||||
|
||||
def add_next_node(self, other):
|
||||
self.next_nodes.append(other)
|
||||
|
||||
def __repr__(self):
|
||||
re = f"stream {self.expr.pid:2d}: "
|
||||
re += f"%{self.expr.node_id} = "
|
||||
if self.prev_node is not None:
|
||||
re += f"%{self.prev_node.expr.node_id} + "
|
||||
re += repr(self.expr)
|
||||
return re
|
||||
1060
python/sglang/lang/interpreter.py
Normal file
1060
python/sglang/lang/interpreter.py
Normal file
File diff suppressed because it is too large
Load Diff
635
python/sglang/lang/ir.py
Normal file
635
python/sglang/lang/ir.py
Normal file
@@ -0,0 +1,635 @@
|
||||
"""The intermediate representation."""
|
||||
|
||||
import dataclasses
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.choices import ChoicesSamplingMethod
|
||||
|
||||
REGEX_INT = r"[-+]?[0-9]+[ \n]*"
|
||||
REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+[ \n]*"
|
||||
REGEX_BOOL = r"(True|False)"
|
||||
REGEX_STR = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SglSamplingParams:
|
||||
max_new_tokens: int = 128
|
||||
min_new_tokens: int = 0
|
||||
n: int = 1
|
||||
stop: Union[str, List[str]] = ()
|
||||
stop_token_ids: Optional[List[int]] = ()
|
||||
temperature: float = 1.0
|
||||
top_p: float = 1.0
|
||||
top_k: int = -1 # -1 means disable
|
||||
min_p: float = 0.0
|
||||
frequency_penalty: float = 0.0
|
||||
presence_penalty: float = 0.0
|
||||
ignore_eos: bool = False
|
||||
return_logprob: Optional[bool] = None
|
||||
logprob_start_len: Optional[int] = (None,)
|
||||
top_logprobs_num: Optional[int] = (None,)
|
||||
return_text_in_logprobs: Optional[bool] = (None,)
|
||||
json_schema: Optional[str] = None
|
||||
|
||||
# for constrained generation, not included in to_xxx_kwargs
|
||||
dtype: Optional[str] = None
|
||||
regex: Optional[str] = None
|
||||
|
||||
def clone(self):
|
||||
return SglSamplingParams(
|
||||
self.max_new_tokens,
|
||||
self.min_new_tokens,
|
||||
self.n,
|
||||
self.stop,
|
||||
self.stop_token_ids,
|
||||
self.temperature,
|
||||
self.top_p,
|
||||
self.top_k,
|
||||
self.min_p,
|
||||
self.frequency_penalty,
|
||||
self.presence_penalty,
|
||||
self.ignore_eos,
|
||||
self.return_logprob,
|
||||
self.logprob_start_len,
|
||||
self.top_logprobs_num,
|
||||
self.return_text_in_logprobs,
|
||||
self.json_schema,
|
||||
)
|
||||
|
||||
def to_openai_kwargs(self):
|
||||
# OpenAI does not support top_k, so we drop it here
|
||||
if self.regex is not None:
|
||||
warnings.warn("Regular expression is not supported in the OpenAI backend.")
|
||||
return {
|
||||
"max_tokens": self.max_new_tokens,
|
||||
"max_completion_tokens": self.max_new_tokens,
|
||||
"n": self.n,
|
||||
"stop": self.stop or None,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
}
|
||||
|
||||
def to_vertexai_kwargs(self):
|
||||
if self.regex is not None:
|
||||
warnings.warn(
|
||||
"Regular expression is not supported in the VertexAI backend."
|
||||
)
|
||||
return {
|
||||
"candidate_count": 1,
|
||||
"max_output_tokens": self.max_new_tokens,
|
||||
"stop_sequences": self.stop,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k if self.top_k > 0 else None,
|
||||
}
|
||||
|
||||
def to_anthropic_kwargs(self):
|
||||
# Anthropic does not support frequency_penalty or presence_penalty, so we drop it here
|
||||
if self.regex is not None:
|
||||
warnings.warn(
|
||||
"Regular expression is not supported in the Anthropic backend."
|
||||
)
|
||||
return {
|
||||
"max_tokens": self.max_new_tokens,
|
||||
"stop_sequences": (
|
||||
self.stop if isinstance(self.stop, (list, tuple)) else [self.stop]
|
||||
),
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
}
|
||||
|
||||
def to_litellm_kwargs(self):
|
||||
if self.regex is not None:
|
||||
warnings.warn("Regular expression is not supported in the LiteLLM backend.")
|
||||
return {
|
||||
"max_tokens": self.max_new_tokens,
|
||||
"stop": self.stop or None,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
}
|
||||
|
||||
def to_srt_kwargs(self):
|
||||
return {
|
||||
"max_new_tokens": self.max_new_tokens,
|
||||
"min_new_tokens": self.min_new_tokens,
|
||||
"n": self.n,
|
||||
"stop": self.stop,
|
||||
"stop_token_ids": self.stop_token_ids,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
"min_p": self.min_p,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"ignore_eos": self.ignore_eos,
|
||||
"regex": self.regex,
|
||||
"json_schema": self.json_schema,
|
||||
}
|
||||
|
||||
|
||||
class SglFunction:
|
||||
def __init__(self, func, num_api_spec_tokens=None, bind_arguments=None):
|
||||
self.func = func
|
||||
self.num_api_spec_tokens = num_api_spec_tokens
|
||||
self.bind_arguments = bind_arguments or {}
|
||||
self.pin_prefix_rid = None
|
||||
|
||||
# Parse arguments
|
||||
argspec = inspect.getfullargspec(func)
|
||||
assert argspec.args[0] == "s", 'The first argument must be "s"'
|
||||
self.arg_names = argspec.args[1:]
|
||||
self.arg_defaults = argspec.defaults if argspec.defaults is not None else []
|
||||
|
||||
def bind(self, **kwargs):
|
||||
assert all(key in self.arg_names for key in kwargs)
|
||||
|
||||
new_bind_dict = {**self.bind_arguments, **kwargs}
|
||||
return SglFunction(self.func, bind_arguments=new_bind_dict)
|
||||
|
||||
def run(
|
||||
self,
|
||||
*args,
|
||||
max_new_tokens: int = 128,
|
||||
n: int = 1,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
min_p: float = 0.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
ignore_eos: bool = False,
|
||||
return_logprob: Optional[bool] = None,
|
||||
logprob_start_len: Optional[int] = None,
|
||||
top_logprobs_num: Optional[int] = None,
|
||||
return_text_in_logprobs: Optional[bool] = None,
|
||||
stream: bool = False,
|
||||
backend=None,
|
||||
use_thread: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
from sglang.lang.interpreter import run_program
|
||||
|
||||
# avoid using [] as the default arg: https://nikos7am.com/posts/mutable-default-arguments/
|
||||
if stop is None:
|
||||
stop = []
|
||||
if stop_token_ids is None:
|
||||
stop_token_ids = []
|
||||
|
||||
default_sampling_para = SglSamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
n=n,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
ignore_eos=ignore_eos,
|
||||
return_logprob=return_logprob,
|
||||
logprob_start_len=logprob_start_len,
|
||||
top_logprobs_num=top_logprobs_num,
|
||||
return_text_in_logprobs=return_text_in_logprobs,
|
||||
)
|
||||
backend = backend or global_config.default_backend
|
||||
return run_program(
|
||||
self,
|
||||
backend,
|
||||
args,
|
||||
kwargs,
|
||||
default_sampling_para,
|
||||
stream,
|
||||
use_thread=use_thread,
|
||||
)
|
||||
|
||||
def run_batch(
|
||||
self,
|
||||
batch_kwargs,
|
||||
*,
|
||||
max_new_tokens: int = 128,
|
||||
n: int = 1,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
min_p: float = 0.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
ignore_eos: bool = False,
|
||||
return_logprob: Optional[bool] = None,
|
||||
logprob_start_len: Optional[int] = None,
|
||||
top_logprobs_num: Optional[int] = None,
|
||||
return_text_in_logprobs: Optional[bool] = None,
|
||||
backend=None,
|
||||
num_threads: Union[str, int] = "auto",
|
||||
progress_bar: bool = False,
|
||||
generator_style: bool = False,
|
||||
):
|
||||
from sglang.lang.interpreter import run_program_batch
|
||||
|
||||
if stop is None:
|
||||
stop = []
|
||||
if stop_token_ids is None:
|
||||
stop_token_ids = []
|
||||
|
||||
assert isinstance(batch_kwargs, (list, tuple))
|
||||
if len(batch_kwargs) == 0:
|
||||
return []
|
||||
if not isinstance(batch_kwargs[0], dict):
|
||||
num_programs = len(batch_kwargs)
|
||||
# change the list of argument values to dict of arg_name -> arg_value
|
||||
batch_kwargs = [
|
||||
{self.arg_names[i]: v for i, v in enumerate(arg_values)}
|
||||
for arg_values in batch_kwargs
|
||||
if isinstance(arg_values, (list, tuple))
|
||||
and len(self.arg_names) - len(self.arg_defaults)
|
||||
<= len(arg_values)
|
||||
<= len(self.arg_names)
|
||||
]
|
||||
# Ensure to raise an exception if the number of arguments mismatch
|
||||
if len(batch_kwargs) != num_programs:
|
||||
raise Exception("Given arguments mismatch the SGL function signature")
|
||||
|
||||
default_sampling_para = SglSamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
n=n,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
ignore_eos=ignore_eos,
|
||||
return_logprob=return_logprob,
|
||||
logprob_start_len=logprob_start_len,
|
||||
top_logprobs_num=top_logprobs_num,
|
||||
return_text_in_logprobs=return_text_in_logprobs,
|
||||
)
|
||||
backend = backend or global_config.default_backend
|
||||
return run_program_batch(
|
||||
self,
|
||||
backend,
|
||||
batch_kwargs,
|
||||
default_sampling_para,
|
||||
num_threads,
|
||||
progress_bar,
|
||||
generator_style=generator_style,
|
||||
)
|
||||
|
||||
def trace(self, *, backend=None, **kwargs):
|
||||
from sglang.lang.tracer import trace_program
|
||||
|
||||
backend = backend or global_config.default_backend
|
||||
return trace_program(self, kwargs, backend)
|
||||
|
||||
def cache(self, backend=None):
|
||||
from sglang.lang.interpreter import cache_program
|
||||
|
||||
backend = backend or global_config.default_backend
|
||||
return cache_program(self, backend)
|
||||
|
||||
def compile(self, *, backend=None):
|
||||
from sglang.lang.compiler import compile_func
|
||||
|
||||
return compile_func(self, backend)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
from sglang.lang.tracer import TracingScope
|
||||
|
||||
tracing_scope = TracingScope.get_current_scope()
|
||||
if tracing_scope is None:
|
||||
return self.run(*args, **kwargs)
|
||||
else:
|
||||
kwargs["backend"] = tracing_scope.tracer_state.backend
|
||||
return self.trace(*args, **kwargs)
|
||||
|
||||
|
||||
class SglExpr:
|
||||
node_ct = 0
|
||||
|
||||
def __init__(self):
|
||||
self.node_id = SglExpr.node_ct
|
||||
self.prev_node = None
|
||||
self.pid = None
|
||||
SglExpr.node_ct += 1
|
||||
|
||||
def __add__(self, other):
|
||||
if isinstance(other, str):
|
||||
other = SglConstantText(other)
|
||||
assert isinstance(other, SglExpr)
|
||||
|
||||
return self.concatenate_ir(self, other)
|
||||
|
||||
def __radd__(self, other):
|
||||
if isinstance(other, str):
|
||||
other = SglConstantText(other)
|
||||
assert isinstance(other, SglExpr), f"{other}"
|
||||
|
||||
return self.concatenate_ir(other, self)
|
||||
|
||||
def concatenate_ir(self, a, b):
|
||||
if isinstance(a, SglExprList):
|
||||
if isinstance(b, SglExprList):
|
||||
return SglExprList(a.expr_list + b.expr_list)
|
||||
else:
|
||||
return SglExprList(a.expr_list + [b])
|
||||
elif isinstance(b, SglExprList):
|
||||
return SglExprList([a] + b.expr_list)
|
||||
|
||||
return SglExprList([a, b])
|
||||
|
||||
def print_graph_dfs(self):
|
||||
ret = [""]
|
||||
visited = set()
|
||||
|
||||
def dfs_print(x):
|
||||
if x is None or x in visited:
|
||||
return
|
||||
visited.add(x)
|
||||
|
||||
# Print dependency
|
||||
if x.prev_node is not None:
|
||||
dfs_print(x.prev_node)
|
||||
|
||||
if isinstance(x, SglExprList):
|
||||
for y in x.expr_list:
|
||||
dfs_print(y)
|
||||
# elif isinstance(x, SglRole):
|
||||
# dfs_print(x.expr)
|
||||
elif isinstance(x, SglVariable):
|
||||
dfs_print(x.source)
|
||||
|
||||
# Print the node itself
|
||||
if isinstance(x, (SglFork, SglGetForkItem)):
|
||||
ret[0] += f"%{x.node_id} = {x}\n"
|
||||
else:
|
||||
if x.prev_node is not None:
|
||||
ret[0] += (
|
||||
f"%{x.node_id} = %{x.prev_node.node_id} + " + str(x) + "\n"
|
||||
)
|
||||
else:
|
||||
ret[0] += f"%{x.node_id} = " + str(x) + "\n"
|
||||
|
||||
dfs_print(self)
|
||||
return ret[0]
|
||||
|
||||
|
||||
class SglExprList(SglExpr):
|
||||
def __init__(self, expr_list: List[SglExpr]):
|
||||
super().__init__()
|
||||
self.expr_list = expr_list
|
||||
|
||||
def __repr__(self):
|
||||
return f"ExprList({self.expr_list})"
|
||||
|
||||
|
||||
class SglArgument(SglExpr):
|
||||
def __init__(self, name: str, value: str):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.value = value
|
||||
|
||||
def __repr__(self):
|
||||
return f"Argument(name={self.name}, value={repr(self.value)})"
|
||||
|
||||
def __len__(self):
|
||||
return len(self.value)
|
||||
|
||||
def __getitem__(self, i):
|
||||
return self.value[i]
|
||||
|
||||
def __int__(self):
|
||||
return self.value
|
||||
|
||||
def __bool__(self):
|
||||
return self.value
|
||||
|
||||
def __format__(self, *args):
|
||||
raise TypeError(
|
||||
"Cannot put argument inside a f-string. "
|
||||
"This is not compatible with the tracer. "
|
||||
)
|
||||
|
||||
|
||||
class SglImage(SglExpr):
|
||||
def __init__(self, path: str):
|
||||
self.path = path
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"SglImage({self.path})"
|
||||
|
||||
|
||||
class SglVideo(SglExpr):
|
||||
def __init__(self, path: str, num_frames: int):
|
||||
self.path = path
|
||||
self.num_frames = num_frames
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"SglVideo({self.path}, {self.num_frames})"
|
||||
|
||||
|
||||
class SglGen(SglExpr):
|
||||
def __init__(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
min_new_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
min_p: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
ignore_eos: Optional[bool] = None,
|
||||
return_logprob: Optional[bool] = None,
|
||||
logprob_start_len: Optional[int] = None,
|
||||
top_logprobs_num: Optional[int] = None,
|
||||
return_text_in_logprobs: Optional[bool] = None,
|
||||
dtype: Optional[type] = None,
|
||||
regex: Optional[str] = None,
|
||||
json_schema: Optional[str] = None,
|
||||
):
|
||||
"""Call the model to generate. See the meaning of the arguments in docs/backend/sampling_params.md"""
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.sampling_params = SglSamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
min_new_tokens=min_new_tokens,
|
||||
n=n,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
ignore_eos=ignore_eos,
|
||||
return_logprob=return_logprob,
|
||||
logprob_start_len=logprob_start_len,
|
||||
top_logprobs_num=top_logprobs_num,
|
||||
return_text_in_logprobs=return_text_in_logprobs,
|
||||
dtype=dtype,
|
||||
regex=regex,
|
||||
json_schema=json_schema,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Gen('{self.name}')"
|
||||
|
||||
|
||||
class SglConstantText(SglExpr):
|
||||
def __init__(self, value: str):
|
||||
super().__init__()
|
||||
self.value = value
|
||||
|
||||
def __repr__(self):
|
||||
return f"Constant({repr(self.value)})"
|
||||
|
||||
|
||||
class SglRoleBegin(SglExpr):
|
||||
def __init__(self, role: str):
|
||||
super().__init__()
|
||||
self.role = role
|
||||
|
||||
def __repr__(self):
|
||||
return f"RoleBegin({self.role})"
|
||||
|
||||
|
||||
class SglRoleEnd(SglExpr):
|
||||
def __init__(self, role: str):
|
||||
super().__init__()
|
||||
self.role = role
|
||||
|
||||
def __repr__(self):
|
||||
return f"RoleEnd({self.role})"
|
||||
|
||||
|
||||
class SglSelect(SglExpr):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
choices: List[str],
|
||||
temperature: float,
|
||||
choices_method: ChoicesSamplingMethod,
|
||||
):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.choices = choices
|
||||
self.temperature = temperature
|
||||
self.choices_method = choices_method
|
||||
|
||||
def __repr__(self):
|
||||
return f"Select({self.name}, choices={self.choices}, choices_method={self.choices_method})"
|
||||
|
||||
|
||||
class SglFork(SglExpr):
|
||||
def __init__(self, number: int, position_ids_offset=None):
|
||||
super().__init__()
|
||||
self.number = number
|
||||
self.position_ids_offset = position_ids_offset
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"Fork(%{self.prev_node.node_id}, number={self.number}, "
|
||||
f"position_ids_offset={self.position_ids_offset})"
|
||||
)
|
||||
|
||||
|
||||
class SglGetForkItem(SglExpr):
|
||||
def __init__(self, index: int):
|
||||
super().__init__()
|
||||
self.index = index
|
||||
|
||||
def __repr__(self):
|
||||
return f"GetForkItem(%{self.prev_node.node_id}, index={self.index})"
|
||||
|
||||
|
||||
class SglVariable(SglExpr):
|
||||
def __init__(self, name: str, source):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.source = source
|
||||
|
||||
def __repr__(self):
|
||||
return f"Variable('{self.name}', source=%{self.source.node_id})"
|
||||
|
||||
|
||||
class SglVarScopeBegin(SglExpr):
|
||||
def __init__(self, name: str):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
|
||||
def __repr__(self):
|
||||
return f"VarScopeBegin('{self.name}')"
|
||||
|
||||
|
||||
class SglVarScopeEnd(SglExpr):
|
||||
def __init__(self, name: str):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
|
||||
def __repr__(self):
|
||||
return f"VarScopeEnd('{self.name}')"
|
||||
|
||||
|
||||
class SglConcateAndAppend(SglExpr):
|
||||
def __init__(self, states):
|
||||
super().__init__()
|
||||
self.states = states
|
||||
|
||||
def __repr__(self):
|
||||
return f"ConcatenateAndAppend('{self.states}')"
|
||||
|
||||
|
||||
class SglCommitLazy(SglExpr):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def __repr__(self):
|
||||
return "CommitLazy()"
|
||||
|
||||
|
||||
class SglSeparateReasoning(SglExpr):
|
||||
def __init__(self, model_type: str, expr: SglExpr):
|
||||
super().__init__()
|
||||
self.model_type = model_type
|
||||
|
||||
self.expr = expr
|
||||
self.name = None
|
||||
self._process_expr(expr)
|
||||
|
||||
def process_name_for_reasoning(self, name):
|
||||
if not name:
|
||||
raise ValueError("name must be provided")
|
||||
return f"{name}_reasoning_content"
|
||||
|
||||
def _process_expr(self, expr):
|
||||
if isinstance(expr, SglGen):
|
||||
self.name = self.process_name_for_reasoning(expr.name)
|
||||
elif isinstance(expr, SglSelect):
|
||||
self.name = self.process_name_for_reasoning(expr.name)
|
||||
elif isinstance(expr, SglExprList):
|
||||
for x in expr.expr_list:
|
||||
self._process_expr(x)
|
||||
|
||||
def __repr__(self):
|
||||
return f"SeparateReasoning(model_type={self.model_type}, name={self.name})"
|
||||
279
python/sglang/lang/tracer.py
Normal file
279
python/sglang/lang/tracer.py
Normal file
@@ -0,0 +1,279 @@
|
||||
"""Tracing a program."""
|
||||
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sglang.lang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.interpreter import ProgramState, ProgramStateGroup
|
||||
from sglang.lang.ir import (
|
||||
SglArgument,
|
||||
SglConstantText,
|
||||
SglExpr,
|
||||
SglExprList,
|
||||
SglFork,
|
||||
SglGen,
|
||||
SglGetForkItem,
|
||||
SglRoleBegin,
|
||||
SglRoleEnd,
|
||||
SglSelect,
|
||||
SglVariable,
|
||||
SglVarScopeBegin,
|
||||
SglVarScopeEnd,
|
||||
)
|
||||
|
||||
|
||||
class StopTracing(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def extract_prefix_by_tracing(program, backend):
|
||||
# Create dummy arguments
|
||||
dummy_arguments = {name: SglArgument(name, None) for name in program.arg_names}
|
||||
arguments = dummy_arguments
|
||||
arguments.update(program.bind_arguments)
|
||||
|
||||
# Trace
|
||||
tracer = TracerProgramState(backend, arguments, only_trace_prefix=True)
|
||||
try:
|
||||
with TracingScope(tracer):
|
||||
tracer.ret_value = program.func(tracer, **arguments)
|
||||
except (StopTracing, TypeError, AttributeError):
|
||||
# Some exceptions may not be caught
|
||||
pass
|
||||
|
||||
# Run and cache prefix
|
||||
prefix = ""
|
||||
for expr in tracer.flatten_nodes():
|
||||
if isinstance(expr, SglConstantText):
|
||||
prefix += expr.value
|
||||
else:
|
||||
break
|
||||
return prefix
|
||||
|
||||
|
||||
def trace_program(program, arguments, backend):
|
||||
# Create dummy backend
|
||||
if backend is None:
|
||||
backend = BaseBackend()
|
||||
|
||||
# Create dummy arguments
|
||||
dummy_arguments = {
|
||||
name: SglArgument(name, None)
|
||||
for name in program.arg_names
|
||||
if name not in arguments
|
||||
}
|
||||
arguments.update(dummy_arguments)
|
||||
arguments.update(program.bind_arguments)
|
||||
|
||||
# Trace
|
||||
tracer = TracerProgramState(backend, arguments, only_trace_prefix=False)
|
||||
with TracingScope(tracer):
|
||||
tracer.ret_value = program.func(tracer, **arguments)
|
||||
return tracer
|
||||
|
||||
|
||||
class TracerProgramState(ProgramState):
|
||||
def __init__(self, backend, arguments, only_trace_prefix):
|
||||
self.pid = uuid.uuid4().hex
|
||||
self.backend = backend
|
||||
self.arguments: Dict[str, Any] = arguments
|
||||
self.only_trace_prefix = only_trace_prefix
|
||||
|
||||
if hasattr(backend, "endpoint"):
|
||||
self.backend = backend.endpoint
|
||||
|
||||
self.nodes = []
|
||||
self.last_node = None
|
||||
self.variables = {}
|
||||
self.ret_value = None
|
||||
|
||||
# For completion
|
||||
|
||||
# For chat
|
||||
self.messages_ = []
|
||||
self.cur_role = None
|
||||
self.chat_template = self.backend.get_chat_template()
|
||||
|
||||
# For multi states
|
||||
self.child_states = []
|
||||
|
||||
cur_scope = TracingScope.get_current_scope()
|
||||
if cur_scope is not None:
|
||||
cur_scope.add_child_state(self)
|
||||
|
||||
##################################
|
||||
########### Public API ###########
|
||||
##################################
|
||||
|
||||
def fork(self, size: int = 1, position_ids_offset: Optional[List[int]] = None):
|
||||
assert size >= 1
|
||||
|
||||
if self.only_trace_prefix:
|
||||
raise StopTracing()
|
||||
|
||||
fork_node = SglFork(size)
|
||||
fork_node.prev_node = self.last_node
|
||||
|
||||
states = [
|
||||
TracerProgramState(self.backend, self.arguments, self.only_trace_prefix)
|
||||
for _ in range(size)
|
||||
]
|
||||
|
||||
for i in range(size):
|
||||
node = SglGetForkItem(i)
|
||||
node.prev_node = fork_node
|
||||
states[i].last_node = node
|
||||
states[i].variables = dict(self.variables)
|
||||
states[i].messages_ = list(self.messages_)
|
||||
states[i].cur_role = self.cur_role
|
||||
states[i].chat_template = self.chat_template
|
||||
|
||||
state_group = ProgramStateGroup(states, self)
|
||||
|
||||
return state_group
|
||||
|
||||
##################################
|
||||
########## Internal API ##########
|
||||
##################################
|
||||
|
||||
def _append_node(self, other: SglExpr):
|
||||
self.nodes.append(other)
|
||||
other.prev_node = self.last_node
|
||||
self.last_node = other
|
||||
|
||||
def _execute(self, other: SglExpr):
|
||||
if isinstance(other, str):
|
||||
other = SglConstantText(other)
|
||||
|
||||
other.pid = self.pid
|
||||
|
||||
if isinstance(other, SglConstantText):
|
||||
self._execute_fill(other)
|
||||
elif isinstance(other, SglGen):
|
||||
self._execute_gen(other)
|
||||
elif isinstance(other, SglSelect):
|
||||
self._execute_select(other)
|
||||
elif isinstance(other, SglExprList):
|
||||
for x in other.expr_list:
|
||||
self._execute(x)
|
||||
elif isinstance(other, SglRoleBegin):
|
||||
self._execute_role_begin(other)
|
||||
elif isinstance(other, SglRoleEnd):
|
||||
self._execute_role_end(other)
|
||||
elif isinstance(other, SglVarScopeBegin):
|
||||
self._execute_var_scope_begin(other)
|
||||
elif isinstance(other, SglVarScopeEnd):
|
||||
self._execute_var_scope_end(other)
|
||||
else:
|
||||
if self.only_trace_prefix:
|
||||
raise StopTracing()
|
||||
else:
|
||||
self._append_node(other)
|
||||
|
||||
return self
|
||||
|
||||
def __iadd__(self, other):
|
||||
self._execute(other)
|
||||
return self
|
||||
|
||||
def _execute_fill(self, expr: SglConstantText):
|
||||
if isinstance(expr, str):
|
||||
expr = SglConstantText(expr)
|
||||
self._append_node(expr)
|
||||
|
||||
def _execute_gen(self, expr: SglGen):
|
||||
name = expr.name if expr.name is not None else "gen_" + str(len(self.variables))
|
||||
new_node = SglVariable(name, source=expr)
|
||||
self.variables[name] = new_node
|
||||
self._append_node(expr)
|
||||
|
||||
def _execute_select(self, expr: SglSelect):
|
||||
name = (
|
||||
expr.name if expr.name is not None else "select_" + str(len(self.variables))
|
||||
)
|
||||
new_node = SglVariable(name, source=expr)
|
||||
self.variables[name] = new_node
|
||||
self._append_node(expr)
|
||||
|
||||
def _execute_role_begin(self, expr: SglRoleBegin):
|
||||
assert self.cur_role is None, "Nested roles are not allowed."
|
||||
|
||||
if len(self.messages_) == 0 and expr.role != "system":
|
||||
# Insert default system message
|
||||
default_system = self.chat_template.default_system_prompt
|
||||
if default_system:
|
||||
self._execute_role_begin(SglRoleBegin("system"))
|
||||
self._execute_fill(default_system)
|
||||
self._execute_role_end(SglRoleEnd("system"))
|
||||
|
||||
self.cur_role = expr.role
|
||||
|
||||
prefix, suffix = self.chat_template.get_prefix_and_suffix(
|
||||
expr.role, self.messages_
|
||||
)
|
||||
|
||||
self._execute_fill(prefix)
|
||||
|
||||
def _execute_role_end(self, expr: SglRoleEnd):
|
||||
prefix, suffix = self.chat_template.get_prefix_and_suffix(
|
||||
expr.role, self.messages_
|
||||
)
|
||||
|
||||
self._execute_fill(suffix)
|
||||
|
||||
self.messages_.append({"role": expr.role, "content": ""})
|
||||
|
||||
self.cur_role = None
|
||||
|
||||
def _execute_var_scope_end(self, expr: SglVarScopeEnd):
|
||||
new_node = SglVariable(expr.name, source=self.last_node)
|
||||
self.variables[expr.name] = new_node
|
||||
|
||||
def get_var(self, name):
|
||||
ret = self.arguments.get(name, None)
|
||||
if ret is not None:
|
||||
return ret
|
||||
|
||||
v = self.variables[name]
|
||||
return SglVariable(v.name, v.source)
|
||||
|
||||
def flatten_nodes(self):
|
||||
def traverse(cur):
|
||||
if isinstance(cur, SglExprList):
|
||||
for child in cur.expr_list:
|
||||
traverse(child)
|
||||
else:
|
||||
ret.append(cur)
|
||||
|
||||
ret = []
|
||||
for x in self.nodes:
|
||||
traverse(x)
|
||||
return ret
|
||||
|
||||
def __del__(self):
|
||||
pass
|
||||
|
||||
|
||||
class TracingScope:
|
||||
cur_scope = None
|
||||
|
||||
def __init__(self, tracer_state: TracerProgramState):
|
||||
self.tracer_state = tracer_state
|
||||
self.last_scope = TracingScope.cur_scope
|
||||
|
||||
def __enter__(self):
|
||||
TracingScope.cur_scope = self
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
TracingScope.cur_scope = self.last_scope
|
||||
|
||||
@staticmethod
|
||||
def get_current_scope():
|
||||
return TracingScope.cur_scope
|
||||
|
||||
def add_child_state(self, state: TracerProgramState):
|
||||
cur_scope = self
|
||||
while cur_scope is not None:
|
||||
cur_scope.tracer_state.child_states.append(state)
|
||||
cur_scope = cur_scope.last_scope
|
||||
16
python/sglang/launch_server.py
Normal file
16
python/sglang/launch_server.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Launch the inference server."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from sglang.srt.entrypoints.http_server import launch_server
|
||||
from sglang.srt.server_args import prepare_server_args
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
|
||||
if __name__ == "__main__":
|
||||
server_args = prepare_server_args(sys.argv[1:])
|
||||
|
||||
try:
|
||||
launch_server(server_args)
|
||||
finally:
|
||||
kill_process_tree(os.getpid(), include_parent=False)
|
||||
166
python/sglang/profiler.py
Normal file
166
python/sglang/profiler.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
Run live profiling.
|
||||
|
||||
Usage:
|
||||
python3 -m sglang.profiler
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
PARENT_FOLDER = "/tmp/sglang-profile"
|
||||
|
||||
|
||||
def _run_profile(
|
||||
url: Optional[str],
|
||||
num_steps: int,
|
||||
activities: List[str],
|
||||
output_dir: Optional[str] = None,
|
||||
profile_name: Optional[str] = None,
|
||||
profile_by_stage: bool = False,
|
||||
) -> str:
|
||||
if output_dir is None:
|
||||
output_dir = PARENT_FOLDER
|
||||
|
||||
output_dir = os.path.normpath(output_dir)
|
||||
output_dir = os.path.abspath(output_dir)
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
# Add "profile_name/timestamp" to the path.
|
||||
if profile_name:
|
||||
output_dir = output_dir / profile_name
|
||||
output_dir = output_dir / str(time.time())
|
||||
output_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
print(f"Dump profiling traces to {output_dir}")
|
||||
print(
|
||||
f"Waiting for {num_steps} steps and the trace to be flushed.... ({profile_by_stage=})"
|
||||
)
|
||||
|
||||
# Dump server args.
|
||||
file_path = Path(output_dir) / "server_args.json"
|
||||
if not file_path.exists():
|
||||
response = requests.get(url + "/get_server_info")
|
||||
response.raise_for_status()
|
||||
server_args_data = response.json()
|
||||
with open(file_path, "w") as file:
|
||||
file.write(json.dumps(server_args_data))
|
||||
|
||||
# Start profiler. The API replies when all steps are processed
|
||||
# and files are generated.
|
||||
json_data = {
|
||||
"output_dir": str(output_dir),
|
||||
"num_steps": str(num_steps),
|
||||
"activities": activities,
|
||||
"profile_by_stage": profile_by_stage,
|
||||
}
|
||||
|
||||
response = requests.post(url=url + "/start_profile", json=json_data)
|
||||
response.raise_for_status()
|
||||
|
||||
trace_link = str(output_dir)
|
||||
return trace_link
|
||||
|
||||
|
||||
def run_profile(
|
||||
url: Optional[str],
|
||||
num_steps: int,
|
||||
activities: List[str],
|
||||
output_dir: Optional[str] = None,
|
||||
profile_name: Optional[str] = None,
|
||||
profile_by_stage: bool = False,
|
||||
):
|
||||
# step based profile will self terminate on num_steps constraints
|
||||
link = _run_profile(
|
||||
url, num_steps, activities, output_dir, profile_name, profile_by_stage
|
||||
)
|
||||
return link
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParser(description="Benchmark the online serving throughput.")
|
||||
parser.add_argument(
|
||||
"--url",
|
||||
type=str,
|
||||
default="http://localhost:30000",
|
||||
help="Server or API base url if not using http host and port.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Profile directory to dump profile traces.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of this profile run.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-steps",
|
||||
type=int,
|
||||
default=5,
|
||||
help="The number of forward steps to profile.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile-by-stage",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
type=bool,
|
||||
default=False,
|
||||
help="The number of forward steps to profile.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpu",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
type=bool,
|
||||
default=True,
|
||||
help="Whether to profile CPU activity",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
type=bool,
|
||||
default=True,
|
||||
help="Whether to profile GPU activity",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mem",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Whether to memory usage (https://pytorch.org/memory_viz)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rpd",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Whether to use rpd profiler (https://github.com/ROCm/rocmProfileData)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
activities = []
|
||||
if args.cpu:
|
||||
activities.append("CPU")
|
||||
if args.gpu:
|
||||
activities.append("GPU")
|
||||
if args.mem:
|
||||
activities.append("MEM")
|
||||
if args.rpd:
|
||||
activities.append("RPD")
|
||||
run_profile(
|
||||
args.url,
|
||||
args.num_steps,
|
||||
activities,
|
||||
args.output_dir,
|
||||
args.profile_name,
|
||||
args.profile_by_stage,
|
||||
)
|
||||
BIN
python/sglang/srt/__pycache__/_custom_ops.cpython-310.pyc
Normal file
BIN
python/sglang/srt/__pycache__/_custom_ops.cpython-310.pyc
Normal file
Binary file not shown.
BIN
python/sglang/srt/__pycache__/aio_rwlock.cpython-310.pyc
Normal file
BIN
python/sglang/srt/__pycache__/aio_rwlock.cpython-310.pyc
Normal file
Binary file not shown.
BIN
python/sglang/srt/__pycache__/constants.cpython-310.pyc
Normal file
BIN
python/sglang/srt/__pycache__/constants.cpython-310.pyc
Normal file
Binary file not shown.
BIN
python/sglang/srt/__pycache__/custom_op.cpython-310.pyc
Normal file
BIN
python/sglang/srt/__pycache__/custom_op.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
python/sglang/srt/__pycache__/host_shared_memory.cpython-310.pyc
Normal file
BIN
python/sglang/srt/__pycache__/host_shared_memory.cpython-310.pyc
Normal file
Binary file not shown.
BIN
python/sglang/srt/__pycache__/offloader.cpython-310.pyc
Normal file
BIN
python/sglang/srt/__pycache__/offloader.cpython-310.pyc
Normal file
Binary file not shown.
BIN
python/sglang/srt/__pycache__/operations.cpython-310.pyc
Normal file
BIN
python/sglang/srt/__pycache__/operations.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
python/sglang/srt/__pycache__/patch_torch.cpython-310.pyc
Normal file
BIN
python/sglang/srt/__pycache__/patch_torch.cpython-310.pyc
Normal file
Binary file not shown.
BIN
python/sglang/srt/__pycache__/poll_based_barrier.cpython-310.pyc
Normal file
BIN
python/sglang/srt/__pycache__/poll_based_barrier.cpython-310.pyc
Normal file
Binary file not shown.
BIN
python/sglang/srt/__pycache__/server_args.cpython-310.pyc
Normal file
BIN
python/sglang/srt/__pycache__/server_args.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
python/sglang/srt/__pycache__/two_batch_overlap.cpython-310.pyc
Normal file
BIN
python/sglang/srt/__pycache__/two_batch_overlap.cpython-310.pyc
Normal file
Binary file not shown.
BIN
python/sglang/srt/__pycache__/utils.cpython-310.pyc
Normal file
BIN
python/sglang/srt/__pycache__/utils.cpython-310.pyc
Normal file
Binary file not shown.
BIN
python/sglang/srt/__pycache__/warmup.cpython-310.pyc
Normal file
BIN
python/sglang/srt/__pycache__/warmup.cpython-310.pyc
Normal file
Binary file not shown.
177
python/sglang/srt/_custom_ops.py
Normal file
177
python/sglang/srt/_custom_ops.py
Normal file
@@ -0,0 +1,177 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py
|
||||
import logging
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu, is_npu
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
use_vllm_custom_allreduce = get_bool_env_var(
|
||||
"USE_VLLM_CUSTOM_ALLREDUCE", default="false"
|
||||
)
|
||||
|
||||
if not is_hpu():
|
||||
# ROCm does not use vllm custom allreduce
|
||||
if use_vllm_custom_allreduce and not is_hip():
|
||||
try:
|
||||
import vllm._C
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import from vllm._C with %r", e)
|
||||
else:
|
||||
try:
|
||||
import sgl_kernel
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import from custom_ar with %r", e)
|
||||
|
||||
|
||||
if not is_hip() and not is_npu():
|
||||
if use_vllm_custom_allreduce:
|
||||
custom_op = torch.ops._C_custom_ar
|
||||
else:
|
||||
custom_op = sgl_kernel.allreduce
|
||||
|
||||
# custom allreduce
|
||||
def init_custom_ar(
|
||||
ipc_tensors: List[torch.Tensor],
|
||||
rank_data: torch.Tensor,
|
||||
rank: int,
|
||||
full_nvlink: bool,
|
||||
) -> int:
|
||||
return custom_op.init_custom_ar(ipc_tensors, rank_data, rank, full_nvlink)
|
||||
|
||||
def all_reduce(
|
||||
fa: int,
|
||||
inp: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
reg_buffer: int,
|
||||
reg_buffer_sz_bytes: int,
|
||||
) -> None:
|
||||
custom_op.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes)
|
||||
|
||||
def dispose(fa: int) -> None:
|
||||
custom_op.dispose(fa)
|
||||
|
||||
def meta_size() -> int:
|
||||
return custom_op.meta_size()
|
||||
|
||||
def register_buffer(fa: int, ipc_tensors: List[int]) -> None:
|
||||
return custom_op.register_buffer(fa, ipc_tensors)
|
||||
|
||||
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
|
||||
return custom_op.get_graph_buffer_ipc_meta(fa)
|
||||
|
||||
def register_graph_buffers(
|
||||
fa: int, handles: List[List[int]], offsets: List[List[int]]
|
||||
) -> None:
|
||||
custom_op.register_graph_buffers(fa, handles, offsets)
|
||||
|
||||
else:
|
||||
# ROCM custom allreduce
|
||||
|
||||
def init_custom_ar(
|
||||
meta: torch.Tensor,
|
||||
rank_data: torch.Tensor,
|
||||
handles: List[str],
|
||||
offsets: List[int],
|
||||
rank: int,
|
||||
full_nvlink: bool,
|
||||
) -> int:
|
||||
return sgl_kernel.allreduce.init_custom_ar(
|
||||
meta, rank_data, handles, offsets, rank, full_nvlink
|
||||
)
|
||||
|
||||
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
|
||||
sgl_kernel.allreduce.all_reduce_reg(fa, inp, out)
|
||||
|
||||
def all_reduce_unreg(
|
||||
fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor
|
||||
) -> None:
|
||||
sgl_kernel.allreduce.all_reduce_unreg(fa, inp, reg_buffer, out)
|
||||
|
||||
def dispose(fa: int) -> None:
|
||||
sgl_kernel.allreduce.dispose(fa)
|
||||
|
||||
def meta_size() -> int:
|
||||
return sgl_kernel.allreduce.meta_size()
|
||||
|
||||
def register_buffer(
|
||||
fa: int, t: torch.Tensor, handles: List[str], offsets: List[int]
|
||||
) -> None:
|
||||
return sgl_kernel.allreduce.register_buffer(fa, t, handles, offsets)
|
||||
|
||||
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]:
|
||||
return sgl_kernel.allreduce.get_graph_buffer_ipc_meta(fa)
|
||||
|
||||
def register_graph_buffers(
|
||||
fa: int, handles: List[str], offsets: List[List[int]]
|
||||
) -> None:
|
||||
sgl_kernel.allreduce.register_graph_buffers(fa, handles, offsets)
|
||||
|
||||
def allocate_meta_buffer(size: int) -> torch.Tensor:
|
||||
return sgl_kernel.allreduce.allocate_meta_buffer(size)
|
||||
|
||||
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
|
||||
return sgl_kernel.allreduce.get_meta_buffer_ipc_handle(inp)
|
||||
|
||||
# ROCM custom quick allreduce
|
||||
|
||||
def init_custom_qr(
|
||||
rank: int, world_size: int, qr_max_size: Optional[int] = None
|
||||
) -> int:
|
||||
return sgl_kernel.allreduce.init_custom_qr(world_size, rank, qr_max_size)
|
||||
|
||||
def qr_get_handle(fa: int) -> torch.Tensor:
|
||||
return sgl_kernel.allreduce.qr_get_handle(fa)
|
||||
|
||||
def qr_open_handles(fa: int, handles: list[torch.Tensor]) -> None:
|
||||
sgl_kernel.allreduce.qr_open_handles(fa, handles)
|
||||
|
||||
def qr_all_reduce(
|
||||
fa: int,
|
||||
inp: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
quant_level: int,
|
||||
cast_bf2half: bool,
|
||||
) -> None:
|
||||
sgl_kernel.allreduce.qr_all_reduce(fa, inp, out, quant_level, cast_bf2half)
|
||||
|
||||
def qr_destroy(fa: int) -> None:
|
||||
sgl_kernel.allreduce.qr_destroy(fa)
|
||||
|
||||
def qr_max_size() -> int:
|
||||
return sgl_kernel.allreduce.qr_max_size()
|
||||
|
||||
|
||||
def mscclpp_generate_unique_id() -> bytes:
|
||||
return sgl_kernel.allreduce.mscclpp_generate_unique_id()
|
||||
|
||||
|
||||
def mscclpp_init_context(
|
||||
unique_id: bytes,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
scratch: torch.Tensor,
|
||||
put_buffer: torch.Tensor,
|
||||
nranks_per_node: int,
|
||||
rank_to_node: List[int],
|
||||
rank_to_ib: List[int],
|
||||
context_selection: int,
|
||||
) -> int:
|
||||
return sgl_kernel.allreduce.mscclpp_init_context(
|
||||
unique_id,
|
||||
rank,
|
||||
world_size,
|
||||
scratch,
|
||||
put_buffer,
|
||||
nranks_per_node,
|
||||
rank_to_node,
|
||||
rank_to_ib,
|
||||
context_selection,
|
||||
)
|
||||
|
||||
|
||||
def mscclpp_allreduce(
|
||||
context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int
|
||||
) -> None:
|
||||
return sgl_kernel.allreduce.mscclpp_allreduce(context, inp, out, nthreads, nblocks)
|
||||
100
python/sglang/srt/aio_rwlock.py
Normal file
100
python/sglang/srt/aio_rwlock.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import asyncio
|
||||
|
||||
|
||||
class RWLock:
|
||||
def __init__(self):
|
||||
# Protects internal state
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Condition variable used to wait for state changes
|
||||
self._cond = asyncio.Condition(self._lock)
|
||||
|
||||
# Number of readers currently holding the lock
|
||||
self._readers = 0
|
||||
|
||||
# Whether a writer is currently holding the lock
|
||||
self._writer_active = False
|
||||
|
||||
# How many writers are queued waiting for a turn
|
||||
self._waiting_writers = 0
|
||||
|
||||
@property
|
||||
def reader_lock(self):
|
||||
"""
|
||||
A context manager for acquiring a shared (reader) lock.
|
||||
|
||||
Example:
|
||||
async with rwlock.reader_lock:
|
||||
# read-only access
|
||||
"""
|
||||
return _ReaderLock(self)
|
||||
|
||||
@property
|
||||
def writer_lock(self):
|
||||
"""
|
||||
A context manager for acquiring an exclusive (writer) lock.
|
||||
|
||||
Example:
|
||||
async with rwlock.writer_lock:
|
||||
# exclusive access
|
||||
"""
|
||||
return _WriterLock(self)
|
||||
|
||||
async def acquire_reader(self):
|
||||
async with self._lock:
|
||||
# Wait until there is no active writer or waiting writer
|
||||
# to ensure fairness.
|
||||
while self._writer_active or self._waiting_writers > 0:
|
||||
await self._cond.wait()
|
||||
self._readers += 1
|
||||
|
||||
async def release_reader(self):
|
||||
async with self._lock:
|
||||
self._readers -= 1
|
||||
# If this was the last reader, wake up anyone waiting
|
||||
# (potentially a writer or new readers).
|
||||
if self._readers == 0:
|
||||
self._cond.notify_all()
|
||||
|
||||
async def acquire_writer(self):
|
||||
async with self._lock:
|
||||
# Increment the count of writers waiting
|
||||
self._waiting_writers += 1
|
||||
try:
|
||||
# Wait while either a writer is active or readers are present
|
||||
while self._writer_active or self._readers > 0:
|
||||
await self._cond.wait()
|
||||
self._writer_active = True
|
||||
finally:
|
||||
# Decrement waiting writers only after we've acquired the writer lock
|
||||
self._waiting_writers -= 1
|
||||
|
||||
async def release_writer(self):
|
||||
async with self._lock:
|
||||
self._writer_active = False
|
||||
# Wake up anyone waiting (readers or writers)
|
||||
self._cond.notify_all()
|
||||
|
||||
|
||||
class _ReaderLock:
|
||||
def __init__(self, rwlock: RWLock):
|
||||
self._rwlock = rwlock
|
||||
|
||||
async def __aenter__(self):
|
||||
await self._rwlock.acquire_reader()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self._rwlock.release_reader()
|
||||
|
||||
|
||||
class _WriterLock:
|
||||
def __init__(self, rwlock: RWLock):
|
||||
self._rwlock = rwlock
|
||||
|
||||
async def __aenter__(self):
|
||||
await self._rwlock.acquire_writer()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self._rwlock.release_writer()
|
||||
137
python/sglang/srt/bench_utils.py
Normal file
137
python/sglang/srt/bench_utils.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import os
|
||||
import sys
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# NOTE copied and modified from DeepGEMM
|
||||
class suppress_stdout_stderr:
|
||||
def __enter__(self):
|
||||
self.outnull_file = open(os.devnull, "w")
|
||||
self.errnull_file = open(os.devnull, "w")
|
||||
|
||||
self.old_stdout_fileno_undup = sys.stdout.fileno()
|
||||
self.old_stderr_fileno_undup = sys.stderr.fileno()
|
||||
|
||||
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
|
||||
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
|
||||
|
||||
self.old_stdout = sys.stdout
|
||||
self.old_stderr = sys.stderr
|
||||
|
||||
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
|
||||
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
|
||||
|
||||
sys.stdout = self.outnull_file
|
||||
sys.stderr = self.errnull_file
|
||||
return self
|
||||
|
||||
def __exit__(self, *_):
|
||||
sys.stdout = self.old_stdout
|
||||
sys.stderr = self.old_stderr
|
||||
|
||||
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
|
||||
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
|
||||
|
||||
os.close(self.old_stdout_fileno)
|
||||
os.close(self.old_stderr_fileno)
|
||||
|
||||
self.outnull_file.close()
|
||||
self.errnull_file.close()
|
||||
|
||||
|
||||
# NOTE copied and modified from DeepGEMM
|
||||
def bench_kineto(
|
||||
fn,
|
||||
kernel_names,
|
||||
num_tests: int = 30,
|
||||
suppress_kineto_output: bool = False,
|
||||
trace_path: str = None,
|
||||
flush_l2: bool = True,
|
||||
with_multiple_kernels: bool = False,
|
||||
):
|
||||
# Conflict with Nsight Systems
|
||||
using_nsys = int(os.environ.get("SGLANG_NSYS_PROFILING", 0))
|
||||
|
||||
# By default, flush L2 with an excessive 8GB memset to give the GPU some (literal) chill time without full idle
|
||||
flush_l2_size = int(8e9 // 4)
|
||||
|
||||
# For some auto-tuning kernels with prints
|
||||
fn()
|
||||
|
||||
# Profile
|
||||
suppress = (
|
||||
suppress_stdout_stderr
|
||||
if suppress_kineto_output and not using_nsys
|
||||
else nullcontext
|
||||
)
|
||||
with suppress():
|
||||
schedule = (
|
||||
torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1)
|
||||
if not using_nsys
|
||||
else None
|
||||
)
|
||||
profiler = (
|
||||
torch.profiler.profile(
|
||||
activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule
|
||||
)
|
||||
if not using_nsys
|
||||
else nullcontext()
|
||||
)
|
||||
with profiler:
|
||||
for i in range(2):
|
||||
for _ in range(num_tests):
|
||||
if flush_l2:
|
||||
torch.empty(
|
||||
flush_l2_size, dtype=torch.int, device="cuda"
|
||||
).zero_()
|
||||
fn()
|
||||
|
||||
if not using_nsys:
|
||||
profiler.step()
|
||||
|
||||
# Return 1 if using Nsight Systems
|
||||
if using_nsys:
|
||||
return 1
|
||||
|
||||
# Parse the profiling table
|
||||
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
|
||||
is_tuple = isinstance(kernel_names, tuple)
|
||||
prof_lines = (
|
||||
profiler.key_averages()
|
||||
.table(sort_by="cuda_time_total", max_name_column_width=100)
|
||||
.split("\n")
|
||||
)
|
||||
kernel_names = (kernel_names,) if isinstance(kernel_names, str) else kernel_names
|
||||
assert all([isinstance(name, str) for name in kernel_names])
|
||||
if not with_multiple_kernels:
|
||||
for name in kernel_names:
|
||||
assert (
|
||||
sum([name in line for line in prof_lines]) == 1
|
||||
), f"Errors of the kernel {name} in the profiling table (table: {prof_lines})"
|
||||
|
||||
# Save chrome traces
|
||||
if trace_path is not None:
|
||||
profiler.export_chrome_trace(trace_path)
|
||||
|
||||
# Return average kernel times
|
||||
units = {"ms": 1e3, "us": 1e6}
|
||||
kernel_times = []
|
||||
for name in kernel_names:
|
||||
total_time = 0
|
||||
total_num = 0
|
||||
for line in prof_lines:
|
||||
if name in line:
|
||||
time_str = line.split()[-2]
|
||||
num_str = line.split()[-1]
|
||||
for unit, scale in units.items():
|
||||
if unit in time_str:
|
||||
total_time += (
|
||||
float(time_str.replace(unit, "")) / scale * int(num_str)
|
||||
)
|
||||
total_num += int(num_str)
|
||||
break
|
||||
kernel_times.append(total_time / total_num)
|
||||
|
||||
return tuple(kernel_times) if is_tuple else kernel_times[0]
|
||||
27
python/sglang/srt/configs/__init__.py
Normal file
27
python/sglang/srt/configs/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from sglang.srt.configs.chatglm import ChatGLMConfig
|
||||
from sglang.srt.configs.dbrx import DbrxConfig
|
||||
from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config
|
||||
from sglang.srt.configs.exaone import ExaoneConfig
|
||||
from sglang.srt.configs.janus_pro import MultiModalityConfig
|
||||
from sglang.srt.configs.kimi_vl import KimiVLConfig
|
||||
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
|
||||
from sglang.srt.configs.longcat_flash import LongcatFlashConfig
|
||||
from sglang.srt.configs.step3_vl import (
|
||||
Step3TextConfig,
|
||||
Step3VisionEncoderConfig,
|
||||
Step3VLConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ExaoneConfig",
|
||||
"ChatGLMConfig",
|
||||
"DbrxConfig",
|
||||
"DeepseekVL2Config",
|
||||
"LongcatFlashConfig",
|
||||
"MultiModalityConfig",
|
||||
"KimiVLConfig",
|
||||
"MoonViTConfig",
|
||||
"Step3VLConfig",
|
||||
"Step3TextConfig",
|
||||
"Step3VisionEncoderConfig",
|
||||
]
|
||||
BIN
python/sglang/srt/configs/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
python/sglang/srt/configs/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
python/sglang/srt/configs/__pycache__/chatglm.cpython-310.pyc
Normal file
BIN
python/sglang/srt/configs/__pycache__/chatglm.cpython-310.pyc
Normal file
Binary file not shown.
BIN
python/sglang/srt/configs/__pycache__/dbrx.cpython-310.pyc
Normal file
BIN
python/sglang/srt/configs/__pycache__/dbrx.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
python/sglang/srt/configs/__pycache__/exaone.cpython-310.pyc
Normal file
BIN
python/sglang/srt/configs/__pycache__/exaone.cpython-310.pyc
Normal file
Binary file not shown.
BIN
python/sglang/srt/configs/__pycache__/internvl.cpython-310.pyc
Normal file
BIN
python/sglang/srt/configs/__pycache__/internvl.cpython-310.pyc
Normal file
Binary file not shown.
BIN
python/sglang/srt/configs/__pycache__/janus_pro.cpython-310.pyc
Normal file
BIN
python/sglang/srt/configs/__pycache__/janus_pro.cpython-310.pyc
Normal file
Binary file not shown.
BIN
python/sglang/srt/configs/__pycache__/kimi_vl.cpython-310.pyc
Normal file
BIN
python/sglang/srt/configs/__pycache__/kimi_vl.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
python/sglang/srt/configs/__pycache__/step3_vl.cpython-310.pyc
Normal file
BIN
python/sglang/srt/configs/__pycache__/step3_vl.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
python/sglang/srt/configs/__pycache__/utils.cpython-310.pyc
Normal file
BIN
python/sglang/srt/configs/__pycache__/utils.cpython-310.pyc
Normal file
Binary file not shown.
78
python/sglang/srt/configs/chatglm.py
Normal file
78
python/sglang/srt/configs/chatglm.py
Normal file
@@ -0,0 +1,78 @@
|
||||
# Adapted from
|
||||
# https://github.com/THUDM/ChatGLM2-6B
|
||||
# https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/configs/chatglm.py
|
||||
|
||||
# ChatGLM2 and ChatGLM3 share the same config.
|
||||
# ChatGLM4 is officially supported by Huggingface
|
||||
# transformers >= 4.46.0 is required
|
||||
# https://huggingface.co/docs/transformers/en/model_doc/glm
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
class ChatGLMConfig(PretrainedConfig):
|
||||
model_type = "chatglm"
|
||||
attribute_map = {
|
||||
"num_hidden_layers": "num_layers",
|
||||
"n_head_kv": "multi_query_group_num",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_layers=28,
|
||||
padded_vocab_size=65024,
|
||||
hidden_size=4096,
|
||||
ffn_hidden_size=13696,
|
||||
kv_channels=128,
|
||||
num_attention_heads=32,
|
||||
seq_length=2048,
|
||||
hidden_dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
layernorm_epsilon=1e-5,
|
||||
rmsnorm=True,
|
||||
apply_residual_connection_post_layernorm=False,
|
||||
post_layer_norm=True,
|
||||
add_bias_linear=False,
|
||||
add_qkv_bias=False,
|
||||
interleaved_qkv=False,
|
||||
bias_dropout_fusion=True,
|
||||
multi_query_attention=False,
|
||||
multi_query_group_num=1,
|
||||
apply_query_key_layer_scaling=True,
|
||||
attention_softmax_in_fp32=True,
|
||||
fp32_residual_connection=False,
|
||||
quantization_bit=0,
|
||||
pre_seq_len=None,
|
||||
prefix_projection=False,
|
||||
**kwargs
|
||||
):
|
||||
self.num_layers = num_layers
|
||||
self.vocab_size = padded_vocab_size
|
||||
self.padded_vocab_size = padded_vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.ffn_hidden_size = ffn_hidden_size
|
||||
self.kv_channels = kv_channels
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.seq_length = seq_length
|
||||
# It is to be compatible with long lora.
|
||||
self.max_position_embeddings = seq_length
|
||||
self.hidden_dropout = hidden_dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.layernorm_epsilon = layernorm_epsilon
|
||||
self.rmsnorm = rmsnorm
|
||||
self.apply_residual_connection_post_layernorm = (
|
||||
apply_residual_connection_post_layernorm
|
||||
)
|
||||
self.post_layer_norm = post_layer_norm
|
||||
self.add_bias_linear = add_bias_linear
|
||||
self.add_qkv_bias = add_qkv_bias
|
||||
self.bias_dropout_fusion = bias_dropout_fusion
|
||||
self.multi_query_attention = multi_query_attention
|
||||
self.multi_query_group_num = multi_query_group_num
|
||||
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
|
||||
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
|
||||
self.fp32_residual_connection = fp32_residual_connection
|
||||
self.quantization_bit = quantization_bit
|
||||
self.pre_seq_len = pre_seq_len
|
||||
self.prefix_projection = prefix_projection
|
||||
self.interleaved_qkv = interleaved_qkv
|
||||
super().__init__(**kwargs)
|
||||
279
python/sglang/srt/configs/dbrx.py
Normal file
279
python/sglang/srt/configs/dbrx.py
Normal file
@@ -0,0 +1,279 @@
|
||||
# Adapted from
|
||||
# https://huggingface.co/databricks/dbrx-base/blob/main/configuration_dbrx.py
|
||||
# https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/configs/dbrx.py
|
||||
"""Dbrx configuration."""
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {} # type: ignore
|
||||
|
||||
|
||||
class DbrxAttentionConfig(PretrainedConfig):
|
||||
"""Configuration class for Dbrx Attention.
|
||||
|
||||
[`DbrxAttention`] class. It is used to instantiate attention layers
|
||||
according to the specified arguments, defining the layers architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
attn_pdrop (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability for the attention layers.
|
||||
clip_qkv (`float`, *optional*, defaults to None):
|
||||
If not `None`, clip the queries, keys, and values in the attention layer to this value.
|
||||
kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads.
|
||||
rope_theta (float): The base frequency for rope.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
attn_pdrop: float = 0,
|
||||
clip_qkv: Optional[float] = None,
|
||||
kv_n_heads: int = 1,
|
||||
rope_theta: float = 10000.0,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.attn_pdrop = attn_pdrop
|
||||
self.clip_qkv = clip_qkv
|
||||
self.kv_n_heads = kv_n_heads
|
||||
self.rope_theta = rope_theta
|
||||
|
||||
for k in ["model_type"]:
|
||||
if k in kwargs:
|
||||
kwargs.pop(k)
|
||||
if len(kwargs) != 0:
|
||||
raise ValueError(f"Found unknown {kwargs=}")
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, pretrained_model_name_or_path: str, **kwargs: Any
|
||||
) -> "PretrainedConfig":
|
||||
cls._set_token_in_kwargs(kwargs)
|
||||
|
||||
config_dict, kwargs = cls.get_config_dict(
|
||||
pretrained_model_name_or_path, **kwargs
|
||||
)
|
||||
|
||||
if config_dict.get("model_type") == "dbrx":
|
||||
config_dict = config_dict["attn_config"]
|
||||
|
||||
if (
|
||||
"model_type" in config_dict
|
||||
and hasattr(cls, "model_type")
|
||||
and config_dict["model_type"] != cls.model_type
|
||||
):
|
||||
logger.warning(
|
||||
"You are using a model of type %s to instantiate a model of "
|
||||
"type %s. This is not supported for all configurations of "
|
||||
"models and can yield errors.",
|
||||
config_dict["model_type"],
|
||||
cls.model_type,
|
||||
)
|
||||
|
||||
return cls.from_dict(config_dict, **kwargs)
|
||||
|
||||
|
||||
class DbrxFFNConfig(PretrainedConfig):
|
||||
"""Configuration class for Dbrx FFN.
|
||||
|
||||
[`DbrxFFN`] class. It is used to instantiate feedforward layers according to
|
||||
the specified arguments, defining the layers architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
ffn_act_fn (dict, optional): A dict specifying activation function for the FFN.
|
||||
The dict should have a key 'name' with the value being the name of
|
||||
the activation function along with any additional keyword arguments.
|
||||
ffn_hidden_size (int, optional): The hidden size of the feedforward network.
|
||||
moe_num_experts (int, optional): The number of experts in the mixture of experts layer.
|
||||
moe_top_k (int, optional): The number of experts to use in the mixture of experts layer.
|
||||
moe_jitter_eps (float, optional): The jitter epsilon for the mixture of experts layer.
|
||||
moe_loss_weight (float, optional): The loss weight for the mixture of experts layer.
|
||||
moe_normalize_expert_weights (float, optional): The normalization factor for the expert weights.
|
||||
uniform_expert_assignment (bool, optional): Whether to use uniform expert assignment.
|
||||
This should only be used for benchmarking purposes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ffn_act_fn: Optional[dict] = None,
|
||||
ffn_hidden_size: int = 3584,
|
||||
moe_num_experts: int = 4,
|
||||
moe_top_k: int = 1,
|
||||
moe_jitter_eps: Optional[float] = None,
|
||||
moe_loss_weight: float = 0.01,
|
||||
moe_normalize_expert_weights: Optional[float] = 1,
|
||||
uniform_expert_assignment: bool = False,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__()
|
||||
if ffn_act_fn is None:
|
||||
ffn_act_fn = {"name": "silu"}
|
||||
self.ffn_act_fn = ffn_act_fn
|
||||
self.ffn_hidden_size = ffn_hidden_size
|
||||
self.moe_num_experts = moe_num_experts
|
||||
self.moe_top_k = moe_top_k
|
||||
self.moe_jitter_eps = moe_jitter_eps
|
||||
self.moe_loss_weight = moe_loss_weight
|
||||
self.moe_normalize_expert_weights = moe_normalize_expert_weights
|
||||
self.uniform_expert_assignment = uniform_expert_assignment
|
||||
|
||||
for k in ["model_type"]:
|
||||
if k in kwargs:
|
||||
kwargs.pop(k)
|
||||
if len(kwargs) != 0:
|
||||
raise ValueError(f"Found unknown {kwargs=}")
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, pretrained_model_name_or_path: str, **kwargs: Any
|
||||
) -> "PretrainedConfig":
|
||||
cls._set_token_in_kwargs(kwargs)
|
||||
|
||||
config_dict, kwargs = cls.get_config_dict(
|
||||
pretrained_model_name_or_path, **kwargs
|
||||
)
|
||||
|
||||
if config_dict.get("model_type") == "dbrx":
|
||||
config_dict = config_dict["ffn_config"]
|
||||
|
||||
if (
|
||||
"model_type" in config_dict
|
||||
and hasattr(cls, "model_type")
|
||||
and config_dict["model_type"] != cls.model_type
|
||||
):
|
||||
logger.warning(
|
||||
"You are using a model of type %s to instantiate a model of "
|
||||
"type %s. This is not supported for all "
|
||||
"configurations of models and can yield errors.",
|
||||
config_dict["model_type"],
|
||||
cls.model_type,
|
||||
)
|
||||
|
||||
return cls.from_dict(config_dict, **kwargs)
|
||||
|
||||
|
||||
class DbrxConfig(PretrainedConfig):
|
||||
"""Configuration class for Dbrx.
|
||||
|
||||
[`DbrxModel`]. It is used to instantiate a Dbrx model according to the
|
||||
specified arguments, defining the model architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
d_model (`int`, *optional*, defaults to 6144):
|
||||
Dimensionality of the embeddings and hidden states.
|
||||
n_heads (`int`, *optional*, defaults to 48):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
n_layers (`int`, *optional*, defaults to 40):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
max_seq_len (`int`, *optional*, defaults to 32768):
|
||||
The maximum sequence length of the model.
|
||||
vocab_size (`int`, *optional*, defaults to 100352):
|
||||
Vocabulary size of the Dbrx model. Defines the maximum number of different tokens that can be represented by
|
||||
the `inputs_ids` passed when calling [`DbrxModel`].
|
||||
resid_pdrop (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability applied to the attention output before combining with residual.
|
||||
emb_pdrop (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability for the embedding layer.
|
||||
attn_config (`dict`, *optional*):
|
||||
A dictionary used to configure the model's attention module.
|
||||
ffn_config (`dict`, *optional*):
|
||||
A dictionary used to configure the model's FFN module.
|
||||
use_cache (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models).
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
output_router_logits (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the router logits should be returned by the model. Enabling this will also
|
||||
allow the model to output the auxiliary loss. See [here]() for more details
|
||||
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
|
||||
The aux loss factor for the total loss.
|
||||
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> from transformers import DbrxConfig, DbrxModel
|
||||
|
||||
>>> # Initializing a Dbrx configuration
|
||||
>>> configuration = DbrxConfig()
|
||||
|
||||
>>> # Initializing a model (with random weights) from the configuration
|
||||
>>> model = DbrxModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```
|
||||
"""
|
||||
|
||||
model_type = "dbrx"
|
||||
attribute_map = {
|
||||
"num_attention_heads": "n_heads",
|
||||
"hidden_size": "d_model",
|
||||
"num_hidden_layers": "n_layers",
|
||||
"max_position_embeddings": "max_seq_len",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int = 2048,
|
||||
n_heads: int = 16,
|
||||
n_layers: int = 24,
|
||||
max_seq_len: int = 2048,
|
||||
vocab_size: int = 32000,
|
||||
resid_pdrop: float = 0.0,
|
||||
emb_pdrop: float = 0.0,
|
||||
attn_config: Optional[DbrxAttentionConfig] = None,
|
||||
ffn_config: Optional[DbrxFFNConfig] = None,
|
||||
use_cache: bool = True,
|
||||
initializer_range: float = 0.02,
|
||||
output_router_logits: bool = False,
|
||||
router_aux_loss_coef: float = 0.05,
|
||||
**kwargs: Any,
|
||||
):
|
||||
if attn_config is None:
|
||||
self.attn_config = DbrxAttentionConfig()
|
||||
elif isinstance(attn_config, dict):
|
||||
self.attn_config = DbrxAttentionConfig(**attn_config)
|
||||
else:
|
||||
self.attn_config = attn_config
|
||||
|
||||
if ffn_config is None:
|
||||
self.ffn_config = DbrxFFNConfig()
|
||||
elif isinstance(ffn_config, dict):
|
||||
self.ffn_config = DbrxFFNConfig(**ffn_config)
|
||||
else:
|
||||
self.ffn_config = ffn_config
|
||||
|
||||
self.d_model = d_model
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.max_seq_len = max_seq_len
|
||||
self.vocab_size = vocab_size
|
||||
self.resid_pdrop = resid_pdrop
|
||||
self.emb_pdrop = emb_pdrop
|
||||
self.use_cache = use_cache
|
||||
self.initializer_range = initializer_range
|
||||
self.output_router_logits = output_router_logits
|
||||
self.router_aux_loss_coef = router_aux_loss_coef
|
||||
|
||||
tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
|
||||
if tie_word_embeddings:
|
||||
raise ValueError("tie_word_embeddings is not supported for Dbrx models.")
|
||||
|
||||
super().__init__(
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
688
python/sglang/srt/configs/deepseekvl2.py
Normal file
688
python/sglang/srt/configs/deepseekvl2.py
Normal file
@@ -0,0 +1,688 @@
|
||||
import math
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from PIL import Image, ImageOps
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
LlamaTokenizerFast,
|
||||
PretrainedConfig,
|
||||
ProcessorMixin,
|
||||
)
|
||||
|
||||
|
||||
def select_best_resolution(image_size, candidate_resolutions):
|
||||
# used for cropping
|
||||
original_width, original_height = image_size
|
||||
best_fit = None
|
||||
max_effective_resolution = 0
|
||||
min_wasted_resolution = float("inf")
|
||||
|
||||
for width, height in candidate_resolutions:
|
||||
scale = min(width / original_width, height / original_height)
|
||||
downscaled_width, downscaled_height = int(original_width * scale), int(
|
||||
original_height * scale
|
||||
)
|
||||
effective_resolution = min(
|
||||
downscaled_width * downscaled_height, original_width * original_height
|
||||
)
|
||||
wasted_resolution = (width * height) - effective_resolution
|
||||
|
||||
if effective_resolution > max_effective_resolution or (
|
||||
effective_resolution == max_effective_resolution
|
||||
and wasted_resolution < min_wasted_resolution
|
||||
):
|
||||
max_effective_resolution = effective_resolution
|
||||
min_wasted_resolution = wasted_resolution
|
||||
best_fit = (width, height)
|
||||
|
||||
return best_fit
|
||||
|
||||
|
||||
class DictOutput(object):
|
||||
def items(self):
|
||||
return self.__dict__.items()
|
||||
|
||||
def keys(self):
|
||||
return self.__dict__.keys()
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.__dict__[item]
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self.__dict__
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.__dict__[key] = value
|
||||
|
||||
|
||||
@dataclass
|
||||
class VLChatProcessorOutput(DictOutput):
|
||||
input_ids: torch.LongTensor
|
||||
target_ids: torch.LongTensor
|
||||
pixel_values: (
|
||||
torch.Tensor
|
||||
) # rename from "images" to "pixel_values" for compatibility
|
||||
images_seq_mask: torch.BoolTensor
|
||||
images_spatial_crop: torch.LongTensor
|
||||
|
||||
def __len__(self):
|
||||
return len(self.input_ids)
|
||||
|
||||
|
||||
class ImageTransform(object):
|
||||
def __init__(
|
||||
self,
|
||||
mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
|
||||
std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
|
||||
normalize: bool = True,
|
||||
):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.normalize = normalize
|
||||
|
||||
# only load torchvision.transforms when needed
|
||||
try:
|
||||
import torchvision.transforms as T
|
||||
|
||||
# FIXME: add version check for gguf
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"Please install torchvision via `pip install torchvision` to use Deepseek-VL2."
|
||||
) from err
|
||||
|
||||
transform_pipelines = [T.ToTensor()]
|
||||
|
||||
if normalize:
|
||||
transform_pipelines.append(T.Normalize(mean, std))
|
||||
|
||||
self.transform = T.Compose(transform_pipelines)
|
||||
|
||||
def __call__(self, pil_img: Image.Image):
|
||||
x = self.transform(pil_img)
|
||||
return x
|
||||
|
||||
|
||||
class DeepseekVLV2Processor(ProcessorMixin):
|
||||
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
|
||||
attributes = ["tokenizer"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: LlamaTokenizerFast,
|
||||
candidate_resolutions: Tuple[Tuple[int, int]],
|
||||
patch_size: int,
|
||||
downsample_ratio: int,
|
||||
image_mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
|
||||
image_std: Tuple[float, float, float] = (0.5, 0.5, 0.5),
|
||||
normalize: bool = True,
|
||||
image_token: str = "<image>",
|
||||
pad_token: str = "<|▁pad▁|>",
|
||||
add_special_token: bool = False,
|
||||
sft_format: str = "deepseek",
|
||||
mask_prompt: bool = True,
|
||||
ignore_id: int = -100,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
self.candidate_resolutions = candidate_resolutions
|
||||
self.image_size = candidate_resolutions[0][0]
|
||||
self.patch_size = patch_size
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.normalize = normalize
|
||||
self.downsample_ratio = downsample_ratio
|
||||
|
||||
self.image_transform = ImageTransform(
|
||||
mean=image_mean, std=image_std, normalize=normalize
|
||||
)
|
||||
self.tokenizer = tokenizer
|
||||
# must set this,padding side with make a difference in batch inference
|
||||
self.tokenizer.padding_side = "left"
|
||||
|
||||
# add the pad_token as special token to use 'tokenizer.pad_token' and 'tokenizer.pad_token_id'
|
||||
if tokenizer.pad_token is None:
|
||||
self.tokenizer.add_special_tokens({"pad_token": pad_token})
|
||||
|
||||
# add image token
|
||||
image_token_id = self.tokenizer.vocab.get(image_token)
|
||||
if image_token_id is None:
|
||||
special_tokens = [image_token]
|
||||
special_tokens_dict = {"additional_special_tokens": special_tokens}
|
||||
self.tokenizer.add_special_tokens(special_tokens_dict)
|
||||
self.image_token_id = self.tokenizer.vocab.get(image_token)
|
||||
|
||||
# add five special tokens for grounding-related tasks
|
||||
# <|ref|>, <|/ref|>, <|det|>, <|/det|>, <|grounding|>
|
||||
special_tokens = ["<|ref|>", "<|/ref|>", "<|det|>", "<|/det|>", "<|grounding|>"]
|
||||
special_tokens_dict = {"additional_special_tokens": special_tokens}
|
||||
self.tokenizer.add_special_tokens(special_tokens_dict)
|
||||
|
||||
# add special tokens for SFT data
|
||||
special_tokens = ["<|User|>", "<|Assistant|>"]
|
||||
special_tokens_dict = {"additional_special_tokens": special_tokens}
|
||||
self.tokenizer.add_special_tokens(special_tokens_dict)
|
||||
|
||||
self.image_token = image_token
|
||||
self.pad_token = pad_token
|
||||
self.add_special_token = add_special_token
|
||||
self.sft_format = sft_format
|
||||
self.mask_prompt = mask_prompt
|
||||
self.ignore_id = ignore_id
|
||||
|
||||
super().__init__(
|
||||
tokenizer,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def format_messages_v2(self, messages, pil_images, max_req_input_len=-1):
|
||||
"""play the role of format_messages_v2 and get_images_info in the last version"""
|
||||
tokenized_data = []
|
||||
masked_tokenized_data = [] # labels
|
||||
images_list = []
|
||||
images_seq_mask = []
|
||||
images_spatial_crop = []
|
||||
|
||||
image_index = 0
|
||||
image_token_cnt = messages.count(self.image_token)
|
||||
tokenized_str, images, seq_mask, spatial_crop = self.tokenize_with_images(
|
||||
messages,
|
||||
pil_images[image_index : image_index + image_token_cnt],
|
||||
bos=True,
|
||||
eos=True,
|
||||
cropping=len(pil_images) <= 2,
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
|
||||
image_index = image_token_cnt
|
||||
tokenized_data += tokenized_str
|
||||
if self.mask_prompt:
|
||||
masked_tokenized_data += [self.ignore_id] * len(tokenized_str)
|
||||
else:
|
||||
masked_tokenized_data += tokenized_str
|
||||
images_list += images
|
||||
images_seq_mask += seq_mask
|
||||
images_spatial_crop += spatial_crop
|
||||
|
||||
assert len(tokenized_data) == len(
|
||||
images_seq_mask
|
||||
), f"format_messages_v2: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}"
|
||||
|
||||
return (
|
||||
tokenized_data,
|
||||
masked_tokenized_data,
|
||||
images_list,
|
||||
images_seq_mask,
|
||||
images_spatial_crop,
|
||||
)
|
||||
|
||||
@property
|
||||
def bos_id(self):
|
||||
return self.tokenizer.bos_token_id
|
||||
|
||||
@property
|
||||
def eos_id(self):
|
||||
return self.tokenizer.eos_token_id
|
||||
|
||||
@property
|
||||
def pad_id(self):
|
||||
return self.tokenizer.pad_token_id
|
||||
|
||||
def encode(self, text: str, bos: bool = True, eos: bool = False):
|
||||
t = self.tokenizer.encode(text, add_special_tokens=False)
|
||||
|
||||
if bos:
|
||||
t = [self.bos_id] + t
|
||||
if eos:
|
||||
t = t + [self.eos_id]
|
||||
|
||||
return t
|
||||
|
||||
def decode(self, t: List[int], **kwargs) -> str:
|
||||
return self.tokenizer.decode(t, **kwargs)
|
||||
|
||||
def process_one(
|
||||
self,
|
||||
prompt: str = None,
|
||||
conversations: List[Dict[str, str]] = None,
|
||||
images: List[Image.Image] = None,
|
||||
apply_sft_format: bool = False,
|
||||
inference_mode: bool = True,
|
||||
system_prompt: str = "",
|
||||
max_req_input_len: int = -1,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
prompt (str): the formatted prompt;
|
||||
conversations (List[Dict]): conversations with a list of messages;
|
||||
images (List[ImageType]): the list of images;
|
||||
apply_sft_format (bool): if prompt is not None, then apply the SFT format to prompt;
|
||||
if conversations is not None, then it will always apply the SFT format to conversations;
|
||||
inference_mode (bool): if True, then remove the last eos token;
|
||||
system_prompt (str): the system prompt;
|
||||
**kwargs:
|
||||
|
||||
Returns:
|
||||
outputs (BaseProcessorOutput): the output of the processor,
|
||||
- input_ids (torch.LongTensor): [N + image tokens]
|
||||
- target_ids (torch.LongTensor): [N + image tokens]
|
||||
- images (torch.FloatTensor): [n_images, 3, H, W]
|
||||
- image_id (int): the id of the image token
|
||||
- num_image_tokens (List[int]): the number of image tokens
|
||||
"""
|
||||
|
||||
assert (
|
||||
prompt is None or conversations is None
|
||||
), "prompt and conversations cannot be used at the same time."
|
||||
|
||||
(
|
||||
tokenized_str,
|
||||
masked_tokenized_str,
|
||||
images_list,
|
||||
images_seq_mask,
|
||||
images_spatial_crop,
|
||||
) = self.format_messages_v2(conversations, images, max_req_input_len)
|
||||
|
||||
assert (
|
||||
len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str)
|
||||
), (
|
||||
f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, "
|
||||
f"imags_seq_mask's length {len(images_seq_mask)}, are not equal"
|
||||
)
|
||||
|
||||
input_ids = torch.LongTensor(tokenized_str)
|
||||
target_ids = torch.LongTensor(masked_tokenized_str)
|
||||
images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
|
||||
|
||||
# set input_ids < 0 | input_ids == self.image_token_id as ignore_id
|
||||
target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = (
|
||||
self.ignore_id
|
||||
)
|
||||
input_ids[input_ids < 0] = self.pad_id
|
||||
|
||||
if inference_mode:
|
||||
assert input_ids[-1] == self.eos_id
|
||||
input_ids = input_ids[:-1]
|
||||
target_ids = target_ids[:-1]
|
||||
images_seq_mask = images_seq_mask[:-1]
|
||||
|
||||
if len(images_list) == 0:
|
||||
images = torch.zeros((1, 3, self.image_size, self.image_size))
|
||||
images_spatial_crop = torch.zeros((1, 2), dtype=torch.long)
|
||||
else:
|
||||
images = torch.stack(images_list, dim=0)
|
||||
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
|
||||
|
||||
images_spatial_crop = torch.stack(
|
||||
[images_spatial_crop], dim=0
|
||||
) # stack the tensor to make it a batch of 1
|
||||
|
||||
prepare = VLChatProcessorOutput(
|
||||
input_ids=input_ids,
|
||||
target_ids=target_ids,
|
||||
pixel_values=images,
|
||||
images_seq_mask=images_seq_mask,
|
||||
images_spatial_crop=images_spatial_crop,
|
||||
)
|
||||
|
||||
return prepare
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
*,
|
||||
prompt: str = None,
|
||||
conversations: List[Dict[str, str]] = None,
|
||||
images: List[Image.Image] = None,
|
||||
apply_sft_format: bool = False,
|
||||
inference_mode: bool = True,
|
||||
system_prompt: str = "",
|
||||
max_req_input_len: int = -1,
|
||||
**kwargs,
|
||||
):
|
||||
prepare = self.process_one(
|
||||
prompt=prompt,
|
||||
conversations=conversations,
|
||||
images=images,
|
||||
apply_sft_format=apply_sft_format,
|
||||
inference_mode=inference_mode,
|
||||
system_prompt=system_prompt,
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
|
||||
return prepare
|
||||
|
||||
def find_all_indices(self, messages, target_value):
|
||||
indices = []
|
||||
for index, item in enumerate(messages):
|
||||
if item == target_value:
|
||||
indices.append(index)
|
||||
return indices
|
||||
|
||||
def tokenize_with_images(
|
||||
self,
|
||||
conversation: str,
|
||||
images: List[Image.Image],
|
||||
bos: bool = True,
|
||||
eos: bool = True,
|
||||
cropping: bool = True,
|
||||
max_req_input_len: int = -1,
|
||||
):
|
||||
"""Tokenize text with <image> tags."""
|
||||
images_list, images_seq_mask, images_spatial_crop = [], [], []
|
||||
text_splits = conversation.split(self.image_token)
|
||||
tokenized_str = []
|
||||
for text_sep, image in zip(text_splits, images):
|
||||
"""encode text_sep"""
|
||||
tokenized_sep = self.encode(text_sep, bos=False, eos=False)
|
||||
tokenized_str += tokenized_sep
|
||||
images_seq_mask += [False] * len(tokenized_sep)
|
||||
|
||||
"""select best resolution for anyres"""
|
||||
if cropping:
|
||||
best_width, best_height = select_best_resolution(
|
||||
image.size, self.candidate_resolutions
|
||||
)
|
||||
else:
|
||||
best_width, best_height = self.image_size, self.image_size
|
||||
# print(image.size, (best_width, best_height)) # check the select_best_resolutions func
|
||||
|
||||
"""process the global view"""
|
||||
global_view = ImageOps.pad(
|
||||
image,
|
||||
(self.image_size, self.image_size),
|
||||
color=tuple(int(x * 255) for x in self.image_transform.mean),
|
||||
)
|
||||
images_list.append(self.image_transform(global_view))
|
||||
|
||||
"""process the local views"""
|
||||
local_view = ImageOps.pad(
|
||||
image,
|
||||
(best_width, best_height),
|
||||
color=tuple(int(x * 255) for x in self.image_transform.mean),
|
||||
)
|
||||
for i in range(0, best_height, self.image_size):
|
||||
for j in range(0, best_width, self.image_size):
|
||||
images_list.append(
|
||||
self.image_transform(
|
||||
local_view.crop(
|
||||
(j, i, j + self.image_size, i + self.image_size)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
"""record height / width crop num"""
|
||||
num_width_tiles, num_height_tiles = (
|
||||
best_width // self.image_size,
|
||||
best_height // self.image_size,
|
||||
)
|
||||
images_spatial_crop.append([num_width_tiles, num_height_tiles])
|
||||
|
||||
"""add image tokens"""
|
||||
h = w = math.ceil(
|
||||
(self.image_size // self.patch_size) / self.downsample_ratio
|
||||
)
|
||||
# global views tokens h * (w + 1), 1 is for line separator
|
||||
tokenized_image = [self.image_token_id] * h * (w + 1)
|
||||
# add a separator between global and local views
|
||||
tokenized_image += [self.image_token_id]
|
||||
# local views tokens, (num_height_tiles * h) * (num_width_tiles * w + 1)
|
||||
tokenized_image += (
|
||||
[self.image_token_id]
|
||||
* (num_height_tiles * h)
|
||||
* (num_width_tiles * w + 1)
|
||||
)
|
||||
|
||||
tokenized_str += tokenized_image
|
||||
images_seq_mask += [True] * len(tokenized_image)
|
||||
# print(width_crop_num, height_crop_num, len(tokenized_image)) # test the correctness of the number of image-related tokens
|
||||
|
||||
"""process the last text split"""
|
||||
tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False)
|
||||
# deal with video, limit with request len
|
||||
if max_req_input_len > -1:
|
||||
if max_req_input_len < len(tokenized_sep) + len(tokenized_str) - 1:
|
||||
rest = max_req_input_len - len(tokenized_sep) - 1 - 1024
|
||||
tokenized_str = tokenized_str[:rest]
|
||||
images_seq_mask = images_seq_mask[:rest]
|
||||
tokenized_str += tokenized_sep
|
||||
images_seq_mask += [False] * len(tokenized_sep)
|
||||
|
||||
"""add the bos and eos tokens"""
|
||||
if bos:
|
||||
tokenized_str = [self.bos_id] + tokenized_str
|
||||
images_seq_mask = [False] + images_seq_mask
|
||||
if eos:
|
||||
tokenized_str = tokenized_str + [self.eos_id]
|
||||
images_seq_mask = images_seq_mask + [False]
|
||||
|
||||
assert len(tokenized_str) == len(
|
||||
images_seq_mask
|
||||
), f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}"
|
||||
|
||||
return tokenized_str, images_list, images_seq_mask, images_spatial_crop
|
||||
|
||||
|
||||
class DeepseekVL2VisionEncoderConfig(PretrainedConfig):
|
||||
model_type: str = "vision"
|
||||
|
||||
model_name: str = "siglip_large_patch16_384"
|
||||
image_size: int = 384
|
||||
patch_size: int = 16
|
||||
width: int = 1024
|
||||
layers: int = 24
|
||||
heads: int = 16
|
||||
mlp_ratio: int = 4
|
||||
global_pool: str = "map"
|
||||
ignore_head: bool = True
|
||||
class_token: bool = False
|
||||
num_classes: int = 0
|
||||
use_checkpoint: bool = False
|
||||
weight_init: str = "skip"
|
||||
deterministic: bool = False
|
||||
num_recomputing_layers: int = 0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "siglip_large_patch16_384",
|
||||
image_size: int = 384,
|
||||
patch_size: int = 16,
|
||||
width: int = 1024,
|
||||
layers: int = 24,
|
||||
heads: int = 16,
|
||||
mlp_ratio: int = 4,
|
||||
global_pool: str = "map",
|
||||
ignore_head: bool = True,
|
||||
class_token: bool = False,
|
||||
num_classes: int = 0,
|
||||
use_checkpoint: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.width = width
|
||||
self.layers = layers
|
||||
self.heads = heads
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.global_pool = global_pool
|
||||
self.ignore_head = ignore_head
|
||||
self.class_token = class_token
|
||||
self.num_classes = num_classes
|
||||
self.use_checkpoint = use_checkpoint
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class DeepseekVL2MlpProjectorConfig(PretrainedConfig):
|
||||
model_type = "mlp_projector"
|
||||
projector_type: str = "downsample_mlp_gelu"
|
||||
input_dim: int = 1152
|
||||
n_embed: int = 2048
|
||||
depth: int = 2
|
||||
mlp_ratio: int = 1
|
||||
downsample_ratio: int = 2
|
||||
token_pooling: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
projector_type: str = "downsample_mlp_gelu",
|
||||
input_dim: int = 1152,
|
||||
n_embed: int = 2048,
|
||||
depth: int = 2,
|
||||
mlp_ratio: int = 1,
|
||||
downsample_ratio: int = 2,
|
||||
**kwargs,
|
||||
):
|
||||
self.projector_type = projector_type
|
||||
self.input_dim = input_dim
|
||||
self.n_embed = n_embed
|
||||
self.depth = depth
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.downsample_ratio = downsample_ratio
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class DeepseekV2Config(PretrainedConfig):
|
||||
|
||||
model_type = "deepseek_v2"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=102400,
|
||||
hidden_size=4096,
|
||||
intermediate_size=11008,
|
||||
moe_intermediate_size=1407,
|
||||
num_hidden_layers=30,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=32,
|
||||
n_shared_experts=None,
|
||||
n_routed_experts=None,
|
||||
ep_size=1,
|
||||
routed_scaling_factor=1.0,
|
||||
kv_lora_rank=512,
|
||||
q_lora_rank=1536,
|
||||
qk_rope_head_dim=64,
|
||||
v_head_dim=128,
|
||||
qk_nope_head_dim=128,
|
||||
topk_method="gready",
|
||||
n_group=None,
|
||||
topk_group=None,
|
||||
num_experts_per_tok=None,
|
||||
moe_layer_freq=1,
|
||||
first_k_dense_replace=0,
|
||||
norm_topk_prob=False,
|
||||
scoring_func="softmax",
|
||||
aux_loss_alpha=0.001,
|
||||
seq_aux=True,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=2048,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=None,
|
||||
bos_token_id=100000,
|
||||
eos_token_id=100001,
|
||||
pretraining_tp=1,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
use_mla=True,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.moe_intermediate_size = moe_intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.n_shared_experts = n_shared_experts
|
||||
self.n_routed_experts = n_routed_experts
|
||||
self.ep_size = ep_size
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.topk_method = topk_method
|
||||
self.n_group = n_group
|
||||
self.topk_group = topk_group
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.moe_layer_freq = moe_layer_freq
|
||||
self.first_k_dense_replace = first_k_dense_replace
|
||||
self.norm_topk_prob = norm_topk_prob
|
||||
self.scoring_func = scoring_func
|
||||
self.aux_loss_alpha = aux_loss_alpha
|
||||
self.seq_aux = seq_aux
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = float(rms_norm_eps)
|
||||
self.pretraining_tp = pretraining_tp
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.use_mla = use_mla
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class DeepseekVL2Config(PretrainedConfig):
|
||||
model_type = "deepseek_vl_v2"
|
||||
vision_config: DeepseekVL2VisionEncoderConfig
|
||||
projector_config: DeepseekVL2MlpProjectorConfig
|
||||
language_config: DeepseekV2Config
|
||||
|
||||
tile_tag: str = "2D"
|
||||
global_view_pos: str = "head"
|
||||
candidate_resolutions: Tuple[Tuple[int, int]] = ((384, 384),)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tile_tag: str = "tile_tag",
|
||||
global_view_pos: str = "head",
|
||||
candidate_resolutions: Tuple[Tuple[int, int]] = ((384, 384),),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
vision_config = kwargs.get("vision_config", {})
|
||||
self.vision_config = DeepseekVL2VisionEncoderConfig(**vision_config)
|
||||
|
||||
projector_config = kwargs.get("projector_config", {})
|
||||
self.projector_config = DeepseekVL2MlpProjectorConfig(**projector_config)
|
||||
|
||||
language_config = kwargs.get("language_config", {})
|
||||
if isinstance(language_config, DeepseekV2Config):
|
||||
self.language_config = language_config
|
||||
else:
|
||||
self.language_config = DeepseekV2Config(**language_config)
|
||||
|
||||
self.tile_tag = tile_tag
|
||||
self.global_view_pos = global_view_pos
|
||||
self.candidate_resolutions = candidate_resolutions
|
||||
self.architectures = ["DeepseekVL2ForCausalLM"]
|
||||
|
||||
|
||||
AutoProcessor.register(DeepseekVL2Config, DeepseekVLV2Processor)
|
||||
17
python/sglang/srt/configs/device_config.py
Normal file
17
python/sglang/srt/configs/device_config.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DeviceConfig:
|
||||
device: Optional[torch.device]
|
||||
|
||||
def __init__(self, device: str = "cuda") -> None:
|
||||
if device in ["cuda", "xpu", "hpu", "cpu", "npu"]:
|
||||
self.device_type = device
|
||||
else:
|
||||
raise RuntimeError(f"Not supported device type: {device}")
|
||||
self.device = torch.device(self.device_type)
|
||||
195
python/sglang/srt/configs/exaone.py
Normal file
195
python/sglang/srt/configs/exaone.py
Normal file
@@ -0,0 +1,195 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The LG AI Research EXAONE Lab. All rights reserved.
|
||||
# Copyright 2024 The LG CNS AI Engineering Team.
|
||||
# Copyright 2023-2024 SGLang Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" EXAONE model configuration """
|
||||
from typing import Any, Dict
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
EXAONE_PRETRAINED_CONFIG_ARCHIVE_MAP: Dict[str, Any] = {}
|
||||
|
||||
|
||||
# ruff: noqa: E501
|
||||
class ExaoneConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a :class:`~transformers.ExaoneModel`. It is used to
|
||||
instantiate a EXAONE model according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of the Exaone
|
||||
|
||||
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
|
||||
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (:obj:`int`, `optional`, defaults to 102400):
|
||||
Vocabulary size of the EXAONE model. Defines the number of different tokens that can be represented by the
|
||||
:obj:`inputs_ids` passed when calling :class:`~transformers.ExaoneModel`. Vocabulary size of the model.
|
||||
Defines the different tokens that can be represented by the `inputs_ids` passed to the forward method of
|
||||
:class:`~transformers.EXAONEModel`.
|
||||
max_position_embeddings (:obj:`int`, `optional`, defaults to 2048):
|
||||
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
||||
just in case (e.g., 512 or 1024 or 2048).
|
||||
hidden_size (:obj:`int`, `optional`, defaults to 2048):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
num_layers (:obj:`int`, `optional`, defaults to 32):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (:obj:`int`, `optional`, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
num_key_value_heads (:obj:`int`, `optional`):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||
`num_attention_heads`.
|
||||
intermediate_size (:obj:`int`, `optional`, defaults to `hidden_size * 4`):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
rope_theta (:obj:`float`, `optional`, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (:obj:`Dict`, `optional`):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (:obj:`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (:obj:`float`, `optional`):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (:obj:`int`, `optional`):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (:obj:`float`, `optional`):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (:obj:`float`, `optional`):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (:obj:`float`, `optional`):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (:obj:`List[float]`, `optional`):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (:obj:`List[float]`, `optional`):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (:obj:`float`, `optional`):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (:obj:`float`, `optional`):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
embed_dropout (:obj:`float`, `optional`, defaults to 0.0):
|
||||
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_dropout (:obj:`float`, `optional`, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
layer_norm_epsilon (:obj:`float`, `optional`, defaults to 1e-5):
|
||||
The epsilon used by the layer normalization layers.
|
||||
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if ``configs.is_decoder=True``.
|
||||
bos_token_id (:obj:`int`, `optional`, defaults to 0):
|
||||
Beginning of stream token id.
|
||||
eos_token_id (:obj:`int`, `optional`, defaults to 2):
|
||||
End of stream token id.
|
||||
tie_word_embeddings (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether to tie weight embeddings
|
||||
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import EXAONEModel, ExaoneConfig
|
||||
|
||||
>>> # Initializing a EXAONE configuration
|
||||
>>> configuration = ExaoneConfig()
|
||||
|
||||
>>> # Initializing a model from configuration
|
||||
>>> model = EXAONEModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.configs
|
||||
"""
|
||||
|
||||
model_type = "exaone"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
attribute_map = {"num_hidden_layers": "num_layers"}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=102400,
|
||||
max_position_embeddings=2048,
|
||||
hidden_size=2048,
|
||||
num_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=None,
|
||||
intermediate_size=None,
|
||||
activation_function="silu",
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
embed_dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
layer_norm_epsilon=1e-5,
|
||||
initializer_range=0.02,
|
||||
use_cache=True,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
tie_word_embeddings=True,
|
||||
**kwargs
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_hidden_layers = num_layers
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
if intermediate_size:
|
||||
self.intermediate_size = intermediate_size
|
||||
else:
|
||||
self.intermediate_size = hidden_size * 4
|
||||
self.activation_function = activation_function
|
||||
self.embed_dropout = embed_dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.initializer_range = initializer_range
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
super().__init__(
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs
|
||||
)
|
||||
706
python/sglang/srt/configs/internvl.py
Normal file
706
python/sglang/srt/configs/internvl.py
Normal file
@@ -0,0 +1,706 @@
|
||||
import copy
|
||||
import os
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import sentencepiece as spm
|
||||
from transformers import (
|
||||
TOKENIZER_MAPPING,
|
||||
GptOssConfig,
|
||||
LlamaConfig,
|
||||
PretrainedConfig,
|
||||
PreTrainedTokenizer,
|
||||
Qwen2Config,
|
||||
Qwen3Config,
|
||||
Qwen3MoeConfig,
|
||||
)
|
||||
|
||||
from sglang.utils import logger
|
||||
|
||||
# Copied from: https://github.com/OpenGVLab/InternVL/blob/34a81000402bf8f716bab8c9b57aff1f6b436bd0/internvl_chat/internvl/model/internvl_chat/configuration_internvl_chat.py#L21
|
||||
|
||||
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "./tokenizer.model"}
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {}
|
||||
|
||||
|
||||
# Modified from transformers.model.llama.configuration_llama.LlamaConfig
|
||||
class InternLM2Config(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate
|
||||
an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of the InternLM2-7B.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 32000):
|
||||
Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`InternLM2Model`]
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 11008):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_key_value_heads (`int`, *optional*):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||
`num_attention_heads`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
||||
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
||||
just in case (e.g., 512 or 1024 or 2048).
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-12):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
|
||||
Whether to tie weight embeddings
|
||||
Example:
|
||||
|
||||
"""
|
||||
|
||||
model_type = "internlm2"
|
||||
_auto_class = "AutoConfig"
|
||||
|
||||
def __init__( # pylint: disable=W0102
|
||||
self,
|
||||
vocab_size=103168,
|
||||
hidden_size=4096,
|
||||
intermediate_size=11008,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=None,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=2048,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
tie_word_embeddings=False,
|
||||
bias=True,
|
||||
rope_theta=10000,
|
||||
rope_scaling=None,
|
||||
attn_implementation="eager",
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.bias = bias
|
||||
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self._rope_scaling_validation()
|
||||
|
||||
self.attn_implementation = attn_implementation
|
||||
if self.attn_implementation is None:
|
||||
self.attn_implementation = "eager"
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _rope_scaling_validation(self):
|
||||
"""
|
||||
Validate the `rope_scaling` configuration.
|
||||
"""
|
||||
if self.rope_scaling is None:
|
||||
return
|
||||
|
||||
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
||||
raise ValueError(
|
||||
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
|
||||
f"got {self.rope_scaling}"
|
||||
)
|
||||
rope_scaling_type = self.rope_scaling.get("type", None)
|
||||
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
||||
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
||||
raise ValueError(
|
||||
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
||||
)
|
||||
if (
|
||||
rope_scaling_factor is None
|
||||
or not isinstance(rope_scaling_factor, (float, int))
|
||||
or rope_scaling_factor < 1.0
|
||||
):
|
||||
raise ValueError(
|
||||
f"`rope_scaling`'s factor field must be a float|int >= 1, got {rope_scaling_factor=}, {type(rope_scaling_factor)=}"
|
||||
)
|
||||
if isinstance(rope_scaling_factor, int):
|
||||
rope_scaling_factor = float(rope_scaling_factor)
|
||||
|
||||
|
||||
class InternVisionConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to
|
||||
instantiate a vision encoder according to the specified arguments, defining the model architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
num_channels (`int`, *optional*, defaults to 3):
|
||||
Number of color channels in the input images (e.g., 3 for RGB).
|
||||
patch_size (`int`, *optional*, defaults to 14):
|
||||
The size (resolution) of each patch.
|
||||
image_size (`int`, *optional*, defaults to 224):
|
||||
The size (resolution) of each image.
|
||||
qkv_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to add a bias to the queries and values in the self-attention layers.
|
||||
hidden_size (`int`, *optional*, defaults to 3200):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
num_attention_heads (`int`, *optional*, defaults to 25):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
intermediate_size (`int`, *optional*, defaults to 12800):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
qk_normalization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the queries and keys in the self-attention layers.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 48):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
use_flash_attn (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use flash attention mechanism.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-6):
|
||||
The epsilon used by the layer normalization layers.
|
||||
dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
drop_path_rate (`float`, *optional*, defaults to 0.0):
|
||||
Dropout rate for stochastic depth.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
initializer_factor (`float`, *optional*, defaults to 0.1):
|
||||
A factor for layer scale.
|
||||
"""
|
||||
|
||||
model_type = "intern_vit_6b"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_channels=3,
|
||||
patch_size=14,
|
||||
image_size=224,
|
||||
qkv_bias=False,
|
||||
hidden_size=3200,
|
||||
num_attention_heads=25,
|
||||
intermediate_size=12800,
|
||||
qk_normalization=True,
|
||||
num_hidden_layers=48,
|
||||
use_flash_attn=True,
|
||||
hidden_act="gelu",
|
||||
layer_norm_eps=1e-6,
|
||||
dropout=0.0,
|
||||
drop_path_rate=0.0,
|
||||
attention_dropout=0.0,
|
||||
initializer_range=0.02,
|
||||
initializer_factor=0.1,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.dropout = dropout
|
||||
self.drop_path_rate = drop_path_rate
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_channels = num_channels
|
||||
self.patch_size = patch_size
|
||||
self.image_size = image_size
|
||||
self.initializer_range = initializer_range
|
||||
self.initializer_factor = initializer_factor
|
||||
self.attention_dropout = attention_dropout
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.hidden_act = hidden_act
|
||||
self.qkv_bias = qkv_bias
|
||||
self.qk_normalization = qk_normalization
|
||||
self.use_flash_attn = use_flash_attn
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
||||
) -> "PretrainedConfig":
|
||||
config_dict, kwargs = cls.get_config_dict(
|
||||
pretrained_model_name_or_path, **kwargs
|
||||
)
|
||||
|
||||
if "vision_config" in config_dict:
|
||||
config_dict = config_dict["vision_config"]
|
||||
|
||||
if (
|
||||
"model_type" in config_dict
|
||||
and hasattr(cls, "model_type")
|
||||
and config_dict["model_type"] != cls.model_type
|
||||
):
|
||||
logger.warning(
|
||||
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
||||
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
||||
)
|
||||
|
||||
return cls.from_dict(config_dict, **kwargs)
|
||||
|
||||
|
||||
class InternVLChatConfig(PretrainedConfig):
|
||||
model_type = "internvl_chat"
|
||||
is_composition = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config=None,
|
||||
llm_config=None,
|
||||
use_backbone_lora=0,
|
||||
use_llm_lora=0,
|
||||
pad2square=False,
|
||||
select_layer=-1,
|
||||
force_image_size=None,
|
||||
downsample_ratio=0.5,
|
||||
template=None,
|
||||
dynamic_image_size=False,
|
||||
use_thumbnail=False,
|
||||
ps_version="v1",
|
||||
min_dynamic_patch=1,
|
||||
max_dynamic_patch=6,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if vision_config is None:
|
||||
vision_config = {"architectures": ["InternVisionModel"]}
|
||||
logger.info(
|
||||
"vision_config is None. Initializing the InternVisionConfig with default values."
|
||||
)
|
||||
|
||||
if llm_config is None:
|
||||
llm_config = {"architectures": ["InternLM2ForCausalLM"]}
|
||||
logger.info(
|
||||
"llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`)."
|
||||
)
|
||||
|
||||
self.vision_config = InternVisionConfig(**vision_config)
|
||||
if llm_config.get("architectures")[0] == "LlamaForCausalLM":
|
||||
self.llm_config = LlamaConfig(**llm_config)
|
||||
elif llm_config.get("architectures")[0] == "InternLM2ForCausalLM":
|
||||
self.llm_config = InternLM2Config(**llm_config)
|
||||
elif llm_config.get("architectures")[0] == "Qwen2ForCausalLM":
|
||||
self.llm_config = Qwen2Config(**llm_config)
|
||||
elif llm_config.get("architectures")[0] == "Qwen3MoeForCausalLM":
|
||||
self.llm_config = Qwen3MoeConfig(**llm_config)
|
||||
elif llm_config.get("architectures")[0] == "Qwen3ForCausalLM":
|
||||
self.llm_config = Qwen3Config(**llm_config)
|
||||
elif llm_config.get("architectures")[0] == "GptOssForCausalLM":
|
||||
self.llm_config = GptOssConfig(**llm_config)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported architecture: {}".format(
|
||||
llm_config.get("architectures")[0]
|
||||
)
|
||||
)
|
||||
|
||||
self.use_backbone_lora = use_backbone_lora
|
||||
self.use_llm_lora = use_llm_lora
|
||||
self.pad2square = pad2square
|
||||
self.select_layer = select_layer
|
||||
self.force_image_size = force_image_size
|
||||
self.downsample_ratio = downsample_ratio
|
||||
self.template = template
|
||||
self.dynamic_image_size = dynamic_image_size
|
||||
self.use_thumbnail = use_thumbnail
|
||||
self.ps_version = ps_version # pixel shuffle version
|
||||
self.min_dynamic_patch = min_dynamic_patch
|
||||
self.max_dynamic_patch = max_dynamic_patch
|
||||
|
||||
self.hidden_size = self.llm_config.hidden_size
|
||||
# By default, we use tie_word_embeddings=False for models of all sizes.
|
||||
self.tie_word_embeddings = False
|
||||
self.llm_config.tie_word_embeddings = self.tie_word_embeddings
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
|
||||
|
||||
Returns:
|
||||
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
||||
"""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
output["vision_config"] = self.vision_config.to_dict()
|
||||
output["llm_config"] = self.llm_config.to_dict()
|
||||
output["model_type"] = self.__class__.model_type
|
||||
output["use_backbone_lora"] = self.use_backbone_lora
|
||||
output["use_llm_lora"] = self.use_llm_lora
|
||||
output["select_layer"] = self.select_layer
|
||||
output["force_image_size"] = self.force_image_size
|
||||
output["downsample_ratio"] = self.downsample_ratio
|
||||
output["template"] = self.template
|
||||
output["dynamic_image_size"] = self.dynamic_image_size
|
||||
output["use_thumbnail"] = self.use_thumbnail
|
||||
output["ps_version"] = self.ps_version
|
||||
output["min_dynamic_patch"] = self.min_dynamic_patch
|
||||
output["max_dynamic_patch"] = self.max_dynamic_patch
|
||||
|
||||
return output
|
||||
|
||||
|
||||
# # Modified from transformers.model.llama.tokenization_llama_fast.LlamaTokenizerFast -> InternLM2TokenizerFast
|
||||
# class InternLM2TokenizerFast(PreTrainedTokenizerFast):
|
||||
# vocab_files_names = VOCAB_FILES_NAMES
|
||||
# slow_tokenizer_class = InternLM2Tokenizer
|
||||
# padding_side = 'left'
|
||||
# model_input_names = ['input_ids', 'attention_mask']
|
||||
# _auto_class = 'AutoTokenizer'
|
||||
#
|
||||
# def __init__(
|
||||
# self,
|
||||
# vocab_file,
|
||||
# unk_token='<unk>',
|
||||
# bos_token='<s>',
|
||||
# eos_token='</s>',
|
||||
# pad_token='</s>',
|
||||
# sp_model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
# add_bos_token=True,
|
||||
# add_eos_token=False,
|
||||
# decode_with_prefix_space=False,
|
||||
# clean_up_tokenization_spaces=False,
|
||||
# **kwargs,
|
||||
# ):
|
||||
# super().__init__(
|
||||
# vocab_file=vocab_file,
|
||||
# unk_token=unk_token,
|
||||
# bos_token=bos_token,
|
||||
# eos_token=eos_token,
|
||||
# pad_token=pad_token,
|
||||
# sp_model_kwargs=sp_model_kwargs,
|
||||
# add_bos_token=add_bos_token,
|
||||
# add_eos_token=add_eos_token,
|
||||
# decode_with_prefix_space=decode_with_prefix_space,
|
||||
# clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
# **kwargs,
|
||||
# )
|
||||
# self._add_bos_token = add_bos_token
|
||||
# self._add_eos_token = add_eos_token
|
||||
# self.update_post_processor()
|
||||
# self.vocab_file = vocab_file
|
||||
#
|
||||
# @property
|
||||
# def can_save_slow_tokenizer(self) -> bool:
|
||||
# return os.path.isfile(self.vocab_file) if self.vocab_file else False
|
||||
#
|
||||
# def update_post_processor(self):
|
||||
# """
|
||||
# Updates the underlying post processor with the current `bos_token` and `eos_token`.
|
||||
# """
|
||||
# bos = self.bos_token
|
||||
# bos_token_id = self.bos_token_id
|
||||
# if bos is None and self.add_bos_token:
|
||||
# raise ValueError('add_bos_token = True but bos_token = None')
|
||||
#
|
||||
# eos = self.eos_token
|
||||
# eos_token_id = self.eos_token_id
|
||||
# if eos is None and self.add_eos_token:
|
||||
# raise ValueError('add_eos_token = True but eos_token = None')
|
||||
#
|
||||
# single = f"{(bos + ':0 ') if self.add_bos_token else ''}$A:0{(' ' + eos + ':0') if self.add_eos_token else ''}"
|
||||
# pair = f"{single}{(' ' + bos + ':1') if self.add_bos_token else ''} $B:1{(' ' + eos + ':1') if self.add_eos_token else ''}"
|
||||
#
|
||||
# special_tokens = []
|
||||
# if self.add_bos_token:
|
||||
# special_tokens.append((bos, bos_token_id))
|
||||
# if self.add_eos_token:
|
||||
# special_tokens.append((eos, eos_token_id))
|
||||
# self._tokenizer.post_processor = processors.TemplateProcessing(
|
||||
# single=single, pair=pair, special_tokens=special_tokens
|
||||
# )
|
||||
#
|
||||
# @property
|
||||
# def add_eos_token(self):
|
||||
# return self._add_eos_token
|
||||
#
|
||||
# @property
|
||||
# def add_bos_token(self):
|
||||
# return self._add_bos_token
|
||||
#
|
||||
# @add_eos_token.setter
|
||||
# def add_eos_token(self, value):
|
||||
# self._add_eos_token = value
|
||||
# self.update_post_processor()
|
||||
#
|
||||
# @add_bos_token.setter
|
||||
# def add_bos_token(self, value):
|
||||
# self._add_bos_token = value
|
||||
# self.update_post_processor()
|
||||
#
|
||||
# def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
# if not self.can_save_slow_tokenizer:
|
||||
# raise ValueError(
|
||||
# 'Your fast tokenizer does not have the necessary information to save the vocabulary for a slow '
|
||||
# 'tokenizer.'
|
||||
# )
|
||||
#
|
||||
# if not os.path.isdir(save_directory):
|
||||
# logger.error(f'Vocabulary path ({save_directory}) should be a directory')
|
||||
# return
|
||||
# out_vocab_file = os.path.join(
|
||||
# save_directory, (filename_prefix + '-' if filename_prefix else '') + VOCAB_FILES_NAMES['vocab_file']
|
||||
# )
|
||||
#
|
||||
# if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
|
||||
# copyfile(self.vocab_file, out_vocab_file)
|
||||
#
|
||||
# return (out_vocab_file,)
|
||||
|
||||
|
||||
# Modified from transformers.model.llama.tokenization_llama.LlamaTokenizer
|
||||
class InternLM2Tokenizer(PreTrainedTokenizer):
|
||||
"""
|
||||
Construct a InternLM2 tokenizer. Based on byte-level Byte-Pair-Encoding.
|
||||
|
||||
Args:
|
||||
vocab_file (`str`):
|
||||
Path to the vocabulary file.
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
_auto_class = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
unk_token="<unk>",
|
||||
bos_token="<s>",
|
||||
eos_token="</s>",
|
||||
pad_token="</s>",
|
||||
sp_model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
add_bos_token=True,
|
||||
add_eos_token=False,
|
||||
decode_with_prefix_space=False,
|
||||
clean_up_tokenization_spaces=False,
|
||||
**kwargs,
|
||||
):
|
||||
print("register succeed")
|
||||
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
||||
self.vocab_file = vocab_file
|
||||
self.add_bos_token = add_bos_token
|
||||
self.add_eos_token = add_eos_token
|
||||
self.decode_with_prefix_space = decode_with_prefix_space
|
||||
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
||||
self.sp_model.Load(vocab_file)
|
||||
self._no_prefix_space_tokens = None
|
||||
super().__init__(
|
||||
bos_token=bos_token,
|
||||
eos_token=eos_token,
|
||||
unk_token=unk_token,
|
||||
pad_token=pad_token,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def no_prefix_space_tokens(self):
|
||||
if self._no_prefix_space_tokens is None:
|
||||
vocab = self.convert_ids_to_tokens(list(range(self.vocab_size)))
|
||||
self._no_prefix_space_tokens = {
|
||||
i for i, tok in enumerate(vocab) if not tok.startswith("▁")
|
||||
}
|
||||
return self._no_prefix_space_tokens
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
"""Returns vocab size"""
|
||||
return self.sp_model.get_piece_size()
|
||||
|
||||
@property
|
||||
def bos_token_id(self) -> Optional[int]:
|
||||
return self.sp_model.bos_id()
|
||||
|
||||
@property
|
||||
def eos_token_id(self) -> Optional[int]:
|
||||
return self.sp_model.eos_id()
|
||||
|
||||
def get_vocab(self):
|
||||
"""Returns vocab as a dict"""
|
||||
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
||||
vocab.update(self.added_tokens_encoder)
|
||||
return vocab
|
||||
|
||||
def _tokenize(self, text):
|
||||
"""Returns a tokenized string."""
|
||||
return self.sp_model.encode(text, out_type=str)
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
"""Converts a token (str) in an id using the vocab."""
|
||||
return self.sp_model.piece_to_id(token)
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
token = self.sp_model.IdToPiece(index)
|
||||
return token
|
||||
|
||||
def _maybe_add_prefix_space(self, tokens, decoded):
|
||||
if tokens and tokens[0] not in self.no_prefix_space_tokens:
|
||||
return " " + decoded
|
||||
else:
|
||||
return decoded
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
"""Converts a sequence of tokens (string) in a single string."""
|
||||
current_sub_tokens = []
|
||||
out_string = ""
|
||||
prev_is_special = False
|
||||
for token in tokens:
|
||||
# make sure that special tokens are not decoded using sentencepiece model
|
||||
if token in self.all_special_tokens:
|
||||
if not prev_is_special:
|
||||
out_string += " "
|
||||
out_string += self.sp_model.decode(current_sub_tokens) + token
|
||||
prev_is_special = True
|
||||
current_sub_tokens = []
|
||||
else:
|
||||
current_sub_tokens.append(token)
|
||||
prev_is_special = False
|
||||
out_string += self.sp_model.decode(current_sub_tokens)
|
||||
out_string = self.clean_up_tokenization(out_string)
|
||||
out_string = self._maybe_add_prefix_space(tokens=tokens, decoded=out_string)
|
||||
return out_string[1:]
|
||||
|
||||
def save_vocabulary(
|
||||
self, save_directory, filename_prefix: Optional[str] = None
|
||||
) -> Tuple[str]:
|
||||
"""
|
||||
Save the vocabulary and special tokens file to a directory.
|
||||
|
||||
Args:
|
||||
save_directory (`str`):
|
||||
The directory in which to save the vocabulary.
|
||||
|
||||
Returns:
|
||||
`Tuple(str)`: Paths to the files saved.
|
||||
"""
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
||||
return
|
||||
out_vocab_file = os.path.join(
|
||||
save_directory,
|
||||
(filename_prefix + "-" if filename_prefix else "")
|
||||
+ VOCAB_FILES_NAMES["vocab_file"],
|
||||
)
|
||||
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(
|
||||
out_vocab_file
|
||||
) and os.path.isfile(self.vocab_file):
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
elif not os.path.isfile(self.vocab_file):
|
||||
with open(out_vocab_file, "wb") as fi:
|
||||
content_spiece_model = self.sp_model.serialized_model_proto()
|
||||
fi.write(content_spiece_model)
|
||||
|
||||
return (out_vocab_file,)
|
||||
|
||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
||||
if self.add_bos_token:
|
||||
bos_token_ids = [self.bos_token_id]
|
||||
else:
|
||||
bos_token_ids = []
|
||||
|
||||
output = bos_token_ids + token_ids_0
|
||||
|
||||
if token_ids_1 is not None:
|
||||
output = output + token_ids_1
|
||||
|
||||
if self.add_eos_token:
|
||||
output = output + [self.eos_token_id]
|
||||
|
||||
return output
|
||||
|
||||
def get_special_tokens_mask(
|
||||
self,
|
||||
token_ids_0: List[int],
|
||||
token_ids_1: Optional[List[int]] = None,
|
||||
already_has_special_tokens: bool = False,
|
||||
) -> List[int]:
|
||||
"""
|
||||
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||
special tokens using the tokenizer `prepare_for_model` method.
|
||||
|
||||
Args:
|
||||
token_ids_0 (`List[int]`):
|
||||
List of IDs.
|
||||
token_ids_1 (`List[int]`, *optional*):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the token list is already formatted with special tokens for the model.
|
||||
|
||||
Returns:
|
||||
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
||||
"""
|
||||
if already_has_special_tokens:
|
||||
return super().get_special_tokens_mask(
|
||||
token_ids_0=token_ids_0,
|
||||
token_ids_1=token_ids_1,
|
||||
already_has_special_tokens=True,
|
||||
)
|
||||
|
||||
if token_ids_1 is None:
|
||||
return [1] + ([0] * len(token_ids_0)) + [1]
|
||||
return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
|
||||
|
||||
def create_token_type_ids_from_sequences(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
|
||||
use of token type ids, therefore a list of zeros is returned.
|
||||
|
||||
Args:
|
||||
token_ids_0 (`List[int]`):
|
||||
List of IDs.
|
||||
token_ids_1 (`List[int]`, *optional*):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
`List[int]`: List of zeros.
|
||||
"""
|
||||
eos = [self.eos_token_id]
|
||||
|
||||
if token_ids_1 is None:
|
||||
return len(token_ids_0 + eos) * [0]
|
||||
return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
|
||||
|
||||
|
||||
TOKENIZER_MAPPING.register(
|
||||
InternVLChatConfig, (InternLM2Tokenizer, None), exist_ok=True
|
||||
)
|
||||
634
python/sglang/srt/configs/janus_pro.py
Normal file
634
python/sglang/srt/configs/janus_pro.py
Normal file
@@ -0,0 +1,634 @@
|
||||
# Adapted from:
|
||||
# https://github.com/deepseek-ai/Janus/tree/main/janus/models
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
from transformers import (
|
||||
BaseImageProcessor,
|
||||
BatchFeature,
|
||||
LlamaConfig,
|
||||
LlamaTokenizerFast,
|
||||
PretrainedConfig,
|
||||
ProcessorMixin,
|
||||
)
|
||||
from transformers.image_utils import to_numpy_array
|
||||
|
||||
from sglang.srt.configs.utils import register_image_processor, register_processor
|
||||
from sglang.srt.multimodal.mm_utils import expand2square
|
||||
|
||||
|
||||
class DictToObject(dict):
|
||||
def __init__(self, dictionary):
|
||||
super(self).__init__(dictionary)
|
||||
|
||||
for key, value in dictionary.items():
|
||||
if isinstance(value, dict):
|
||||
value = DictToObject(value)
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class VisionConfig(PretrainedConfig):
|
||||
model_type = "vision"
|
||||
cls: str = ""
|
||||
params = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.cls = kwargs.get("cls", "")
|
||||
if not isinstance(self.cls, str):
|
||||
self.cls = self.cls.__name__
|
||||
|
||||
self.params = kwargs.get("params", {})
|
||||
|
||||
|
||||
class GenAlignerConfig(PretrainedConfig):
|
||||
model_type = "gen_aligner"
|
||||
cls: str = ""
|
||||
params = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.cls = kwargs.get("cls", "")
|
||||
if not isinstance(self.cls, str):
|
||||
self.cls = self.cls.__name__
|
||||
|
||||
self.params = kwargs.get("params", {})
|
||||
|
||||
|
||||
class GenHeadConfig(PretrainedConfig):
|
||||
model_type = "gen_head"
|
||||
cls: str = ""
|
||||
params = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.cls = kwargs.get("cls", "")
|
||||
if not isinstance(self.cls, str):
|
||||
self.cls = self.cls.__name__
|
||||
|
||||
self.params = kwargs.get("params", {})
|
||||
|
||||
|
||||
class AlignerConfig(PretrainedConfig):
|
||||
model_type = "aligner"
|
||||
cls: str = ""
|
||||
params = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.cls = kwargs.get("cls", "")
|
||||
if not isinstance(self.cls, str):
|
||||
self.cls = self.cls.__name__
|
||||
|
||||
self.params = kwargs.get("params", {})
|
||||
|
||||
|
||||
class GenVisionConfig(PretrainedConfig):
|
||||
model_type = "gen_vision"
|
||||
cls: str = ""
|
||||
params = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.cls = kwargs.get("cls", "")
|
||||
if not isinstance(self.cls, str):
|
||||
self.cls = self.cls.__name__
|
||||
|
||||
self.params = kwargs.get("params", {})
|
||||
|
||||
|
||||
@dataclass
|
||||
class SigLIPVisionCfg:
|
||||
width: int = 1152
|
||||
layers: Union[Tuple[int, int, int, int], int] = 27
|
||||
heads: int = 16
|
||||
patch_size: int = 14
|
||||
image_size: Union[Tuple[int, int], int] = 336
|
||||
global_pool: str = "map"
|
||||
mlp_ratio: float = 3.7362
|
||||
class_token: bool = False
|
||||
num_classes: int = 0
|
||||
use_checkpoint: bool = False
|
||||
|
||||
|
||||
class MultiModalityConfig(PretrainedConfig):
|
||||
model_type = "multi_modality"
|
||||
vision_config: VisionConfig
|
||||
aligner_config: AlignerConfig
|
||||
|
||||
gen_vision_config: GenVisionConfig
|
||||
gen_aligner_config: GenAlignerConfig
|
||||
gen_head_config: GenHeadConfig
|
||||
|
||||
language_config: LlamaConfig
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
vision_config = kwargs.get("vision_config", {})
|
||||
self.vision_config = VisionConfig(**vision_config)
|
||||
|
||||
aligner_config = kwargs.get("aligner_config", {})
|
||||
self.aligner_config = AlignerConfig(**aligner_config)
|
||||
|
||||
gen_vision_config = kwargs.get("gen_vision_config", {})
|
||||
self.gen_vision_config = GenVisionConfig(**gen_vision_config)
|
||||
|
||||
gen_aligner_config = kwargs.get("gen_aligner_config", {})
|
||||
self.gen_aligner_config = GenAlignerConfig(**gen_aligner_config)
|
||||
|
||||
gen_head_config = kwargs.get("gen_head_config", {})
|
||||
self.gen_head_config = GenHeadConfig(**gen_head_config)
|
||||
|
||||
language_config = kwargs.get("language_config", {})
|
||||
if isinstance(language_config, LlamaConfig):
|
||||
self.language_config = language_config
|
||||
else:
|
||||
self.language_config = LlamaConfig(**language_config)
|
||||
|
||||
|
||||
class VLMImageProcessor(BaseImageProcessor):
|
||||
model_input_names = ["pixel_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size: int,
|
||||
min_size: int = 14,
|
||||
image_mean: Union[Tuple[float, float, float], List[float]] = (
|
||||
0.48145466,
|
||||
0.4578275,
|
||||
0.40821073,
|
||||
),
|
||||
image_std: Union[Tuple[float, float, float], List[float]] = (
|
||||
0.26862954,
|
||||
0.26130258,
|
||||
0.27577711,
|
||||
),
|
||||
rescale_factor: float = 1.0 / 255.0,
|
||||
do_normalize: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.image_size = image_size
|
||||
self.rescale_factor = rescale_factor
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.min_size = min_size
|
||||
self.do_normalize = do_normalize
|
||||
|
||||
if image_mean is None:
|
||||
self.background_color = (127, 127, 127)
|
||||
else:
|
||||
self.background_color = tuple([int(x * 255) for x in image_mean])
|
||||
|
||||
def resize(self, pil_img: Image) -> np.ndarray:
|
||||
"""
|
||||
|
||||
Args:
|
||||
pil_img (PIL.Image): [H, W, 3] in PIL.Image in RGB
|
||||
|
||||
Returns:
|
||||
x (np.ndarray): [3, self.image_size, self.image_size]
|
||||
"""
|
||||
|
||||
width, height = pil_img.size
|
||||
max_size = max(width, height)
|
||||
|
||||
size = [
|
||||
max(int(height / max_size * self.image_size), self.min_size),
|
||||
max(int(width / max_size * self.image_size), self.min_size),
|
||||
]
|
||||
|
||||
if width <= 0 or height <= 0 or size[0] <= 0 or size[1] <= 0:
|
||||
# print(f"orig size = {pil_img.size}, new size = {size}")
|
||||
raise ValueError("Invalid size!")
|
||||
|
||||
def resize(
|
||||
pil_img, size, interpolation=PIL.Image.Resampling.BICUBIC, antialias=True
|
||||
):
|
||||
if isinstance(size, int):
|
||||
w, h = pil_img.size
|
||||
if (w <= h and w == size) or (h <= w and h == size):
|
||||
return pil_img
|
||||
if w < h:
|
||||
ow = size
|
||||
oh = int(size * h / w)
|
||||
else:
|
||||
oh = size
|
||||
ow = int(size * w / h)
|
||||
size = (ow, oh)
|
||||
else:
|
||||
size = (size[1], size[0])
|
||||
|
||||
return pil_img.resize(
|
||||
size, resample=interpolation, reducing_gap=None if antialias else 3.0
|
||||
)
|
||||
|
||||
pil_img = resize(
|
||||
pil_img, size, interpolation=PIL.Image.Resampling.BICUBIC, antialias=True
|
||||
)
|
||||
|
||||
pil_img = expand2square(pil_img, self.background_color)
|
||||
x = to_numpy_array(pil_img)
|
||||
|
||||
# [H, W, 3] -> [3, H, W]
|
||||
x = np.transpose(x, (2, 0, 1))
|
||||
|
||||
return x
|
||||
|
||||
def preprocess(self, images, return_tensors: str = "pt", **kwargs) -> BatchFeature:
|
||||
# resize and pad to [self.image_size, self.image_size]
|
||||
# then convert from [H, W, 3] to [3, H, W]
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
images: List[np.ndarray] = [self.resize(image) for image in images]
|
||||
images = [image[:3, ...] for image in images]
|
||||
|
||||
# rescale from [0, 255] -> [0, 1]
|
||||
images = [
|
||||
self.rescale(
|
||||
image=image,
|
||||
scale=self.rescale_factor,
|
||||
input_data_format="channels_first",
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
|
||||
# normalize
|
||||
if self.do_normalize:
|
||||
images = [
|
||||
self.normalize(
|
||||
image=image,
|
||||
mean=self.image_mean,
|
||||
std=self.image_std,
|
||||
input_data_format="channels_first",
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
data = {"pixel_values": images}
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
@property
|
||||
def default_shape(self):
|
||||
return [3, self.image_size, self.image_size]
|
||||
|
||||
|
||||
class DictOutput(object):
|
||||
def items(self):
|
||||
return self.__dict__.items()
|
||||
|
||||
def keys(self):
|
||||
return self.__dict__.keys()
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.__dict__[item]
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self.__dict__
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.__dict__[key] = value
|
||||
|
||||
|
||||
@dataclass
|
||||
class VLChatProcessorOutput(DictOutput):
|
||||
sft_format: str
|
||||
input_ids: torch.Tensor
|
||||
pixel_values: torch.Tensor
|
||||
num_image_tokens: torch.IntTensor
|
||||
|
||||
def __len__(self):
|
||||
return len(self.input_ids)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchedVLChatProcessorOutput(DictOutput):
|
||||
sft_format: List[str]
|
||||
input_ids: torch.Tensor
|
||||
pixel_values: torch.Tensor
|
||||
attention_mask: torch.Tensor
|
||||
images_seq_mask: torch.BoolTensor
|
||||
images_emb_mask: torch.BoolTensor
|
||||
|
||||
|
||||
# FIXME: had to place Official Processor here, since image_processor module would not be imported in all threads,
|
||||
# hence AutoProcessor registration would not be affective in some cases
|
||||
class VLChatProcessor(ProcessorMixin):
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_processor: VLMImageProcessor,
|
||||
tokenizer: LlamaTokenizerFast,
|
||||
image_tag: str = "<image_placeholder>",
|
||||
image_start_tag: str = "<begin_of_image>",
|
||||
image_end_tag: str = "<end_of_image>",
|
||||
pad_tag: str = "<|▁pad▁|>",
|
||||
num_image_tokens: int = 576,
|
||||
add_special_token: bool = False,
|
||||
sft_format: str = "deepseek",
|
||||
mask_prompt: bool = True,
|
||||
ignore_id: int = -100,
|
||||
**kwargs,
|
||||
):
|
||||
self.image_processor = image_processor
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
image_id = self.tokenizer.vocab.get(image_tag)
|
||||
if image_id is None:
|
||||
special_tokens = [image_tag]
|
||||
special_tokens_dict = {"additional_special_tokens": special_tokens}
|
||||
self.tokenizer.add_special_tokens(special_tokens_dict)
|
||||
# print(f"Add image tag = {image_tag} to the tokenizer")
|
||||
|
||||
self.image_tag = image_tag
|
||||
self.image_start_tag = image_start_tag
|
||||
self.image_end_tag = image_end_tag
|
||||
self.pad_tag = pad_tag
|
||||
|
||||
self.num_image_tokens = num_image_tokens
|
||||
self.add_special_token = add_special_token
|
||||
self.sft_format = sft_format
|
||||
self.ignore_id = ignore_id
|
||||
|
||||
super().__init__(
|
||||
image_processor,
|
||||
tokenizer,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def image_token(self):
|
||||
return self.image_tag
|
||||
|
||||
@property
|
||||
def image_id(self) -> int:
|
||||
image_id = self.tokenizer.vocab.get(self.image_tag)
|
||||
return image_id
|
||||
|
||||
@property
|
||||
def image_start_id(self):
|
||||
image_start_id = self.tokenizer.vocab.get(self.image_start_tag)
|
||||
return image_start_id
|
||||
|
||||
@property
|
||||
def image_end_id(self):
|
||||
image_end_id = self.tokenizer.vocab.get(self.image_end_tag)
|
||||
return image_end_id
|
||||
|
||||
@property
|
||||
def image_start_token(self):
|
||||
return self.image_start_tag
|
||||
|
||||
@property
|
||||
def image_end_token(self):
|
||||
return self.image_end_tag
|
||||
|
||||
@property
|
||||
def pad_id(self):
|
||||
pad_id = self.tokenizer.vocab.get(self.pad_tag)
|
||||
return pad_id
|
||||
|
||||
def add_image_token(
|
||||
self,
|
||||
image_indices: List[int],
|
||||
input_ids: torch.LongTensor,
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
image_indices (List[int]): [index_0, index_1, ..., index_j]
|
||||
input_ids (torch.LongTensor): [N]
|
||||
|
||||
Returns:
|
||||
input_ids (torch.LongTensor): [N + image tokens]
|
||||
num_image_tokens (torch.IntTensor): [n_images]
|
||||
"""
|
||||
|
||||
input_slices = []
|
||||
|
||||
start = 0
|
||||
for index in image_indices:
|
||||
if self.add_special_token:
|
||||
end = index + 1
|
||||
else:
|
||||
end = index
|
||||
|
||||
# original text tokens
|
||||
input_slices.append(input_ids[start:end])
|
||||
|
||||
# add boi, image tokens, eoi and set the mask as False
|
||||
input_slices.append(self.image_start_id * torch.ones((1), dtype=torch.long))
|
||||
input_slices.append(
|
||||
self.image_id * torch.ones((self.num_image_tokens,), dtype=torch.long)
|
||||
)
|
||||
input_slices.append(self.image_end_id * torch.ones((1), dtype=torch.long))
|
||||
start = index + 1
|
||||
|
||||
# the left part
|
||||
input_slices.append(input_ids[start:])
|
||||
|
||||
# concat all slices
|
||||
input_ids = torch.cat(input_slices, dim=0)
|
||||
num_image_tokens = torch.IntTensor([self.num_image_tokens] * len(image_indices))
|
||||
|
||||
return input_ids, num_image_tokens
|
||||
|
||||
def process_one(
|
||||
self,
|
||||
prompt: str = None,
|
||||
images: List[Image] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
prompt (str): the formatted prompt;
|
||||
images (List[ImageType]): the list of images;
|
||||
**kwargs:
|
||||
|
||||
Returns:
|
||||
outputs (BaseProcessorOutput): the output of the processor,
|
||||
- input_ids (torch.LongTensor): [N + image tokens]
|
||||
- target_ids (torch.LongTensor): [N + image tokens]
|
||||
- images (torch.FloatTensor): [n_images, 3, H, W]
|
||||
- image_id (int): the id of the image token
|
||||
- num_image_tokens (List[int]): the number of image tokens
|
||||
"""
|
||||
|
||||
sft_format = prompt
|
||||
# tokenize
|
||||
input_ids = self.tokenizer.encode(sft_format)
|
||||
input_ids = torch.LongTensor(input_ids)
|
||||
|
||||
# add image tokens to the input_ids
|
||||
image_token_mask: torch.Tensor = (input_ids == self.image_id).to(torch.bool)
|
||||
image_indices = image_token_mask.nonzero()
|
||||
input_ids, num_image_tokens = self.add_image_token(
|
||||
image_indices=image_indices,
|
||||
input_ids=input_ids,
|
||||
)
|
||||
|
||||
# load images
|
||||
images_outputs = self.image_processor(images, return_tensors="pt")
|
||||
|
||||
prepare = VLChatProcessorOutput(
|
||||
sft_format=sft_format,
|
||||
input_ids=input_ids,
|
||||
pixel_values=images_outputs.pixel_values,
|
||||
num_image_tokens=num_image_tokens,
|
||||
)
|
||||
|
||||
return prepare
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
*,
|
||||
prompt: str = None,
|
||||
conversations: List[Dict[str, str]] = None,
|
||||
images: List[Image] = None,
|
||||
force_batchify: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
prompt (str): the formatted prompt;
|
||||
conversations (List[Dict]): conversations with a list of messages;
|
||||
images (List[ImageType]): the list of images;
|
||||
force_batchify (bool): force batchify the inputs;
|
||||
**kwargs:
|
||||
|
||||
Returns:
|
||||
outputs (BaseProcessorOutput): the output of the processor,
|
||||
- input_ids (torch.LongTensor): [N + image tokens]
|
||||
- images (torch.FloatTensor): [n_images, 3, H, W]
|
||||
- image_id (int): the id of the image token
|
||||
- num_image_tokens (List[int]): the number of image tokens
|
||||
"""
|
||||
|
||||
prepare = self.process_one(
|
||||
prompt=prompt, conversations=conversations, images=images
|
||||
)
|
||||
|
||||
if force_batchify:
|
||||
prepare = self.batchify([prepare])
|
||||
|
||||
return prepare
|
||||
|
||||
def batchify(
|
||||
self, prepare_list: List[VLChatProcessorOutput]
|
||||
) -> BatchedVLChatProcessorOutput:
|
||||
"""
|
||||
Preprocesses the inputs for multimodal inference.
|
||||
|
||||
Args:
|
||||
prepare_list (List[VLChatProcessorOutput]): A list of VLChatProcessorOutput.
|
||||
|
||||
Returns:
|
||||
BatchedVLChatProcessorOutput: A dictionary of the inputs to use for multimodal inference.
|
||||
"""
|
||||
|
||||
batch_size = len(prepare_list)
|
||||
sft_format = []
|
||||
n_images = []
|
||||
seq_lens = []
|
||||
for prepare in prepare_list:
|
||||
n_images.append(len(prepare.num_image_tokens))
|
||||
seq_lens.append(len(prepare))
|
||||
|
||||
input_token_max_len = max(seq_lens)
|
||||
max_n_images = max(1, max(n_images))
|
||||
|
||||
batched_input_ids = torch.full(
|
||||
(batch_size, input_token_max_len), self.pad_id
|
||||
).long() # FIXME
|
||||
batched_attention_mask = torch.zeros((batch_size, input_token_max_len)).long()
|
||||
batched_pixel_values = torch.zeros(
|
||||
(batch_size, max_n_images, *self.image_processor.default_shape)
|
||||
).float()
|
||||
batched_images_seq_mask = torch.zeros((batch_size, input_token_max_len)).bool()
|
||||
batched_images_emb_mask = torch.zeros(
|
||||
(batch_size, max_n_images, self.num_image_tokens)
|
||||
).bool()
|
||||
|
||||
for i, prepare in enumerate(prepare_list):
|
||||
input_ids = prepare.input_ids
|
||||
seq_len = len(prepare)
|
||||
n_image = len(prepare.num_image_tokens)
|
||||
# left-padding
|
||||
batched_attention_mask[i, -seq_len:] = 1
|
||||
batched_input_ids[i, -seq_len:] = torch.LongTensor(input_ids)
|
||||
batched_images_seq_mask[i, -seq_len:] = input_ids == self.image_id
|
||||
|
||||
if n_image > 0:
|
||||
batched_pixel_values[i, :n_image] = prepare.pixel_values
|
||||
for j, n_image_tokens in enumerate(prepare.num_image_tokens):
|
||||
batched_images_emb_mask[i, j, :n_image_tokens] = True
|
||||
|
||||
sft_format.append(prepare.sft_format)
|
||||
|
||||
batched_prepares = BatchedVLChatProcessorOutput(
|
||||
input_ids=batched_input_ids,
|
||||
attention_mask=batched_attention_mask,
|
||||
pixel_values=batched_pixel_values,
|
||||
images_seq_mask=batched_images_seq_mask,
|
||||
images_emb_mask=batched_images_emb_mask,
|
||||
sft_format=sft_format,
|
||||
)
|
||||
|
||||
return batched_prepares
|
||||
|
||||
|
||||
class VLMImageProcessorConfig(PretrainedConfig):
|
||||
model_type = "deepseek_vlm"
|
||||
image_size: int
|
||||
min_size: int
|
||||
image_mean: Union[Tuple[float, float, float], List[float]]
|
||||
image_std: Union[Tuple[float, float, float], List[float]]
|
||||
rescale_factor: float
|
||||
do_normalize: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size: int,
|
||||
min_size: int = 14,
|
||||
image_mean: Union[Tuple[float, float, float], List[float]] = (
|
||||
0.48145466,
|
||||
0.4578275,
|
||||
0.40821073,
|
||||
),
|
||||
image_std: Union[Tuple[float, float, float], List[float]] = (
|
||||
0.26862954,
|
||||
0.26130258,
|
||||
0.27577711,
|
||||
),
|
||||
rescale_factor: float = 1.0 / 255.0,
|
||||
do_normalize: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
self.image_size = image_size
|
||||
self.min_size = min_size
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.rescale_factor = rescale_factor
|
||||
self.do_normalize = do_normalize
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
register_processor(MultiModalityConfig, VLChatProcessor)
|
||||
register_image_processor(MultiModalityConfig, VLMImageProcessor)
|
||||
38
python/sglang/srt/configs/kimi_vl.py
Normal file
38
python/sglang/srt/configs/kimi_vl.py
Normal file
@@ -0,0 +1,38 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py
|
||||
from typing import Optional, Union
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from sglang.srt.configs.deepseekvl2 import DeepseekV2Config
|
||||
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
|
||||
|
||||
|
||||
class KimiVLConfig(PretrainedConfig):
|
||||
model_type = "kimi_vl"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config: Optional[Union[dict, MoonViTConfig]] = None,
|
||||
text_config: Optional[Union[dict, DeepseekV2Config]] = None,
|
||||
ignore_index: int = -100,
|
||||
media_placeholder_token_id: int = 163605,
|
||||
pad_token_id: int = 0,
|
||||
**kwargs
|
||||
):
|
||||
if vision_config is None:
|
||||
vision_config = MoonViTConfig()
|
||||
elif isinstance(vision_config, dict):
|
||||
vision_config = MoonViTConfig(**vision_config)
|
||||
self.vision_config = vision_config
|
||||
|
||||
if text_config is None:
|
||||
text_config = DeepseekV2Config()
|
||||
elif isinstance(text_config, dict):
|
||||
text_config = DeepseekV2Config(**text_config)
|
||||
self.text_config = text_config
|
||||
|
||||
self.ignore_index = ignore_index
|
||||
self.media_placeholder_token_id = media_placeholder_token_id
|
||||
|
||||
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
||||
32
python/sglang/srt/configs/kimi_vl_moonvit.py
Normal file
32
python/sglang/srt/configs/kimi_vl_moonvit.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
class MoonViTConfig(PretrainedConfig):
|
||||
model_type = "moonvit"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 14,
|
||||
init_pos_emb_height: int = 64,
|
||||
init_pos_emb_width: int = 64,
|
||||
num_attention_heads: int = 16,
|
||||
num_hidden_layers: int = 27,
|
||||
hidden_size: int = 1152,
|
||||
intermediate_size: int = 4304,
|
||||
merge_kernel_size: tuple[int, int] = (2, 2),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.patch_size = patch_size
|
||||
# Positional embedding config
|
||||
self.init_pos_emb_height = init_pos_emb_height
|
||||
self.init_pos_emb_width = init_pos_emb_width
|
||||
# Transformer config
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
# Patch merger config
|
||||
self.merge_kernel_size = merge_kernel_size
|
||||
89
python/sglang/srt/configs/load_config.py
Normal file
89
python/sglang/srt/configs/load_config.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
||||
import enum
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoadFormat(str, enum.Enum):
|
||||
AUTO = "auto"
|
||||
PT = "pt"
|
||||
SAFETENSORS = "safetensors"
|
||||
NPCACHE = "npcache"
|
||||
DUMMY = "dummy"
|
||||
SHARDED_STATE = "sharded_state"
|
||||
GGUF = "gguf"
|
||||
BITSANDBYTES = "bitsandbytes"
|
||||
MISTRAL = "mistral"
|
||||
LAYERED = "layered"
|
||||
JAX = "jax"
|
||||
REMOTE = "remote"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadConfig:
|
||||
"""
|
||||
download_dir: Directory to download and load the weights, default to the
|
||||
default cache directory of huggingface.
|
||||
load_format: The format of the model weights to load:
|
||||
"auto" will try to load the weights in the safetensors format and
|
||||
fall back to the pytorch bin format if safetensors format is
|
||||
not available.
|
||||
"pt" will load the weights in the pytorch bin format.
|
||||
"safetensors" will load the weights in the safetensors format.
|
||||
"npcache" will load the weights in pytorch format and store
|
||||
a numpy cache to speed up the loading.
|
||||
"dummy" will initialize the weights with random values, which is
|
||||
mainly for profiling.
|
||||
"bitsandbytes" will load nf4 type weights.
|
||||
ignore_patterns: The list of patterns to ignore when loading the model.
|
||||
Default to "original/**/*" to avoid repeated loading of llama's
|
||||
checkpoints.
|
||||
decryption_key_file: If set, decrypts the output files with a password read
|
||||
from this file (after PBKDF2).
|
||||
"""
|
||||
|
||||
load_format: Union[str, LoadFormat] = LoadFormat.AUTO
|
||||
download_dir: Optional[str] = None
|
||||
model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)
|
||||
ignore_patterns: Optional[Union[List[str], str]] = None
|
||||
decryption_key_file: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
model_loader_extra_config = self.model_loader_extra_config or {}
|
||||
if isinstance(model_loader_extra_config, str):
|
||||
self.model_loader_extra_config = json.loads(model_loader_extra_config)
|
||||
self._verify_load_format()
|
||||
|
||||
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
|
||||
logger.info(
|
||||
"Ignoring the following patterns when downloading weights: %s",
|
||||
self.ignore_patterns,
|
||||
)
|
||||
else:
|
||||
self.ignore_patterns = ["original/**/*"]
|
||||
|
||||
def _verify_load_format(self) -> None:
|
||||
if not isinstance(self.load_format, str):
|
||||
return
|
||||
|
||||
load_format = self.load_format.lower()
|
||||
self.load_format = LoadFormat(load_format)
|
||||
|
||||
rocm_not_supported_load_format: List[str] = []
|
||||
if is_hip() and load_format in rocm_not_supported_load_format:
|
||||
rocm_supported_load_format = [
|
||||
f
|
||||
for f in LoadFormat.__members__
|
||||
if (f not in rocm_not_supported_load_format)
|
||||
]
|
||||
raise ValueError(
|
||||
f"load format '{load_format}' is not supported in ROCm. "
|
||||
f"Supported load formats are "
|
||||
f"{rocm_supported_load_format}"
|
||||
)
|
||||
104
python/sglang/srt/configs/longcat_flash.py
Normal file
104
python/sglang/srt/configs/longcat_flash.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
FLASH_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
|
||||
|
||||
|
||||
class LongcatFlashConfig(PretrainedConfig):
|
||||
model_type = "longcat_flash"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=131072,
|
||||
hidden_size=6144,
|
||||
intermediate_size=None,
|
||||
ffn_hidden_size=12288,
|
||||
expert_ffn_hidden_size=2048,
|
||||
num_layers=28,
|
||||
num_hidden_layers=None,
|
||||
num_attention_heads=64,
|
||||
ep_size=1,
|
||||
kv_lora_rank=512,
|
||||
q_lora_rank=1536,
|
||||
qk_rope_head_dim=128,
|
||||
qk_nope_head_dim=128,
|
||||
v_head_dim=128,
|
||||
n_routed_experts=512,
|
||||
moe_topk=12,
|
||||
norm_topk_prob=False,
|
||||
max_position_embeddings=131072,
|
||||
rms_norm_eps=1e-05,
|
||||
use_cache=True,
|
||||
pad_token_id=None,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
pretraining_tp=1,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000000.0,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
mla_scale_q_lora=True,
|
||||
mla_scale_kv_lora=True,
|
||||
torch_dtype="bfloat16",
|
||||
params_dtype="bfloat16",
|
||||
rounter_params_dtype="float32",
|
||||
router_bias=False,
|
||||
topk_method=None,
|
||||
routed_scaling_factor=6.0,
|
||||
zero_expert_num=256,
|
||||
zero_expert_type="identity",
|
||||
nextn_use_scmoe=False,
|
||||
num_nextn_predict_layers=1,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
torch_dtype=torch_dtype,
|
||||
params_dtype=params_dtype,
|
||||
rounter_params_dtype=rounter_params_dtype,
|
||||
topk_method=topk_method,
|
||||
router_bias=router_bias,
|
||||
nextn_use_scmoe=nextn_use_scmoe,
|
||||
num_nextn_predict_layers=num_nextn_predict_layers,
|
||||
**kwargs,
|
||||
)
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = (
|
||||
num_hidden_layers if num_hidden_layers is not None else num_layers
|
||||
)
|
||||
self.intermediate_size = (
|
||||
intermediate_size if intermediate_size is not None else ffn_hidden_size
|
||||
)
|
||||
self.moe_intermediate_size = expert_ffn_hidden_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.ep_size = ep_size
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.n_routed_experts = n_routed_experts
|
||||
self.moe_topk = moe_topk
|
||||
self.norm_topk_prob = norm_topk_prob
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.pretraining_tp = pretraining_tp
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.mla_scale_q_lora = mla_scale_q_lora
|
||||
self.mla_scale_kv_lora = mla_scale_kv_lora
|
||||
self.zero_expert_num = zero_expert_num
|
||||
self.zero_expert_type = zero_expert_type
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
self.hidden_act = "silu"
|
||||
794
python/sglang/srt/configs/model_config.py
Normal file
794
python/sglang/srt/configs/model_config.py
Normal file
@@ -0,0 +1,794 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from enum import Enum, IntEnum, auto
|
||||
from typing import List, Optional, Set, Union
|
||||
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt.hf_transformers_utils import (
|
||||
get_config,
|
||||
get_context_length,
|
||||
get_generation_config,
|
||||
get_hf_text_config,
|
||||
get_sparse_attention_config,
|
||||
)
|
||||
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import get_bool_env_var, is_hip
|
||||
from sglang.utils import is_in_ci
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AttentionArch(IntEnum):
|
||||
MLA = auto()
|
||||
MHA = auto()
|
||||
|
||||
|
||||
class ModelImpl(str, Enum):
|
||||
AUTO = "auto"
|
||||
SGLANG = "sglang"
|
||||
TRANSFORMERS = "transformers"
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
trust_remote_code: bool = True,
|
||||
revision: Optional[str] = None,
|
||||
context_length: Optional[int] = None,
|
||||
model_override_args: str = "{}",
|
||||
is_embedding: Optional[bool] = None,
|
||||
enable_multimodal: Optional[bool] = None,
|
||||
dtype: str = "auto",
|
||||
quantization: Optional[str] = None,
|
||||
override_config_file: Optional[str] = None,
|
||||
is_draft_model: bool = False,
|
||||
hybrid_kvcache_ratio: Optional[float] = None,
|
||||
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
|
||||
) -> None:
|
||||
# Parse args
|
||||
self.model_path = model_path
|
||||
self.revision = revision
|
||||
self.quantization = quantization
|
||||
self.model_impl = model_impl
|
||||
|
||||
self.maybe_pull_model_tokenizer_from_remote()
|
||||
self.model_override_args = json.loads(model_override_args)
|
||||
kwargs = {}
|
||||
if override_config_file and override_config_file.strip():
|
||||
kwargs["_configuration_file"] = override_config_file.strip()
|
||||
|
||||
self.hf_config = get_config(
|
||||
self.model_path,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
model_override_args=self.model_override_args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.hf_generation_config = get_generation_config(
|
||||
self.model_path,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||
self.attention_chunk_size = getattr(
|
||||
self.hf_text_config, "attention_chunk_size", None
|
||||
)
|
||||
self.is_hybrid = is_hybrid_model(
|
||||
self.hf_config.architectures,
|
||||
hybrid_kvcache_ratio=hybrid_kvcache_ratio,
|
||||
context_length=context_length,
|
||||
attention_chunk_size=self.attention_chunk_size,
|
||||
)
|
||||
if self.is_hybrid is not None:
|
||||
self.swa_attention_layer_ids, self.full_attention_layer_ids = (
|
||||
get_hybrid_layer_ids(
|
||||
self.hf_config.architectures, self.hf_text_config.num_hidden_layers
|
||||
)
|
||||
)
|
||||
|
||||
if enable_multimodal is None:
|
||||
mm_disabled_models = [
|
||||
"Gemma3ForConditionalGeneration",
|
||||
"Llama4ForConditionalGeneration",
|
||||
"Step3VLForConditionalGeneration",
|
||||
]
|
||||
if self.hf_config.architectures[0] in mm_disabled_models:
|
||||
enable_multimodal = False
|
||||
logger.info(
|
||||
f"Multimodal is disabled for {self.hf_config.model_type}. To enable it, set --enable-multimodal."
|
||||
)
|
||||
else:
|
||||
enable_multimodal = True
|
||||
|
||||
if (
|
||||
is_draft_model
|
||||
and self.hf_config.architectures[0] == "DeepseekV3ForCausalLM"
|
||||
):
|
||||
self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN"
|
||||
|
||||
if is_draft_model and self.hf_config.architectures[0] == "Glm4MoeForCausalLM":
|
||||
self.hf_config.architectures[0] = "Glm4MoeForCausalLMNextN"
|
||||
|
||||
if (
|
||||
is_draft_model
|
||||
and self.hf_config.architectures[0] == "LongcatFlashForCausalLM"
|
||||
):
|
||||
self.hf_config.architectures[0] = "LongcatFlashForCausalLMNextN"
|
||||
self.hf_config.num_hidden_layers = self.hf_config.num_nextn_predict_layers
|
||||
|
||||
if is_draft_model and self.hf_config.architectures[0] == "MiMoForCausalLM":
|
||||
self.hf_config.architectures[0] = "MiMoMTP"
|
||||
if (
|
||||
is_draft_model
|
||||
and self.hf_config.architectures[0] == "Ernie4_5_MoeForCausalLM"
|
||||
):
|
||||
self.hf_config.architectures[0] = "Ernie4_5_MoeForCausalLMMTP"
|
||||
|
||||
# Check model type
|
||||
self.is_generation = is_generation_model(
|
||||
self.hf_config.architectures, is_embedding
|
||||
)
|
||||
self.is_multimodal = enable_multimodal and is_multimodal_model(
|
||||
self.hf_config.architectures
|
||||
)
|
||||
self.is_multimodal_gen = enable_multimodal and is_multimodal_gen_model(
|
||||
self.hf_config.architectures
|
||||
)
|
||||
self.is_image_gen = enable_multimodal and is_image_gen_model(
|
||||
self.hf_config.architectures
|
||||
)
|
||||
self.is_audio_model = enable_multimodal and is_audio_model(
|
||||
self.hf_config.architectures
|
||||
)
|
||||
self.is_multimodal_chunked_prefill_supported = (
|
||||
enable_multimodal
|
||||
and is_multimodal_chunked_prefill_supported(self.hf_config.architectures)
|
||||
)
|
||||
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
|
||||
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
||||
|
||||
# Derive context length
|
||||
derived_context_len = get_context_length(self.hf_text_config)
|
||||
if context_length is not None:
|
||||
if context_length > derived_context_len:
|
||||
reason = "Target model's" if is_draft_model else "User-specified"
|
||||
msg = (
|
||||
f"Warning: {reason} context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
|
||||
f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config."
|
||||
)
|
||||
if (
|
||||
get_bool_env_var("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN")
|
||||
or is_in_ci() # FIXME: fix this special case
|
||||
):
|
||||
logger.warning(msg)
|
||||
self.context_len = context_length
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{msg} To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1"
|
||||
)
|
||||
else:
|
||||
self.context_len = context_length
|
||||
else:
|
||||
self.context_len = derived_context_len
|
||||
|
||||
# Unify the config keys for hf_text_config
|
||||
self.head_dim = getattr(
|
||||
self.hf_text_config,
|
||||
"head_dim",
|
||||
self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads,
|
||||
)
|
||||
|
||||
# FIXME: temporary special judge for MLA architecture
|
||||
if (
|
||||
"DeepseekV2ForCausalLM" in self.hf_config.architectures
|
||||
or "DeepseekV3ForCausalLM" in self.hf_config.architectures
|
||||
or "DeepseekV3ForCausalLMNextN" in self.hf_config.architectures
|
||||
or "LongcatFlashForCausalLM" in self.hf_config.architectures
|
||||
or "LongcatFlashForCausalLMNextN" in self.hf_config.architectures
|
||||
):
|
||||
self.head_dim = 256
|
||||
self.attention_arch = AttentionArch.MLA
|
||||
self.kv_lora_rank = self.hf_config.kv_lora_rank
|
||||
self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim
|
||||
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
|
||||
self.v_head_dim = self.hf_config.v_head_dim
|
||||
|
||||
# Handle rope scaling with yarn
|
||||
self.scaling = 1 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)
|
||||
if self.hf_config.rope_scaling:
|
||||
mscale_all_dim = self.hf_config.rope_scaling.get(
|
||||
"mscale_all_dim", False
|
||||
)
|
||||
scaling_factor = self.hf_config.rope_scaling["factor"]
|
||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||
self.scaling = self.scaling * mscale * mscale
|
||||
|
||||
elif "MiniCPM3ForCausalLM" in self.hf_config.architectures:
|
||||
self.head_dim = 128
|
||||
self.attention_arch = AttentionArch.MLA
|
||||
self.kv_lora_rank = self.hf_config.kv_lora_rank
|
||||
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
|
||||
elif "DeepseekVL2ForCausalLM" in self.hf_config.architectures and getattr(
|
||||
self.hf_text_config, "use_mla", True
|
||||
):
|
||||
self.head_dim = 256
|
||||
self.attention_arch = AttentionArch.MLA
|
||||
self.kv_lora_rank = self.hf_text_config.kv_lora_rank
|
||||
self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
|
||||
elif "KimiVLForConditionalGeneration" in self.hf_config.architectures:
|
||||
self.head_dim = 256
|
||||
self.attention_arch = AttentionArch.MLA
|
||||
self.kv_lora_rank = self.hf_text_config.kv_lora_rank
|
||||
self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
|
||||
self.v_head_dim = self.hf_text_config.v_head_dim
|
||||
self.qk_nope_head_dim = self.hf_text_config.qk_nope_head_dim
|
||||
else:
|
||||
if (
|
||||
"MistralModel" in self.hf_config.architectures
|
||||
or "MixtralForCausalLM" in self.hf_config.architectures
|
||||
or "MistralForCausalLM" in self.hf_config.architectures
|
||||
):
|
||||
if getattr(self, "head_dim", None) is None:
|
||||
self.head_dim = (
|
||||
self.hf_config.hidden_size // self.hf_config.num_attention_heads
|
||||
)
|
||||
# In transformers==4.52.3, the head_dim is null in MistralConfig
|
||||
if (
|
||||
not hasattr(self.hf_text_config, "head_dim")
|
||||
or self.hf_text_config.head_dim is None
|
||||
):
|
||||
setattr(self.hf_text_config, "head_dim", self.head_dim)
|
||||
|
||||
self.attention_arch = AttentionArch.MHA
|
||||
|
||||
self.num_attention_heads = self.hf_text_config.num_attention_heads
|
||||
self.num_key_value_heads = getattr(
|
||||
self.hf_text_config, "num_key_value_heads", None
|
||||
)
|
||||
|
||||
# for Dbrx and MPT models
|
||||
if self.hf_config.model_type in ["dbrx", "mpt"]:
|
||||
self.num_key_value_heads = getattr(
|
||||
self.hf_config.attn_config, "kv_n_heads", None
|
||||
)
|
||||
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
self.hidden_size = self.hf_text_config.hidden_size
|
||||
self.num_hidden_layers = self.hf_text_config.num_hidden_layers
|
||||
self.num_attention_layers = self.num_hidden_layers
|
||||
if "LongcatFlashForCausalLM" in self.hf_config.architectures:
|
||||
self.num_attention_layers = self.num_hidden_layers * 2
|
||||
self.num_nextn_predict_layers = getattr(
|
||||
self.hf_text_config, "num_nextn_predict_layers", None
|
||||
)
|
||||
self.vocab_size = self.hf_text_config.vocab_size
|
||||
|
||||
# Verify quantization
|
||||
self._verify_quantization()
|
||||
|
||||
# Verify dual-chunk attention config
|
||||
self._verify_dual_chunk_attention_config()
|
||||
|
||||
# Cache attributes
|
||||
self.hf_eos_token_id = self.get_hf_eos_token_id()
|
||||
|
||||
# multimodal
|
||||
self.image_token_id = getattr(
|
||||
self.hf_config, "image_token_id", None
|
||||
) or getattr(self.hf_config, "image_token_index", None)
|
||||
|
||||
@staticmethod
|
||||
def from_server_args(server_args: ServerArgs, model_path: str = None, **kwargs):
|
||||
return ModelConfig(
|
||||
model_path=model_path or server_args.model_path,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
revision=server_args.revision,
|
||||
context_length=server_args.context_length,
|
||||
model_override_args=server_args.json_model_override_args,
|
||||
is_embedding=server_args.is_embedding,
|
||||
enable_multimodal=server_args.enable_multimodal,
|
||||
dtype=server_args.dtype,
|
||||
quantization=server_args.quantization,
|
||||
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
|
||||
model_impl=server_args.model_impl,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_total_num_attention_heads(self) -> int:
|
||||
return self.num_attention_heads
|
||||
|
||||
def get_num_attention_heads(self, tensor_parallel_size) -> int:
|
||||
total_num_attention_heads = self.num_attention_heads
|
||||
return max(1, total_num_attention_heads // tensor_parallel_size)
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
|
||||
def get_total_num_kv_heads(self) -> int:
|
||||
"""Returns the total number of KV heads."""
|
||||
# For GPTBigCode & Falcon:
|
||||
# NOTE: for falcon, when new_decoder_architecture is True, the
|
||||
# multi_query flag is ignored and we use n_head_kv for the number of
|
||||
# KV heads.
|
||||
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
|
||||
new_decoder_arch_falcon = (
|
||||
self.hf_config.model_type in falcon_model_types
|
||||
and getattr(self.hf_config, "new_decoder_architecture", False)
|
||||
)
|
||||
if not new_decoder_arch_falcon and getattr(
|
||||
self.hf_text_config, "multi_query", False
|
||||
):
|
||||
# Multi-query attention, only one KV head.
|
||||
# Currently, tensor parallelism is not supported in this case.
|
||||
return 1
|
||||
|
||||
# For DBRX and MPT
|
||||
if self.hf_config.model_type in ["mpt"]:
|
||||
if "kv_n_heads" in self.hf_config.attn_config:
|
||||
return self.hf_config.attn_config["kv_n_heads"]
|
||||
return self.hf_config.num_attention_heads
|
||||
if self.hf_config.model_type in ["dbrx"]:
|
||||
return getattr(
|
||||
self.hf_config.attn_config,
|
||||
"kv_n_heads",
|
||||
self.hf_config.num_attention_heads,
|
||||
)
|
||||
if self.hf_config.model_type in ["nemotron-nas"]:
|
||||
nkvh = {
|
||||
self.hf_config.num_attention_heads // block.attention.n_heads_in_group
|
||||
for block in self.hf_config.block_configs
|
||||
if not block.attention.no_op
|
||||
}
|
||||
if len(nkvh) == 0:
|
||||
raise RuntimeError("Couldn't determine number of kv heads")
|
||||
if len(nkvh) > 1:
|
||||
raise ValueError(
|
||||
"Variable GQA (VGQA) is not yet supported for nemotron-nas in sglang"
|
||||
)
|
||||
return next(iter(nkvh))
|
||||
|
||||
attributes = [
|
||||
# For Falcon:
|
||||
"n_head_kv",
|
||||
"num_kv_heads",
|
||||
# For LLaMA-2:
|
||||
"num_key_value_heads",
|
||||
# For ChatGLM:
|
||||
"multi_query_group_num",
|
||||
# For Step3
|
||||
"num_attention_groups",
|
||||
]
|
||||
for attr in attributes:
|
||||
num_kv_heads = getattr(self.hf_text_config, attr, None)
|
||||
if num_kv_heads is not None:
|
||||
return num_kv_heads
|
||||
|
||||
# For non-grouped-query attention models, the number of KV heads is
|
||||
# equal to the number of attention heads.
|
||||
return self.hf_text_config.num_attention_heads
|
||||
|
||||
def get_num_kv_heads(self, tensor_parallel_size) -> int:
|
||||
"""Returns the number of KV heads per GPU."""
|
||||
total_num_kv_heads = self.get_total_num_kv_heads()
|
||||
# If tensor parallelism is used, we divide the number of KV heads by
|
||||
# the tensor parallel size. We will replicate the KV heads in the
|
||||
# case where the number of KV heads is smaller than the tensor
|
||||
# parallel size so each GPU has at least one KV head.
|
||||
return max(1, total_num_kv_heads // tensor_parallel_size)
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
||||
def _parse_quant_hf_config(self):
|
||||
quant_cfg = getattr(self.hf_config, "quantization_config", None)
|
||||
if quant_cfg is None:
|
||||
# compressed-tensors uses a "compression_config" key
|
||||
quant_cfg = getattr(self.hf_config, "compression_config", None)
|
||||
if quant_cfg is None:
|
||||
# check if is modelopt or mixed-precision model -- Both of them don't have corresponding field
|
||||
# in hf `config.json` but has a standalone `hf_quant_config.json` in the root directory
|
||||
# example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main
|
||||
# example: https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/tree/main
|
||||
is_local = os.path.exists(self.model_path)
|
||||
modelopt_quant_config = {"quant_method": "modelopt"}
|
||||
if not is_local:
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
hf_api = HfApi()
|
||||
if hf_api.file_exists(self.model_path, "hf_quant_config.json"):
|
||||
quant_cfg = modelopt_quant_config
|
||||
elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")):
|
||||
quant_config_file = os.path.join(
|
||||
self.model_path, "hf_quant_config.json"
|
||||
)
|
||||
with open(quant_config_file) as f:
|
||||
quant_config_dict = json.load(f)
|
||||
json_quant_configs = quant_config_dict["quantization"]
|
||||
quant_algo = json_quant_configs.get("quant_algo", None)
|
||||
if quant_algo == "MIXED_PRECISION":
|
||||
quant_cfg = {"quant_method": "w4afp8"}
|
||||
else:
|
||||
quant_cfg = modelopt_quant_config
|
||||
return quant_cfg
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
||||
def _verify_quantization(self) -> None:
|
||||
supported_quantization = [*QUANTIZATION_METHODS]
|
||||
rocm_supported_quantization = [
|
||||
"awq",
|
||||
"gptq",
|
||||
"fp8",
|
||||
"compressed_tensors",
|
||||
"compressed-tensors",
|
||||
"fbgemm_fp8",
|
||||
"w8a8_fp8",
|
||||
"petit_nvfp4",
|
||||
"quark",
|
||||
"mxfp4",
|
||||
]
|
||||
optimized_quantization_methods = [
|
||||
"fp8",
|
||||
"marlin",
|
||||
"modelopt",
|
||||
"gptq_marlin_24",
|
||||
"gptq_marlin",
|
||||
"awq_marlin",
|
||||
"fbgemm_fp8",
|
||||
"compressed_tensors",
|
||||
"compressed-tensors",
|
||||
"experts_int8",
|
||||
"w8a8_int8",
|
||||
"w8a8_fp8",
|
||||
"moe_wna16",
|
||||
"qoq",
|
||||
"w4afp8",
|
||||
"petit_nvfp4",
|
||||
]
|
||||
compatible_quantization_methods = {
|
||||
"modelopt_fp4": ["modelopt"],
|
||||
"petit_nvfp4": ["modelopt"],
|
||||
"w8a8_int8": ["compressed-tensors", "compressed_tensors"],
|
||||
"w8a8_fp8": ["compressed-tensors", "compressed_tensors"],
|
||||
}
|
||||
if self.quantization is not None:
|
||||
self.quantization = self.quantization.lower()
|
||||
|
||||
# Parse quantization method from the HF model config, if available.
|
||||
quant_cfg = self._parse_quant_hf_config()
|
||||
|
||||
if quant_cfg is not None:
|
||||
quant_method = quant_cfg.get(
|
||||
"quant_method", "" if not self.quantization else self.quantization
|
||||
).lower()
|
||||
|
||||
# Detect which checkpoint is it
|
||||
for _, method in QUANTIZATION_METHODS.items():
|
||||
quantization_override = method.override_quantization_method(
|
||||
quant_cfg, self.quantization
|
||||
)
|
||||
if quantization_override:
|
||||
quant_method = quantization_override
|
||||
self.quantization = quantization_override
|
||||
break
|
||||
|
||||
# Verify quantization configurations.
|
||||
if self.quantization is None:
|
||||
self.quantization = quant_method
|
||||
elif self.quantization != quant_method:
|
||||
if (
|
||||
self.quantization not in compatible_quantization_methods
|
||||
or quant_method
|
||||
not in compatible_quantization_methods[self.quantization]
|
||||
):
|
||||
raise ValueError(
|
||||
"Quantization method specified in the model config "
|
||||
f"({quant_method}) does not match the quantization "
|
||||
f"method specified in the `quantization` argument "
|
||||
f"({self.quantization})."
|
||||
)
|
||||
|
||||
if self.quantization is not None:
|
||||
if self.quantization not in supported_quantization:
|
||||
raise ValueError(
|
||||
f"Unknown quantization method: {self.quantization}. Must "
|
||||
f"be one of {supported_quantization}."
|
||||
)
|
||||
if is_hip() and self.quantization not in rocm_supported_quantization:
|
||||
raise ValueError(
|
||||
f"{self.quantization} quantization is currently not "
|
||||
f"supported in ROCm."
|
||||
)
|
||||
if self.quantization not in optimized_quantization_methods:
|
||||
logger.warning(
|
||||
"%s quantization is not fully "
|
||||
"optimized yet. The speed can be slower than "
|
||||
"non-quantized models.",
|
||||
self.quantization,
|
||||
)
|
||||
|
||||
def _verify_dual_chunk_attention_config(self) -> None:
|
||||
if hasattr(self.hf_config, "dual_chunk_attention_config"):
|
||||
# Try loading the sparse attention config
|
||||
sparse_attn_config = get_sparse_attention_config(self.model_path)
|
||||
if not sparse_attn_config:
|
||||
return
|
||||
self.hf_config.dual_chunk_attention_config["sparse_attention_config"] = (
|
||||
sparse_attn_config
|
||||
)
|
||||
if (
|
||||
"sparse_attention_enabled"
|
||||
not in self.hf_config.dual_chunk_attention_config
|
||||
):
|
||||
self.hf_config.dual_chunk_attention_config[
|
||||
"sparse_attention_enabled"
|
||||
] = True
|
||||
|
||||
def get_hf_eos_token_id(self) -> Optional[Set[int]]:
|
||||
eos_ids = getattr(self.hf_config, "eos_token_id", None)
|
||||
if eos_ids is not None:
|
||||
# it can be either int or list of int
|
||||
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
|
||||
if eos_ids is None:
|
||||
eos_ids = set()
|
||||
if self.hf_generation_config:
|
||||
generation_eos_ids = getattr(
|
||||
self.hf_generation_config, "eos_token_id", None
|
||||
)
|
||||
if generation_eos_ids:
|
||||
generation_eos_ids = (
|
||||
{generation_eos_ids}
|
||||
if isinstance(generation_eos_ids, int)
|
||||
else set(generation_eos_ids)
|
||||
)
|
||||
eos_ids = eos_ids | generation_eos_ids
|
||||
return eos_ids
|
||||
|
||||
def maybe_pull_model_tokenizer_from_remote(self) -> None:
|
||||
"""
|
||||
Pull the model config files to a temporary
|
||||
directory in case of remote.
|
||||
|
||||
Args:
|
||||
model: The model name or path.
|
||||
|
||||
"""
|
||||
from sglang.srt.connector import create_remote_connector
|
||||
from sglang.srt.utils import is_remote_url
|
||||
|
||||
if is_remote_url(self.model_path):
|
||||
logger.info("Pulling model configs from remote...")
|
||||
# BaseConnector implements __del__() to clean up the local dir.
|
||||
# Since config files need to exist all the time, so we DO NOT use
|
||||
# with statement to avoid closing the client.
|
||||
client = create_remote_connector(self.model_path)
|
||||
if is_remote_url(self.model_path):
|
||||
client.pull_files(allow_pattern=["*config.json"])
|
||||
self.model_weights = self.model_path
|
||||
self.model_path = client.get_local_dir()
|
||||
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"half": torch.float16,
|
||||
"float16": torch.float16,
|
||||
"float": torch.float32,
|
||||
"float32": torch.float32,
|
||||
"bfloat16": torch.bfloat16,
|
||||
}
|
||||
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
||||
def _get_and_verify_dtype(
|
||||
config: PretrainedConfig,
|
||||
dtype: Union[str, torch.dtype],
|
||||
) -> torch.dtype:
|
||||
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
|
||||
# because config.torch_dtype can be None.
|
||||
config_dtype = getattr(config, "torch_dtype", None)
|
||||
if isinstance(config_dtype, str):
|
||||
config_dtype = _STR_DTYPE_TO_TORCH_DTYPE.get(config_dtype, None)
|
||||
if config_dtype is None:
|
||||
config_dtype = torch.float32
|
||||
|
||||
if isinstance(dtype, str):
|
||||
dtype = dtype.lower()
|
||||
if dtype == "auto":
|
||||
if config_dtype == torch.float32:
|
||||
if config.model_type.startswith("gemma"):
|
||||
if config.model_type == "gemma":
|
||||
gemma_version = ""
|
||||
else:
|
||||
gemma_version = config.model_type[5]
|
||||
logger.info(
|
||||
f"For Gemma {gemma_version}, we downcast float32 to bfloat16 instead "
|
||||
"of float16 by default. Please specify `dtype` if you "
|
||||
"want to use float16."
|
||||
)
|
||||
torch_dtype = torch.bfloat16
|
||||
else:
|
||||
# Following the common practice, we use float16 for float32
|
||||
# models.
|
||||
torch_dtype = torch.float16
|
||||
else:
|
||||
torch_dtype = config_dtype
|
||||
else:
|
||||
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||
elif isinstance(dtype, torch.dtype):
|
||||
torch_dtype = dtype
|
||||
else:
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
|
||||
# Verify the dtype.
|
||||
if torch_dtype != config_dtype:
|
||||
if torch_dtype == torch.float32:
|
||||
# Upcasting to float32 is allowed.
|
||||
logger.info("Upcasting %s to %s.", config_dtype, torch_dtype)
|
||||
pass
|
||||
elif config_dtype == torch.float32:
|
||||
# Downcasting from float32 to float16 or bfloat16 is allowed.
|
||||
logger.info("Downcasting %s to %s.", config_dtype, torch_dtype)
|
||||
pass
|
||||
else:
|
||||
# Casting between float16 and bfloat16 is allowed with a warning.
|
||||
logger.warning("Casting %s to %s.", config_dtype, torch_dtype)
|
||||
|
||||
return torch_dtype
|
||||
|
||||
|
||||
def is_generation_model(model_architectures: List[str], is_embedding: bool = False):
|
||||
# We have two ways to determine whether a model is a generative model.
|
||||
# 1. Check the model architecture
|
||||
# 2. check the `is_embedding` server args
|
||||
|
||||
if (
|
||||
"LlamaEmbeddingModel" in model_architectures
|
||||
or "MistralModel" in model_architectures
|
||||
or "LlamaForSequenceClassification" in model_architectures
|
||||
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
|
||||
or "InternLM2ForRewardModel" in model_architectures
|
||||
or "Qwen2ForRewardModel" in model_architectures
|
||||
or "Qwen2ForSequenceClassification" in model_architectures
|
||||
or "Qwen3ForSequenceClassification" in model_architectures
|
||||
or "CLIPModel" in model_architectures
|
||||
or "BertModel" in model_architectures
|
||||
or "Contriever" in model_architectures
|
||||
or "BertForSequenceClassification" in model_architectures
|
||||
or "XLMRobertaModel" in model_architectures
|
||||
or "XLMRobertaForSequenceClassification" in model_architectures
|
||||
):
|
||||
return False
|
||||
else:
|
||||
return not is_embedding
|
||||
|
||||
|
||||
multimodal_model_archs = [
|
||||
"CLIPModel",
|
||||
"DeepseekVL2ForCausalLM",
|
||||
"Gemma3ForConditionalGeneration",
|
||||
"Gemma3nForConditionalGeneration",
|
||||
"Glm4vForConditionalGeneration",
|
||||
"Glm4vMoeForConditionalGeneration",
|
||||
"Grok1VForCausalLM",
|
||||
"Grok1AForCausalLM",
|
||||
"LlavaLlamaForCausalLM",
|
||||
"Llama4ForConditionalGeneration",
|
||||
"LlavaMistralForCausalLM",
|
||||
"LlavaQwenForCausalLM",
|
||||
"LlavaForConditionalGeneration",
|
||||
"LlavaVidForCausalLM",
|
||||
"MiniCPMO",
|
||||
"MiniCPMV",
|
||||
"Mistral3ForConditionalGeneration",
|
||||
"MultiModalityCausalLM",
|
||||
"MllamaForConditionalGeneration",
|
||||
"Qwen2AudioForConditionalGeneration",
|
||||
"Qwen2VLForConditionalGeneration",
|
||||
"Qwen2_5_VLForConditionalGeneration",
|
||||
"KimiVLForConditionalGeneration",
|
||||
"InternVLChatModel",
|
||||
"InternS1ForConditionalGeneration",
|
||||
"Phi4MMForCausalLM",
|
||||
"VILAForConditionalGeneration",
|
||||
"Step3VLForConditionalGeneration",
|
||||
]
|
||||
|
||||
|
||||
def is_multimodal_model(model_architectures: List[str]):
|
||||
if any(
|
||||
multi_model_arch in model_architectures
|
||||
for multi_model_arch in multimodal_model_archs
|
||||
):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def is_multimodal_gen_model(model_architectures: List[str]):
|
||||
return False
|
||||
|
||||
|
||||
def is_image_gen_model(model_architectures: List[str]):
|
||||
return False
|
||||
|
||||
|
||||
def is_audio_model(model_architectures: List[str]):
|
||||
return False
|
||||
|
||||
|
||||
def is_encoder_decoder_model(model_architectures: List[str]):
|
||||
return "MllamaForConditionalGeneration" in model_architectures
|
||||
|
||||
|
||||
def is_multimodal_chunked_prefill_supported(model_architectures: List[str]):
|
||||
"""Check if chunked prefill is supported for a MultiModal model."""
|
||||
unsupported = [
|
||||
"Grok1VForCausalLM",
|
||||
"Grok1AForCausalLM",
|
||||
"LlavaLlamaForCausalLM",
|
||||
"MllamaForConditionalGeneration",
|
||||
"CLIPModel",
|
||||
]
|
||||
if any(multi_model_arch in unsupported for multi_model_arch in model_architectures):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
|
||||
def is_hybrid_model(
|
||||
model_architectures: List[str],
|
||||
hybrid_kvcache_ratio: Optional[float],
|
||||
context_length: Optional[int],
|
||||
attention_chunk_size: Optional[int],
|
||||
):
|
||||
if hybrid_kvcache_ratio is None:
|
||||
return None
|
||||
elif (
|
||||
hybrid_kvcache_ratio > 0
|
||||
and model_architectures[0] == "Llama4ForConditionalGeneration"
|
||||
and context_length > attention_chunk_size
|
||||
):
|
||||
return hybrid_kvcache_ratio
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_hybrid_layer_ids(model_architectures: List[str], num_hidden_layers: int):
|
||||
if "Llama4ForConditionalGeneration" in model_architectures:
|
||||
swa_attention_layer_ids = [
|
||||
i for i in range(num_hidden_layers) if (i + 1) % 4 != 0
|
||||
]
|
||||
full_attention_layer_ids = [
|
||||
i for i in range(num_hidden_layers) if (i + 1) % 4 == 0
|
||||
]
|
||||
else:
|
||||
swa_attention_layer_ids = None
|
||||
full_attention_layer_ids = None
|
||||
return swa_attention_layer_ids, full_attention_layer_ids
|
||||
172
python/sglang/srt/configs/step3_vl.py
Normal file
172
python/sglang/srt/configs/step3_vl.py
Normal file
@@ -0,0 +1,172 @@
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
class Step3VisionEncoderConfig(PretrainedConfig):
|
||||
model_type = "step3_vision_encoder"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=1792,
|
||||
intermediate_size=3072,
|
||||
output_hidden_size=4096,
|
||||
num_hidden_layers=63,
|
||||
num_attention_heads=16,
|
||||
num_channels=3,
|
||||
image_size=728,
|
||||
patch_size=14,
|
||||
hidden_act="quick_gelu",
|
||||
layer_norm_eps=1e-5,
|
||||
**kwargs,
|
||||
):
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.output_hidden_size = output_hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_channels = num_channels
|
||||
self.patch_size = patch_size
|
||||
self.image_size = image_size
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.hidden_act = hidden_act
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class Step3TextConfig(PretrainedConfig):
|
||||
model_type = "step3_text"
|
||||
architectures = ["Step3TextForCausalLM"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 7168,
|
||||
intermediate_size: int = 18432,
|
||||
num_attention_heads: int = 64,
|
||||
num_attention_groups: int = 1,
|
||||
num_hidden_layers: int = 61,
|
||||
max_seq_len: int = 65536,
|
||||
vocab_size: int = 128815,
|
||||
rms_norm_eps: float = 1e-5,
|
||||
moe_intermediate_size: int = 5120,
|
||||
moe_num_experts: int = 48,
|
||||
moe_top_k: int = 3,
|
||||
rope_theta: float = 500000,
|
||||
rope_scaling: Optional[dict[str, Any]] = None,
|
||||
max_position_embedding: int = 65536,
|
||||
share_expert_dim: int = 5120,
|
||||
share_q_dim: int = 2048,
|
||||
head_dim: int = 256,
|
||||
norm_expert_weight: bool = False,
|
||||
moe_layers_enum: tuple[int] = (
|
||||
4,
|
||||
5,
|
||||
6,
|
||||
7,
|
||||
8,
|
||||
9,
|
||||
10,
|
||||
11,
|
||||
12,
|
||||
13,
|
||||
14,
|
||||
15,
|
||||
16,
|
||||
17,
|
||||
18,
|
||||
19,
|
||||
20,
|
||||
21,
|
||||
22,
|
||||
23,
|
||||
24,
|
||||
25,
|
||||
26,
|
||||
27,
|
||||
28,
|
||||
29,
|
||||
30,
|
||||
31,
|
||||
32,
|
||||
33,
|
||||
34,
|
||||
35,
|
||||
36,
|
||||
37,
|
||||
38,
|
||||
39,
|
||||
40,
|
||||
41,
|
||||
42,
|
||||
43,
|
||||
44,
|
||||
45,
|
||||
46,
|
||||
47,
|
||||
48,
|
||||
49,
|
||||
50,
|
||||
51,
|
||||
52,
|
||||
53,
|
||||
54,
|
||||
55,
|
||||
56,
|
||||
57,
|
||||
58,
|
||||
59,
|
||||
),
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_attention_groups = num_attention_groups
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.max_seq_len = max_seq_len
|
||||
self.vocab_size = vocab_size
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.moe_intermediate_size = moe_intermediate_size
|
||||
self.moe_num_experts = moe_num_experts
|
||||
self.moe_top_k = moe_top_k
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.max_position_embedding = max_position_embedding
|
||||
self.share_expert_dim = share_expert_dim
|
||||
self.share_q_dim = share_q_dim
|
||||
self.head_dim = head_dim
|
||||
self.norm_expert_weight = norm_expert_weight
|
||||
self.moe_layers_enum = moe_layers_enum
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class Step3VLConfig(PretrainedConfig):
|
||||
model_type = "step3_vl"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config: Optional[Union[dict, Step3VisionEncoderConfig]] = None,
|
||||
text_config: Optional[Union[dict, Step3TextConfig]] = None,
|
||||
understand_projector_stride: int = 1,
|
||||
projector_bias: bool = True,
|
||||
image_token_id: int = 128001,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if vision_config is None:
|
||||
vision_config = Step3VisionEncoderConfig()
|
||||
elif isinstance(vision_config, dict):
|
||||
vision_config = Step3VisionEncoderConfig(**vision_config)
|
||||
self.vision_config = vision_config
|
||||
|
||||
if text_config is None:
|
||||
text_config = Step3TextConfig()
|
||||
elif isinstance(text_config, dict):
|
||||
text_config = Step3TextConfig(**text_config)
|
||||
self.text_config = text_config
|
||||
|
||||
self.understand_projector_stride = understand_projector_stride
|
||||
self.projector_bias = projector_bias
|
||||
self.hidden_size = text_config.hidden_size
|
||||
self.image_token_id = image_token_id
|
||||
|
||||
super().__init__(**kwargs)
|
||||
156
python/sglang/srt/configs/update_config.py
Normal file
156
python/sglang/srt/configs/update_config.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
DEFAULT_MOE_PADDING_SIZE = 32
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.configs.load_config import LoadConfig
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
|
||||
|
||||
def may_get_weight_block_size(model_config, load_config):
|
||||
from sglang.srt.model_loader.loader import _get_quantization_config
|
||||
from sglang.srt.model_loader.utils import get_model_architecture
|
||||
|
||||
model_class, _ = get_model_architecture(model_config)
|
||||
packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {})
|
||||
|
||||
quant_config = _get_quantization_config(
|
||||
model_config, load_config, packed_modules_mapping
|
||||
)
|
||||
|
||||
if quant_config is not None and hasattr(quant_config, "weight_block_size"):
|
||||
return getattr(quant_config, "weight_block_size")
|
||||
return None
|
||||
|
||||
|
||||
def get_moe_padding_size(weight_block_size):
|
||||
if weight_block_size is not None:
|
||||
# See NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
|
||||
assert (
|
||||
len(weight_block_size) == 2
|
||||
), "Only len(weight_block_size) == 2 is supported"
|
||||
assert (
|
||||
weight_block_size[0] == weight_block_size[1]
|
||||
), "Only weight_block_size[0] == weight_block_size[1] is supported"
|
||||
|
||||
return weight_block_size[0]
|
||||
|
||||
return DEFAULT_MOE_PADDING_SIZE
|
||||
|
||||
|
||||
def get_num_heads_padding_size(tp_size, weight_block_size):
|
||||
pad_size = (
|
||||
tp_size * 2 if tp_size % 2 == 1 and weight_block_size is not None else tp_size
|
||||
)
|
||||
return pad_size
|
||||
|
||||
|
||||
def update_intermediate_size(model_config, attr_name, intermediate_padding_size):
|
||||
attr_value = intermediate_padding_size
|
||||
if hasattr(model_config, "hf_config") and hasattr(
|
||||
model_config.hf_config, attr_name
|
||||
):
|
||||
attr_value = getattr(model_config.hf_config, attr_name)
|
||||
elif hasattr(model_config, attr_name):
|
||||
attr_value = getattr(model_config, attr_name)
|
||||
|
||||
if attr_value % intermediate_padding_size != 0:
|
||||
from sglang.srt.layers.vocab_parallel_embedding import pad_vocab_size
|
||||
|
||||
attr_value = pad_vocab_size(attr_value, intermediate_padding_size)
|
||||
if hasattr(model_config, "hf_config"):
|
||||
setattr(model_config.hf_config, attr_name, attr_value)
|
||||
if hasattr(model_config, "hf_text_config"):
|
||||
setattr(model_config.hf_text_config, attr_name, attr_value)
|
||||
else:
|
||||
setattr(model_config, attr_name, attr_value)
|
||||
|
||||
return model_config
|
||||
|
||||
|
||||
def adjust_config_with_unaligned_cpu_tp(
|
||||
model_config: ModelConfig, load_config: LoadConfig, tp_size: int
|
||||
) -> ModelConfig:
|
||||
# Support the case where the num_attention_heads is not divisible by the TP size.
|
||||
weight_block_size = may_get_weight_block_size(model_config, load_config)
|
||||
|
||||
model_config.hf_config.original_num_attention_heads = (
|
||||
model_config.num_attention_heads
|
||||
)
|
||||
model_config.hf_text_config.original_num_attention_heads = (
|
||||
model_config.num_attention_heads
|
||||
)
|
||||
|
||||
model_config.hf_config.original_total_num_kv_heads = (
|
||||
model_config.get_total_num_kv_heads()
|
||||
)
|
||||
model_config.hf_text_config.original_total_num_kv_heads = (
|
||||
model_config.get_total_num_kv_heads()
|
||||
)
|
||||
|
||||
if (
|
||||
model_config.num_attention_heads % tp_size != 0
|
||||
or model_config.get_total_num_kv_heads() % tp_size != 0
|
||||
):
|
||||
# Compute the head_dim using the model_config.num_attention_heads before padding
|
||||
if not hasattr(model_config.hf_config, "head_dim"):
|
||||
model_config.hf_config.head_dim = (
|
||||
model_config.hidden_size // model_config.num_attention_heads
|
||||
)
|
||||
|
||||
query_heads_per_kv = (
|
||||
model_config.num_attention_heads // model_config.get_total_num_kv_heads()
|
||||
)
|
||||
total_kv_heads = model_config.get_total_num_kv_heads()
|
||||
from sglang.srt.layers.vocab_parallel_embedding import pad_vocab_size
|
||||
|
||||
pad_size = get_num_heads_padding_size(tp_size, weight_block_size)
|
||||
num_key_value_heads = pad_vocab_size(total_kv_heads, pad_size)
|
||||
|
||||
model_config.num_key_value_heads = num_key_value_heads
|
||||
model_config.hf_config.num_key_value_heads = num_key_value_heads
|
||||
model_config.hf_text_config.num_key_value_heads = num_key_value_heads
|
||||
|
||||
num_attention_heads = num_key_value_heads * query_heads_per_kv
|
||||
model_config.num_attention_heads = num_attention_heads
|
||||
model_config.hf_config.num_attention_heads = num_attention_heads
|
||||
model_config.hf_text_config.num_attention_heads = num_attention_heads
|
||||
|
||||
intermediate_padding_size = tp_size * get_moe_padding_size(weight_block_size)
|
||||
model_config = update_intermediate_size(
|
||||
model_config, "moe_intermediate_size", intermediate_padding_size
|
||||
)
|
||||
model_config = update_intermediate_size(
|
||||
model_config, "intermediate_size", intermediate_padding_size
|
||||
)
|
||||
model_config = update_intermediate_size(
|
||||
model_config, "intermediate_size_mlp", intermediate_padding_size
|
||||
)
|
||||
if (
|
||||
hasattr(model_config.hf_config, "vision_config")
|
||||
and model_config.hf_config.vision_config.model_type == "siglip_vision_model"
|
||||
):
|
||||
model_config.hf_config.vision_config.original_num_attention_heads = (
|
||||
model_config.num_attention_heads
|
||||
)
|
||||
if model_config.hf_config.vision_config.num_attention_heads % tp_size != 0:
|
||||
model_config.hf_config.vision_config.head_dim = (
|
||||
model_config.hf_config.vision_config.hidden_size
|
||||
// model_config.hf_config.vision_config.num_attention_heads
|
||||
)
|
||||
from sglang.srt.layers.vocab_parallel_embedding import pad_vocab_size
|
||||
|
||||
pad_size = get_num_heads_padding_size(tp_size, weight_block_size)
|
||||
model_config.hf_config.vision_config.num_attention_heads = pad_vocab_size(
|
||||
model_config.hf_config.vision_config.num_attention_heads, pad_size
|
||||
)
|
||||
model_config.hf_config.vision_config = update_intermediate_size(
|
||||
model_config.hf_config.vision_config,
|
||||
"intermediate_size",
|
||||
intermediate_padding_size,
|
||||
)
|
||||
|
||||
return model_config
|
||||
25
python/sglang/srt/configs/utils.py
Normal file
25
python/sglang/srt/configs/utils.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from typing import Type
|
||||
|
||||
from transformers import (
|
||||
AutoImageProcessor,
|
||||
AutoProcessor,
|
||||
BaseImageProcessor,
|
||||
PretrainedConfig,
|
||||
ProcessorMixin,
|
||||
)
|
||||
|
||||
|
||||
def register_image_processor(
|
||||
config: Type[PretrainedConfig], image_processor: Type[BaseImageProcessor]
|
||||
):
|
||||
"""
|
||||
register customized hf image processor while removing hf impl
|
||||
"""
|
||||
AutoImageProcessor.register(config, None, image_processor, None, exist_ok=True)
|
||||
|
||||
|
||||
def register_processor(config: Type[PretrainedConfig], processor: Type[ProcessorMixin]):
|
||||
"""
|
||||
register customized hf processor while removing hf impl
|
||||
"""
|
||||
AutoProcessor.register(config, processor, exist_ok=True)
|
||||
51
python/sglang/srt/connector/__init__.py
Normal file
51
python/sglang/srt/connector/__init__.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import enum
|
||||
import logging
|
||||
|
||||
from sglang.srt.connector.base_connector import (
|
||||
BaseConnector,
|
||||
BaseFileConnector,
|
||||
BaseKVConnector,
|
||||
)
|
||||
from sglang.srt.connector.redis import RedisConnector
|
||||
from sglang.srt.connector.s3 import S3Connector
|
||||
from sglang.srt.utils import parse_connector_type
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConnectorType(str, enum.Enum):
|
||||
FS = "filesystem"
|
||||
KV = "KV"
|
||||
|
||||
|
||||
def create_remote_connector(url, **kwargs) -> BaseConnector:
|
||||
connector_type = parse_connector_type(url)
|
||||
if connector_type == "redis":
|
||||
return RedisConnector(url)
|
||||
elif connector_type == "s3":
|
||||
return S3Connector(url)
|
||||
else:
|
||||
raise ValueError(f"Invalid connector type: {url}")
|
||||
|
||||
|
||||
def get_connector_type(client: BaseConnector) -> ConnectorType:
|
||||
if isinstance(client, BaseKVConnector):
|
||||
return ConnectorType.KV
|
||||
if isinstance(client, BaseFileConnector):
|
||||
return ConnectorType.FS
|
||||
|
||||
raise ValueError(f"Invalid connector type: {client}")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BaseConnector",
|
||||
"BaseFileConnector",
|
||||
"BaseKVConnector",
|
||||
"RedisConnector",
|
||||
"S3Connector",
|
||||
"ConnectorType",
|
||||
"create_remote_connector",
|
||||
"get_connector_type",
|
||||
]
|
||||
BIN
python/sglang/srt/connector/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
python/sglang/srt/connector/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
python/sglang/srt/connector/__pycache__/redis.cpython-310.pyc
Normal file
BIN
python/sglang/srt/connector/__pycache__/redis.cpython-310.pyc
Normal file
Binary file not shown.
BIN
python/sglang/srt/connector/__pycache__/s3.cpython-310.pyc
Normal file
BIN
python/sglang/srt/connector/__pycache__/s3.cpython-310.pyc
Normal file
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user