feat: Add model version tracking with API endpoints and response metadata (#8795)

This commit is contained in:
Chengxing Xie
2025-08-15 03:13:46 +08:00
committed by GitHub
parent 2cc9eeab01
commit c1c7dc4534
9 changed files with 320 additions and 1 deletions

View File

@@ -88,6 +88,7 @@ from sglang.srt.managers.io_struct import (
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput,
UpdateWeightVersionReqInput,
VertexGenerateReqInput,
)
from sglang.srt.managers.template_manager import TemplateManager
@@ -342,10 +343,19 @@ async def get_model_info():
"tokenizer_path": _global_state.tokenizer_manager.server_args.tokenizer_path,
"is_generation": _global_state.tokenizer_manager.is_generation,
"preferred_sampling_params": _global_state.tokenizer_manager.server_args.preferred_sampling_params,
"weight_version": _global_state.tokenizer_manager.server_args.weight_version,
}
return result
@app.get("/get_weight_version")
async def get_weight_version():
"""Get the current weight version."""
return {
"weight_version": _global_state.tokenizer_manager.server_args.weight_version
}
@app.get("/get_server_info")
async def get_server_info():
# Returns interna states per DP.
@@ -537,6 +547,12 @@ async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: R
success, message, num_paused_requests = (
await _global_state.tokenizer_manager.update_weights_from_disk(obj, request)
)
# Update weight version if provided and weights update was successful
if success and obj.weight_version is not None:
_update_weight_version_if_provided(obj.weight_version)
message += f" Weight version updated to {obj.weight_version}."
content = {
"success": success,
"message": message,
@@ -583,6 +599,12 @@ async def update_weights_from_tensor(
success, message = await _global_state.tokenizer_manager.update_weights_from_tensor(
obj, request
)
# Update weight version if provided and weights update was successful
if success and obj.weight_version is not None:
_update_weight_version_if_provided(obj.weight_version)
message += f" Weight version updated to {obj.weight_version}."
content = {"success": success, "message": message}
return ORJSONResponse(
content, status_code=200 if success else HTTPStatus.BAD_REQUEST
@@ -599,6 +621,12 @@ async def update_weights_from_distributed(
obj, request
)
)
# Update weight version if provided and weights update was successful
if success and obj.weight_version is not None:
_update_weight_version_if_provided(obj.weight_version)
message += f" Weight version updated to {obj.weight_version}."
content = {"success": success, "message": message}
if success:
return ORJSONResponse(content, status_code=200)
@@ -606,6 +634,36 @@ async def update_weights_from_distributed(
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
@app.post("/update_weight_version")
async def update_weight_version(obj: UpdateWeightVersionReqInput, request: Request):
"""Update the weight version. This operation requires no active requests."""
if obj.abort_all_requests:
_global_state.tokenizer_manager.abort_request(abort_all=True)
# Use a simple approach without the complex lock mechanism for now
# since weight_version update is a simple operation that doesn't affect model weights
try:
# Update the weight version in server args (the single source of truth)
_global_state.tokenizer_manager.server_args.weight_version = obj.new_version
return ORJSONResponse(
{
"success": True,
"message": f"Weight version updated to {obj.new_version}",
"new_version": obj.new_version,
},
status_code=HTTPStatus.OK,
)
except Exception as e:
return ORJSONResponse(
{
"success": False,
"message": f"Failed to update weight version: {str(e)}",
},
status_code=HTTPStatus.BAD_REQUEST,
)
@app.api_route("/get_weights_by_name", methods=["GET", "POST"])
async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
"""Get model parameter by name."""
@@ -966,6 +1024,12 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque
return ORJSONResponse({"predictions": ret})
def _update_weight_version_if_provided(weight_version: Optional[str]) -> None:
"""Update weight version if provided."""
if weight_version is not None:
_global_state.tokenizer_manager.server_args.weight_version = weight_version
def _create_error_response(e):
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST

View File

@@ -240,6 +240,7 @@ class CompletionResponse(BaseModel):
model: str
choices: List[CompletionResponseChoice]
usage: UsageInfo
metadata: Optional[Dict[str, Any]] = None
class CompletionResponseStreamChoice(BaseModel):
@@ -517,6 +518,7 @@ class ChatCompletionResponse(BaseModel):
model: str
choices: List[ChatCompletionResponseChoice]
usage: UsageInfo
metadata: Optional[Dict[str, Any]] = None
class DeltaMessage(BaseModel):

View File

@@ -723,6 +723,7 @@ class OpenAIServingChat(OpenAIServingBase):
model=request.model,
choices=choices,
usage=usage,
metadata={"weight_version": ret[0]["meta_info"]["weight_version"]},
)
def _process_logprobs_tokens(

View File

@@ -373,6 +373,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
created=created,
choices=choices,
usage=usage,
metadata={"weight_version": ret[0]["meta_info"]["weight_version"]},
)
def _get_echo_text(self, request: CompletionRequest, index: int) -> str:

View File

@@ -798,6 +798,8 @@ class UpdateWeightFromDiskReqInput:
load_format: Optional[str] = None
# Whether to abort all requests before updating weights
abort_all_requests: bool = False
# Optional: Update weight version along with weights
weight_version: Optional[str] = None
@dataclass
@@ -819,6 +821,8 @@ class UpdateWeightsFromDistributedReqInput:
flush_cache: bool = True
# Whether to abort all requests before updating weights
abort_all_requests: bool = False
# Optional: Update weight version along with weights
weight_version: Optional[str] = None
@dataclass
@@ -842,6 +846,8 @@ class UpdateWeightsFromTensorReqInput:
flush_cache: bool = True
# Whether to abort all requests before updating weights
abort_all_requests: bool = False
# Optional: Update weight version along with weights
weight_version: Optional[str] = None
@dataclass
@@ -872,6 +878,14 @@ class InitWeightsUpdateGroupReqOutput:
message: str
@dataclass
class UpdateWeightVersionReqInput:
# The new weight version
new_version: str
# Whether to abort all running requests before updating
abort_all_requests: bool = True
@dataclass
class GetWeightsByNameReqInput:
name: str

View File

@@ -1529,6 +1529,7 @@ class TokenizerManager:
"id": rid,
"finish_reason": recv_obj.finished_reasons[i],
"prompt_tokens": recv_obj.prompt_tokens[i],
"weight_version": self.server_args.weight_version,
}
if getattr(state.obj, "return_logprob", False):

View File

@@ -124,6 +124,7 @@ class ServerArgs:
# API related
api_key: Optional[str] = None
served_model_name: Optional[str] = None
weight_version: str = "default"
chat_template: Optional[str] = None
completion_template: Optional[str] = None
file_storage_path: str = "sglang_storage"
@@ -1163,6 +1164,12 @@ class ServerArgs:
default=ServerArgs.served_model_name,
help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
)
parser.add_argument(
"--weight-version",
type=str,
default=ServerArgs.weight_version,
help="Version identifier for the model weights. Defaults to 'default' if not specified.",
)
parser.add_argument(
"--chat-template",
type=str,