feat: Add model version tracking with API endpoints and response metadata (#8795)
This commit is contained in:
@@ -83,7 +83,8 @@
|
|||||||
"- `model_path`: The path/name of the model.\n",
|
"- `model_path`: The path/name of the model.\n",
|
||||||
"- `is_generation`: Whether the model is used as generation model or embedding model.\n",
|
"- `is_generation`: Whether the model is used as generation model or embedding model.\n",
|
||||||
"- `tokenizer_path`: The path/name of the tokenizer.\n",
|
"- `tokenizer_path`: The path/name of the tokenizer.\n",
|
||||||
"- `preferred_sampling_params`: The default sampling params specified via `--preferred-sampling-params`. `None` is returned in this example as we did not explicitly configure it in server args."
|
"- `preferred_sampling_params`: The default sampling params specified via `--preferred-sampling-params`. `None` is returned in this example as we did not explicitly configure it in server args.\n",
|
||||||
|
"- `weight_version`: This field contains the version of the model weights. This is often used to track changes or updates to the model’s trained parameters."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -106,6 +107,7 @@
|
|||||||
" \"is_generation\",\n",
|
" \"is_generation\",\n",
|
||||||
" \"tokenizer_path\",\n",
|
" \"tokenizer_path\",\n",
|
||||||
" \"preferred_sampling_params\",\n",
|
" \"preferred_sampling_params\",\n",
|
||||||
|
" \"weight_version\",\n",
|
||||||
"}"
|
"}"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -88,6 +88,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
UpdateWeightFromDiskReqInput,
|
UpdateWeightFromDiskReqInput,
|
||||||
UpdateWeightsFromDistributedReqInput,
|
UpdateWeightsFromDistributedReqInput,
|
||||||
UpdateWeightsFromTensorReqInput,
|
UpdateWeightsFromTensorReqInput,
|
||||||
|
UpdateWeightVersionReqInput,
|
||||||
VertexGenerateReqInput,
|
VertexGenerateReqInput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.template_manager import TemplateManager
|
from sglang.srt.managers.template_manager import TemplateManager
|
||||||
@@ -342,10 +343,19 @@ async def get_model_info():
|
|||||||
"tokenizer_path": _global_state.tokenizer_manager.server_args.tokenizer_path,
|
"tokenizer_path": _global_state.tokenizer_manager.server_args.tokenizer_path,
|
||||||
"is_generation": _global_state.tokenizer_manager.is_generation,
|
"is_generation": _global_state.tokenizer_manager.is_generation,
|
||||||
"preferred_sampling_params": _global_state.tokenizer_manager.server_args.preferred_sampling_params,
|
"preferred_sampling_params": _global_state.tokenizer_manager.server_args.preferred_sampling_params,
|
||||||
|
"weight_version": _global_state.tokenizer_manager.server_args.weight_version,
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/get_weight_version")
|
||||||
|
async def get_weight_version():
|
||||||
|
"""Get the current weight version."""
|
||||||
|
return {
|
||||||
|
"weight_version": _global_state.tokenizer_manager.server_args.weight_version
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/get_server_info")
|
@app.get("/get_server_info")
|
||||||
async def get_server_info():
|
async def get_server_info():
|
||||||
# Returns interna states per DP.
|
# Returns interna states per DP.
|
||||||
@@ -537,6 +547,12 @@ async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: R
|
|||||||
success, message, num_paused_requests = (
|
success, message, num_paused_requests = (
|
||||||
await _global_state.tokenizer_manager.update_weights_from_disk(obj, request)
|
await _global_state.tokenizer_manager.update_weights_from_disk(obj, request)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Update weight version if provided and weights update was successful
|
||||||
|
if success and obj.weight_version is not None:
|
||||||
|
_update_weight_version_if_provided(obj.weight_version)
|
||||||
|
message += f" Weight version updated to {obj.weight_version}."
|
||||||
|
|
||||||
content = {
|
content = {
|
||||||
"success": success,
|
"success": success,
|
||||||
"message": message,
|
"message": message,
|
||||||
@@ -583,6 +599,12 @@ async def update_weights_from_tensor(
|
|||||||
success, message = await _global_state.tokenizer_manager.update_weights_from_tensor(
|
success, message = await _global_state.tokenizer_manager.update_weights_from_tensor(
|
||||||
obj, request
|
obj, request
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Update weight version if provided and weights update was successful
|
||||||
|
if success and obj.weight_version is not None:
|
||||||
|
_update_weight_version_if_provided(obj.weight_version)
|
||||||
|
message += f" Weight version updated to {obj.weight_version}."
|
||||||
|
|
||||||
content = {"success": success, "message": message}
|
content = {"success": success, "message": message}
|
||||||
return ORJSONResponse(
|
return ORJSONResponse(
|
||||||
content, status_code=200 if success else HTTPStatus.BAD_REQUEST
|
content, status_code=200 if success else HTTPStatus.BAD_REQUEST
|
||||||
@@ -599,6 +621,12 @@ async def update_weights_from_distributed(
|
|||||||
obj, request
|
obj, request
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Update weight version if provided and weights update was successful
|
||||||
|
if success and obj.weight_version is not None:
|
||||||
|
_update_weight_version_if_provided(obj.weight_version)
|
||||||
|
message += f" Weight version updated to {obj.weight_version}."
|
||||||
|
|
||||||
content = {"success": success, "message": message}
|
content = {"success": success, "message": message}
|
||||||
if success:
|
if success:
|
||||||
return ORJSONResponse(content, status_code=200)
|
return ORJSONResponse(content, status_code=200)
|
||||||
@@ -606,6 +634,36 @@ async def update_weights_from_distributed(
|
|||||||
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/update_weight_version")
|
||||||
|
async def update_weight_version(obj: UpdateWeightVersionReqInput, request: Request):
|
||||||
|
"""Update the weight version. This operation requires no active requests."""
|
||||||
|
if obj.abort_all_requests:
|
||||||
|
_global_state.tokenizer_manager.abort_request(abort_all=True)
|
||||||
|
|
||||||
|
# Use a simple approach without the complex lock mechanism for now
|
||||||
|
# since weight_version update is a simple operation that doesn't affect model weights
|
||||||
|
try:
|
||||||
|
# Update the weight version in server args (the single source of truth)
|
||||||
|
_global_state.tokenizer_manager.server_args.weight_version = obj.new_version
|
||||||
|
|
||||||
|
return ORJSONResponse(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"message": f"Weight version updated to {obj.new_version}",
|
||||||
|
"new_version": obj.new_version,
|
||||||
|
},
|
||||||
|
status_code=HTTPStatus.OK,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return ORJSONResponse(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": f"Failed to update weight version: {str(e)}",
|
||||||
|
},
|
||||||
|
status_code=HTTPStatus.BAD_REQUEST,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.api_route("/get_weights_by_name", methods=["GET", "POST"])
|
@app.api_route("/get_weights_by_name", methods=["GET", "POST"])
|
||||||
async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
|
async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
|
||||||
"""Get model parameter by name."""
|
"""Get model parameter by name."""
|
||||||
@@ -966,6 +1024,12 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque
|
|||||||
return ORJSONResponse({"predictions": ret})
|
return ORJSONResponse({"predictions": ret})
|
||||||
|
|
||||||
|
|
||||||
|
def _update_weight_version_if_provided(weight_version: Optional[str]) -> None:
|
||||||
|
"""Update weight version if provided."""
|
||||||
|
if weight_version is not None:
|
||||||
|
_global_state.tokenizer_manager.server_args.weight_version = weight_version
|
||||||
|
|
||||||
|
|
||||||
def _create_error_response(e):
|
def _create_error_response(e):
|
||||||
return ORJSONResponse(
|
return ORJSONResponse(
|
||||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||||
|
|||||||
@@ -240,6 +240,7 @@ class CompletionResponse(BaseModel):
|
|||||||
model: str
|
model: str
|
||||||
choices: List[CompletionResponseChoice]
|
choices: List[CompletionResponseChoice]
|
||||||
usage: UsageInfo
|
usage: UsageInfo
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
class CompletionResponseStreamChoice(BaseModel):
|
class CompletionResponseStreamChoice(BaseModel):
|
||||||
@@ -517,6 +518,7 @@ class ChatCompletionResponse(BaseModel):
|
|||||||
model: str
|
model: str
|
||||||
choices: List[ChatCompletionResponseChoice]
|
choices: List[ChatCompletionResponseChoice]
|
||||||
usage: UsageInfo
|
usage: UsageInfo
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
class DeltaMessage(BaseModel):
|
class DeltaMessage(BaseModel):
|
||||||
|
|||||||
@@ -723,6 +723,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
model=request.model,
|
model=request.model,
|
||||||
choices=choices,
|
choices=choices,
|
||||||
usage=usage,
|
usage=usage,
|
||||||
|
metadata={"weight_version": ret[0]["meta_info"]["weight_version"]},
|
||||||
)
|
)
|
||||||
|
|
||||||
def _process_logprobs_tokens(
|
def _process_logprobs_tokens(
|
||||||
|
|||||||
@@ -373,6 +373,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
|||||||
created=created,
|
created=created,
|
||||||
choices=choices,
|
choices=choices,
|
||||||
usage=usage,
|
usage=usage,
|
||||||
|
metadata={"weight_version": ret[0]["meta_info"]["weight_version"]},
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_echo_text(self, request: CompletionRequest, index: int) -> str:
|
def _get_echo_text(self, request: CompletionRequest, index: int) -> str:
|
||||||
|
|||||||
@@ -798,6 +798,8 @@ class UpdateWeightFromDiskReqInput:
|
|||||||
load_format: Optional[str] = None
|
load_format: Optional[str] = None
|
||||||
# Whether to abort all requests before updating weights
|
# Whether to abort all requests before updating weights
|
||||||
abort_all_requests: bool = False
|
abort_all_requests: bool = False
|
||||||
|
# Optional: Update weight version along with weights
|
||||||
|
weight_version: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -819,6 +821,8 @@ class UpdateWeightsFromDistributedReqInput:
|
|||||||
flush_cache: bool = True
|
flush_cache: bool = True
|
||||||
# Whether to abort all requests before updating weights
|
# Whether to abort all requests before updating weights
|
||||||
abort_all_requests: bool = False
|
abort_all_requests: bool = False
|
||||||
|
# Optional: Update weight version along with weights
|
||||||
|
weight_version: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -842,6 +846,8 @@ class UpdateWeightsFromTensorReqInput:
|
|||||||
flush_cache: bool = True
|
flush_cache: bool = True
|
||||||
# Whether to abort all requests before updating weights
|
# Whether to abort all requests before updating weights
|
||||||
abort_all_requests: bool = False
|
abort_all_requests: bool = False
|
||||||
|
# Optional: Update weight version along with weights
|
||||||
|
weight_version: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -872,6 +878,14 @@ class InitWeightsUpdateGroupReqOutput:
|
|||||||
message: str
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UpdateWeightVersionReqInput:
|
||||||
|
# The new weight version
|
||||||
|
new_version: str
|
||||||
|
# Whether to abort all running requests before updating
|
||||||
|
abort_all_requests: bool = True
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GetWeightsByNameReqInput:
|
class GetWeightsByNameReqInput:
|
||||||
name: str
|
name: str
|
||||||
|
|||||||
@@ -1529,6 +1529,7 @@ class TokenizerManager:
|
|||||||
"id": rid,
|
"id": rid,
|
||||||
"finish_reason": recv_obj.finished_reasons[i],
|
"finish_reason": recv_obj.finished_reasons[i],
|
||||||
"prompt_tokens": recv_obj.prompt_tokens[i],
|
"prompt_tokens": recv_obj.prompt_tokens[i],
|
||||||
|
"weight_version": self.server_args.weight_version,
|
||||||
}
|
}
|
||||||
|
|
||||||
if getattr(state.obj, "return_logprob", False):
|
if getattr(state.obj, "return_logprob", False):
|
||||||
|
|||||||
@@ -124,6 +124,7 @@ class ServerArgs:
|
|||||||
# API related
|
# API related
|
||||||
api_key: Optional[str] = None
|
api_key: Optional[str] = None
|
||||||
served_model_name: Optional[str] = None
|
served_model_name: Optional[str] = None
|
||||||
|
weight_version: str = "default"
|
||||||
chat_template: Optional[str] = None
|
chat_template: Optional[str] = None
|
||||||
completion_template: Optional[str] = None
|
completion_template: Optional[str] = None
|
||||||
file_storage_path: str = "sglang_storage"
|
file_storage_path: str = "sglang_storage"
|
||||||
@@ -1163,6 +1164,12 @@ class ServerArgs:
|
|||||||
default=ServerArgs.served_model_name,
|
default=ServerArgs.served_model_name,
|
||||||
help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
|
help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--weight-version",
|
||||||
|
type=str,
|
||||||
|
default=ServerArgs.weight_version,
|
||||||
|
help="Version identifier for the model weights. Defaults to 'default' if not specified.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--chat-template",
|
"--chat-template",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
227
test/srt/test_weight_version.py
Normal file
227
test/srt/test_weight_version.py
Normal file
@@ -0,0 +1,227 @@
|
|||||||
|
"""
|
||||||
|
Test weight version functionality.
|
||||||
|
|
||||||
|
This test suite verifies the weight_version feature implementation including:
|
||||||
|
1. Default weight_version setting
|
||||||
|
2. /get_weight_version endpoint
|
||||||
|
3. /update_weight_version endpoint
|
||||||
|
4. /generate request meta_info contains weight_version
|
||||||
|
5. OpenAI API response metadata contains weight_version
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
CustomTestCase,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestWeightVersion(CustomTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
"""Start server once for all tests with custom weight version."""
|
||||||
|
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = "http://127.0.0.1:30000"
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
base_url=cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[
|
||||||
|
"--weight-version",
|
||||||
|
"test_version_1.0",
|
||||||
|
"--attention-backend",
|
||||||
|
"flashinfer",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
"""Terminate server after all tests complete."""
|
||||||
|
if cls.process:
|
||||||
|
cls.process.terminate()
|
||||||
|
|
||||||
|
def test_weight_version_comprehensive(self):
|
||||||
|
"""Comprehensive test for all weight_version functionality."""
|
||||||
|
|
||||||
|
response = requests.get(f"{self.base_url}/get_model_info")
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
data = response.json()
|
||||||
|
self.assertIn("weight_version", data)
|
||||||
|
self.assertEqual(data["weight_version"], "test_version_1.0")
|
||||||
|
|
||||||
|
response = requests.get(f"{self.base_url}/get_weight_version")
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
data = response.json()
|
||||||
|
self.assertIn("weight_version", data)
|
||||||
|
self.assertEqual(data["weight_version"], "test_version_1.0")
|
||||||
|
|
||||||
|
request_data = {
|
||||||
|
"text": "Hello, how are you?",
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0.0,
|
||||||
|
"max_new_tokens": 5,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
response = requests.post(f"{self.base_url}/generate", json=request_data)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
data = response.json()
|
||||||
|
self.assertIn("meta_info", data)
|
||||||
|
self.assertIn("weight_version", data["meta_info"])
|
||||||
|
self.assertEqual(data["meta_info"]["weight_version"], "test_version_1.0")
|
||||||
|
|
||||||
|
request_data = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
|
"max_tokens": 5,
|
||||||
|
"temperature": 0.0,
|
||||||
|
}
|
||||||
|
response = requests.post(
|
||||||
|
f"{self.base_url}/v1/chat/completions", json=request_data
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
data = response.json()
|
||||||
|
self.assertIn("metadata", data)
|
||||||
|
self.assertIn("weight_version", data["metadata"])
|
||||||
|
self.assertEqual(data["metadata"]["weight_version"], "test_version_1.0")
|
||||||
|
|
||||||
|
request_data = {
|
||||||
|
"model": self.model,
|
||||||
|
"prompt": "Hello",
|
||||||
|
"max_tokens": 5,
|
||||||
|
"temperature": 0.0,
|
||||||
|
}
|
||||||
|
response = requests.post(f"{self.base_url}/v1/completions", json=request_data)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
data = response.json()
|
||||||
|
self.assertIn("metadata", data)
|
||||||
|
self.assertIn("weight_version", data["metadata"])
|
||||||
|
self.assertEqual(data["metadata"]["weight_version"], "test_version_1.0")
|
||||||
|
|
||||||
|
update_data = {
|
||||||
|
"new_version": "updated_version_2.0",
|
||||||
|
"abort_all_requests": False,
|
||||||
|
}
|
||||||
|
response = requests.post(
|
||||||
|
f"{self.base_url}/update_weight_version", json=update_data
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
data = response.json()
|
||||||
|
self.assertTrue(data["success"])
|
||||||
|
self.assertEqual(data["new_version"], "updated_version_2.0")
|
||||||
|
|
||||||
|
response = requests.get(f"{self.base_url}/get_weight_version")
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
data = response.json()
|
||||||
|
self.assertEqual(data["weight_version"], "updated_version_2.0")
|
||||||
|
|
||||||
|
gen_data = {
|
||||||
|
"text": "Test persistence",
|
||||||
|
"sampling_params": {"temperature": 0.0, "max_new_tokens": 3},
|
||||||
|
}
|
||||||
|
response = requests.post(f"{self.base_url}/generate", json=gen_data)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
data = response.json()
|
||||||
|
self.assertEqual(data["meta_info"]["weight_version"], "updated_version_2.0")
|
||||||
|
|
||||||
|
chat_data = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": [{"role": "user", "content": "Test"}],
|
||||||
|
"max_tokens": 3,
|
||||||
|
"temperature": 0.0,
|
||||||
|
}
|
||||||
|
response = requests.post(f"{self.base_url}/v1/chat/completions", json=chat_data)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
data = response.json()
|
||||||
|
self.assertEqual(data["metadata"]["weight_version"], "updated_version_2.0")
|
||||||
|
|
||||||
|
update_data = {"new_version": "final_version_3.0", "abort_all_requests": True}
|
||||||
|
response = requests.post(
|
||||||
|
f"{self.base_url}/update_weight_version", json=update_data
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
data = response.json()
|
||||||
|
self.assertTrue(data["success"])
|
||||||
|
self.assertEqual(data["new_version"], "final_version_3.0")
|
||||||
|
|
||||||
|
# Check /get_weight_version
|
||||||
|
response = requests.get(f"{self.base_url}/get_weight_version")
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
self.assertEqual(response.json()["weight_version"], "final_version_3.0")
|
||||||
|
|
||||||
|
# Check /get_model_info
|
||||||
|
response = requests.get(f"{self.base_url}/get_model_info")
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
self.assertEqual(response.json()["weight_version"], "final_version_3.0")
|
||||||
|
|
||||||
|
# Check /generate meta_info
|
||||||
|
response = requests.post(
|
||||||
|
f"{self.base_url}/generate",
|
||||||
|
json={
|
||||||
|
"text": "Final test",
|
||||||
|
"sampling_params": {"temperature": 0.0, "max_new_tokens": 2},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
self.assertEqual(
|
||||||
|
response.json()["meta_info"]["weight_version"], "final_version_3.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check OpenAI chat metadata
|
||||||
|
response = requests.post(
|
||||||
|
f"{self.base_url}/v1/chat/completions",
|
||||||
|
json={
|
||||||
|
"model": self.model,
|
||||||
|
"messages": [{"role": "user", "content": "Final"}],
|
||||||
|
"max_tokens": 2,
|
||||||
|
"temperature": 0.0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
self.assertEqual(
|
||||||
|
response.json()["metadata"]["weight_version"], "final_version_3.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("All weight_version functionality tests passed!")
|
||||||
|
|
||||||
|
def test_update_weight_version_with_weight_updates(self):
|
||||||
|
"""Test that weight_version can be updated along with weight updates using real model data."""
|
||||||
|
print("Testing weight_version update with real weight operations...")
|
||||||
|
|
||||||
|
# Get current model info for reference
|
||||||
|
model_info_response = requests.get(f"{self.base_url}/get_model_info")
|
||||||
|
self.assertEqual(model_info_response.status_code, 200)
|
||||||
|
current_model_path = model_info_response.json()["model_path"]
|
||||||
|
|
||||||
|
update_data = {
|
||||||
|
"model_path": current_model_path,
|
||||||
|
"load_format": "auto",
|
||||||
|
"abort_all_requests": False,
|
||||||
|
"weight_version": "disk_update_v2.0.0",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
f"{self.base_url}/update_weights_from_disk", json=update_data
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
response.status_code,
|
||||||
|
200,
|
||||||
|
f"update_weights_from_disk failed with status {response.status_code}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify version was updated
|
||||||
|
version_response = requests.get(f"{self.base_url}/get_weight_version")
|
||||||
|
self.assertEqual(version_response.status_code, 200)
|
||||||
|
self.assertEqual(
|
||||||
|
version_response.json()["weight_version"], "disk_update_v2.0.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Weight update with weight_version test completed!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user