Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8152e24cb2 |
188
README.md
188
README.md
@@ -50,166 +50,28 @@
|
|||||||
```
|
```
|
||||||
|
|
||||||
模型测试结果
|
模型测试结果
|
||||||
|
| 模型名称 | A100出字速度(字/秒) | 沐曦卡出字速度(字/秒) | 备注 |
|
||||||
| 模型名称 | A100出字速度(字/秒) | 曦云C500出字速度(字/秒) | A100输出质量 | 曦云C500输出质量 | A100首字延迟(秒) | 曦云C500首字延迟(秒) | 备注 |
|
|---------|-----|-----|---------------------|
|
||||||
| ----- | ----- | ----- | ----- | ----- | ----- | ----- | ----- |
|
| unsloth/gpt-oss-20b-BF16 | 80.1 | 52.9 | |
|
||||||
| 01ai/Yi-1.5-6B-Chat | 109.5427 | 109.5463 | 85.0000 | 80.0000 | 0.0921 | 0.1178 | |
|
| Qwen/Qwen3-4B | 171.8 | 112.3 | |
|
||||||
| 01ai/Yi-6B-Chat | 113.3089 | 83.3881 | 85.0000 | 85.0000 | 0.0663 | 0.0986 | |
|
| Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8 | 168.5 | 135.3 | |
|
||||||
| AI-ModelScope/CausalLM-7B | 108.3807 | 86.3852 | 77.5000 | 85.0000 | 0.0798 | 0.1090 | |
|
| Qwen/Qwen-1_8B-Chat-Int4 | 536.4 | 192.9 | |
|
||||||
| AI-ModelScope/granite-3.1-3b-a800m-instruct | 39.7904 | 41.1946 | 53.7500 | 63.7500 | 0.2164 | 0.2640 | |
|
| Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int4 | 153.1 | 132.4 | |
|
||||||
| AI-ModelScope/granite-7b-instruct | 106.5448 | 92.1454 | 32.5000 | 27.5000 | 0.1004 | 0.1307 | |
|
| deepseek-ai/deepseek-moe-16b-chat | 68.1 | 60.7 | |
|
||||||
| AI-ModelScope/Hermes-3-Llama-3.1-8B | 104.0769 | 69.4578 | 85.0000 | 85.0000 | 0.0827 | 0.1063 | |
|
| Qwen/Qwen2-7B-Instruct-GPTQ-Int4 | 129.2 | 127.5 | |
|
||||||
| AI-ModelScope/mathstral-7B-v0.1 | 81.6464 | 56.6757 | 41.2500 | 47.5000 | 0.1011 | 0.1811 | |
|
| Qwen2.5-7B-Instruct-GPTQ-Int4 | 118.1 | 133.1 | |
|
||||||
| AI-ModelScope/Ministral-8B-Instruct-2410 | 73.7056 | 66.5971 | 85.0000 | 71.2500 | 0.0997 | 0.1629 | |
|
| tclf90/glm-4-9b-chat-GPTQ-Int4 | 92.2 | 97.7 | |
|
||||||
| AI-ModelScope/Mistral-7B-v0.1 | 74.3966 | 51.7199 | 12.5000 | 15.0000 | 0.1456 | 0.1329 | |
|
| Qwen/Qwen2.5-14B-Instruct-GPTQ-Int4 | 79.6 | 72.3 | |
|
||||||
| AI-ModelScope/TinyLlama-1.1B-Chat-v0.4 | 58.0479 | 102.3987 | 22.5000 | 22.5000 | 0.0747 | 0.1156 | |
|
| Qwen/Qwen-14B-Chat-Int8 | 103.1 | 56.4 | 该模型在沐曦卡上生成质量要差于A100 |
|
||||||
| allenai/OLMoE-1B-7B-0924-Instruct | 66.9544 | 59.7696 | 37.5000 | 42.5000 | 0.2127 | 0.2610 | |
|
| Qwen/Qwen2.5-14B-Instruct-GPTQ-Int8 | 81.4 | 70.4 | |
|
||||||
| BAAI/Finance-llama3_1_8B_instruct | 95.4692 | 65.3383 | 58.7500 | 57.5000 | 0.0812 | 0.1262 | |
|
| tclf90/qwq-32b-gptq-int4 | 60.3 | 53.5 | |
|
||||||
| BAAI/Hospitality-llama3_1_8B_instruct | 91.7153 | 62.6340 | 61.2500 | 56.2500 | 0.0650 | 0.1210 | |
|
| Qwen/Qwen2.5-32B-Instruct-GPTQ-Int4 | 60.1 | 54.6 | |
|
||||||
| BAAI/Technology-llama3_1_8B_instruct | 80.6653 | 60.4259 | 32.5000 | 36.2500 | 0.0844 | 0.1147 | |
|
| Qwen/Qwen1.5-32B-Chat-GPTQ-Int4 | 58.2 | 51.3 | |
|
||||||
| ByteDance-Seed/Seed-OSS-36B-Instruct | 38.2767 | 17.4623 | 86.7500 | 88.5000 | 0.1400 | 0.3135 | |
|
| tclf90/Qwen3-32B-GPTQ-Int8 | 54.0 | 41.3 | |
|
||||||
| codefuse-ai/TestGPT-7B | 68.6996 | 49.0781 | 15.0000 | 15.0000 | 0.0756 | 0.1193 | |
|
| tclf90/deepseek-r1-distill-qwen-32b-gptq-int8 | 59.9 | 45.8 | |
|
||||||
| CohereLabs/aya-expanse-8B | 88.7216 | 78.2155 | 86.7500 | 86.7500 | 0.0687 | 3.8936 | |
|
| Qwen/Qwen2.5-72B-Instruct-GPTQ-Int4 | 46.6|29.5 | |
|
||||||
| Cylingo/Xinyuan-LLM-14B-0428 | 71.7868 | 49.5901 | 86.7500 | 89.2500 | 0.0860 | 0.2501 | |
|
| Qwen/Qwen2-72B-Instruct-GPTQ-Int4 | 48.2| 29.7| |
|
||||||
| deepseek-ai/deepseek-llm-7b-base | 141.4714 | 93.0921 | 20.0000 | 22.5000 | 0.1231 | 0.1760 | |
|
| Qwen/Qwen3-4B-Instruct-2507 | 65.4| 71.8| |
|
||||||
| deepseek-ai/deepseek-llm-7b-chat | 124.9381 | 99.4562 | 85.0000 | 81.2500 | 0.0581 | 0.5611 | |
|
| Qwen/Qwen3-4B-Thinking-2507 |73.4 |52.6 | |
|
||||||
| deepseek-ai/deepseek-moe-16b-chat | 68.0789 | 60.7028 | 85.0000 | 68.7500 | 0.2346 | 0.2880 | |
|
| tclf90/Qwen3-32B-GPTQ-Int4 | 54.4| 38.4 | |
|
||||||
| Fengshenbang/Ziya-LLaMA-13B-v1 | 56.6471 | 39.3946 | 61.2500 | 58.7500 | 0.0616 | 0.1163 | |
|
| Qwen/Qwen3-0.6B-GPTQ-Int8 |117.4 | 95.0 | |
|
||||||
| FlagAlpha/Llama3-Chinese-8B-Instruct | 100.3664 | 70.7829 | 38.7500 | 30.0000 | 0.1091 | 0.1442 | |
|
|
||||||
| HuggingFaceH4/zephyr-7b-beta | 97.4954 | 51.7445 | 66.2500 | 47.5000 | 0.0802 | 0.1135 | |
|
|
||||||
| iic/WritingBench-Critic-Model-Qwen-7B | 125.5657 | 96.5148 | 87.5000 | 87.5000 | 0.0593 | 0.1159 | |
|
|
||||||
| InfiniAI/Megrez-3b-Instruct | 142.4497 | 157.9848 | 85.0000 | 85.0000 | 0.0815 | 0.0946 | |
|
|
||||||
| JunHowie/Qwen3-0.6B-GPTQ-Int4 | 108.2950 | 67.6596 | 58.7500 | 21.2500 | 0.0711 | 0.1206 | |
|
|
||||||
| JunHowie/Qwen3-1.7B-GPTQ-Int8 | 124.1477 | 131.5458 | 71.2500 | 38.7500 | 0.0900 | 0.1848 | |
|
|
||||||
| JunHowie/Qwen3-8B-GPTQ-Int4 | 82.8148 | 109.8259 | 86.7500 | 52.5000 | 0.0800 | 0.1575 | |
|
|
||||||
| Kedreamix/Xinjing-LM | 137.3682 | 150.4469 | 80.0000 | 66.2500 | 0.0710 | 0.1214 | |
|
|
||||||
| LLM-Research/gemma-2-9b-it | 47.0337 | 47.1786 | 85.0000 | 66.2500 | 0.1125 | 0.2271 | |
|
|
||||||
| LLM-Research/gemma-3-1b-it | 39.6315 | 43.8421 | 47.5000 | 71.2500 | 0.2256 | 0.2829 | |
|
|
||||||
| LLM-Research/gemma-3-1b-it | 38.3524 | 33.8285 | 47.5000 | 36.2500 | 0.2080 | 0.1673 | |
|
|
||||||
| LLM-Research/Llama-3.2-3B | 137.6931 | 145.1816 | 15.0000 | 15.0000 | 0.0891 | 0.1399 | |
|
|
||||||
| LLM-Research/Llama-3.2-3B-Instruct | 107.1550 | 95.4333 | 63.7500 | 68.7500 | 0.0533 | 0.1868 | |
|
|
||||||
| LLM-Research/Llama-Guard-3-8B | 74.5048 | 73.5737 | 33.7500 | 22.5000 | 0.0861 | 0.1248 | |
|
|
||||||
| LLM-Research/Llama3-8B-Chinese-Chat | 69.0446 | 66.6263 | 85.0000 | 85.0000 | 0.0823 | 0.3008 | |
|
|
||||||
| LLM-Research/Meta-Llama-3-8B-Instruct | 167.5121 | 27.8093 | 85.0000 | 92.7500 | 0.0778 | 0.2426 | |
|
|
||||||
| LLM-Research/Meta-Llama-3-8B-Instruct-GPTQ | 199.0842 | 175.9663 | 57.5000 | 71.2500 | 0.1028 | 0.1520 | |
|
|
||||||
| LLM-Research/Meta-Llama-3.1-70B-Instruct-AWQ-INT4 | 34.9493 | 18.1700 | 87.5000 | 85.0000 | 0.1327 | 0.2035 | |
|
|
||||||
| LLM-Research/Meta-Llama-3.1-8B-Instruct-AWQ-INT4 | 93.1869 | 77.1093 | 85.0000 | 85.0000 | 0.0665 | 0.1753 | |
|
|
||||||
| LLM-Research/OpenHermes-2.5-Mistral-7B | 78.2992 | 51.0257 | 75.0000 | 75.0000 | 0.0797 | 0.1161 | |
|
|
||||||
| LLM-Research/Phi-3-mini-4k-instruct | 65.0090 | 59.8164 | 45.0000 | 40.0000 | 0.0632 | 0.1004 | |
|
|
||||||
| LLM-Research/Qwen2-7B | 170.2887 | 126.6331 | 77.5000 | 78.2500 | 0.0787 | 0.1345 | |
|
|
||||||
| LLM-Research/Starling-LM-7B-beta | 73.7987 | 49.5925 | 66.2500 | 71.2500 | 0.0952 | 0.5098 | |
|
|
||||||
| LLM-Research/tulu-2-dpo-7b | 60.8343 | 47.6062 | 70.0000 | 72.5000 | 0.0618 | 0.0997 | |
|
|
||||||
| m-a-p/neo_7b_instruct_v0.1 | 178.2949 | 137.2681 | 71.2500 | 72.5000 | 0.0839 | 0.1358 | |
|
|
||||||
| mistralai/Devstral-Small-2507 | 56.7985 | 42.3061 | 88.7500 | 91.0000 | 0.0698 | 0.1193 | |
|
|
||||||
| mistralai/Mistral-Small-24B-Instruct-2501 | 57.3654 | 42.7280 | 86.7500 | 87.5000 | 0.0899 | 0.1340 | |
|
|
||||||
| mlabonne/EvolCodeLlama-7b | 53.7677 | 43.3601 | 5.0000 | 5.0000 | 0.1003 | 0.1287 | |
|
|
||||||
| modelscope/zephyr-7b-beta | 81.4112 | 52.4249 | 65.0000 | 47.5000 | 0.0835 | 0.2036 | |
|
|
||||||
| neuralmagic/Meta-Llama-3.1-8B-quantized.w8a8 | 157.2308 | 130.4350 | 10.0000 | 10.0000 | 0.0638 | 0.2007 | |
|
|
||||||
| neuralmagic/Qwen2-1.5B-Instruct-quantized.w8a8 | 133.1439 | 136.2916 | 80.0000 | 80.0000 | 0.0564 | 0.0941 | |
|
|
||||||
| neuralmagic/SmolLM-1.7B-Instruct-quantized.w8a8 | 105.3004 | 107.6412 | 10.0000 | 10.0000 | 0.0886 | 0.0876 | |
|
|
||||||
| nv-community/Minitron-4B-Base | 63.8192 | 69.2841 | 5.0000 | 5.0000 | 0.0706 | 0.0917 | |
|
|
||||||
| openai-mirror/gpt-oss-safeguard-20b | 112.4384 | 86.9659 | 61.2500 | 72.0000 | 0.2067 | 0.4129 | |
|
|
||||||
| OpenBMB/MiniCPM-1B-sft-bf16 | 61.4045 | 71.6924 | 66.2500 | 75.0000 | 0.0748 | 0.1010 | |
|
|
||||||
| OpenBMB/MiniCPM-2B-dpo-fp16 | 61.8623 | 82.9386 | 80.0000 | 71.2500 | 0.1056 | 0.1071 | |
|
|
||||||
| OpenBMB/MiniCPM-2B-sft-fp32 | 74.4678 | 89.2962 | 65.0000 | 55.0000 | 0.0996 | 0.1408 | |
|
|
||||||
| OpenBMB/MiniCPM3-4B | 28.2210 | 21.3110 | 86.7500 | 86.7500 | 0.1020 | 0.1905 | |
|
|
||||||
| OpenBMB/MiniCPM4-0.5B | 82.2793 | 93.0776 | 23.7500 | 15.0000 | 0.0606 | 0.1143 | |
|
|
||||||
| OpenBMB/MiniCPM4-8B | 60.5837 | 57.0162 | 87.5000 | 52.5000 | 0.0921 | 0.1482 | |
|
|
||||||
| OpenBMB/MiniCPM4-8B-marlin-vLLM | 57.5381 | 58.2351 | 85.0000 | 72.0000 | 0.1033 | 0.1107 | |
|
|
||||||
| OpenBMB/MiniCPM4.1-8B | 61.4557 | 51.5120 | 89.2500 | 86.7500 | 0.0784 | 0.1290 | |
|
|
||||||
| PaddlePaddle/ERNIE-4.5-0.3B-PT | 144.5465 | 185.2054 | 52.5000 | 47.5000 | 0.0837 | 0.0902 | |
|
|
||||||
| prithivMLmods/Llama-Sentient-3.2-3B-Instruct | 147.0239 | 168.6519 | 40.0000 | 40.0000 | 0.0661 | 0.1230 | |
|
|
||||||
| prithivMLmods/Qwen-UMLS-7B-Instruct | 122.8389 | 89.8736 | 55.0000 | 55.0000 | 0.0584 | 0.1061 | |
|
|
||||||
| QuantTrio/Qwen3-30B-A3B-Instruct-2507-GPTQ-Int8 | 40.1016 | 22.6581 | 91.7500 | 68.0000 | 0.0915 | 0.2704 | |
|
|
||||||
| Qwen/Qwen-1_8B-Chat-Int4 | 536.3839 | 192.9054 | 28.7500 | 25.0000 | 0.0517 | 0.0936 | |
|
|
||||||
| Qwen/Qwen-1_8B-Chat-Int8 | 425.9918 | 95.9929 | 42.5000 | 85.0000 | 0.0549 | 0.1754 | |
|
|
||||||
| Qwen/Qwen-14B | 64.3123 | 157.3503 | 38.7500 | 33.7500 | 0.0912 | 0.0858 | |
|
|
||||||
| Qwen/Qwen-14B-Chat-Int4 | 84.9682 | 148.9078 | 68.7500 | 85.0000 | 0.0971 | 0.1298 | |
|
|
||||||
| Qwen/Qwen-14B-Chat-Int8 | 103.0940 | 56.3855 | 85.0000 | 21.2500 | 0.0758 | 0.1443 | 该模型在沐曦卡上生成质量要差于A100 |
|
|
||||||
| Qwen/Qwen-72B-Chat-Int4 | 55.1679 | 260.1285 | 85.0000 | 47.5000 | 0.1360 | 0.1411 | |
|
|
||||||
| Qwen/Qwen-72B-Chat-Int8 | 45.9768 | 83.0130 | 68.0000 | 58.7500 | 0.1172 | 0.1511 | |
|
|
||||||
| Qwen/Qwen-7B-Chat-Int4 | 119.8944 | 102.6521 | 55.0000 | 55.0000 | 0.0831 | 0.1092 | |
|
|
||||||
| Qwen/Qwen-7B-Chat-Int4 | 117.8753 | 89.0941 | 55.0000 | 37.5000 | 0.0761 | 0.1086 | |
|
|
||||||
| Qwen/Qwen-7B-Chat-Int8 | 128.7233 | 107.4733 | 40.0000 | 80.0000 | 0.0545 | 0.1578 | |
|
|
||||||
| Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8 | 168.5100 | 135.2830 | 33.7500 | 33.7500 | 0.0530 | 0.0869 | |
|
|
||||||
| Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int4 | 153.0596 | 132.3559 | 33.7500 | 63.7500 | 0.0557 | 0.1316 | |
|
|
||||||
| Qwen/Qwen1.5-14B-Chat-AWQ | 98.3736 | 74.0728 | 91.5000 | 88.5000 | 0.0725 | 0.1188 | |
|
|
||||||
| Qwen/Qwen1.5-14B-Chat-GPTQ-Int4 | 77.7489 | 92.8893 | 88.5000 | 88.5000 | 0.0863 | 0.1030 | |
|
|
||||||
| Qwen/Qwen1.5-14B-Chat-GPTQ-Int8 | 90.9851 | 79.5202 | 88.5000 | 90.2500 | 0.0716 | 0.1136 | |
|
|
||||||
| Qwen/Qwen1.5-32B-Chat-GPTQ-Int4 | 58.2218 | 51.3081 | 92.2500 | 91.0000 | 0.0775 | 0.1392 | |
|
|
||||||
| Qwen/Qwen1.5-72B-Chat-GPTQ-Int4 | 47.0482 | 28.5930 | 92.7500 | 92.7500 | 0.1314 | 0.3305 | |
|
|
||||||
| Qwen/Qwen1.5-7B-Chat-GPTQ-Int4 | 111.7481 | 110.7277 | 88.0000 | 86.7500 | 0.0857 | 0.0978 | |
|
|
||||||
| Qwen/Qwen2-72B-Instruct-GPTQ-Int4 | 48.2245 | 29.6541 | 92.7500 | 92.7500 | 0.1091 | 0.2308 | |
|
|
||||||
| Qwen/Qwen2-7B-Instruct-GPTQ-Int4 | 129.1925 | 127.4951 | 88.5000 | 88.5000 | 0.0632 | 0.1595 | |
|
|
||||||
| Qwen/Qwen2-7B-Instruct-GPTQ-Int8 | 120.7039 | 120.5521 | 89.2500 | 91.0000 | 0.0589 | 0.0928 | |
|
|
||||||
| Qwen/Qwen2.5-0.5B | 190.6378 | 143.7020 | 12.5000 | 80.0000 | 0.0569 | 0.0966 | |
|
|
||||||
| Qwen/Qwen2.5-0.5B-Instruct | 157.0171 | 74.5845 | 66.2500 | 90.0000 | 0.0524 | 0.1228 | |
|
|
||||||
| Qwen/Qwen2.5-1.5B-Instruct | 132.3709 | 108.0231 | 85.0000 | 85.0000 | 0.0784 | 0.1907 | |
|
|
||||||
| Qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int4 | 99.1006 | 111.6715 | 80.0000 | 85.0000 | 0.0832 | 0.0998 | |
|
|
||||||
| Qwen/Qwen2.5-14B-Instruct-1M | 73.2996 | 50.5570 | 92.7500 | 91.0000 | 0.0736 | 0.1316 | |
|
|
||||||
| Qwen/Qwen2.5-14B-Instruct-GPTQ-Int4 | 79.5701 | 72.2741 | 91.0000 | 91.0000 | 0.0738 | 0.1219 | |
|
|
||||||
| Qwen/Qwen2.5-14B-Instruct-GPTQ-Int8 | 81.3902 | 70.3944 | 91.0000 | 91.0000 | 0.0588 | 0.3164 | |
|
|
||||||
| Qwen/Qwen2.5-32B-Instruct-GPTQ-Int4 | 60.0678 | 54.5765 | 90.5000 | 91.0000 | 0.1023 | 0.1855 | |
|
|
||||||
| Qwen/Qwen2.5-3B | 102.8492 | 124.4283 | 47.5000 | 50.0000 | 0.0830 | 0.1048 | |
|
|
||||||
| Qwen/Qwen2.5-72B-Instruct-GPTQ-Int4 | 46.5678 | 29.5256 | 91.0000 | 70.5000 | 0.1013 | 0.2163 | |
|
|
||||||
| Qwen/Qwen2.5-7B-Instruct-1M | 97.2979 | 92.4331 | 89.2500 | 89.2500 | 0.0685 | 0.1017 | |
|
|
||||||
| Qwen/Qwen2.5-7B-Instruct-AWQ | 115.5071 | 113.5233 | 88.5000 | 91.0000 | 0.0599 | 0.0982 | |
|
|
||||||
| Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4 | 118.1169 | 133.1432 | 87.5000 | 87.5000 | 0.0603 | 0.1551 | |
|
|
||||||
| Qwen/Qwen2.5-Coder-14B-Instruct-AWQ | 47.8161 | 52.9300 | 87.5000 | 87.5000 | 0.0854 | 0.1162 | |
|
|
||||||
| Qwen/Qwen2.5-Coder-14B-Instruct-GPTQ-Int4 | 70.0316 | 68.4585 | 87.5000 | 73.7500 | 0.0694 | 0.1246 | |
|
|
||||||
| Qwen/Qwen2.5-Coder-32B-Instruct-GPTQ-Int4 | 60.2992 | 55.2420 | 91.0000 | 91.0000 | 0.1108 | 0.2650 | |
|
|
||||||
| Qwen/Qwen2.5-Coder-7B | 135.1152 | 115.0369 | 46.7500 | 63.7500 | 0.0979 | 0.2146 | |
|
|
||||||
| Qwen/Qwen2.5-Coder-7B-Instruct-GPTQ-Int8 | 108.4975 | 100.5769 | 87.5000 | 87.5000 | 0.0646 | 0.1794 | |
|
|
||||||
| Qwen/Qwen3-0.6B | 115.7984 | 128.5600 | 56.2500 | 72.5000 | 0.0890 | 0.1073 | |
|
|
||||||
| Qwen/Qwen3-0.6B-GPTQ-Int8 | 117.4333 | 95.0022 | 75.0000 | 33.7500 | 0.0891 | 0.1146 | |
|
|
||||||
| Qwen/Qwen3-1___7B | 123.6276 | 55.1110 | 81.7500 | 86.7500 | 0.0931 | 0.2069 | |
|
|
||||||
| Qwen/Qwen3-30B-A3B-Base | 36.5579 | 145.1734 | 86.0000 | 65.0000 | 0.1890 | 0.0983 | |
|
|
||||||
| Qwen/Qwen3-30B-A3B-GPTQ-Int4 | 41.5809 | 95.4118 | 88.5000 | 66.2500 | 0.0818 | 0.1298 | |
|
|
||||||
| Qwen/Qwen3-32B-GPTQ-Int8 | 53.9982 | 41.3361 | 86.7500 | 86.7500 | 0.1501 | 0.2713 | |
|
|
||||||
| Qwen/Qwen3-4B | 167.1038 | 83.0732 | 87.5000 | 66.2500 | 0.0587 | 0.1159 | |
|
|
||||||
| Qwen/Qwen3-4B | 171.1984 | 105.4759 | 89.2500 | 84.2500 | 0.0598 | 0.1105 | |
|
|
||||||
| Qwen/Qwen3-4B-AWQ | 76.4882 | 91.4280 | 86.7500 | 88.5000 | 0.0922 | 0.1245 | |
|
|
||||||
| Qwen/Qwen3-4B-Instruct-2507 | 65.4432 | 71.7888 | 91.7500 | 91.7500 | 0.0966 | 0.1285 | |
|
|
||||||
| Qwen/Qwen3-4B-SafeRL | 78.8314 | 96.6840 | 87.5000 | 88.5000 | 0.0930 | 0.1174 | |
|
|
||||||
| Qwen/Qwen3-4B-Thinking-2507 | 73.3587 | 52.6143 | 81.7500 | 88.5000 | 0.0799 | 0.2663 | |
|
|
||||||
| Qwen/Qwen3-8B-AWQ | 94.2630 | 67.3837 | 88.5000 | 86.7500 | 0.0990 | 0.1425 | |
|
|
||||||
| QwenCollection/Hercules-Mini-1.8B | 202.5339 | 202.1224 | 28.7500 | 28.7500 | 0.0517 | 0.0855 | |
|
|
||||||
| QwenCollection/Ragas-critic-llm-Qwen1.5-GPTQ | 170.6923 | 136.0126 | 35.0000 | 25.0000 | 0.0667 | 0.1142 | |
|
|
||||||
| RUC-DataLab/DeepAnalyze-8B | 84.5026 | 84.9213 | 86.7500 | 83.5000 | 0.1118 | 0.1673 | |
|
|
||||||
| shakechen/Llama-2-7b-chat-hf | 260.9870 | 34.1061 | 61.2500 | 91.0000 | 0.0892 | 0.1960 | |
|
|
||||||
| Shanghai_AI_Laboratory/internlm-20b | 54.0570 | 40.9779 | 49.2500 | 71.2500 | 0.1376 | 0.2653 | |
|
|
||||||
| Shanghai_AI_Laboratory/internlm-chat-20b | 59.3568 | 44.0583 | 85.0000 | 85.0000 | 0.0789 | 0.1403 | |
|
|
||||||
| Shanghai_AI_Laboratory/internlm2-chat-1_8b | 137.9855 | 150.8975 | 22.5000 | 31.2500 | 0.0874 | 0.1612 | |
|
|
||||||
| tclf90/Codestral-22B-v0.1-hf-GPTQ-Int4 | 69.1961 | 71.7657 | 80.0000 | 70.0000 | 0.1810 | 0.2461 | |
|
|
||||||
| tclf90/deepseek-r1-distill-qwen-14b-gptq-int4 | 61.5092 | 59.0025 | 85.0000 | 85.0000 | 0.0786 | 0.1791 | |
|
|
||||||
| tclf90/deepseek-r1-distill-qwen-32b-gptq-int4 | 63.7344 | 54.3297 | 88.0000 | 86.7500 | 0.1675 | 0.2525 | |
|
|
||||||
| tclf90/deepseek-r1-distill-qwen-32b-gptq-int8 | 59.9314 | 45.7905 | 86.7500 | 86.7500 | 0.1491 | 0.2839 | |
|
|
||||||
| tclf90/deepseek-r1-distill-qwen-7b-gptq-int4 | 134.7569 | 141.3986 | 80.0000 | 85.0000 | 0.0892 | 0.2778 | |
|
|
||||||
| tclf90/glm-4-9b-chat-GPTQ-Int4 | 92.2061 | 97.6775 | 89.2500 | 91.7500 | 0.0735 | 0.4389 | |
|
|
||||||
| tclf90/glm-4-9b-chat-GPTQ-Int8 | 67.4918 | 93.5178 | 90.5000 | 86.7500 | 0.1037 | 0.1350 | |
|
|
||||||
| tclf90/Qwen2-14B-merge-GPTQ-Int8 | 68.0706 | 73.0019 | 88.5000 | 88.5000 | 0.0741 | 0.1217 | |
|
|
||||||
| tclf90/qwen2.5-14b-instruct-1m-gptq-int4 | 80.4909 | 71.7343 | 91.0000 | 91.0000 | 0.0720 | 0.1329 | |
|
|
||||||
| tclf90/qwen2.5-72b-instruct-gptq-int4 | 38.6166 | 34.9558 | 91.7500 | 52.5000 | 0.1053 | 0.5529 | |
|
|
||||||
| tclf90/Qwen3-32B-GPTQ-Int4 | 54.3526 | 38.4391 | 86.7500 | 88.5000 | 0.1470 | 0.2276 | |
|
|
||||||
| tclf90/Qwen3-32B-GPTQ-Int4 | 60.6483 | 42.8220 | 86.7500 | 86.7500 | 0.1651 | 0.2262 | |
|
|
||||||
| tclf90/Qwen3-32B-GPTQ-Int8 | 48.2039 | 148.4290 | 86.7500 | 70.0000 | 0.1526 | 0.0950 | |
|
|
||||||
| tclf90/qwq-32b-gptq-int4 | 60.3297 | 53.4619 | 87.5000 | 86.7500 | 0.1581 | 0.4379 | |
|
|
||||||
| tclf90/qwq-32b-gptq-int8 | 52.0826 | 41.4700 | 86.7500 | 87.5000 | 0.1853 | 0.2502 | |
|
|
||||||
| TheBloke/Kimiko-7B-fp16 | 52.3951 | 41.4855 | 15.0000 | 15.0000 | 0.0947 | 0.1223 | |
|
|
||||||
| tiiuae/falcon-7b-instruct | 177.4328 | 147.2327 | 20.0000 | 20.0000 | 0.0906 | 0.1180 | |
|
|
||||||
| TongyiFinance/Tongyi-Finance-14B-Chat | 161.2871 | 55.1869 | 52.5000 | 63.7500 | 0.0970 | 0.2396 | |
|
|
||||||
| TongyiFinance/Tongyi-Finance-14B-Chat-Int4 | 144.2085 | 114.2191 | 57.5000 | 38.7500 | 0.0592 | 0.1210 | |
|
|
||||||
| TongyiFinance/Tongyi-Finance-14B-Chat-Int4 | 147.6310 | 129.5033 | 57.5000 | 57.5000 | 0.0694 | 0.1123 | |
|
|
||||||
| UnicomAI/Unichat-llama3.2-Chinese-1B | 170.9294 | 181.6792 | 26.2500 | 25.0000 | 0.0594 | 0.1311 | |
|
|
||||||
| unsloth/Phi-3.5-mini-instruct | 39.4028 | 38.0348 | 80.0000 | 77.5000 | 0.0716 | 0.1253 | |
|
|
||||||
| XGenerationLab/XiYanSQL-QwenCoder-3B-2502 | 111.9181 | 105.4054 | 80.0000 | 85.0000 | 0.0593 | 0.1060 | |
|
|
||||||
| Xunzillm4cc/Xunzi-Qwen1.5-4B | 99.2706 | 102.2685 | 22.5000 | 25.0000 | 0.0811 | 0.0982 | |
|
|
||||||
| Xunzillm4cc/Xunzi-Qwen2-1.5B | 125.9543 | 137.9787 | 20.0000 | 17.5000 | 0.0588 | 0.0903 | |
|
|
||||||
| ZhipuAI/chatglm3-6b-base | 129.0337 | 40.0233 | 38.7500 | 71.2500 | 0.1349 | 0.2365 | |
|
|
||||||
| ZhipuAI/glm-4-9b-chat-1m | 102.4445 | 77.4381 | 89.7500 | 88.7500 | 0.0911 | 0.1357 | |
|
|
||||||
| ZhipuAI/glm-4-9b-chat-hf | 109.1148 | 71.8195 | 85.0000 | 85.0000 | 0.0900 | 0.1838 | |
|
|
||||||
| ZhipuAI/GLM-Z1-32B-0414 | 55.9116 | 41.5245 | 88.5000 | 88.5000 | 0.1048 | 0.2995 | |
|
|
||||||
| ZhipuAI/GLM-Z1-9B-0414 | 90.9342 | 71.1734 | 85.0000 | 86.7500 | 0.0879 | 0.9642 | |
|
|
||||||
| ZhipuAI/LongWriter-glm4-9b | 100.4160 | 81.4277 | 86.7500 | 86.7500 | 0.1221 | 0.1758 | |
|
|
||||||
| zpeng1989/TCM_DeepSeek_LLM | 98.9219 | 64.9526 | 82.5000 | 80.0000 | 0.0721 | 0.1533 | |
|
|
||||||
|
|
||||||
84
vllm/_utils/__init__.py
Normal file
84
vllm/_utils/__init__.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
import warnings
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
from utils import *
|
||||||
|
|
||||||
|
_DEPRECATED_MAPPINGS = {
|
||||||
|
"cprofile": "profiling",
|
||||||
|
"cprofile_context": "profiling",
|
||||||
|
# Used by lm-eval
|
||||||
|
"get_open_port": "network_utils",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def __getattr__(name: str) -> Any: # noqa: D401 - short deprecation docstring
|
||||||
|
"""Module-level getattr to handle deprecated utilities."""
|
||||||
|
if name in _DEPRECATED_MAPPINGS:
|
||||||
|
submodule_name = _DEPRECATED_MAPPINGS[name]
|
||||||
|
warnings.warn(
|
||||||
|
f"vllm.utils.{name} is deprecated and will be removed in a future version. "
|
||||||
|
f"Use vllm.utils.{submodule_name}.{name} instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
module = __import__(f"vllm.utils.{submodule_name}", fromlist=[submodule_name])
|
||||||
|
return getattr(module, name)
|
||||||
|
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||||
|
|
||||||
|
|
||||||
|
def __dir__() -> list[str]:
|
||||||
|
# expose deprecated names in dir() for better UX/tab-completion
|
||||||
|
return sorted(list(globals().keys()) + list(_DEPRECATED_MAPPINGS.keys()))
|
||||||
|
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# Constants related to forcing the attention backend selection
|
||||||
|
|
||||||
|
# String name of register which may be set in order to
|
||||||
|
# force auto-selection of attention backend by Attention
|
||||||
|
# wrapper
|
||||||
|
STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
|
||||||
|
|
||||||
|
# Possible string values of STR_BACKEND_ENV_VAR
|
||||||
|
# register, corresponding to possible backends
|
||||||
|
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
|
||||||
|
STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
|
||||||
|
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
|
||||||
|
STR_INVALID_VAL: str = "INVALID"
|
||||||
|
|
||||||
|
|
||||||
|
def random_uuid() -> str:
|
||||||
|
return str(uuid.uuid4().hex)
|
||||||
|
|
||||||
|
|
||||||
|
def length_from_prompt_token_ids_or_embeds(
|
||||||
|
prompt_token_ids: list[int] | None,
|
||||||
|
prompt_embeds: torch.Tensor | None,
|
||||||
|
) -> int:
|
||||||
|
"""Calculate the request length (in number of tokens) give either
|
||||||
|
prompt_token_ids or prompt_embeds.
|
||||||
|
"""
|
||||||
|
prompt_token_len = None if prompt_token_ids is None else len(prompt_token_ids)
|
||||||
|
prompt_embeds_len = None if prompt_embeds is None else len(prompt_embeds)
|
||||||
|
|
||||||
|
if prompt_token_len is None:
|
||||||
|
if prompt_embeds_len is None:
|
||||||
|
raise ValueError("Neither prompt_token_ids nor prompt_embeds were defined.")
|
||||||
|
return prompt_embeds_len
|
||||||
|
else:
|
||||||
|
if prompt_embeds_len is not None and prompt_embeds_len != prompt_token_len:
|
||||||
|
raise ValueError(
|
||||||
|
"Prompt token ids and prompt embeds had different lengths"
|
||||||
|
f" prompt_token_ids={prompt_token_len}"
|
||||||
|
f" prompt_embeds={prompt_embeds_len}"
|
||||||
|
)
|
||||||
|
return prompt_token_len
|
||||||
487
vllm/_utils/argparse_utils.py
Normal file
487
vllm/_utils/argparse_utils.py
Normal file
@@ -0,0 +1,487 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Argument parsing utilities for vLLM."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import textwrap
|
||||||
|
from argparse import (
|
||||||
|
Action,
|
||||||
|
ArgumentDefaultsHelpFormatter,
|
||||||
|
ArgumentParser,
|
||||||
|
ArgumentTypeError,
|
||||||
|
Namespace,
|
||||||
|
RawDescriptionHelpFormatter,
|
||||||
|
_ArgumentGroup,
|
||||||
|
)
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import regex as re
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SortedHelpFormatter(ArgumentDefaultsHelpFormatter, RawDescriptionHelpFormatter):
|
||||||
|
"""SortedHelpFormatter that sorts arguments by their option strings."""
|
||||||
|
|
||||||
|
def _split_lines(self, text, width):
|
||||||
|
"""
|
||||||
|
1. Sentences split across lines have their single newlines removed.
|
||||||
|
2. Paragraphs and explicit newlines are split into separate lines.
|
||||||
|
3. Each line is wrapped to the specified width (width of terminal).
|
||||||
|
"""
|
||||||
|
# The patterns also include whitespace after the newline
|
||||||
|
single_newline = re.compile(r"(?<!\n)\n(?!\n)\s*")
|
||||||
|
multiple_newlines = re.compile(r"\n{2,}\s*")
|
||||||
|
text = single_newline.sub(" ", text)
|
||||||
|
lines = re.split(multiple_newlines, text)
|
||||||
|
return sum([textwrap.wrap(line, width) for line in lines], [])
|
||||||
|
|
||||||
|
def add_arguments(self, actions):
|
||||||
|
actions = sorted(actions, key=lambda x: x.option_strings)
|
||||||
|
super().add_arguments(actions)
|
||||||
|
|
||||||
|
|
||||||
|
class FlexibleArgumentParser(ArgumentParser):
|
||||||
|
"""ArgumentParser that allows both underscore and dash in names."""
|
||||||
|
|
||||||
|
_deprecated: set[Action] = set()
|
||||||
|
_json_tip: str = (
|
||||||
|
"When passing JSON CLI arguments, the following sets of arguments "
|
||||||
|
"are equivalent:\n"
|
||||||
|
' --json-arg \'{"key1": "value1", "key2": {"key3": "value2"}}\'\n'
|
||||||
|
" --json-arg.key1 value1 --json-arg.key2.key3 value2\n\n"
|
||||||
|
"Additionally, list elements can be passed individually using +:\n"
|
||||||
|
' --json-arg \'{"key4": ["value3", "value4", "value5"]}\'\n'
|
||||||
|
" --json-arg.key4+ value3 --json-arg.key4+='value4,value5'\n\n"
|
||||||
|
)
|
||||||
|
_search_keyword: str | None = None
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
# Set the default "formatter_class" to SortedHelpFormatter
|
||||||
|
if "formatter_class" not in kwargs:
|
||||||
|
kwargs["formatter_class"] = SortedHelpFormatter
|
||||||
|
# Pop kwarg "add_json_tip" to control whether to add the JSON tip
|
||||||
|
self.add_json_tip = kwargs.pop("add_json_tip", True)
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
if sys.version_info < (3, 13):
|
||||||
|
# Enable the deprecated kwarg for Python 3.12 and below
|
||||||
|
|
||||||
|
def parse_known_args(self, args=None, namespace=None):
|
||||||
|
if args is not None and "--disable-log-requests" in args:
|
||||||
|
# Special case warning because the warning below won't trigger
|
||||||
|
# if –-disable-log-requests because its value is default.
|
||||||
|
logger.warning_once(
|
||||||
|
"argument '--disable-log-requests' is deprecated and "
|
||||||
|
"replaced with '--enable-log-requests'. This will be "
|
||||||
|
"removed in v0.12.0."
|
||||||
|
)
|
||||||
|
namespace, args = super().parse_known_args(args, namespace)
|
||||||
|
for action in FlexibleArgumentParser._deprecated:
|
||||||
|
if (
|
||||||
|
hasattr(namespace, dest := action.dest)
|
||||||
|
and getattr(namespace, dest) != action.default
|
||||||
|
):
|
||||||
|
logger.warning_once("argument '%s' is deprecated", dest)
|
||||||
|
return namespace, args
|
||||||
|
|
||||||
|
def add_argument(self, *args, **kwargs):
|
||||||
|
deprecated = kwargs.pop("deprecated", False)
|
||||||
|
action = super().add_argument(*args, **kwargs)
|
||||||
|
if deprecated:
|
||||||
|
FlexibleArgumentParser._deprecated.add(action)
|
||||||
|
return action
|
||||||
|
|
||||||
|
class _FlexibleArgumentGroup(_ArgumentGroup):
|
||||||
|
def add_argument(self, *args, **kwargs):
|
||||||
|
deprecated = kwargs.pop("deprecated", False)
|
||||||
|
action = super().add_argument(*args, **kwargs)
|
||||||
|
if deprecated:
|
||||||
|
FlexibleArgumentParser._deprecated.add(action)
|
||||||
|
return action
|
||||||
|
|
||||||
|
def add_argument_group(self, *args, **kwargs):
|
||||||
|
group = self._FlexibleArgumentGroup(self, *args, **kwargs)
|
||||||
|
self._action_groups.append(group)
|
||||||
|
return group
|
||||||
|
|
||||||
|
def format_help(self):
|
||||||
|
# Only use custom help formatting for bottom level parsers
|
||||||
|
if self._subparsers is not None:
|
||||||
|
return super().format_help()
|
||||||
|
|
||||||
|
formatter = self._get_formatter()
|
||||||
|
|
||||||
|
# Handle keyword search of the args
|
||||||
|
if (search_keyword := self._search_keyword) is not None:
|
||||||
|
# Normalise the search keyword
|
||||||
|
search_keyword = search_keyword.lower().replace("_", "-")
|
||||||
|
# Return full help if searching for 'all'
|
||||||
|
if search_keyword == "all":
|
||||||
|
self.epilog = self._json_tip
|
||||||
|
return super().format_help()
|
||||||
|
|
||||||
|
# Return group help if searching for a group title
|
||||||
|
for group in self._action_groups:
|
||||||
|
if group.title and group.title.lower() == search_keyword:
|
||||||
|
formatter.start_section(group.title)
|
||||||
|
formatter.add_text(group.description)
|
||||||
|
formatter.add_arguments(group._group_actions)
|
||||||
|
formatter.end_section()
|
||||||
|
formatter.add_text(self._json_tip)
|
||||||
|
return formatter.format_help()
|
||||||
|
|
||||||
|
# Return matched args if searching for an arg name
|
||||||
|
matched_actions = []
|
||||||
|
for group in self._action_groups:
|
||||||
|
for action in group._group_actions:
|
||||||
|
# search option name
|
||||||
|
if any(
|
||||||
|
search_keyword in opt.lower() for opt in action.option_strings
|
||||||
|
):
|
||||||
|
matched_actions.append(action)
|
||||||
|
if matched_actions:
|
||||||
|
formatter.start_section(f"Arguments matching '{search_keyword}'")
|
||||||
|
formatter.add_arguments(matched_actions)
|
||||||
|
formatter.end_section()
|
||||||
|
formatter.add_text(self._json_tip)
|
||||||
|
return formatter.format_help()
|
||||||
|
|
||||||
|
# No match found
|
||||||
|
formatter.add_text(
|
||||||
|
f"No group or arguments matching '{search_keyword}'.\n"
|
||||||
|
"Use '--help' to see available groups or "
|
||||||
|
"'--help=all' to see all available parameters."
|
||||||
|
)
|
||||||
|
return formatter.format_help()
|
||||||
|
|
||||||
|
# usage
|
||||||
|
formatter.add_usage(self.usage, self._actions, self._mutually_exclusive_groups)
|
||||||
|
|
||||||
|
# description
|
||||||
|
formatter.add_text(self.description)
|
||||||
|
|
||||||
|
# positionals, optionals and user-defined groups
|
||||||
|
formatter.start_section("Config Groups")
|
||||||
|
config_groups = ""
|
||||||
|
for group in self._action_groups:
|
||||||
|
if not group._group_actions:
|
||||||
|
continue
|
||||||
|
title = group.title
|
||||||
|
description = group.description or ""
|
||||||
|
config_groups += f"{title: <24}{description}\n"
|
||||||
|
formatter.add_text(config_groups)
|
||||||
|
formatter.end_section()
|
||||||
|
|
||||||
|
# epilog
|
||||||
|
formatter.add_text(self.epilog)
|
||||||
|
|
||||||
|
# determine help from format above
|
||||||
|
return formatter.format_help()
|
||||||
|
|
||||||
|
def parse_args( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
args: list[str] | None = None,
|
||||||
|
namespace: Namespace | None = None,
|
||||||
|
):
|
||||||
|
if args is None:
|
||||||
|
args = sys.argv[1:]
|
||||||
|
|
||||||
|
# Check for --model in command line arguments first
|
||||||
|
if args and args[0] == "serve":
|
||||||
|
try:
|
||||||
|
model_idx = next(
|
||||||
|
i
|
||||||
|
for i, arg in enumerate(args)
|
||||||
|
if arg == "--model" or arg.startswith("--model=")
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
"With `vllm serve`, you should provide the model as a "
|
||||||
|
"positional argument or in a config file instead of via "
|
||||||
|
"the `--model` option. "
|
||||||
|
"The `--model` option will be removed in v0.13."
|
||||||
|
)
|
||||||
|
|
||||||
|
if args[model_idx] == "--model":
|
||||||
|
model_tag = args[model_idx + 1]
|
||||||
|
rest_start_idx = model_idx + 2
|
||||||
|
else:
|
||||||
|
model_tag = args[model_idx].removeprefix("--model=")
|
||||||
|
rest_start_idx = model_idx + 1
|
||||||
|
|
||||||
|
# Move <model> to the front, e,g:
|
||||||
|
# [Before]
|
||||||
|
# vllm serve -tp 2 --model <model> --enforce-eager --port 8001
|
||||||
|
# [After]
|
||||||
|
# vllm serve <model> -tp 2 --enforce-eager --port 8001
|
||||||
|
args = [
|
||||||
|
"serve",
|
||||||
|
model_tag,
|
||||||
|
*args[1:model_idx],
|
||||||
|
*args[rest_start_idx:],
|
||||||
|
]
|
||||||
|
except StopIteration:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if "--config" in args:
|
||||||
|
args = self._pull_args_from_config(args)
|
||||||
|
|
||||||
|
def repl(match: re.Match) -> str:
|
||||||
|
"""Replaces underscores with dashes in the matched string."""
|
||||||
|
return match.group(0).replace("_", "-")
|
||||||
|
|
||||||
|
# Everything between the first -- and the first .
|
||||||
|
pattern = re.compile(r"(?<=--)[^\.]*")
|
||||||
|
|
||||||
|
# Convert underscores to dashes and vice versa in argument names
|
||||||
|
processed_args = list[str]()
|
||||||
|
for i, arg in enumerate(args):
|
||||||
|
if arg.startswith("--help="):
|
||||||
|
FlexibleArgumentParser._search_keyword = arg.split("=", 1)[-1].lower()
|
||||||
|
processed_args.append("--help")
|
||||||
|
elif arg.startswith("--"):
|
||||||
|
if "=" in arg:
|
||||||
|
key, value = arg.split("=", 1)
|
||||||
|
key = pattern.sub(repl, key, count=1)
|
||||||
|
processed_args.append(f"{key}={value}")
|
||||||
|
else:
|
||||||
|
key = pattern.sub(repl, arg, count=1)
|
||||||
|
processed_args.append(key)
|
||||||
|
elif arg.startswith("-O") and arg != "-O" and arg[2] != ".":
|
||||||
|
# allow -O flag to be used without space, e.g. -O3 or -Odecode
|
||||||
|
# -O.<...> handled later
|
||||||
|
# also handle -O=<mode> here
|
||||||
|
mode = arg[3:] if arg[2] == "=" else arg[2:]
|
||||||
|
processed_args.append(f"-O.mode={mode}")
|
||||||
|
elif (
|
||||||
|
arg == "-O"
|
||||||
|
and i + 1 < len(args)
|
||||||
|
and args[i + 1] in {"0", "1", "2", "3"}
|
||||||
|
):
|
||||||
|
# Convert -O <n> to -O.mode <n>
|
||||||
|
processed_args.append("-O.mode")
|
||||||
|
else:
|
||||||
|
processed_args.append(arg)
|
||||||
|
|
||||||
|
def create_nested_dict(keys: list[str], value: str) -> dict[str, Any]:
|
||||||
|
"""Creates a nested dictionary from a list of keys and a value.
|
||||||
|
|
||||||
|
For example, `keys = ["a", "b", "c"]` and `value = 1` will create:
|
||||||
|
`{"a": {"b": {"c": 1}}}`
|
||||||
|
"""
|
||||||
|
nested_dict: Any = value
|
||||||
|
for key in reversed(keys):
|
||||||
|
nested_dict = {key: nested_dict}
|
||||||
|
return nested_dict
|
||||||
|
|
||||||
|
def recursive_dict_update(
|
||||||
|
original: dict[str, Any],
|
||||||
|
update: dict[str, Any],
|
||||||
|
) -> set[str]:
|
||||||
|
"""Recursively updates a dictionary with another dictionary.
|
||||||
|
Returns a set of duplicate keys that were overwritten.
|
||||||
|
"""
|
||||||
|
duplicates = set[str]()
|
||||||
|
for k, v in update.items():
|
||||||
|
if isinstance(v, dict) and isinstance(original.get(k), dict):
|
||||||
|
nested_duplicates = recursive_dict_update(original[k], v)
|
||||||
|
duplicates |= {f"{k}.{d}" for d in nested_duplicates}
|
||||||
|
elif isinstance(v, list) and isinstance(original.get(k), list):
|
||||||
|
original[k] += v
|
||||||
|
else:
|
||||||
|
if k in original:
|
||||||
|
duplicates.add(k)
|
||||||
|
original[k] = v
|
||||||
|
return duplicates
|
||||||
|
|
||||||
|
delete = set[int]()
|
||||||
|
dict_args = defaultdict[str, dict[str, Any]](dict)
|
||||||
|
duplicates = set[str]()
|
||||||
|
for i, processed_arg in enumerate(processed_args):
|
||||||
|
if i in delete: # skip if value from previous arg
|
||||||
|
continue
|
||||||
|
|
||||||
|
if processed_arg.startswith("-") and "." in processed_arg:
|
||||||
|
if "=" in processed_arg:
|
||||||
|
processed_arg, value_str = processed_arg.split("=", 1)
|
||||||
|
if "." not in processed_arg:
|
||||||
|
# False positive, '.' was only in the value
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
value_str = processed_args[i + 1]
|
||||||
|
delete.add(i + 1)
|
||||||
|
|
||||||
|
if processed_arg.endswith("+"):
|
||||||
|
processed_arg = processed_arg[:-1]
|
||||||
|
value_str = json.dumps(list(value_str.split(",")))
|
||||||
|
|
||||||
|
key, *keys = processed_arg.split(".")
|
||||||
|
try:
|
||||||
|
value = json.loads(value_str)
|
||||||
|
except json.decoder.JSONDecodeError:
|
||||||
|
value = value_str
|
||||||
|
|
||||||
|
# Merge all values with the same key into a single dict
|
||||||
|
arg_dict = create_nested_dict(keys, value)
|
||||||
|
arg_duplicates = recursive_dict_update(dict_args[key], arg_dict)
|
||||||
|
duplicates |= {f"{key}.{d}" for d in arg_duplicates}
|
||||||
|
delete.add(i)
|
||||||
|
# Filter out the dict args we set to None
|
||||||
|
processed_args = [a for i, a in enumerate(processed_args) if i not in delete]
|
||||||
|
if duplicates:
|
||||||
|
logger.warning("Found duplicate keys %s", ", ".join(duplicates))
|
||||||
|
|
||||||
|
# Add the dict args back as if they were originally passed as JSON
|
||||||
|
for dict_arg, dict_value in dict_args.items():
|
||||||
|
processed_args.append(dict_arg)
|
||||||
|
processed_args.append(json.dumps(dict_value))
|
||||||
|
|
||||||
|
return super().parse_args(processed_args, namespace)
|
||||||
|
|
||||||
|
def check_port(self, value):
|
||||||
|
try:
|
||||||
|
value = int(value)
|
||||||
|
except ValueError:
|
||||||
|
msg = "Port must be an integer"
|
||||||
|
raise ArgumentTypeError(msg) from None
|
||||||
|
|
||||||
|
if not (1024 <= value <= 65535):
|
||||||
|
raise ArgumentTypeError("Port must be between 1024 and 65535")
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
def _pull_args_from_config(self, args: list[str]) -> list[str]:
|
||||||
|
"""Method to pull arguments specified in the config file
|
||||||
|
into the command-line args variable.
|
||||||
|
|
||||||
|
The arguments in config file will be inserted between
|
||||||
|
the argument list.
|
||||||
|
|
||||||
|
example:
|
||||||
|
```yaml
|
||||||
|
port: 12323
|
||||||
|
tensor-parallel-size: 4
|
||||||
|
```
|
||||||
|
```python
|
||||||
|
$: vllm {serve,chat,complete} "facebook/opt-12B" \
|
||||||
|
--config config.yaml -tp 2
|
||||||
|
$: args = [
|
||||||
|
"serve,chat,complete",
|
||||||
|
"facebook/opt-12B",
|
||||||
|
'--config', 'config.yaml',
|
||||||
|
'-tp', '2'
|
||||||
|
]
|
||||||
|
$: args = [
|
||||||
|
"serve,chat,complete",
|
||||||
|
"facebook/opt-12B",
|
||||||
|
'--port', '12323',
|
||||||
|
'--tensor-parallel-size', '4',
|
||||||
|
'-tp', '2'
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
Please note how the config args are inserted after the sub command.
|
||||||
|
this way the order of priorities is maintained when these are args
|
||||||
|
parsed by super().
|
||||||
|
"""
|
||||||
|
assert args.count("--config") <= 1, "More than one config file specified!"
|
||||||
|
|
||||||
|
index = args.index("--config")
|
||||||
|
if index == len(args) - 1:
|
||||||
|
raise ValueError(
|
||||||
|
"No config file specified! \
|
||||||
|
Please check your command-line arguments."
|
||||||
|
)
|
||||||
|
|
||||||
|
file_path = args[index + 1]
|
||||||
|
|
||||||
|
config_args = self.load_config_file(file_path)
|
||||||
|
|
||||||
|
# 0th index might be the sub command {serve,chat,complete,...}
|
||||||
|
# optionally followed by model_tag (only for serve)
|
||||||
|
# followed by config args
|
||||||
|
# followed by rest of cli args.
|
||||||
|
# maintaining this order will enforce the precedence
|
||||||
|
# of cli > config > defaults
|
||||||
|
if args[0].startswith("-"):
|
||||||
|
# No sub command (e.g., api_server entry point)
|
||||||
|
args = config_args + args[0:index] + args[index + 2 :]
|
||||||
|
elif args[0] == "serve":
|
||||||
|
model_in_cli = len(args) > 1 and not args[1].startswith("-")
|
||||||
|
model_in_config = any(arg == "--model" for arg in config_args)
|
||||||
|
|
||||||
|
if not model_in_cli and not model_in_config:
|
||||||
|
raise ValueError(
|
||||||
|
"No model specified! Please specify model either "
|
||||||
|
"as a positional argument or in a config file."
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_in_cli:
|
||||||
|
# Model specified as positional arg, keep CLI version
|
||||||
|
args = (
|
||||||
|
[args[0]]
|
||||||
|
+ [args[1]]
|
||||||
|
+ config_args
|
||||||
|
+ args[2:index]
|
||||||
|
+ args[index + 2 :]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# No model in CLI, use config if available
|
||||||
|
args = [args[0]] + config_args + args[1:index] + args[index + 2 :]
|
||||||
|
else:
|
||||||
|
args = [args[0]] + config_args + args[1:index] + args[index + 2 :]
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
def load_config_file(self, file_path: str) -> list[str]:
|
||||||
|
"""Loads a yaml file and returns the key value pairs as a
|
||||||
|
flattened list with argparse like pattern
|
||||||
|
```yaml
|
||||||
|
port: 12323
|
||||||
|
tensor-parallel-size: 4
|
||||||
|
```
|
||||||
|
returns:
|
||||||
|
processed_args: list[str] = [
|
||||||
|
'--port': '12323',
|
||||||
|
'--tensor-parallel-size': '4'
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
extension: str = file_path.split(".")[-1]
|
||||||
|
if extension not in ("yaml", "yml"):
|
||||||
|
raise ValueError(
|
||||||
|
f"Config file must be of a yaml/yml type. {extension} supplied"
|
||||||
|
)
|
||||||
|
|
||||||
|
# only expecting a flat dictionary of atomic types
|
||||||
|
processed_args: list[str] = []
|
||||||
|
|
||||||
|
config: dict[str, int | str] = {}
|
||||||
|
try:
|
||||||
|
with open(file_path) as config_file:
|
||||||
|
config = yaml.safe_load(config_file)
|
||||||
|
except Exception as ex:
|
||||||
|
logger.error(
|
||||||
|
"Unable to read the config file at %s. Check path correctness",
|
||||||
|
file_path,
|
||||||
|
)
|
||||||
|
raise ex
|
||||||
|
|
||||||
|
for key, value in config.items():
|
||||||
|
if isinstance(value, bool):
|
||||||
|
if value:
|
||||||
|
processed_args.append("--" + key)
|
||||||
|
elif isinstance(value, list):
|
||||||
|
if value:
|
||||||
|
processed_args.append("--" + key)
|
||||||
|
for item in value:
|
||||||
|
processed_args.append(str(item))
|
||||||
|
else:
|
||||||
|
processed_args.append("--" + key)
|
||||||
|
processed_args.append(str(value))
|
||||||
|
|
||||||
|
return processed_args
|
||||||
303
vllm/_utils/async_utils.py
Normal file
303
vllm/_utils/async_utils.py
Normal file
@@ -0,0 +1,303 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Contains helpers related to asynchronous code.
|
||||||
|
|
||||||
|
This is similar in concept to the `asyncio` module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import contextlib
|
||||||
|
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task
|
||||||
|
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||||
|
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||||
|
from functools import partial
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
|
from transformers.tokenization_utils_base import BatchEncoding
|
||||||
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncMicrobatchTokenizer:
|
||||||
|
"""Asynchronous tokenizer with micro-batching.
|
||||||
|
|
||||||
|
Pulls pending encode/decode requests from a queue and batches them
|
||||||
|
up to reduce overhead. A single-thread ThreadPoolExecutor is used
|
||||||
|
so the event loop stays responsive.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer,
|
||||||
|
max_batch_size: int = 32,
|
||||||
|
batch_wait_timeout_s: float = 0.002,
|
||||||
|
) -> None:
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.max_batch_size = max_batch_size
|
||||||
|
self.batch_wait_timeout_s = batch_wait_timeout_s
|
||||||
|
|
||||||
|
self._loop = asyncio.get_running_loop()
|
||||||
|
self._queues: dict[
|
||||||
|
tuple,
|
||||||
|
asyncio.Queue[tuple[str, dict, Future] | tuple[list[int], Future]],
|
||||||
|
] = {}
|
||||||
|
self._batcher_tasks: list[Task] = []
|
||||||
|
|
||||||
|
# Single-thread executor for blocking tokenizer calls.
|
||||||
|
self._executor = ThreadPoolExecutor(max_workers=1)
|
||||||
|
|
||||||
|
# === Public async API ===
|
||||||
|
async def __call__(self, prompt, **kwargs):
|
||||||
|
result_future: Future = self._loop.create_future()
|
||||||
|
key = self._queue_key("encode", kwargs)
|
||||||
|
queue = self._get_queue(self._loop, key)
|
||||||
|
await queue.put((prompt, kwargs, result_future))
|
||||||
|
return await result_future
|
||||||
|
|
||||||
|
async def decode(self, token_ids, **kwargs):
|
||||||
|
result_future: Future = self._loop.create_future()
|
||||||
|
key = self._queue_key("decode", kwargs)
|
||||||
|
queue = self._get_queue(self._loop, key)
|
||||||
|
await queue.put((token_ids, result_future))
|
||||||
|
return await result_future
|
||||||
|
|
||||||
|
# === Internal helpers ===
|
||||||
|
def _get_queue(
|
||||||
|
self, loop: asyncio.AbstractEventLoop, key: tuple
|
||||||
|
) -> asyncio.Queue[tuple[str, dict, Future] | tuple[list[int], Future]]:
|
||||||
|
"""Get the request queue for the given operation key, creating a new
|
||||||
|
queue and batcher task if needed."""
|
||||||
|
queue = self._queues.get(key)
|
||||||
|
if queue is None:
|
||||||
|
self._queues[key] = queue = asyncio.Queue()
|
||||||
|
if key[0] == "encode":
|
||||||
|
can_batch = key[1] != "other"
|
||||||
|
coro = self._batch_encode_loop(queue, can_batch)
|
||||||
|
else:
|
||||||
|
assert key[0] == "decode", f"Unknown operation type: {key[0]}."
|
||||||
|
coro = self._batch_decode_loop(queue)
|
||||||
|
self._batcher_tasks.append(loop.create_task(coro))
|
||||||
|
return queue
|
||||||
|
|
||||||
|
async def _batch_encode_loop(self, queue: asyncio.Queue, can_batch: bool):
|
||||||
|
"""Batch incoming encode requests for efficiency."""
|
||||||
|
while True:
|
||||||
|
prompt, kwargs, result_future = await queue.get()
|
||||||
|
prompts = [prompt]
|
||||||
|
kwargs_list = [kwargs]
|
||||||
|
result_futures = [result_future]
|
||||||
|
deadline = self._loop.time() + self.batch_wait_timeout_s
|
||||||
|
|
||||||
|
while len(prompts) < self.max_batch_size:
|
||||||
|
timeout = deadline - self._loop.time()
|
||||||
|
if timeout <= 0:
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
prompt, kwargs, result_future = await asyncio.wait_for(
|
||||||
|
queue.get(), timeout
|
||||||
|
)
|
||||||
|
prompts.append(prompt)
|
||||||
|
result_futures.append(result_future)
|
||||||
|
if not can_batch:
|
||||||
|
kwargs_list.append(kwargs)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
# If every request uses identical kwargs we can run a single
|
||||||
|
# batched tokenizer call for a big speed-up.
|
||||||
|
if can_batch and len(prompts) > 1:
|
||||||
|
batch_encode_fn = partial(self.tokenizer, prompts, **kwargs)
|
||||||
|
results = await self._loop.run_in_executor(
|
||||||
|
self._executor, batch_encode_fn
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, fut in enumerate(result_futures):
|
||||||
|
if not fut.done():
|
||||||
|
data = {k: v[i] for k, v in results.items()}
|
||||||
|
fut.set_result(BatchEncoding(data))
|
||||||
|
else:
|
||||||
|
encode_fn = lambda prompts=prompts, kwargs=kwargs_list: [
|
||||||
|
self.tokenizer(p, **kw) for p, kw in zip(prompts, kwargs)
|
||||||
|
]
|
||||||
|
results = await self._loop.run_in_executor(
|
||||||
|
self._executor, encode_fn
|
||||||
|
)
|
||||||
|
|
||||||
|
for fut, res in zip(result_futures, results):
|
||||||
|
if not fut.done():
|
||||||
|
fut.set_result(res)
|
||||||
|
except Exception as e:
|
||||||
|
for fut in result_futures:
|
||||||
|
if not fut.done():
|
||||||
|
fut.set_exception(e)
|
||||||
|
|
||||||
|
async def _batch_decode_loop(self, queue: asyncio.Queue):
|
||||||
|
"""Batch incoming decode requests for efficiency."""
|
||||||
|
while True:
|
||||||
|
token_ids, result_future = await queue.get()
|
||||||
|
token_ids_list = [token_ids]
|
||||||
|
result_futures = [result_future]
|
||||||
|
deadline = self._loop.time() + self.batch_wait_timeout_s
|
||||||
|
|
||||||
|
while len(token_ids_list) < self.max_batch_size:
|
||||||
|
timeout = deadline - self._loop.time()
|
||||||
|
if timeout <= 0:
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
token_ids, result_future = await asyncio.wait_for(
|
||||||
|
queue.get(), timeout
|
||||||
|
)
|
||||||
|
token_ids_list.append(token_ids)
|
||||||
|
result_futures.append(result_future)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Perform a single batched decode call for all requests
|
||||||
|
results = await self._loop.run_in_executor(
|
||||||
|
self._executor, self.tokenizer.batch_decode, token_ids_list
|
||||||
|
)
|
||||||
|
for fut, res in zip(result_futures, results):
|
||||||
|
if not fut.done():
|
||||||
|
fut.set_result(res)
|
||||||
|
except Exception as e:
|
||||||
|
for fut in result_futures:
|
||||||
|
if not fut.done():
|
||||||
|
fut.set_exception(e)
|
||||||
|
|
||||||
|
def _queue_key(self, op: str, kwargs: dict) -> tuple:
|
||||||
|
"""
|
||||||
|
Return a normalized key describing operation + kwargs.
|
||||||
|
|
||||||
|
- `add_special_tokens`: {True/False}
|
||||||
|
- `truncation`: {True/False}
|
||||||
|
- If `truncation` is False (`max_length` is None),
|
||||||
|
returns a key for a can_batch queue.
|
||||||
|
- If `truncation` is True and `max_length` is None or equals
|
||||||
|
`tokenizer.model_max_length`, returns a key for a can_batch queue.
|
||||||
|
- Otherwise, returns a key for a cannot_batch queue.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
- Decode: ("decode",)
|
||||||
|
- Encode typical:
|
||||||
|
("encode", add_special_tokens, bool_truncation, max_length_label)
|
||||||
|
- Fallback: ("encode", "other")
|
||||||
|
"""
|
||||||
|
|
||||||
|
if op == "decode":
|
||||||
|
return ("decode",)
|
||||||
|
|
||||||
|
add_special_tokens = kwargs.get("add_special_tokens", True)
|
||||||
|
truncation = kwargs.get("truncation", False)
|
||||||
|
max_length = kwargs.get("max_length")
|
||||||
|
|
||||||
|
if not truncation:
|
||||||
|
return "encode", add_special_tokens, False, None
|
||||||
|
|
||||||
|
model_max = getattr(self.tokenizer, "model_max_length", None)
|
||||||
|
if max_length is None or (model_max is not None and max_length == model_max):
|
||||||
|
return "encode", add_special_tokens, True, "model_max"
|
||||||
|
|
||||||
|
return "encode", "other"
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if (
|
||||||
|
(tasks := getattr(self, "_batcher_tasks", None))
|
||||||
|
and (loop := getattr(self, "_loop", None))
|
||||||
|
and not loop.is_closed()
|
||||||
|
):
|
||||||
|
|
||||||
|
def cancel_tasks():
|
||||||
|
for task in tasks:
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
loop.call_soon_threadsafe(cancel_tasks)
|
||||||
|
|
||||||
|
|
||||||
|
def cancel_task_threadsafe(task: Task):
|
||||||
|
if task and not task.done():
|
||||||
|
run_in_loop(task.get_loop(), task.cancel)
|
||||||
|
|
||||||
|
|
||||||
|
def make_async(
|
||||||
|
func: Callable[P, T],
|
||||||
|
executor: Executor | None = None,
|
||||||
|
) -> Callable[P, Awaitable[T]]:
|
||||||
|
"""
|
||||||
|
Take a blocking function, and run it on in an executor thread.
|
||||||
|
|
||||||
|
This function prevents the blocking function from blocking the
|
||||||
|
asyncio event loop.
|
||||||
|
The code in this function needs to be thread safe.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> Future[T]:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
p_func = partial(func, *args, **kwargs)
|
||||||
|
return loop.run_in_executor(executor=executor, func=p_func)
|
||||||
|
|
||||||
|
return _async_wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def run_in_loop(loop: AbstractEventLoop, function: Callable, *args):
|
||||||
|
if in_loop(loop):
|
||||||
|
function(*args)
|
||||||
|
elif not loop.is_closed():
|
||||||
|
loop.call_soon_threadsafe(function, *args)
|
||||||
|
|
||||||
|
|
||||||
|
def in_loop(event_loop: AbstractEventLoop) -> bool:
|
||||||
|
try:
|
||||||
|
return asyncio.get_running_loop() == event_loop
|
||||||
|
except RuntimeError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def merge_async_iterators(
|
||||||
|
*iterators: AsyncGenerator[T, None],
|
||||||
|
) -> AsyncGenerator[tuple[int, T], None]:
|
||||||
|
"""Merge multiple asynchronous iterators into a single iterator.
|
||||||
|
|
||||||
|
This method handle the case where some iterators finish before others.
|
||||||
|
When it yields, it yields a tuple (i, item) where i is the index of the
|
||||||
|
iterator that yields the item.
|
||||||
|
"""
|
||||||
|
if len(iterators) == 1:
|
||||||
|
# Fast-path single iterator case.
|
||||||
|
async for item in iterators[0]:
|
||||||
|
yield 0, item
|
||||||
|
return
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
|
awaits = {loop.create_task(anext(it)): (i, it) for i, it in enumerate(iterators)}
|
||||||
|
try:
|
||||||
|
while awaits:
|
||||||
|
done, _ = await asyncio.wait(awaits.keys(), return_when=FIRST_COMPLETED)
|
||||||
|
for d in done:
|
||||||
|
pair = awaits.pop(d)
|
||||||
|
try:
|
||||||
|
item = await d
|
||||||
|
i, it = pair
|
||||||
|
awaits[loop.create_task(anext(it))] = pair
|
||||||
|
yield i, item
|
||||||
|
except StopAsyncIteration:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
# Cancel any remaining iterators
|
||||||
|
for f, (_, it) in awaits.items():
|
||||||
|
with contextlib.suppress(BaseException):
|
||||||
|
f.cancel()
|
||||||
|
await it.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
async def collect_from_async_generator(iterator: AsyncGenerator[T, None]) -> list[T]:
|
||||||
|
"""Collect all items from an async generator into a list."""
|
||||||
|
items = []
|
||||||
|
async for item in iterator:
|
||||||
|
items.append(item)
|
||||||
|
return items
|
||||||
214
vllm/_utils/cache.py
Normal file
214
vllm/_utils/cache.py
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from collections import UserDict
|
||||||
|
from collections.abc import Callable, Hashable, Iterator, KeysView, Mapping
|
||||||
|
from types import MappingProxyType
|
||||||
|
from typing import NamedTuple, TypeVar, cast, overload
|
||||||
|
|
||||||
|
import cachetools
|
||||||
|
|
||||||
|
_K = TypeVar("_K", bound=Hashable)
|
||||||
|
_V = TypeVar("_V")
|
||||||
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
|
|
||||||
|
class _Sentinel: ...
|
||||||
|
|
||||||
|
|
||||||
|
ALL_PINNED_SENTINEL = _Sentinel()
|
||||||
|
|
||||||
|
|
||||||
|
class _MappingOrderCacheView(UserDict[_K, _V]):
|
||||||
|
def __init__(self, data: Mapping[_K, _V], ordered_keys: Mapping[_K, None]):
|
||||||
|
super().__init__(data)
|
||||||
|
self.ordered_keys = ordered_keys
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[_K]:
|
||||||
|
return iter(self.ordered_keys)
|
||||||
|
|
||||||
|
def keys(self) -> KeysView[_K]:
|
||||||
|
return KeysView(self.ordered_keys)
|
||||||
|
|
||||||
|
|
||||||
|
class CacheInfo(NamedTuple):
|
||||||
|
hits: int
|
||||||
|
total: int
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hit_ratio(self) -> float:
|
||||||
|
if self.total == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
return self.hits / self.total
|
||||||
|
|
||||||
|
def __sub__(self, other: "CacheInfo"):
|
||||||
|
return CacheInfo(
|
||||||
|
hits=self.hits - other.hits,
|
||||||
|
total=self.total - other.total,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LRUCache(cachetools.LRUCache[_K, _V]):
|
||||||
|
def __init__(self, capacity: float, getsizeof: Callable[[_V], float] | None = None):
|
||||||
|
super().__init__(capacity, getsizeof)
|
||||||
|
|
||||||
|
self.pinned_items = set[_K]()
|
||||||
|
|
||||||
|
self._hits = 0
|
||||||
|
self._total = 0
|
||||||
|
self._last_info = CacheInfo(hits=0, total=0)
|
||||||
|
|
||||||
|
def __getitem__(self, key: _K, *, update_info: bool = True) -> _V:
|
||||||
|
value = super().__getitem__(key)
|
||||||
|
|
||||||
|
if update_info:
|
||||||
|
self._hits += 1
|
||||||
|
self._total += 1
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
def __delitem__(self, key: _K) -> None:
|
||||||
|
run_on_remove = key in self
|
||||||
|
value = self.__getitem__(key, update_info=False) # type: ignore[call-arg]
|
||||||
|
super().__delitem__(key)
|
||||||
|
if key in self.pinned_items:
|
||||||
|
# Todo: add warning to inform that del pinned item
|
||||||
|
self._unpin(key)
|
||||||
|
if run_on_remove:
|
||||||
|
self._on_remove(key, value)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cache(self) -> Mapping[_K, _V]:
|
||||||
|
"""Return the internal cache dictionary in order (read-only)."""
|
||||||
|
return _MappingOrderCacheView(
|
||||||
|
self._Cache__data, # type: ignore
|
||||||
|
self.order,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def order(self) -> Mapping[_K, None]:
|
||||||
|
"""Return the internal order dictionary (read-only)."""
|
||||||
|
return MappingProxyType(self._LRUCache__order) # type: ignore
|
||||||
|
|
||||||
|
@property
|
||||||
|
def capacity(self) -> float:
|
||||||
|
return self.maxsize
|
||||||
|
|
||||||
|
@property
|
||||||
|
def usage(self) -> float:
|
||||||
|
if self.maxsize == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
return self.currsize / self.maxsize
|
||||||
|
|
||||||
|
def stat(self, *, delta: bool = False) -> CacheInfo:
|
||||||
|
"""
|
||||||
|
Gets the cumulative number of hits and queries against this cache.
|
||||||
|
|
||||||
|
If `delta=True`, instead gets these statistics
|
||||||
|
since the last call that also passed `delta=True`.
|
||||||
|
"""
|
||||||
|
info = CacheInfo(hits=self._hits, total=self._total)
|
||||||
|
|
||||||
|
if delta:
|
||||||
|
info_delta = info - self._last_info
|
||||||
|
self._last_info = info
|
||||||
|
info = info_delta
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
|
def touch(self, key: _K) -> None:
|
||||||
|
try:
|
||||||
|
self._LRUCache__order.move_to_end(key) # type: ignore
|
||||||
|
except KeyError:
|
||||||
|
self._LRUCache__order[key] = None # type: ignore
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get(self, key: _K, /) -> _V | None: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get(self, key: _K, /, default: _V | _T) -> _V | _T: ...
|
||||||
|
|
||||||
|
def get(self, key: _K, /, default: _V | _T | None = None) -> _V | _T | None:
|
||||||
|
value: _V | _T | None
|
||||||
|
if key in self:
|
||||||
|
value = self.__getitem__(key, update_info=False) # type: ignore[call-arg]
|
||||||
|
|
||||||
|
self._hits += 1
|
||||||
|
else:
|
||||||
|
value = default
|
||||||
|
|
||||||
|
self._total += 1
|
||||||
|
return value
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def pop(self, key: _K) -> _V: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def pop(self, key: _K, default: _V | _T) -> _V | _T: ...
|
||||||
|
|
||||||
|
def pop(self, key: _K, default: _V | _T | None = None) -> _V | _T | None:
|
||||||
|
value: _V | _T | None
|
||||||
|
if key not in self:
|
||||||
|
return default
|
||||||
|
|
||||||
|
value = self.__getitem__(key, update_info=False) # type: ignore[call-arg]
|
||||||
|
self.__delitem__(key)
|
||||||
|
return value
|
||||||
|
|
||||||
|
def put(self, key: _K, value: _V) -> None:
|
||||||
|
self.__setitem__(key, value)
|
||||||
|
|
||||||
|
def pin(self, key: _K) -> None:
|
||||||
|
"""
|
||||||
|
Pins a key in the cache preventing it from being
|
||||||
|
evicted in the LRU order.
|
||||||
|
"""
|
||||||
|
if key not in self:
|
||||||
|
raise ValueError(f"Cannot pin key: {key} not in cache.")
|
||||||
|
self.pinned_items.add(key)
|
||||||
|
|
||||||
|
def _unpin(self, key: _K) -> None:
|
||||||
|
"""
|
||||||
|
Unpins a key in the cache allowing it to be
|
||||||
|
evicted in the LRU order.
|
||||||
|
"""
|
||||||
|
self.pinned_items.remove(key)
|
||||||
|
|
||||||
|
def _on_remove(self, key: _K, value: _V | None) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def remove_oldest(self, *, remove_pinned: bool = False) -> None:
|
||||||
|
if len(self) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.popitem(remove_pinned=remove_pinned)
|
||||||
|
|
||||||
|
def _remove_old_if_needed(self) -> None:
|
||||||
|
while self.currsize > self.capacity:
|
||||||
|
self.remove_oldest()
|
||||||
|
|
||||||
|
def popitem(self, remove_pinned: bool = False):
|
||||||
|
"""Remove and return the `(key, value)` pair least recently used."""
|
||||||
|
if not remove_pinned:
|
||||||
|
# pop the oldest item in the cache that is not pinned
|
||||||
|
lru_key = next(
|
||||||
|
(key for key in self.order if key not in self.pinned_items),
|
||||||
|
ALL_PINNED_SENTINEL,
|
||||||
|
)
|
||||||
|
if lru_key is ALL_PINNED_SENTINEL:
|
||||||
|
raise RuntimeError(
|
||||||
|
"All items are pinned, cannot remove oldest from the cache."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
lru_key = next(iter(self.order))
|
||||||
|
value = self.pop(cast(_K, lru_key))
|
||||||
|
return (lru_key, value)
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
while len(self) > 0:
|
||||||
|
self.remove_oldest(remove_pinned=True)
|
||||||
|
|
||||||
|
self._hits = 0
|
||||||
|
self._total = 0
|
||||||
|
self._last_info = CacheInfo(hits=0, total=0)
|
||||||
139
vllm/_utils/collection_utils.py
Normal file
139
vllm/_utils/collection_utils.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Contains helpers that are applied to collections.
|
||||||
|
|
||||||
|
This is similar in concept to the `collections` module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections import UserDict, defaultdict
|
||||||
|
from collections.abc import Callable, Generator, Hashable, Iterable, Mapping
|
||||||
|
from typing import Generic, Literal, TypeVar
|
||||||
|
|
||||||
|
from typing_extensions import TypeIs, assert_never
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
U = TypeVar("U")
|
||||||
|
|
||||||
|
_K = TypeVar("_K", bound=Hashable)
|
||||||
|
_V = TypeVar("_V")
|
||||||
|
|
||||||
|
|
||||||
|
class ClassRegistry(UserDict[type[T], _V]):
|
||||||
|
"""
|
||||||
|
A registry that acts like a dictionary but searches for other classes
|
||||||
|
in the MRO if the original class is not found.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __getitem__(self, key: type[T]) -> _V:
|
||||||
|
for cls in key.mro():
|
||||||
|
if cls in self.data:
|
||||||
|
return self.data[cls]
|
||||||
|
|
||||||
|
raise KeyError(key)
|
||||||
|
|
||||||
|
def __contains__(self, key: object) -> bool:
|
||||||
|
return self.contains(key)
|
||||||
|
|
||||||
|
def contains(self, key: object, *, strict: bool = False) -> bool:
|
||||||
|
if not isinstance(key, type):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if strict:
|
||||||
|
return key in self.data
|
||||||
|
|
||||||
|
return any(cls in self.data for cls in key.mro())
|
||||||
|
|
||||||
|
|
||||||
|
class LazyDict(Mapping[str, T], Generic[T]):
|
||||||
|
"""
|
||||||
|
Evaluates dictionary items only when they are accessed.
|
||||||
|
|
||||||
|
Adapted from: https://stackoverflow.com/a/47212782/5082708
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, factory: dict[str, Callable[[], T]]):
|
||||||
|
self._factory = factory
|
||||||
|
self._dict: dict[str, T] = {}
|
||||||
|
|
||||||
|
def __getitem__(self, key: str) -> T:
|
||||||
|
if key not in self._dict:
|
||||||
|
if key not in self._factory:
|
||||||
|
raise KeyError(key)
|
||||||
|
self._dict[key] = self._factory[key]()
|
||||||
|
return self._dict[key]
|
||||||
|
|
||||||
|
def __setitem__(self, key: str, value: Callable[[], T]):
|
||||||
|
self._factory[key] = value
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self._factory)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._factory)
|
||||||
|
|
||||||
|
|
||||||
|
def as_list(maybe_list: Iterable[T]) -> list[T]:
|
||||||
|
"""Convert iterable to list, unless it's already a list."""
|
||||||
|
return maybe_list if isinstance(maybe_list, list) else list(maybe_list)
|
||||||
|
|
||||||
|
|
||||||
|
def as_iter(obj: T | Iterable[T]) -> Iterable[T]:
|
||||||
|
if isinstance(obj, str) or not isinstance(obj, Iterable):
|
||||||
|
return [obj] # type: ignore[list-item]
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
def is_list_of(
|
||||||
|
value: object,
|
||||||
|
typ: type[T] | tuple[type[T], ...],
|
||||||
|
*,
|
||||||
|
check: Literal["first", "all"] = "first",
|
||||||
|
) -> TypeIs[list[T]]:
|
||||||
|
if not isinstance(value, list):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if check == "first":
|
||||||
|
return len(value) == 0 or isinstance(value[0], typ)
|
||||||
|
elif check == "all":
|
||||||
|
return all(isinstance(v, typ) for v in value)
|
||||||
|
|
||||||
|
assert_never(check)
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_list(lst: list[T], chunk_size: int) -> Generator[list[T]]:
|
||||||
|
"""Yield successive chunk_size chunks from lst."""
|
||||||
|
for i in range(0, len(lst), chunk_size):
|
||||||
|
yield lst[i : i + chunk_size]
|
||||||
|
|
||||||
|
|
||||||
|
def flatten_2d_lists(lists: Iterable[Iterable[T]]) -> list[T]:
|
||||||
|
"""Flatten a list of lists to a single list."""
|
||||||
|
return [item for sublist in lists for item in sublist]
|
||||||
|
|
||||||
|
|
||||||
|
def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]):
|
||||||
|
"""
|
||||||
|
Unlike [`itertools.groupby`][], groups are not broken by
|
||||||
|
non-contiguous data.
|
||||||
|
"""
|
||||||
|
groups = defaultdict[_K, list[_V]](list)
|
||||||
|
|
||||||
|
for value in values:
|
||||||
|
groups[key(value)].append(value)
|
||||||
|
|
||||||
|
return groups.items()
|
||||||
|
|
||||||
|
|
||||||
|
def swap_dict_values(obj: dict[_K, _V], key1: _K, key2: _K) -> None:
|
||||||
|
"""Swap values between two keys."""
|
||||||
|
v1 = obj.get(key1)
|
||||||
|
v2 = obj.get(key2)
|
||||||
|
if v1 is not None:
|
||||||
|
obj[key2] = v1
|
||||||
|
else:
|
||||||
|
obj.pop(key2, None)
|
||||||
|
if v2 is not None:
|
||||||
|
obj[key1] = v2
|
||||||
|
else:
|
||||||
|
obj.pop(key1, None)
|
||||||
45
vllm/_utils/counter.py
Normal file
45
vllm/_utils/counter.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import threading
|
||||||
|
|
||||||
|
|
||||||
|
class Counter:
|
||||||
|
def __init__(self, start: int = 0) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.counter = start
|
||||||
|
|
||||||
|
def __next__(self) -> int:
|
||||||
|
i = self.counter
|
||||||
|
self.counter += 1
|
||||||
|
return i
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
self.counter = 0
|
||||||
|
|
||||||
|
|
||||||
|
class AtomicCounter:
|
||||||
|
"""An atomic, thread-safe counter"""
|
||||||
|
|
||||||
|
def __init__(self, initial: int = 0) -> None:
|
||||||
|
"""Initialize a new atomic counter to given initial value"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self._value = initial
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def value(self) -> int:
|
||||||
|
return self._value
|
||||||
|
|
||||||
|
def inc(self, num: int = 1) -> int:
|
||||||
|
"""Atomically increment the counter by num and return the new value"""
|
||||||
|
with self._lock:
|
||||||
|
self._value += num
|
||||||
|
return self._value
|
||||||
|
|
||||||
|
def dec(self, num: int = 1) -> int:
|
||||||
|
"""Atomically decrement the counter by num and return the new value"""
|
||||||
|
with self._lock:
|
||||||
|
self._value -= num
|
||||||
|
return self._value
|
||||||
391
vllm/_utils/deep_gemm.py
Normal file
391
vllm/_utils/deep_gemm.py
Normal file
@@ -0,0 +1,391 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Compatibility wrapper for DeepGEMM API changes.
|
||||||
|
|
||||||
|
Users of vLLM should always import **only** these wrappers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import importlib
|
||||||
|
import os
|
||||||
|
from collections.abc import Callable
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, NoReturn
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.logger import logger
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils.import_utils import has_deep_gemm
|
||||||
|
from vllm.utils.math_utils import cdiv
|
||||||
|
|
||||||
|
|
||||||
|
class DeepGemmQuantScaleFMT(Enum):
|
||||||
|
# Float32 scales in Float32 tensor
|
||||||
|
FLOAT32 = 0
|
||||||
|
# Compute float32 scales and ceil the scales to UE8M0.
|
||||||
|
# Keep the scales in Float32 tensor.
|
||||||
|
FLOAT32_CEIL_UE8M0 = 1
|
||||||
|
# Compute float32 scales and ceil the scales to UE8M0.
|
||||||
|
# Pack the scales into a int32 tensor where each int32
|
||||||
|
# element contains 4 scale values.
|
||||||
|
UE8M0 = 2
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_oracle() -> "DeepGemmQuantScaleFMT":
|
||||||
|
if not is_deep_gemm_e8m0_used():
|
||||||
|
return DeepGemmQuantScaleFMT.FLOAT32
|
||||||
|
return (
|
||||||
|
DeepGemmQuantScaleFMT.UE8M0
|
||||||
|
if current_platform.is_device_capability(100)
|
||||||
|
else DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def is_deep_gemm_supported() -> bool:
|
||||||
|
"""Return `True` if DeepGEMM is supported on the current platform.
|
||||||
|
Currently, only Hopper and Blackwell GPUs are supported.
|
||||||
|
"""
|
||||||
|
is_supported_arch = current_platform.is_cuda() and (
|
||||||
|
current_platform.is_device_capability(90)
|
||||||
|
or current_platform.is_device_capability(100)
|
||||||
|
)
|
||||||
|
return envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def is_deep_gemm_e8m0_used() -> bool:
|
||||||
|
"""Return `True` if vLLM is configured to use DeepGEMM "
|
||||||
|
"E8M0 scale on a Hopper or Blackwell-class GPU.
|
||||||
|
"""
|
||||||
|
if not is_deep_gemm_supported():
|
||||||
|
logger.debug_once(
|
||||||
|
"DeepGEMM E8M0 disabled: DeepGEMM not supported on this system."
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
_lazy_init()
|
||||||
|
|
||||||
|
if _fp8_gemm_nt_impl is None:
|
||||||
|
logger.info_once("DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if envs.VLLM_USE_DEEP_GEMM_E8M0:
|
||||||
|
logger.info_once("DeepGEMM E8M0 enabled on current platform.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
logger.info_once("DeepGEMM E8M0 disabled on current configuration.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _missing(*_: Any, **__: Any) -> NoReturn:
|
||||||
|
"""Placeholder for unavailable DeepGEMM backend."""
|
||||||
|
raise RuntimeError(
|
||||||
|
"DeepGEMM backend is not available or outdated. Please install or "
|
||||||
|
"update the `deep_gemm` to a newer version to enable FP8 kernels."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
|
||||||
|
_grouped_impl: Callable[..., Any] | None = None
|
||||||
|
_grouped_masked_impl: Callable[..., Any] | None = None
|
||||||
|
_fp8_mqa_logits_impl: Callable[..., Any] | None = None
|
||||||
|
_fp8_paged_mqa_logits_impl: Callable[..., Any] | None = None
|
||||||
|
_get_paged_mqa_logits_metadata_impl: Callable[..., Any] | None = None
|
||||||
|
_get_mn_major_tma_aligned_tensor_impl: Callable[..., Any] | None = None
|
||||||
|
_get_mk_alignment_for_contiguous_layout_impl: Callable[..., Any] | None = None
|
||||||
|
_transform_sf_into_required_layout_impl: Callable[..., Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _lazy_init() -> None:
|
||||||
|
"""Import deep_gemm and resolve symbols on first use."""
|
||||||
|
global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl
|
||||||
|
global _fp8_mqa_logits_impl, _fp8_paged_mqa_logits_impl
|
||||||
|
global _get_paged_mqa_logits_metadata_impl
|
||||||
|
global _get_mn_major_tma_aligned_tensor_impl
|
||||||
|
global _get_mk_alignment_for_contiguous_layout_impl
|
||||||
|
global _transform_sf_into_required_layout_impl
|
||||||
|
# fast path
|
||||||
|
if (
|
||||||
|
_fp8_gemm_nt_impl is not None
|
||||||
|
or _grouped_impl is not None
|
||||||
|
or _grouped_masked_impl is not None
|
||||||
|
or _fp8_mqa_logits_impl is not None
|
||||||
|
or _fp8_paged_mqa_logits_impl is not None
|
||||||
|
or _get_paged_mqa_logits_metadata_impl is not None
|
||||||
|
or _get_mk_alignment_for_contiguous_layout_impl is not None
|
||||||
|
or _transform_sf_into_required_layout_impl is not None
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
if not has_deep_gemm():
|
||||||
|
return
|
||||||
|
|
||||||
|
# Set up deep_gemm cache path
|
||||||
|
DEEP_GEMM_JIT_CACHE_ENV_NAME = "DG_JIT_CACHE_DIR"
|
||||||
|
if not os.environ.get(DEEP_GEMM_JIT_CACHE_ENV_NAME, None):
|
||||||
|
os.environ[DEEP_GEMM_JIT_CACHE_ENV_NAME] = os.path.join(
|
||||||
|
envs.VLLM_CACHE_ROOT, "deep_gemm"
|
||||||
|
)
|
||||||
|
|
||||||
|
_dg = importlib.import_module("deep_gemm")
|
||||||
|
|
||||||
|
_fp8_gemm_nt_impl = getattr(_dg, "fp8_gemm_nt", None)
|
||||||
|
_grouped_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_contiguous", None)
|
||||||
|
_grouped_masked_impl = getattr(_dg, "fp8_m_grouped_gemm_nt_masked", None)
|
||||||
|
_fp8_mqa_logits_impl = getattr(_dg, "fp8_mqa_logits", None)
|
||||||
|
_fp8_paged_mqa_logits_impl = getattr(_dg, "fp8_paged_mqa_logits", None)
|
||||||
|
_get_paged_mqa_logits_metadata_impl = getattr(
|
||||||
|
_dg, "get_paged_mqa_logits_metadata", None
|
||||||
|
)
|
||||||
|
_get_mn_major_tma_aligned_tensor_impl = getattr(
|
||||||
|
_dg, "get_mn_major_tma_aligned_tensor", None
|
||||||
|
)
|
||||||
|
_get_mk_alignment_for_contiguous_layout_impl = getattr(
|
||||||
|
_dg, "get_mk_alignment_for_contiguous_layout", None
|
||||||
|
)
|
||||||
|
_transform_sf_into_required_layout_impl = getattr(
|
||||||
|
_dg, "transform_sf_into_required_layout", None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_num_sms() -> int:
|
||||||
|
_lazy_init()
|
||||||
|
_dg = importlib.import_module("deep_gemm")
|
||||||
|
return int(_dg.get_num_sms())
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def get_mk_alignment_for_contiguous_layout() -> list[int]:
|
||||||
|
_lazy_init()
|
||||||
|
if _get_mk_alignment_for_contiguous_layout_impl is None:
|
||||||
|
return _missing()
|
||||||
|
mk_align_size = _get_mk_alignment_for_contiguous_layout_impl()
|
||||||
|
return [mk_align_size, mk_align_size]
|
||||||
|
|
||||||
|
|
||||||
|
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Wrapper for DeepGEMM's get_mn_major_tma_aligned_tensor"""
|
||||||
|
_lazy_init()
|
||||||
|
if _get_mn_major_tma_aligned_tensor_impl is None:
|
||||||
|
return _missing()
|
||||||
|
return _get_mn_major_tma_aligned_tensor_impl(x)
|
||||||
|
|
||||||
|
|
||||||
|
def fp8_gemm_nt(*args, **kwargs):
|
||||||
|
_lazy_init()
|
||||||
|
if _fp8_gemm_nt_impl is None:
|
||||||
|
return _missing(*args, **kwargs)
|
||||||
|
if "is_deep_gemm_e8m0_used" in kwargs:
|
||||||
|
use_ue8m0 = kwargs["is_deep_gemm_e8m0_used"]
|
||||||
|
del kwargs["is_deep_gemm_e8m0_used"]
|
||||||
|
else:
|
||||||
|
use_ue8m0 = is_deep_gemm_e8m0_used()
|
||||||
|
return _fp8_gemm_nt_impl(*args, disable_ue8m0_cast=not use_ue8m0, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
|
||||||
|
_lazy_init()
|
||||||
|
if _grouped_impl is None:
|
||||||
|
return _missing(*args, **kwargs)
|
||||||
|
return _grouped_impl(
|
||||||
|
*args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
|
||||||
|
_lazy_init()
|
||||||
|
if _grouped_masked_impl is None:
|
||||||
|
return _missing(*args, **kwargs)
|
||||||
|
return _grouped_masked_impl(
|
||||||
|
*args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def transform_sf_into_required_layout(*args, **kwargs):
|
||||||
|
_lazy_init()
|
||||||
|
if _transform_sf_into_required_layout_impl is None:
|
||||||
|
return _missing(*args, **kwargs)
|
||||||
|
return _transform_sf_into_required_layout_impl(
|
||||||
|
*args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fp8_mqa_logits(
|
||||||
|
q: torch.Tensor,
|
||||||
|
kv: tuple[torch.Tensor, torch.Tensor],
|
||||||
|
weights: torch.Tensor,
|
||||||
|
cu_seqlen_ks: torch.Tensor,
|
||||||
|
cu_seqlen_ke: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Compute FP8 MQA logits for a single sequence without KV paging.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: Query tensor of shape [M, H, D]. Casted to
|
||||||
|
`torch.float8_e4m3fn` by caller.
|
||||||
|
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
|
||||||
|
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
|
||||||
|
[N, 1]) with dtype `torch.float32`.
|
||||||
|
weights: weights of shape [M, H], dtype `torch.float32`.
|
||||||
|
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
|
||||||
|
shape [M], dtype int32.
|
||||||
|
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
|
||||||
|
shape [M], dtype int32.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Logits tensor of shape [M, N], dtype `torch.float32`.
|
||||||
|
"""
|
||||||
|
_lazy_init()
|
||||||
|
if _fp8_mqa_logits_impl is None:
|
||||||
|
return _missing()
|
||||||
|
return _fp8_mqa_logits_impl(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke)
|
||||||
|
|
||||||
|
|
||||||
|
def get_paged_mqa_logits_metadata(
|
||||||
|
context_lens: torch.Tensor, block_size: int, num_sms: int
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Build scheduling metadata for paged MQA logits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_lens: Tensor of shape [B], dtype int32; effective context length
|
||||||
|
per batch element.
|
||||||
|
block_size: KV-cache block size in tokens (e.g., 64).
|
||||||
|
num_sms: Number of SMs available. 132 for Hopper
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Backend-specific tensor consumed by `fp8_paged_mqa_logits` to
|
||||||
|
schedule work across SMs.
|
||||||
|
"""
|
||||||
|
_lazy_init()
|
||||||
|
if _get_paged_mqa_logits_metadata_impl is None:
|
||||||
|
return _missing()
|
||||||
|
return _get_paged_mqa_logits_metadata_impl(context_lens, block_size, num_sms)
|
||||||
|
|
||||||
|
|
||||||
|
def fp8_paged_mqa_logits(
|
||||||
|
q_fp8: torch.Tensor,
|
||||||
|
kv_cache_fp8: torch.Tensor,
|
||||||
|
weights: torch.Tensor,
|
||||||
|
context_lens: torch.Tensor,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
schedule_metadata: torch.Tensor,
|
||||||
|
max_model_len: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Compute FP8 MQA logits using paged KV-cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to
|
||||||
|
`torch.float8_e4m3fn` by caller.
|
||||||
|
kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape
|
||||||
|
[num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last
|
||||||
|
4 bytes per (block,pos) store the `float` dequant scale.
|
||||||
|
weights: Tensor of shape [B * next_n, H], dtype `torch.float32`.
|
||||||
|
context_lens: Tensor of shape [B], dtype int32; effective context length
|
||||||
|
for each batch element.
|
||||||
|
block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical
|
||||||
|
block indices to physical blocks in the paged cache.
|
||||||
|
schedule_metadata: Returned by `get_paged_mqa_logits_metadata`;
|
||||||
|
used to distribute work across SMs.
|
||||||
|
max_model_len: Maximum sequence length used to size the logits output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Logits tensor of shape [B * next_n, max_model_len], dtype
|
||||||
|
`torch.float32`.
|
||||||
|
"""
|
||||||
|
_lazy_init()
|
||||||
|
if _fp8_paged_mqa_logits_impl is None:
|
||||||
|
return _missing()
|
||||||
|
return _fp8_paged_mqa_logits_impl(
|
||||||
|
q_fp8,
|
||||||
|
kv_cache_fp8,
|
||||||
|
weights,
|
||||||
|
context_lens,
|
||||||
|
block_tables,
|
||||||
|
schedule_metadata,
|
||||||
|
max_model_len,
|
||||||
|
clean_logits=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _ceil_to_ue8m0(x: torch.Tensor):
|
||||||
|
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
|
||||||
|
|
||||||
|
|
||||||
|
def _align(x: int, y: int) -> int:
|
||||||
|
return cdiv(x, y) * y
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_BLOCK_SIZE = [128, 128]
|
||||||
|
|
||||||
|
|
||||||
|
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/dd6ed14acbc7445dcef224248a77ab4d22b5f240/deep_gemm/utils/math.py#L38
|
||||||
|
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
||||||
|
def per_block_cast_to_fp8(
|
||||||
|
x: torch.Tensor, block_size: list[int] = DEFAULT_BLOCK_SIZE, use_ue8m0: bool = False
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
assert x.dim() == 2
|
||||||
|
m, n = x.shape
|
||||||
|
block_m, block_n = block_size
|
||||||
|
x_padded = torch.zeros(
|
||||||
|
(_align(m, block_m), _align(n, block_n)), dtype=x.dtype, device=x.device
|
||||||
|
)
|
||||||
|
x_padded[:m, :n] = x
|
||||||
|
x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
|
||||||
|
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||||
|
sf = x_amax / 448.0
|
||||||
|
sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
||||||
|
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
|
||||||
|
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
|
||||||
|
x_view.size(0), x_view.size(2)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def calc_diff(x: torch.Tensor, y: torch.Tensor):
|
||||||
|
"""Return a global difference metric for unit tests.
|
||||||
|
|
||||||
|
DeepGEMM kernels on Blackwell/B200 currently exhibit noticeable per-element
|
||||||
|
error, causing `torch.testing.assert_close` to fail. Instead of checking
|
||||||
|
every element, we compute a cosine-style similarity over the whole tensor
|
||||||
|
and report `1 - sim`. Once kernel accuracy improves this helper can be
|
||||||
|
removed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
x, y = x.double(), y.double()
|
||||||
|
denominator = (x * x + y * y).sum()
|
||||||
|
sim = 2 * (x * y).sum() / denominator
|
||||||
|
return 1 - sim
|
||||||
|
|
||||||
|
|
||||||
|
def should_use_deepgemm_for_fp8_linear(
|
||||||
|
output_dtype: torch.dtype,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
supports_deep_gemm: bool | None = None,
|
||||||
|
):
|
||||||
|
if supports_deep_gemm is None:
|
||||||
|
supports_deep_gemm = is_deep_gemm_supported()
|
||||||
|
return (
|
||||||
|
supports_deep_gemm
|
||||||
|
and output_dtype == torch.bfloat16
|
||||||
|
and weight.shape[0] % 128 == 0
|
||||||
|
and weight.shape[1] % 128 == 0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"calc_diff",
|
||||||
|
"fp8_gemm_nt",
|
||||||
|
"m_grouped_fp8_gemm_nt_contiguous",
|
||||||
|
"fp8_m_grouped_gemm_nt_masked",
|
||||||
|
"fp8_mqa_logits",
|
||||||
|
"fp8_paged_mqa_logits",
|
||||||
|
"get_paged_mqa_logits_metadata",
|
||||||
|
"per_block_cast_to_fp8",
|
||||||
|
"is_deep_gemm_e8m0_used",
|
||||||
|
"is_deep_gemm_supported",
|
||||||
|
"get_num_sms",
|
||||||
|
"should_use_deepgemm_for_fp8_linear",
|
||||||
|
"get_col_major_tma_aligned_tensor",
|
||||||
|
"get_mk_alignment_for_contiguous_layout",
|
||||||
|
]
|
||||||
492
vllm/_utils/flashinfer.py
Normal file
492
vllm/_utils/flashinfer.py
Normal file
@@ -0,0 +1,492 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Compatibility wrapper for FlashInfer API changes.
|
||||||
|
|
||||||
|
Users of vLLM should always import **only** these wrappers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import functools
|
||||||
|
import importlib
|
||||||
|
import importlib.util
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Any, NoReturn
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.batch_invariant import (
|
||||||
|
vllm_is_batch_invariant,
|
||||||
|
)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# This is the storage path for the cubins, it can be replaced
|
||||||
|
# with a local path for testing.
|
||||||
|
# Referenced from https://github.com/flashinfer-ai/flashinfer/blob/0c9a92c3d9a7e043ab6f3f7b2273269caf6ab044/flashinfer/jit/cubin_loader.py#L35 # noqa: E501
|
||||||
|
FLASHINFER_CUBINS_REPOSITORY = os.environ.get(
|
||||||
|
"FLASHINFER_CUBINS_REPOSITORY",
|
||||||
|
"https://edge.urm.nvidia.com/artifactory/sw-kernelinferencelibrary-public-generic-local/", # noqa: E501
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def has_flashinfer_cubin() -> bool:
|
||||||
|
"""Return `True` if flashinfer-cubin package is available."""
|
||||||
|
if envs.VLLM_HAS_FLASHINFER_CUBIN:
|
||||||
|
return True
|
||||||
|
if importlib.util.find_spec("flashinfer_cubin") is not None:
|
||||||
|
return True
|
||||||
|
logger.debug_once("flashinfer-cubin package was not found")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def has_flashinfer() -> bool:
|
||||||
|
"""Return `True` if flashinfer-python package is available."""
|
||||||
|
# Use find_spec to check if the module exists without importing it
|
||||||
|
# This avoids potential CUDA initialization side effects
|
||||||
|
if importlib.util.find_spec("flashinfer") is None:
|
||||||
|
logger.debug_once("FlashInfer unavailable since package was not found")
|
||||||
|
return False
|
||||||
|
# When not using flashinfer cubin,
|
||||||
|
# Also check if nvcc is available since it's required to JIT compile flashinfer
|
||||||
|
if not has_flashinfer_cubin() and shutil.which("nvcc") is None:
|
||||||
|
logger.debug_once(
|
||||||
|
"FlashInfer unavailable since nvcc was not found "
|
||||||
|
"and not using pre-downloaded cubins"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _missing(*_: Any, **__: Any) -> NoReturn:
|
||||||
|
"""Placeholder for unavailable FlashInfer backend."""
|
||||||
|
raise RuntimeError(
|
||||||
|
"FlashInfer backend is not available. Please install the package "
|
||||||
|
"to enable FlashInfer kernels: "
|
||||||
|
"https://github.com/flashinfer-ai/flashinfer"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_submodule(module_name: str) -> Any | None:
|
||||||
|
"""Safely import a submodule and return it, or None if not available."""
|
||||||
|
try:
|
||||||
|
return importlib.import_module(module_name)
|
||||||
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# General lazy import wrapper
|
||||||
|
def _lazy_import_wrapper(
|
||||||
|
module_name: str, attr_name: str, fallback_fn: Callable[..., Any] = _missing
|
||||||
|
):
|
||||||
|
"""Create a lazy import wrapper for a specific function."""
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def _get_impl():
|
||||||
|
if not has_flashinfer():
|
||||||
|
return None
|
||||||
|
mod = _get_submodule(module_name)
|
||||||
|
return getattr(mod, attr_name, None) if mod else None
|
||||||
|
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
impl = _get_impl()
|
||||||
|
if impl is None:
|
||||||
|
return fallback_fn(*args, **kwargs)
|
||||||
|
return impl(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
# Create lazy wrappers for each function
|
||||||
|
flashinfer_trtllm_fp8_block_scale_moe = _lazy_import_wrapper(
|
||||||
|
"flashinfer.fused_moe", "trtllm_fp8_block_scale_moe"
|
||||||
|
)
|
||||||
|
flashinfer_trtllm_fp8_per_tensor_scale_moe = _lazy_import_wrapper(
|
||||||
|
"flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe"
|
||||||
|
)
|
||||||
|
flashinfer_cutlass_fused_moe = _lazy_import_wrapper(
|
||||||
|
"flashinfer.fused_moe", "cutlass_fused_moe"
|
||||||
|
)
|
||||||
|
flashinfer_fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
|
||||||
|
nvfp4_block_scale_interleave = _lazy_import_wrapper(
|
||||||
|
"flashinfer", "nvfp4_block_scale_interleave"
|
||||||
|
)
|
||||||
|
trtllm_fp4_block_scale_moe = _lazy_import_wrapper(
|
||||||
|
"flashinfer", "trtllm_fp4_block_scale_moe"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Special case for autotune since it returns a context manager
|
||||||
|
autotune = _lazy_import_wrapper(
|
||||||
|
"flashinfer.autotuner",
|
||||||
|
"autotune",
|
||||||
|
fallback_fn=lambda *args, **kwargs: contextlib.nullcontext(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def has_flashinfer_comm() -> bool:
|
||||||
|
"""Return `True` if FlashInfer comm module is available."""
|
||||||
|
return has_flashinfer() and importlib.util.find_spec("flashinfer.comm") is not None
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def has_flashinfer_all2all() -> bool:
|
||||||
|
"""Return `True` if FlashInfer mnnvl all2all is available."""
|
||||||
|
if not has_flashinfer_comm():
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if all required functions are available
|
||||||
|
required_functions = [
|
||||||
|
("flashinfer.comm", "Mapping"),
|
||||||
|
("flashinfer.comm.mnnvl", "MnnvlMemory"),
|
||||||
|
("flashinfer.comm.trtllm_alltoall", "MnnvlMoe"),
|
||||||
|
("flashinfer.comm.trtllm_alltoall", "MoEAlltoallInfo"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for module_name, attr_name in required_functions:
|
||||||
|
mod = _get_submodule(module_name)
|
||||||
|
if not mod or not hasattr(mod, attr_name):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def has_flashinfer_moe() -> bool:
|
||||||
|
"""Return `True` if FlashInfer MoE module is available."""
|
||||||
|
return (
|
||||||
|
has_flashinfer()
|
||||||
|
and importlib.util.find_spec("flashinfer.fused_moe") is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def has_flashinfer_cutlass_fused_moe() -> bool:
|
||||||
|
"""Return `True` if FlashInfer CUTLASS fused MoE is available."""
|
||||||
|
if not has_flashinfer_moe():
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if all required functions are available
|
||||||
|
required_functions = [
|
||||||
|
("flashinfer.fused_moe", "cutlass_fused_moe"),
|
||||||
|
("flashinfer", "fp4_quantize"),
|
||||||
|
("flashinfer", "nvfp4_block_scale_interleave"),
|
||||||
|
("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for module_name, attr_name in required_functions:
|
||||||
|
mod = _get_submodule(module_name)
|
||||||
|
if not mod or not hasattr(mod, attr_name):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def has_nvidia_artifactory() -> bool:
|
||||||
|
"""Return `True` if NVIDIA's artifactory is accessible.
|
||||||
|
|
||||||
|
This checks connectivity to the kernel inference library artifactory
|
||||||
|
which is required for downloading certain cubin kernels like TRTLLM FHMA.
|
||||||
|
"""
|
||||||
|
# If we have pre-downloaded cubins, we can assume the cubins are available.
|
||||||
|
if has_flashinfer_cubin():
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use a short timeout to avoid blocking for too long
|
||||||
|
response = requests.get(FLASHINFER_CUBINS_REPOSITORY, timeout=5)
|
||||||
|
accessible = response.status_code == 200
|
||||||
|
if accessible:
|
||||||
|
logger.debug_once("NVIDIA artifactory is accessible")
|
||||||
|
else:
|
||||||
|
logger.warning_once(
|
||||||
|
"NVIDIA artifactory returned failed status code: %d",
|
||||||
|
response.status_code,
|
||||||
|
)
|
||||||
|
return accessible
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning_once("Failed to connect to NVIDIA artifactory: %s", e)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def supports_trtllm_attention() -> bool:
|
||||||
|
"""
|
||||||
|
TRTLLM attention is supported if the platform is SM100,
|
||||||
|
NVIDIA artifactory is accessible, and batch-invariant mode is not enabled.
|
||||||
|
"""
|
||||||
|
# Batch-invariant mode disables TRTLLM attention
|
||||||
|
if vllm_is_batch_invariant():
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins
|
||||||
|
return current_platform.is_device_capability(100) and has_nvidia_artifactory()
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def _force_use_trtllm_attention(env_value: bool | None) -> bool | None:
|
||||||
|
"""Cache the env value for VLLM_USE_TRTLLM_ATTENTION"""
|
||||||
|
if env_value is not None:
|
||||||
|
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
|
||||||
|
return env_value
|
||||||
|
|
||||||
|
|
||||||
|
def force_use_trtllm_attention() -> bool | None:
|
||||||
|
"""
|
||||||
|
Return `None` if VLLM_USE_TRTLLM_ATTENTION is not set,
|
||||||
|
return `True` if TRTLLM attention is forced to be used,
|
||||||
|
return `False` if TRTLLM attention is forced to be not used.
|
||||||
|
"""
|
||||||
|
return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION)
|
||||||
|
|
||||||
|
|
||||||
|
def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
|
||||||
|
"""Check if the current configuration supports TRTLLM attention."""
|
||||||
|
if force_use_trtllm_attention() is False:
|
||||||
|
return False
|
||||||
|
has_trtllm = supports_trtllm_attention()
|
||||||
|
return has_trtllm and (num_qo_heads % num_kv_heads == 0)
|
||||||
|
|
||||||
|
|
||||||
|
def use_trtllm_attention(
|
||||||
|
num_qo_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
num_tokens: int,
|
||||||
|
max_seq_len: int,
|
||||||
|
dcp_world_size: int,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
q_dtype: torch.dtype,
|
||||||
|
is_prefill: bool,
|
||||||
|
has_sinks: bool = False,
|
||||||
|
has_spec: bool = False,
|
||||||
|
) -> bool:
|
||||||
|
"""Return `True` if TRTLLM attention is used."""
|
||||||
|
force_use_trtllm = force_use_trtllm_attention()
|
||||||
|
|
||||||
|
# Environment variable is set to 0 - respect it
|
||||||
|
if force_use_trtllm is not None and not force_use_trtllm:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Decode context parallel is not supported
|
||||||
|
if dcp_world_size > 1:
|
||||||
|
logger.warning_once(
|
||||||
|
"Trtllm does not support returning LSE and as a result "
|
||||||
|
"does not support DCP, reverting to FlashInfer"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# The platform is not supported
|
||||||
|
if not supports_trtllm_attention():
|
||||||
|
if force_use_trtllm:
|
||||||
|
logger.warning_once(
|
||||||
|
"TRTLLM attention is not supported on this platform, "
|
||||||
|
"but VLLM_USE_TRTLLM_ATTENTION is set to 1"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# The combination of query and key heads is not supported
|
||||||
|
if num_qo_heads % num_kv_heads != 0:
|
||||||
|
if force_use_trtllm:
|
||||||
|
logger.warning_once(
|
||||||
|
"TRTLLM attention is not supported for this combination of "
|
||||||
|
"query and key heads, but VLLM_USE_TRTLLM_ATTENTION is set to 1"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
if has_spec and not is_prefill:
|
||||||
|
# Speculative decoding requires TRTLLM attention for decodes
|
||||||
|
logger.info_once("Using TRTLLM attention (enabled for speculative decoding).")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Must use TRTLLM attention if query is FP8 quantized
|
||||||
|
if q_dtype == current_platform.fp8_dtype():
|
||||||
|
logger.info_once("Using TRTLLM attention (query is quantized).")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# If sinks are being used, we must use TRTLLM attention as it's
|
||||||
|
# the only backend that supports them
|
||||||
|
if has_sinks:
|
||||||
|
logger.info_once("Using TRTLLM attention (required for attention sinks).")
|
||||||
|
return True
|
||||||
|
|
||||||
|
if force_use_trtllm is None:
|
||||||
|
# Environment variable not set - use auto-detection
|
||||||
|
if is_prefill:
|
||||||
|
# Prefill auto-detection
|
||||||
|
use_trtllm = max_seq_len <= 131072 and kv_cache_dtype == "auto"
|
||||||
|
if use_trtllm:
|
||||||
|
logger.warning_once("Using TRTLLM prefill attention (auto-detected).")
|
||||||
|
else:
|
||||||
|
# Decode auto-detection
|
||||||
|
use_trtllm = (
|
||||||
|
num_tokens <= 256 and max_seq_len <= 131072 and kv_cache_dtype == "auto"
|
||||||
|
)
|
||||||
|
if use_trtllm:
|
||||||
|
logger.warning_once("Using TRTLLM decode attention (auto-detected).")
|
||||||
|
return use_trtllm
|
||||||
|
|
||||||
|
# Environment variable is set to 1 - respect it
|
||||||
|
logger.info_once("Using TRTLLM attention (VLLM_USE_TRTLLM_ATTENTION is set to 1)")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
if has_flashinfer():
|
||||||
|
|
||||||
|
@torch.library.custom_op(
|
||||||
|
"vllm::flashinfer_mm_fp4",
|
||||||
|
mutates_args=[],
|
||||||
|
device_types="cuda",
|
||||||
|
)
|
||||||
|
def flashinfer_mm_fp4(
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
A_scale: torch.Tensor,
|
||||||
|
B_scale: torch.Tensor,
|
||||||
|
g_scale: torch.Tensor,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
backend: str,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
from flashinfer import mm_fp4 as flashinfer_mm_fp4_
|
||||||
|
|
||||||
|
return flashinfer_mm_fp4_(
|
||||||
|
A, B, A_scale, B_scale, g_scale, dtype, block_size=16, backend=backend
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.library.register_fake(
|
||||||
|
"vllm::flashinfer_mm_fp4",
|
||||||
|
)
|
||||||
|
def flashinfer_mm_fp4_fake(
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
A_scale: torch.Tensor,
|
||||||
|
B_scale: torch.Tensor,
|
||||||
|
g_scale: torch.Tensor,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
backend: str,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.empty(A.shape[0], B.shape[1], dtype=dtype, device=A.device)
|
||||||
|
|
||||||
|
@torch.library.custom_op(
|
||||||
|
"vllm::bmm_fp8",
|
||||||
|
mutates_args=[],
|
||||||
|
device_types="cuda",
|
||||||
|
)
|
||||||
|
def bmm_fp8(
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
A_scale: torch.Tensor,
|
||||||
|
B_scale: torch.Tensor,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
backend: str,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
from flashinfer import bmm_fp8 as bmm_fp8_
|
||||||
|
|
||||||
|
return bmm_fp8_(A, B, A_scale, B_scale, dtype, None, backend)
|
||||||
|
|
||||||
|
@torch.library.register_fake(
|
||||||
|
"vllm::bmm_fp8",
|
||||||
|
)
|
||||||
|
def bmm_fp8_fake(
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
A_scale: torch.Tensor,
|
||||||
|
B_scale: torch.Tensor,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
backend: str,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.empty(
|
||||||
|
A.shape[0], A.shape[1], B.shape[2], dtype=dtype, device=A.device
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def flashinfer_scaled_fp4_mm(
|
||||||
|
a: torch.Tensor,
|
||||||
|
b: torch.Tensor,
|
||||||
|
block_scale_a: torch.Tensor,
|
||||||
|
block_scale_b: torch.Tensor,
|
||||||
|
alpha: torch.Tensor,
|
||||||
|
out_dtype: torch.dtype,
|
||||||
|
backend: str,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
assert a.ndim == 2 and b.ndim == 2
|
||||||
|
assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2
|
||||||
|
assert a.stride(-1) == 1 and b.stride(-1) == 1
|
||||||
|
assert a.shape[1] == b.shape[1]
|
||||||
|
|
||||||
|
if backend == "cutlass":
|
||||||
|
block_scale_a = block_scale_a.view(torch.uint8)
|
||||||
|
block_scale_b = block_scale_b.view(torch.uint8)
|
||||||
|
|
||||||
|
return flashinfer_mm_fp4(
|
||||||
|
a,
|
||||||
|
b.t(),
|
||||||
|
block_scale_a,
|
||||||
|
block_scale_b.t(),
|
||||||
|
alpha,
|
||||||
|
out_dtype,
|
||||||
|
backend=backend,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def flashinfer_scaled_fp8_mm(
|
||||||
|
a: torch.Tensor,
|
||||||
|
b: torch.Tensor,
|
||||||
|
scale_a: torch.Tensor,
|
||||||
|
scale_b: torch.Tensor,
|
||||||
|
out_dtype: torch.dtype,
|
||||||
|
bias: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
assert a.ndim == 2 and b.ndim == 2
|
||||||
|
assert a.shape[1] == b.shape[0]
|
||||||
|
assert scale_a.numel() == 1 and scale_b.numel() == 1
|
||||||
|
assert a.dtype == torch.float8_e4m3fn and b.dtype == torch.float8_e4m3fn
|
||||||
|
assert a.device.type == "cuda" and b.device.type == "cuda"
|
||||||
|
assert scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32
|
||||||
|
assert scale_a.device.type == "cuda" and scale_b.device.type == "cuda"
|
||||||
|
|
||||||
|
output = bmm_fp8(
|
||||||
|
a.unsqueeze(0),
|
||||||
|
b.unsqueeze(0),
|
||||||
|
scale_a,
|
||||||
|
scale_b,
|
||||||
|
out_dtype,
|
||||||
|
"auto",
|
||||||
|
).view(a.shape[0], b.shape[1])
|
||||||
|
|
||||||
|
if bias is not None:
|
||||||
|
output = output + bias
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def flashinfer_disable_q_quantization() -> bool:
|
||||||
|
"""Cache result which only depends on the environment"""
|
||||||
|
return envs.VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"has_flashinfer",
|
||||||
|
"flashinfer_trtllm_fp8_block_scale_moe",
|
||||||
|
"flashinfer_cutlass_fused_moe",
|
||||||
|
"flashinfer_fp4_quantize",
|
||||||
|
"nvfp4_block_scale_interleave",
|
||||||
|
"trtllm_fp4_block_scale_moe",
|
||||||
|
"autotune",
|
||||||
|
"has_flashinfer_moe",
|
||||||
|
"has_flashinfer_comm",
|
||||||
|
"has_flashinfer_all2all",
|
||||||
|
"has_flashinfer_cutlass_fused_moe",
|
||||||
|
"has_nvidia_artifactory",
|
||||||
|
"supports_trtllm_attention",
|
||||||
|
"can_use_trtllm_attention",
|
||||||
|
"use_trtllm_attention",
|
||||||
|
"flashinfer_disable_q_quantization",
|
||||||
|
"flashinfer_scaled_fp4_mm",
|
||||||
|
"flashinfer_scaled_fp8_mm",
|
||||||
|
]
|
||||||
236
vllm/_utils/func_utils.py
Normal file
236
vllm/_utils/func_utils.py
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Contains helpers that are applied to functions.
|
||||||
|
|
||||||
|
This is similar in concept to the `functools` module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import threading
|
||||||
|
import warnings
|
||||||
|
from collections.abc import Callable, Mapping
|
||||||
|
from functools import lru_cache, partial, wraps
|
||||||
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
T = TypeVar("T")
|
||||||
|
F = TypeVar("F", bound=Callable[..., Any])
|
||||||
|
|
||||||
|
|
||||||
|
def identity(value: T, **kwargs) -> T:
|
||||||
|
"""Returns the first provided value."""
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def run_once(f: Callable[P, None]) -> Callable[P, None]:
|
||||||
|
def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
|
||||||
|
if wrapper.has_run: # type: ignore[attr-defined]
|
||||||
|
return
|
||||||
|
|
||||||
|
with wrapper.lock: # type: ignore[attr-defined]
|
||||||
|
if not wrapper.has_run: # type: ignore[attr-defined]
|
||||||
|
wrapper.has_run = True # type: ignore[attr-defined]
|
||||||
|
return f(*args, **kwargs)
|
||||||
|
|
||||||
|
wrapper.has_run = False # type: ignore[attr-defined]
|
||||||
|
wrapper.lock = threading.Lock() # type: ignore[attr-defined]
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def deprecate_args(
|
||||||
|
start_index: int,
|
||||||
|
is_deprecated: bool | Callable[[], bool] = True,
|
||||||
|
additional_message: str | None = None,
|
||||||
|
) -> Callable[[F], F]:
|
||||||
|
if not callable(is_deprecated):
|
||||||
|
is_deprecated = partial(identity, is_deprecated)
|
||||||
|
|
||||||
|
def wrapper(fn: F) -> F:
|
||||||
|
params = inspect.signature(fn).parameters
|
||||||
|
pos_types = (
|
||||||
|
inspect.Parameter.POSITIONAL_ONLY,
|
||||||
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||||
|
)
|
||||||
|
pos_kws = [kw for kw, param in params.items() if param.kind in pos_types]
|
||||||
|
|
||||||
|
@wraps(fn)
|
||||||
|
def inner(*args, **kwargs):
|
||||||
|
if is_deprecated():
|
||||||
|
deprecated_args = pos_kws[start_index : len(args)]
|
||||||
|
if deprecated_args:
|
||||||
|
msg = (
|
||||||
|
f"The positional arguments {deprecated_args} are "
|
||||||
|
"deprecated and will be removed in a future update."
|
||||||
|
)
|
||||||
|
if additional_message is not None:
|
||||||
|
msg += f" {additional_message}"
|
||||||
|
|
||||||
|
warnings.warn(
|
||||||
|
DeprecationWarning(msg),
|
||||||
|
stacklevel=3, # The inner function takes up one level
|
||||||
|
)
|
||||||
|
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
return inner # type: ignore
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def deprecate_kwargs(
|
||||||
|
*kws: str,
|
||||||
|
is_deprecated: bool | Callable[[], bool] = True,
|
||||||
|
additional_message: str | None = None,
|
||||||
|
) -> Callable[[F], F]:
|
||||||
|
deprecated_kws = set(kws)
|
||||||
|
|
||||||
|
if not callable(is_deprecated):
|
||||||
|
is_deprecated = partial(identity, is_deprecated)
|
||||||
|
|
||||||
|
def wrapper(fn: F) -> F:
|
||||||
|
@wraps(fn)
|
||||||
|
def inner(*args, **kwargs):
|
||||||
|
if is_deprecated():
|
||||||
|
deprecated_kwargs = kwargs.keys() & deprecated_kws
|
||||||
|
if deprecated_kwargs:
|
||||||
|
msg = (
|
||||||
|
f"The keyword arguments {deprecated_kwargs} are "
|
||||||
|
"deprecated and will be removed in a future update."
|
||||||
|
)
|
||||||
|
if additional_message is not None:
|
||||||
|
msg += f" {additional_message}"
|
||||||
|
|
||||||
|
warnings.warn(
|
||||||
|
DeprecationWarning(msg),
|
||||||
|
stacklevel=3, # The inner function takes up one level
|
||||||
|
)
|
||||||
|
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
return inner # type: ignore
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def supports_kw(
|
||||||
|
callable: Callable[..., object],
|
||||||
|
kw_name: str,
|
||||||
|
*,
|
||||||
|
requires_kw_only: bool = False,
|
||||||
|
allow_var_kwargs: bool = True,
|
||||||
|
) -> bool:
|
||||||
|
"""Check if a keyword is a valid kwarg for a callable; if requires_kw_only
|
||||||
|
disallows kwargs names that can also be positional arguments.
|
||||||
|
"""
|
||||||
|
params = inspect.signature(callable).parameters
|
||||||
|
if not params:
|
||||||
|
return False
|
||||||
|
|
||||||
|
param_val = params.get(kw_name)
|
||||||
|
|
||||||
|
# Types where the it may be valid, i.e., explicitly defined & nonvariadic
|
||||||
|
passable_kw_types = set(
|
||||||
|
(
|
||||||
|
inspect.Parameter.POSITIONAL_ONLY,
|
||||||
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||||
|
inspect.Parameter.KEYWORD_ONLY,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if param_val:
|
||||||
|
is_sig_param = param_val.kind in passable_kw_types
|
||||||
|
# We want kwargs only, but this is passable as a positional arg
|
||||||
|
if (
|
||||||
|
requires_kw_only
|
||||||
|
and is_sig_param
|
||||||
|
and param_val.kind != inspect.Parameter.KEYWORD_ONLY
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
if (requires_kw_only and param_val.kind == inspect.Parameter.KEYWORD_ONLY) or (
|
||||||
|
not requires_kw_only and is_sig_param
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# If we're okay with var-kwargs, it's supported as long as
|
||||||
|
# the kw_name isn't something like *args, **kwargs
|
||||||
|
if allow_var_kwargs:
|
||||||
|
# Get the last param; type is ignored here because params is a proxy
|
||||||
|
# mapping, but it wraps an ordered dict, and they appear in order.
|
||||||
|
# Ref: https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters
|
||||||
|
last_param = params[next(reversed(params))] # type: ignore
|
||||||
|
return (
|
||||||
|
last_param.kind == inspect.Parameter.VAR_KEYWORD
|
||||||
|
and last_param.name != kw_name
|
||||||
|
)
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_allowed_kwarg_only_overrides(
|
||||||
|
callable: Callable[..., object],
|
||||||
|
overrides: Mapping[str, object] | None,
|
||||||
|
*,
|
||||||
|
requires_kw_only: bool = True,
|
||||||
|
allow_var_kwargs: bool = False,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Given a callable which has one or more keyword only params and a dict
|
||||||
|
mapping param names to values, drop values that can be not be kwarg
|
||||||
|
expanded to overwrite one or more keyword-only args. This is used in a
|
||||||
|
few places to handle custom processor overrides for multimodal models,
|
||||||
|
e.g., for profiling when processor options provided by the user
|
||||||
|
may affect the number of mm tokens per instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
callable: Callable which takes 0 or more keyword only arguments.
|
||||||
|
If None is provided, all overrides names are allowed.
|
||||||
|
overrides: Potential overrides to be used when invoking the callable.
|
||||||
|
allow_var_kwargs: Allows overrides that are expandable for var kwargs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing the kwargs to be leveraged which may be used
|
||||||
|
to overwrite one or more keyword only arguments when invoking the
|
||||||
|
callable.
|
||||||
|
"""
|
||||||
|
if not overrides:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Drop any mm_processor_kwargs provided by the user that
|
||||||
|
# are not kwargs, unless it can fit it var_kwargs param
|
||||||
|
filtered_overrides = {
|
||||||
|
kwarg_name: val
|
||||||
|
for kwarg_name, val in overrides.items()
|
||||||
|
if supports_kw(
|
||||||
|
callable,
|
||||||
|
kwarg_name,
|
||||||
|
requires_kw_only=requires_kw_only,
|
||||||
|
allow_var_kwargs=allow_var_kwargs,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
# If anything is dropped, log a warning
|
||||||
|
dropped_keys = overrides.keys() - filtered_overrides.keys()
|
||||||
|
if dropped_keys:
|
||||||
|
if requires_kw_only:
|
||||||
|
logger.warning(
|
||||||
|
"The following intended overrides are not keyword-only args "
|
||||||
|
"and will be dropped: %s",
|
||||||
|
dropped_keys,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"The following intended overrides are not keyword args "
|
||||||
|
"and will be dropped: %s",
|
||||||
|
dropped_keys,
|
||||||
|
)
|
||||||
|
|
||||||
|
return filtered_overrides
|
||||||
147
vllm/_utils/gc_utils.py
Normal file
147
vllm/_utils/gc_utils.py
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import gc
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from collections import Counter
|
||||||
|
from contextlib import suppress
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class GCDebugConfig:
|
||||||
|
"""
|
||||||
|
Config for GC Debugger.
|
||||||
|
- 0: disable GC debugger
|
||||||
|
- 1: enable GC debugger with gc.collect elpased times
|
||||||
|
- '{"top_objects":5}': enable GC debugger with top 5 collected objects
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, gc_debug_conf: str | None = None) -> None:
|
||||||
|
self.enabled: bool = False
|
||||||
|
self.top_objects: int = -1
|
||||||
|
|
||||||
|
if not gc_debug_conf or gc_debug_conf == "0":
|
||||||
|
pass
|
||||||
|
elif gc_debug_conf == "1":
|
||||||
|
self.enabled = True
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
json_conf = json.loads(gc_debug_conf)
|
||||||
|
self.enabled = True
|
||||||
|
self.top_objects = json_conf.get("top_objects", -1)
|
||||||
|
except Exception:
|
||||||
|
self.enabled = False
|
||||||
|
logger.error("Failed to parse VLLM_GC_DEBUG(%s)", envs.VLLM_GC_DEBUG)
|
||||||
|
logger.debug("GC Debug Config. %s", str(self))
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"enabled:{self.enabled},top_objects:{self.top_objects}"
|
||||||
|
|
||||||
|
|
||||||
|
class GCDebugger:
|
||||||
|
"""
|
||||||
|
Debugger for GC which logs helpful information for GC understanding.
|
||||||
|
To enable, you should call maybe_attach_gc_debug_callback in the process.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: GCDebugConfig) -> None:
|
||||||
|
self.config = config
|
||||||
|
# Start time in micro second of this GC cycle
|
||||||
|
self.start_time_ns: int = time.monotonic_ns()
|
||||||
|
# If config.top_objects is positive,
|
||||||
|
# compute top collected objects by object types
|
||||||
|
self.gc_top_collected_objects: str = ""
|
||||||
|
|
||||||
|
def handle(self, phase: str, info: dict[str, int]) -> None:
|
||||||
|
"""
|
||||||
|
Handles a GC event (e.g. GC start or GC finish)
|
||||||
|
"""
|
||||||
|
generation = info.get("generation")
|
||||||
|
if generation is None:
|
||||||
|
return
|
||||||
|
if phase == "start":
|
||||||
|
# Before GC started, record GC start time
|
||||||
|
# and top collected objects
|
||||||
|
self.start_time_ns = time.monotonic_ns()
|
||||||
|
self.gc_top_collected_objects = _compute_top_gc_collected_objects(
|
||||||
|
gc.get_objects(generation), self.config.top_objects
|
||||||
|
)
|
||||||
|
elif phase == "stop":
|
||||||
|
# After GC finished, Record GC elapsed time and
|
||||||
|
# optionally top collected objects
|
||||||
|
elpased_ms = (time.monotonic_ns() - self.start_time_ns) / 1e6
|
||||||
|
logger.info(
|
||||||
|
"GC took %.3fms to complete. "
|
||||||
|
"Collected %s objects in GC generation %d.%s",
|
||||||
|
elpased_ms,
|
||||||
|
str(info.get("collected", "?")),
|
||||||
|
generation,
|
||||||
|
(
|
||||||
|
f" Top collected objects: \n{self.gc_top_collected_objects}"
|
||||||
|
if self.gc_top_collected_objects
|
||||||
|
else ""
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def freeze_gc_heap() -> None:
|
||||||
|
"""
|
||||||
|
Freeze all objects tracked by the garbage collector. It should be invoked
|
||||||
|
after server init / warmup, to reduce GC overhead from static objects
|
||||||
|
during serving time.
|
||||||
|
"""
|
||||||
|
# Ensure all static objects are pushed down to the oldest generation for
|
||||||
|
# freeze
|
||||||
|
gc.collect(0)
|
||||||
|
gc.collect(1)
|
||||||
|
gc.collect(2)
|
||||||
|
# Freeze all GC tracked objects
|
||||||
|
gc.freeze()
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_attach_gc_debug_callback() -> None:
|
||||||
|
"""
|
||||||
|
Attached a callback for GC debug when VLLM_GC_DEBUG is enabled.
|
||||||
|
"""
|
||||||
|
config = GCDebugConfig(envs.VLLM_GC_DEBUG)
|
||||||
|
if config.enabled:
|
||||||
|
debugger: GCDebugger = GCDebugger(config)
|
||||||
|
|
||||||
|
def gc_callback(phase: str, info: dict[str, int]) -> None:
|
||||||
|
debugger.handle(phase, info)
|
||||||
|
|
||||||
|
gc.callbacks.append(gc_callback)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_detailed_type(o: Any) -> str:
|
||||||
|
"""
|
||||||
|
Detailed object type.
|
||||||
|
|
||||||
|
TODO(Jialin): Further enhance the detailed type with element types for
|
||||||
|
easier debugging. We tried but occasionally it would run into signals
|
||||||
|
which kills the engine.
|
||||||
|
"""
|
||||||
|
size_str: str = ""
|
||||||
|
# Object doesn't support len() - this can happen with type objects
|
||||||
|
# or other objects that don't implement __len__ properly
|
||||||
|
with suppress(Exception):
|
||||||
|
size_str = f"(size:{len(o)})"
|
||||||
|
return f"{str(type(o))}{size_str}"
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_top_gc_collected_objects(objects: list[Any], top: int) -> str:
|
||||||
|
"""
|
||||||
|
Group collected objects by types.
|
||||||
|
"""
|
||||||
|
if top <= 0:
|
||||||
|
return ""
|
||||||
|
object_types = [_compute_detailed_type(o) for o in objects]
|
||||||
|
return "\n".join(
|
||||||
|
f"{count:>5}:{object_type}"
|
||||||
|
for object_type, count in Counter(object_types).most_common(top)
|
||||||
|
)
|
||||||
63
vllm/_utils/hashing.py
Normal file
63
vllm/_utils/hashing.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import pickle
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import cbor2
|
||||||
|
|
||||||
|
|
||||||
|
def sha256(input: Any) -> bytes:
|
||||||
|
"""Hash any picklable Python object using SHA-256.
|
||||||
|
|
||||||
|
The input is serialized using pickle before hashing, which allows
|
||||||
|
arbitrary Python objects to be used. Note that this function does
|
||||||
|
not use a hash seed—if you need one, prepend it explicitly to the input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input: Any picklable Python object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Bytes representing the SHA-256 hash of the serialized input.
|
||||||
|
"""
|
||||||
|
input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
return hashlib.sha256(input_bytes).digest()
|
||||||
|
|
||||||
|
|
||||||
|
def sha256_cbor(input: Any) -> bytes:
|
||||||
|
"""Hash objects using CBOR serialization and SHA-256.
|
||||||
|
|
||||||
|
This option is useful for non-Python-dependent serialization and hashing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input: Object to be serialized and hashed. Supported types include
|
||||||
|
basic Python types and complex structures like lists, tuples, and
|
||||||
|
dictionaries.
|
||||||
|
Custom classes must implement CBOR serialization methods.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Bytes representing the SHA-256 hash of the CBOR serialized input.
|
||||||
|
"""
|
||||||
|
input_bytes = cbor2.dumps(input, canonical=True)
|
||||||
|
return hashlib.sha256(input_bytes).digest()
|
||||||
|
|
||||||
|
|
||||||
|
def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
|
||||||
|
"""Get a hash function by name, or raise an error if the function is not found.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hash_fn_name: Name of the hash function.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A hash function.
|
||||||
|
"""
|
||||||
|
if hash_fn_name == "sha256":
|
||||||
|
return sha256
|
||||||
|
if hash_fn_name == "sha256_cbor":
|
||||||
|
return sha256_cbor
|
||||||
|
|
||||||
|
raise ValueError(f"Unsupported hash function: {hash_fn_name}")
|
||||||
411
vllm/_utils/import_utils.py
Normal file
411
vllm/_utils/import_utils.py
Normal file
@@ -0,0 +1,411 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Contains helpers related to importing modules.
|
||||||
|
|
||||||
|
This is similar in concept to the `importlib` module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib.metadata
|
||||||
|
import importlib.util
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from functools import cache
|
||||||
|
from types import ModuleType
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import regex as re
|
||||||
|
from typing_extensions import Never
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: This function can be removed if transformer_modules classes are
|
||||||
|
# serialized by value when communicating between processes
|
||||||
|
def init_cached_hf_modules() -> None:
|
||||||
|
"""
|
||||||
|
Lazy initialization of the Hugging Face modules.
|
||||||
|
"""
|
||||||
|
from transformers.dynamic_module_utils import init_hf_modules
|
||||||
|
|
||||||
|
init_hf_modules()
|
||||||
|
|
||||||
|
|
||||||
|
def import_pynvml():
|
||||||
|
"""
|
||||||
|
Historical comments:
|
||||||
|
|
||||||
|
libnvml.so is the library behind nvidia-smi, and
|
||||||
|
pynvml is a Python wrapper around it. We use it to get GPU
|
||||||
|
status without initializing CUDA context in the current process.
|
||||||
|
Historically, there are two packages that provide pynvml:
|
||||||
|
- `nvidia-ml-py` (https://pypi.org/project/nvidia-ml-py/): The official
|
||||||
|
wrapper. It is a dependency of vLLM, and is installed when users
|
||||||
|
install vLLM. It provides a Python module named `pynvml`.
|
||||||
|
- `pynvml` (https://pypi.org/project/pynvml/): An unofficial wrapper.
|
||||||
|
Prior to version 12.0, it also provides a Python module `pynvml`,
|
||||||
|
and therefore conflicts with the official one. What's worse,
|
||||||
|
the module is a Python package, and has higher priority than
|
||||||
|
the official one which is a standalone Python file.
|
||||||
|
This causes errors when both of them are installed.
|
||||||
|
Starting from version 12.0, it migrates to a new module
|
||||||
|
named `pynvml_utils` to avoid the conflict.
|
||||||
|
It is so confusing that many packages in the community use the
|
||||||
|
unofficial one by mistake, and we have to handle this case.
|
||||||
|
For example, `nvcr.io/nvidia/pytorch:24.12-py3` uses the unofficial
|
||||||
|
one, and it will cause errors, see the issue
|
||||||
|
https://github.com/vllm-project/vllm/issues/12847 for example.
|
||||||
|
After all the troubles, we decide to copy the official `pynvml`
|
||||||
|
module to our codebase, and use it directly.
|
||||||
|
"""
|
||||||
|
import vllm.third_party.pynvml as pynvml
|
||||||
|
|
||||||
|
return pynvml
|
||||||
|
|
||||||
|
|
||||||
|
def import_from_path(module_name: str, file_path: str | os.PathLike):
|
||||||
|
"""
|
||||||
|
Import a Python file according to its file path.
|
||||||
|
|
||||||
|
Based on the official recipe:
|
||||||
|
https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
|
||||||
|
"""
|
||||||
|
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||||||
|
if spec is None:
|
||||||
|
raise ModuleNotFoundError(f"No module named {module_name!r}")
|
||||||
|
|
||||||
|
assert spec.loader is not None
|
||||||
|
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
sys.modules[module_name] = module
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_obj_by_qualname(qualname: str) -> Any:
|
||||||
|
"""
|
||||||
|
Resolve an object by its fully-qualified class name.
|
||||||
|
"""
|
||||||
|
module_name, obj_name = qualname.rsplit(".", 1)
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
return getattr(module, obj_name)
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def get_vllm_optional_dependencies():
|
||||||
|
metadata = importlib.metadata.metadata("vllm")
|
||||||
|
requirements = metadata.get_all("Requires-Dist", [])
|
||||||
|
extras = metadata.get_all("Provides-Extra", [])
|
||||||
|
|
||||||
|
return {
|
||||||
|
extra: [
|
||||||
|
re.split(r";|>=|<=|==", req)[0]
|
||||||
|
for req in requirements
|
||||||
|
if req.endswith(f'extra == "{extra}"')
|
||||||
|
]
|
||||||
|
for extra in extras
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class _PlaceholderBase:
|
||||||
|
"""
|
||||||
|
Disallows downstream usage of placeholder modules.
|
||||||
|
|
||||||
|
We need to explicitly override each dunder method because
|
||||||
|
[`__getattr__`][vllm.utils.import_utils._PlaceholderBase.__getattr__]
|
||||||
|
is not called when they are accessed.
|
||||||
|
|
||||||
|
Info:
|
||||||
|
[Special method lookup](https://docs.python.org/3/reference/datamodel.html#special-lookup)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __getattr__(self, key: str) -> Never:
|
||||||
|
"""
|
||||||
|
The main class should implement this to throw an error
|
||||||
|
for attribute accesses representing downstream usage.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
# [Basic customization]
|
||||||
|
|
||||||
|
def __lt__(self, other: object):
|
||||||
|
return self.__getattr__("__lt__")
|
||||||
|
|
||||||
|
def __le__(self, other: object):
|
||||||
|
return self.__getattr__("__le__")
|
||||||
|
|
||||||
|
def __eq__(self, other: object):
|
||||||
|
return self.__getattr__("__eq__")
|
||||||
|
|
||||||
|
def __ne__(self, other: object):
|
||||||
|
return self.__getattr__("__ne__")
|
||||||
|
|
||||||
|
def __gt__(self, other: object):
|
||||||
|
return self.__getattr__("__gt__")
|
||||||
|
|
||||||
|
def __ge__(self, other: object):
|
||||||
|
return self.__getattr__("__ge__")
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return self.__getattr__("__hash__")
|
||||||
|
|
||||||
|
def __bool__(self):
|
||||||
|
return self.__getattr__("__bool__")
|
||||||
|
|
||||||
|
# [Callable objects]
|
||||||
|
|
||||||
|
def __call__(self, *args: object, **kwargs: object):
|
||||||
|
return self.__getattr__("__call__")
|
||||||
|
|
||||||
|
# [Container types]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.__getattr__("__len__")
|
||||||
|
|
||||||
|
def __getitem__(self, key: object):
|
||||||
|
return self.__getattr__("__getitem__")
|
||||||
|
|
||||||
|
def __setitem__(self, key: object, value: object):
|
||||||
|
return self.__getattr__("__setitem__")
|
||||||
|
|
||||||
|
def __delitem__(self, key: object):
|
||||||
|
return self.__getattr__("__delitem__")
|
||||||
|
|
||||||
|
# __missing__ is optional according to __getitem__ specification,
|
||||||
|
# so it is skipped
|
||||||
|
|
||||||
|
# __iter__ and __reversed__ have a default implementation
|
||||||
|
# based on __len__ and __getitem__, so they are skipped.
|
||||||
|
|
||||||
|
# [Numeric Types]
|
||||||
|
|
||||||
|
def __add__(self, other: object):
|
||||||
|
return self.__getattr__("__add__")
|
||||||
|
|
||||||
|
def __sub__(self, other: object):
|
||||||
|
return self.__getattr__("__sub__")
|
||||||
|
|
||||||
|
def __mul__(self, other: object):
|
||||||
|
return self.__getattr__("__mul__")
|
||||||
|
|
||||||
|
def __matmul__(self, other: object):
|
||||||
|
return self.__getattr__("__matmul__")
|
||||||
|
|
||||||
|
def __truediv__(self, other: object):
|
||||||
|
return self.__getattr__("__truediv__")
|
||||||
|
|
||||||
|
def __floordiv__(self, other: object):
|
||||||
|
return self.__getattr__("__floordiv__")
|
||||||
|
|
||||||
|
def __mod__(self, other: object):
|
||||||
|
return self.__getattr__("__mod__")
|
||||||
|
|
||||||
|
def __divmod__(self, other: object):
|
||||||
|
return self.__getattr__("__divmod__")
|
||||||
|
|
||||||
|
def __pow__(self, other: object, modulo: object = ...):
|
||||||
|
return self.__getattr__("__pow__")
|
||||||
|
|
||||||
|
def __lshift__(self, other: object):
|
||||||
|
return self.__getattr__("__lshift__")
|
||||||
|
|
||||||
|
def __rshift__(self, other: object):
|
||||||
|
return self.__getattr__("__rshift__")
|
||||||
|
|
||||||
|
def __and__(self, other: object):
|
||||||
|
return self.__getattr__("__and__")
|
||||||
|
|
||||||
|
def __xor__(self, other: object):
|
||||||
|
return self.__getattr__("__xor__")
|
||||||
|
|
||||||
|
def __or__(self, other: object):
|
||||||
|
return self.__getattr__("__or__")
|
||||||
|
|
||||||
|
# r* and i* methods have lower priority than
|
||||||
|
# the methods for left operand so they are skipped
|
||||||
|
|
||||||
|
def __neg__(self):
|
||||||
|
return self.__getattr__("__neg__")
|
||||||
|
|
||||||
|
def __pos__(self):
|
||||||
|
return self.__getattr__("__pos__")
|
||||||
|
|
||||||
|
def __abs__(self):
|
||||||
|
return self.__getattr__("__abs__")
|
||||||
|
|
||||||
|
def __invert__(self):
|
||||||
|
return self.__getattr__("__invert__")
|
||||||
|
|
||||||
|
# __complex__, __int__ and __float__ have a default implementation
|
||||||
|
# based on __index__, so they are skipped.
|
||||||
|
|
||||||
|
def __index__(self):
|
||||||
|
return self.__getattr__("__index__")
|
||||||
|
|
||||||
|
def __round__(self, ndigits: object = ...):
|
||||||
|
return self.__getattr__("__round__")
|
||||||
|
|
||||||
|
def __trunc__(self):
|
||||||
|
return self.__getattr__("__trunc__")
|
||||||
|
|
||||||
|
def __floor__(self):
|
||||||
|
return self.__getattr__("__floor__")
|
||||||
|
|
||||||
|
def __ceil__(self):
|
||||||
|
return self.__getattr__("__ceil__")
|
||||||
|
|
||||||
|
# [Context managers]
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self.__getattr__("__enter__")
|
||||||
|
|
||||||
|
def __exit__(self, *args: object, **kwargs: object):
|
||||||
|
return self.__getattr__("__exit__")
|
||||||
|
|
||||||
|
|
||||||
|
class PlaceholderModule(_PlaceholderBase):
|
||||||
|
"""
|
||||||
|
A placeholder object to use when a module does not exist.
|
||||||
|
|
||||||
|
This enables more informative errors when trying to access attributes
|
||||||
|
of a module that does not exist.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, name: str) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Apply name mangling to avoid conflicting with module attributes
|
||||||
|
self.__name = name
|
||||||
|
|
||||||
|
def placeholder_attr(self, attr_path: str):
|
||||||
|
return _PlaceholderModuleAttr(self, attr_path)
|
||||||
|
|
||||||
|
def __getattr__(self, key: str) -> Never:
|
||||||
|
name = self.__name
|
||||||
|
|
||||||
|
try:
|
||||||
|
importlib.import_module(name)
|
||||||
|
except ImportError as exc:
|
||||||
|
for extra, names in get_vllm_optional_dependencies().items():
|
||||||
|
if name in names:
|
||||||
|
msg = f"Please install vllm[{extra}] for {extra} support"
|
||||||
|
raise ImportError(msg) from exc
|
||||||
|
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
raise AssertionError(
|
||||||
|
"PlaceholderModule should not be used "
|
||||||
|
"when the original module can be imported"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _PlaceholderModuleAttr(_PlaceholderBase):
|
||||||
|
def __init__(self, module: PlaceholderModule, attr_path: str) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Apply name mangling to avoid conflicting with module attributes
|
||||||
|
self.__module = module
|
||||||
|
self.__attr_path = attr_path
|
||||||
|
|
||||||
|
def placeholder_attr(self, attr_path: str):
|
||||||
|
return _PlaceholderModuleAttr(self.__module, f"{self.__attr_path}.{attr_path}")
|
||||||
|
|
||||||
|
def __getattr__(self, key: str) -> Never:
|
||||||
|
getattr(self.__module, f"{self.__attr_path}.{key}")
|
||||||
|
|
||||||
|
raise AssertionError(
|
||||||
|
"PlaceholderModule should not be used "
|
||||||
|
"when the original module can be imported"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LazyLoader(ModuleType):
|
||||||
|
"""
|
||||||
|
`LazyLoader` module borrowed from [Tensorflow]
|
||||||
|
(https://github.com/tensorflow/tensorflow/blob/main/tensorflow/python/util/lazy_loader.py)
|
||||||
|
with an addition of "module caching".
|
||||||
|
|
||||||
|
Lazily import a module, mainly to avoid pulling in large dependencies.
|
||||||
|
Modules such as `xgrammar` might do additional side effects, so we
|
||||||
|
only want to use this when it is needed, delaying all eager effects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
local_name: str,
|
||||||
|
parent_module_globals: dict[str, Any],
|
||||||
|
name: str,
|
||||||
|
):
|
||||||
|
self._local_name = local_name
|
||||||
|
self._parent_module_globals = parent_module_globals
|
||||||
|
self._module: ModuleType | None = None
|
||||||
|
|
||||||
|
super().__init__(str(name))
|
||||||
|
|
||||||
|
def _load(self) -> ModuleType:
|
||||||
|
# Import the target module and insert it into the parent's namespace
|
||||||
|
try:
|
||||||
|
module = importlib.import_module(self.__name__)
|
||||||
|
self._parent_module_globals[self._local_name] = module
|
||||||
|
# The additional add to sys.modules
|
||||||
|
# ensures library is actually loaded.
|
||||||
|
sys.modules[self._local_name] = module
|
||||||
|
except ModuleNotFoundError as err:
|
||||||
|
raise err from None
|
||||||
|
|
||||||
|
# Update this object's dict so that if someone keeps a
|
||||||
|
# reference to the LazyLoader, lookups are efficient
|
||||||
|
# (__getattr__ is only called on lookups that fail).
|
||||||
|
self.__dict__.update(module.__dict__)
|
||||||
|
return module
|
||||||
|
|
||||||
|
def __getattr__(self, item: Any) -> Any:
|
||||||
|
if self._module is None:
|
||||||
|
self._module = self._load()
|
||||||
|
return getattr(self._module, item)
|
||||||
|
|
||||||
|
def __dir__(self) -> list[str]:
|
||||||
|
if self._module is None:
|
||||||
|
self._module = self._load()
|
||||||
|
return dir(self._module)
|
||||||
|
|
||||||
|
|
||||||
|
# Optional dependency detection utilities
|
||||||
|
@cache
|
||||||
|
def _has_module(module_name: str) -> bool:
|
||||||
|
"""Return True if *module_name* can be found in the current environment.
|
||||||
|
|
||||||
|
The result is cached so that subsequent queries for the same module incur
|
||||||
|
no additional overhead.
|
||||||
|
"""
|
||||||
|
return importlib.util.find_spec(module_name) is not None
|
||||||
|
|
||||||
|
|
||||||
|
def has_pplx() -> bool:
|
||||||
|
"""Whether the optional `pplx_kernels` package is available."""
|
||||||
|
return _has_module("pplx_kernels")
|
||||||
|
|
||||||
|
|
||||||
|
def has_deep_ep() -> bool:
|
||||||
|
"""Whether the optional `deep_ep` package is available."""
|
||||||
|
return _has_module("deep_ep")
|
||||||
|
|
||||||
|
|
||||||
|
def has_deep_gemm() -> bool:
|
||||||
|
"""Whether the optional `deep_gemm` package is available."""
|
||||||
|
return _has_module("deep_gemm")
|
||||||
|
|
||||||
|
|
||||||
|
def has_triton_kernels() -> bool:
|
||||||
|
"""Whether the optional `triton_kernels` package is available."""
|
||||||
|
return _has_module("triton_kernels")
|
||||||
|
|
||||||
|
|
||||||
|
def has_tilelang() -> bool:
|
||||||
|
"""Whether the optional `tilelang` package is available."""
|
||||||
|
return _has_module("tilelang")
|
||||||
|
|
||||||
|
|
||||||
|
def has_arctic_inference() -> bool:
|
||||||
|
"""Whether the optional `arctic_inference` package is available."""
|
||||||
|
|
||||||
|
return _has_module("arctic_inference")
|
||||||
165
vllm/_utils/jsontree.py
Normal file
165
vllm/_utils/jsontree.py
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Helper functions to work with nested JSON structures."""
|
||||||
|
|
||||||
|
from collections.abc import Callable, Iterable
|
||||||
|
from functools import reduce
|
||||||
|
from typing import TYPE_CHECKING, TypeAlias, TypeVar, cast, overload
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.multimodal.inputs import BatchedTensorInputs
|
||||||
|
|
||||||
|
_T = TypeVar("_T")
|
||||||
|
_U = TypeVar("_U")
|
||||||
|
|
||||||
|
JSONTree: TypeAlias = (
|
||||||
|
dict[str, "JSONTree[_T]"] | list["JSONTree[_T]"] | tuple["JSONTree[_T]", ...] | _T
|
||||||
|
)
|
||||||
|
"""A nested JSON structure where the leaves need not be JSON-serializable."""
|
||||||
|
|
||||||
|
_JSONTree: TypeAlias = (
|
||||||
|
dict[str, "JSONTree[_T]"]
|
||||||
|
| list["JSONTree[_T]"]
|
||||||
|
| tuple["JSONTree[_T]", ...]
|
||||||
|
| dict[str, _T]
|
||||||
|
| list[_T]
|
||||||
|
| tuple[_T, ...]
|
||||||
|
| _T
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
Same as `JSONTree` but with additional `Union` members to satisfy overloads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def json_iter_leaves(value: JSONTree[_T]) -> Iterable[_T]:
|
||||||
|
"""Iterate through each leaf in a nested JSON structure."""
|
||||||
|
if isinstance(value, dict):
|
||||||
|
for v in value.values():
|
||||||
|
yield from json_iter_leaves(v)
|
||||||
|
elif isinstance(value, (list, tuple)):
|
||||||
|
for v in value:
|
||||||
|
yield from json_iter_leaves(v)
|
||||||
|
else:
|
||||||
|
yield value
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def json_map_leaves(
|
||||||
|
func: Callable[["torch.Tensor"], "torch.Tensor"],
|
||||||
|
value: "BatchedTensorInputs",
|
||||||
|
) -> "BatchedTensorInputs": ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def json_map_leaves(
|
||||||
|
func: Callable[[_T], _U],
|
||||||
|
value: _T | dict[str, _T],
|
||||||
|
) -> _U | dict[str, _U]: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def json_map_leaves(
|
||||||
|
func: Callable[[_T], _U],
|
||||||
|
value: _T | list[_T],
|
||||||
|
) -> _U | list[_U]: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def json_map_leaves(
|
||||||
|
func: Callable[[_T], _U],
|
||||||
|
value: _T | tuple[_T, ...],
|
||||||
|
) -> _U | tuple[_U, ...]: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def json_map_leaves(
|
||||||
|
func: Callable[[_T], _U],
|
||||||
|
value: JSONTree[_T],
|
||||||
|
) -> JSONTree[_U]: ...
|
||||||
|
|
||||||
|
|
||||||
|
def json_map_leaves(
|
||||||
|
func: Callable[[_T], _U],
|
||||||
|
value: "BatchedTensorInputs" | _JSONTree[_T],
|
||||||
|
) -> "BatchedTensorInputs" | _JSONTree[_U]:
|
||||||
|
"""Apply a function to each leaf in a nested JSON structure."""
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return {
|
||||||
|
k: json_map_leaves(func, v) # type: ignore[arg-type]
|
||||||
|
for k, v in value.items()
|
||||||
|
}
|
||||||
|
elif isinstance(value, list):
|
||||||
|
return [json_map_leaves(func, v) for v in value]
|
||||||
|
elif isinstance(value, tuple):
|
||||||
|
return tuple(json_map_leaves(func, v) for v in value)
|
||||||
|
else:
|
||||||
|
return func(value)
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def json_reduce_leaves(
|
||||||
|
func: Callable[[_T, _T], _T],
|
||||||
|
value: _T | dict[str, _T],
|
||||||
|
/,
|
||||||
|
) -> _T: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def json_reduce_leaves(
|
||||||
|
func: Callable[[_T, _T], _T],
|
||||||
|
value: _T | list[_T],
|
||||||
|
/,
|
||||||
|
) -> _T: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def json_reduce_leaves(
|
||||||
|
func: Callable[[_T, _T], _T],
|
||||||
|
value: _T | tuple[_T, ...],
|
||||||
|
/,
|
||||||
|
) -> _T: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def json_reduce_leaves(
|
||||||
|
func: Callable[[_T, _T], _T],
|
||||||
|
value: JSONTree[_T],
|
||||||
|
/,
|
||||||
|
) -> _T: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def json_reduce_leaves(
|
||||||
|
func: Callable[[_U, _T], _U],
|
||||||
|
value: JSONTree[_T],
|
||||||
|
initial: _U,
|
||||||
|
/,
|
||||||
|
) -> _U: ...
|
||||||
|
|
||||||
|
|
||||||
|
def json_reduce_leaves(
|
||||||
|
func: Callable[..., _T | _U],
|
||||||
|
value: _JSONTree[_T],
|
||||||
|
initial: _U = cast(_U, ...), # noqa: B008
|
||||||
|
/,
|
||||||
|
) -> _T | _U:
|
||||||
|
"""
|
||||||
|
Apply a function of two arguments cumulatively to each leaf in a
|
||||||
|
nested JSON structure, from left to right, so as to reduce the
|
||||||
|
sequence to a single value.
|
||||||
|
"""
|
||||||
|
if initial is ...:
|
||||||
|
return reduce(func, json_iter_leaves(value)) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
return reduce(
|
||||||
|
func, # type: ignore[arg-type]
|
||||||
|
json_iter_leaves(value),
|
||||||
|
initial,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def json_count_leaves(value: JSONTree[_T]) -> int:
|
||||||
|
"""Count the number of leaves in a nested JSON structure."""
|
||||||
|
return sum(1 for _ in json_iter_leaves(value))
|
||||||
32
vllm/_utils/math_utils.py
Normal file
32
vllm/_utils/math_utils.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Math utility functions for vLLM."""
|
||||||
|
|
||||||
|
|
||||||
|
def cdiv(a: int, b: int) -> int:
|
||||||
|
"""Ceiling division."""
|
||||||
|
return -(a // -b)
|
||||||
|
|
||||||
|
|
||||||
|
def next_power_of_2(n: int) -> int:
|
||||||
|
"""The next power of 2 (inclusive)"""
|
||||||
|
if n < 1:
|
||||||
|
return 1
|
||||||
|
return 1 << (n - 1).bit_length()
|
||||||
|
|
||||||
|
|
||||||
|
def prev_power_of_2(n: int) -> int:
|
||||||
|
"""The previous power of 2 (inclusive)"""
|
||||||
|
if n <= 0:
|
||||||
|
return 0
|
||||||
|
return 1 << (n.bit_length() - 1)
|
||||||
|
|
||||||
|
|
||||||
|
def round_up(x: int, y: int) -> int:
|
||||||
|
"""Round up x to the nearest multiple of y."""
|
||||||
|
return ((x + y - 1) // y) * y
|
||||||
|
|
||||||
|
|
||||||
|
def round_down(x: int, y: int) -> int:
|
||||||
|
"""Round down x to the nearest multiple of y."""
|
||||||
|
return (x // y) * y
|
||||||
13
vllm/_utils/mem_constants.py
Normal file
13
vllm/_utils/mem_constants.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
MB_bytes = 1_000_000
|
||||||
|
"""The number of bytes in one megabyte (MB)."""
|
||||||
|
|
||||||
|
MiB_bytes = 1 << 20
|
||||||
|
"""The number of bytes in one mebibyte (MiB)."""
|
||||||
|
|
||||||
|
GB_bytes = 1_000_000_000
|
||||||
|
"""The number of bytes in one gigabyte (GB)."""
|
||||||
|
|
||||||
|
GiB_bytes = 1 << 30
|
||||||
|
"""The number of bytes in one gibibyte (GiB)."""
|
||||||
232
vllm/_utils/mem_utils.py
Normal file
232
vllm/_utils/mem_utils.py
Normal file
@@ -0,0 +1,232 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import contextlib
|
||||||
|
import gc
|
||||||
|
import time
|
||||||
|
from collections.abc import Generator
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from functools import cache
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
import torch
|
||||||
|
import torch.types
|
||||||
|
|
||||||
|
from .mem_constants import GiB_bytes
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
|
||||||
|
"""Returns the maximum shared memory per thread block in bytes."""
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
|
max_shared_mem = ops.get_max_shared_memory_per_block_device_attribute(gpu)
|
||||||
|
# value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
|
||||||
|
# will fail
|
||||||
|
assert max_shared_mem > 0, "max_shared_mem can not be zero"
|
||||||
|
return int(max_shared_mem)
|
||||||
|
|
||||||
|
|
||||||
|
def get_cpu_memory() -> int:
|
||||||
|
"""Returns the total CPU memory of the node in bytes."""
|
||||||
|
return psutil.virtual_memory().total
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceMemoryProfiler:
|
||||||
|
def __init__(self, device: torch.types.Device | None = None):
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
def current_memory_usage(self) -> float:
|
||||||
|
# Return the memory usage in bytes.
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
return current_platform.get_current_memory_usage(self.device)
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.initial_memory = self.current_memory_usage()
|
||||||
|
# This allows us to call methods of the context manager if needed
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.final_memory = self.current_memory_usage()
|
||||||
|
self.consumed_memory = self.final_memory - self.initial_memory
|
||||||
|
|
||||||
|
# Force garbage collection
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MemorySnapshot:
|
||||||
|
"""Memory snapshot."""
|
||||||
|
|
||||||
|
torch_peak: int = 0
|
||||||
|
free_memory: int = 0
|
||||||
|
total_memory: int = 0
|
||||||
|
cuda_memory: int = 0
|
||||||
|
torch_memory: int = 0
|
||||||
|
non_torch_memory: int = 0
|
||||||
|
timestamp: float = 0.0
|
||||||
|
auto_measure: bool = True
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.auto_measure:
|
||||||
|
self.measure()
|
||||||
|
|
||||||
|
def measure(self):
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
# we measure the torch peak memory usage via allocated_bytes,
|
||||||
|
# rather than `torch.cuda.memory_reserved()` .
|
||||||
|
# After `torch.cuda.reset_peak_memory_stats()`,
|
||||||
|
# `torch.cuda.memory_reserved()` will keep growing, and only shrink
|
||||||
|
# when we call `torch.cuda.empty_cache()` or OOM happens.
|
||||||
|
self.torch_peak = torch.cuda.memory_stats().get("allocated_bytes.all.peak", 0)
|
||||||
|
|
||||||
|
self.free_memory, self.total_memory = torch.cuda.mem_get_info()
|
||||||
|
shared_sysmem_device_mem_sms = ((8, 7), (11, 0), (12, 1)) # Orin, Thor, Spark
|
||||||
|
if (
|
||||||
|
current_platform.is_cuda()
|
||||||
|
and current_platform.get_device_capability() in shared_sysmem_device_mem_sms
|
||||||
|
):
|
||||||
|
# On UMA (Orin, Thor and Spark) platform,
|
||||||
|
# where both CPU and GPU rely on system memory,
|
||||||
|
# the cudaMemGetInfo function shows the amount of free system memory
|
||||||
|
# rather than what’s actually available.
|
||||||
|
# In the case,
|
||||||
|
# torch.cuda.mem_get_info() only reports "free" memory,
|
||||||
|
# which can be lower than what is actually
|
||||||
|
# available due to not including cache memory.
|
||||||
|
# There’s also a comprehensive reference page
|
||||||
|
# that explains how you can compute the proper value yourself.
|
||||||
|
# https://docs.nvidia.com/cuda/cuda-for-tegra-appnote/#estimating-total-allocatable-device-memory-on-an-integrated-gpu-device
|
||||||
|
self.free_memory = psutil.virtual_memory().available
|
||||||
|
|
||||||
|
self.cuda_memory = self.total_memory - self.free_memory
|
||||||
|
|
||||||
|
# torch.cuda.memory_reserved() is how many bytes
|
||||||
|
# PyTorch gets from cuda (by calling cudaMalloc, etc.)
|
||||||
|
# this is used to measure the non-torch memory usage
|
||||||
|
self.torch_memory = torch.cuda.memory_reserved()
|
||||||
|
|
||||||
|
self.non_torch_memory = self.cuda_memory - self.torch_memory
|
||||||
|
self.timestamp = time.time()
|
||||||
|
|
||||||
|
def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot":
|
||||||
|
return MemorySnapshot(
|
||||||
|
torch_peak=self.torch_peak - other.torch_peak,
|
||||||
|
free_memory=self.free_memory - other.free_memory,
|
||||||
|
total_memory=self.total_memory - other.total_memory,
|
||||||
|
cuda_memory=self.cuda_memory - other.cuda_memory,
|
||||||
|
torch_memory=self.torch_memory - other.torch_memory,
|
||||||
|
non_torch_memory=self.non_torch_memory - other.non_torch_memory,
|
||||||
|
timestamp=self.timestamp - other.timestamp,
|
||||||
|
auto_measure=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MemoryProfilingResult:
|
||||||
|
"""Memory profiling result. All numbers are in bytes."""
|
||||||
|
|
||||||
|
non_kv_cache_memory: int = 0
|
||||||
|
torch_peak_increase: int = 0
|
||||||
|
non_torch_increase: int = 0
|
||||||
|
weights_memory: float = 0
|
||||||
|
before_create: MemorySnapshot = field(default_factory=MemorySnapshot)
|
||||||
|
before_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
|
||||||
|
after_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
|
||||||
|
profile_time: float = 0.0
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"Memory profiling takes {self.profile_time:.2f} seconds. "
|
||||||
|
f"Total non KV cache memory: "
|
||||||
|
f"{(self.non_kv_cache_memory / GiB_bytes):.2f}GiB; "
|
||||||
|
f"torch peak memory increase: "
|
||||||
|
f"{(self.torch_peak_increase / GiB_bytes):.2f}GiB; "
|
||||||
|
f"non-torch forward increase memory: "
|
||||||
|
f"{(self.non_torch_increase / GiB_bytes):.2f}GiB; "
|
||||||
|
f"weights memory: {(self.weights_memory / GiB_bytes):.2f}GiB."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def memory_profiling(
|
||||||
|
baseline_snapshot: MemorySnapshot, weights_memory: int
|
||||||
|
) -> Generator[MemoryProfilingResult, None, None]:
|
||||||
|
"""Memory profiling context manager.
|
||||||
|
baseline_snapshot: the memory snapshot before the current vLLM instance.
|
||||||
|
weights_memory: memory used by PyTorch when loading the model weights.
|
||||||
|
Note that, before loading the model weights, we also initialize the device
|
||||||
|
and distributed environment, which may consume some memory. This part is not
|
||||||
|
included in the weights_memory because PyTorch does not control it.
|
||||||
|
|
||||||
|
The memory in one GPU can be classified into 3 categories:
|
||||||
|
1. memory used by anything other than the current vLLM instance.
|
||||||
|
2. memory used by torch in the current vLLM instance.
|
||||||
|
3. memory used in the current vLLM instance, but not by torch.
|
||||||
|
|
||||||
|
A quantitive example:
|
||||||
|
|
||||||
|
Before creating the current vLLM instance:
|
||||||
|
category 1: 1 GiB
|
||||||
|
category 2: 0 GiB
|
||||||
|
category 3: 0 GiB
|
||||||
|
|
||||||
|
After creating the current vLLM instance and loading the model,
|
||||||
|
(i.e. before profiling):
|
||||||
|
category 1: 1 GiB
|
||||||
|
category 2: 2 GiB (model weights take 2 GiB)
|
||||||
|
category 3: 0.5 GiB (memory used by NCCL)
|
||||||
|
|
||||||
|
During profiling (peak):
|
||||||
|
category 1: 1 GiB
|
||||||
|
category 2: 4 GiB (peak activation tensors take 2 GiB)
|
||||||
|
category 3: 1 GiB (memory used by NCCL + buffers for some attention backends)
|
||||||
|
|
||||||
|
After profiling:
|
||||||
|
category 1: 1 GiB
|
||||||
|
category 2: 3 GiB (after garbage-collecting activation tensors)
|
||||||
|
category 3: 1 GiB (memory used by NCCL + buffers for some attention backends)
|
||||||
|
|
||||||
|
In this case, non-kv cache takes 5 GiB in total, including:
|
||||||
|
a. 2 GiB used by the model weights (category 2)
|
||||||
|
b. 2 GiB reserved for the peak activation tensors (category 2)
|
||||||
|
c. 1 GiB used by non-torch components (category 3)
|
||||||
|
|
||||||
|
The memory used for loading weights (a.) is directly given from the argument `weights_memory`.
|
||||||
|
|
||||||
|
The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.).
|
||||||
|
|
||||||
|
The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.).
|
||||||
|
""" # noqa
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
|
result = MemoryProfilingResult()
|
||||||
|
|
||||||
|
result.before_create = baseline_snapshot
|
||||||
|
# the part of memory used for holding the model weights
|
||||||
|
result.weights_memory = weights_memory
|
||||||
|
|
||||||
|
result.before_profile.measure()
|
||||||
|
|
||||||
|
yield result
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
result.after_profile.measure()
|
||||||
|
|
||||||
|
diff_profile = result.after_profile - result.before_profile
|
||||||
|
diff_from_create = result.after_profile - result.before_create
|
||||||
|
result.torch_peak_increase = diff_profile.torch_peak
|
||||||
|
result.non_torch_increase = diff_from_create.non_torch_memory
|
||||||
|
result.profile_time = diff_profile.timestamp
|
||||||
|
|
||||||
|
non_torch_memory = result.non_torch_increase
|
||||||
|
peak_activation_memory = result.torch_peak_increase
|
||||||
|
result.non_kv_cache_memory = (
|
||||||
|
non_torch_memory + peak_activation_memory + result.weights_memory
|
||||||
|
) # noqa
|
||||||
64
vllm/_utils/nccl.py
Normal file
64
vllm/_utils/nccl.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def find_nccl_library() -> str:
|
||||||
|
"""Return NCCL/RCCL shared library name to load.
|
||||||
|
|
||||||
|
Uses `VLLM_NCCL_SO_PATH` if set; otherwise chooses by torch backend.
|
||||||
|
"""
|
||||||
|
so_file = envs.VLLM_NCCL_SO_PATH
|
||||||
|
if so_file:
|
||||||
|
logger.info(
|
||||||
|
"Found nccl from environment variable VLLM_NCCL_SO_PATH=%s", so_file
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if torch.version.cuda is not None:
|
||||||
|
so_file = "libnccl.so.2"
|
||||||
|
elif torch.version.hip is not None:
|
||||||
|
so_file = "librccl.so.1"
|
||||||
|
else:
|
||||||
|
raise ValueError("NCCL only supports CUDA and ROCm backends.")
|
||||||
|
logger.debug_once("Found nccl from library %s", so_file)
|
||||||
|
return so_file
|
||||||
|
|
||||||
|
|
||||||
|
def find_nccl_include_paths() -> list[str] | None:
|
||||||
|
"""Return possible include paths containing `nccl.h`.
|
||||||
|
|
||||||
|
Considers `VLLM_NCCL_INCLUDE_PATH` and the `nvidia-nccl-cuXX` package.
|
||||||
|
"""
|
||||||
|
paths: list[str] = []
|
||||||
|
inc = envs.VLLM_NCCL_INCLUDE_PATH
|
||||||
|
if inc and os.path.isdir(inc):
|
||||||
|
paths.append(inc)
|
||||||
|
|
||||||
|
try:
|
||||||
|
spec = importlib.util.find_spec("nvidia.nccl")
|
||||||
|
if spec and getattr(spec, "submodule_search_locations", None):
|
||||||
|
for loc in spec.submodule_search_locations:
|
||||||
|
inc_dir = os.path.join(loc, "include")
|
||||||
|
if os.path.exists(os.path.join(inc_dir, "nccl.h")):
|
||||||
|
paths.append(inc_dir)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("Failed to find nccl include path from nvidia.nccl package: %s", e)
|
||||||
|
|
||||||
|
seen: set[str] = set()
|
||||||
|
out: list[str] = []
|
||||||
|
for p in paths:
|
||||||
|
if p and p not in seen:
|
||||||
|
out.append(p)
|
||||||
|
seen.add(p)
|
||||||
|
return out or None
|
||||||
331
vllm/_utils/network_utils.py
Normal file
331
vllm/_utils/network_utils.py
Normal file
@@ -0,0 +1,331 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import contextlib
|
||||||
|
import ipaddress
|
||||||
|
import os
|
||||||
|
import socket
|
||||||
|
import sys
|
||||||
|
import warnings
|
||||||
|
from collections.abc import (
|
||||||
|
Iterator,
|
||||||
|
Sequence,
|
||||||
|
)
|
||||||
|
from typing import Any
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
import zmq
|
||||||
|
import zmq.asyncio
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def close_sockets(sockets: Sequence[zmq.Socket | zmq.asyncio.Socket]):
|
||||||
|
for sock in sockets:
|
||||||
|
if sock is not None:
|
||||||
|
sock.close(linger=0)
|
||||||
|
|
||||||
|
|
||||||
|
def get_ip() -> str:
|
||||||
|
host_ip = envs.VLLM_HOST_IP
|
||||||
|
if "HOST_IP" in os.environ and "VLLM_HOST_IP" not in os.environ:
|
||||||
|
logger.warning(
|
||||||
|
"The environment variable HOST_IP is deprecated and ignored, as"
|
||||||
|
" it is often used by Docker and other software to"
|
||||||
|
" interact with the container's network stack. Please "
|
||||||
|
"use VLLM_HOST_IP instead to set the IP address for vLLM processes"
|
||||||
|
" to communicate with each other."
|
||||||
|
)
|
||||||
|
if host_ip:
|
||||||
|
return host_ip
|
||||||
|
|
||||||
|
# IP is not set, try to get it from the network interface
|
||||||
|
|
||||||
|
# try ipv4
|
||||||
|
try:
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
||||||
|
s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
|
||||||
|
return s.getsockname()[0]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# try ipv6
|
||||||
|
try:
|
||||||
|
with socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) as s:
|
||||||
|
# Google's public DNS server, see
|
||||||
|
# https://developers.google.com/speed/public-dns/docs/using#addresses
|
||||||
|
s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
|
||||||
|
return s.getsockname()[0]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
warnings.warn(
|
||||||
|
"Failed to get the IP address, using 0.0.0.0 by default."
|
||||||
|
"The value can be set by the environment variable"
|
||||||
|
" VLLM_HOST_IP or HOST_IP.",
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
return "0.0.0.0"
|
||||||
|
|
||||||
|
|
||||||
|
def test_loopback_bind(address, family):
|
||||||
|
try:
|
||||||
|
s = socket.socket(family, socket.SOCK_DGRAM)
|
||||||
|
s.bind((address, 0)) # Port 0 = auto assign
|
||||||
|
s.close()
|
||||||
|
return True
|
||||||
|
except OSError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_loopback_ip() -> str:
|
||||||
|
loopback_ip = envs.VLLM_LOOPBACK_IP
|
||||||
|
if loopback_ip:
|
||||||
|
return loopback_ip
|
||||||
|
|
||||||
|
# VLLM_LOOPBACK_IP is not set, try to get it based on network interface
|
||||||
|
|
||||||
|
if test_loopback_bind("127.0.0.1", socket.AF_INET):
|
||||||
|
return "127.0.0.1"
|
||||||
|
elif test_loopback_bind("::1", socket.AF_INET6):
|
||||||
|
return "::1"
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Neither 127.0.0.1 nor ::1 are bound to a local interface. "
|
||||||
|
"Set the VLLM_LOOPBACK_IP environment variable explicitly."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_ipv6_address(address: str) -> bool:
|
||||||
|
try:
|
||||||
|
ipaddress.IPv6Address(address)
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def split_host_port(host_port: str) -> tuple[str, int]:
|
||||||
|
# ipv6
|
||||||
|
if host_port.startswith("["):
|
||||||
|
host, port = host_port.rsplit("]", 1)
|
||||||
|
host = host[1:]
|
||||||
|
port = port.split(":")[1]
|
||||||
|
return host, int(port)
|
||||||
|
else:
|
||||||
|
host, port = host_port.split(":")
|
||||||
|
return host, int(port)
|
||||||
|
|
||||||
|
|
||||||
|
def join_host_port(host: str, port: int) -> str:
|
||||||
|
if is_valid_ipv6_address(host):
|
||||||
|
return f"[{host}]:{port}"
|
||||||
|
else:
|
||||||
|
return f"{host}:{port}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_distributed_init_method(ip: str, port: int) -> str:
|
||||||
|
return get_tcp_uri(ip, port)
|
||||||
|
|
||||||
|
|
||||||
|
def get_tcp_uri(ip: str, port: int) -> str:
|
||||||
|
if is_valid_ipv6_address(ip):
|
||||||
|
return f"tcp://[{ip}]:{port}"
|
||||||
|
else:
|
||||||
|
return f"tcp://{ip}:{port}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_open_zmq_ipc_path() -> str:
|
||||||
|
base_rpc_path = envs.VLLM_RPC_BASE_PATH
|
||||||
|
return f"ipc://{base_rpc_path}/{uuid4()}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_open_zmq_inproc_path() -> str:
|
||||||
|
return f"inproc://{uuid4()}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_open_port() -> int:
|
||||||
|
"""
|
||||||
|
Get an open port for the vLLM process to listen on.
|
||||||
|
An edge case to handle, is when we run data parallel,
|
||||||
|
we need to avoid ports that are potentially used by
|
||||||
|
the data parallel master process.
|
||||||
|
Right now we reserve 10 ports for the data parallel master
|
||||||
|
process. Currently it uses 2 ports.
|
||||||
|
"""
|
||||||
|
if "VLLM_DP_MASTER_PORT" in os.environ:
|
||||||
|
dp_master_port = envs.VLLM_DP_MASTER_PORT
|
||||||
|
reserved_port_range = range(dp_master_port, dp_master_port + 10)
|
||||||
|
while True:
|
||||||
|
candidate_port = _get_open_port()
|
||||||
|
if candidate_port not in reserved_port_range:
|
||||||
|
return candidate_port
|
||||||
|
return _get_open_port()
|
||||||
|
|
||||||
|
|
||||||
|
def get_open_ports_list(count: int = 5) -> list[int]:
|
||||||
|
"""Get a list of open ports."""
|
||||||
|
ports = set[int]()
|
||||||
|
while len(ports) < count:
|
||||||
|
ports.add(get_open_port())
|
||||||
|
return list(ports)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_open_port() -> int:
|
||||||
|
port = envs.VLLM_PORT
|
||||||
|
if port is not None:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
s.bind(("", port))
|
||||||
|
return port
|
||||||
|
except OSError:
|
||||||
|
port += 1 # Increment port number if already in use
|
||||||
|
logger.info("Port %d is already in use, trying port %d", port - 1, port)
|
||||||
|
# try ipv4
|
||||||
|
try:
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
s.bind(("", 0))
|
||||||
|
return s.getsockname()[1]
|
||||||
|
except OSError:
|
||||||
|
# try ipv6
|
||||||
|
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
|
||||||
|
s.bind(("", 0))
|
||||||
|
return s.getsockname()[1]
|
||||||
|
|
||||||
|
|
||||||
|
def find_process_using_port(port: int) -> psutil.Process | None:
|
||||||
|
# TODO: We can not check for running processes with network
|
||||||
|
# port on macOS. Therefore, we can not have a full graceful shutdown
|
||||||
|
# of vLLM. For now, let's not look for processes in this case.
|
||||||
|
# Ref: https://www.florianreinhard.de/accessdenied-in-psutil/
|
||||||
|
if sys.platform.startswith("darwin"):
|
||||||
|
return None
|
||||||
|
|
||||||
|
our_pid = os.getpid()
|
||||||
|
for conn in psutil.net_connections():
|
||||||
|
if conn.laddr.port == port and (conn.pid is not None and conn.pid != our_pid):
|
||||||
|
try:
|
||||||
|
return psutil.Process(conn.pid)
|
||||||
|
except psutil.NoSuchProcess:
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def split_zmq_path(path: str) -> tuple[str, str, str]:
|
||||||
|
"""Split a zmq path into its parts."""
|
||||||
|
parsed = urlparse(path)
|
||||||
|
if not parsed.scheme:
|
||||||
|
raise ValueError(f"Invalid zmq path: {path}")
|
||||||
|
|
||||||
|
scheme = parsed.scheme
|
||||||
|
host = parsed.hostname or ""
|
||||||
|
port = str(parsed.port or "")
|
||||||
|
|
||||||
|
if scheme == "tcp" and not all((host, port)):
|
||||||
|
# The host and port fields are required for tcp
|
||||||
|
raise ValueError(f"Invalid zmq path: {path}")
|
||||||
|
|
||||||
|
if scheme != "tcp" and port:
|
||||||
|
# port only makes sense with tcp
|
||||||
|
raise ValueError(f"Invalid zmq path: {path}")
|
||||||
|
|
||||||
|
return scheme, host, port
|
||||||
|
|
||||||
|
|
||||||
|
def make_zmq_path(scheme: str, host: str, port: int | None = None) -> str:
|
||||||
|
"""Make a ZMQ path from its parts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scheme: The ZMQ transport scheme (e.g. tcp, ipc, inproc).
|
||||||
|
host: The host - can be an IPv4 address, IPv6 address, or hostname.
|
||||||
|
port: Optional port number, only used for TCP sockets.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A properly formatted ZMQ path string.
|
||||||
|
"""
|
||||||
|
if port is None:
|
||||||
|
return f"{scheme}://{host}"
|
||||||
|
if is_valid_ipv6_address(host):
|
||||||
|
return f"{scheme}://[{host}]:{port}"
|
||||||
|
return f"{scheme}://{host}:{port}"
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L783 # noqa: E501
|
||||||
|
def make_zmq_socket(
|
||||||
|
ctx: zmq.asyncio.Context | zmq.Context, # type: ignore[name-defined]
|
||||||
|
path: str,
|
||||||
|
socket_type: Any,
|
||||||
|
bind: bool | None = None,
|
||||||
|
identity: bytes | None = None,
|
||||||
|
linger: int | None = None,
|
||||||
|
) -> zmq.Socket | zmq.asyncio.Socket: # type: ignore[name-defined]
|
||||||
|
"""Make a ZMQ socket with the proper bind/connect semantics."""
|
||||||
|
|
||||||
|
mem = psutil.virtual_memory()
|
||||||
|
socket = ctx.socket(socket_type)
|
||||||
|
|
||||||
|
# Calculate buffer size based on system memory
|
||||||
|
total_mem = mem.total / 1024**3
|
||||||
|
available_mem = mem.available / 1024**3
|
||||||
|
# For systems with substantial memory (>32GB total, >16GB available):
|
||||||
|
# - Set a large 0.5GB buffer to improve throughput
|
||||||
|
# For systems with less memory:
|
||||||
|
# - Use system default (-1) to avoid excessive memory consumption
|
||||||
|
buf_size = int(0.5 * 1024**3) if total_mem > 32 and available_mem > 16 else -1
|
||||||
|
|
||||||
|
if bind is None:
|
||||||
|
bind = socket_type not in (zmq.PUSH, zmq.SUB, zmq.XSUB)
|
||||||
|
|
||||||
|
if socket_type in (zmq.PULL, zmq.DEALER, zmq.ROUTER):
|
||||||
|
socket.setsockopt(zmq.RCVHWM, 0)
|
||||||
|
socket.setsockopt(zmq.RCVBUF, buf_size)
|
||||||
|
|
||||||
|
if socket_type in (zmq.PUSH, zmq.DEALER, zmq.ROUTER):
|
||||||
|
socket.setsockopt(zmq.SNDHWM, 0)
|
||||||
|
socket.setsockopt(zmq.SNDBUF, buf_size)
|
||||||
|
|
||||||
|
if identity is not None:
|
||||||
|
socket.setsockopt(zmq.IDENTITY, identity)
|
||||||
|
|
||||||
|
if linger is not None:
|
||||||
|
socket.setsockopt(zmq.LINGER, linger)
|
||||||
|
|
||||||
|
if socket_type == zmq.XPUB:
|
||||||
|
socket.setsockopt(zmq.XPUB_VERBOSE, True)
|
||||||
|
|
||||||
|
# Determine if the path is a TCP socket with an IPv6 address.
|
||||||
|
# Enable IPv6 on the zmq socket if so.
|
||||||
|
scheme, host, _ = split_zmq_path(path)
|
||||||
|
if scheme == "tcp" and is_valid_ipv6_address(host):
|
||||||
|
socket.setsockopt(zmq.IPV6, 1)
|
||||||
|
|
||||||
|
if bind:
|
||||||
|
socket.bind(path)
|
||||||
|
else:
|
||||||
|
socket.connect(path)
|
||||||
|
|
||||||
|
return socket
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def zmq_socket_ctx(
|
||||||
|
path: str,
|
||||||
|
socket_type: Any,
|
||||||
|
bind: bool | None = None,
|
||||||
|
linger: int = 0,
|
||||||
|
identity: bytes | None = None,
|
||||||
|
) -> Iterator[zmq.Socket]:
|
||||||
|
"""Context manager for a ZMQ socket"""
|
||||||
|
|
||||||
|
ctx = zmq.Context() # type: ignore[attr-defined]
|
||||||
|
try:
|
||||||
|
yield make_zmq_socket(ctx, path, socket_type, bind=bind, identity=identity)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.debug("Got Keyboard Interrupt.")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
ctx.destroy(linger=linger)
|
||||||
59
vllm/_utils/platform_utils.py
Normal file
59
vllm/_utils/platform_utils.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import multiprocessing
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from concurrent.futures.process import ProcessPoolExecutor
|
||||||
|
from functools import cache
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def cuda_is_initialized() -> bool:
|
||||||
|
"""Check if CUDA is initialized."""
|
||||||
|
if not torch.cuda._is_compiled():
|
||||||
|
return False
|
||||||
|
return torch.cuda.is_initialized()
|
||||||
|
|
||||||
|
|
||||||
|
def xpu_is_initialized() -> bool:
|
||||||
|
"""Check if XPU is initialized."""
|
||||||
|
if not torch.xpu._is_compiled():
|
||||||
|
return False
|
||||||
|
return torch.xpu.is_initialized()
|
||||||
|
|
||||||
|
|
||||||
|
def get_cu_count(device_id: int = 0) -> int:
|
||||||
|
"""Returns the total number of compute units (CU) on single GPU."""
|
||||||
|
return torch.cuda.get_device_properties(device_id).multi_processor_count
|
||||||
|
|
||||||
|
|
||||||
|
def cuda_get_device_properties(
|
||||||
|
device, names: Sequence[str], init_cuda=False
|
||||||
|
) -> tuple[Any, ...]:
|
||||||
|
"""Get specified CUDA device property values without initializing CUDA in
|
||||||
|
the current process."""
|
||||||
|
if init_cuda or cuda_is_initialized():
|
||||||
|
props = torch.cuda.get_device_properties(device)
|
||||||
|
return tuple(getattr(props, name) for name in names)
|
||||||
|
|
||||||
|
# Run in subprocess to avoid initializing CUDA as a side effect.
|
||||||
|
mp_ctx = multiprocessing.get_context("fork")
|
||||||
|
with ProcessPoolExecutor(max_workers=1, mp_context=mp_ctx) as executor:
|
||||||
|
return executor.submit(cuda_get_device_properties, device, names, True).result()
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def is_pin_memory_available() -> bool:
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
return current_platform.is_pin_memory_available()
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def is_uva_available() -> bool:
|
||||||
|
"""Check if Unified Virtual Addressing (UVA) is available."""
|
||||||
|
# UVA requires pinned memory.
|
||||||
|
# TODO: Add more requirements for UVA if needed.
|
||||||
|
return is_pin_memory_available()
|
||||||
56
vllm/_utils/profiling.py
Normal file
56
vllm/_utils/profiling.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
from collections.abc import Callable
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def cprofile_context(save_file: str | None = None):
|
||||||
|
"""Run a cprofile
|
||||||
|
|
||||||
|
Args:
|
||||||
|
save_file: path to save the profile result. "1" or
|
||||||
|
None will result in printing to stdout.
|
||||||
|
"""
|
||||||
|
import cProfile
|
||||||
|
|
||||||
|
prof = cProfile.Profile()
|
||||||
|
prof.enable()
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
prof.disable()
|
||||||
|
if save_file and save_file != "1":
|
||||||
|
prof.dump_stats(save_file)
|
||||||
|
else:
|
||||||
|
prof.print_stats(sort="cumtime")
|
||||||
|
|
||||||
|
|
||||||
|
def cprofile(save_file: str | None = None, enabled: bool = True):
|
||||||
|
"""Decorator to profile a Python method using cProfile.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
save_file: Path to save the profile result.
|
||||||
|
If "1", None, or "", results will be printed to stdout.
|
||||||
|
enabled: Set to false to turn this into a no-op
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(func: Callable):
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args: Any, **kwargs: Any):
|
||||||
|
if not enabled:
|
||||||
|
# If profiling is disabled, just call the function directly.
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
with cprofile_context(save_file):
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
49
vllm/_utils/registry.py
Normal file
49
vllm/_utils/registry.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class ExtensionManager:
|
||||||
|
"""
|
||||||
|
A registry for managing pluggable extension classes.
|
||||||
|
|
||||||
|
This class provides a simple mechanism to register and instantiate
|
||||||
|
extension classes by name. It is commonly used to implement plugin
|
||||||
|
systems where different implementations can be swapped at runtime.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Basic usage with a registry instance:
|
||||||
|
|
||||||
|
>>> FOO_REGISTRY = ExtensionManager()
|
||||||
|
>>> @FOO_REGISTRY.register("my_foo_impl")
|
||||||
|
... class MyFooImpl(Foo):
|
||||||
|
... def __init__(self, value):
|
||||||
|
... self.value = value
|
||||||
|
>>> foo_impl = FOO_REGISTRY.load("my_foo_impl", value=123)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""
|
||||||
|
Initialize an empty extension registry.
|
||||||
|
"""
|
||||||
|
self.name2class: dict[str, type] = {}
|
||||||
|
|
||||||
|
def register(self, name: str):
|
||||||
|
"""
|
||||||
|
Decorator to register a class with the given name.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def wrap(cls_to_register):
|
||||||
|
self.name2class[name] = cls_to_register
|
||||||
|
return cls_to_register
|
||||||
|
|
||||||
|
return wrap
|
||||||
|
|
||||||
|
def load(self, cls_name: str, *args, **kwargs) -> Any:
|
||||||
|
"""
|
||||||
|
Instantiate and return a registered extension class by name.
|
||||||
|
"""
|
||||||
|
cls = self.name2class.get(cls_name)
|
||||||
|
assert cls is not None, f"Extension class {cls_name} not found"
|
||||||
|
return cls(*args, **kwargs)
|
||||||
169
vllm/_utils/serial_utils.py
Normal file
169
vllm/_utils/serial_utils.py
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import base64
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from typing_extensions import assert_never
|
||||||
|
|
||||||
|
from vllm import PoolingRequestOutput
|
||||||
|
|
||||||
|
sys_byteorder = sys.byteorder
|
||||||
|
|
||||||
|
|
||||||
|
EMBED_DTYPE_TO_TORCH_DTYPE = {
|
||||||
|
"float32": torch.float32,
|
||||||
|
"float16": torch.float16,
|
||||||
|
"bfloat16": torch.bfloat16,
|
||||||
|
# I'm not sure if other platforms' CPUs support the fp8 data format.
|
||||||
|
# EMBED_DTYPE only uses the fp8 data representation,
|
||||||
|
# does not use fp8 computation, and only occurs on the CPU.
|
||||||
|
# Apologize for any possible break.
|
||||||
|
"fp8_e4m3": torch.float8_e4m3fn,
|
||||||
|
"fp8_e5m2": torch.float8_e5m2,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
EMBED_DTYPE_TO_TORCH_DTYPE_VIEW = {
|
||||||
|
"float32": torch.float32,
|
||||||
|
"float16": torch.float16,
|
||||||
|
# numpy does not support bfloat16 and fp8
|
||||||
|
"bfloat16": torch.float16,
|
||||||
|
"fp8_e4m3": torch.uint8,
|
||||||
|
"fp8_e5m2": torch.uint8,
|
||||||
|
}
|
||||||
|
|
||||||
|
EMBED_DTYPE_TO_NUMPY_DTYPE_VIEW = {
|
||||||
|
"float32": np.float32,
|
||||||
|
"float16": np.float16,
|
||||||
|
# numpy does not support bfloat16 and fp8
|
||||||
|
"bfloat16": np.float16,
|
||||||
|
"fp8_e4m3": np.uint8,
|
||||||
|
"fp8_e5m2": np.uint8,
|
||||||
|
}
|
||||||
|
|
||||||
|
ENDIANNESS = ["native", "big", "little"]
|
||||||
|
|
||||||
|
EmbedDType = Literal["float32", "float16", "bfloat16", "fp8_e4m3", "fp8_e5m2"]
|
||||||
|
Endianness = Literal["native", "big", "little"]
|
||||||
|
EncodingFormat = Literal["float", "base64", "bytes"]
|
||||||
|
|
||||||
|
|
||||||
|
def tensor2binary(
|
||||||
|
tensor: torch.Tensor, embed_dtype: EmbedDType, endianness: Endianness
|
||||||
|
) -> bytes:
|
||||||
|
assert isinstance(tensor, torch.Tensor)
|
||||||
|
assert embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE
|
||||||
|
assert endianness in ENDIANNESS
|
||||||
|
|
||||||
|
torch_dtype = EMBED_DTYPE_TO_TORCH_DTYPE[embed_dtype]
|
||||||
|
torch_view_dtype = EMBED_DTYPE_TO_TORCH_DTYPE_VIEW[embed_dtype]
|
||||||
|
|
||||||
|
np_array = (
|
||||||
|
tensor.to(torch_dtype).flatten().contiguous().view(torch_view_dtype).numpy()
|
||||||
|
)
|
||||||
|
|
||||||
|
if endianness != "native" and endianness != sys_byteorder:
|
||||||
|
np_array = np_array.byteswap()
|
||||||
|
|
||||||
|
return np_array.tobytes()
|
||||||
|
|
||||||
|
|
||||||
|
def binary2tensor(
|
||||||
|
binary: bytes,
|
||||||
|
shape: tuple[int, ...],
|
||||||
|
embed_dtype: EmbedDType,
|
||||||
|
endianness: Endianness,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
assert embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE
|
||||||
|
assert embed_dtype in EMBED_DTYPE_TO_NUMPY_DTYPE_VIEW
|
||||||
|
assert endianness in ENDIANNESS
|
||||||
|
|
||||||
|
torch_dtype = EMBED_DTYPE_TO_TORCH_DTYPE[embed_dtype]
|
||||||
|
np_dtype = EMBED_DTYPE_TO_NUMPY_DTYPE_VIEW[embed_dtype]
|
||||||
|
|
||||||
|
np_array = np.frombuffer(binary, dtype=np_dtype).reshape(shape)
|
||||||
|
|
||||||
|
if endianness != "native" and endianness != sys_byteorder:
|
||||||
|
np_array = np_array.byteswap()
|
||||||
|
|
||||||
|
return torch.from_numpy(np_array).view(torch_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def encode_pooling_output(
|
||||||
|
output: PoolingRequestOutput,
|
||||||
|
encoding_format: EncodingFormat,
|
||||||
|
embed_dtype: EmbedDType,
|
||||||
|
endianness: Endianness,
|
||||||
|
) -> list[float] | str | bytes:
|
||||||
|
if encoding_format == "float":
|
||||||
|
return output.outputs.data.tolist()
|
||||||
|
elif encoding_format == "base64":
|
||||||
|
embedding_bytes = tensor2binary(output.outputs.data, embed_dtype, endianness)
|
||||||
|
return base64.b64encode(embedding_bytes).decode("utf-8")
|
||||||
|
elif encoding_format == "bytes":
|
||||||
|
return tensor2binary(output.outputs.data, embed_dtype, endianness)
|
||||||
|
assert_never(encoding_format)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MetadataItem:
|
||||||
|
index: int
|
||||||
|
embed_dtype: EmbedDType
|
||||||
|
endianness: Endianness
|
||||||
|
start: int
|
||||||
|
end: int
|
||||||
|
shape: tuple[int, ...]
|
||||||
|
|
||||||
|
|
||||||
|
def encode_pooling_bytes(
|
||||||
|
pooling_outputs: list[PoolingRequestOutput],
|
||||||
|
embed_dtype: EmbedDType,
|
||||||
|
endianness: Endianness,
|
||||||
|
):
|
||||||
|
num_prompt_tokens = 0
|
||||||
|
items: list[dict[str, MetadataItem]] = []
|
||||||
|
body = []
|
||||||
|
offset = 0
|
||||||
|
for idx, output in enumerate(pooling_outputs):
|
||||||
|
binary = tensor2binary(
|
||||||
|
tensor=output.outputs.data,
|
||||||
|
embed_dtype=embed_dtype,
|
||||||
|
endianness=endianness,
|
||||||
|
)
|
||||||
|
size = len(binary)
|
||||||
|
|
||||||
|
item = {
|
||||||
|
"index": idx,
|
||||||
|
"embed_dtype": embed_dtype,
|
||||||
|
"endianness": endianness,
|
||||||
|
"start": offset,
|
||||||
|
"end": offset + size,
|
||||||
|
"shape": output.outputs.data.shape,
|
||||||
|
}
|
||||||
|
|
||||||
|
body.append(binary)
|
||||||
|
items.append(item)
|
||||||
|
prompt_token_ids = output.prompt_token_ids
|
||||||
|
num_prompt_tokens += len(prompt_token_ids)
|
||||||
|
offset += size
|
||||||
|
|
||||||
|
usage = {
|
||||||
|
"prompt_tokens": num_prompt_tokens,
|
||||||
|
"total_tokens": num_prompt_tokens,
|
||||||
|
}
|
||||||
|
return body, items, usage
|
||||||
|
|
||||||
|
|
||||||
|
def decode_pooling_output(items: list[MetadataItem], body: bytes) -> list[torch.Tensor]:
|
||||||
|
items.sort(key=lambda x: x.index)
|
||||||
|
|
||||||
|
tensor_list: list[torch.Tensor] = []
|
||||||
|
for item in items:
|
||||||
|
binary = body[item.start : item.end]
|
||||||
|
tensor = binary2tensor(binary, item.shape, item.embed_dtype, item.endianness)
|
||||||
|
tensor_list.append(tensor)
|
||||||
|
return tensor_list
|
||||||
229
vllm/_utils/system_utils.py
Normal file
229
vllm/_utils/system_utils.py
Normal file
@@ -0,0 +1,229 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import multiprocessing
|
||||||
|
import os
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
from collections.abc import Callable, Iterator
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TextIO
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.ray.lazy_utils import is_in_ray_actor
|
||||||
|
|
||||||
|
from .platform_utils import cuda_is_initialized, xpu_is_initialized
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
CYAN = "\033[1;36m"
|
||||||
|
RESET = "\033[0;0m"
|
||||||
|
|
||||||
|
|
||||||
|
# Environment variable utilities
|
||||||
|
|
||||||
|
|
||||||
|
def update_environment_variables(envs_dict: dict[str, str]):
|
||||||
|
"""Update multiple environment variables with logging."""
|
||||||
|
for k, v in envs_dict.items():
|
||||||
|
if k in os.environ and os.environ[k] != v:
|
||||||
|
logger.warning(
|
||||||
|
"Overwriting environment variable %s from '%s' to '%s'",
|
||||||
|
k,
|
||||||
|
os.environ[k],
|
||||||
|
v,
|
||||||
|
)
|
||||||
|
os.environ[k] = v
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def set_env_var(key: str, value: str) -> Iterator[None]:
|
||||||
|
"""Temporarily set an environment variable."""
|
||||||
|
old = os.environ.get(key)
|
||||||
|
os.environ[key] = value
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
if old is None:
|
||||||
|
os.environ.pop(key, None)
|
||||||
|
else:
|
||||||
|
os.environ[key] = old
|
||||||
|
|
||||||
|
|
||||||
|
# File path utilities
|
||||||
|
|
||||||
|
|
||||||
|
def unique_filepath(fn: Callable[[int], Path]) -> Path:
|
||||||
|
"""Generate a unique file path by trying incrementing integers.
|
||||||
|
|
||||||
|
Note: This function has a TOCTOU race condition.
|
||||||
|
Caller should use atomic operations (e.g., open with 'x' mode)
|
||||||
|
when creating the file to ensure thread safety.
|
||||||
|
"""
|
||||||
|
i = 0
|
||||||
|
while True:
|
||||||
|
p = fn(i)
|
||||||
|
if not p.exists():
|
||||||
|
return p
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
|
||||||
|
# Process management utilities
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_force_spawn():
|
||||||
|
"""Check if we need to force the use of the `spawn` multiprocessing start
|
||||||
|
method.
|
||||||
|
"""
|
||||||
|
if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") == "spawn":
|
||||||
|
return
|
||||||
|
|
||||||
|
reasons = []
|
||||||
|
if is_in_ray_actor():
|
||||||
|
# even if we choose to spawn, we need to pass the ray address
|
||||||
|
# to the subprocess so that it knows how to connect to the ray cluster.
|
||||||
|
# env vars are inherited by subprocesses, even if we use spawn.
|
||||||
|
import ray
|
||||||
|
|
||||||
|
os.environ["RAY_ADDRESS"] = ray.get_runtime_context().gcs_address
|
||||||
|
reasons.append("In a Ray actor and can only be spawned")
|
||||||
|
|
||||||
|
if cuda_is_initialized():
|
||||||
|
reasons.append("CUDA is initialized")
|
||||||
|
elif xpu_is_initialized():
|
||||||
|
reasons.append("XPU is initialized")
|
||||||
|
|
||||||
|
if reasons:
|
||||||
|
logger.warning(
|
||||||
|
"We must use the `spawn` multiprocessing start method. "
|
||||||
|
"Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
|
||||||
|
"See https://docs.vllm.ai/en/latest/usage/"
|
||||||
|
"troubleshooting.html#python-multiprocessing "
|
||||||
|
"for more information. Reasons: %s",
|
||||||
|
"; ".join(reasons),
|
||||||
|
)
|
||||||
|
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||||
|
|
||||||
|
|
||||||
|
def get_mp_context():
|
||||||
|
"""Get a multiprocessing context with a particular method (spawn or fork).
|
||||||
|
By default we follow the value of the VLLM_WORKER_MULTIPROC_METHOD to
|
||||||
|
determine the multiprocessing method (default is fork). However, under
|
||||||
|
certain conditions, we may enforce spawn and override the value of
|
||||||
|
VLLM_WORKER_MULTIPROC_METHOD.
|
||||||
|
"""
|
||||||
|
_maybe_force_spawn()
|
||||||
|
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
|
||||||
|
return multiprocessing.get_context(mp_method)
|
||||||
|
|
||||||
|
|
||||||
|
def set_process_title(
|
||||||
|
name: str,
|
||||||
|
suffix: str = "",
|
||||||
|
prefix: str = envs.VLLM_PROCESS_NAME_PREFIX,
|
||||||
|
) -> None:
|
||||||
|
"""Set the current process title with optional suffix."""
|
||||||
|
try:
|
||||||
|
import setproctitle
|
||||||
|
except ImportError:
|
||||||
|
return
|
||||||
|
|
||||||
|
if suffix:
|
||||||
|
name = f"{name}_{suffix}"
|
||||||
|
|
||||||
|
setproctitle.setproctitle(f"{prefix}::{name}")
|
||||||
|
|
||||||
|
|
||||||
|
def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
|
||||||
|
"""Add colored prefix to file output for log decoration."""
|
||||||
|
prefix = f"{CYAN}({worker_name} pid={pid}){RESET} "
|
||||||
|
file_write = file.write
|
||||||
|
|
||||||
|
def write_with_prefix(s: str):
|
||||||
|
if not s:
|
||||||
|
return
|
||||||
|
if file.start_new_line: # type: ignore[attr-defined]
|
||||||
|
file_write(prefix)
|
||||||
|
idx = 0
|
||||||
|
while (next_idx := s.find("\n", idx)) != -1:
|
||||||
|
next_idx += 1
|
||||||
|
file_write(s[idx:next_idx])
|
||||||
|
if next_idx == len(s):
|
||||||
|
file.start_new_line = True # type: ignore[attr-defined]
|
||||||
|
return
|
||||||
|
file_write(prefix)
|
||||||
|
idx = next_idx
|
||||||
|
file_write(s[idx:])
|
||||||
|
file.start_new_line = False # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
file.start_new_line = True # type: ignore[attr-defined]
|
||||||
|
file.write = write_with_prefix # type: ignore[method-assign]
|
||||||
|
|
||||||
|
|
||||||
|
def decorate_logs(process_name: str | None = None) -> None:
|
||||||
|
"""Decorate stdout/stderr with process name and PID prefix."""
|
||||||
|
if process_name is None:
|
||||||
|
process_name = get_mp_context().current_process().name
|
||||||
|
|
||||||
|
pid = os.getpid()
|
||||||
|
_add_prefix(sys.stdout, process_name, pid)
|
||||||
|
_add_prefix(sys.stderr, process_name, pid)
|
||||||
|
|
||||||
|
|
||||||
|
def kill_process_tree(pid: int):
|
||||||
|
"""
|
||||||
|
Kills all descendant processes of the given pid by sending SIGKILL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pid (int): Process ID of the parent process
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
parent = psutil.Process(pid)
|
||||||
|
except psutil.NoSuchProcess:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get all children recursively
|
||||||
|
children = parent.children(recursive=True)
|
||||||
|
|
||||||
|
# Send SIGKILL to all children first
|
||||||
|
for child in children:
|
||||||
|
with contextlib.suppress(ProcessLookupError):
|
||||||
|
os.kill(child.pid, signal.SIGKILL)
|
||||||
|
|
||||||
|
# Finally kill the parent
|
||||||
|
with contextlib.suppress(ProcessLookupError):
|
||||||
|
os.kill(pid, signal.SIGKILL)
|
||||||
|
|
||||||
|
|
||||||
|
# Resource utilities
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630
|
||||||
|
def set_ulimit(target_soft_limit: int = 65535):
|
||||||
|
if sys.platform.startswith("win"):
|
||||||
|
logger.info("Windows detected, skipping ulimit adjustment.")
|
||||||
|
return
|
||||||
|
|
||||||
|
import resource
|
||||||
|
|
||||||
|
resource_type = resource.RLIMIT_NOFILE
|
||||||
|
current_soft, current_hard = resource.getrlimit(resource_type)
|
||||||
|
|
||||||
|
if current_soft < target_soft_limit:
|
||||||
|
try:
|
||||||
|
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
|
||||||
|
except ValueError as e:
|
||||||
|
logger.warning(
|
||||||
|
"Found ulimit of %s and failed to automatically increase "
|
||||||
|
"with error %s. This can cause fd limit errors like "
|
||||||
|
"`OSError: [Errno 24] Too many open files`. Consider "
|
||||||
|
"increasing with ulimit -n",
|
||||||
|
current_soft,
|
||||||
|
e,
|
||||||
|
)
|
||||||
255
vllm/_utils/tensor_schema.py
Normal file
255
vllm/_utils/tensor_schema.py
Normal file
@@ -0,0 +1,255 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from types import UnionType
|
||||||
|
from typing import Annotated, Any, Union, get_args, get_origin, get_type_hints
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TensorShape:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*dims: int | str,
|
||||||
|
dynamic_dims: set[str] | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dims = dims
|
||||||
|
self.dynamic_dims = dynamic_dims if dynamic_dims else set()
|
||||||
|
|
||||||
|
def resolve(self, **bindings: int) -> tuple[int | str, ...]:
|
||||||
|
resolved = list[int | str]()
|
||||||
|
for dim in self.dims:
|
||||||
|
if isinstance(dim, str) and dim in bindings:
|
||||||
|
resolved.append(bindings[dim])
|
||||||
|
else:
|
||||||
|
resolved.append(dim)
|
||||||
|
return tuple(resolved)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""Return a string representation of the tensor shape."""
|
||||||
|
dim_strs = []
|
||||||
|
for dim in self.dims:
|
||||||
|
if isinstance(dim, str):
|
||||||
|
if dim in self.dynamic_dims:
|
||||||
|
dim_strs.append(f"{dim}*") # Mark dynamic dimensions with *
|
||||||
|
else:
|
||||||
|
dim_strs.append(dim)
|
||||||
|
else:
|
||||||
|
dim_strs.append(str(dim))
|
||||||
|
return f"({', '.join(dim_strs)})"
|
||||||
|
|
||||||
|
|
||||||
|
class TensorSchema:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
validate: bool = True,
|
||||||
|
resolve_bindings: dict[str, int] | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self._resolve_bindings = resolve_bindings if resolve_bindings else {}
|
||||||
|
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
if validate:
|
||||||
|
self.validate()
|
||||||
|
|
||||||
|
def __getitem__(self, key: str) -> Any:
|
||||||
|
return getattr(self, key)
|
||||||
|
|
||||||
|
def get(self, key: str, default: Any = None) -> Any:
|
||||||
|
return getattr(self, key, default)
|
||||||
|
|
||||||
|
def _match_shape_with_dynamic(
|
||||||
|
self,
|
||||||
|
actual: tuple[int, ...],
|
||||||
|
reference: tuple[int, ...],
|
||||||
|
expected_shape: tuple[int | str, ...],
|
||||||
|
dynamic_dims: set[str],
|
||||||
|
) -> bool:
|
||||||
|
if len(actual) != len(reference) or len(actual) > len(expected_shape):
|
||||||
|
return False
|
||||||
|
|
||||||
|
for i, (a, r) in enumerate(zip(actual, reference)):
|
||||||
|
# When validating list inputs, we match shape suffixes only
|
||||||
|
# (e.g. "p", 3, "h", "w"), assuming the list length corresponds
|
||||||
|
# to the leading symbolic dim (e.g. "bn"). This allows comparing
|
||||||
|
# only the trailing dimensions of each element in the list.
|
||||||
|
dim = expected_shape[-len(actual) + i]
|
||||||
|
# Skip this dimension if it's marked dynamic
|
||||||
|
if dim in dynamic_dims:
|
||||||
|
continue
|
||||||
|
if a != r:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _fmt_indexer(self, idxs: tuple[int, ...]) -> str:
|
||||||
|
if not idxs:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
return str(list(idxs))
|
||||||
|
|
||||||
|
def _validate_field(
|
||||||
|
self,
|
||||||
|
value: object,
|
||||||
|
field_name: str,
|
||||||
|
expected_shape: tuple[int | str, ...],
|
||||||
|
dynamic_dims: set[str],
|
||||||
|
leading_idxs: tuple[int, ...] = (),
|
||||||
|
) -> tuple[int, ...]:
|
||||||
|
"""Validate a field and return the actual shape."""
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
return () # Scalar
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
return value.shape
|
||||||
|
|
||||||
|
if not isinstance(value, (list, tuple)):
|
||||||
|
raise TypeError(
|
||||||
|
f"{field_name}{self._fmt_indexer(leading_idxs)} is not "
|
||||||
|
f"one of the expected types: int, float, Tensor, list, tuple. "
|
||||||
|
f"Got: {type(value)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(value) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"{field_name}{self._fmt_indexer(leading_idxs)} is an empty sequence"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure all tensors in the list have the same
|
||||||
|
# shape, besides dynamic dimensions
|
||||||
|
for i, v in enumerate(value):
|
||||||
|
shape = self._validate_field(
|
||||||
|
v,
|
||||||
|
field_name,
|
||||||
|
expected_shape[1:],
|
||||||
|
dynamic_dims,
|
||||||
|
leading_idxs=leading_idxs + (i,),
|
||||||
|
)
|
||||||
|
|
||||||
|
if i == 0:
|
||||||
|
first_shape = shape
|
||||||
|
elif not self._match_shape_with_dynamic(
|
||||||
|
shape,
|
||||||
|
first_shape,
|
||||||
|
expected_shape,
|
||||||
|
dynamic_dims,
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"{field_name}{self._fmt_indexer(leading_idxs)} "
|
||||||
|
f"contains inconsistent shapes: {first_shape} "
|
||||||
|
f"(index 0) vs {shape} (index {i})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Treat the list as a stacked tensor:
|
||||||
|
# shape = (len(list), *tensor.shape)
|
||||||
|
return (len(value),) + first_shape
|
||||||
|
|
||||||
|
def _validate_tensor_shape_expected(
|
||||||
|
self,
|
||||||
|
actual_shape: tuple[int, ...],
|
||||||
|
expected_shape: tuple[int | str, ...],
|
||||||
|
field_name: str,
|
||||||
|
shape_env: dict[str, int],
|
||||||
|
dynamic_dims: set[str],
|
||||||
|
) -> None:
|
||||||
|
"""Validate that the actual tensor shape matches the expected shape."""
|
||||||
|
|
||||||
|
if len(actual_shape) != len(expected_shape):
|
||||||
|
raise ValueError(
|
||||||
|
f"{field_name} has rank {len(actual_shape)} "
|
||||||
|
f"but expected {len(expected_shape)}. "
|
||||||
|
f"Expected shape: {expected_shape}, "
|
||||||
|
f"but got {actual_shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, dim in enumerate(expected_shape):
|
||||||
|
if dim in dynamic_dims:
|
||||||
|
continue
|
||||||
|
elif isinstance(dim, int):
|
||||||
|
if actual_shape[i] != dim:
|
||||||
|
raise ValueError(
|
||||||
|
f"{field_name} dim[{i}] expected "
|
||||||
|
f"{dim}, got {actual_shape[i]}. "
|
||||||
|
f"Expected shape: {expected_shape}, "
|
||||||
|
f"but got {actual_shape}"
|
||||||
|
)
|
||||||
|
elif isinstance(dim, str):
|
||||||
|
if dim in shape_env:
|
||||||
|
if actual_shape[i] != shape_env[dim]:
|
||||||
|
raise ValueError(
|
||||||
|
f"{field_name} dim[{i}] expected "
|
||||||
|
f"'{dim}'={shape_env[dim]}, got "
|
||||||
|
f"{actual_shape[i]}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
shape_env[dim] = actual_shape[i]
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"{field_name} dim[{i}] has unsupported type: {type(dim)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def validate(self) -> None:
|
||||||
|
type_hints = get_type_hints(self.__class__, include_extras=True)
|
||||||
|
shape_env = dict[str, int]()
|
||||||
|
|
||||||
|
for field_name, field_type in type_hints.items():
|
||||||
|
# Check if field is missing
|
||||||
|
if not hasattr(self, field_name) or getattr(self, field_name) is None:
|
||||||
|
# Check if field is marked as optional
|
||||||
|
actual_type = field_type
|
||||||
|
if get_origin(field_type) is Annotated:
|
||||||
|
args = get_args(field_type)
|
||||||
|
actual_type = args[0]
|
||||||
|
|
||||||
|
# Check arg was provided as Union
|
||||||
|
if get_origin(actual_type) in {Union, UnionType}:
|
||||||
|
# Union for Union[X, Y] and UnionType for X | Y
|
||||||
|
args = get_args(actual_type)
|
||||||
|
# Skip validation when Union contains None
|
||||||
|
if type(None) in args:
|
||||||
|
continue
|
||||||
|
# Otherwise field is required, raise error
|
||||||
|
raise ValueError(f"Required field '{field_name}' is missing")
|
||||||
|
|
||||||
|
# Field exists, proceed with validation
|
||||||
|
value = getattr(self, field_name)
|
||||||
|
if get_origin(field_type) is not None:
|
||||||
|
args = get_args(field_type)
|
||||||
|
|
||||||
|
for arg in args:
|
||||||
|
if isinstance(arg, TensorShape):
|
||||||
|
expected_shape = arg.resolve(**self._resolve_bindings)
|
||||||
|
actual_shape = self._validate_field(
|
||||||
|
value,
|
||||||
|
field_name,
|
||||||
|
expected_shape,
|
||||||
|
arg.dynamic_dims,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._validate_tensor_shape_expected(
|
||||||
|
actual_shape,
|
||||||
|
expected_shape,
|
||||||
|
field_name,
|
||||||
|
shape_env,
|
||||||
|
arg.dynamic_dims,
|
||||||
|
)
|
||||||
|
|
||||||
|
def print_shapes(self) -> None:
|
||||||
|
"""Print TensorShape annotations for debugging."""
|
||||||
|
logger.debug("Shapes in %s:", self.__class__.__name__)
|
||||||
|
type_hints = get_type_hints(self.__class__, include_extras=True)
|
||||||
|
|
||||||
|
for field_name, field_type in type_hints.items():
|
||||||
|
if get_origin(field_type) is not None:
|
||||||
|
args = get_args(field_type)
|
||||||
|
for arg in args:
|
||||||
|
if isinstance(arg, TensorShape):
|
||||||
|
logger.debug(" %s: %s", field_name, str(arg))
|
||||||
657
vllm/_utils/torch_utils.py
Normal file
657
vllm/_utils/torch_utils.py
Normal file
@@ -0,0 +1,657 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import contextlib
|
||||||
|
import importlib.metadata
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
from collections.abc import Callable, Collection
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import TYPE_CHECKING, Any, TypeVar
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
|
import torch
|
||||||
|
from packaging import version
|
||||||
|
from packaging.version import Version
|
||||||
|
from torch.library import Library
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
else:
|
||||||
|
ModelConfig = object
|
||||||
|
IntermediateTensors = object
|
||||||
|
|
||||||
|
|
||||||
|
STR_DTYPE_TO_TORCH_DTYPE = {
|
||||||
|
"float32": torch.float32,
|
||||||
|
"half": torch.half,
|
||||||
|
"bfloat16": torch.bfloat16,
|
||||||
|
"float": torch.float,
|
||||||
|
"fp8": torch.uint8,
|
||||||
|
"fp8_e4m3": torch.uint8,
|
||||||
|
"fp8_e5m2": torch.uint8,
|
||||||
|
"int8": torch.int8,
|
||||||
|
"fp8_inc": torch.float8_e4m3fn,
|
||||||
|
"fp8_ds_mla": torch.uint8,
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_DTYPE_TO_NUMPY_DTYPE = {
|
||||||
|
torch.float16: np.float16,
|
||||||
|
torch.float32: np.float32,
|
||||||
|
torch.float64: np.float64,
|
||||||
|
torch.uint8: np.uint8,
|
||||||
|
torch.int32: np.int32,
|
||||||
|
torch.int64: np.int64,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def set_default_torch_dtype(dtype: torch.dtype):
|
||||||
|
"""Sets the default torch dtype to the given dtype."""
|
||||||
|
old_dtype = torch.get_default_dtype()
|
||||||
|
torch.set_default_dtype(dtype)
|
||||||
|
yield
|
||||||
|
torch.set_default_dtype(old_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def set_default_torch_num_threads(num_threads: int):
|
||||||
|
"""Sets the default number of threads for PyTorch to the given value."""
|
||||||
|
old_num_threads = torch.get_num_threads()
|
||||||
|
torch.set_num_threads(num_threads)
|
||||||
|
yield
|
||||||
|
torch.set_num_threads(old_num_threads)
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def guard_cuda_initialization():
|
||||||
|
"""Avoid unexpected CUDA initialization."""
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
if not current_platform.is_cuda():
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
had_key = "CUDA_VISIBLE_DEVICES" in os.environ
|
||||||
|
old_value = os.environ.get("CUDA_VISIBLE_DEVICES")
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
except Exception as e:
|
||||||
|
if "No CUDA GPUs are available" in str(e):
|
||||||
|
err_msg = "CUDA initialization is blocked."
|
||||||
|
else:
|
||||||
|
err_msg = str(e)
|
||||||
|
raise RuntimeError(err_msg) from e
|
||||||
|
finally:
|
||||||
|
if had_key:
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = old_value
|
||||||
|
else:
|
||||||
|
os.environ.pop("CUDA_VISIBLE_DEVICES")
|
||||||
|
|
||||||
|
|
||||||
|
def get_dtype_size(dtype: torch.dtype) -> int:
|
||||||
|
"""Get the size of the data type in bytes."""
|
||||||
|
return torch.tensor([], dtype=dtype).element_size()
|
||||||
|
|
||||||
|
|
||||||
|
# bool = 0, int = 1, float = 2, complex = 3
|
||||||
|
def _get_precision_level(dtype: torch.dtype) -> int:
|
||||||
|
# NOTE: Complex dtypes return `is_floating_point=False`
|
||||||
|
return (dtype != torch.bool) + dtype.is_floating_point + dtype.is_complex * 2
|
||||||
|
|
||||||
|
|
||||||
|
def is_lossless_cast(src_dtype: torch.dtype, tgt_dtype: torch.dtype):
|
||||||
|
"""
|
||||||
|
Test whether it is lossless to cast a tensor from
|
||||||
|
`src_dtype` to `tgt_dtype`.
|
||||||
|
"""
|
||||||
|
if src_dtype == tgt_dtype:
|
||||||
|
return True
|
||||||
|
|
||||||
|
src_level = _get_precision_level(src_dtype)
|
||||||
|
tgt_level = _get_precision_level(tgt_dtype)
|
||||||
|
|
||||||
|
if src_level < tgt_level:
|
||||||
|
return True
|
||||||
|
if src_level > tgt_level:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Compare integral types
|
||||||
|
if not src_dtype.is_floating_point and not src_dtype.is_complex:
|
||||||
|
src_info = torch.iinfo(src_dtype)
|
||||||
|
tgt_info = torch.iinfo(tgt_dtype)
|
||||||
|
return src_info.min >= tgt_info.min and src_info.max <= tgt_info.max
|
||||||
|
|
||||||
|
# Compare floating-point types
|
||||||
|
src_info = torch.finfo(src_dtype)
|
||||||
|
tgt_info = torch.finfo(tgt_dtype)
|
||||||
|
return (
|
||||||
|
src_info.min >= tgt_info.min
|
||||||
|
and src_info.max <= tgt_info.max
|
||||||
|
and src_info.resolution >= tgt_info.resolution
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def common_broadcastable_dtype(dtypes: Collection[torch.dtype]):
|
||||||
|
"""
|
||||||
|
Get the common `dtype` where all of the other `dtypes` can be
|
||||||
|
cast to it without losing any information.
|
||||||
|
"""
|
||||||
|
return max(
|
||||||
|
dtypes,
|
||||||
|
key=lambda dtype: sum(is_lossless_cast(dt, dtype) for dt in dtypes),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_random_fp8(
|
||||||
|
tensor: torch.Tensor,
|
||||||
|
low: float,
|
||||||
|
high: float,
|
||||||
|
) -> None:
|
||||||
|
# NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type,
|
||||||
|
# it may occur Inf or NaN if we directly use torch.randint
|
||||||
|
# to generate random data for fp8 data.
|
||||||
|
# For example, s.11111.00 in fp8e5m2 format represents Inf.
|
||||||
|
# | E4M3 | E5M2
|
||||||
|
# -----|-------------|-------------------
|
||||||
|
# Inf | N/A | s.11111.00
|
||||||
|
# NaN | s.1111.111 | s.11111.{01,10,11}
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
|
tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
|
||||||
|
tensor_tmp.uniform_(low, high)
|
||||||
|
ops.convert_fp8(tensor, tensor_tmp)
|
||||||
|
del tensor_tmp
|
||||||
|
|
||||||
|
|
||||||
|
def get_kv_cache_torch_dtype(
|
||||||
|
cache_dtype: str | torch.dtype | None,
|
||||||
|
model_dtype: str | torch.dtype | None = None,
|
||||||
|
) -> torch.dtype:
|
||||||
|
if isinstance(cache_dtype, str):
|
||||||
|
if cache_dtype == "auto":
|
||||||
|
if isinstance(model_dtype, str) and model_dtype in STR_DTYPE_TO_TORCH_DTYPE:
|
||||||
|
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
|
||||||
|
elif isinstance(model_dtype, torch.dtype):
|
||||||
|
torch_dtype = model_dtype
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid model dtype: {model_dtype}")
|
||||||
|
elif cache_dtype in STR_DTYPE_TO_TORCH_DTYPE:
|
||||||
|
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
|
||||||
|
elif isinstance(cache_dtype, torch.dtype):
|
||||||
|
torch_dtype = cache_dtype
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
|
||||||
|
return torch_dtype
|
||||||
|
|
||||||
|
|
||||||
|
def kv_cache_dtype_str_to_dtype(
|
||||||
|
kv_cache_dtype: str, model_config: ModelConfig
|
||||||
|
) -> torch.dtype:
|
||||||
|
if kv_cache_dtype == "auto":
|
||||||
|
# Model config may not be specified for unit tests, default to float16
|
||||||
|
return model_config.dtype if model_config else torch.half
|
||||||
|
return STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
|
||||||
|
|
||||||
|
|
||||||
|
def create_kv_caches_with_random_flash(
|
||||||
|
num_blocks: int,
|
||||||
|
block_size: int,
|
||||||
|
num_layers: int,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
cache_dtype: str | torch.dtype | None,
|
||||||
|
model_dtype: str | torch.dtype | None = None,
|
||||||
|
seed: int | None = None,
|
||||||
|
device: str | None = "cuda",
|
||||||
|
cache_layout: str | None = "NHD",
|
||||||
|
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
current_platform.seed_everything(seed)
|
||||||
|
|
||||||
|
dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
|
||||||
|
generic_kv_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
|
||||||
|
assert cache_layout in ("NHD", "HND")
|
||||||
|
stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, 2, 4)
|
||||||
|
|
||||||
|
kv_cache_allocation_shape = tuple(generic_kv_cache_shape[i] for i in stride_order)
|
||||||
|
scale = head_size**-0.5
|
||||||
|
|
||||||
|
key_caches: list[torch.Tensor] = []
|
||||||
|
value_caches: list[torch.Tensor] = []
|
||||||
|
|
||||||
|
for _ in range(num_layers):
|
||||||
|
key_value_cache = torch.empty(
|
||||||
|
size=kv_cache_allocation_shape, dtype=dtype, device=device
|
||||||
|
).permute(*stride_order)
|
||||||
|
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
|
||||||
|
key_value_cache.uniform_(-scale, scale)
|
||||||
|
elif cache_dtype == "fp8":
|
||||||
|
_generate_random_fp8(key_value_cache, -scale, scale)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Does not support key cache of type {cache_dtype}")
|
||||||
|
key_caches.append(key_value_cache[:, 0])
|
||||||
|
value_caches.append(key_value_cache[:, 1])
|
||||||
|
return key_caches, value_caches
|
||||||
|
|
||||||
|
|
||||||
|
def create_kv_caches_with_random(
|
||||||
|
num_blocks: int,
|
||||||
|
block_size: int,
|
||||||
|
num_layers: int,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
cache_dtype: str | torch.dtype | None,
|
||||||
|
model_dtype: str | torch.dtype | None = None,
|
||||||
|
seed: int | None = None,
|
||||||
|
device: str | None = "cuda",
|
||||||
|
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
|
||||||
|
if cache_dtype == "fp8" and head_size % 16:
|
||||||
|
raise ValueError(
|
||||||
|
f"Does not support key cache of type fp8 with head_size {head_size}"
|
||||||
|
)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
current_platform.seed_everything(seed)
|
||||||
|
|
||||||
|
dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
|
||||||
|
|
||||||
|
scale = head_size**-0.5
|
||||||
|
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||||
|
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
||||||
|
key_caches: list[torch.Tensor] = []
|
||||||
|
for _ in range(num_layers):
|
||||||
|
key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device=device)
|
||||||
|
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
|
||||||
|
key_cache.uniform_(-scale, scale)
|
||||||
|
elif cache_dtype == "fp8":
|
||||||
|
_generate_random_fp8(key_cache, -scale, scale)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Does not support key cache of type {cache_dtype}")
|
||||||
|
key_caches.append(key_cache)
|
||||||
|
|
||||||
|
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
||||||
|
value_caches: list[torch.Tensor] = []
|
||||||
|
for _ in range(num_layers):
|
||||||
|
value_cache = torch.empty(size=value_cache_shape, dtype=dtype, device=device)
|
||||||
|
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
|
||||||
|
value_cache.uniform_(-scale, scale)
|
||||||
|
elif cache_dtype == "fp8":
|
||||||
|
_generate_random_fp8(value_cache, -scale, scale)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Does not support value cache of type {cache_dtype}")
|
||||||
|
value_caches.append(value_cache)
|
||||||
|
return key_caches, value_caches
|
||||||
|
|
||||||
|
|
||||||
|
def async_tensor_h2d(
|
||||||
|
data: list,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
target_device: str | torch.device,
|
||||||
|
pin_memory: bool,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Asynchronously create a tensor and copy it from host to device."""
|
||||||
|
t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
|
||||||
|
return t.to(device=target_device, non_blocking=True)
|
||||||
|
|
||||||
|
|
||||||
|
def make_ndarray_with_pad(
|
||||||
|
x: list[list[T]],
|
||||||
|
pad: T,
|
||||||
|
dtype: npt.DTypeLike,
|
||||||
|
*,
|
||||||
|
max_len: int | None = None,
|
||||||
|
) -> npt.NDArray:
|
||||||
|
"""
|
||||||
|
Make a padded array from 2D inputs.
|
||||||
|
|
||||||
|
The padding is applied to the end of each inner list until it reaches
|
||||||
|
`max_len`.
|
||||||
|
"""
|
||||||
|
if max_len is None:
|
||||||
|
# Unlike for most functions, map is faster than a genexpr over `len`
|
||||||
|
max_len = max(map(len, x), default=0)
|
||||||
|
|
||||||
|
padded_x = np.full((len(x), max_len), pad, dtype=dtype)
|
||||||
|
for ind, blocktb in enumerate(x):
|
||||||
|
assert len(blocktb) <= max_len
|
||||||
|
padded_x[ind, : len(blocktb)] = blocktb
|
||||||
|
|
||||||
|
return padded_x
|
||||||
|
|
||||||
|
|
||||||
|
def make_tensor_with_pad(
|
||||||
|
x: list[list[T]],
|
||||||
|
pad: T,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
*,
|
||||||
|
max_len: int | None = None,
|
||||||
|
device: str | torch.device | None = None,
|
||||||
|
pin_memory: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Make a padded tensor from 2D inputs.
|
||||||
|
|
||||||
|
The padding is applied to the end of each inner list until it reaches
|
||||||
|
`max_len`.
|
||||||
|
"""
|
||||||
|
np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype]
|
||||||
|
padded_x = make_ndarray_with_pad(x, pad, np_dtype, max_len=max_len)
|
||||||
|
|
||||||
|
tensor = torch.from_numpy(padded_x).to(device)
|
||||||
|
if pin_memory:
|
||||||
|
tensor = tensor.pin_memory()
|
||||||
|
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
prev_set_stream = torch.cuda.set_stream
|
||||||
|
|
||||||
|
_current_stream_tls = threading.local()
|
||||||
|
|
||||||
|
|
||||||
|
def _patched_set_stream(stream: torch.cuda.Stream) -> None:
|
||||||
|
_current_stream_tls.value = stream
|
||||||
|
prev_set_stream(stream)
|
||||||
|
|
||||||
|
|
||||||
|
torch.cuda.set_stream = _patched_set_stream
|
||||||
|
|
||||||
|
|
||||||
|
class _StreamPlaceholder:
|
||||||
|
def __init__(self):
|
||||||
|
self.synchronize = lambda: None
|
||||||
|
|
||||||
|
|
||||||
|
def current_stream() -> torch.cuda.Stream:
|
||||||
|
"""
|
||||||
|
replace `torch.cuda.current_stream()` with `vllm.utils.current_stream()`.
|
||||||
|
it turns out that `torch.cuda.current_stream()` is quite expensive,
|
||||||
|
as it will construct a new stream object at each call.
|
||||||
|
here we patch `torch.cuda.set_stream` to keep track of the current stream
|
||||||
|
directly, so that we can avoid calling `torch.cuda.current_stream()`.
|
||||||
|
|
||||||
|
the underlying hypothesis is that we do not call `torch._C._cuda_setStream`
|
||||||
|
from C/C++ code.
|
||||||
|
"""
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
if not hasattr(_current_stream_tls, "value") or _current_stream_tls.value is None:
|
||||||
|
# when this function is called before any stream is set,
|
||||||
|
# we return the default stream.
|
||||||
|
# On ROCm using the default 0 stream in combination with RCCL
|
||||||
|
# is hurting performance. Therefore creating a dedicated stream
|
||||||
|
# per process
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
# torch.cuda.set_stream here is the alias of _pathed_set_stream
|
||||||
|
torch.cuda.set_stream(torch.cuda.Stream())
|
||||||
|
elif current_platform.is_cpu():
|
||||||
|
_current_stream_tls.value = _StreamPlaceholder()
|
||||||
|
else:
|
||||||
|
current_stream = current_platform.current_stream
|
||||||
|
if current_stream is not None:
|
||||||
|
_current_stream_tls.value = current_stream()
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Fail to set current stream, current platform "
|
||||||
|
"may not support current_stream with torch API"
|
||||||
|
)
|
||||||
|
return _current_stream_tls.value
|
||||||
|
|
||||||
|
|
||||||
|
# Global auxilary stream for running operations in background streams.
|
||||||
|
# We have single global auxilary stream to avoid an explosion of streams
|
||||||
|
# for every layer (and make profiling look sane).
|
||||||
|
#
|
||||||
|
# aux_stream() is currently used for:
|
||||||
|
# - MoE shared_expert overlap with router
|
||||||
|
_aux_stream: torch.cuda.Stream | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def aux_stream() -> torch.cuda.Stream | None:
|
||||||
|
"""
|
||||||
|
Ensures aux_stream is initialized only once
|
||||||
|
"""
|
||||||
|
global _aux_stream
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
# TODO: validate this works properly on ROCm platform.
|
||||||
|
if _aux_stream is None and current_platform.is_cuda():
|
||||||
|
_aux_stream = torch.cuda.Stream()
|
||||||
|
|
||||||
|
return _aux_stream
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=8)
|
||||||
|
def _cuda_device_count_stateless(cuda_visible_devices: str | None = None) -> int:
|
||||||
|
# Note: cuda_visible_devices is not used, but we keep it as an argument for
|
||||||
|
# LRU Cache purposes.
|
||||||
|
|
||||||
|
# Code below is based on
|
||||||
|
# https://github.com/pytorch/pytorch/blob/
|
||||||
|
# c1cd946818442aca8c7f812b16d187ce1586c3bc/
|
||||||
|
# torch/cuda/__init__.py#L831C1-L831C17
|
||||||
|
import torch.cuda
|
||||||
|
import torch.version
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
if not torch.cuda._is_compiled():
|
||||||
|
return 0
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
# ROCm uses amdsmi instead of nvml for stateless device count
|
||||||
|
# This requires a sufficiently modern version of Torch 2.4.0
|
||||||
|
raw_count = (
|
||||||
|
torch.cuda._device_count_amdsmi()
|
||||||
|
if (hasattr(torch.cuda, "_device_count_amdsmi"))
|
||||||
|
else -1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raw_count = torch.cuda._device_count_nvml()
|
||||||
|
r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
|
||||||
|
return r
|
||||||
|
|
||||||
|
|
||||||
|
def cuda_device_count_stateless() -> int:
|
||||||
|
"""Get number of CUDA devices, caching based on the value of
|
||||||
|
CUDA_VISIBLE_DEVICES at the time of call.
|
||||||
|
|
||||||
|
This should be used instead of torch.cuda.device_count()
|
||||||
|
unless CUDA_VISIBLE_DEVICES has already been set to the desired
|
||||||
|
value."""
|
||||||
|
|
||||||
|
# This can be removed and simply replaced with torch.cuda.get_device_count
|
||||||
|
# after https://github.com/pytorch/pytorch/pull/122815 is released.
|
||||||
|
return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
|
||||||
|
|
||||||
|
|
||||||
|
def weak_ref_tensor(tensor: Any) -> Any:
|
||||||
|
"""
|
||||||
|
Create a weak reference to a tensor.
|
||||||
|
The new tensor will share the same data as the original tensor,
|
||||||
|
but will not keep the original tensor alive.
|
||||||
|
"""
|
||||||
|
if isinstance(tensor, torch.Tensor):
|
||||||
|
return torch.ops._C.weak_ref_tensor(tensor)
|
||||||
|
else:
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def weak_ref_tensors(
|
||||||
|
tensors: torch.Tensor
|
||||||
|
| list[torch.Tensor]
|
||||||
|
| tuple[torch.Tensor]
|
||||||
|
| IntermediateTensors,
|
||||||
|
) -> torch.Tensor | list[Any] | tuple[Any] | Any:
|
||||||
|
"""
|
||||||
|
Convenience function to create weak references to tensors,
|
||||||
|
for single tensor, list of tensors or tuple of tensors.
|
||||||
|
"""
|
||||||
|
if isinstance(tensors, torch.Tensor):
|
||||||
|
return weak_ref_tensor(tensors)
|
||||||
|
if isinstance(tensors, list):
|
||||||
|
return [weak_ref_tensor(t) for t in tensors]
|
||||||
|
if isinstance(tensors, tuple):
|
||||||
|
return tuple(weak_ref_tensor(t) for t in tensors)
|
||||||
|
|
||||||
|
# For IntermediateTensors used in pipeline parallelism
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
|
if isinstance(tensors, IntermediateTensors):
|
||||||
|
ret = IntermediateTensors(
|
||||||
|
{key: weak_ref_tensor(val) for key, val in tensors.tensors.items()}
|
||||||
|
)
|
||||||
|
return ret
|
||||||
|
raise ValueError("Invalid type for tensors")
|
||||||
|
|
||||||
|
|
||||||
|
def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Get a CUDA view of a CPU tensor using Unified Virtual Addressing (UVA).
|
||||||
|
"""
|
||||||
|
assert cpu_tensor.is_pinned(), "CPU tensor must be pinned"
|
||||||
|
return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor)
|
||||||
|
|
||||||
|
|
||||||
|
# Helper function used in testing.
|
||||||
|
def _is_torch_equal_or_newer(torch_version: str, target: str) -> bool:
|
||||||
|
torch_version = version.parse(torch_version)
|
||||||
|
return torch_version >= version.parse(target)
|
||||||
|
|
||||||
|
|
||||||
|
def is_torch_equal_or_newer(target: str) -> bool:
|
||||||
|
"""Check if the installed torch version is >= the target version.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target: a version string, like "2.6.0".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Whether the condition meets.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return _is_torch_equal_or_newer(str(torch.__version__), target)
|
||||||
|
except Exception:
|
||||||
|
# Fallback to PKG-INFO to load the package info, needed by the doc gen.
|
||||||
|
return Version(importlib.metadata.version("torch")) >= Version(target)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_torch_equal(target: str) -> bool:
|
||||||
|
assert target.count(".") == 2
|
||||||
|
torch_version = str(torch.__version__)
|
||||||
|
torch_version = version.parse(torch_version)
|
||||||
|
# torch version is like "2.6.0.dev20240101" or "2.6.0.dev20240101+cpu"
|
||||||
|
# or "2.6.0+cu128" but never "2.6.0.1"
|
||||||
|
return (
|
||||||
|
torch_version >= version.parse(target)
|
||||||
|
and version.parse(target + ".1") > torch_version
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_torch_equal(target: str) -> bool:
|
||||||
|
"""Check if the installed torch version is == the target version.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target: a version string, like "2.6.0".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Whether the condition meets.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return _is_torch_equal(target)
|
||||||
|
except Exception:
|
||||||
|
return Version(importlib.metadata.version("torch")) == Version(target)
|
||||||
|
|
||||||
|
|
||||||
|
# Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0.
|
||||||
|
# In particular, the FakeScalarType is not supported for earlier versions of
|
||||||
|
# PyTorch which breaks dynamo for any ops registered using ScalarType.
|
||||||
|
def supports_dynamo() -> bool:
|
||||||
|
return is_torch_equal_or_newer("2.4.0")
|
||||||
|
|
||||||
|
|
||||||
|
# Supports xccl with PyTorch versions >= 2.8.0.dev for XPU platform
|
||||||
|
def supports_xccl() -> bool:
|
||||||
|
return (
|
||||||
|
is_torch_equal_or_newer("2.8.0.dev") and torch.distributed.is_xccl_available()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Some backends use pytorch version < 2.4.0 which doesn't
|
||||||
|
# support `torch.library.custom_op`.
|
||||||
|
def supports_custom_op() -> bool:
|
||||||
|
return hasattr(torch.library, "custom_op")
|
||||||
|
|
||||||
|
|
||||||
|
# create a library to hold the custom op
|
||||||
|
vllm_lib = Library("vllm", "FRAGMENT") # noqa
|
||||||
|
|
||||||
|
|
||||||
|
def direct_register_custom_op(
|
||||||
|
op_name: str,
|
||||||
|
op_func: Callable,
|
||||||
|
mutates_args: list[str] | None = None,
|
||||||
|
fake_impl: Callable | None = None,
|
||||||
|
target_lib: Library | None = None,
|
||||||
|
dispatch_key: str | None = None,
|
||||||
|
tags: tuple[torch.Tag, ...] = (),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
`torch.library.custom_op` can have significant overhead because it
|
||||||
|
needs to consider complicated dispatching logic. This function
|
||||||
|
directly registers a custom op and dispatches it to the CUDA backend.
|
||||||
|
See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5
|
||||||
|
for more details.
|
||||||
|
|
||||||
|
By default, the custom op is registered to the vLLM library. If you
|
||||||
|
want to register it to a different library, you can pass the library
|
||||||
|
object to the `target_lib` argument.
|
||||||
|
|
||||||
|
IMPORTANT: the lifetime of the operator is tied to the lifetime of the
|
||||||
|
library object. If you want to bind the operator to a different library,
|
||||||
|
make sure the library object is alive when the operator is used.
|
||||||
|
"""
|
||||||
|
if not supports_custom_op():
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
assert not current_platform.is_cuda_alike(), (
|
||||||
|
"cuda platform needs torch>=2.4 to support custom op, "
|
||||||
|
"chances are you are using an old version of pytorch "
|
||||||
|
"or a custom build of pytorch. It is recommended to "
|
||||||
|
"use vLLM in a fresh new environment and let it install "
|
||||||
|
"the required dependencies."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if mutates_args is None:
|
||||||
|
mutates_args = []
|
||||||
|
|
||||||
|
if dispatch_key is None:
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
dispatch_key = current_platform.dispatch_key
|
||||||
|
|
||||||
|
import torch.library
|
||||||
|
|
||||||
|
if hasattr(torch.library, "infer_schema"):
|
||||||
|
schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
|
||||||
|
else:
|
||||||
|
# for pytorch 2.4
|
||||||
|
import torch._custom_op.impl
|
||||||
|
|
||||||
|
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
|
||||||
|
my_lib = target_lib or vllm_lib
|
||||||
|
my_lib.define(op_name + schema_str, tags=tags)
|
||||||
|
my_lib.impl(op_name, op_func, dispatch_key=dispatch_key)
|
||||||
|
if fake_impl is not None:
|
||||||
|
my_lib._register_fake(op_name, fake_impl)
|
||||||
@@ -0,0 +1,91 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
import warnings
|
||||||
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
_DEPRECATED_MAPPINGS = {
|
||||||
|
"cprofile": "profiling",
|
||||||
|
"cprofile_context": "profiling",
|
||||||
|
# Used by lm-eval
|
||||||
|
"get_open_port": "network_utils",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def __getattr__(name: str) -> Any: # noqa: D401 - short deprecation docstring
|
||||||
|
"""Module-level getattr to handle deprecated utilities."""
|
||||||
|
if name in _DEPRECATED_MAPPINGS:
|
||||||
|
submodule_name = _DEPRECATED_MAPPINGS[name]
|
||||||
|
warnings.warn(
|
||||||
|
f"vllm.utils.{name} is deprecated and will be removed in a future version. "
|
||||||
|
f"Use vllm.utils.{submodule_name}.{name} instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
module = __import__(f"vllm.utils.{submodule_name}", fromlist=[submodule_name])
|
||||||
|
return getattr(module, name)
|
||||||
|
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||||
|
|
||||||
|
|
||||||
|
def __dir__() -> list[str]:
|
||||||
|
# expose deprecated names in dir() for better UX/tab-completion
|
||||||
|
return sorted(list(globals().keys()) + list(_DEPRECATED_MAPPINGS.keys()))
|
||||||
|
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# This value is chosen to have a balance between ITL and TTFT. Note it is
|
||||||
|
# not optimized for throughput.
|
||||||
|
DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048
|
||||||
|
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
|
||||||
|
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
|
||||||
|
|
||||||
|
# Constants related to forcing the attention backend selection
|
||||||
|
|
||||||
|
# String name of register which may be set in order to
|
||||||
|
# force auto-selection of attention backend by Attention
|
||||||
|
# wrapper
|
||||||
|
STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
|
||||||
|
|
||||||
|
# Possible string values of STR_BACKEND_ENV_VAR
|
||||||
|
# register, corresponding to possible backends
|
||||||
|
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
|
||||||
|
STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
|
||||||
|
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
|
||||||
|
STR_INVALID_VAL: str = "INVALID"
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def random_uuid() -> str:
|
||||||
|
return str(uuid.uuid4().hex)
|
||||||
|
|
||||||
|
|
||||||
|
def length_from_prompt_token_ids_or_embeds(
|
||||||
|
prompt_token_ids: list[int] | None,
|
||||||
|
prompt_embeds: torch.Tensor | None,
|
||||||
|
) -> int:
|
||||||
|
"""Calculate the request length (in number of tokens) give either
|
||||||
|
prompt_token_ids or prompt_embeds.
|
||||||
|
"""
|
||||||
|
prompt_token_len = None if prompt_token_ids is None else len(prompt_token_ids)
|
||||||
|
prompt_embeds_len = None if prompt_embeds is None else len(prompt_embeds)
|
||||||
|
|
||||||
|
if prompt_token_len is None:
|
||||||
|
if prompt_embeds_len is None:
|
||||||
|
raise ValueError("Neither prompt_token_ids nor prompt_embeds were defined.")
|
||||||
|
return prompt_embeds_len
|
||||||
|
else:
|
||||||
|
if prompt_embeds_len is not None and prompt_embeds_len != prompt_token_len:
|
||||||
|
raise ValueError(
|
||||||
|
"Prompt token ids and prompt embeds had different lengths"
|
||||||
|
f" prompt_token_ids={prompt_token_len}"
|
||||||
|
f" prompt_embeds={prompt_embeds_len}"
|
||||||
|
)
|
||||||
|
return prompt_token_len
|
||||||
@@ -358,7 +358,7 @@ class ModelConfig:
|
|||||||
for multimodal models."""
|
for multimodal models."""
|
||||||
use_async_output_proc: bool = True
|
use_async_output_proc: bool = True
|
||||||
"""Whether to use async output processor."""
|
"""Whether to use async output processor."""
|
||||||
config_format: Union[str, ConfigFormat] = ConfigFormat.AUTO.value
|
config_format: Union[str, ConfigFormat] = "auto"
|
||||||
"""The format of the model config to load:\n
|
"""The format of the model config to load:\n
|
||||||
- "auto" will try to load the config in hf format if available else it
|
- "auto" will try to load the config in hf format if available else it
|
||||||
will try to load in mistral format.\n
|
will try to load in mistral format.\n
|
||||||
@@ -522,8 +522,8 @@ class ModelConfig:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Sleep mode is not supported on current platform.")
|
"Sleep mode is not supported on current platform.")
|
||||||
|
|
||||||
if isinstance(self.config_format, str):
|
# if isinstance(self.config_format, str):
|
||||||
self.config_format = ConfigFormat(self.config_format)
|
# self.config_format = ConfigFormat(self.config_format)
|
||||||
|
|
||||||
hf_config = get_config(self.hf_config_path or self.model,
|
hf_config = get_config(self.hf_config_path or self.model,
|
||||||
self.trust_remote_code, self.revision,
|
self.trust_remote_code, self.revision,
|
||||||
|
|||||||
@@ -522,7 +522,6 @@ class EngineArgs:
|
|||||||
help="Disable async output processing. This may result in "
|
help="Disable async output processing. This may result in "
|
||||||
"lower performance.")
|
"lower performance.")
|
||||||
model_group.add_argument("--config-format",
|
model_group.add_argument("--config-format",
|
||||||
choices=[f.value for f in ConfigFormat],
|
|
||||||
**model_kwargs["config_format"])
|
**model_kwargs["config_format"])
|
||||||
# This one is a special case because it can bool
|
# This one is a special case because it can bool
|
||||||
# or str. TODO: Handle this in get_kwargs
|
# or str. TODO: Handle this in get_kwargs
|
||||||
|
|||||||
@@ -4,25 +4,43 @@
|
|||||||
import enum
|
import enum
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from fractions import Fraction
|
from fractions import Fraction
|
||||||
from typing import Any, Optional, Union
|
from typing import TYPE_CHECKING, Any, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||||
from vllm.model_executor.layers.linear import LinearMethodBase
|
from vllm.model_executor.layers.linear import LinearMethodBase
|
||||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig,
|
||||||
|
QuantizeMethodBase,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
||||||
get_linear_quant_method)
|
get_linear_quant_method,
|
||||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
)
|
||||||
GroupQuantScaleParameter,
|
from vllm.model_executor.parameter import (
|
||||||
PackedColumnParameter,
|
ChannelQuantScaleParameter,
|
||||||
PackedvLLMParameter,
|
GroupQuantScaleParameter,
|
||||||
RowvLLMParameter)
|
PackedColumnParameter,
|
||||||
|
PackedvLLMParameter,
|
||||||
|
RowvLLMParameter,
|
||||||
|
)
|
||||||
|
from vllm.transformers_utils.config import get_safetensors_params_metadata
|
||||||
|
from vllm.utils import is_list_of
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
|
from vllm.model_executor.models.utils import WeightsMapper
|
||||||
|
else:
|
||||||
|
QuantizationMethods = str
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
class GPTQConfig(QuantizationConfig):
|
class GPTQConfig(QuantizationConfig):
|
||||||
"""Config class for GPTQ.
|
"""Config class for GPTQ.
|
||||||
|
|
||||||
@@ -35,7 +53,10 @@ class GPTQConfig(QuantizationConfig):
|
|||||||
group_size: int,
|
group_size: int,
|
||||||
desc_act: bool,
|
desc_act: bool,
|
||||||
lm_head_quantized: bool,
|
lm_head_quantized: bool,
|
||||||
dynamic: dict[str, dict[str, Union[int, bool]]],
|
dynamic: dict[str, dict[str, int | bool]],
|
||||||
|
autoround_version: str = "",
|
||||||
|
modules_in_block_to_quantize: list[str] | None = None,
|
||||||
|
checkpoint_format: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
# GPTQModel use `dynamic` config property to allow per module
|
# GPTQModel use `dynamic` config property to allow per module
|
||||||
# quantization config so each module can be individually optimized.
|
# quantization config so each module can be individually optimized.
|
||||||
@@ -71,23 +92,44 @@ class GPTQConfig(QuantizationConfig):
|
|||||||
if self.weight_bits not in [2, 3, 4, 8]:
|
if self.weight_bits not in [2, 3, 4, 8]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Currently, only 2/3/4/8-bit weight quantization is "
|
"Currently, only 2/3/4/8-bit weight quantization is "
|
||||||
f"supported for GPTQ, but got {self.weight_bits} bits.")
|
f"supported for GPTQ, but got {self.weight_bits} bits."
|
||||||
|
)
|
||||||
|
# Somehow gptq_gemm 4-bit is buggy, maybe fix it in the future.
|
||||||
|
# For now, show a warning, since gptq_marlin will be used by default.
|
||||||
|
if self.weight_bits == 4:
|
||||||
|
logger.warning_once(
|
||||||
|
"Currently, the 4-bit gptq_gemm kernel for GPTQ is buggy. "
|
||||||
|
"Please switch to gptq_marlin or gptq_bitblas."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.modules_in_block_to_quantize = modules_in_block_to_quantize or []
|
||||||
|
|
||||||
|
# used to identify GPTQ model quantized by autoround
|
||||||
|
self.autoround_version = autoround_version
|
||||||
|
|
||||||
|
# GPTQ v1 and v2 format deals with zero points differently.
|
||||||
|
# Currently GPTQModel stores v1 format checkpoints by default,
|
||||||
|
# but provides the option to set `format="gptq_v2"` in `QuantizeConfig`.
|
||||||
|
self.checkpoint_format = checkpoint_format
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (f"GPTQConfig(weight_bits={self.weight_bits}, "
|
return (
|
||||||
f"group_size={self.group_size}, "
|
f"GPTQConfig(weight_bits={self.weight_bits}, "
|
||||||
f"desc_act={self.desc_act}), "
|
f"group_size={self.group_size}, "
|
||||||
f"lm_head_quantized={self.lm_head_quantized}), "
|
f"desc_act={self.desc_act}), "
|
||||||
f"dynamic={self.dynamic}")
|
f"lm_head_quantized={self.lm_head_quantized}, "
|
||||||
|
f"dynamic={self.dynamic}, "
|
||||||
|
f"modules_in_block_to_quantize={self.modules_in_block_to_quantize}), "
|
||||||
|
f"checkpoint_format={self.checkpoint_format})"
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_name(cls) -> QuantizationMethods:
|
def get_name(cls) -> QuantizationMethods:
|
||||||
return "gptq"
|
return "gptq"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
||||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||||
return [torch.half, torch.bfloat16]
|
return [torch.half]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
# Need to figure it out
|
# Need to figure it out
|
||||||
@@ -106,18 +148,77 @@ class GPTQConfig(QuantizationConfig):
|
|||||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||||
group_size = cls.get_from_keys(config, ["group_size"])
|
group_size = cls.get_from_keys(config, ["group_size"])
|
||||||
desc_act = cls.get_from_keys(config, ["desc_act"])
|
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
|
||||||
default=False)
|
autoround_version = cls.get_from_keys_or(
|
||||||
return cls(weight_bits, group_size, desc_act, lm_head_quantized,
|
config, ["autoround_version"], default=""
|
||||||
dynamic)
|
)
|
||||||
|
modules_in_block_to_quantize = cls.get_from_keys_or(
|
||||||
|
config, ["modules_in_block_to_quantize"], default=None
|
||||||
|
)
|
||||||
|
checkpoint_format = cls.get_from_keys_or(
|
||||||
|
config, ["checkpoint_format"], default=""
|
||||||
|
)
|
||||||
|
return cls(
|
||||||
|
weight_bits,
|
||||||
|
group_size,
|
||||||
|
desc_act,
|
||||||
|
lm_head_quantized,
|
||||||
|
dynamic,
|
||||||
|
autoround_version,
|
||||||
|
modules_in_block_to_quantize,
|
||||||
|
checkpoint_format,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_quant_method(
|
||||||
|
self, layer: torch.nn.Module, prefix: str
|
||||||
|
) -> Union["GPTQLinearMethod", "QuantizeMethodBase"] | None:
|
||||||
|
if isinstance(layer, FusedMoE):
|
||||||
|
# GPTQ MoE support: fall back to MoeWNA16 for broad compatibility
|
||||||
|
from .moe_wna16 import MoeWNA16Config
|
||||||
|
|
||||||
|
print("Using MoeWNA16Config for GPTQ MoE layer quantization.")
|
||||||
|
# TODO: maybe update this for GPTQv2 format checkpoints
|
||||||
|
config = {
|
||||||
|
"quant_method": "gptq",
|
||||||
|
"bits": self.weight_bits,
|
||||||
|
"group_size": self.group_size,
|
||||||
|
"sym": True, # GPTQ typically uses symmetric quantization
|
||||||
|
"lm_head": False,
|
||||||
|
}
|
||||||
|
return MoeWNA16Config.from_config(config).get_quant_method(layer, prefix)
|
||||||
|
|
||||||
def get_quant_method(self, layer: torch.nn.Module,
|
|
||||||
prefix: str) -> Optional["GPTQLinearMethod"]:
|
|
||||||
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
|
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
|
||||||
|
|
||||||
|
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
|
||||||
|
if self.modules_in_block_to_quantize is not None:
|
||||||
|
self.modules_in_block_to_quantize = hf_to_vllm_mapper.apply_list(
|
||||||
|
self.modules_in_block_to_quantize
|
||||||
|
)
|
||||||
|
|
||||||
|
def maybe_update_config(self, model_name: str, revision: str | None = None):
|
||||||
|
if self.modules_in_block_to_quantize:
|
||||||
|
if is_list_of(self.modules_in_block_to_quantize, list):
|
||||||
|
# original modules_in_block_to_quantize: list[list[str]]
|
||||||
|
# flatten original modules_in_block_to_quantize
|
||||||
|
self.modules_in_block_to_quantize = [
|
||||||
|
item
|
||||||
|
for sublist in self.modules_in_block_to_quantize
|
||||||
|
for item in sublist
|
||||||
|
]
|
||||||
|
return
|
||||||
|
|
||||||
|
unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||||
|
metadata = get_safetensors_params_metadata(model_name, revision=revision)
|
||||||
|
quant_layers: set[str] = {
|
||||||
|
param_name.rsplit(".", 1)[0]
|
||||||
|
for param_name, info in metadata.items()
|
||||||
|
if (dtype := info.get("dtype", None))
|
||||||
|
and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes
|
||||||
|
}
|
||||||
|
self.modules_in_block_to_quantize = list(quant_layers)
|
||||||
|
|
||||||
|
|
||||||
class ExllamaState(Enum):
|
class ExllamaState(Enum):
|
||||||
|
|
||||||
UNUSED = enum.auto()
|
UNUSED = enum.auto()
|
||||||
UNINITIALIZED = enum.auto()
|
UNINITIALIZED = enum.auto()
|
||||||
READY = enum.auto()
|
READY = enum.auto()
|
||||||
@@ -133,6 +234,9 @@ class GPTQLinearMethod(LinearMethodBase):
|
|||||||
def __init__(self, quant_config: GPTQConfig):
|
def __init__(self, quant_config: GPTQConfig):
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
|
# GPTQ v1 and v2 format deals with zero points differently
|
||||||
|
self.use_v2_format = quant_config.checkpoint_format == "gptq_v2"
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@@ -149,14 +253,15 @@ class GPTQLinearMethod(LinearMethodBase):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The input size is not aligned with the quantized "
|
"The input size is not aligned with the quantized "
|
||||||
"weight shape. This can be caused by too large "
|
"weight shape. This can be caused by too large "
|
||||||
"tensor parallel size.")
|
"tensor parallel size."
|
||||||
|
)
|
||||||
output_size_per_partition = sum(output_partition_sizes)
|
output_size_per_partition = sum(output_partition_sizes)
|
||||||
if (output_size_per_partition % self.quant_config.pack_factor.numerator
|
if output_size_per_partition % self.quant_config.pack_factor.numerator != 0:
|
||||||
!= 0):
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The output size is not aligned with the quantized "
|
"The output size is not aligned with the quantized "
|
||||||
"weight shape. This can be caused by too large "
|
"weight shape. This can be caused by too large "
|
||||||
"tensor parallel size.")
|
"tensor parallel size."
|
||||||
|
)
|
||||||
|
|
||||||
if self.quant_config.group_size != -1:
|
if self.quant_config.group_size != -1:
|
||||||
group_size = self.quant_config.group_size
|
group_size = self.quant_config.group_size
|
||||||
@@ -165,8 +270,10 @@ class GPTQLinearMethod(LinearMethodBase):
|
|||||||
exllama_state = ExllamaState.UNINITIALIZED
|
exllama_state = ExllamaState.UNINITIALIZED
|
||||||
scale_and_zero_size = input_size // group_size
|
scale_and_zero_size = input_size // group_size
|
||||||
scale_and_zero_input_dim = None
|
scale_and_zero_input_dim = None
|
||||||
if (input_size != input_size_per_partition
|
if (
|
||||||
and self.quant_config.group_size != -1):
|
input_size != input_size_per_partition
|
||||||
|
and self.quant_config.group_size != -1
|
||||||
|
):
|
||||||
# For act-order models, we cannot use Exllama for row parallel layer
|
# For act-order models, we cannot use Exllama for row parallel layer
|
||||||
if self.quant_config.desc_act:
|
if self.quant_config.desc_act:
|
||||||
exllama_state = ExllamaState.UNUSED
|
exllama_state = ExllamaState.UNUSED
|
||||||
@@ -185,56 +292,56 @@ class GPTQLinearMethod(LinearMethodBase):
|
|||||||
output_dim=1,
|
output_dim=1,
|
||||||
packed_dim=0,
|
packed_dim=0,
|
||||||
packed_factor=self.quant_config.pack_factor,
|
packed_factor=self.quant_config.pack_factor,
|
||||||
weight_loader=weight_loader)
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
|
||||||
g_idx = RowvLLMParameter(data=torch.tensor(
|
g_idx = RowvLLMParameter(
|
||||||
[
|
data=torch.tensor(
|
||||||
i // self.quant_config.group_size
|
[
|
||||||
for i in range(input_size_per_partition)
|
i // self.quant_config.group_size
|
||||||
],
|
for i in range(input_size_per_partition)
|
||||||
dtype=torch.int32,
|
],
|
||||||
),
|
dtype=torch.int32,
|
||||||
input_dim=0,
|
),
|
||||||
weight_loader=weight_loader)
|
input_dim=0,
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
qzeros_args = {
|
qzeros_args = {
|
||||||
"data":
|
"data": torch.empty(
|
||||||
torch.empty(
|
|
||||||
scale_and_zero_size,
|
scale_and_zero_size,
|
||||||
output_size_per_partition // self.quant_config.pack_factor,
|
output_size_per_partition // self.quant_config.pack_factor,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
),
|
),
|
||||||
"weight_loader":
|
"weight_loader": weight_loader,
|
||||||
weight_loader
|
|
||||||
}
|
}
|
||||||
weight_scale_args = {
|
weight_scale_args = {
|
||||||
"data":
|
"data": torch.empty(
|
||||||
torch.empty(
|
|
||||||
scale_and_zero_size,
|
scale_and_zero_size,
|
||||||
output_size_per_partition,
|
output_size_per_partition,
|
||||||
dtype=params_dtype,
|
dtype=params_dtype,
|
||||||
),
|
),
|
||||||
"weight_loader":
|
"weight_loader": weight_loader,
|
||||||
weight_loader
|
|
||||||
}
|
}
|
||||||
if scale_and_zero_input_dim is None:
|
if scale_and_zero_input_dim is None:
|
||||||
scales = ChannelQuantScaleParameter(output_dim=1,
|
scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args)
|
||||||
**weight_scale_args)
|
|
||||||
qzeros = PackedColumnParameter(
|
qzeros = PackedColumnParameter(
|
||||||
output_dim=1,
|
output_dim=1,
|
||||||
packed_dim=1,
|
packed_dim=1,
|
||||||
packed_factor=self.quant_config.pack_factor,
|
packed_factor=self.quant_config.pack_factor,
|
||||||
**qzeros_args)
|
**qzeros_args,
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
scales = GroupQuantScaleParameter(output_dim=1,
|
scales = GroupQuantScaleParameter(
|
||||||
input_dim=0,
|
output_dim=1, input_dim=0, **weight_scale_args
|
||||||
**weight_scale_args)
|
)
|
||||||
qzeros = PackedvLLMParameter(
|
qzeros = PackedvLLMParameter(
|
||||||
input_dim=0,
|
input_dim=0,
|
||||||
output_dim=1,
|
output_dim=1,
|
||||||
packed_dim=1,
|
packed_dim=1,
|
||||||
packed_factor=self.quant_config.pack_factor,
|
packed_factor=self.quant_config.pack_factor,
|
||||||
**qzeros_args)
|
**qzeros_args,
|
||||||
|
)
|
||||||
|
|
||||||
layer.register_parameter("qweight", qweight)
|
layer.register_parameter("qweight", qweight)
|
||||||
layer.register_parameter("g_idx", g_idx)
|
layer.register_parameter("g_idx", g_idx)
|
||||||
@@ -252,79 +359,23 @@ class GPTQLinearMethod(LinearMethodBase):
|
|||||||
|
|
||||||
# exllama needs to shuffle the weight after the weight is loaded
|
# exllama needs to shuffle the weight after the weight is loaded
|
||||||
# here we do the shuffle on first forward pass
|
# here we do the shuffle on first forward pass
|
||||||
if self.quant_config.group_size == 128 or self.quant_config.group_size == 64:
|
if layer.exllama_state == ExllamaState.UNINITIALIZED:
|
||||||
if self.quant_config.desc_act:
|
if self.quant_config.desc_act:
|
||||||
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
|
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
|
||||||
else:
|
else:
|
||||||
layer.g_idx.data = torch.empty((0, ),
|
layer.g_idx.data = torch.empty(
|
||||||
dtype=torch.int,
|
(0,), dtype=torch.int, device=layer.g_idx.device
|
||||||
device=layer.g_idx.device)
|
)
|
||||||
layer.exllama_state = ExllamaState.READY
|
layer.exllama_state = ExllamaState.READY
|
||||||
ops.gptq_shuffle(layer.qweight, layer.g_idx,
|
ops.gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits)
|
||||||
self.quant_config.weight_bits)
|
|
||||||
|
|
||||||
if layer.scales.dtype != torch.bfloat16:
|
|
||||||
perm_space = torch.empty(0)
|
|
||||||
temp_space = torch.empty(0)
|
|
||||||
if self.quant_config.weight_bits == 4:
|
|
||||||
# warmup
|
|
||||||
reshaped_x = torch.randn(1, layer.qweight.shape[0]*8, dtype=layer.scales.dtype, device="cuda")
|
|
||||||
_ = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
|
|
||||||
layer.scales, layer.g_idx,
|
|
||||||
layer.exllama_state == ExllamaState.READY,
|
|
||||||
self.quant_config.weight_bits,
|
|
||||||
self.quant_config.group_size,
|
|
||||||
perm_space, temp_space,
|
|
||||||
False)
|
|
||||||
if self.quant_config.weight_bits == 8:
|
|
||||||
# warmup
|
|
||||||
reshaped_x = torch.randn(1, layer.qweight.shape[0]*4, dtype=layer.scales.dtype, device="cuda")
|
|
||||||
_ = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
|
|
||||||
layer.scales, layer.g_idx,
|
|
||||||
layer.exllama_state == ExllamaState.READY,
|
|
||||||
self.quant_config.weight_bits,
|
|
||||||
self.quant_config.group_size,
|
|
||||||
perm_space, temp_space,
|
|
||||||
False)
|
|
||||||
else:
|
|
||||||
if layer.exllama_state == ExllamaState.UNINITIALIZED:
|
|
||||||
if self.quant_config.desc_act:
|
|
||||||
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
|
|
||||||
else:
|
|
||||||
layer.g_idx.data = torch.empty((0, ),
|
|
||||||
dtype=torch.int,
|
|
||||||
device=layer.g_idx.device)
|
|
||||||
layer.exllama_state = ExllamaState.READY
|
|
||||||
ops.gptq_shuffle(layer.qweight, layer.g_idx,
|
|
||||||
self.quant_config.weight_bits)
|
|
||||||
|
|
||||||
"""
|
def apply(
|
||||||
perm_space = torch.empty(0)
|
self,
|
||||||
if self.quant_config.weight_bits == 4:
|
layer: torch.nn.Module,
|
||||||
# warmup
|
x: torch.Tensor,
|
||||||
reshaped_x = torch.randn(1, layer.qweight.shape[0]*8, dtype=layer.scales.dtype, device="cuda")
|
bias: torch.Tensor | None = None,
|
||||||
_ = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
|
) -> torch.Tensor:
|
||||||
layer.scales, layer.g_idx,
|
out_shape = x.shape[:-1] + (layer.qweight.shape[-1],)
|
||||||
layer.exllama_state == ExllamaState.READY,
|
|
||||||
self.quant_config.weight_bits,
|
|
||||||
self.quant_config.group_size,
|
|
||||||
perm_space)
|
|
||||||
if self.quant_config.weight_bits == 8:
|
|
||||||
# warmup
|
|
||||||
reshaped_x = torch.randn(1, layer.qweight.shape[0]*4, dtype=layer.scales.dtype, device="cuda")
|
|
||||||
_ = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
|
|
||||||
layer.scales, layer.g_idx,
|
|
||||||
layer.exllama_state == ExllamaState.READY,
|
|
||||||
self.quant_config.weight_bits,
|
|
||||||
self.quant_config.group_size,
|
|
||||||
perm_space)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def apply(self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
x: torch.Tensor,
|
|
||||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
||||||
out_shape = x.shape[:-1] + (layer.qweight.shape[-1], )
|
|
||||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||||
|
|
||||||
perm_space = torch.empty(0)
|
perm_space = torch.empty(0)
|
||||||
@@ -334,11 +385,12 @@ class GPTQLinearMethod(LinearMethodBase):
|
|||||||
if self.quant_config.desc_act:
|
if self.quant_config.desc_act:
|
||||||
perm_space = torch.empty(reshaped_x.shape[0], reshaped_x.shape[1],
|
perm_space = torch.empty(reshaped_x.shape[0], reshaped_x.shape[1],
|
||||||
dtype=torch.float16, device="cuda")
|
dtype=torch.float16, device="cuda")
|
||||||
|
|
||||||
if reshaped_x.dtype == torch.bfloat16:
|
if reshaped_x.dtype == torch.bfloat16:
|
||||||
temp_space = torch.zeros(reshaped_x.shape[0], layer.qweight.shape[1],
|
temp_space = torch.zeros(reshaped_x.shape[0], layer.qweight.shape[1],
|
||||||
dtype=torch.float32, device="cuda")
|
dtype=torch.float32, device="cuda")
|
||||||
|
# GPTQ v1 and v2 format checkpoints deals with zero points differently,
|
||||||
|
# and require different gemm kernels.
|
||||||
output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
|
output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
|
||||||
layer.scales, layer.g_idx,
|
layer.scales, layer.g_idx,
|
||||||
layer.exllama_state == ExllamaState.READY,
|
layer.exllama_state == ExllamaState.READY,
|
||||||
@@ -348,4 +400,4 @@ class GPTQLinearMethod(LinearMethodBase):
|
|||||||
True if reshaped_x.dtype == torch.bfloat16 else False)
|
True if reshaped_x.dtype == torch.bfloat16 else False)
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
output.add_(bias)
|
output.add_(bias)
|
||||||
return output.reshape(out_shape)
|
return output.reshape(out_shape)
|
||||||
@@ -298,6 +298,10 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
|||||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
|
enable_eplb: bool = False,
|
||||||
|
expert_load_view: torch.Tensor | None = None,
|
||||||
|
logical_to_physical_map: torch.Tensor | None = None,
|
||||||
|
logical_replica_count: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||||
assert activation == "silu", "Only SiLU activation is supported."
|
assert activation == "silu", "Only SiLU activation is supported."
|
||||||
|
|||||||
@@ -122,7 +122,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
self.gate = ReplicatedLinear(config.hidden_size,
|
self.gate = ReplicatedLinear(config.hidden_size,
|
||||||
config.num_experts,
|
config.num_experts,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=None,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.gate")
|
prefix=f"{prefix}.gate")
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -294,7 +294,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Self Attention
|
# Self Attention
|
||||||
if residual is None:
|
if residual is None:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@@ -532,4 +532,4 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP):
|
|||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
loader = AutoWeightsLoader(self)
|
loader = AutoWeightsLoader(self)
|
||||||
return loader.load_weights(weights)
|
return loader.load_weights(weights)
|
||||||
File diff suppressed because it is too large
Load Diff
20
vllm/transformers_utils/config_parser_base.py
Normal file
20
vllm/transformers_utils/config_parser_base.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigParserBase(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def parse(
|
||||||
|
self,
|
||||||
|
model: str | Path,
|
||||||
|
trust_remote_code: bool,
|
||||||
|
revision: str | None = None,
|
||||||
|
code_revision: str | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[dict, PretrainedConfig]:
|
||||||
|
raise NotImplementedError
|
||||||
59
vllm/transformers_utils/dynamic_module.py
Normal file
59
vllm/transformers_utils/dynamic_module.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import os
|
||||||
|
|
||||||
|
from transformers.dynamic_module_utils import get_class_from_dynamic_module
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def try_get_class_from_dynamic_module(
|
||||||
|
class_reference: str,
|
||||||
|
pretrained_model_name_or_path: str,
|
||||||
|
cache_dir: str | os.PathLike | None = None,
|
||||||
|
force_download: bool = False,
|
||||||
|
resume_download: bool | None = None,
|
||||||
|
proxies: dict[str, str] | None = None,
|
||||||
|
token: bool | str | None = None,
|
||||||
|
revision: str | None = None,
|
||||||
|
local_files_only: bool = False,
|
||||||
|
repo_type: str | None = None,
|
||||||
|
code_revision: str | None = None,
|
||||||
|
warn_on_fail: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
) -> type | None:
|
||||||
|
"""
|
||||||
|
As `transformers.dynamic_module_utils.get_class_from_dynamic_module`,
|
||||||
|
but ignoring any errors.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return get_class_from_dynamic_module(
|
||||||
|
class_reference,
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
resume_download=resume_download,
|
||||||
|
proxies=proxies,
|
||||||
|
token=token,
|
||||||
|
revision=revision,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
repo_type=repo_type,
|
||||||
|
code_revision=code_revision,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
location = "ModelScope" if envs.VLLM_USE_MODELSCOPE else "HF Hub"
|
||||||
|
|
||||||
|
if warn_on_fail:
|
||||||
|
logger.warning(
|
||||||
|
"Unable to load %s from %s on %s.",
|
||||||
|
class_reference,
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
location,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
@@ -2,22 +2,32 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
import struct
|
||||||
from functools import cache
|
from functools import cache
|
||||||
from os import PathLike
|
from os import PathLike
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union
|
from typing import Any
|
||||||
|
|
||||||
from vllm.envs import VLLM_MODEL_REDIRECT_PATH
|
import vllm.envs as envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def is_s3(model_or_path: str) -> bool:
|
def is_s3(model_or_path: str) -> bool:
|
||||||
return model_or_path.lower().startswith('s3://')
|
return model_or_path.lower().startswith("s3://")
|
||||||
|
|
||||||
|
|
||||||
def check_gguf_file(model: Union[str, PathLike]) -> bool:
|
def is_gcs(model_or_path: str) -> bool:
|
||||||
|
return model_or_path.lower().startswith("gs://")
|
||||||
|
|
||||||
|
|
||||||
|
def is_cloud_storage(model_or_path: str) -> bool:
|
||||||
|
return is_s3(model_or_path) or is_gcs(model_or_path)
|
||||||
|
|
||||||
|
|
||||||
|
def check_gguf_file(model: str | PathLike) -> bool:
|
||||||
"""Check if the file is a GGUF model."""
|
"""Check if the file is a GGUF model."""
|
||||||
model = Path(model)
|
model = Path(model)
|
||||||
if not model.is_file():
|
if not model.is_file():
|
||||||
@@ -37,23 +47,26 @@ def check_gguf_file(model: Union[str, PathLike]) -> bool:
|
|||||||
|
|
||||||
def modelscope_list_repo_files(
|
def modelscope_list_repo_files(
|
||||||
repo_id: str,
|
repo_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: str | None = None,
|
||||||
token: Union[str, bool, None] = None,
|
token: str | bool | None = None,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""List files in a modelscope repo."""
|
"""List files in a modelscope repo."""
|
||||||
from modelscope.hub.api import HubApi
|
from modelscope.hub.api import HubApi
|
||||||
|
|
||||||
api = HubApi()
|
api = HubApi()
|
||||||
api.login(token)
|
api.login(token)
|
||||||
# same as huggingface_hub.list_repo_files
|
# same as huggingface_hub.list_repo_files
|
||||||
files = [
|
files = [
|
||||||
file['Path'] for file in api.get_model_files(
|
file["Path"]
|
||||||
model_id=repo_id, revision=revision, recursive=True)
|
for file in api.get_model_files(
|
||||||
if file['Type'] == 'blob'
|
model_id=repo_id, revision=revision, recursive=True
|
||||||
|
)
|
||||||
|
if file["Type"] == "blob"
|
||||||
]
|
]
|
||||||
return files
|
return files
|
||||||
|
|
||||||
|
|
||||||
def _maybe_json_dict(path: Union[str, PathLike]) -> dict[str, str]:
|
def _maybe_json_dict(path: str | PathLike) -> dict[str, str]:
|
||||||
with open(path) as f:
|
with open(path) as f:
|
||||||
try:
|
try:
|
||||||
return json.loads(f.read())
|
return json.loads(f.read())
|
||||||
@@ -61,7 +74,7 @@ def _maybe_json_dict(path: Union[str, PathLike]) -> dict[str, str]:
|
|||||||
return dict[str, str]()
|
return dict[str, str]()
|
||||||
|
|
||||||
|
|
||||||
def _maybe_space_split_dict(path: Union[str, PathLike]) -> dict[str, str]:
|
def _maybe_space_split_dict(path: str | PathLike) -> dict[str, str]:
|
||||||
parsed_dict = dict[str, str]()
|
parsed_dict = dict[str, str]()
|
||||||
with open(path) as f:
|
with open(path) as f:
|
||||||
for line in f.readlines():
|
for line in f.readlines():
|
||||||
@@ -82,7 +95,7 @@ def maybe_model_redirect(model: str) -> str:
|
|||||||
:return: maybe redirect to a local folder
|
:return: maybe redirect to a local folder
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_redirect_path = VLLM_MODEL_REDIRECT_PATH
|
model_redirect_path = envs.VLLM_MODEL_REDIRECT_PATH
|
||||||
|
|
||||||
if not model_redirect_path:
|
if not model_redirect_path:
|
||||||
return model
|
return model
|
||||||
@@ -90,10 +103,28 @@ def maybe_model_redirect(model: str) -> str:
|
|||||||
if not Path(model_redirect_path).exists():
|
if not Path(model_redirect_path).exists():
|
||||||
return model
|
return model
|
||||||
|
|
||||||
redirect_dict = (_maybe_json_dict(model_redirect_path)
|
redirect_dict = _maybe_json_dict(model_redirect_path) or _maybe_space_split_dict(
|
||||||
or _maybe_space_split_dict(model_redirect_path))
|
model_redirect_path
|
||||||
if (redirect_model := redirect_dict.get(model)):
|
)
|
||||||
|
if redirect_model := redirect_dict.get(model):
|
||||||
logger.info("model redirect: [ %s ] -> [ %s ]", model, redirect_model)
|
logger.info("model redirect: [ %s ] -> [ %s ]", model, redirect_model)
|
||||||
return redirect_model
|
return redirect_model
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def parse_safetensors_file_metadata(path: str | PathLike) -> dict[str, Any]:
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
length_of_metadata = struct.unpack("<Q", f.read(8))[0]
|
||||||
|
metadata = json.loads(f.read(length_of_metadata).decode("utf-8"))
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
|
def convert_model_repo_to_path(model_repo: str) -> str:
|
||||||
|
"""When VLLM_USE_MODELSCOPE is True convert a model
|
||||||
|
repository string to a Path str."""
|
||||||
|
if not envs.VLLM_USE_MODELSCOPE or Path(model_repo).exists():
|
||||||
|
return model_repo
|
||||||
|
from modelscope.utils.file_utils import get_model_cache_root
|
||||||
|
|
||||||
|
return os.path.join(get_model_cache_root(), model_repo)
|
||||||
|
|||||||
Reference in New Issue
Block a user