From 88e7999d2955bd576e29a92493ed7a1268fa4b08 Mon Sep 17 00:00:00 2001 From: Aarav Navani <38411399+oofmeister27@users.noreply.github.com> Date: Thu, 27 Jun 2024 17:08:20 -0700 Subject: [PATCH 01/23] add local and remote functions --- validator/main.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/validator/main.py b/validator/main.py index 2b66013..86e0155 100644 --- a/validator/main.py +++ b/validator/main.py @@ -12,6 +12,8 @@ Validator, register_validator, ) +import json + @register_validator(name="guardrails/competitor_check", data_type="string") @@ -56,6 +58,7 @@ def __init__( self, competitors: List[str], on_fail: Optional[Callable] = None, + use_local: bool = True ): super().__init__(competitors=competitors, on_fail=on_fail) self._competitors = competitors @@ -190,4 +193,20 @@ def find_all(a_str, sub): error_spans=error_spans ) else: - return PassResult() \ No newline at end of file + return PassResult() + + def _inference_local(self, value: str, metadata: Dict) -> ValidationResult: + """Local inference method for the competitor check validator.""" + return self.perform_ner(value, metadata) + + + def _inference_remote(self, value: str, metadata: Dict) -> ValidationResult: + """Remote inference method for the competitor check validator.""" + request_body = { + "model_name": "CompetitorCheck", + "text": value, + "competitors": self._competitors, + } + request_body = json.dumps(request_body, ensure_ascii=False) + response = self._hub_inference_request(request_body) + return response \ No newline at end of file From 4c1bee95e885ae1d884175e9e5bf1a335269de23 Mon Sep 17 00:00:00 2001 From: Aarav Navani <38411399+oofmeister27@users.noreply.github.com> Date: Thu, 27 Jun 2024 17:24:34 -0700 Subject: [PATCH 02/23] make sure formats are the same --- validator/main.py | 37 +++++++++++++++++++++++++++---------- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/validator/main.py b/validator/main.py index 86e0155..f5667bc 100644 --- a/validator/main.py +++ b/validator/main.py @@ -195,18 +195,35 @@ def find_all(a_str, sub): else: return PassResult() - def _inference_local(self, value: str, metadata: Dict) -> ValidationResult: - """Local inference method for the competitor check validator.""" - return self.perform_ner(value, metadata) - + def _inference_local(self, model_input: Any) -> str: + """Local inference method to detect and anonymize competitor names.""" + text = model_input["text"] + competitors = model_input["competitors"] + + doc = self.nlp(text) + anonymized_text = text + for ent in doc.ents: + if ent.text in competitors: + anonymized_text = anonymized_text.replace(ent.text, "[COMPETITOR]") + return anonymized_text - def _inference_remote(self, value: str, metadata: Dict) -> ValidationResult: - """Remote inference method for the competitor check validator.""" + def _inference_remote(self, model_input: Any) -> str: + """Remote inference method for a hosted ML endpoint.""" request_body = { "model_name": "CompetitorCheck", - "text": value, - "competitors": self._competitors, + "text": model_input["text"], + "competitors": model_input["competitors"] } - request_body = json.dumps(request_body, ensure_ascii=False) response = self._hub_inference_request(request_body) - return response \ No newline at end of file + + if not response or 'outputs' not in response: + raise ValueError("Invalid response from remote inference") + + outputs = response['outputs'][0]['data'][0] + result = json.loads(outputs) + + if 'output' in result: + return result['output'] + else: + raise ValueError("Invalid format of the response from remote inference") + \ No newline at end of file From 124002aea5e713a3d4e59ed28819e91e9528072f Mon Sep 17 00:00:00 2001 From: Aarav Navani <38411399+oofmeister27@users.noreply.github.com> Date: Thu, 27 Jun 2024 17:33:05 -0700 Subject: [PATCH 03/23] fix any type --- validator/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/validator/main.py b/validator/main.py index f5667bc..75b359c 100644 --- a/validator/main.py +++ b/validator/main.py @@ -1,7 +1,7 @@ import nltk import spacy import re -from typing import Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional from guardrails.validator_base import ErrorSpan from guardrails.logger import logger From cce06eaf78b795c8215df3cb82093a12f097ff2a Mon Sep 17 00:00:00 2001 From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com> Date: Thu, 27 Jun 2024 22:07:32 -0700 Subject: [PATCH 04/23] fixing validate to use inference rather than nlp model --- validator/main.py | 62 +++++++++++++++++++++++++---------------------- 1 file changed, 33 insertions(+), 29 deletions(-) diff --git a/validator/main.py b/validator/main.py index 75b359c..e8e1d6a 100644 --- a/validator/main.py +++ b/validator/main.py @@ -1,10 +1,11 @@ -import nltk -import spacy +import json import re from typing import Any, Callable, Dict, List, Optional -from guardrails.validator_base import ErrorSpan +import nltk +import spacy from guardrails.logger import logger +from guardrails.validator_base import ErrorSpan from guardrails.validators import ( FailResult, PassResult, @@ -12,8 +13,6 @@ Validator, register_validator, ) -import json - @register_validator(name="guardrails/competitor_check", data_type="string") @@ -57,13 +56,14 @@ def chunking_function(self, chunk: str): def __init__( self, competitors: List[str], + use_local: bool = False, on_fail: Optional[Callable] = None, - use_local: bool = True ): - super().__init__(competitors=competitors, on_fail=on_fail) + super().__init__(competitors=competitors, use_local=use_local, on_fail=on_fail) self._competitors = competitors model = "en_core_web_trf" - self.nlp = spacy.load(model) + if self.use_local: + self.nlp = spacy.load(model) def exact_match(self, text: str, competitors: List[str]) -> List[str]: """Performs exact match to find competitors from a list in a given @@ -85,13 +85,12 @@ def exact_match(self, text: str, competitors: List[str]) -> List[str]: found_entities.append(entity) return found_entities - def perform_ner(self, text: str, nlp) -> List[str]: + def perform_ner(self, text: str) -> List[str]: """Performs named entity recognition on text using a provided NLP model. Args: text (str): The text to perform named entity recognition on. - nlp: The NLP model to use for entity recognition. Returns: entities: A list of entities found. @@ -141,14 +140,15 @@ def validate(self, value: str, metadata=Dict) -> ValidationResult: sentences = nltk.sent_tokenize(value) flagged_sentences = [] filtered_sentences = [] - error_spans:List[ErrorSpan] = [] + error_spans: List[ErrorSpan] = [] list_of_competitors_found = [] start_ind = 0 for sentence in sentences: - entities = self.exact_match(sentence, self._competitors) if entities: - ner_entities = self.perform_ner(sentence, self.nlp) + ner_entities = self.inference( + {"text": sentence, "competitors": self._competitors} + ) found_competitors = self.is_entity_in_list(ner_entities, entities) if found_competitors: flagged_sentences.append((found_competitors, sentence)) @@ -171,17 +171,22 @@ def find_all(a_str, sub): start = 0 while True: start = a_str.find(sub, start) - if start == -1: + if start == -1: return yield start - start += len(sub) # use start += 1 to find overlapping matches + start += len(sub) # use start += 1 to find overlapping matches error_spans = [] - for entity in found_entities: + for entity in found_entities: starts = list(find_all(value, entity)) for start in starts: - error_spans.append(ErrorSpan(start=start, end=start+len(entity), reason=f'Competitor found: {value[start:start+len(entity)]}')) - + error_spans.append( + ErrorSpan( + start=start, + end=start + len(entity), + reason=f"Competitor found: {value[start:start+len(entity)]}", + ) + ) if len(flagged_sentences): return FailResult( @@ -190,11 +195,11 @@ def find_all(a_str, sub): "Please avoid naming those competitors next time" ), fix_value=filtered_output, - error_spans=error_spans + error_spans=error_spans, ) else: return PassResult() - + def _inference_local(self, model_input: Any) -> str: """Local inference method to detect and anonymize competitor names.""" text = model_input["text"] @@ -212,18 +217,17 @@ def _inference_remote(self, model_input: Any) -> str: request_body = { "model_name": "CompetitorCheck", "text": model_input["text"], - "competitors": model_input["competitors"] + "competitors": model_input["competitors"], } response = self._hub_inference_request(request_body) - - if not response or 'outputs' not in response: + + if not response or "outputs" not in response: raise ValueError("Invalid response from remote inference") - - outputs = response['outputs'][0]['data'][0] + + outputs = response["outputs"][0]["data"][0] result = json.loads(outputs) - - if 'output' in result: - return result['output'] + + if "output" in result: + return result["output"] else: raise ValueError("Invalid format of the response from remote inference") - \ No newline at end of file From 793078865f3c9367ce581f590802c34831953d64 Mon Sep 17 00:00:00 2001 From: Aarav Navani <38411399+oofmeister27@users.noreply.github.com> Date: Mon, 1 Jul 2024 12:10:44 -0700 Subject: [PATCH 05/23] kwargs updates --- validator/main.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/validator/main.py b/validator/main.py index 75b359c..9e529b8 100644 --- a/validator/main.py +++ b/validator/main.py @@ -16,7 +16,7 @@ -@register_validator(name="guardrails/competitor_check", data_type="string") +@register_validator(name="guardrails/competitor_check", data_type="string", has_guardrails_endpoint=True) class CompetitorCheck(Validator): """Validates that LLM-generated text is not naming any competitors from a given list. @@ -58,10 +58,15 @@ def __init__( self, competitors: List[str], on_fail: Optional[Callable] = None, - use_local: bool = True + **kwargs, ): - super().__init__(competitors=competitors, on_fail=on_fail) + super().__init__( + competitors=competitors, + on_fail=on_fail, + **kwargs, + ) self._competitors = competitors + self.use_local = kwargs.get("use_local", None) model = "en_core_web_trf" self.nlp = spacy.load(model) From eb6b95fb1143e91083c6a6670bb22a338767b501 Mon Sep 17 00:00:00 2001 From: Aarav Navani <38411399+oofmeister27@users.noreply.github.com> Date: Mon, 1 Jul 2024 12:14:26 -0700 Subject: [PATCH 06/23] fix args --- validator/main.py | 1 - 1 file changed, 1 deletion(-) diff --git a/validator/main.py b/validator/main.py index bfa574e..cac07fc 100644 --- a/validator/main.py +++ b/validator/main.py @@ -56,7 +56,6 @@ def chunking_function(self, chunk: str): def __init__( self, competitors: List[str], - use_local: bool = False, on_fail: Optional[Callable] = None, **kwargs, ): From cbe18ee87be8144c4e557f97c7c24634ff866c33 Mon Sep 17 00:00:00 2001 From: Aarav Navani <38411399+oofmeister27@users.noreply.github.com> Date: Mon, 1 Jul 2024 12:14:57 -0700 Subject: [PATCH 07/23] fix inference call --- validator/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/validator/main.py b/validator/main.py index cac07fc..2cc9e39 100644 --- a/validator/main.py +++ b/validator/main.py @@ -151,7 +151,7 @@ def validate(self, value: str, metadata=Dict) -> ValidationResult: for sentence in sentences: entities = self.exact_match(sentence, self._competitors) if entities: - ner_entities = self.inference( + ner_entities = self._inference( {"text": sentence, "competitors": self._competitors} ) found_competitors = self.is_entity_in_list(ner_entities, entities) From 33d83aa08b1b2246f25e402a338008cbce36c348 Mon Sep 17 00:00:00 2001 From: Aarav Navani <38411399+oofmeister27@users.noreply.github.com> Date: Mon, 1 Jul 2024 13:56:08 -0700 Subject: [PATCH 08/23] formatting --- validator/main.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/validator/main.py b/validator/main.py index 2cc9e39..d96a858 100644 --- a/validator/main.py +++ b/validator/main.py @@ -15,7 +15,9 @@ ) -@register_validator(name="guardrails/competitor_check", data_type="string", has_guardrails_endpoint=True) +@register_validator( + name="guardrails/competitor_check", data_type="string", has_guardrails_endpoint=True +) class CompetitorCheck(Validator): """Validates that LLM-generated text is not naming any competitors from a given list. From fde7231deddf6ca1b0b82a0530169773f805a630 Mon Sep 17 00:00:00 2001 From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com> Date: Mon, 1 Jul 2024 15:09:29 -0700 Subject: [PATCH 09/23] Update main.py --- validator/main.py | 1 - 1 file changed, 1 deletion(-) diff --git a/validator/main.py b/validator/main.py index d96a858..737dcbb 100644 --- a/validator/main.py +++ b/validator/main.py @@ -67,7 +67,6 @@ def __init__( **kwargs, ) self._competitors = competitors - self.use_local = kwargs.get("use_local", None) model = "en_core_web_trf" if self.use_local: self.nlp = spacy.load(model) From 357caf2d7864aa09222afcffbce2644ec83c4094 Mon Sep 17 00:00:00 2001 From: Aarav Navani <38411399+oofmeister27@users.noreply.github.com> Date: Mon, 1 Jul 2024 17:13:45 -0700 Subject: [PATCH 10/23] fix response output --- validator/main.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/validator/main.py b/validator/main.py index d96a858..9d22340 100644 --- a/validator/main.py +++ b/validator/main.py @@ -222,7 +222,6 @@ def _inference_local(self, model_input: Any) -> str: def _inference_remote(self, model_input: Any) -> str: """Remote inference method for a hosted ML endpoint.""" request_body = { - "model_name": "CompetitorCheck", "text": model_input["text"], "competitors": model_input["competitors"], } @@ -234,7 +233,4 @@ def _inference_remote(self, model_input: Any) -> str: outputs = response["outputs"][0]["data"][0] result = json.loads(outputs) - if "output" in result: - return result["output"] - else: - raise ValueError("Invalid format of the response from remote inference") + return result From d926a2fde561049a84856aadfb5fb491e03ddca0 Mon Sep 17 00:00:00 2001 From: Aarav Navani <38411399+oofmeister27@users.noreply.github.com> Date: Mon, 1 Jul 2024 17:16:04 -0700 Subject: [PATCH 11/23] remove use local check --- validator/main.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/validator/main.py b/validator/main.py index c74b9c5..15f93fb 100644 --- a/validator/main.py +++ b/validator/main.py @@ -68,8 +68,7 @@ def __init__( ) self._competitors = competitors model = "en_core_web_trf" - if self.use_local: - self.nlp = spacy.load(model) + self.nlp = spacy.load(model) def exact_match(self, text: str, competitors: List[str]) -> List[str]: """Performs exact match to find competitors from a list in a given From cd6e09982b68665375679068372f7c9abbe98931 Mon Sep 17 00:00:00 2001 From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com> Date: Mon, 8 Jul 2024 15:50:16 -0700 Subject: [PATCH 12/23] Update main.py --- validator/main.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/validator/main.py b/validator/main.py index 15f93fb..02a0e34 100644 --- a/validator/main.py +++ b/validator/main.py @@ -62,7 +62,7 @@ def __init__( **kwargs, ): super().__init__( - competitors=competitors, + competitors=competitors, on_fail=on_fail, **kwargs, ) @@ -101,7 +101,7 @@ def perform_ner(self, text: str) -> List[str]: entities: A list of entities found. """ - doc = nlp(text) + doc = self.nlp(text) entities = [] for ent in doc.ents: entities.append(ent.text) @@ -154,6 +154,8 @@ def validate(self, value: str, metadata=Dict) -> ValidationResult: ner_entities = self._inference( {"text": sentence, "competitors": self._competitors} ) + if isinstance(ner_entities, str): + ner_entities = [ner_entities] found_competitors = self.is_entity_in_list(ner_entities, entities) if found_competitors: flagged_sentences.append((found_competitors, sentence)) @@ -192,7 +194,8 @@ def find_all(a_str, sub): reason=f"Competitor found: {value[start:start+len(entity)]}", ) ) - + print("FLAGGED SENTENCES:", flagged_sentences) + print("ERROR SPANS:", error_spans) if len(flagged_sentences): return FailResult( error_message=( @@ -223,10 +226,10 @@ def _inference_remote(self, model_input: Any) -> str: "text": model_input["text"], "competitors": model_input["competitors"], } - response = self._hub_inference_request(request_body) + response = self._hub_inference_request(request_body, self.validation_endpoint) if not response or "outputs" not in response: - raise ValueError("Invalid response from remote inference") + raise ValueError("Invalid response from remote inference", response) outputs = response["outputs"][0]["data"][0] result = json.loads(outputs) From fb487a2c9790ad0667dec1bf222d8c3485e6063f Mon Sep 17 00:00:00 2001 From: Wyatt Lansford <22553069+wylansford@users.noreply.github.com> Date: Mon, 8 Jul 2024 15:50:38 -0700 Subject: [PATCH 13/23] Update main.py --- validator/main.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/validator/main.py b/validator/main.py index 02a0e34..aaaecc0 100644 --- a/validator/main.py +++ b/validator/main.py @@ -194,8 +194,6 @@ def find_all(a_str, sub): reason=f"Competitor found: {value[start:start+len(entity)]}", ) ) - print("FLAGGED SENTENCES:", flagged_sentences) - print("ERROR SPANS:", error_spans) if len(flagged_sentences): return FailResult( error_message=( From 37c828d8027de16ef288379bcbedf4e6d39b7c1a Mon Sep 17 00:00:00 2001 From: Aarav Navani <38411399+oofmeister27@users.noreply.github.com> Date: Tue, 16 Jul 2024 12:37:32 -0700 Subject: [PATCH 14/23] fix req body --- validator/main.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/validator/main.py b/validator/main.py index aaaecc0..25f9e99 100644 --- a/validator/main.py +++ b/validator/main.py @@ -221,8 +221,20 @@ def _inference_local(self, model_input: Any) -> str: def _inference_remote(self, model_input: Any) -> str: """Remote inference method for a hosted ML endpoint.""" request_body = { - "text": model_input["text"], - "competitors": model_input["competitors"], + "inputs": [ + { + "name": "text", + "shape": [1], + "data": [model_input["text"]], + "datatype": "BYTES" + }, + { + "name": "competitors", + "shape": [1], + "data": model_input["competitors"], + "datatype": "BYTES" + } + ] } response = self._hub_inference_request(request_body, self.validation_endpoint) From 4e4f875405e9aaf50b7fb5000dc599e3d6b19b9b Mon Sep 17 00:00:00 2001 From: Aarav Navani <38411399+oofmeister27@users.noreply.github.com> Date: Tue, 16 Jul 2024 12:38:22 -0700 Subject: [PATCH 15/23] json encoding --- validator/main.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/validator/main.py b/validator/main.py index 25f9e99..1d38ab7 100644 --- a/validator/main.py +++ b/validator/main.py @@ -236,12 +236,11 @@ def _inference_remote(self, model_input: Any) -> str: } ] } - response = self._hub_inference_request(request_body, self.validation_endpoint) + response = self._hub_inference_request(json.dumps(request_body), self.validation_endpoint) if not response or "outputs" not in response: raise ValueError("Invalid response from remote inference", response) outputs = response["outputs"][0]["data"][0] - result = json.loads(outputs) - return result + return outputs From 4e1c3484aff63bf61324e1e6c120734741af933a Mon Sep 17 00:00:00 2001 From: Aarav Navani <38411399+oofmeister27@users.noreply.github.com> Date: Thu, 18 Jul 2024 12:20:47 -0700 Subject: [PATCH 16/23] fix tests --- app.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/app.py b/app.py index b29ed48..45d8cfa 100644 --- a/app.py +++ b/app.py @@ -1,5 +1,3 @@ -import json -import torch import nltk from typing import Any, Dict, List import spacy From 7d75a24b1ca6b9058f393bec00b4455dd1d58b19 Mon Sep 17 00:00:00 2001 From: Aarav Navani <38411399+oofmeister27@users.noreply.github.com> Date: Thu, 18 Jul 2024 12:24:17 -0700 Subject: [PATCH 17/23] test fix --- app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app.py b/app.py index 45d8cfa..2d28117 100644 --- a/app.py +++ b/app.py @@ -1,5 +1,5 @@ import nltk -from typing import Any, Dict, List +from typing import Any, Dict import spacy class InferlessPythonModel: From 08458bcf36e2564a5afbdfc2fc33d523629da7a3 Mon Sep 17 00:00:00 2001 From: Aarav Navani <38411399+oofmeister27@users.noreply.github.com> Date: Mon, 5 Aug 2024 08:35:31 -0700 Subject: [PATCH 18/23] fix shape --- validator/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/validator/main.py b/validator/main.py index 00faba2..c50a19a 100644 --- a/validator/main.py +++ b/validator/main.py @@ -246,7 +246,7 @@ def _inference_remote(self, model_input: Any) -> str: }, { "name": "competitors", - "shape": [1], + "shape": [len(model_input["competitors"])], "data": model_input["competitors"], "datatype": "BYTES" } From 4c8d787bc3d84a0275a2a4b55faf24e531381752 Mon Sep 17 00:00:00 2001 From: Aarav Navani <38411399+oofmeister27@users.noreply.github.com> Date: Mon, 5 Aug 2024 12:58:38 -0700 Subject: [PATCH 19/23] model loading --- validator/main.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/validator/main.py b/validator/main.py index c50a19a..5061e8f 100644 --- a/validator/main.py +++ b/validator/main.py @@ -68,7 +68,8 @@ def __init__( ) self._competitors = competitors model = "en_core_web_trf" - self.nlp = spacy.load(model) + if self.use_local: + self.nlp = spacy.load(model) def exact_match(self, text: str, competitors: List[str]) -> List[str]: """Performs exact match to find competitors from a list in a given From bced58e4533dd096f73f380d92e1a8dbb2e5ebd9 Mon Sep 17 00:00:00 2001 From: Aarav Navani <38411399+oofmeister27@users.noreply.github.com> Date: Mon, 5 Aug 2024 14:39:59 -0700 Subject: [PATCH 20/23] Revert "fix shape" This reverts commit 08458bcf36e2564a5afbdfc2fc33d523629da7a3. --- validator/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/validator/main.py b/validator/main.py index 5061e8f..55d48d2 100644 --- a/validator/main.py +++ b/validator/main.py @@ -247,7 +247,7 @@ def _inference_remote(self, model_input: Any) -> str: }, { "name": "competitors", - "shape": [len(model_input["competitors"])], + "shape": [1], "data": model_input["competitors"], "datatype": "BYTES" } From 76404c4b8be75edd270b7cb9233dd0bd8c84c1f3 Mon Sep 17 00:00:00 2001 From: Aarav Navani <38411399+oofmeister27@users.noreply.github.com> Date: Mon, 5 Aug 2024 14:45:09 -0700 Subject: [PATCH 21/23] Revert "Revert "fix shape"" This reverts commit bced58e4533dd096f73f380d92e1a8dbb2e5ebd9. --- validator/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/validator/main.py b/validator/main.py index ffa80c6..c5f0b84 100644 --- a/validator/main.py +++ b/validator/main.py @@ -249,7 +249,7 @@ def _inference_remote(self, model_input: Any) -> str: }, { "name": "competitors", - "shape": [1], + "shape": [len(model_input["competitors"])], "data": model_input["competitors"], "datatype": "BYTES" } From f183b43479b3f643c175bdeff35d4a807d0e5aaf Mon Sep 17 00:00:00 2001 From: Aarav Navani <38411399+oofmeister27@users.noreply.github.com> Date: Mon, 5 Aug 2024 14:48:03 -0700 Subject: [PATCH 22/23] go back to old code --- validator/main.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/validator/main.py b/validator/main.py index c5f0b84..ee8e88f 100644 --- a/validator/main.py +++ b/validator/main.py @@ -4,8 +4,6 @@ import nltk import spacy - -from guardrails.validator_base import ErrorSpan from guardrails.logger import logger from guardrails.validator_base import ErrorSpan from guardrails.validators import ( @@ -262,9 +260,4 @@ def _inference_remote(self, model_input: Any) -> str: outputs = response["outputs"][0]["data"][0] - return outputs - error_spans=error_spans - ) - else: - return PassResult() - \ No newline at end of file + return outputs \ No newline at end of file From 3f15a53ae581ac2ce0dfbe84fc7994a329ab532b Mon Sep 17 00:00:00 2001 From: Aarav Navani <38411399+oofmeister27@users.noreply.github.com> Date: Mon, 5 Aug 2024 23:57:14 -0700 Subject: [PATCH 23/23] remove duplicate error spans function --- validator/main.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/validator/main.py b/validator/main.py index ee8e88f..ba88caa 100644 --- a/validator/main.py +++ b/validator/main.py @@ -195,21 +195,6 @@ def find_all(a_str, sub): reason=f"Competitor found: {value[start:start+len(entity)]}", ) ) - def find_all(a_str, sub): - start = 0 - while True: - start = a_str.find(sub, start) - if start == -1: - return - yield start - start += len(sub) # use start += 1 to find overlapping matches - - error_spans = [] - for entity in found_entities: - starts = list(find_all(value, entity)) - for start in starts: - error_spans.append(ErrorSpan(start=start, end=start+len(entity), reason=f'Competitor found: {value[start:start+len(entity)]}')) - if len(flagged_sentences): return FailResult(