Sync from v0.13
This commit is contained in:
153
tests/entrypoints/sagemaker/test_sagemaker_stateful_sessions.py
Normal file
153
tests/entrypoints/sagemaker/test_sagemaker_stateful_sessions.py
Normal file
@@ -0,0 +1,153 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
from .conftest import (
|
||||
HEADER_SAGEMAKER_CLOSED_SESSION_ID,
|
||||
HEADER_SAGEMAKER_NEW_SESSION_ID,
|
||||
HEADER_SAGEMAKER_SESSION_ID,
|
||||
MODEL_NAME_SMOLLM,
|
||||
)
|
||||
|
||||
CLOSE_BADREQUEST_CASES = [
|
||||
(
|
||||
"nonexistent_session_id",
|
||||
{"session_id": "nonexistent-session-id"},
|
||||
{},
|
||||
"session not found",
|
||||
),
|
||||
("malformed_close_request", {}, {"extra-field": "extra-field-data"}, None),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session_badrequest(basic_server_with_lora: RemoteOpenAIServer):
|
||||
bad_response = requests.post(
|
||||
basic_server_with_lora.url_for("invocations"),
|
||||
json={"requestType": "NEW_SESSION", "extra-field": "extra-field-data"},
|
||||
)
|
||||
|
||||
assert bad_response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"test_name,session_id_change,request_body_change,expected_error",
|
||||
CLOSE_BADREQUEST_CASES,
|
||||
)
|
||||
async def test_close_session_badrequest(
|
||||
basic_server_with_lora: RemoteOpenAIServer,
|
||||
test_name: str,
|
||||
session_id_change: dict[str, str],
|
||||
request_body_change: dict[str, str],
|
||||
expected_error: str | None,
|
||||
):
|
||||
# first attempt to create a session
|
||||
url = basic_server_with_lora.url_for("invocations")
|
||||
create_response = requests.post(url, json={"requestType": "NEW_SESSION"})
|
||||
create_response.raise_for_status()
|
||||
valid_session_id, expiration = create_response.headers.get(
|
||||
HEADER_SAGEMAKER_NEW_SESSION_ID, ""
|
||||
).split(";")
|
||||
assert valid_session_id
|
||||
|
||||
close_request_json = {"requestType": "CLOSE"}
|
||||
if request_body_change:
|
||||
close_request_json.update(request_body_change)
|
||||
bad_session_id = session_id_change.get("session_id")
|
||||
bad_close_response = requests.post(
|
||||
url,
|
||||
headers={HEADER_SAGEMAKER_SESSION_ID: bad_session_id or valid_session_id},
|
||||
json=close_request_json,
|
||||
)
|
||||
|
||||
# clean up created session, should succeed
|
||||
clean_up_response = requests.post(
|
||||
url,
|
||||
headers={HEADER_SAGEMAKER_SESSION_ID: valid_session_id},
|
||||
json={"requestType": "CLOSE"},
|
||||
)
|
||||
clean_up_response.raise_for_status()
|
||||
|
||||
assert bad_close_response.status_code == 400
|
||||
if expected_error:
|
||||
assert expected_error in bad_close_response.json()["error"]["message"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_session_invalidrequest(
|
||||
basic_server_with_lora: RemoteOpenAIServer, async_client: openai.AsyncOpenAI
|
||||
):
|
||||
# first attempt to create a session
|
||||
url = basic_server_with_lora.url_for("invocations")
|
||||
create_response = requests.post(url, json={"requestType": "NEW_SESSION"})
|
||||
create_response.raise_for_status()
|
||||
valid_session_id, expiration = create_response.headers.get(
|
||||
HEADER_SAGEMAKER_NEW_SESSION_ID, ""
|
||||
).split(";")
|
||||
assert valid_session_id
|
||||
|
||||
close_request_json = {"requestType": "CLOSE"}
|
||||
invalid_close_response = requests.post(
|
||||
url,
|
||||
# no headers to specify session_id
|
||||
json=close_request_json,
|
||||
)
|
||||
|
||||
# clean up created session, should succeed
|
||||
clean_up_response = requests.post(
|
||||
url,
|
||||
headers={HEADER_SAGEMAKER_SESSION_ID: valid_session_id},
|
||||
json={"requestType": "CLOSE"},
|
||||
)
|
||||
clean_up_response.raise_for_status()
|
||||
|
||||
assert invalid_close_response.status_code == 424
|
||||
assert "invalid session_id" in invalid_close_response.json()["error"]["message"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session(basic_server_with_lora: RemoteOpenAIServer):
|
||||
# first attempt to create a session
|
||||
url = basic_server_with_lora.url_for("invocations")
|
||||
create_response = requests.post(url, json={"requestType": "NEW_SESSION"})
|
||||
create_response.raise_for_status()
|
||||
valid_session_id, expiration = create_response.headers.get(
|
||||
HEADER_SAGEMAKER_NEW_SESSION_ID, ""
|
||||
).split(";")
|
||||
assert valid_session_id
|
||||
|
||||
# test invocation with session id
|
||||
|
||||
request_args = {
|
||||
"model": MODEL_NAME_SMOLLM,
|
||||
"prompt": "what is 1+1?",
|
||||
"max_completion_tokens": 5,
|
||||
"temperature": 0.0,
|
||||
"logprobs": False,
|
||||
}
|
||||
|
||||
invocation_response = requests.post(
|
||||
basic_server_with_lora.url_for("invocations"),
|
||||
headers={HEADER_SAGEMAKER_SESSION_ID: valid_session_id},
|
||||
json=request_args,
|
||||
)
|
||||
invocation_response.raise_for_status()
|
||||
|
||||
# close created session, should succeed
|
||||
close_response = requests.post(
|
||||
url,
|
||||
headers={HEADER_SAGEMAKER_SESSION_ID: valid_session_id},
|
||||
json={"requestType": "CLOSE"},
|
||||
)
|
||||
close_response.raise_for_status()
|
||||
|
||||
assert (
|
||||
close_response.headers.get(HEADER_SAGEMAKER_CLOSED_SESSION_ID)
|
||||
== valid_session_id
|
||||
)
|
||||
Reference in New Issue
Block a user