Add Gemma2 (#592)
This commit is contained in:
@@ -165,6 +165,7 @@ def decode(input_token_ids, batch, model_runner):
|
||||
return next_token_ids, output.next_token_logits
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def correctness_test(
|
||||
server_args,
|
||||
bench_args,
|
||||
@@ -178,9 +179,10 @@ def correctness_test(
|
||||
# Prepare inputs
|
||||
input_ids, reqs = prepare_inputs(bench_args, tokenizer)
|
||||
|
||||
# Prefill
|
||||
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
|
||||
rank_print("prefill logits (first half)", next_token_logits)
|
||||
if bench_args.cut_len > 0:
|
||||
# Prefill
|
||||
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
|
||||
rank_print("prefill logits (first half)", next_token_logits)
|
||||
|
||||
# Prepare extend inputs
|
||||
reqs = prepare_extend_inputs(bench_args, input_ids, reqs, model_runner)
|
||||
@@ -190,7 +192,7 @@ def correctness_test(
|
||||
rank_print("prefill logits (final)", next_token_logits)
|
||||
|
||||
# Decode
|
||||
output_ids = [list(req.input_ids) for req in reqs]
|
||||
output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
|
||||
for _ in range(bench_args.output_len):
|
||||
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
||||
for i in range(len(reqs)):
|
||||
|
||||
Reference in New Issue
Block a user