Fix cuda illegal memory access in overlap mode (#2070)
This commit is contained in:
@@ -16,7 +16,6 @@ from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.test.few_shot_gsm8k_engine import run_eval
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
)
|
||||
@@ -43,7 +42,7 @@ class TestSRTEngine(unittest.TestCase):
|
||||
|
||||
print("==== Answer 2 ====")
|
||||
print(out2)
|
||||
assert out1 == out2, f"{out1} != {out2}"
|
||||
self.assertEqual(out1, out2)
|
||||
|
||||
def test_2_engine_multiple_generate(self):
|
||||
# just to ensure there is no issue running multiple generate calls
|
||||
@@ -106,14 +105,14 @@ class TestSRTEngine(unittest.TestCase):
|
||||
def test_4_gsm8k(self):
|
||||
|
||||
args = SimpleNamespace(
|
||||
model_path=DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
local_data_path=None,
|
||||
num_shots=5,
|
||||
num_questions=200,
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
assert metrics["accuracy"] > 0.7
|
||||
self.assertGreater(metrics["accuracy"], 0.3)
|
||||
|
||||
def test_5_prompt_input_ids_consistency(self):
|
||||
prompt = "The capital of UK is"
|
||||
@@ -136,7 +135,7 @@ class TestSRTEngine(unittest.TestCase):
|
||||
|
||||
print("==== Answer 2 ====")
|
||||
print(out2)
|
||||
assert out1 == out2, f"{out1} != {out2}"
|
||||
self.assertEqual(out1, out2)
|
||||
|
||||
def test_6_engine_runtime_encode_consistency(self):
|
||||
prompt = "Today is a sunny day and I like"
|
||||
@@ -156,11 +155,11 @@ class TestSRTEngine(unittest.TestCase):
|
||||
|
||||
def test_7_engine_offline_throughput(self):
|
||||
server_args = ServerArgs(
|
||||
model_path=DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
)
|
||||
bench_args = BenchArgs(num_prompts=100)
|
||||
bench_args = BenchArgs(num_prompts=10)
|
||||
result = throughput_test(server_args=server_args, bench_args=bench_args)
|
||||
self.assertTrue(result["total_throughput"] > 3000)
|
||||
self.assertGreater(result["total_throughput"], 3500)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user