|
|
|
|
@@ -88,6 +88,7 @@ from sglang.srt.managers.io_struct import (
|
|
|
|
|
UpdateWeightFromDiskReqInput,
|
|
|
|
|
UpdateWeightsFromDistributedReqInput,
|
|
|
|
|
UpdateWeightsFromTensorReqInput,
|
|
|
|
|
UpdateWeightVersionReqInput,
|
|
|
|
|
VertexGenerateReqInput,
|
|
|
|
|
)
|
|
|
|
|
from sglang.srt.managers.template_manager import TemplateManager
|
|
|
|
|
@@ -342,10 +343,19 @@ async def get_model_info():
|
|
|
|
|
"tokenizer_path": _global_state.tokenizer_manager.server_args.tokenizer_path,
|
|
|
|
|
"is_generation": _global_state.tokenizer_manager.is_generation,
|
|
|
|
|
"preferred_sampling_params": _global_state.tokenizer_manager.server_args.preferred_sampling_params,
|
|
|
|
|
"weight_version": _global_state.tokenizer_manager.server_args.weight_version,
|
|
|
|
|
}
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/get_weight_version")
|
|
|
|
|
async def get_weight_version():
|
|
|
|
|
"""Get the current weight version."""
|
|
|
|
|
return {
|
|
|
|
|
"weight_version": _global_state.tokenizer_manager.server_args.weight_version
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/get_server_info")
|
|
|
|
|
async def get_server_info():
|
|
|
|
|
# Returns interna states per DP.
|
|
|
|
|
@@ -537,6 +547,12 @@ async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: R
|
|
|
|
|
success, message, num_paused_requests = (
|
|
|
|
|
await _global_state.tokenizer_manager.update_weights_from_disk(obj, request)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Update weight version if provided and weights update was successful
|
|
|
|
|
if success and obj.weight_version is not None:
|
|
|
|
|
_update_weight_version_if_provided(obj.weight_version)
|
|
|
|
|
message += f" Weight version updated to {obj.weight_version}."
|
|
|
|
|
|
|
|
|
|
content = {
|
|
|
|
|
"success": success,
|
|
|
|
|
"message": message,
|
|
|
|
|
@@ -583,6 +599,12 @@ async def update_weights_from_tensor(
|
|
|
|
|
success, message = await _global_state.tokenizer_manager.update_weights_from_tensor(
|
|
|
|
|
obj, request
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Update weight version if provided and weights update was successful
|
|
|
|
|
if success and obj.weight_version is not None:
|
|
|
|
|
_update_weight_version_if_provided(obj.weight_version)
|
|
|
|
|
message += f" Weight version updated to {obj.weight_version}."
|
|
|
|
|
|
|
|
|
|
content = {"success": success, "message": message}
|
|
|
|
|
return ORJSONResponse(
|
|
|
|
|
content, status_code=200 if success else HTTPStatus.BAD_REQUEST
|
|
|
|
|
@@ -599,6 +621,12 @@ async def update_weights_from_distributed(
|
|
|
|
|
obj, request
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Update weight version if provided and weights update was successful
|
|
|
|
|
if success and obj.weight_version is not None:
|
|
|
|
|
_update_weight_version_if_provided(obj.weight_version)
|
|
|
|
|
message += f" Weight version updated to {obj.weight_version}."
|
|
|
|
|
|
|
|
|
|
content = {"success": success, "message": message}
|
|
|
|
|
if success:
|
|
|
|
|
return ORJSONResponse(content, status_code=200)
|
|
|
|
|
@@ -606,6 +634,36 @@ async def update_weights_from_distributed(
|
|
|
|
|
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/update_weight_version")
|
|
|
|
|
async def update_weight_version(obj: UpdateWeightVersionReqInput, request: Request):
|
|
|
|
|
"""Update the weight version. This operation requires no active requests."""
|
|
|
|
|
if obj.abort_all_requests:
|
|
|
|
|
_global_state.tokenizer_manager.abort_request(abort_all=True)
|
|
|
|
|
|
|
|
|
|
# Use a simple approach without the complex lock mechanism for now
|
|
|
|
|
# since weight_version update is a simple operation that doesn't affect model weights
|
|
|
|
|
try:
|
|
|
|
|
# Update the weight version in server args (the single source of truth)
|
|
|
|
|
_global_state.tokenizer_manager.server_args.weight_version = obj.new_version
|
|
|
|
|
|
|
|
|
|
return ORJSONResponse(
|
|
|
|
|
{
|
|
|
|
|
"success": True,
|
|
|
|
|
"message": f"Weight version updated to {obj.new_version}",
|
|
|
|
|
"new_version": obj.new_version,
|
|
|
|
|
},
|
|
|
|
|
status_code=HTTPStatus.OK,
|
|
|
|
|
)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
return ORJSONResponse(
|
|
|
|
|
{
|
|
|
|
|
"success": False,
|
|
|
|
|
"message": f"Failed to update weight version: {str(e)}",
|
|
|
|
|
},
|
|
|
|
|
status_code=HTTPStatus.BAD_REQUEST,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.api_route("/get_weights_by_name", methods=["GET", "POST"])
|
|
|
|
|
async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
|
|
|
|
|
"""Get model parameter by name."""
|
|
|
|
|
@@ -966,6 +1024,12 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque
|
|
|
|
|
return ORJSONResponse({"predictions": ret})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _update_weight_version_if_provided(weight_version: Optional[str]) -> None:
|
|
|
|
|
"""Update weight version if provided."""
|
|
|
|
|
if weight_version is not None:
|
|
|
|
|
_global_state.tokenizer_manager.server_args.weight_version = weight_version
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_error_response(e):
|
|
|
|
|
return ORJSONResponse(
|
|
|
|
|
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
|
|
|
|
|