Fix cuda illegal memory access in overlap mode (#2070)
This commit is contained in:
@@ -1055,9 +1055,6 @@ class ScheduleBatch:
|
||||
)
|
||||
|
||||
def copy(self):
|
||||
# We need a stream synchronization here. Otherwise, there will be cuda illegal memory access errors.
|
||||
_ = self.seq_lens[0].item()
|
||||
|
||||
# Only contain fields that will be used by process_batch_result
|
||||
return ScheduleBatch(
|
||||
reqs=self.reqs,
|
||||
|
||||
@@ -390,6 +390,9 @@ class Scheduler:
|
||||
batch = self.get_next_batch_to_run()
|
||||
self.cur_batch = batch
|
||||
if batch:
|
||||
# We need a stream synchronization here. Otherwise, there will be cuda illegal memory access errors.
|
||||
_ = batch.seq_lens[0].item()
|
||||
|
||||
result = self.run_batch(batch)
|
||||
result_queue.append((batch.copy(), result))
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.test.few_shot_gsm8k_engine import run_eval
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
)
|
||||
@@ -43,7 +42,7 @@ class TestSRTEngine(unittest.TestCase):
|
||||
|
||||
print("==== Answer 2 ====")
|
||||
print(out2)
|
||||
assert out1 == out2, f"{out1} != {out2}"
|
||||
self.assertEqual(out1, out2)
|
||||
|
||||
def test_2_engine_multiple_generate(self):
|
||||
# just to ensure there is no issue running multiple generate calls
|
||||
@@ -106,14 +105,14 @@ class TestSRTEngine(unittest.TestCase):
|
||||
def test_4_gsm8k(self):
|
||||
|
||||
args = SimpleNamespace(
|
||||
model_path=DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
local_data_path=None,
|
||||
num_shots=5,
|
||||
num_questions=200,
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
assert metrics["accuracy"] > 0.7
|
||||
self.assertGreater(metrics["accuracy"], 0.3)
|
||||
|
||||
def test_5_prompt_input_ids_consistency(self):
|
||||
prompt = "The capital of UK is"
|
||||
@@ -136,7 +135,7 @@ class TestSRTEngine(unittest.TestCase):
|
||||
|
||||
print("==== Answer 2 ====")
|
||||
print(out2)
|
||||
assert out1 == out2, f"{out1} != {out2}"
|
||||
self.assertEqual(out1, out2)
|
||||
|
||||
def test_6_engine_runtime_encode_consistency(self):
|
||||
prompt = "Today is a sunny day and I like"
|
||||
@@ -156,11 +155,11 @@ class TestSRTEngine(unittest.TestCase):
|
||||
|
||||
def test_7_engine_offline_throughput(self):
|
||||
server_args = ServerArgs(
|
||||
model_path=DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
)
|
||||
bench_args = BenchArgs(num_prompts=100)
|
||||
bench_args = BenchArgs(num_prompts=10)
|
||||
result = throughput_test(server_args=server_args, bench_args=bench_args)
|
||||
self.assertTrue(result["total_throughput"] > 3000)
|
||||
self.assertGreater(result["total_throughput"], 3500)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user