fix black in pre-commit (#1940)

This commit is contained in:
Chayenne
2024-11-07 15:42:47 -08:00
committed by GitHub
parent dca87ec348
commit c77c1e05ba
29 changed files with 641 additions and 508 deletions

View File

@@ -30,6 +30,6 @@ repos:
rev: 24.10.0 rev: 24.10.0
hooks: hooks:
- id: black - id: black
additional_dependencies: ['.[jupyter]'] types: [python]
types: [python, jupyter] - id: black-jupyter
types_or: [python, jupyter] types: [jupyter]

View File

@@ -34,10 +34,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:08:08.536886Z", "iopub.execute_input": "2024-11-07T18:44:42.063503Z",
"iopub.status.busy": "2024-11-05T05:08:08.536763Z", "iopub.status.busy": "2024-11-07T18:44:42.063379Z",
"iopub.status.idle": "2024-11-05T05:08:34.725831Z", "iopub.status.idle": "2024-11-07T18:45:07.255300Z",
"shell.execute_reply": "2024-11-05T05:08:34.725316Z" "shell.execute_reply": "2024-11-07T18:45:07.254547Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -73,10 +73,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:08:34.727530Z", "iopub.execute_input": "2024-11-07T18:45:07.258292Z",
"iopub.status.busy": "2024-11-05T05:08:34.727333Z", "iopub.status.busy": "2024-11-07T18:45:07.257710Z",
"iopub.status.idle": "2024-11-05T05:08:35.359784Z", "iopub.status.idle": "2024-11-07T18:45:07.611559Z",
"shell.execute_reply": "2024-11-05T05:08:35.359090Z" "shell.execute_reply": "2024-11-07T18:45:07.610842Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -101,10 +101,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:08:35.362286Z", "iopub.execute_input": "2024-11-07T18:45:07.613911Z",
"iopub.status.busy": "2024-11-05T05:08:35.362140Z", "iopub.status.busy": "2024-11-07T18:45:07.613746Z",
"iopub.status.idle": "2024-11-05T05:08:35.368711Z", "iopub.status.idle": "2024-11-07T18:45:07.620286Z",
"shell.execute_reply": "2024-11-05T05:08:35.368220Z" "shell.execute_reply": "2024-11-07T18:45:07.619779Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -132,10 +132,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:08:35.371313Z", "iopub.execute_input": "2024-11-07T18:45:07.622407Z",
"iopub.status.busy": "2024-11-05T05:08:35.370877Z", "iopub.status.busy": "2024-11-07T18:45:07.622267Z",
"iopub.status.idle": "2024-11-05T05:08:35.376712Z", "iopub.status.idle": "2024-11-07T18:45:07.628290Z",
"shell.execute_reply": "2024-11-05T05:08:35.376230Z" "shell.execute_reply": "2024-11-07T18:45:07.627793Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -164,10 +164,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:08:35.378982Z", "iopub.execute_input": "2024-11-07T18:45:07.630585Z",
"iopub.status.busy": "2024-11-05T05:08:35.378597Z", "iopub.status.busy": "2024-11-07T18:45:07.630235Z",
"iopub.status.idle": "2024-11-05T05:08:35.391820Z", "iopub.status.idle": "2024-11-07T18:45:07.643498Z",
"shell.execute_reply": "2024-11-05T05:08:35.391336Z" "shell.execute_reply": "2024-11-07T18:45:07.643007Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -183,10 +183,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:08:35.393748Z", "iopub.execute_input": "2024-11-07T18:45:07.645336Z",
"iopub.status.busy": "2024-11-05T05:08:35.393606Z", "iopub.status.busy": "2024-11-07T18:45:07.645196Z",
"iopub.status.idle": "2024-11-05T05:08:35.398645Z", "iopub.status.idle": "2024-11-07T18:45:07.650363Z",
"shell.execute_reply": "2024-11-05T05:08:35.398145Z" "shell.execute_reply": "2024-11-07T18:45:07.649837Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -211,10 +211,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:08:35.400683Z", "iopub.execute_input": "2024-11-07T18:45:07.652212Z",
"iopub.status.busy": "2024-11-05T05:08:35.400419Z", "iopub.status.busy": "2024-11-07T18:45:07.652076Z",
"iopub.status.idle": "2024-11-05T05:08:35.406146Z", "iopub.status.idle": "2024-11-07T18:45:07.658633Z",
"shell.execute_reply": "2024-11-05T05:08:35.405661Z" "shell.execute_reply": "2024-11-07T18:45:07.658119Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -241,10 +241,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:08:35.408176Z", "iopub.execute_input": "2024-11-07T18:45:07.660468Z",
"iopub.status.busy": "2024-11-05T05:08:35.407884Z", "iopub.status.busy": "2024-11-07T18:45:07.660325Z",
"iopub.status.idle": "2024-11-05T05:08:35.413587Z", "iopub.status.idle": "2024-11-07T18:45:07.666476Z",
"shell.execute_reply": "2024-11-05T05:08:35.413108Z" "shell.execute_reply": "2024-11-07T18:45:07.665984Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -271,10 +271,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:08:35.416090Z", "iopub.execute_input": "2024-11-07T18:45:07.668242Z",
"iopub.status.busy": "2024-11-05T05:08:35.415793Z", "iopub.status.busy": "2024-11-07T18:45:07.668108Z",
"iopub.status.idle": "2024-11-05T05:08:36.552549Z", "iopub.status.idle": "2024-11-07T18:45:08.725709Z",
"shell.execute_reply": "2024-11-05T05:08:36.551870Z" "shell.execute_reply": "2024-11-07T18:45:08.725021Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -296,10 +296,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:08:36.554823Z", "iopub.execute_input": "2024-11-07T18:45:08.727865Z",
"iopub.status.busy": "2024-11-05T05:08:36.554680Z", "iopub.status.busy": "2024-11-07T18:45:08.727721Z",
"iopub.status.idle": "2024-11-05T05:08:38.053945Z", "iopub.status.idle": "2024-11-07T18:45:11.165841Z",
"shell.execute_reply": "2024-11-05T05:08:38.053034Z" "shell.execute_reply": "2024-11-07T18:45:11.165282Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -335,10 +335,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:08:38.056783Z", "iopub.execute_input": "2024-11-07T18:45:11.167853Z",
"iopub.status.busy": "2024-11-05T05:08:38.056497Z", "iopub.status.busy": "2024-11-07T18:45:11.167711Z",
"iopub.status.idle": "2024-11-05T05:09:04.436030Z", "iopub.status.idle": "2024-11-07T18:45:39.542988Z",
"shell.execute_reply": "2024-11-05T05:09:04.435311Z" "shell.execute_reply": "2024-11-07T18:45:39.542135Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -360,10 +360,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:09:04.438987Z", "iopub.execute_input": "2024-11-07T18:45:39.545416Z",
"iopub.status.busy": "2024-11-05T05:09:04.438568Z", "iopub.status.busy": "2024-11-07T18:45:39.545005Z",
"iopub.status.idle": "2024-11-05T05:09:04.485291Z", "iopub.status.idle": "2024-11-07T18:45:39.588793Z",
"shell.execute_reply": "2024-11-05T05:09:04.484829Z" "shell.execute_reply": "2024-11-07T18:45:39.588054Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -392,10 +392,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:09:04.487191Z", "iopub.execute_input": "2024-11-07T18:45:39.590729Z",
"iopub.status.busy": "2024-11-05T05:09:04.486929Z", "iopub.status.busy": "2024-11-07T18:45:39.590446Z",
"iopub.status.idle": "2024-11-05T05:09:25.553481Z", "iopub.status.idle": "2024-11-07T18:45:59.660376Z",
"shell.execute_reply": "2024-11-05T05:09:25.552747Z" "shell.execute_reply": "2024-11-07T18:45:59.659992Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -419,10 +419,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:09:25.555813Z", "iopub.execute_input": "2024-11-07T18:45:59.661779Z",
"iopub.status.busy": "2024-11-05T05:09:25.555666Z", "iopub.status.busy": "2024-11-07T18:45:59.661641Z",
"iopub.status.idle": "2024-11-05T05:09:26.354372Z", "iopub.status.idle": "2024-11-07T18:46:00.475726Z",
"shell.execute_reply": "2024-11-05T05:09:26.353693Z" "shell.execute_reply": "2024-11-07T18:46:00.475269Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -445,10 +445,7 @@
"prompts = tokenizer.apply_chat_template(CONVS, tokenize=False)\n", "prompts = tokenizer.apply_chat_template(CONVS, tokenize=False)\n",
"\n", "\n",
"url = \"http://localhost:30030/classify\"\n", "url = \"http://localhost:30030/classify\"\n",
"data = {\n", "data = {\"model\": \"Skywork/Skywork-Reward-Llama-3.1-8B-v0.2\", \"text\": prompts}\n",
" \"model\": \"Skywork/Skywork-Reward-Llama-3.1-8B-v0.2\", \n",
" \"text\": prompts\n",
"}\n",
"\n", "\n",
"responses = requests.post(url, json=data).json()\n", "responses = requests.post(url, json=data).json()\n",
"for response in responses:\n", "for response in responses:\n",
@@ -460,10 +457,10 @@
"execution_count": 15, "execution_count": 15,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:09:26.356532Z", "iopub.execute_input": "2024-11-07T18:46:00.477283Z",
"iopub.status.busy": "2024-11-05T05:09:26.356327Z", "iopub.status.busy": "2024-11-07T18:46:00.477025Z",
"iopub.status.idle": "2024-11-05T05:09:26.396590Z", "iopub.status.idle": "2024-11-07T18:46:00.525758Z",
"shell.execute_reply": "2024-11-05T05:09:26.395914Z" "shell.execute_reply": "2024-11-07T18:46:00.525236Z"
} }
}, },
"outputs": [], "outputs": [],

View File

@@ -35,10 +35,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:21:27.503026Z", "iopub.execute_input": "2024-11-07T18:46:04.789536Z",
"iopub.status.busy": "2024-11-05T05:21:27.502741Z", "iopub.status.busy": "2024-11-07T18:46:04.789418Z",
"iopub.status.idle": "2024-11-05T05:21:49.554631Z", "iopub.status.idle": "2024-11-07T18:46:27.038169Z",
"shell.execute_reply": "2024-11-05T05:21:49.553690Z" "shell.execute_reply": "2024-11-07T18:46:27.037540Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -64,10 +64,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:21:49.558275Z", "iopub.execute_input": "2024-11-07T18:46:27.040005Z",
"iopub.status.busy": "2024-11-05T05:21:49.558110Z", "iopub.status.busy": "2024-11-07T18:46:27.039872Z",
"iopub.status.idle": "2024-11-05T05:21:52.717287Z", "iopub.status.idle": "2024-11-07T18:46:30.203840Z",
"shell.execute_reply": "2024-11-05T05:21:52.716842Z" "shell.execute_reply": "2024-11-07T18:46:30.203368Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -99,10 +99,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:21:52.721738Z", "iopub.execute_input": "2024-11-07T18:46:30.205880Z",
"iopub.status.busy": "2024-11-05T05:21:52.720908Z", "iopub.status.busy": "2024-11-07T18:46:30.205719Z",
"iopub.status.idle": "2024-11-05T05:22:01.770341Z", "iopub.status.idle": "2024-11-07T18:46:39.256561Z",
"shell.execute_reply": "2024-11-05T05:22:01.769510Z" "shell.execute_reply": "2024-11-07T18:46:39.255880Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -137,10 +137,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:22:01.772662Z", "iopub.execute_input": "2024-11-07T18:46:39.259464Z",
"iopub.status.busy": "2024-11-05T05:22:01.772377Z", "iopub.status.busy": "2024-11-07T18:46:39.259309Z",
"iopub.status.idle": "2024-11-05T05:22:04.897499Z", "iopub.status.idle": "2024-11-07T18:46:42.384955Z",
"shell.execute_reply": "2024-11-05T05:22:04.896867Z" "shell.execute_reply": "2024-11-07T18:46:42.384378Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -179,10 +179,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:22:04.899754Z", "iopub.execute_input": "2024-11-07T18:46:42.387431Z",
"iopub.status.busy": "2024-11-05T05:22:04.899478Z", "iopub.status.busy": "2024-11-07T18:46:42.387279Z",
"iopub.status.idle": "2024-11-05T05:22:13.970245Z", "iopub.status.idle": "2024-11-07T18:46:51.448572Z",
"shell.execute_reply": "2024-11-05T05:22:13.969779Z" "shell.execute_reply": "2024-11-07T18:46:51.447781Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -216,10 +216,10 @@
"execution_count": 6, "execution_count": 6,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:22:13.972039Z", "iopub.execute_input": "2024-11-07T18:46:51.451177Z",
"iopub.status.busy": "2024-11-05T05:22:13.971846Z", "iopub.status.busy": "2024-11-07T18:46:51.450952Z",
"iopub.status.idle": "2024-11-05T05:22:14.027421Z", "iopub.status.idle": "2024-11-07T18:46:51.497530Z",
"shell.execute_reply": "2024-11-05T05:22:14.027003Z" "shell.execute_reply": "2024-11-07T18:46:51.496850Z"
} }
}, },
"outputs": [], "outputs": [],

