diff --git a/docs/basic_usage/native_api.ipynb b/docs/basic_usage/native_api.ipynb index 53dde48ec..33dffea74 100644 --- a/docs/basic_usage/native_api.ipynb +++ b/docs/basic_usage/native_api.ipynb @@ -83,7 +83,8 @@ "- `model_path`: The path/name of the model.\n", "- `is_generation`: Whether the model is used as generation model or embedding model.\n", "- `tokenizer_path`: The path/name of the tokenizer.\n", - "- `preferred_sampling_params`: The default sampling params specified via `--preferred-sampling-params`. `None` is returned in this example as we did not explicitly configure it in server args." + "- `preferred_sampling_params`: The default sampling params specified via `--preferred-sampling-params`. `None` is returned in this example as we did not explicitly configure it in server args.\n", + "- `weight_version`: This field contains the version of the model weights. This is often used to track changes or updates to the model’s trained parameters." ] }, { @@ -106,6 +107,7 @@ " \"is_generation\",\n", " \"tokenizer_path\",\n", " \"preferred_sampling_params\",\n", + " \"weight_version\",\n", "}" ] }, diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 2f0c8a41d..2dd2c75f1 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -88,6 +88,7 @@ from sglang.srt.managers.io_struct import ( UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, UpdateWeightsFromTensorReqInput, + UpdateWeightVersionReqInput, VertexGenerateReqInput, ) from sglang.srt.managers.template_manager import TemplateManager @@ -342,10 +343,19 @@ async def get_model_info(): "tokenizer_path": _global_state.tokenizer_manager.server_args.tokenizer_path, "is_generation": _global_state.tokenizer_manager.is_generation, "preferred_sampling_params": _global_state.tokenizer_manager.server_args.preferred_sampling_params, + "weight_version": _global_state.tokenizer_manager.server_args.weight_version, } return result +@app.get("/get_weight_version") +async def get_weight_version(): + """Get the current weight version.""" + return { + "weight_version": _global_state.tokenizer_manager.server_args.weight_version + } + + @app.get("/get_server_info") async def get_server_info(): # Returns interna states per DP. @@ -537,6 +547,12 @@ async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: R success, message, num_paused_requests = ( await _global_state.tokenizer_manager.update_weights_from_disk(obj, request) ) + + # Update weight version if provided and weights update was successful + if success and obj.weight_version is not None: + _update_weight_version_if_provided(obj.weight_version) + message += f" Weight version updated to {obj.weight_version}." + content = { "success": success, "message": message, @@ -583,6 +599,12 @@ async def update_weights_from_tensor( success, message = await _global_state.tokenizer_manager.update_weights_from_tensor( obj, request ) + + # Update weight version if provided and weights update was successful + if success and obj.weight_version is not None: + _update_weight_version_if_provided(obj.weight_version) + message += f" Weight version updated to {obj.weight_version}." + content = {"success": success, "message": message} return ORJSONResponse( content, status_code=200 if success else HTTPStatus.BAD_REQUEST @@ -599,6 +621,12 @@ async def update_weights_from_distributed( obj, request ) ) + + # Update weight version if provided and weights update was successful + if success and obj.weight_version is not None: + _update_weight_version_if_provided(obj.weight_version) + message += f" Weight version updated to {obj.weight_version}." + content = {"success": success, "message": message} if success: return ORJSONResponse(content, status_code=200) @@ -606,6 +634,36 @@ async def update_weights_from_distributed( return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST) +@app.post("/update_weight_version") +async def update_weight_version(obj: UpdateWeightVersionReqInput, request: Request): + """Update the weight version. This operation requires no active requests.""" + if obj.abort_all_requests: + _global_state.tokenizer_manager.abort_request(abort_all=True) + + # Use a simple approach without the complex lock mechanism for now + # since weight_version update is a simple operation that doesn't affect model weights + try: + # Update the weight version in server args (the single source of truth) + _global_state.tokenizer_manager.server_args.weight_version = obj.new_version + + return ORJSONResponse( + { + "success": True, + "message": f"Weight version updated to {obj.new_version}", + "new_version": obj.new_version, + }, + status_code=HTTPStatus.OK, + ) + except Exception as e: + return ORJSONResponse( + { + "success": False, + "message": f"Failed to update weight version: {str(e)}", + }, + status_code=HTTPStatus.BAD_REQUEST, + ) + + @app.api_route("/get_weights_by_name", methods=["GET", "POST"]) async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request): """Get model parameter by name.""" @@ -966,6 +1024,12 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque return ORJSONResponse({"predictions": ret}) +def _update_weight_version_if_provided(weight_version: Optional[str]) -> None: + """Update weight version if provided.""" + if weight_version is not None: + _global_state.tokenizer_manager.server_args.weight_version = weight_version + + def _create_error_response(e): return ORJSONResponse( {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py index 3a761d9f6..9360993df 100644 --- a/python/sglang/srt/entrypoints/openai/protocol.py +++ b/python/sglang/srt/entrypoints/openai/protocol.py @@ -240,6 +240,7 @@ class CompletionResponse(BaseModel): model: str choices: List[CompletionResponseChoice] usage: UsageInfo + metadata: Optional[Dict[str, Any]] = None class CompletionResponseStreamChoice(BaseModel): @@ -517,6 +518,7 @@ class ChatCompletionResponse(BaseModel): model: str choices: List[ChatCompletionResponseChoice] usage: UsageInfo + metadata: Optional[Dict[str, Any]] = None class DeltaMessage(BaseModel): diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 9914a4c2e..d87c50dd6 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -723,6 +723,7 @@ class OpenAIServingChat(OpenAIServingBase): model=request.model, choices=choices, usage=usage, + metadata={"weight_version": ret[0]["meta_info"]["weight_version"]}, ) def _process_logprobs_tokens( diff --git a/python/sglang/srt/entrypoints/openai/serving_completions.py b/python/sglang/srt/entrypoints/openai/serving_completions.py index 992787132..51fa31296 100644 --- a/python/sglang/srt/entrypoints/openai/serving_completions.py +++ b/python/sglang/srt/entrypoints/openai/serving_completions.py @@ -373,6 +373,7 @@ class OpenAIServingCompletion(OpenAIServingBase): created=created, choices=choices, usage=usage, + metadata={"weight_version": ret[0]["meta_info"]["weight_version"]}, ) def _get_echo_text(self, request: CompletionRequest, index: int) -> str: diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 3c7c0069e..c126dd35b 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -798,6 +798,8 @@ class UpdateWeightFromDiskReqInput: load_format: Optional[str] = None # Whether to abort all requests before updating weights abort_all_requests: bool = False + # Optional: Update weight version along with weights + weight_version: Optional[str] = None @dataclass @@ -819,6 +821,8 @@ class UpdateWeightsFromDistributedReqInput: flush_cache: bool = True # Whether to abort all requests before updating weights abort_all_requests: bool = False + # Optional: Update weight version along with weights + weight_version: Optional[str] = None @dataclass @@ -842,6 +846,8 @@ class UpdateWeightsFromTensorReqInput: flush_cache: bool = True # Whether to abort all requests before updating weights abort_all_requests: bool = False + # Optional: Update weight version along with weights + weight_version: Optional[str] = None @dataclass @@ -872,6 +878,14 @@ class InitWeightsUpdateGroupReqOutput: message: str +@dataclass +class UpdateWeightVersionReqInput: + # The new weight version + new_version: str + # Whether to abort all running requests before updating + abort_all_requests: bool = True + + @dataclass class GetWeightsByNameReqInput: name: str diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 2b9aa8219..58220b1d6 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -1529,6 +1529,7 @@ class TokenizerManager: "id": rid, "finish_reason": recv_obj.finished_reasons[i], "prompt_tokens": recv_obj.prompt_tokens[i], + "weight_version": self.server_args.weight_version, } if getattr(state.obj, "return_logprob", False): diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 0f2879fde..d7f2ebe2b 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -124,6 +124,7 @@ class ServerArgs: # API related api_key: Optional[str] = None served_model_name: Optional[str] = None + weight_version: str = "default" chat_template: Optional[str] = None completion_template: Optional[str] = None file_storage_path: str = "sglang_storage" @@ -1163,6 +1164,12 @@ class ServerArgs: default=ServerArgs.served_model_name, help="Override the model name returned by the v1/models endpoint in OpenAI API server.", ) + parser.add_argument( + "--weight-version", + type=str, + default=ServerArgs.weight_version, + help="Version identifier for the model weights. Defaults to 'default' if not specified.", + ) parser.add_argument( "--chat-template", type=str, diff --git a/test/srt/test_weight_version.py b/test/srt/test_weight_version.py new file mode 100644 index 000000000..5011ee701 --- /dev/null +++ b/test/srt/test_weight_version.py @@ -0,0 +1,227 @@ +""" +Test weight version functionality. + +This test suite verifies the weight_version feature implementation including: +1. Default weight_version setting +2. /get_weight_version endpoint +3. /update_weight_version endpoint +4. /generate request meta_info contains weight_version +5. OpenAI API response metadata contains weight_version +""" + +import unittest + +import requests + +from sglang.test.test_utils import ( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + CustomTestCase, + popen_launch_server, +) + + +class TestWeightVersion(CustomTestCase): + @classmethod + def setUpClass(cls): + """Start server once for all tests with custom weight version.""" + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.base_url = "http://127.0.0.1:30000" + cls.process = popen_launch_server( + cls.model, + base_url=cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--weight-version", + "test_version_1.0", + "--attention-backend", + "flashinfer", + ], + ) + + @classmethod + def tearDownClass(cls): + """Terminate server after all tests complete.""" + if cls.process: + cls.process.terminate() + + def test_weight_version_comprehensive(self): + """Comprehensive test for all weight_version functionality.""" + + response = requests.get(f"{self.base_url}/get_model_info") + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertIn("weight_version", data) + self.assertEqual(data["weight_version"], "test_version_1.0") + + response = requests.get(f"{self.base_url}/get_weight_version") + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertIn("weight_version", data) + self.assertEqual(data["weight_version"], "test_version_1.0") + + request_data = { + "text": "Hello, how are you?", + "sampling_params": { + "temperature": 0.0, + "max_new_tokens": 5, + }, + } + response = requests.post(f"{self.base_url}/generate", json=request_data) + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertIn("meta_info", data) + self.assertIn("weight_version", data["meta_info"]) + self.assertEqual(data["meta_info"]["weight_version"], "test_version_1.0") + + request_data = { + "model": self.model, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 5, + "temperature": 0.0, + } + response = requests.post( + f"{self.base_url}/v1/chat/completions", json=request_data + ) + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertIn("metadata", data) + self.assertIn("weight_version", data["metadata"]) + self.assertEqual(data["metadata"]["weight_version"], "test_version_1.0") + + request_data = { + "model": self.model, + "prompt": "Hello", + "max_tokens": 5, + "temperature": 0.0, + } + response = requests.post(f"{self.base_url}/v1/completions", json=request_data) + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertIn("metadata", data) + self.assertIn("weight_version", data["metadata"]) + self.assertEqual(data["metadata"]["weight_version"], "test_version_1.0") + + update_data = { + "new_version": "updated_version_2.0", + "abort_all_requests": False, + } + response = requests.post( + f"{self.base_url}/update_weight_version", json=update_data + ) + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertTrue(data["success"]) + self.assertEqual(data["new_version"], "updated_version_2.0") + + response = requests.get(f"{self.base_url}/get_weight_version") + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertEqual(data["weight_version"], "updated_version_2.0") + + gen_data = { + "text": "Test persistence", + "sampling_params": {"temperature": 0.0, "max_new_tokens": 3}, + } + response = requests.post(f"{self.base_url}/generate", json=gen_data) + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertEqual(data["meta_info"]["weight_version"], "updated_version_2.0") + + chat_data = { + "model": self.model, + "messages": [{"role": "user", "content": "Test"}], + "max_tokens": 3, + "temperature": 0.0, + } + response = requests.post(f"{self.base_url}/v1/chat/completions", json=chat_data) + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertEqual(data["metadata"]["weight_version"], "updated_version_2.0") + + update_data = {"new_version": "final_version_3.0", "abort_all_requests": True} + response = requests.post( + f"{self.base_url}/update_weight_version", json=update_data + ) + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertTrue(data["success"]) + self.assertEqual(data["new_version"], "final_version_3.0") + + # Check /get_weight_version + response = requests.get(f"{self.base_url}/get_weight_version") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json()["weight_version"], "final_version_3.0") + + # Check /get_model_info + response = requests.get(f"{self.base_url}/get_model_info") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json()["weight_version"], "final_version_3.0") + + # Check /generate meta_info + response = requests.post( + f"{self.base_url}/generate", + json={ + "text": "Final test", + "sampling_params": {"temperature": 0.0, "max_new_tokens": 2}, + }, + ) + self.assertEqual(response.status_code, 200) + self.assertEqual( + response.json()["meta_info"]["weight_version"], "final_version_3.0" + ) + + # Check OpenAI chat metadata + response = requests.post( + f"{self.base_url}/v1/chat/completions", + json={ + "model": self.model, + "messages": [{"role": "user", "content": "Final"}], + "max_tokens": 2, + "temperature": 0.0, + }, + ) + self.assertEqual(response.status_code, 200) + self.assertEqual( + response.json()["metadata"]["weight_version"], "final_version_3.0" + ) + + print("All weight_version functionality tests passed!") + + def test_update_weight_version_with_weight_updates(self): + """Test that weight_version can be updated along with weight updates using real model data.""" + print("Testing weight_version update with real weight operations...") + + # Get current model info for reference + model_info_response = requests.get(f"{self.base_url}/get_model_info") + self.assertEqual(model_info_response.status_code, 200) + current_model_path = model_info_response.json()["model_path"] + + update_data = { + "model_path": current_model_path, + "load_format": "auto", + "abort_all_requests": False, + "weight_version": "disk_update_v2.0.0", + } + + response = requests.post( + f"{self.base_url}/update_weights_from_disk", json=update_data + ) + self.assertEqual( + response.status_code, + 200, + f"update_weights_from_disk failed with status {response.status_code}", + ) + + # Verify version was updated + version_response = requests.get(f"{self.base_url}/get_weight_version") + self.assertEqual(version_response.status_code, 200) + self.assertEqual( + version_response.json()["weight_version"], "disk_update_v2.0.0" + ) + + print("Weight update with weight_version test completed!") + + +if __name__ == "__main__": + unittest.main()