diff --git a/benchmark/reasoning_benchmark/README.md b/benchmark/reasoning_benchmark/README.md new file mode 100644 index 000000000..539c2e710 --- /dev/null +++ b/benchmark/reasoning_benchmark/README.md @@ -0,0 +1,54 @@ +# Run benchmark + +This benchmark is primarily intended to be used with reasoning models like `DeepSeek-R1` and its distilled models like `DeepSeek-R1-Distill-Qwen-1.5B`. Please use + +```bash +pip install antlr4-python3-runtime +``` + +for `parse_latex` which we use for symbolic equality check. + +## Benchmark sglang + +1. Launch the Server + +```bash +python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --port 30000 +``` + +Note that depending on the GPU this benchmark will take quiet some time. To employ data parallelism please use: + +```bash +python3 -m sglang_router.launch_server --model-path deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --port 30000 --dp-size 4 +``` + +2. Benchmarking + +We use [suggested](https://github.com/deepseek-ai/DeepSeek-R1) parameters of `temperature=0.6`, `top_p=.95`, `max_new_tokens=32768`. The command line argument `num-tries` can be used to evaluate the model multiple times on the same question. We use the suggested `64` from the repo for AIME 2024. For LIMO, we use `8` as the number of tries due to the size of the dataset. + +By default evaluate on LIMO dataset. + +```bash +python3 bench_sglang.py --parallel 256 --num-tries 64 --port 30000 +``` + +Evaluate on AIME 2024 dataset. + +```bash +python3 bench_sglang.py --parallel 256 --port 30000 --data-path Maxwell-Jia/AIME_2024 --question-key Problem --answer-key Answer --num-tries 64 +``` + +Evaluate on [AIME 2025 I dataset](https://huggingface.co/datasets/opencompass/AIME2025). For benchmark result see [here](https://matharena.ai/). + +```bash +python3 bench_sglang.py --parallel 256 --port 30000 --data-path opencompass/AIME2025 --question-key question --answer-key answer --num-tries 64 +``` + + +## Results + +| Dataset | Num Tries | Accuracy | Reference | +|------------|-----------|----------|-----------| +| LIMO | 8 | 47.7% | ? | +| AIME 2024 | 64 | 33.2% | 28.9% | +| AIME 2025 I| 64 | 29.9% | 25.0% | diff --git a/benchmark/reasoning_benchmark/answer_extraction.py b/benchmark/reasoning_benchmark/answer_extraction.py new file mode 100644 index 000000000..45ce59afb --- /dev/null +++ b/benchmark/reasoning_benchmark/answer_extraction.py @@ -0,0 +1,269 @@ +# Adapted from https://github.com/deepseek-ai/DeepSeek-Math/blob/main/evaluation/data_processing/answer_extraction.py + +import re + +import regex + + +def _fix_fracs(string): + substrs = string.split("\\frac") + new_str = substrs[0] + if len(substrs) > 1: + substrs = substrs[1:] + for substr in substrs: + new_str += "\\frac" + if len(substr) > 0 and substr[0] == "{": + new_str += substr + else: + try: + assert len(substr) >= 2 + except: + return string + a = substr[0] + b = substr[1] + if b != "{": + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}{" + b + "}" + post_substr + else: + new_str += "{" + a + "}{" + b + "}" + else: + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}" + b + post_substr + else: + new_str += "{" + a + "}" + b + string = new_str + return string + + +def _fix_a_slash_b(string): + if len(string.split("/")) != 2: + return string + a = string.split("/")[0] + b = string.split("/")[1] + try: + if "sqrt" not in a: + a = int(a) + if "sqrt" not in b: + b = int(b) + assert string == "{}/{}".format(a, b) + new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" + return new_string + except: + return string + + +def _fix_sqrt(string): + _string = re.sub(r"\\sqrt(-?[0-9.a-zA-Z]+)", r"\\sqrt{\1}", string) + _string = re.sub(r"\\sqrt\s+(\w+)$", r"\\sqrt{\1}", _string) + return _string + + +def _fix_tan(string): + _string = re.sub(r"\\tan(-?[0-9.a-zA-Z]+)", r"\\tan{\1}", string) + _string = re.sub(r"\\tan\s+(\w+)$", r"\\tan{\1}", _string) + return _string + + +def strip_string(string): + string = str(string).strip() + # linebreaks + string = string.replace("\n", "") + + # right "." + string = string.rstrip(".") + + # remove inverse spaces + string = string.replace("\\!", "") + # string = string.replace("\\ ", "") + + # replace \\ with \ + # string = string.replace("\\\\", "\\") + # string = string.replace("\\\\", "\\") + + if string.startswith("\\text{") and string.endswith("}"): + string = string.split("{", 1)[1][:-1] + + # replace tfrac and dfrac with frac + string = string.replace("tfrac", "frac") + string = string.replace("dfrac", "frac") + string = string.replace("cfrac", "frac") + + # remove \left and \right + string = string.replace("\\left", "") + string = string.replace("\\right", "") + + # Remove unit: miles, dollars if after is not none + _string = re.sub(r"\\text{.*?}$", "", string).strip() + if _string != "" and _string != string: + # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string)) + string = _string + + # Remove circ (degrees) + string = string.replace("^{\\circ}", "").strip() + string = string.replace("^\\circ", "").strip() + + string = regex.sub(r"\{(c|m)?m\}(\^(2|3))?", "", string).strip() + string = regex.sub(r"p\.m\.$", "", string).strip() + string = regex.sub(r"(\d)\s*t$", r"\1", string).strip() + + # remove dollar signs + string = string.replace("\\$", "") + string = string.replace("$", "") + + # string = string.replace("\\text", "") + string = string.replace("x\\in", "") + + # remove percentage + string = string.replace("\\%", "%") + string = string.replace("\%", "%") + # string = string.replace("%", "") + + # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string + string = string.replace(" .", " 0.") + string = string.replace("{.", "{0.") + + # cdot + string = string.replace("\\cdot", "") + + # inf + string = string.replace("infinity", "\\infty") + if "\\infty" not in string: + string = string.replace("inf", "\\infty") + string = string.replace("+\\inity", "\\infty") + + # and + # string = string.replace("and", "") + string = string.replace("\\mathbf", "") + string = string.replace("\\mathrm", "") + + # use regex to remove \mbox{...} + string = re.sub(r"\\mbox{.*?}", "", string) + + # quote + string.replace("'", "") + string.replace('"', "") + + # i, j + if "j" in string and "i" not in string: + string = string.replace("j", "i") + + # replace a.000b where b is not number or b is end, with ab, use regex + string = re.sub(r"(\d+)\.0+([^\d])", r"\1\2", string) + string = re.sub(r"(\d+)\.0+$", r"\1", string) + + # if empty, return empty string + if len(string) == 0: + return string + if string[0] == ".": + string = "0" + string + + # to consider: get rid of e.g. "k = " or "q = " at beginning + # if len(string.split("=")) == 2: + # if len(string.split("=")[0]) <= 2: + # string = string.split("=")[1] + + string = _fix_sqrt(string) + string = _fix_tan(string) + string = string.replace(" ", "") + + # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} + string = _fix_fracs(string) + + # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y + string = _fix_a_slash_b(string) + + string = regex.sub(r"(\\|,|\.)+$", "", string) + + return string + + +def extract_boxed_answers(text): + answers = [] + for piece in text.split("boxed{")[1:]: + n = 0 + for i in range(len(piece)): + if piece[i] == "{": + n += 1 + elif piece[i] == "}": + n -= 1 + if n < 0: + if i + 1 < len(piece) and piece[i + 1] == "%": + answers.append(piece[: i + 1]) + else: + answers.append(piece[:i]) + break + return answers + + +def extract_program_output(pred_str): + """ + extract output between the last ```output\n...\n``` + """ + if "```output" not in pred_str: + return "" + if "```output" in pred_str: + pred_str = pred_str.split("```output")[-1] + if "```" in pred_str: + pred_str = pred_str.split("```")[0] + output = pred_str.strip() + return output + + +def extract_answer(pred_str, exhaust=False): + pred = [] + if "final answer is $" in pred_str and "$. I hope" in pred_str: + tmp = pred_str.split("final answer is $", 1)[1] + pred = [tmp.split("$. I hope", 1)[0].strip()] + elif "boxed" in pred_str: + pred = extract_boxed_answers(pred_str) + elif "he answer is" in pred_str: + pred = [pred_str.split("he answer is")[-1].strip()] + else: + program_output = extract_program_output(pred_str) + if program_output != "": + # fall back to program + pred.append(program_output) + else: # use the last number + pattern = "-?\d*\.?\d+" + ans = re.findall(pattern, pred_str.replace(",", "")) + if len(ans) >= 1: + ans = ans[-1] + else: + ans = "" + if ans: + pred.append(ans) + + # multiple line + _pred = [] + for ans in pred: + ans = ans.strip().split("\n")[0] + ans = ans.lstrip(":") + ans = ans.rstrip(".") + ans = ans.rstrip("/") + ans = strip_string(ans) + _pred.append(ans) + if exhaust: + return _pred + else: + return _pred[-1] if _pred else "" + + +def extract_math_answer(question, reasoning, task): + answer = [] + for ans in extract_answer(reasoning, exhaust=True): + if "separated by commas" in question and all(ch not in ans for ch in "()[]"): + answer.extend([a.strip() for a in ans.split(",")]) + elif regex.search(r"\\text\{\s*and\s*\}", ans): + answer.extend( + [ + a.strip() + for a in regex.sub(r"\\text\{\s*and\s*\}", "[SEP]", ans).split( + "[SEP]" + ) + ] + ) + else: + answer.append(ans.strip()) + return answer diff --git a/benchmark/reasoning_benchmark/bench_sglang.py b/benchmark/reasoning_benchmark/bench_sglang.py new file mode 100644 index 000000000..de0f13971 --- /dev/null +++ b/benchmark/reasoning_benchmark/bench_sglang.py @@ -0,0 +1,119 @@ +import argparse +import json +import time + +import answer_extraction +import eval_utils +from datasets import load_dataset + +import sglang as sgl +from sglang.test.test_utils import ( + add_common_sglang_args_and_parse, + select_sglang_backend, +) +from sglang.utils import dump_state_text + + +@sgl.function +def reasoning_gen(s, question: str): + s += sgl.user( + question + + "\nPlease reason step by step, and put your final answer within \boxed{}." + ) + s += sgl.assistant( + sgl.gen( + "answer", + ) + ) + + +def convert_dataset(path: str, question_key: str, answer_key: str, num_tries: int): + raw_dataset = load_dataset(path) + questions = [] + answers = [] + for data in raw_dataset["train"]: + question = data[question_key] + answer = data[answer_key] + for _ in range(num_tries): + questions.append({"question": question}) + answers.append({"answer": answer}) + return questions, answers + + +def main(args): + # Select backend + sgl.set_default_backend(select_sglang_backend(args)) + + # Get dataset + questions, answers = convert_dataset( + args.data_path, args.question_key, args.answer_key, args.num_tries + ) + + # Run requests + tic = time.time() + states = reasoning_gen.run_batch( + questions, + num_threads=args.parallel, + progress_bar=True, + temperature=0.6, + max_new_tokens=32768, + top_p=0.95, + ) + latency = time.time() - tic + + # Extract answers + correct = 0 + for i, state in enumerate(states): + try: + pred_answer = answer_extraction.extract_math_answer( + questions[i]["question"], state["answer"], "limo" + ) + gt_answer = str(answers[i]["answer"]) + # Use last answer if multiple were extracted + pred_answer = ( + pred_answer[-1] if isinstance(pred_answer, list) else pred_answer + ) + correct += 1 if eval_utils.math_equal(pred_answer, gt_answer) else 0 + except Exception as e: + print(f"Error extracting answer: {e}") + pass + + # Calculate accuracy + accuracy = correct / len(questions) + print(f"Accuracy: {accuracy}") + + # Calculate output throughput + num_output_tokens = sum( + s.get_meta_info("answer")["completion_tokens"] for s in states + ) + output_throughput = num_output_tokens / latency + print(f"Output throughput: {output_throughput} token/s") + + # Dump results + dump_state_text(f"tmp_output_{args.backend}.txt", states) + + # Write results + with open(args.result_file, "a") as fout: + value = { + "task": "limo", + "backend": args.backend, + "latency": round(latency, 3), + "accuracy": round(accuracy, 3), + "num_requests": len(questions), + "other": { + "num_questions": len(questions), + "parallel": args.parallel, + }, + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=str, default="GAIR/LIMO") + parser.add_argument("--question-key", type=str, default="question") + parser.add_argument("--answer-key", type=str, default="answer") + parser.add_argument("--num-tries", type=int, default=1) + add_common_sglang_args_and_parse(parser) + args = parser.parse_args() + main(args) diff --git a/benchmark/reasoning_benchmark/eval_utils.py b/benchmark/reasoning_benchmark/eval_utils.py new file mode 100644 index 000000000..ab736954f --- /dev/null +++ b/benchmark/reasoning_benchmark/eval_utils.py @@ -0,0 +1,206 @@ +# Adapted from https://github.com/deepseek-ai/DeepSeek-Math/blob/main/evaluation/eval/eval_utils.py + +from math import isclose + +import regex +from sympy import N, simplify +from sympy.parsing.latex import parse_latex +from sympy.parsing.sympy_parser import parse_expr + + +def parse_digits(num): + # format: 234.23 || 23% + num = regex.sub(",", "", str(num)) + try: + return float(num) + except: + if num.endswith("%"): + num = num[:-1] + if num.endswith("\\"): + num = num[:-1] + try: + return float(num) / 100 + except: + pass + return None + + +def is_digit(num): + # paired with parse_digits + return parse_digits(num) is not None + + +def symbolic_equal(a, b): + def _parse(s): + for f in [parse_latex, parse_expr]: + try: + return f(s) + except: + pass + return s + + a = _parse(a) + b = _parse(b) + + try: + if simplify(a - b) == 0: + return True + except: + pass + + try: + if isclose(N(a), N(b), abs_tol=1e-3): + return True + except: + pass + return False + + +def math_equal(prediction, reference, include_percentage=True, is_close=True): + """ + Exact match of math if and only if: + 1. numerical equal: both can convert to float and are equal + 2. symbolic equal: both can convert to sympy expression and are equal + """ + if str(prediction) == str(reference): + return True + + try: # 1. numerical equal + if is_digit(prediction) and is_digit(reference): + prediction = parse_digits(prediction) + reference = parse_digits(reference) + # number questions + if include_percentage: + gt_result = [reference / 100, reference, reference * 100] + else: + gt_result = [reference] + for item in gt_result: + try: + if is_close: + if isclose(item, prediction, abs_tol=1e-3): + return True + else: + if item == prediction: + return True + except Exception: + continue + return False + except: + pass + + if not prediction and prediction not in [0, False]: + return False + + # 2. symbolic equal + reference = str(reference).strip() + prediction = str(prediction).strip() + + if ( + regex.match(r"(\(|\[).+(\)|\])", prediction) is not None + and regex.match(r"(\(|\[).+(\)|\])", reference) is not None + ): + pred_parts = prediction[1:-1].split(",") + ref_parts = reference[1:-1].split(",") + if len(pred_parts) == len(ref_parts): + if all( + [ + math_equal( + pred_parts[i], ref_parts[i], include_percentage, is_close + ) + for i in range(len(pred_parts)) + ] + ): + return True + + # Add back matrix comparison + if ( + ( + prediction.startswith("\\begin{pmatrix}") + or prediction.startswith("\\begin{bmatrix}") + ) + and ( + prediction.endswith("\\end{pmatrix}") + or prediction.endswith("\\end{bmatrix}") + ) + and ( + reference.startswith("\\begin{pmatrix}") + or reference.startswith("\\begin{bmatrix}") + ) + and ( + reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}") + ) + ): + pred_lines = [ + line.strip() + for line in prediction[ + len("\\begin{pmatrix}") : -len("\\end{pmatrix}") + ].split("\\\\") + if line.strip() + ] + ref_lines = [ + line.strip() + for line in reference[ + len("\\begin{pmatrix}") : -len("\\end{pmatrix}") + ].split("\\\\") + if line.strip() + ] + matched = True + if len(pred_lines) == len(ref_lines): + for pred_line, ref_line in zip(pred_lines, ref_lines): + pred_parts = pred_line.split("&") + ref_parts = ref_line.split("&") + if len(pred_parts) == len(ref_parts): + if not all( + [ + math_equal( + pred_parts[i], + ref_parts[i], + include_percentage, + is_close, + ) + for i in range(len(pred_parts)) + ] + ): + matched = False + break + else: + matched = False + if not matched: + break + else: + matched = False + if matched: + return True + + # Add back equation comparison + if prediction.count("=") == 1 and reference.count("=") == 1: + pred = prediction.split("=") + pred = f"{pred[0].strip()} - ({pred[1].strip()})" + ref = reference.split("=") + ref = f"{ref[0].strip()} - ({ref[1].strip()})" + if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref): + return True + elif ( + prediction.count("=") == 1 + and len(prediction.split("=")[0].strip()) <= 2 + and "=" not in reference + ): + if math_equal( + prediction.split("=")[1], reference, include_percentage, is_close + ): + return True + elif ( + reference.count("=") == 1 + and len(reference.split("=")[0].strip()) <= 2 + and "=" not in prediction + ): + if math_equal( + prediction, reference.split("=")[1], include_percentage, is_close + ): + return True + + # symbolic equal with sympy + if symbolic_equal(prediction, reference): + return True + + return False