[EAGLE] many fixes for eagle (#4195)
Co-authored-by: SangBin Cho <rkooo567@gmail.com> Co-authored-by: Sehoon Kim <sehoon@x.ai>
This commit is contained in:
@@ -123,7 +123,7 @@ class TestEAGLEEngine(unittest.TestCase):
|
||||
def _test_acc_length(self, engine):
|
||||
prompt = [
|
||||
"Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:"
|
||||
]
|
||||
] * 5
|
||||
sampling_params = {"temperature": 0, "max_new_tokens": 512}
|
||||
output = engine.generate(prompt, sampling_params)
|
||||
output = output[0]
|
||||
@@ -141,10 +141,14 @@ class TestEAGLEEngine(unittest.TestCase):
|
||||
/ output["meta_info"]["e2e_latency"]
|
||||
)
|
||||
print(f"{acc_length=}")
|
||||
self.assertGreater(acc_length, 3.6)
|
||||
|
||||
if engine.server_args.model_path == DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST:
|
||||
self.assertGreater(acc_length, 3.6)
|
||||
else:
|
||||
self.assertGreater(acc_length, 2.6)
|
||||
|
||||
|
||||
class TestEAGLEEngineTokenMap(unittest.TestCase):
|
||||
class TestEAGLEEngineTokenMap(TestEAGLEEngine):
|
||||
BASE_CONFIG = {
|
||||
"model_path": "meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
"speculative_draft_model_path": "lmsys/sglang-EAGLE-LLaMA3-Instruct-8B",
|
||||
@@ -155,6 +159,7 @@ class TestEAGLEEngineTokenMap(unittest.TestCase):
|
||||
"speculative_token_map": "thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt",
|
||||
"mem_fraction_static": 0.7,
|
||||
"cuda_graph_max_bs": 5,
|
||||
"dtype": "float16",
|
||||
}
|
||||
NUM_CONFIGS = 1
|
||||
|
||||
@@ -245,8 +250,25 @@ class TestEAGLEServer(unittest.TestCase):
|
||||
for p in threads:
|
||||
p.join()
|
||||
|
||||
def test_max_token_one(self):
|
||||
requests.get(self.base_url + "/flush_cache")
|
||||
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=200,
|
||||
max_new_tokens=1,
|
||||
parallel=128,
|
||||
host="http://127.0.0.1",
|
||||
port=int(self.base_url.split(":")[-1]),
|
||||
)
|
||||
|
||||
# Just run and check it does not hang
|
||||
metrics = run_eval(args)
|
||||
self.assertGreater(metrics["output_throughput"], 50)
|
||||
|
||||
def test_gsm8k(self):
|
||||
server_info = requests.get(self.base_url + "/flush_cache")
|
||||
requests.get(self.base_url + "/flush_cache")
|
||||
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
@@ -391,6 +413,53 @@ class TestEAGLEServer(unittest.TestCase):
|
||||
with ThreadPoolExecutor(8) as executor:
|
||||
list(executor.map(func, args))
|
||||
|
||||
def run_decode(self, sampling_params):
|
||||
return_logprob = True
|
||||
top_logprobs_num = 5
|
||||
return_text = True
|
||||
n = 1
|
||||
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"text": "Human: Write a travel blog post to Hawaii.\n\nAssistant:",
|
||||
"sampling_params": {
|
||||
"max_new_tokens": 48,
|
||||
"n": n,
|
||||
"temperature": 0.7,
|
||||
**sampling_params,
|
||||
},
|
||||
"return_logprob": return_logprob,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
"return_text_in_logprobs": return_text,
|
||||
"logprob_start_len": 0,
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
print(json.dumps(response.json()))
|
||||
print("=" * 100)
|
||||
|
||||
def test_penalty_mixed(self):
|
||||
args = [
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
{"frequency_penalty": 2},
|
||||
{"presence_penalty": 1},
|
||||
{"min_new_tokens": 16},
|
||||
{"frequency_penalty": 0.2},
|
||||
{"presence_penalty": 0.4},
|
||||
{"min_new_tokens": 8},
|
||||
{"frequency_penalty": 0.4, "presence_penalty": 0.8},
|
||||
{"frequency_penalty": 0.4, "min_new_tokens": 12},
|
||||
{"presence_penalty": 0.8, "min_new_tokens": 12},
|
||||
{"presence_penalty": -0.3, "frequency_penalty": 1.3, "min_new_tokens": 32},
|
||||
{"presence_penalty": 0.3, "frequency_penalty": -1.3, "min_new_tokens": 32},
|
||||
]
|
||||
random.shuffle(args * 5)
|
||||
with ThreadPoolExecutor(8) as executor:
|
||||
list(executor.map(self.run_decode, args))
|
||||
|
||||
|
||||
class TestEAGLERetract(TestEAGLEServer):
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user