[kernel] Use sgl_kernel rope (#3169)

Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
Byron Hsu
2025-01-27 22:33:11 -08:00
committed by GitHub
parent 81262c7b72
commit 988d0a4bfc
2 changed files with 45 additions and 16 deletions

View File

@@ -54,6 +54,7 @@ class TestSessionControl(unittest.TestCase):
chunks_ids[i] = chunks_ids[i][1:]
# 1. using session control
requests.post(self.base_url + "/flush_cache")
session_id = requests.post(
self.base_url + "/open_session",
json={"capacity_of_str_len": 1000},
@@ -215,7 +216,9 @@ class TestSessionControl(unittest.TestCase):
print(outputs_from_session)
print("outputs from normal queries:")
print(outputs_normal)
assert outputs_from_session == outputs_normal
assert (
outputs_from_session == outputs_normal
), f"outputs_from_session: {outputs_from_session}, outputs_normal: {outputs_normal}"
async def async_generate(self, payload):
url = self.base_url + "/generate"
@@ -250,6 +253,7 @@ class TestSessionControl(unittest.TestCase):
chunks_ids[i] = chunks_ids[i][1:]
# 1. using session control
requests.post(self.base_url + "/flush_cache")
session_id = requests.post(
self.base_url + "/open_session",
json={"capacity_of_str_len": 1000},
@@ -320,6 +324,7 @@ class TestSessionControl(unittest.TestCase):
assert response["meta_info"]["finish_reason"]["type"] == "abort"
else:
# 2. not using session control
requests.post(self.base_url + "/flush_cache")
output_ids = tokenizer.encode(gen_so_far)
if output_ids[0] == tokenizer.bos_token_id:
output_ids = output_ids[1:]
@@ -342,7 +347,9 @@ class TestSessionControl(unittest.TestCase):
output_no_session = response["text"]
print("second request output without session:")
print(output_no_session)
assert second_output == output_no_session
assert (
second_output == output_no_session
), f"second_output: {second_output}, output_no_session: {output_no_session}"
def test_session_control_backtrack_with_abort(self):
asyncio.run(self.run_session_control_backtrack_with_abort(replace=True))
@@ -355,6 +362,7 @@ class TestSessionControl(unittest.TestCase):
assert len(x) == len(chunks_per_step[0])
# 1. using session control
requests.post(self.base_url + "/flush_cache")
session_id = requests.post(
self.base_url + "/open_session",
json={"capacity_of_str_len": 1000},
@@ -459,7 +467,9 @@ class TestSessionControl(unittest.TestCase):
print(outputs_from_session)
print("====== outputs from normal queries: =======")
print(outputs_normal)
assert outputs_from_session == outputs_normal
assert (
outputs_from_session == outputs_normal
), f"outputs_from_session: {outputs_from_session}, outputs_normal: {outputs_normal}"
def test_session_control_with_branching(self):
root_prompt = "First, let me explain in one sentence about AI"
@@ -525,6 +535,7 @@ class TestSessionControlVision(unittest.TestCase):
gen_len = 32
# 1. using session control
requests.post(self.base_url + "/flush_cache")
session_id = requests.post(
self.base_url + "/open_session",
json={"capacity_of_str_len": 1000},
@@ -691,7 +702,9 @@ class TestSessionControlVision(unittest.TestCase):
print(outputs_from_session)
print("outputs from normal queries:")
print(outputs_normal)
assert outputs_from_session == outputs_normal
assert (
outputs_from_session == outputs_normal
), f"outputs_from_session: {outputs_from_session}, outputs_normal: {outputs_normal}"
if __name__ == "__main__":