diff --git a/lm_eval/tasks/llama3/instruct/mgsm_chat.yaml b/lm_eval/tasks/llama3/instruct/mgsm_chat.yaml new file mode 100644 index 0000000000..6fdafc7bde --- /dev/null +++ b/lm_eval/tasks/llama3/instruct/mgsm_chat.yaml @@ -0,0 +1,43 @@ +tag: llama3 +task: mgsm_chat +dataset_path: meta-llama/Llama-3.2-3B-Instruct-evals +dataset_name: Llama-3.2-3B-Instruct-evals__mgsm__details +output_type: generate_until +test_split: latest +doc_to_text: "{{ + input_final_prompts + |replace('<|eot_id|><|start_header_id|>assistant<|end_header_id|>', '') + |replace('<|start_header_id|>', '') + |replace('<|end_header_id|>', '') + |replace('<|eot_id|>', '') + |trim +}}" +doc_to_target: "input_correct_responses" +process_results: !function utils.process_results_mgsm +generation_kwargs: + until: [] + do_sample: false + temperature: 0.0 + max_gen_toks: 2048 +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true + ignore_case: true + ignore_punctuation: true +filter_list: + - name: "strict-match" + filter: + - function: "regex" + regex_pattern: "Answer: (\\-?[0-9\\.\\,]+)" + - function: "take_first" + - filter: + - function: regex + group_select: -1 + regex_pattern: "Answer: (-?[$0-9.,]{2,})|(-?[0-9]+)" + - function: take_first + - function: remove_whitespace + - function: take_first + name: flexible-extract +metadata: + version: 2.0 diff --git a/lm_eval/tasks/llama3/instruct/utils.py b/lm_eval/tasks/llama3/instruct/utils.py new file mode 100644 index 0000000000..4a33d71dc3 --- /dev/null +++ b/lm_eval/tasks/llama3/instruct/utils.py @@ -0,0 +1,15 @@ +from typing import List + +from lm_eval.api.metrics import exact_match_fn + + +def process_results_mgsm(doc, prediction): + gold: List = doc["input_correct_responses"] + return { + "exact_match": int( + exact_match_fn( + predictions=prediction * len(gold), references=gold, ignore_case=True + )["exact_match"] + > 0 + ) + }