Add Gemma2 (#592)

This commit is contained in:
Ying Sheng
2024-07-05 09:48:54 -07:00
committed by GitHub
parent d737da5f17
commit 5a57b8addd
7 changed files with 467 additions and 30 deletions

View File

@@ -165,6 +165,7 @@ def decode(input_token_ids, batch, model_runner):
return next_token_ids, output.next_token_logits
@torch.inference_mode()
def correctness_test(
server_args,
bench_args,
@@ -178,9 +179,10 @@ def correctness_test(
# Prepare inputs
input_ids, reqs = prepare_inputs(bench_args, tokenizer)
# Prefill
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
rank_print("prefill logits (first half)", next_token_logits)
if bench_args.cut_len > 0:
# Prefill
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
rank_print("prefill logits (first half)", next_token_logits)
# Prepare extend inputs
reqs = prepare_extend_inputs(bench_args, input_ids, reqs, model_runner)
@@ -190,7 +192,7 @@ def correctness_test(
rank_print("prefill logits (final)", next_token_logits)
# Decode
output_ids = [list(req.input_ids) for req in reqs]
output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
for _ in range(bench_args.output_len):
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
for i in range(len(reqs)):