From 76e3a3229a836e003a895950827bb75d24a9b367 Mon Sep 17 00:00:00 2001 From: Marcos Martinez Date: Mon, 17 Oct 2022 14:18:40 +0100 Subject: [PATCH] Feat/zshot version (#17) * :art: Improved structure of setup and init. * :pencil2: Fixed minor typos and format in evaluator * :white_check_mark: Update evaluation tests to work with latest version of evaluate * :bug: Fixed bug while importing version --- setup.cfg | 3 +++ setup.py | 2 -- zshot/__init__.py | 2 ++ zshot/evaluation/evaluator.py | 3 ++- zshot/evaluation/zshot_evaluate.py | 6 ++++-- zshot/tests/evaluation/test_evaluation.py | 17 +++++++++-------- 6 files changed, 20 insertions(+), 13 deletions(-) diff --git a/setup.cfg b/setup.cfg index 49a8107..3cd4e0c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,2 +1,5 @@ [egg_info] tag_svn_revision = true + +[metadata] +version = attr: zshot.__version__ \ No newline at end of file diff --git a/setup.py b/setup.py index b83fd60..e237cf7 100644 --- a/setup.py +++ b/setup.py @@ -4,10 +4,8 @@ this_directory = Path(__file__).parent long_description = (this_directory / "README.md").read_text() -version = '0.0.2' setup(name='zshot', - version=version, description="Zero and Few shot named entity recognition", long_description_content_type='text/markdown', long_description=long_description, diff --git a/zshot/__init__.py b/zshot/__init__.py index f19d039..35177ff 100644 --- a/zshot/__init__.py +++ b/zshot/__init__.py @@ -1,2 +1,4 @@ from zshot.zshot import MentionsExtractor, Linker, Zshot, PipelineConfig # noqa: F401 from zshot.utils.displacy import displacy # noqa: F401 + +__version__ = '0.0.3' diff --git a/zshot/evaluation/evaluator.py b/zshot/evaluation/evaluator.py index 9333200..8e5b794 100644 --- a/zshot/evaluation/evaluator.py +++ b/zshot/evaluation/evaluator.py @@ -34,7 +34,8 @@ def prepare_pipeline( feature_extractor=None, # noqa: F821 device: int = None, ): - pipe = super(TokenClassificationEvaluator, self).prepare_pipeline(model_or_pipeline, tokenizer, feature_extractor, device) + pipe = super(TokenClassificationEvaluator, self).prepare_pipeline(model_or_pipeline, tokenizer, + feature_extractor, device) return pipe diff --git a/zshot/evaluation/zshot_evaluate.py b/zshot/evaluation/zshot_evaluate.py index fb266df..1d21f5a 100644 --- a/zshot/evaluation/zshot_evaluate.py +++ b/zshot/evaluation/zshot_evaluate.py @@ -5,13 +5,12 @@ from prettytable import PrettyTable from zshot.evaluation import load_medmentions, load_ontonotes -from zshot.evaluation.dataset.dataset import DatasetWithEntities from zshot.evaluation.evaluator import ZeroShotTokenClassificationEvaluator, MentionsExtractorEvaluator from zshot.evaluation.pipeline import LinkerPipeline, MentionsExtractorPipeline def evaluate(nlp: spacy.language.Language, - datasets: Union[DatasetWithEntities, List[DatasetWithEntities]], + datasets: Union[str, List[str]], splits: Optional[Union[str, List[str]]] = None, metric: Optional[Union[str, EvaluationModule]] = None, batch_size: Optional[int] = 16) -> str: @@ -31,6 +30,9 @@ def evaluate(nlp: spacy.language.Language, if type(splits) == str: splits = [splits] + if type(datasets) == str: + datasets = [datasets] + result = {} field_names = ["Metric"] for dataset_name in datasets: diff --git a/zshot/tests/evaluation/test_evaluation.py b/zshot/tests/evaluation/test_evaluation.py index cef2c6d..80f55f1 100644 --- a/zshot/tests/evaluation/test_evaluation.py +++ b/zshot/tests/evaluation/test_evaluation.py @@ -113,7 +113,7 @@ def test_prediction_token_based_evaluation_all_matching(self): dataset = get_dataset(gt, sentences) custom_evaluator = ZeroShotTokenClassificationEvaluator("token-classification") - metrics = custom_evaluator.compute(get_linker_pipe([('New York', 'FAC', 1)]), dataset, "seqeval") + metrics = custom_evaluator.compute(get_linker_pipe([('New York', 'FAC', 1)]), dataset, metric="seqeval") assert float(metrics["overall_precision"]) == 1.0 assert float(metrics["overall_precision"]) == 1.0 @@ -128,7 +128,7 @@ def test_prediction_token_based_evaluation_overlapping_spans(self): custom_evaluator = ZeroShotTokenClassificationEvaluator("token-classification") metrics = custom_evaluator.compute(get_linker_pipe([('New York', 'FAC', 1), ('York', 'LOC', 0.7)]), dataset, - "seqeval") + metric="seqeval") assert float(metrics["overall_precision"]) == 1.0 assert float(metrics["overall_precision"]) == 1.0 @@ -144,7 +144,7 @@ def test_prediction_token_based_evaluation_partial_match_spans_expand(self): custom_evaluator = ZeroShotTokenClassificationEvaluator("token-classification", alignment_mode=AlignmentMode.expand) pipe = get_linker_pipe([('New Yo', 'FAC', 1)]) - metrics = custom_evaluator.compute(pipe, dataset, "seqeval") + metrics = custom_evaluator.compute(pipe, dataset, metric="seqeval") assert float(metrics["overall_precision"]) == 1.0 assert float(metrics["overall_precision"]) == 1.0 @@ -160,7 +160,7 @@ def test_prediction_token_based_evaluation_partial_match_spans_contract(self): custom_evaluator = ZeroShotTokenClassificationEvaluator("token-classification", alignment_mode=AlignmentMode.contract) pipe = get_linker_pipe([('New York i', 'FAC', 1)]) - metrics = custom_evaluator.compute(pipe, dataset, "seqeval") + metrics = custom_evaluator.compute(pipe, dataset, metric="seqeval") assert float(metrics["overall_precision"]) == 1.0 assert float(metrics["overall_precision"]) == 1.0 @@ -176,7 +176,7 @@ def test_prediction_token_based_evaluation_partial_and_overlapping_spans(self): custom_evaluator = ZeroShotTokenClassificationEvaluator("token-classification", alignment_mode=AlignmentMode.contract) pipe = get_linker_pipe([('New York i', 'FAC', 1), ('w York', 'LOC', 0.7)]) - metrics = custom_evaluator.compute(pipe, dataset, "seqeval") + metrics = custom_evaluator.compute(pipe, dataset, metric="seqeval") assert float(metrics["overall_precision"]) == 1.0 assert float(metrics["overall_precision"]) == 1.0 @@ -207,7 +207,8 @@ def test_prediction_token_based_evaluation_all_matching(self): dataset = get_dataset(gt, sentences) custom_evaluator = MentionsExtractorEvaluator("token-classification") - metrics = custom_evaluator.compute(get_mentions_extractor_pipe([('New York', 'FAC', 1)]), dataset, "seqeval") + metrics = custom_evaluator.compute(get_mentions_extractor_pipe([('New York', 'FAC', 1)]), dataset, + metric="seqeval") assert float(metrics["overall_precision"]) == 1.0 assert float(metrics["overall_precision"]) == 1.0 @@ -222,7 +223,7 @@ def test_prediction_token_based_evaluation_overlapping_spans(self): custom_evaluator = MentionsExtractorEvaluator("token-classification") metrics = custom_evaluator.compute(get_mentions_extractor_pipe([('New York', 'FAC', 1), ('York', 'LOC', 0.7)]), - dataset, "seqeval") + dataset, metric="seqeval") assert float(metrics["overall_precision"]) == 1.0 assert float(metrics["overall_precision"]) == 1.0 @@ -238,7 +239,7 @@ def test_prediction_token_based_evaluation_partial_match_spans_expand(self): custom_evaluator = MentionsExtractorEvaluator("token-classification", alignment_mode=AlignmentMode.expand) pipe = get_mentions_extractor_pipe([('New Yo', 'FAC', 1)]) - metrics = custom_evaluator.compute(pipe, dataset, "seqeval") + metrics = custom_evaluator.compute(pipe, dataset, metric="seqeval") assert float(metrics["overall_precision"]) == 1.0 assert float(metrics["overall_precision"]) == 1.0