View File

@@ -39,10 +39,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:09:30.637832Z", "iopub.execute_input": "2024-11-07T18:46:54.813876Z",
"iopub.status.busy": "2024-11-05T05:09:30.637709Z", "iopub.status.busy": "2024-11-07T18:46:54.813741Z",
"iopub.status.idle": "2024-11-05T05:09:58.830158Z", "iopub.status.idle": "2024-11-07T18:47:24.015527Z",
"shell.execute_reply": "2024-11-05T05:09:58.829395Z" "shell.execute_reply": "2024-11-07T18:47:24.014987Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -79,10 +79,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:09:58.833008Z", "iopub.execute_input": "2024-11-07T18:47:24.018153Z",
"iopub.status.busy": "2024-11-05T05:09:58.832805Z", "iopub.status.busy": "2024-11-07T18:47:24.017755Z",
"iopub.status.idle": "2024-11-05T05:10:00.187146Z", "iopub.status.idle": "2024-11-07T18:47:25.374821Z",
"shell.execute_reply": "2024-11-05T05:10:00.186657Z" "shell.execute_reply": "2024-11-07T18:47:25.374397Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -119,10 +119,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:10:00.189444Z", "iopub.execute_input": "2024-11-07T18:47:25.376617Z",
"iopub.status.busy": "2024-11-05T05:10:00.189289Z", "iopub.status.busy": "2024-11-07T18:47:25.376495Z",
"iopub.status.idle": "2024-11-05T05:10:03.291891Z", "iopub.status.idle": "2024-11-07T18:47:28.482537Z",
"shell.execute_reply": "2024-11-05T05:10:03.291173Z" "shell.execute_reply": "2024-11-07T18:47:28.482125Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -165,10 +165,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:10:03.294389Z", "iopub.execute_input": "2024-11-07T18:47:28.484819Z",
"iopub.status.busy": "2024-11-05T05:10:03.294237Z", "iopub.status.busy": "2024-11-07T18:47:28.484673Z",
"iopub.status.idle": "2024-11-05T05:10:03.469357Z", "iopub.status.idle": "2024-11-07T18:47:28.659814Z",
"shell.execute_reply": "2024-11-05T05:10:03.468661Z" "shell.execute_reply": "2024-11-07T18:47:28.659435Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -198,10 +198,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:10:03.471573Z", "iopub.execute_input": "2024-11-07T18:47:28.661844Z",
"iopub.status.busy": "2024-11-05T05:10:03.471430Z", "iopub.status.busy": "2024-11-07T18:47:28.661710Z",
"iopub.status.idle": "2024-11-05T05:10:04.977081Z", "iopub.status.idle": "2024-11-07T18:47:30.168922Z",
"shell.execute_reply": "2024-11-05T05:10:04.976391Z" "shell.execute_reply": "2024-11-07T18:47:30.168600Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -234,10 +234,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:10:04.979428Z", "iopub.execute_input": "2024-11-07T18:47:30.171319Z",
"iopub.status.busy": "2024-11-05T05:10:04.979272Z", "iopub.status.busy": "2024-11-07T18:47:30.171176Z",
"iopub.status.idle": "2024-11-05T05:10:08.568761Z", "iopub.status.idle": "2024-11-07T18:47:33.760113Z",
"shell.execute_reply": "2024-11-05T05:10:08.568355Z" "shell.execute_reply": "2024-11-07T18:47:33.759713Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -273,10 +273,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:10:08.571102Z", "iopub.execute_input": "2024-11-07T18:47:33.762729Z",
"iopub.status.busy": "2024-11-05T05:10:08.570964Z", "iopub.status.busy": "2024-11-07T18:47:33.762590Z",
"iopub.status.idle": "2024-11-05T05:10:23.214087Z", "iopub.status.idle": "2024-11-07T18:47:34.255316Z",
"shell.execute_reply": "2024-11-05T05:10:23.213664Z" "shell.execute_reply": "2024-11-07T18:47:34.254907Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -297,7 +297,10 @@
"response = client.chat.completions.create(\n", "response = client.chat.completions.create(\n",
" model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n", " model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
" messages=[\n", " messages=[\n",
" {\"role\": \"user\", \"content\": \"Give me the information of the capital of France in the JSON format.\"},\n", " {\n",
" \"role\": \"user\",\n",
" \"content\": \"Give me the information of the capital of France in the JSON format.\",\n",
" },\n",
" ],\n", " ],\n",
" temperature=0,\n", " temperature=0,\n",
" max_tokens=128,\n", " max_tokens=128,\n",
@@ -322,10 +325,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:10:23.216229Z", "iopub.execute_input": "2024-11-07T18:47:34.257393Z",
"iopub.status.busy": "2024-11-05T05:10:23.216076Z", "iopub.status.busy": "2024-11-07T18:47:34.257246Z",
"iopub.status.idle": "2024-11-05T05:10:23.884236Z", "iopub.status.idle": "2024-11-07T18:47:34.413506Z",
"shell.execute_reply": "2024-11-05T05:10:23.883897Z" "shell.execute_reply": "2024-11-07T18:47:34.413172Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -365,10 +368,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:10:23.886276Z", "iopub.execute_input": "2024-11-07T18:47:34.414816Z",
"iopub.status.busy": "2024-11-05T05:10:23.886136Z", "iopub.status.busy": "2024-11-07T18:47:34.414541Z",
"iopub.status.idle": "2024-11-05T05:10:23.905880Z", "iopub.status.idle": "2024-11-07T18:47:34.431341Z",
"shell.execute_reply": "2024-11-05T05:10:23.905529Z" "shell.execute_reply": "2024-11-07T18:47:34.431081Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -427,10 +430,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:10:23.907468Z", "iopub.execute_input": "2024-11-07T18:47:34.432325Z",
"iopub.status.busy": "2024-11-05T05:10:23.907247Z", "iopub.status.busy": "2024-11-07T18:47:34.432208Z",
"iopub.status.idle": "2024-11-05T05:10:26.920212Z", "iopub.status.idle": "2024-11-07T18:47:37.444337Z",
"shell.execute_reply": "2024-11-05T05:10:26.919865Z" "shell.execute_reply": "2024-11-07T18:47:37.444000Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -482,10 +485,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:10:26.922675Z", "iopub.execute_input": "2024-11-07T18:47:37.445894Z",
"iopub.status.busy": "2024-11-05T05:10:26.922413Z", "iopub.status.busy": "2024-11-07T18:47:37.445744Z",
"iopub.status.idle": "2024-11-05T05:10:51.961703Z", "iopub.status.idle": "2024-11-07T18:48:02.482532Z",
"shell.execute_reply": "2024-11-05T05:10:51.960846Z" "shell.execute_reply": "2024-11-07T18:48:02.482042Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -565,10 +568,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:10:51.964749Z", "iopub.execute_input": "2024-11-07T18:48:02.485206Z",
"iopub.status.busy": "2024-11-05T05:10:51.964215Z", "iopub.status.busy": "2024-11-07T18:48:02.485064Z",
"iopub.status.idle": "2024-11-05T05:11:05.023450Z", "iopub.status.idle": "2024-11-07T18:48:15.521489Z",
"shell.execute_reply": "2024-11-05T05:11:05.023101Z" "shell.execute_reply": "2024-11-07T18:48:15.521156Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -660,10 +663,10 @@
"execution_count": 13, "execution_count": 13,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:11:05.024877Z", "iopub.execute_input": "2024-11-07T18:48:15.522794Z",
"iopub.status.busy": "2024-11-05T05:11:05.024561Z", "iopub.status.busy": "2024-11-07T18:48:15.522657Z",
"iopub.status.idle": "2024-11-05T05:11:06.358695Z", "iopub.status.idle": "2024-11-07T18:48:16.875740Z",
"shell.execute_reply": "2024-11-05T05:11:06.357635Z" "shell.execute_reply": "2024-11-07T18:48:16.874847Z"
} }
}, },
"outputs": [], "outputs": [],

View File

@@ -35,10 +35,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:22:17.227174Z", "iopub.execute_input": "2024-11-07T18:48:21.128020Z",
"iopub.status.busy": "2024-11-05T05:22:17.226952Z", "iopub.status.busy": "2024-11-07T18:48:21.127898Z",
"iopub.status.idle": "2024-11-05T05:22:42.445791Z", "iopub.status.idle": "2024-11-07T18:48:45.310371Z",
"shell.execute_reply": "2024-11-05T05:22:42.444980Z" "shell.execute_reply": "2024-11-07T18:48:45.309469Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -72,10 +72,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:22:42.448147Z", "iopub.execute_input": "2024-11-07T18:48:45.313506Z",
"iopub.status.busy": "2024-11-05T05:22:42.447775Z", "iopub.status.busy": "2024-11-07T18:48:45.313123Z",
"iopub.status.idle": "2024-11-05T05:22:42.495311Z", "iopub.status.idle": "2024-11-07T18:48:45.364918Z",
"shell.execute_reply": "2024-11-05T05:22:42.495027Z" "shell.execute_reply": "2024-11-07T18:48:45.364155Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -106,10 +106,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:22:42.496666Z", "iopub.execute_input": "2024-11-07T18:48:45.367776Z",
"iopub.status.busy": "2024-11-05T05:22:42.496524Z", "iopub.status.busy": "2024-11-07T18:48:45.367490Z",
"iopub.status.idle": "2024-11-05T05:22:42.540687Z", "iopub.status.idle": "2024-11-07T18:48:45.411386Z",
"shell.execute_reply": "2024-11-05T05:22:42.540060Z" "shell.execute_reply": "2024-11-07T18:48:45.411134Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -140,10 +140,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:22:42.542551Z", "iopub.execute_input": "2024-11-07T18:48:45.412462Z",
"iopub.status.busy": "2024-11-05T05:22:42.542282Z", "iopub.status.busy": "2024-11-07T18:48:45.412351Z",
"iopub.status.idle": "2024-11-05T05:22:42.928542Z", "iopub.status.idle": "2024-11-07T18:48:45.768796Z",
"shell.execute_reply": "2024-11-05T05:22:42.928181Z" "shell.execute_reply": "2024-11-07T18:48:45.768406Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -176,10 +176,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:22:42.930093Z", "iopub.execute_input": "2024-11-07T18:48:45.770227Z",
"iopub.status.busy": "2024-11-05T05:22:42.929954Z", "iopub.status.busy": "2024-11-07T18:48:45.770106Z",
"iopub.status.idle": "2024-11-05T05:22:44.799945Z", "iopub.status.idle": "2024-11-07T18:48:47.447065Z",
"shell.execute_reply": "2024-11-05T05:22:44.799562Z" "shell.execute_reply": "2024-11-07T18:48:47.446733Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -208,10 +208,10 @@
"execution_count": 6, "execution_count": 6,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:22:44.801418Z", "iopub.execute_input": "2024-11-07T18:48:47.448510Z",
"iopub.status.busy": "2024-11-05T05:22:44.801192Z", "iopub.status.busy": "2024-11-07T18:48:47.448337Z",
"iopub.status.idle": "2024-11-05T05:22:45.094634Z", "iopub.status.idle": "2024-11-07T18:48:47.743336Z",
"shell.execute_reply": "2024-11-05T05:22:45.093950Z" "shell.execute_reply": "2024-11-07T18:48:47.742276Z"
} }
}, },
"outputs": [], "outputs": [],

