Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Setup ML Endpoints #8

Merged
merged 28 commits into from
Aug 6, 2024
Merged
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
88e7999
add local and remote functions
aaravnavani Jun 28, 2024
4c1bee9
make sure formats are the same
aaravnavani Jun 28, 2024
124002a
fix any type
aaravnavani Jun 28, 2024
cce06ea
fixing validate to use inference rather than nlp model
wylansford Jun 28, 2024
7930788
kwargs updates
aaravnavani Jul 1, 2024
ba091fb
resolve conflicts
aaravnavani Jul 1, 2024
eb6b95f
fix args
aaravnavani Jul 1, 2024
cbe18ee
fix inference call
aaravnavani Jul 1, 2024
33d83aa
formatting
aaravnavani Jul 1, 2024
fde7231
Update main.py
wylansford Jul 1, 2024
357caf2
fix response output
aaravnavani Jul 2, 2024
e85a239
Merge branch 'ml_endpoint_setup' of https://github.com/guardrails-ai/…
aaravnavani Jul 2, 2024
d926a2f
remove use local check
aaravnavani Jul 2, 2024
cd6e099
Update main.py
wylansford Jul 8, 2024
fb487a2
Update main.py
wylansford Jul 8, 2024
37c828d
fix req body
aaravnavani Jul 16, 2024
4e4f875
json encoding
aaravnavani Jul 16, 2024
9cda7ab
Merge branch 'main' into ml_endpoint_setup
aaravnavani Jul 16, 2024
4e1c348
fix tests
aaravnavani Jul 18, 2024
7d75a24
test fix
aaravnavani Jul 18, 2024
08458bc
fix shape
aaravnavani Aug 5, 2024
4c8d787
model loading
aaravnavani Aug 5, 2024
f18c77d
Merge branch 'main' into ml_endpoint_setup
aaravnavani Aug 5, 2024
bced58e
Revert "fix shape"
aaravnavani Aug 5, 2024
7da461b
Merge branch 'ml_endpoint_setup' of https://github.com/guardrails-ai/…
aaravnavani Aug 5, 2024
76404c4
Revert "Revert "fix shape""
aaravnavani Aug 5, 2024
f183b43
go back to old code
aaravnavani Aug 5, 2024
3f15a53
remove duplicate error spans function
aaravnavani Aug 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 73 additions & 19 deletions validator/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import nltk
import spacy
import json
import re
from typing import Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional

from guardrails.validator_base import ErrorSpan
import nltk
import spacy
from guardrails.logger import logger
from guardrails.validator_base import ErrorSpan
from guardrails.validators import (
FailResult,
PassResult,
Expand All @@ -14,7 +15,9 @@
)


@register_validator(name="guardrails/competitor_check", data_type="string")
@register_validator(
name="guardrails/competitor_check", data_type="string", has_guardrails_endpoint=True
)
class CompetitorCheck(Validator):
"""Validates that LLM-generated text is not naming any competitors from a
given list.
Expand Down Expand Up @@ -56,11 +59,17 @@ def __init__(
self,
competitors: List[str],
on_fail: Optional[Callable] = None,
**kwargs,
):
super().__init__(competitors=competitors, on_fail=on_fail)
super().__init__(
competitors=competitors,
on_fail=on_fail,
**kwargs,
)
self._competitors = competitors
model = "en_core_web_trf"
self.nlp = spacy.load(model)
if self.use_local:
self.nlp = spacy.load(model)

def exact_match(self, text: str, competitors: List[str]) -> List[str]:
"""Performs exact match to find competitors from a list in a given
Expand All @@ -82,19 +91,18 @@ def exact_match(self, text: str, competitors: List[str]) -> List[str]:
found_entities.append(entity)
return found_entities

def perform_ner(self, text: str, nlp) -> List[str]:
def perform_ner(self, text: str) -> List[str]:
"""Performs named entity recognition on text using a provided NLP
model.

Args:
text (str): The text to perform named entity recognition on.
nlp: The NLP model to use for entity recognition.

Returns:
entities: A list of entities found.
"""

doc = nlp(text)
doc = self.nlp(text)
entities = []
for ent in doc.ents:
entities.append(ent.text)
Expand Down Expand Up @@ -142,10 +150,13 @@ def validate(self, value: str, metadata=Dict) -> ValidationResult:
list_of_competitors_found = []
start_ind = 0
for sentence in sentences:

entities = self.exact_match(sentence, self._competitors)
if entities:
ner_entities = self.perform_ner(sentence, self.nlp)
ner_entities = self._inference(
{"text": sentence, "competitors": self._competitors}
)
if isinstance(ner_entities, str):
ner_entities = [ner_entities]
found_competitors = self.is_entity_in_list(ner_entities, entities)
if found_competitors:
flagged_sentences.append((found_competitors, sentence))
Expand All @@ -168,17 +179,22 @@ def find_all(a_str, sub):
start = 0
while True:
start = a_str.find(sub, start)
if start == -1:
if start == -1:
return
yield start
start += len(sub) # use start += 1 to find overlapping matches
start += len(sub) # use start += 1 to find overlapping matches

error_spans = []
for entity in found_entities:
for entity in found_entities:
starts = list(find_all(value, entity))
for start in starts:
error_spans.append(ErrorSpan(start=start, end=start+len(entity), reason=f'Competitor found: {value[start:start+len(entity)]}'))

error_spans.append(
ErrorSpan(
start=start,
end=start + len(entity),
reason=f"Competitor found: {value[start:start+len(entity)]}",
)
)

if len(flagged_sentences):
return FailResult(
Expand All @@ -187,8 +203,46 @@ def find_all(a_str, sub):
"Please avoid naming those competitors next time"
),
fix_value=filtered_output,
error_spans=error_spans
error_spans=error_spans,
)
else:
return PassResult()


def _inference_local(self, model_input: Any) -> str:
"""Local inference method to detect and anonymize competitor names."""
text = model_input["text"]
competitors = model_input["competitors"]

doc = self.nlp(text)
anonymized_text = text
for ent in doc.ents:
if ent.text in competitors:
anonymized_text = anonymized_text.replace(ent.text, "[COMPETITOR]")
return anonymized_text

def _inference_remote(self, model_input: Any) -> str:
"""Remote inference method for a hosted ML endpoint."""
request_body = {
"inputs": [
{
"name": "text",
"shape": [1],
"data": [model_input["text"]],
"datatype": "BYTES"
},
{
"name": "competitors",
"shape": [len(model_input["competitors"])],
"data": model_input["competitors"],
"datatype": "BYTES"
}
]
}
response = self._hub_inference_request(json.dumps(request_body), self.validation_endpoint)

if not response or "outputs" not in response:
raise ValueError("Invalid response from remote inference", response)

outputs = response["outputs"][0]["data"][0]

return outputs
Loading