[kernel] Use sgl_kernel rope (#3169)
Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
@@ -54,6 +54,7 @@ class TestSessionControl(unittest.TestCase):
|
||||
chunks_ids[i] = chunks_ids[i][1:]
|
||||
|
||||
# 1. using session control
|
||||
requests.post(self.base_url + "/flush_cache")
|
||||
session_id = requests.post(
|
||||
self.base_url + "/open_session",
|
||||
json={"capacity_of_str_len": 1000},
|
||||
@@ -215,7 +216,9 @@ class TestSessionControl(unittest.TestCase):
|
||||
print(outputs_from_session)
|
||||
print("outputs from normal queries:")
|
||||
print(outputs_normal)
|
||||
assert outputs_from_session == outputs_normal
|
||||
assert (
|
||||
outputs_from_session == outputs_normal
|
||||
), f"outputs_from_session: {outputs_from_session}, outputs_normal: {outputs_normal}"
|
||||
|
||||
async def async_generate(self, payload):
|
||||
url = self.base_url + "/generate"
|
||||
@@ -250,6 +253,7 @@ class TestSessionControl(unittest.TestCase):
|
||||
chunks_ids[i] = chunks_ids[i][1:]
|
||||
|
||||
# 1. using session control
|
||||
requests.post(self.base_url + "/flush_cache")
|
||||
session_id = requests.post(
|
||||
self.base_url + "/open_session",
|
||||
json={"capacity_of_str_len": 1000},
|
||||
@@ -320,6 +324,7 @@ class TestSessionControl(unittest.TestCase):
|
||||
assert response["meta_info"]["finish_reason"]["type"] == "abort"
|
||||
else:
|
||||
# 2. not using session control
|
||||
requests.post(self.base_url + "/flush_cache")
|
||||
output_ids = tokenizer.encode(gen_so_far)
|
||||
if output_ids[0] == tokenizer.bos_token_id:
|
||||
output_ids = output_ids[1:]
|
||||
@@ -342,7 +347,9 @@ class TestSessionControl(unittest.TestCase):
|
||||
output_no_session = response["text"]
|
||||
print("second request output without session:")
|
||||
print(output_no_session)
|
||||
assert second_output == output_no_session
|
||||
assert (
|
||||
second_output == output_no_session
|
||||
), f"second_output: {second_output}, output_no_session: {output_no_session}"
|
||||
|
||||
def test_session_control_backtrack_with_abort(self):
|
||||
asyncio.run(self.run_session_control_backtrack_with_abort(replace=True))
|
||||
@@ -355,6 +362,7 @@ class TestSessionControl(unittest.TestCase):
|
||||
assert len(x) == len(chunks_per_step[0])
|
||||
|
||||
# 1. using session control
|
||||
requests.post(self.base_url + "/flush_cache")
|
||||
session_id = requests.post(
|
||||
self.base_url + "/open_session",
|
||||
json={"capacity_of_str_len": 1000},
|
||||
@@ -459,7 +467,9 @@ class TestSessionControl(unittest.TestCase):
|
||||
print(outputs_from_session)
|
||||
print("====== outputs from normal queries: =======")
|
||||
print(outputs_normal)
|
||||
assert outputs_from_session == outputs_normal
|
||||
assert (
|
||||
outputs_from_session == outputs_normal
|
||||
), f"outputs_from_session: {outputs_from_session}, outputs_normal: {outputs_normal}"
|
||||
|
||||
def test_session_control_with_branching(self):
|
||||
root_prompt = "First, let me explain in one sentence about AI"
|
||||
@@ -525,6 +535,7 @@ class TestSessionControlVision(unittest.TestCase):
|
||||
gen_len = 32
|
||||
|
||||
# 1. using session control
|
||||
requests.post(self.base_url + "/flush_cache")
|
||||
session_id = requests.post(
|
||||
self.base_url + "/open_session",
|
||||
json={"capacity_of_str_len": 1000},
|
||||
@@ -691,7 +702,9 @@ class TestSessionControlVision(unittest.TestCase):
|
||||
print(outputs_from_session)
|
||||
print("outputs from normal queries:")
|
||||
print(outputs_normal)
|
||||
assert outputs_from_session == outputs_normal
|
||||
assert (
|
||||
outputs_from_session == outputs_normal
|
||||
), f"outputs_from_session: {outputs_from_session}, outputs_normal: {outputs_normal}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user