View File

@@ -39,10 +39,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:22:49.320999Z", "iopub.execute_input": "2024-11-07T18:43:47.311708Z",
"iopub.status.busy": "2024-11-05T05:22:49.320880Z", "iopub.status.busy": "2024-11-07T18:43:47.311517Z",
"iopub.status.idle": "2024-11-05T05:23:21.537478Z", "iopub.status.idle": "2024-11-07T18:44:18.512576Z",
"shell.execute_reply": "2024-11-05T05:23:21.536956Z" "shell.execute_reply": "2024-11-07T18:44:18.511909Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -78,10 +78,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:23:21.539953Z", "iopub.execute_input": "2024-11-07T18:44:18.515678Z",
"iopub.status.busy": "2024-11-05T05:23:21.539100Z", "iopub.status.busy": "2024-11-07T18:44:18.515314Z",
"iopub.status.idle": "2024-11-05T05:23:25.880179Z", "iopub.status.idle": "2024-11-07T18:44:22.880793Z",
"shell.execute_reply": "2024-11-05T05:23:25.879744Z" "shell.execute_reply": "2024-11-07T18:44:22.880303Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -129,10 +129,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:23:25.881742Z", "iopub.execute_input": "2024-11-07T18:44:22.883309Z",
"iopub.status.busy": "2024-11-05T05:23:25.881595Z", "iopub.status.busy": "2024-11-07T18:44:22.883160Z",
"iopub.status.idle": "2024-11-05T05:23:26.758503Z", "iopub.status.idle": "2024-11-07T18:44:27.048810Z",
"shell.execute_reply": "2024-11-05T05:23:26.758084Z" "shell.execute_reply": "2024-11-07T18:44:27.048074Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -176,10 +176,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:23:26.760098Z", "iopub.execute_input": "2024-11-07T18:44:27.051312Z",
"iopub.status.busy": "2024-11-05T05:23:26.759955Z", "iopub.status.busy": "2024-11-07T18:44:27.051190Z",
"iopub.status.idle": "2024-11-05T05:23:27.849510Z", "iopub.status.idle": "2024-11-07T18:44:32.358097Z",
"shell.execute_reply": "2024-11-05T05:23:27.849117Z" "shell.execute_reply": "2024-11-07T18:44:32.357628Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -227,10 +227,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:23:27.850994Z", "iopub.execute_input": "2024-11-07T18:44:32.359532Z",
"iopub.status.busy": "2024-11-05T05:23:27.850864Z", "iopub.status.busy": "2024-11-07T18:44:32.359413Z",
"iopub.status.idle": "2024-11-05T05:23:31.609137Z", "iopub.status.idle": "2024-11-07T18:44:36.164664Z",
"shell.execute_reply": "2024-11-05T05:23:31.608748Z" "shell.execute_reply": "2024-11-07T18:44:36.164005Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -276,10 +276,10 @@
"execution_count": 6, "execution_count": 6,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:23:31.610683Z", "iopub.execute_input": "2024-11-07T18:44:36.167123Z",
"iopub.status.busy": "2024-11-05T05:23:31.610560Z", "iopub.status.busy": "2024-11-07T18:44:36.166535Z",
"iopub.status.idle": "2024-11-05T05:23:32.965146Z", "iopub.status.idle": "2024-11-07T18:44:37.743761Z",
"shell.execute_reply": "2024-11-05T05:23:32.963922Z" "shell.execute_reply": "2024-11-07T18:44:37.742510Z"
} }
}, },
"outputs": [], "outputs": [],

View File

@@ -31,7 +31,7 @@ extensions = [
] ]
nbsphinx_allow_errors = True nbsphinx_allow_errors = True
nbsphinx_execute = 'never' nbsphinx_execute = "never"
autosectionlabel_prefix_document = True autosectionlabel_prefix_document = True
nbsphinx_allow_directives = True nbsphinx_allow_directives = True
@@ -49,7 +49,7 @@ myst_enable_extensions = [
myst_heading_anchors = 3 myst_heading_anchors = 3
nbsphinx_kernel_name = 'python3' nbsphinx_kernel_name = "python3"
nbsphinx_execute_arguments = [ nbsphinx_execute_arguments = [
"--InlineBackend.figure_formats={'svg', 'pdf'}", "--InlineBackend.figure_formats={'svg', 'pdf'}",
"--InlineBackend.rc={'figure.dpi': 96}", "--InlineBackend.rc={'figure.dpi': 96}",
@@ -130,8 +130,10 @@ html_context = {
html_static_path = ["_static"] html_static_path = ["_static"]
html_css_files = ["css/custom_log.css"] html_css_files = ["css/custom_log.css"]
def setup(app): def setup(app):
app.add_css_file('css/custom_log.css') app.add_css_file("css/custom_log.css")
myst_enable_extensions = [ myst_enable_extensions = [
"dollarmath", "dollarmath",

View File

@@ -33,10 +33,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:11:10.680191Z", "iopub.execute_input": "2024-11-07T18:48:52.032229Z",
"iopub.status.busy": "2024-11-05T05:11:10.679710Z", "iopub.status.busy": "2024-11-07T18:48:52.032105Z",
"iopub.status.idle": "2024-11-05T05:11:39.882385Z", "iopub.status.idle": "2024-11-07T18:49:20.226042Z",
"shell.execute_reply": "2024-11-05T05:11:39.881827Z" "shell.execute_reply": "2024-11-07T18:49:20.225562Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -49,7 +49,7 @@
")\n", ")\n",
"\n", "\n",
"server_process = execute_shell_command(\n", "server_process = execute_shell_command(\n",
"\"\"\"\n", " \"\"\"\n",
"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n", "python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
"--port 30000 --host 0.0.0.0\n", "--port 30000 --host 0.0.0.0\n",
"\"\"\"\n", "\"\"\"\n",
@@ -70,10 +70,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:11:39.883923Z", "iopub.execute_input": "2024-11-07T18:49:20.228006Z",
"iopub.status.busy": "2024-11-05T05:11:39.883721Z", "iopub.status.busy": "2024-11-07T18:49:20.227572Z",
"iopub.status.idle": "2024-11-05T05:11:40.124980Z", "iopub.status.idle": "2024-11-07T18:49:20.469885Z",
"shell.execute_reply": "2024-11-05T05:11:40.124557Z" "shell.execute_reply": "2024-11-07T18:49:20.469518Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -101,10 +101,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:11:40.126564Z", "iopub.execute_input": "2024-11-07T18:49:20.471956Z",
"iopub.status.busy": "2024-11-05T05:11:40.126369Z", "iopub.status.busy": "2024-11-07T18:49:20.471811Z",
"iopub.status.idle": "2024-11-05T05:11:40.324316Z", "iopub.status.idle": "2024-11-07T18:49:20.667997Z",
"shell.execute_reply": "2024-11-05T05:11:40.323693Z" "shell.execute_reply": "2024-11-07T18:49:20.667630Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -115,9 +115,7 @@
"\n", "\n",
"data = {\n", "data = {\n",
" \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n", " \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
" \"messages\": [\n", " \"messages\": [{\"role\": \"user\", \"content\": \"What is the capital of France?\"}],\n",
" {\"role\": \"user\", \"content\": \"What is the capital of France?\"}\n",
" ]\n",
"}\n", "}\n",
"\n", "\n",
"response = requests.post(url, json=data)\n", "response = requests.post(url, json=data)\n",
@@ -136,10 +134,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:11:40.327043Z", "iopub.execute_input": "2024-11-07T18:49:20.669977Z",
"iopub.status.busy": "2024-11-05T05:11:40.326759Z", "iopub.status.busy": "2024-11-07T18:49:20.669826Z",
"iopub.status.idle": "2024-11-05T05:11:41.687336Z", "iopub.status.idle": "2024-11-07T18:49:22.004855Z",
"shell.execute_reply": "2024-11-05T05:11:41.686855Z" "shell.execute_reply": "2024-11-07T18:49:22.004472Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -171,10 +169,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:11:41.688676Z", "iopub.execute_input": "2024-11-07T18:49:22.006983Z",
"iopub.status.busy": "2024-11-05T05:11:41.688527Z", "iopub.status.busy": "2024-11-07T18:49:22.006858Z",
"iopub.status.idle": "2024-11-05T05:11:42.717140Z", "iopub.status.idle": "2024-11-07T18:49:23.029098Z",
"shell.execute_reply": "2024-11-05T05:11:42.716452Z" "shell.execute_reply": "2024-11-07T18:49:23.028697Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -197,7 +195,7 @@
"# Handle the streaming output\n", "# Handle the streaming output\n",
"for chunk in response:\n", "for chunk in response:\n",
" if chunk.choices[0].delta.content:\n", " if chunk.choices[0].delta.content:\n",
" print(chunk.choices[0].delta.content, end='', flush=True)" " print(chunk.choices[0].delta.content, end=\"\", flush=True)"
] ]
}, },
{ {
@@ -214,10 +212,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:11:42.720467Z", "iopub.execute_input": "2024-11-07T18:49:23.031712Z",
"iopub.status.busy": "2024-11-05T05:11:42.720182Z", "iopub.status.busy": "2024-11-07T18:49:23.031571Z",
"iopub.status.idle": "2024-11-05T05:11:43.480765Z", "iopub.status.idle": "2024-11-07T18:49:23.787752Z",
"shell.execute_reply": "2024-11-05T05:11:43.480143Z" "shell.execute_reply": "2024-11-07T18:49:23.787368Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -250,10 +248,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:11:43.483575Z", "iopub.execute_input": "2024-11-07T18:49:23.789840Z",
"iopub.status.busy": "2024-11-05T05:11:43.483295Z", "iopub.status.busy": "2024-11-07T18:49:23.789702Z",
"iopub.status.idle": "2024-11-05T05:11:44.242950Z", "iopub.status.idle": "2024-11-07T18:49:24.545631Z",
"shell.execute_reply": "2024-11-05T05:11:44.242248Z" "shell.execute_reply": "2024-11-07T18:49:24.545241Z"
} }
}, },
"outputs": [], "outputs": [],
@@ -290,10 +288,10 @@
"execution_count": 8, "execution_count": 8,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:11:44.245660Z", "iopub.execute_input": "2024-11-07T18:49:24.547641Z",
"iopub.status.busy": "2024-11-05T05:11:44.245373Z", "iopub.status.busy": "2024-11-07T18:49:24.547497Z",
"iopub.status.idle": "2024-11-05T05:11:45.591682Z", "iopub.status.idle": "2024-11-07T18:49:25.888864Z",
"shell.execute_reply": "2024-11-05T05:11:45.591184Z" "shell.execute_reply": "2024-11-07T18:49:25.888114Z"
} }
}, },
"outputs": [], "outputs": [],

