Skip to content

Commit

Permalink
Merge pull request #8 from guardrails-ai/ml_endpoint_setup
Browse files Browse the repository at this point in the history
Setup ML Endpoints
  • Loading branch information
CalebCourier authored Aug 6, 2024
2 parents edba7c0 + 3f15a53 commit 0f8be03
Showing 1 changed file with 73 additions and 19 deletions.
92 changes: 73 additions & 19 deletions validator/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import nltk
import spacy
import json
import re
from typing import Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional

from guardrails.validator_base import ErrorSpan
import nltk
import spacy
from guardrails.logger import logger
from guardrails.validator_base import ErrorSpan
from guardrails.validators import (
FailResult,
PassResult,
Expand All @@ -14,7 +15,9 @@
)


@register_validator(name="guardrails/competitor_check", data_type="string")
@register_validator(
name="guardrails/competitor_check", data_type="string", has_guardrails_endpoint=True
)
class CompetitorCheck(Validator):
"""Validates that LLM-generated text is not naming any competitors from a
given list.
Expand Down Expand Up @@ -56,11 +59,17 @@ def __init__(
self,
competitors: List[str],
on_fail: Optional[Callable] = None,
**kwargs,
):
super().__init__(competitors=competitors, on_fail=on_fail)
super().__init__(
competitors=competitors,
on_fail=on_fail,
**kwargs,
)
self._competitors = competitors
model = "en_core_web_trf"
self.nlp = spacy.load(model)
if self.use_local:
self.nlp = spacy.load(model)

def exact_match(self, text: str, competitors: List[str]) -> List[str]:
"""Performs exact match to find competitors from a list in a given
Expand All @@ -82,19 +91,18 @@ def exact_match(self, text: str, competitors: List[str]) -> List[str]:
found_entities.append(entity)
return found_entities

def perform_ner(self, text: str, nlp) -> List[str]:
def perform_ner(self, text: str) -> List[str]:
"""Performs named entity recognition on text using a provided NLP
model.
Args:
text (str): The text to perform named entity recognition on.
nlp: The NLP model to use for entity recognition.
Returns:
entities: A list of entities found.
"""

doc = nlp(text)
doc = self.nlp(text)
entities = []
for ent in doc.ents:
entities.append(ent.text)
Expand Down Expand Up @@ -142,10 +150,13 @@ def validate(self, value: str, metadata=Dict) -> ValidationResult:
list_of_competitors_found = []
start_ind = 0
for sentence in sentences:

entities = self.exact_match(sentence, self._competitors)
if entities:
ner_entities = self.perform_ner(sentence, self.nlp)
ner_entities = self._inference(
{"text": sentence, "competitors": self._competitors}
)
if isinstance(ner_entities, str):
ner_entities = [ner_entities]
found_competitors = self.is_entity_in_list(ner_entities, entities)
if found_competitors:
flagged_sentences.append((found_competitors, sentence))
Expand All @@ -168,17 +179,22 @@ def find_all(a_str, sub):
start = 0
while True:
start = a_str.find(sub, start)
if start == -1:
if start == -1:
return
yield start
start += len(sub) # use start += 1 to find overlapping matches
start += len(sub) # use start += 1 to find overlapping matches

error_spans = []
for entity in found_entities:
for entity in found_entities:
starts = list(find_all(value, entity))
for start in starts:
error_spans.append(ErrorSpan(start=start, end=start+len(entity), reason=f'Competitor found: {value[start:start+len(entity)]}'))

error_spans.append(
ErrorSpan(
start=start,
end=start + len(entity),
reason=f"Competitor found: {value[start:start+len(entity)]}",
)
)

if len(flagged_sentences):
return FailResult(
Expand All @@ -187,8 +203,46 @@ def find_all(a_str, sub):
"Please avoid naming those competitors next time"
),
fix_value=filtered_output,
error_spans=error_spans
error_spans=error_spans,
)
else:
return PassResult()


def _inference_local(self, model_input: Any) -> str:
"""Local inference method to detect and anonymize competitor names."""
text = model_input["text"]
competitors = model_input["competitors"]

doc = self.nlp(text)
anonymized_text = text
for ent in doc.ents:
if ent.text in competitors:
anonymized_text = anonymized_text.replace(ent.text, "[COMPETITOR]")
return anonymized_text

def _inference_remote(self, model_input: Any) -> str:
"""Remote inference method for a hosted ML endpoint."""
request_body = {
"inputs": [
{
"name": "text",
"shape": [1],
"data": [model_input["text"]],
"datatype": "BYTES"
},
{
"name": "competitors",
"shape": [len(model_input["competitors"])],
"data": model_input["competitors"],
"datatype": "BYTES"
}
]
}
response = self._hub_inference_request(json.dumps(request_body), self.validation_endpoint)

if not response or "outputs" not in response:
raise ValueError("Invalid response from remote inference", response)

outputs = response["outputs"][0]["data"][0]

return outputs

0 comments on commit 0f8be03

Please sign in to comment.