View File

@@ -71,7 +71,7 @@
"source": [ "source": [
"import json\n", "import json\n",
"import os\n", "import os\n",
"from typing import List\n", "from typing import List\n",
"\n", "\n",
"import chromadb\n", "import chromadb\n",
"\n", "\n",
@@ -80,7 +80,7 @@
"if not os.path.exists(path_qca):\n", "if not os.path.exists(path_qca):\n",
" !wget https://virattt.github.io/datasets/abnb-2023-10k.json -O airbnb-2023-10k-qca.json\n", " !wget https://virattt.github.io/datasets/abnb-2023-10k.json -O airbnb-2023-10k-qca.json\n",
"\n", "\n",
"with open(path_qca, 'r') as f:\n", "with open(path_qca, \"r\") as f:\n",
" question_context_answers = json.load(f)\n", " question_context_answers = json.load(f)\n",
"\n", "\n",
"chroma_client = chromadb.PersistentClient()\n", "chroma_client = chromadb.PersistentClient()\n",
@@ -88,7 +88,7 @@
"if collection.count() == 0:\n", "if collection.count() == 0:\n",
" collection.add(\n", " collection.add(\n",
" documents=[qca[\"context\"] for qca in question_context_answers],\n", " documents=[qca[\"context\"] for qca in question_context_answers],\n",
" ids=[str(i) for i in range(len(question_context_answers))]\n", " ids=[str(i) for i in range(len(question_context_answers))],\n",
" )" " )"
], ],
"metadata": { "metadata": {
@@ -123,7 +123,7 @@
"\n", "\n",
"load_dotenv()\n", "load_dotenv()\n",
"\n", "\n",
"os.environ['TOKENIZERS_PARALLELISM'] = \"false\"\n", "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
"\n", "\n",
"p = Parea(api_key=os.getenv(\"PAREA_API_KEY\"), project_name=\"rag_sglang\")\n", "p = Parea(api_key=os.getenv(\"PAREA_API_KEY\"), project_name=\"rag_sglang\")\n",
"p.integrate_with_sglang()\n", "p.integrate_with_sglang()\n",
@@ -150,10 +150,7 @@
"source": [ "source": [
"@trace\n", "@trace\n",
"def retrieval(question: str) -> List[str]:\n", "def retrieval(question: str) -> List[str]:\n",
" return collection.query(\n", " return collection.query(query_texts=[question], n_results=1)[\"documents\"][0]"
" query_texts=[question],\n",
" n_results=1\n",
" )['documents'][0]"
], ],
"metadata": { "metadata": {
"collapsed": false "collapsed": false
@@ -176,7 +173,9 @@
"@function\n", "@function\n",
"def generation_sglang(s, question: str, *context: str):\n", "def generation_sglang(s, question: str, *context: str):\n",
" context = \"\\n\".join(context)\n", " context = \"\\n\".join(context)\n",
" s += user(f'Given this question:\\n{question}\\n\\nAnd this context:\\n{context}\\n\\nAnswer the question.')\n", " s += user(\n",
" f\"Given this question:\\n{question}\\n\\nAnd this context:\\n{context}\\n\\nAnswer the question.\"\n",
" )\n",
" s += assistant(gen(\"answer\"))\n", " s += assistant(gen(\"answer\"))\n",
"\n", "\n",
"\n", "\n",
@@ -223,7 +222,9 @@
" return generation(question, *contexts)\n", " return generation(question, *contexts)\n",
"\n", "\n",
"\n", "\n",
"rag_pipeline(\"When did the World Health Organization formally declare an end to the COVID-19 global health emergency?\")" "rag_pipeline(\n",
" \"When did the World Health Organization formally declare an end to the COVID-19 global health emergency?\"\n",
")"
] ]
}, },
{ {
@@ -271,7 +272,10 @@
"execution_count": null, "execution_count": null,
"outputs": [], "outputs": [],
"source": [ "source": [
"from parea.evals.rag import context_query_relevancy_factory, percent_target_supported_by_context_factory\n", "from parea.evals.rag import (\n",
" context_query_relevancy_factory,\n",
" percent_target_supported_by_context_factory,\n",
")\n",
"\n", "\n",
"\n", "\n",
"context_relevancy_eval = context_query_relevancy_factory()\n", "context_relevancy_eval = context_query_relevancy_factory()\n",
@@ -280,10 +284,7 @@
"\n", "\n",
"@trace(eval_funcs=[context_relevancy_eval, percent_target_supported_by_context])\n", "@trace(eval_funcs=[context_relevancy_eval, percent_target_supported_by_context])\n",
"def retrieval(question: str) -> List[str]:\n", "def retrieval(question: str) -> List[str]:\n",
" return collection.query(\n", " return collection.query(query_texts=[question], n_results=1)[\"documents\"][0]"
" query_texts=[question],\n",
" n_results=1\n",
" )['documents'][0]"
], ],
"metadata": { "metadata": {
"collapsed": false "collapsed": false
@@ -310,10 +311,13 @@
"answer_context_faithfulness = answer_context_faithfulness_statement_level_factory()\n", "answer_context_faithfulness = answer_context_faithfulness_statement_level_factory()\n",
"answer_matches_target_llm_grader = answer_matches_target_llm_grader_factory()\n", "answer_matches_target_llm_grader = answer_matches_target_llm_grader_factory()\n",
"\n", "\n",
"\n",
"@function\n", "@function\n",
"def generation_sglang(s, question: str, *context: str):\n", "def generation_sglang(s, question: str, *context: str):\n",
" context = \"\\n\".join(context)\n", " context = \"\\n\".join(context)\n",
" s += user(f'Given this question:\\n{question}\\n\\nAnd this context:\\n{context}\\n\\nAnswer the question.')\n", " s += user(\n",
" f\"Given this question:\\n{question}\\n\\nAnd this context:\\n{context}\\n\\nAnswer the question.\"\n",
" )\n",
" s += assistant(gen(\"answer\", max_tokens=1_000))\n", " s += assistant(gen(\"answer\", max_tokens=1_000))\n",
"\n", "\n",
"\n", "\n",
@@ -357,7 +361,9 @@
" return generation(question, *contexts)\n", " return generation(question, *contexts)\n",
"\n", "\n",
"\n", "\n",
"rag_pipeline(\"When did the World Health Organization formally declare an end to the COVID-19 global health emergency?\")" "rag_pipeline(\n",
" \"When did the World Health Organization formally declare an end to the COVID-19 global health emergency?\"\n",
")"
], ],
"metadata": { "metadata": {
"collapsed": false "collapsed": false
@@ -402,6 +408,7 @@
"source": [ "source": [
"!pip install nest-asyncio\n", "!pip install nest-asyncio\n",
"import nest_asyncio\n", "import nest_asyncio\n",
"\n",
"nest_asyncio.apply()" "nest_asyncio.apply()"
], ],
"metadata": { "metadata": {
@@ -461,7 +468,7 @@
], ],
"source": [ "source": [
"e = p.experiment(\n", "e = p.experiment(\n",
" 'RAG',\n", " \"RAG\",\n",
" data=[\n", " data=[\n",
" {\n", " {\n",
" \"question\": qca[\"question\"],\n", " \"question\": qca[\"question\"],\n",
@@ -469,7 +476,7 @@
" }\n", " }\n",
" for qca in question_context_answers\n", " for qca in question_context_answers\n",
" ],\n", " ],\n",
" func=rag_pipeline\n", " func=rag_pipeline,\n",
").run()" ").run()"
], ],
"metadata": { "metadata": {

View File

@@ -7,6 +7,7 @@ from sglang.srt.hf_transformers_utils import get_tokenizer
MODEL_PATH = "meta-llama/Llama-3.1-8B-Instruct" MODEL_PATH = "meta-llama/Llama-3.1-8B-Instruct"
def main(): def main():
# Sample prompts. # Sample prompts.
prompts = [ prompts = [

View File

@@ -39,7 +39,7 @@ class ModelConfig:
revision: Optional[str] = None, revision: Optional[str] = None,
context_length: Optional[int] = None, context_length: Optional[int] = None,
model_override_args: Optional[dict] = None, model_override_args: Optional[dict] = None,
is_embedding: Optional[bool] = None is_embedding: Optional[bool] = None,
) -> None: ) -> None:
# Parse args # Parse args
self.model_override_args = json.loads(model_override_args) self.model_override_args = json.loads(model_override_args)
@@ -52,7 +52,9 @@ class ModelConfig:
self.hf_text_config = get_hf_text_config(self.hf_config) self.hf_text_config = get_hf_text_config(self.hf_config)
# Check model type # Check model type
self.is_generation = is_generation_model(self.hf_config.architectures, is_embedding) self.is_generation = is_generation_model(
self.hf_config.architectures, is_embedding
)
self.is_multimodal = is_multimodal_model(self.hf_config.architectures) self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures) self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)

View File

@@ -122,16 +122,14 @@ class QuantizationConfig(ABC):
""" """
raise NotImplementedError raise NotImplementedError
def method_has_implemented_embedding(
method_class: Type[QuantizeMethodBase]) -> bool: def method_has_implemented_embedding(method_class: Type[QuantizeMethodBase]) -> bool:
""" """
Not all quant methods have embedding implemented, so we need to check that Not all quant methods have embedding implemented, so we need to check that
it exists for our given method. We check this by making sure the function it exists for our given method. We check this by making sure the function
has been changed from the base implementation. has been changed from the base implementation.
""" """
base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding", base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding", None)
None)
class_embedding = inspect.getattr_static(method_class, "embedding", None) class_embedding = inspect.getattr_static(method_class, "embedding", None)
return (class_embedding is not None return class_embedding is not None and class_embedding is not base_embedding
and class_embedding is not base_embedding)

View File

@@ -27,59 +27,67 @@ DEFAULT_VOCAB_PADDING_SIZE = 64
class UnquantizedEmbeddingMethod(QuantizeMethodBase): class UnquantizedEmbeddingMethod(QuantizeMethodBase):
"""Unquantized method for embeddings.""" """Unquantized method for embeddings."""
def create_weights(self, layer: torch.nn.Module, def create_weights(
input_size_per_partition: int, self,
output_partition_sizes: List[int], input_size: int, layer: torch.nn.Module,
output_size: int, params_dtype: torch.dtype, input_size_per_partition: int,
**extra_weight_attrs): output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
"""Create weights for embedding layer.""" """Create weights for embedding layer."""
weight = Parameter(torch.empty(sum(output_partition_sizes), weight = Parameter(
input_size_per_partition, torch.empty(
dtype=params_dtype), sum(output_partition_sizes),
requires_grad=False) input_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight) layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs) set_weight_attrs(weight, extra_weight_attrs)
def apply(self, def apply(
layer: torch.nn.Module, self,
x: torch.Tensor, layer: torch.nn.Module,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return F.linear(x, layer.weight, bias) return F.linear(x, layer.weight, bias)
def embedding(self, layer: torch.nn.Module, def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
input_: torch.Tensor) -> torch.Tensor:
return F.embedding(input_, layer.weight) return F.embedding(input_, layer.weight)
def pad_vocab_size(vocab_size: int, def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
"""Pad the vocab size to the given value.""" """Pad the vocab size to the given value."""
return ((vocab_size + pad_to - 1) // pad_to) * pad_to return ((vocab_size + pad_to - 1) // pad_to) * pad_to
def vocab_range_from_per_partition_vocab_size( def vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size: int, per_partition_vocab_size: int, rank: int, offset: int = 0
rank: int, ) -> Sequence[int]:
offset: int = 0) -> Sequence[int]:
index_f = rank * per_partition_vocab_size index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size index_l = index_f + per_partition_vocab_size
return index_f + offset, index_l + offset return index_f + offset, index_l + offset
def vocab_range_from_global_vocab_size(global_vocab_size: int, def vocab_range_from_global_vocab_size(
rank: int, global_vocab_size: int, rank: int, world_size: int, offset: int = 0
world_size: int, ) -> Sequence[int]:
offset: int = 0) -> Sequence[int]:
per_partition_vocab_size = divide(global_vocab_size, world_size) per_partition_vocab_size = divide(global_vocab_size, world_size)
return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, return vocab_range_from_per_partition_vocab_size(
rank, per_partition_vocab_size, rank, offset=offset
offset=offset) )
@dataclass @dataclass
class VocabParallelEmbeddingShardIndices: class VocabParallelEmbeddingShardIndices:
"""Indices for a shard of a vocab parallel embedding.""" """Indices for a shard of a vocab parallel embedding."""
padded_org_vocab_start_index: int padded_org_vocab_start_index: int
padded_org_vocab_end_index: int padded_org_vocab_end_index: int
padded_added_vocab_start_index: int padded_added_vocab_start_index: int
@@ -100,13 +108,11 @@ class VocabParallelEmbeddingShardIndices:
@property @property
def num_org_elements_padded(self) -> int: def num_org_elements_padded(self) -> int:
return (self.padded_org_vocab_end_index - return self.padded_org_vocab_end_index - self.padded_org_vocab_start_index
self.padded_org_vocab_start_index)
@property @property
def num_added_elements_padded(self) -> int: def num_added_elements_padded(self) -> int:
return (self.padded_added_vocab_end_index - return self.padded_added_vocab_end_index - self.padded_added_vocab_start_index
self.padded_added_vocab_start_index)
@property @property
def num_org_vocab_padding(self) -> int: def num_org_vocab_padding(self) -> int:
@@ -122,17 +128,14 @@ class VocabParallelEmbeddingShardIndices:
def __post_init__(self): def __post_init__(self):
# sanity checks # sanity checks
assert (self.padded_org_vocab_start_index <= assert self.padded_org_vocab_start_index <= self.padded_org_vocab_end_index
self.padded_org_vocab_end_index) assert self.padded_added_vocab_start_index <= self.padded_added_vocab_end_index
assert (self.padded_added_vocab_start_index <=
self.padded_added_vocab_end_index)
assert self.org_vocab_start_index <= self.org_vocab_end_index assert self.org_vocab_start_index <= self.org_vocab_end_index
assert self.added_vocab_start_index <= self.added_vocab_end_index assert self.added_vocab_start_index <= self.added_vocab_end_index
assert self.org_vocab_start_index <= self.padded_org_vocab_start_index assert self.org_vocab_start_index <= self.padded_org_vocab_start_index
assert (self.added_vocab_start_index <= assert self.added_vocab_start_index <= self.padded_added_vocab_start_index
self.padded_added_vocab_start_index)
assert self.org_vocab_end_index <= self.padded_org_vocab_end_index assert self.org_vocab_end_index <= self.padded_org_vocab_end_index
assert self.added_vocab_end_index <= self.padded_added_vocab_end_index assert self.added_vocab_end_index <= self.padded_added_vocab_end_index
@@ -142,20 +145,27 @@ class VocabParallelEmbeddingShardIndices:
@torch.jit.script @torch.jit.script
def get_masked_input_and_mask( def get_masked_input_and_mask(
input_: torch.Tensor, org_vocab_start_index: int, input_: torch.Tensor,
org_vocab_end_index: int, num_org_vocab_padding: int, org_vocab_start_index: int,
added_vocab_start_index: int, org_vocab_end_index: int,
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]: num_org_vocab_padding: int,
added_vocab_start_index: int,
added_vocab_end_index: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
# torch.jit.script will fuse all of the pointwise ops below # torch.jit.script will fuse all of the pointwise ops below
# into a single kernel, making it very fast # into a single kernel, making it very fast
org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_end_index)
org_vocab_end_index)
added_vocab_mask = (input_ >= added_vocab_start_index) & ( added_vocab_mask = (input_ >= added_vocab_start_index) & (
input_ < added_vocab_end_index) input_ < added_vocab_end_index
added_offset = added_vocab_start_index - ( )
org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding added_offset = (
valid_offset = (org_vocab_start_index * added_vocab_start_index
org_vocab_mask) + (added_offset * added_vocab_mask) - (org_vocab_end_index - org_vocab_start_index)
- num_org_vocab_padding
)
valid_offset = (org_vocab_start_index * org_vocab_mask) + (
added_offset * added_vocab_mask
)
vocab_mask = org_vocab_mask | added_vocab_mask vocab_mask = org_vocab_mask | added_vocab_mask
input_ = vocab_mask * (input_ - valid_offset) input_ = vocab_mask * (input_ - valid_offset)
return input_, ~vocab_mask return input_, ~vocab_mask
@@ -200,15 +210,17 @@ class VocabParallelEmbedding(torch.nn.Module):
prefix: full name of the layer in the state dict prefix: full name of the layer in the state dict
""" # noqa: E501 """ # noqa: E501
def __init__(self, def __init__(
num_embeddings: int, self,
embedding_dim: int, num_embeddings: int,
params_dtype: Optional[torch.dtype] = None, embedding_dim: int,
org_num_embeddings: Optional[int] = None, params_dtype: Optional[torch.dtype] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, org_num_embeddings: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None, padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
prefix: str = "", quant_config: Optional[QuantizationConfig] = None,
enable_tp: bool = True): prefix: str = "",
enable_tp: bool = True,
):
super().__init__() super().__init__()
self.enable_tp = enable_tp self.enable_tp = enable_tp
@@ -223,18 +235,22 @@ class VocabParallelEmbedding(torch.nn.Module):
self.padding_size = padding_size self.padding_size = padding_size
self.org_vocab_size = org_num_embeddings or num_embeddings self.org_vocab_size = org_num_embeddings or num_embeddings
num_added_embeddings = num_embeddings - self.org_vocab_size num_added_embeddings = num_embeddings - self.org_vocab_size
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size, self.org_vocab_size_padded = pad_vocab_size(
self.padding_size) self.org_vocab_size, self.padding_size
)
self.num_embeddings_padded = pad_vocab_size( self.num_embeddings_padded = pad_vocab_size(
self.org_vocab_size_padded + num_added_embeddings, self.org_vocab_size_padded + num_added_embeddings, self.padding_size
self.padding_size) )
assert self.org_vocab_size_padded <= self.num_embeddings_padded assert self.org_vocab_size_padded <= self.num_embeddings_padded
self.shard_indices = self._get_indices(self.num_embeddings_padded, self.shard_indices = self._get_indices(
self.org_vocab_size_padded, self.num_embeddings_padded,
self.num_embeddings, self.org_vocab_size_padded,
self.org_vocab_size, tp_rank, self.num_embeddings,
self.tp_size) self.org_vocab_size,
tp_rank,
self.tp_size,
)
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
linear_method = None linear_method = None
@@ -248,11 +264,13 @@ class VocabParallelEmbedding(torch.nn.Module):
# layer type like ParallelLMHead, this is not important. # layer type like ParallelLMHead, this is not important.
is_embedding_layer = type(self.__class__) is VocabParallelEmbedding is_embedding_layer = type(self.__class__) is VocabParallelEmbedding
linear_method_implements_embedding = method_has_implemented_embedding( linear_method_implements_embedding = method_has_implemented_embedding(
type(linear_method)) type(linear_method)
)
if is_embedding_layer and not linear_method_implements_embedding: if is_embedding_layer and not linear_method_implements_embedding:
raise NotImplementedError( raise NotImplementedError(
f"The class {type(linear_method).__name__} must implement " f"The class {type(linear_method).__name__} must implement "
"the 'embedding' method, see UnquantizedEmbeddingMethod.") "the 'embedding' method, see UnquantizedEmbeddingMethod."
)
self.linear_method: QuantizeMethodBase = linear_method self.linear_method: QuantizeMethodBase = linear_method
@@ -260,53 +278,68 @@ class VocabParallelEmbedding(torch.nn.Module):
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
# Divide the weight matrix along the vocaburaly dimension. # Divide the weight matrix along the vocaburaly dimension.
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
self.num_embeddings_per_partition = divide(self.num_embeddings_padded, self.num_embeddings_per_partition = divide(
self.tp_size) self.num_embeddings_padded, self.tp_size
assert (self.shard_indices.num_elements_padded == )
self.num_embeddings_per_partition) assert (
self.shard_indices.num_elements_padded == self.num_embeddings_per_partition
)
self.num_org_embeddings_per_partition = ( self.num_org_embeddings_per_partition = (
self.shard_indices.org_vocab_end_index - self.shard_indices.org_vocab_end_index
self.shard_indices.org_vocab_start_index) - self.shard_indices.org_vocab_start_index
)
self.num_added_embeddings_per_partition = ( self.num_added_embeddings_per_partition = (
self.shard_indices.added_vocab_end_index - self.shard_indices.added_vocab_end_index
self.shard_indices.added_vocab_start_index) - self.shard_indices.added_vocab_start_index
)
self.linear_method.create_weights(self, self.linear_method.create_weights(
self.embedding_dim, self,
[self.num_embeddings_per_partition], self.embedding_dim,
self.embedding_dim, [self.num_embeddings_per_partition],
self.num_embeddings_padded, self.embedding_dim,
params_dtype=params_dtype, self.num_embeddings_padded,
weight_loader=self.weight_loader) params_dtype=params_dtype,
weight_loader=self.weight_loader,
)
@classmethod @classmethod
def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int, def _get_indices(
vocab_size: int, org_vocab_size: int, tp_rank: int, cls,
tp_size: int) -> VocabParallelEmbeddingShardIndices: vocab_size_padded: int,
org_vocab_size_padded: int,
vocab_size: int,
org_vocab_size: int,
tp_rank: int,
tp_size: int,
) -> VocabParallelEmbeddingShardIndices:
"""Get start and end indices for vocab parallel embedding, following the """Get start and end indices for vocab parallel embedding, following the
layout outlined in the class docstring, based on the given tp_rank and layout outlined in the class docstring, based on the given tp_rank and
tp_size.""" tp_size."""
num_added_embeddings_padded = vocab_size_padded - org_vocab_size_padded num_added_embeddings_padded = vocab_size_padded - org_vocab_size_padded
padded_org_vocab_start_index, padded_org_vocab_end_index = ( padded_org_vocab_start_index, padded_org_vocab_end_index = (
vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank, vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank, tp_size)
tp_size)) )
padded_added_vocab_start_index, padded_added_vocab_end_index = ( padded_added_vocab_start_index, padded_added_vocab_end_index = (
vocab_range_from_global_vocab_size(num_added_embeddings_padded, vocab_range_from_global_vocab_size(
tp_rank, num_added_embeddings_padded, tp_rank, tp_size, offset=org_vocab_size
tp_size, )
offset=org_vocab_size)) )
# remove padding # remove padding
org_vocab_start_index = min(padded_org_vocab_start_index, org_vocab_start_index = min(padded_org_vocab_start_index, org_vocab_size)
org_vocab_size)
org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size) org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size)
added_vocab_start_index = min(padded_added_vocab_start_index, added_vocab_start_index = min(padded_added_vocab_start_index, vocab_size)
vocab_size)
added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size) added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size)
return VocabParallelEmbeddingShardIndices( return VocabParallelEmbeddingShardIndices(
padded_org_vocab_start_index, padded_org_vocab_end_index, padded_org_vocab_start_index,
padded_added_vocab_start_index, padded_added_vocab_end_index, padded_org_vocab_end_index,
org_vocab_start_index, org_vocab_end_index, padded_added_vocab_start_index,
added_vocab_start_index, added_vocab_end_index) padded_added_vocab_end_index,
org_vocab_start_index,
org_vocab_end_index,
added_vocab_start_index,
added_vocab_end_index,
)
def get_sharded_to_full_mapping(self) -> Optional[List[int]]: def get_sharded_to_full_mapping(self) -> Optional[List[int]]:
"""Get a mapping that can be used to reindex the gathered """Get a mapping that can be used to reindex the gathered
@@ -326,32 +359,49 @@ class VocabParallelEmbedding(torch.nn.Module):
added_embeddings: List[int] = [] added_embeddings: List[int] = []
padding: List[int] = [] padding: List[int] = []
for tp_rank in range(self.tp_size): for tp_rank in range(self.tp_size):
shard_indices = self._get_indices(self.num_embeddings_padded, shard_indices = self._get_indices(
self.org_vocab_size_padded, self.num_embeddings_padded,
self.num_embeddings, self.org_vocab_size_padded,
self.org_vocab_size, tp_rank, self.num_embeddings,
self.tp_size) self.org_vocab_size,
tp_rank,
self.tp_size,
)
range_start = self.num_embeddings_per_partition * tp_rank range_start = self.num_embeddings_per_partition * tp_rank
range_end = self.num_embeddings_per_partition * (tp_rank + 1) range_end = self.num_embeddings_per_partition * (tp_rank + 1)
base_embeddings.extend( base_embeddings.extend(
range(range_start, range(range_start, range_start + shard_indices.num_org_elements)
range_start + shard_indices.num_org_elements)) )
padding.extend( padding.extend(
range(range_start + shard_indices.num_org_elements, range(
range_start + shard_indices.num_org_elements_padded)) range_start + shard_indices.num_org_elements,
range_start + shard_indices.num_org_elements_padded,
)
)
added_embeddings.extend( added_embeddings.extend(
range( range(
range_start + shard_indices.num_org_elements_padded, range_start + shard_indices.num_org_elements_padded,
range_start + shard_indices.num_org_elements_padded + range_start
shard_indices.num_added_elements)) + shard_indices.num_org_elements_padded
+ shard_indices.num_added_elements,
)
)
padding.extend( padding.extend(
range( range(
range_start + shard_indices.num_org_elements_padded + range_start
shard_indices.num_added_elements, + shard_indices.num_org_elements_padded
range_start + shard_indices.num_org_elements_padded + + shard_indices.num_added_elements,
shard_indices.num_added_elements_padded)) range_start
assert (range_start + shard_indices.num_org_elements_padded + + shard_indices.num_org_elements_padded
shard_indices.num_added_elements_padded == range_end) + shard_indices.num_added_elements_padded,
)
)
assert (
range_start
+ shard_indices.num_org_elements_padded
+ shard_indices.num_added_elements_padded
== range_end
)
ret = base_embeddings + added_embeddings + padding ret = base_embeddings + added_embeddings + padding
assert len(ret) == self.num_embeddings_padded assert len(ret) == self.num_embeddings_padded
return ret return ret
@@ -385,10 +435,14 @@ class VocabParallelEmbedding(torch.nn.Module):
# If param packed on the same dim we are sharding on, then # If param packed on the same dim we are sharding on, then
# need to adjust offsets of loaded weight by pack_factor. # need to adjust offsets of loaded weight by pack_factor.
if packed_dim is not None and packed_dim == output_dim: if packed_dim is not None and packed_dim == output_dim:
packed_factor = param.packed_factor if isinstance( packed_factor = (
param, BasevLLMParameter) else param.pack_factor param.packed_factor
assert loaded_weight.shape[output_dim] == (self.org_vocab_size // if isinstance(param, BasevLLMParameter)
param.packed_factor) else param.pack_factor
)
assert loaded_weight.shape[output_dim] == (
self.org_vocab_size // param.packed_factor
)
start_idx = start_idx // packed_factor start_idx = start_idx // packed_factor
shard_size = shard_size // packed_factor shard_size = shard_size // packed_factor
else: else:
@@ -396,23 +450,24 @@ class VocabParallelEmbedding(torch.nn.Module):
# Copy the data. # Copy the data.
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
param[:loaded_weight.shape[0]].data.copy_(loaded_weight) param[: loaded_weight.shape[0]].data.copy_(loaded_weight)
param[loaded_weight.shape[0]:].data.fill_(0) param[loaded_weight.shape[0] :].data.fill_(0)
def forward(self, input_): def forward(self, input_):
if self.tp_size > 1: if self.tp_size > 1:
# Build the mask. # Build the mask.
masked_input, input_mask = get_masked_input_and_mask( masked_input, input_mask = get_masked_input_and_mask(
input_, self.shard_indices.org_vocab_start_index, input_,
self.shard_indices.org_vocab_start_index,
self.shard_indices.org_vocab_end_index, self.shard_indices.org_vocab_end_index,
self.shard_indices.num_org_vocab_padding, self.shard_indices.num_org_vocab_padding,
self.shard_indices.added_vocab_start_index, self.shard_indices.added_vocab_start_index,
self.shard_indices.added_vocab_end_index) self.shard_indices.added_vocab_end_index,
)
else: else:
masked_input = input_ masked_input = input_
# Get the embeddings. # Get the embeddings.
output_parallel = self.linear_method.embedding(self, output_parallel = self.linear_method.embedding(self, masked_input.long())
masked_input.long())
# Mask the output embedding. # Mask the output embedding.
if self.tp_size > 1: if self.tp_size > 1:
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
@@ -426,9 +481,9 @@ class VocabParallelEmbedding(torch.nn.Module):
s = f"num_embeddings={self.num_embeddings_per_partition}" s = f"num_embeddings={self.num_embeddings_per_partition}"
s += f", embedding_dim={self.embedding_dim}" s += f", embedding_dim={self.embedding_dim}"
s += f", org_vocab_size={self.org_vocab_size}" s += f", org_vocab_size={self.org_vocab_size}"
s += f', num_embeddings_padded={self.num_embeddings_padded}' s += f", num_embeddings_padded={self.num_embeddings_padded}"
if self.enable_tp: if self.enable_tp:
s += f', tp_size={self.tp_size}' s += f", tp_size={self.tp_size}"
return s return s
@@ -448,27 +503,38 @@ class ParallelLMHead(VocabParallelEmbedding):
padding_size: padding size for the vocabulary. padding_size: padding size for the vocabulary.
""" """
def __init__(self, def __init__(
num_embeddings: int, self,
embedding_dim: int, num_embeddings: int,
bias: bool = False, embedding_dim: int,
params_dtype: Optional[torch.dtype] = None, bias: bool = False,
org_num_embeddings: Optional[int] = None, params_dtype: Optional[torch.dtype] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, org_num_embeddings: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None, padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
prefix: str = ""): quant_config: Optional[QuantizationConfig] = None,
super().__init__(num_embeddings, embedding_dim, params_dtype, prefix: str = "",
org_num_embeddings, padding_size, quant_config, ):
prefix) super().__init__(
num_embeddings,
embedding_dim,
params_dtype,
org_num_embeddings,
padding_size,
quant_config,
prefix,
)
self.quant_config = quant_config self.quant_config = quant_config
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.empty(self.num_embeddings_per_partition, torch.empty(self.num_embeddings_per_partition, dtype=params_dtype)
dtype=params_dtype)) )
set_weight_attrs(self.bias, { set_weight_attrs(
"output_dim": 0, self.bias,
"weight_loader": self.weight_loader, {
}) "output_dim": 0,
"weight_loader": self.weight_loader,
},
)
else: else:
self.register_parameter("bias", None) self.register_parameter("bias", None)

View File

@@ -86,8 +86,10 @@ class GenerateReqInput:
self.parallel_sample_num = self.sampling_params.get("n", 1) self.parallel_sample_num = self.sampling_params.get("n", 1)
else: # isinstance(self.sampling_params, list): else: # isinstance(self.sampling_params, list):
self.parallel_sample_num = self.sampling_params[0].get("n", 1) self.parallel_sample_num = self.sampling_params[0].get("n", 1)
assert all(self.parallel_sample_num == sampling_params.get("n", 1) for sampling_params in self.sampling_params), ( assert all(
"The parallel_sample_num should be the same for all samples in sample params.") self.parallel_sample_num == sampling_params.get("n", 1)
for sampling_params in self.sampling_params
), "The parallel_sample_num should be the same for all samples in sample params."
if self.parallel_sample_num > 1 and self.is_single: if self.parallel_sample_num > 1 and self.is_single:
self.is_single = False self.is_single = False

View File

@@ -911,8 +911,7 @@ class ScheduleBatch:
keep_indices = [ keep_indices = [
i i
for i in range(len(self.reqs)) for i in range(len(self.reqs))
if not self.reqs[i].finished() if not self.reqs[i].finished() and self.reqs[i] is not being_chunked_req
and self.reqs[i] is not being_chunked_req
] ]
if keep_indices is None or len(keep_indices) == 0: if keep_indices is None or len(keep_indices) == 0:
@@ -1043,6 +1042,7 @@ class ScheduleBatch:
for req in self.reqs: for req in self.reqs:
req.started_time = time.time() req.started_time = time.time()
@dataclasses.dataclass @dataclasses.dataclass
class ModelWorkerBatch: class ModelWorkerBatch:
# The batch id # The batch id

View File

@@ -224,8 +224,8 @@ class Scheduler:
self.forward_ct = 0 self.forward_ct = 0
self.forward_ct_decode = 0 self.forward_ct_decode = 0
self.num_generated_tokens = 0 self.num_generated_tokens = 0
self.last_stats_tic = time.time() # time of last stats for every iter self.last_stats_tic = time.time() # time of last stats for every iter
self.last_log_tic = time.time() # time of last log for print decode log self.last_log_tic = time.time() # time of last log for print decode log
self.stream_interval = server_args.stream_interval self.stream_interval = server_args.stream_interval
# Init chunked prefill # Init chunked prefill
@@ -566,9 +566,7 @@ class Scheduler:
and not self.last_batch.is_empty() and not self.last_batch.is_empty()
): ):
if self.being_chunked_req: if self.being_chunked_req:
self.last_batch.filter_batch( self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
being_chunked_req=self.being_chunked_req
)
self.tree_cache.cache_unfinished_req(self.being_chunked_req) self.tree_cache.cache_unfinished_req(self.being_chunked_req)
# Inflight request keeps its rid but will get a new req_pool_idx. # Inflight request keeps its rid but will get a new req_pool_idx.
self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx) self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
@@ -628,9 +626,7 @@ class Scheduler:
has_inflight = self.being_chunked_req is not None has_inflight = self.being_chunked_req is not None
if has_inflight: if has_inflight:
self.being_chunked_req.init_next_round_input() self.being_chunked_req.init_next_round_input()
self.being_chunked_req = adder.add_inflight_req( self.being_chunked_req = adder.add_inflight_req(self.being_chunked_req)
self.being_chunked_req
)
if self.lora_paths: if self.lora_paths:
lora_set = ( lora_set = (
@@ -813,7 +809,8 @@ class Scheduler:
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch) embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
ret = embeddings, model_worker_batch.bid ret = embeddings, model_worker_batch.bid
return ret return ret
def get_stats(self,batch: ScheduleBatch):
def get_stats(self, batch: ScheduleBatch):
# TODO: get stats for chunked prefill # TODO: get stats for chunked prefill
now = time.time() now = time.time()
@@ -829,8 +826,8 @@ class Scheduler:
# set stats from prefill # set stats from prefill
if self.stats is not None: if self.stats is not None:
# new_seq=self.stats.new_seq # new_seq=self.stats.new_seq
cache_hit_rate=self.stats.cache_hit_rate cache_hit_rate = self.stats.cache_hit_rate
token_usage=self.stats.token_usage token_usage = self.stats.token_usage
# Iteration stats # Iteration stats
num_prompt_tokens_iter = 0 num_prompt_tokens_iter = 0
num_generation_tokens_iter = 0 num_generation_tokens_iter = 0
@@ -851,15 +848,19 @@ class Scheduler:
# _, next_token_ids, _ = result # _, next_token_ids, _ = result
if batch is not None: if batch is not None:
num_generation_tokens_iter = len(batch.output_ids) num_generation_tokens_iter = len(batch.output_ids)
gen_throughput = round(num_generation_tokens_iter / (now - self.last_stats_tic), 2) gen_throughput = round(
num_generation_tokens_iter / (now - self.last_stats_tic), 2
)
for i, req in enumerate(batch.reqs): for i, req in enumerate(batch.reqs):
# NOTE: Batch forward mode is extend befor start decode, # NOTE: Batch forward mode is extend befor start decode,
if batch.forward_mode.is_extend(): if batch.forward_mode.is_extend():
num_prompt_tokens_iter=len(batch.input_ids)+sum(batch.prefix_lens) num_prompt_tokens_iter = len(batch.input_ids) + sum(
batch.prefix_lens
)
time_to_first_tokens_iter.append(now - req.started_time) time_to_first_tokens_iter.append(now - req.started_time)
else: else:
time_per_output_tokens_iter.append(now-self.last_stats_tic) time_per_output_tokens_iter.append(now - self.last_stats_tic)
if req.finished(): if req.finished():
time_e2e_requests.append(now - req.created_time) time_e2e_requests.append(now - req.created_time)
@@ -867,9 +868,10 @@ class Scheduler:
num_prompt_tokens_requests.append(len(req.origin_input_ids)) num_prompt_tokens_requests.append(len(req.origin_input_ids))
num_generation_tokens_requests.append(len(req.output_ids)) num_generation_tokens_requests.append(len(req.output_ids))
finished_reason_requests.append( finished_reason_requests.append(
req.finished_reason.to_json() req.finished_reason.to_json()
if req.finished_reason is not None if req.finished_reason is not None
else None) else None
)
return Stats( return Stats(
new_seq=new_seq, new_seq=new_seq,
@@ -893,7 +895,7 @@ class Scheduler:
max_running_requests=self.max_running_requests, max_running_requests=self.max_running_requests,
) )
def log_stats(self,stats:Stats): def log_stats(self, stats: Stats):
self.metrics_collector.log_stats(stats) self.metrics_collector.log_stats(stats)
def process_batch_result(self, batch: ScheduleBatch, result): def process_batch_result(self, batch: ScheduleBatch, result):
@@ -1003,9 +1005,7 @@ class Scheduler:
if req.is_retracted: if req.is_retracted:
continue continue
if self.server_args.enable_overlap_schedule and ( if self.server_args.enable_overlap_schedule and (req.finished()):
req.finished()
):
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1]) self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
continue continue
@@ -1031,7 +1031,10 @@ class Scheduler:
self.token_to_kv_pool.free_group_end() self.token_to_kv_pool.free_group_end()
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30) self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
if self.tp_rank == 0 and self.forward_ct_decode % self.server_args.decode_log_interval == 0: if (
self.tp_rank == 0
and self.forward_ct_decode % self.server_args.decode_log_interval == 0
):
self.print_decode_stats() self.print_decode_stats()
def add_logprob_return_values( def add_logprob_return_values(

View File

@@ -215,7 +215,7 @@ class TokenizerManager:
logprob_start_len, logprob_start_len,
top_logprobs_num, top_logprobs_num,
obj.stream, obj.stream,
obj.lora_path obj.lora_path,
) )
elif isinstance(obj, EmbeddingReqInput): elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput( tokenized_obj = TokenizedEmbeddingReqInput(
@@ -290,7 +290,9 @@ class TokenizerManager:
# Tokenize all requests # Tokenize all requests
objs = [obj[i] for i in range(batch_size)] objs = [obj[i] for i in range(batch_size)]
tokenized_objs = await asyncio.gather(*(self._tokenize_one_request(obj) for obj in objs)) tokenized_objs = await asyncio.gather(
*(self._tokenize_one_request(obj) for obj in objs)
)
# Cache the common prefix for parallel sampling # Cache the common prefix for parallel sampling
for i in range(batch_size): for i in range(batch_size):
@@ -322,7 +324,9 @@ class TokenizerManager:
rid_to_index = {rid: i for i, rid in enumerate(rids)} rid_to_index = {rid: i for i, rid in enumerate(rids)}
task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators} task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
while task_map: while task_map:
done, _ = await asyncio.wait(task_map.keys(), return_when=asyncio.FIRST_COMPLETED) done, _ = await asyncio.wait(
task_map.keys(), return_when=asyncio.FIRST_COMPLETED
)
for task in done: for task in done:
gen = task_map.pop(task) gen = task_map.pop(task)
@@ -367,7 +371,7 @@ class TokenizerManager:
if self.server_args.dp_size == 1: if self.server_args.dp_size == 1:
res = await self.mem_pool_size res = await self.mem_pool_size
return res.size return res.size
else: # self.server_args.dp_size > 1 else: # self.server_args.dp_size > 1
self.mem_pool_size_tmp = [] self.mem_pool_size_tmp = []
res = await self.mem_pool_size res = await self.mem_pool_size
ret = [r.size for r in res] ret = [r.size for r in res]
@@ -399,7 +403,7 @@ class TokenizerManager:
self.server_args.load_format = obj.load_format self.server_args.load_format = obj.load_format
self.model_path = obj.model_path self.model_path = obj.model_path
return result.success, result.message return result.success, result.message
else: # self.server_args.dp_size > 1 else: # self.server_args.dp_size > 1
self.model_update_tmp = [] self.model_update_tmp = []
result = await self.model_update_result result = await self.model_update_result
@@ -470,7 +474,7 @@ class TokenizerManager:
if isinstance(recv_obj, UpdateWeightReqOutput): if isinstance(recv_obj, UpdateWeightReqOutput):
if self.server_args.dp_size == 1: if self.server_args.dp_size == 1:
self.model_update_result.set_result(recv_obj) self.model_update_result.set_result(recv_obj)
else: # self.server_args.dp_size > 1 else: # self.server_args.dp_size > 1
self.model_update_tmp.append(recv_obj) self.model_update_tmp.append(recv_obj)
# set future if the all results are recevied # set future if the all results are recevied
if len(self.model_update_tmp) == self.server_args.dp_size: if len(self.model_update_tmp) == self.server_args.dp_size:
@@ -479,7 +483,7 @@ class TokenizerManager:
elif isinstance(recv_obj, GetMemPoolSizeReqOutput): elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
if self.server_args.dp_size == 1: if self.server_args.dp_size == 1:
self.mem_pool_size.set_result(recv_obj) self.mem_pool_size.set_result(recv_obj)
else: # self.sever_args.dp_size > 1 else: # self.sever_args.dp_size > 1
self.mem_pool_size_tmp.append(recv_obj) self.mem_pool_size_tmp.append(recv_obj)
# set future if the all results are received # set future if the all results are received
if len(self.mem_pool_size_tmp) == self.server_args.dp_size: if len(self.mem_pool_size_tmp) == self.server_args.dp_size:

View File

@@ -130,27 +130,65 @@ class Metrics:
self.counter_prompt_tokens = Counter( self.counter_prompt_tokens = Counter(
name="sglang:prompt_tokens_total", name="sglang:prompt_tokens_total",
documentation="Number of prefill tokens processed.", documentation="Number of prefill tokens processed.",
labelnames=labelnames) labelnames=labelnames,
)
self.counter_generation_tokens = Counter( self.counter_generation_tokens = Counter(
name="sglang:generation_tokens_total", name="sglang:generation_tokens_total",
documentation="Number of generation tokens processed.", documentation="Number of generation tokens processed.",
labelnames=labelnames) labelnames=labelnames,
)
self.histogram_time_to_first_token = Histogram( self.histogram_time_to_first_token = Histogram(
name="sglang:time_to_first_token_seconds", name="sglang:time_to_first_token_seconds",
documentation="Histogram of time to first token in seconds.", documentation="Histogram of time to first token in seconds.",
labelnames=labelnames, labelnames=labelnames,
buckets=[ buckets=[
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5, 0.001,
0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 15.0, 20.0, 25.0, 30.0 0.005,
]) 0.01,
0.02,
0.04,
0.06,
0.08,
0.1,
0.25,
0.5,
0.75,
1.0,
2.5,
5.0,
7.5,
10.0,
15.0,
20.0,
25.0,
30.0,
],
)
self.histogram_time_per_output_token = Histogram( self.histogram_time_per_output_token = Histogram(
name="sglang:time_per_output_token_seconds", name="sglang:time_per_output_token_seconds",
documentation="Histogram of time per output token in seconds.", documentation="Histogram of time per output token in seconds.",
labelnames=labelnames, labelnames=labelnames,
buckets=[ buckets=[
0.005, 0.01, 0.015, 0.02, 0.025, 0.03, 0.04, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, 0.005,
1.0, 2.5 0.01,
]) 0.015,
0.02,
0.025,
0.03,
0.04,
0.05,
0.075,
0.1,
0.15,
0.2,
0.3,
0.4,
0.5,
0.75,
1.0,
2.5,
],
)
# Request Stats # Request Stats
# Metadata # Metadata
@@ -245,14 +283,19 @@ class PrometheusMetricsCollector(MetricsCollector):
stats.num_generation_tokens_requests, stats.num_generation_tokens_requests,
) )
self._log_counter(self.metrics.counter_prompt_tokens, self._log_counter(
stats.num_prompt_tokens_iter) self.metrics.counter_prompt_tokens, stats.num_prompt_tokens_iter
self._log_counter(self.metrics.counter_generation_tokens, )
stats.num_generation_tokens_iter) self._log_counter(
self._log_histogram(self.metrics.histogram_time_to_first_token, self.metrics.counter_generation_tokens, stats.num_generation_tokens_iter
stats.time_to_first_tokens_iter) )
self._log_histogram(self.metrics.histogram_time_per_output_token, self._log_histogram(
stats.time_per_output_tokens_iter) self.metrics.histogram_time_to_first_token, stats.time_to_first_tokens_iter
)
self._log_histogram(
self.metrics.histogram_time_per_output_token,
stats.time_per_output_tokens_iter,
)
# self._log_gauge(self.metrics.gpu_cache_usage_sys, stats.gpu_cache_usage_sys) # self._log_gauge(self.metrics.gpu_cache_usage_sys, stats.gpu_cache_usage_sys)
self._log_gauge(self.metrics.num_running_sys, stats.num_running_req) self._log_gauge(self.metrics.num_running_sys, stats.num_running_req)

View File

@@ -28,7 +28,7 @@ from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
#from sglang.srt.layers.activation import get_act_fn # from sglang.srt.layers.activation import get_act_fn
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
@@ -47,15 +47,14 @@ class GPT2Attention(nn.Module):
self, self,
layer_id: int, layer_id: int,
config: GPT2Config, config: GPT2Config,
cache_config = None, cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
total_num_heads = config.num_attention_heads total_num_heads = config.num_attention_heads
tensor_model_parallel_world_size = ( tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
get_tensor_model_parallel_world_size())
assert total_num_heads % tensor_model_parallel_world_size == 0 assert total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = total_num_heads // tensor_model_parallel_world_size self.num_heads = total_num_heads // tensor_model_parallel_world_size
self.head_dim = self.hidden_size // total_num_heads self.head_dim = self.hidden_size // total_num_heads
@@ -76,11 +75,13 @@ class GPT2Attention(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.c_proj", prefix=f"{prefix}.c_proj",
) )
self.attn = RadixAttention(self.num_heads, self.attn = RadixAttention(
self.head_dim, self.num_heads,
scaling=self.scale, self.head_dim,
num_kv_heads=total_num_heads, scaling=self.scale,
layer_id=layer_id) num_kv_heads=total_num_heads,
layer_id=layer_id,
)
def forward( def forward(
self, self,
@@ -119,10 +120,14 @@ class GPT2MLP(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.c_proj", prefix=f"{prefix}.c_proj",
) )
self.act = get_act_fn(config.activation_function, quant_config, self.act = get_act_fn(
intermediate_size) config.activation_function, quant_config, intermediate_size
)
def forward(self, hidden_states: torch.Tensor,) -> torch.Tensor: def forward(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
hidden_states, _ = self.c_fc(hidden_states) hidden_states, _ = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states) hidden_states = self.act(hidden_states)
hidden_states, _ = self.c_proj(hidden_states) hidden_states, _ = self.c_proj(hidden_states)
@@ -135,27 +140,20 @@ class GPT2Block(nn.Module):
self, self,
layer_id: int, layer_id: int,
config: GPT2Config, config: GPT2Config,
cache_config = None, cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
inner_dim = (config.n_inner if config.n_inner is not None else 4 * inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPT2Attention(layer_id, self.attn = GPT2Attention(
config, layer_id, config, cache_config, quant_config, prefix=f"{prefix}.attn"
cache_config, )
quant_config,
prefix=f"{prefix}.attn")
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPT2MLP(inner_dim, self.mlp = GPT2MLP(inner_dim, config, quant_config, prefix=f"{prefix}.mlp")
config,
quant_config,
prefix=f"{prefix}.mlp")
def forward( def forward(
self, self,
@@ -179,13 +177,12 @@ class GPT2Block(nn.Module):
return hidden_states return hidden_states
class GPT2Model(nn.Module): class GPT2Model(nn.Module):
def __init__( def __init__(
self, self,
config: GPT2Config, config: GPT2Config,
cache_config = None, cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
): ):
@@ -229,16 +226,15 @@ class GPT2LMHeadModel(nn.Module):
def __init__( def __init__(
self, self,
config: GPT2Config, config: GPT2Config,
cache_config = None, cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = GPT2Model(config, self.transformer = GPT2Model(
cache_config, config, cache_config, quant_config, prefix="transformer"
quant_config, )
prefix="transformer")
self.lm_head = self.transformer.wte self.lm_head = self.transformer.wte
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
@@ -254,8 +250,6 @@ class GPT2LMHeadModel(nn.Module):
input_ids, hidden_states, self.lm_head.weight, forward_batch input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights: for name, loaded_weight in weights:
@@ -280,8 +274,8 @@ class GPT2LMHeadModel(nn.Module):
if not name.endswith(".weight"): if not name.endswith(".weight"):
continue continue
loaded_weight = loaded_weight.t() loaded_weight = loaded_weight.t()
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader", default_weight_loader)
default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
EntryClass = GPT2LMHeadModel EntryClass = GPT2LMHeadModel

View File

@@ -419,6 +419,7 @@ def launch_engine(
for i in range(len(scheduler_pipe_readers)): for i in range(len(scheduler_pipe_readers)):
scheduler_pipe_readers[i].recv() scheduler_pipe_readers[i].recv()
def add_prometheus_middleware(app: FastAPI): def add_prometheus_middleware(app: FastAPI):
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.1/vllm/entrypoints/openai/api_server.py#L216 # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.1/vllm/entrypoints/openai/api_server.py#L216
from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess
@@ -490,6 +491,7 @@ def launch_server(
finally: finally:
t.join() t.join()
def _set_prometheus_env(): def _set_prometheus_env():
# Set prometheus multiprocess directory # Set prometheus multiprocess directory
# sglang uses prometheus multiprocess mode # sglang uses prometheus multiprocess mode
@@ -506,6 +508,7 @@ def _set_prometheus_env():
os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
logger.debug(f"PROMETHEUS_MULTIPROC_DIR: {os.environ['PROMETHEUS_MULTIPROC_DIR']}") logger.debug(f"PROMETHEUS_MULTIPROC_DIR: {os.environ['PROMETHEUS_MULTIPROC_DIR']}")
def _set_envs_and_config(server_args: ServerArgs): def _set_envs_and_config(server_args: ServerArgs):
# Set global environments # Set global environments
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
@@ -763,8 +766,8 @@ class Engine:
# runtime server default log level is log # runtime server default log level is log
# offline engine works in scripts, so we set it to error # offline engine works in scripts, so we set it to error
if 'log_level' not in kwargs: if "log_level" not in kwargs:
kwargs['log_level'] = 'error' kwargs["log_level"] = "error"
server_args = ServerArgs(*args, **kwargs) server_args = ServerArgs(*args, **kwargs)
launch_engine(server_args=server_args) launch_engine(server_args=server_args)

View File

@@ -448,7 +448,7 @@ class ServerArgs:
"--decode-log-interval", "--decode-log-interval",
type=int, type=int,
default=ServerArgs.decode_log_interval, default=ServerArgs.decode_log_interval,
help="The log interval of decode batch" help="The log interval of decode batch",
) )
# Data parallelism # Data parallelism

View File

@@ -742,7 +742,13 @@ def run_mmlu_test(
finally: finally:
pass pass
run_and_check_memory_leak(workload_func, disable_radix_cache, enable_mixed_chunk, enable_overlap, chunked_prefill_size) run_and_check_memory_leak(
workload_func,
disable_radix_cache,
enable_mixed_chunk,
enable_overlap,
chunked_prefill_size,
)
def run_mulit_request_test( def run_mulit_request_test(
@@ -775,4 +781,10 @@ def run_mulit_request_test(
with ThreadPoolExecutor(2) as executor: with ThreadPoolExecutor(2) as executor:
list(executor.map(run_one, list(range(4)))) list(executor.map(run_one, list(range(4))))
run_and_check_memory_leak(workload_func, disable_radix_cache, enable_mixed_chunk, enable_overlap, chunked_prefill_size) run_and_check_memory_leak(
workload_func,
disable_radix_cache,
enable_mixed_chunk,
enable_overlap,
chunked_prefill_size,
)

View File

@@ -349,6 +349,7 @@ def wait_for_server(base_url: str, timeout: int = None) -> None:
def terminate_process(process): def terminate_process(process):
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_child_process
kill_child_process(process.pid, include_self=True) kill_child_process(process.pid, include_self=True)

View File

@@ -11,7 +11,7 @@ router = router.Router(
"http://localhost:30000", "http://localhost:30000",
"http://localhost:30002", "http://localhost:30002",
], ],
policy="random" policy="random",
) )
# Start the router - this will block and run the server # Start the router - this will block and run the server

View File

@@ -104,15 +104,9 @@ if __name__ == "__main__":
default="TinyLlama/TinyLlama-1.1B-Chat-v0.4", default="TinyLlama/TinyLlama-1.1B-Chat-v0.4",
# default="meta-llama/Llama-2-7b-chat-hf", # default="meta-llama/Llama-2-7b-chat-hf",
) )
parser.add_argument( parser.add_argument("--max-new-tokens", type=int, default=16)
"--max-new-tokens",
type=int,
default=16)
parser.add_argument( parser.add_argument("--dtype", type=str, default="float16")
"--dtype",
type=str,
default="float16")
args = parser.parse_args() args = parser.parse_args()

View File

@@ -56,7 +56,7 @@ ALL_OTHER_MODELS = [
ModelCase("HuggingFaceTB/SmolLM-135M-Instruct", skip_long_prompt=True), ModelCase("HuggingFaceTB/SmolLM-135M-Instruct", skip_long_prompt=True),
ModelCase("allenai/OLMo-1B-0724-hf", decode_tolerance=8e-2, skip_long_prompt=True), ModelCase("allenai/OLMo-1B-0724-hf", decode_tolerance=8e-2, skip_long_prompt=True),
ModelCase("THUDM/glm-4-9b-chat"), ModelCase("THUDM/glm-4-9b-chat"),
ModelCase("openai-community/gpt2") ModelCase("openai-community/gpt2"),
] ]
TORCH_DTYPES = [torch.float16] TORCH_DTYPES = [torch.float16]

View File

@@ -3,6 +3,7 @@ python3 -m unittest test_openai_server.TestOpenAIServer.test_batch
python3 -m unittest test_openai_server.TestOpenAIServer.test_completion python3 -m unittest test_openai_server.TestOpenAIServer.test_completion
""" """
import json import json
import time import time
import unittest import unittest

View File

@@ -1,6 +1,7 @@
""" """
python3 -m unittest test_skip_tokenizer_init.TestSkipTokenizerInit.test_parallel_sample python3 -m unittest test_skip_tokenizer_init.TestSkipTokenizerInit.test_parallel_sample
""" """
import json import json
import unittest import unittest

View File

@@ -110,7 +110,6 @@ class TestSRTEngine(unittest.TestCase):
def test_5_prompt_input_ids_consistency(self): def test_5_prompt_input_ids_consistency(self):
prompt = "The capital of UK is" prompt = "The capital of UK is"
model_path = DEFAULT_MODEL_NAME_FOR_TEST model_path = DEFAULT_MODEL_NAME_FOR_TEST
engine = sgl.Engine(model_path=model_path, random_seed=42, log_level="error") engine = sgl.Engine(model_path=model_path, random_seed=42, log_level="error")
sampling_params = {"temperature": 0, "max_new_tokens": 8} sampling_params = {"temperature": 0, "max_new_tokens": 8}
@@ -118,7 +117,9 @@ class TestSRTEngine(unittest.TestCase):
tokenizer = get_tokenizer(model_path) tokenizer = get_tokenizer(model_path)
token_ids = tokenizer.encode(prompt) token_ids = tokenizer.encode(prompt)
out2 = engine.generate(input_ids=token_ids, sampling_params=sampling_params)["text"] out2 = engine.generate(input_ids=token_ids, sampling_params=sampling_params)[
"text"
]
engine.shutdown() engine.shutdown()