Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] Adding pyrit dependency #39809

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

import logging
from typing import Any, Callable, Dict, List, Optional

from pyrit.models import (
PromptRequestResponse,
construct_response_from_request,
)
from pyrit.prompt_target import PromptChatTarget

logger = logging.getLogger(__name__)


class CallbackChatTarget(PromptChatTarget):
def __init__(
self,
*,
callback: Callable[[List[Dict], bool, Optional[str], Optional[Dict[str, Any]]], Dict],
stream: bool = False,
) -> None:
"""
Initializes an instance of the CallbackChatTarget class.

It is intended to be used with PYRIT where users define a callback function
that handles sending a prompt to a target and receiving a response.
The CallbackChatTarget class is a wrapper around the callback function that allows it to be used
as a target in the PyRIT framework.
For that reason, it merely handles additional functionality such as memory.

Args:
callback (Callable): The callback function that sends a prompt to a target and receives a response.
stream (bool, optional): Indicates whether the target supports streaming. Defaults to False.
"""
PromptChatTarget.__init__(self)
self._callback = callback
self._stream = stream

async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse:

self._validate_request(prompt_request=prompt_request)
request = prompt_request.request_pieces[0]

messages = self._memory.get_chat_messages_with_conversation_id(conversation_id=request.conversation_id)

messages.append(request.to_chat_message())

logger.info(f"Sending the following prompt to the prompt target: {request}")

# response_context contains "messages", "stream", "session_state, "context"
response_context = await self._callback(messages=messages, stream=self._stream, session_state=None, context=None)

response_text = response_context["messages"][-1]["content"]
response_entry = construct_response_from_request(
request=request, response_text_pieces=[response_text]
)

logger.info(
"Received the following response from the prompt target"
+ f"{response_text}"
)
return response_entry

def _validate_request(self, *, prompt_request: PromptRequestResponse) -> None:
if len(prompt_request.request_pieces) != 1:
raise ValueError("This target only supports a single prompt request piece.")

if prompt_request.request_pieces[0].converted_value_data_type != "text":
raise ValueError("This target only supports text prompt input.")

def is_json_response_supported(self) -> bool:
"""Indicates that this target supports JSON response format."""
return False
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import logging
from datetime import datetime
from azure.ai.evaluation._common._experimental import experimental
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
from azure.ai.evaluation._common.math import list_mean_nan_safe
from azure.ai.evaluation._constants import CONTENT_SAFETY_DEFECT_RATE_THRESHOLD_DEFAULT
from azure.ai.evaluation._evaluators import _content_safety, _protected_material, _groundedness, _relevance, _similarity, _fluency, _xpia, _coherence
Expand All @@ -17,11 +17,19 @@
from azure.ai.evaluation._model_configurations import AzureAIProject, EvaluationResult
from azure.ai.evaluation.simulator import Simulator, AdversarialSimulator, AdversarialScenario, AdversarialScenarioJailbreak, IndirectAttackSimulator, DirectAttackSimulator
from azure.ai.evaluation.simulator._utils import JsonLineList
from azure.ai.evaluation.simulator._model_tools import ManagedIdentityAPITokenManager, TokenScope, RAIClient, AdversarialTemplateHandler
from azure.ai.evaluation._common.utils import validate_azure_ai_project
from azure.ai.evaluation._model_configurations import AzureOpenAIModelConfiguration, OpenAIModelConfiguration
from azure.ai.evaluation._safety_evaluation._callback_chat_target import CallbackChatTarget
from azure.core.credentials import TokenCredential
import json
from pathlib import Path
import re
import itertools
from pyrit.common import initialize_pyrit, DUCK_DB
from pyrit.orchestrator import PromptSendingOrchestrator
from pyrit.prompt_target import OpenAIChatTarget
from pyrit.models import ChatMessage

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -87,6 +95,21 @@ def __init__(
self.credential=credential
self.logger = _setup_logger()

# For pyrit
self.token_manager = ManagedIdentityAPITokenManager(
token_scope=TokenScope.DEFAULT_AZURE_MANAGEMENT,
logger=logging.getLogger("AdversarialSimulator"),
credential=cast(TokenCredential, credential),
)

self.rai_client = RAIClient(azure_ai_project=self.azure_ai_project, token_manager=self.token_manager)
self.adversarial_template_handler = AdversarialTemplateHandler(
azure_ai_project=self.azure_ai_project, rai_client=self.rai_client
)

initialize_pyrit(memory_db_type=DUCK_DB)


@staticmethod
def _validate_model_config(model_config: Any):
"""
Expand Down Expand Up @@ -426,11 +449,27 @@ def _check_target_returns_context(target: Callable) -> bool:
if ret_type is tuple:
return True
return False

@staticmethod
def _check_target_returns_str(target: Callable) -> bool:
'''
Checks if the target function returns a string.

:param target: The target function to check.
:type target: Callable
'''
sig = inspect.signature(target)
ret_type = sig.return_annotation
if ret_type == inspect.Signature.empty:
return False
if ret_type is str:
return True
return False

def _validate_inputs(
self,
evaluators: List[_SafetyEvaluator],
target: Callable,
target: Union[Callable, AzureOpenAIModelConfiguration, OpenAIModelConfiguration],
num_turns: int = 1,
scenario: Optional[Union[AdversarialScenario, AdversarialScenarioJailbreak]] = None,
source_text: Optional[str] = None,
Expand All @@ -447,19 +486,32 @@ def _validate_inputs(
:type scenario: Optional[Union[AdversarialScenario, AdversarialScenarioJailbreak]]
:param source_text: The source text to use as grounding document in the evaluation.
:type source_text: Optional[str]
'''
if _SafetyEvaluator.GROUNDEDNESS in evaluators and not (self._check_target_returns_context(target) or source_text):
self.logger.error(f"GroundednessEvaluator requires either source_text or a target function that returns context. Source text: {source_text}, _check_target_returns_context: {self._check_target_returns_context(target)}")
msg = "GroundednessEvaluator requires either source_text or a target function that returns context"
'''
if not callable(target):
self._validate_model_config(target)
elif not self._check_target_returns_str(target):
self.logger.error(f"Target function {target} does not return a string.")
msg = f"Target function {target} does not return a string."
raise EvaluationException(
message=msg,
internal_message=msg,
target=ErrorTarget.UNKNOWN,
category=ErrorCategory.INVALID_VALUE,
blame=ErrorBlame.USER_ERROR,
)

if _SafetyEvaluator.GROUNDEDNESS in evaluators and not source_text:
self.logger.error(f"GroundednessEvaluator requires source_text. Source text: {source_text}")
msg = "GroundednessEvaluator requires source_text"
raise EvaluationException(
message=msg,
internal_message=msg,
target=ErrorTarget.GROUNDEDNESS_EVALUATOR,
category=ErrorCategory.MISSING_FIELD,
blame=ErrorBlame.USER_ERROR,
)

if scenario and not _SafetyEvaluator.CONTENT_SAFETY in evaluators:
if scenario and len(evaluators)>0 and not _SafetyEvaluator.CONTENT_SAFETY in evaluators:
self.logger.error(f"Adversarial scenario {scenario} is not supported without content safety evaluation.")
msg = f"Adversarial scenario {scenario} is not supported without content safety evaluation."
raise EvaluationException(
Expand All @@ -470,7 +522,7 @@ def _validate_inputs(
blame=ErrorBlame.USER_ERROR,
)

if _SafetyEvaluator.CONTENT_SAFETY in evaluators and scenario and num_turns > 1:
if _SafetyEvaluator.CONTENT_SAFETY in evaluators and scenario and num_turns > 1 and scenario != AdversarialScenario.ADVERSARIAL_CONVERSATION:
self.logger.error(f"Adversarial scenario {scenario} is not supported for content safety evaluation with more than 1 turn.")
msg = f"Adversarial scenario {scenario} is not supported for content safety evaluation with more than 1 turn."
raise EvaluationException(
Expand Down Expand Up @@ -524,11 +576,107 @@ def _calculate_defect_rate(self, evaluation_result_dict) -> EvaluationResult:
}
evaluation_result['studio_url'] = evaluation_result_dict['jailbreak']['studio_url'] + '\t' + evaluation_result_dict['regular']['studio_url']
return evaluation_result

async def _get_all_prompts(self, scenario: AdversarialScenario, num_rows: int = 3) -> List[str]:
templates = await self.adversarial_template_handler._get_content_harm_template_collections(
scenario.value
)
parameter_lists = [t.template_parameters for t in templates]
zipped_parameters = list(zip(*parameter_lists))

def fill_template(template_str, parameter):
pattern = re.compile(r'{{\s*(.*?)\s*}}')
def replacer(match):
placeholder = match.group(1)
return parameter.get(placeholder, f"{{{{ {placeholder} }}}}")
filled_text = pattern.sub(replacer, template_str)
return filled_text

all_prompts = []
count = 0
for param_group in zipped_parameters:
for _, parameter in zip(templates, param_group):
filled_template = fill_template(parameter['conversation_starter'], parameter)
all_prompts.append(filled_template)
count += 1
if count >= num_rows:
break
if count >= num_rows:
break
return all_prompts

def _message_to_dict(self, message: ChatMessage):
return {
"role": message.role,
"content": message.content,
}

async def _pyrit(
self,
target: Union[Callable, AzureOpenAIModelConfiguration, OpenAIModelConfiguration],
scenario: AdversarialScenario,
num_rows: int = 1,
) -> str:
chat_target: OpenAIChatTarget = None
if not isinstance(target, Callable):
if "azure_deployment" in target and "azure_endpoint" in target: # Azure OpenAI
api_key = target.get("api_key", None)
if not api_key:
chat_target = OpenAIChatTarget(deployment_name=target["azure_deployment"], endpoint=target["azure_endpoint"], use_aad_auth=True)
else:
chat_target = OpenAIChatTarget(deployment_name=target["azure_deployment"], endpoint=target["azure_endpoint"], api_key=api_key)
else:
chat_target = OpenAIChatTarget(deployment=target["model"], endpoint=target.get("base_url", None), key=target["api_key"], is_azure_target=False)
else:
async def callback_target(
messages: List[Dict],
stream: bool = False,
session_state: Optional[str] = None,
context: Optional[Dict] = None
) -> dict:
messages_list = [self._message_to_dict(chat_message) for chat_message in messages] # type: ignore
latest_message = messages_list[-1]
application_input = latest_message["content"]
try:
response = target(query=application_input)
except Exception as e:
response = f"Something went wrong {e!s}"

## We format the response to follow the openAI chat protocol format
formatted_response = {
"content": response,
"role": "assistant",
"context":{},
}
## NOTE: In the future, instead of appending to messages we should just return `formatted_response`
messages_list.append(formatted_response) # type: ignore
return {"messages": messages_list, "stream": stream, "session_state": session_state, "context": {}}


chat_target = CallbackChatTarget(callback=callback_target)

all_prompts_list = await self._get_all_prompts(scenario, num_rows=num_rows)

orchestrator = PromptSendingOrchestrator(objective_target=chat_target)
await orchestrator.send_prompts_async(prompt_list=all_prompts_list)
memory = orchestrator.get_memory()

# Get conversations as a List[List[ChatMessage]]
conversations = [[item.to_chat_message() for item in group] for conv_id, group in itertools.groupby(memory, key=lambda x: x.conversation_id)]

#Convert to json lines
json_lines = ""
for conversation in conversations: # each conversation is a List[ChatMessage]
json_lines += json.dumps({"conversation": {"messages": [self._message_to_dict(message) for message in conversation]}}) + "\n"

data_path = "pyrit_outputs.jsonl"
with Path(data_path).open("w") as f:
f.writelines(json_lines)
return data_path

async def __call__(
self,
target: Callable,
target: Union[Callable, AzureOpenAIModelConfiguration, OpenAIModelConfiguration],
evaluators: List[_SafetyEvaluator] = [],
evaluation_name: Optional[str] = None,
num_turns : int = 1,
Expand Down Expand Up @@ -583,12 +731,17 @@ async def __call__(

# Get scenario
adversarial_scenario = self._get_scenario(evaluators, num_turns=num_turns, scenario=scenario)
self.logger.info(f"Using scenario: {adversarial_scenario}")

if isinstance(adversarial_scenario, AdversarialScenario) and num_turns==1:
self.logger.info(f"Running Pyrit with inputs target={target}, scenario={scenario}")
data_path = await self._pyrit(target, adversarial_scenario, num_rows=num_rows)

## Get evaluators
evaluators_dict = self._get_evaluators(evaluators)

## If `data_path` is not provided, run simulator
if data_path is None and jailbreak_data_path is None:
if data_path is None and jailbreak_data_path is None and isinstance(target, Callable):
self.logger.info(f"No data_path provided. Running simulator.")
data_paths = await self._simulate(
target=target,
Expand Down Expand Up @@ -637,4 +790,7 @@ async def __call__(
target=ErrorTarget.UNKNOWN,
category=ErrorCategory.MISSING_FIELD,
blame=ErrorBlame.USER_ERROR,
)
)



Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,32 @@ def test_target(query: str) -> str:
))
# [END default_safety_evaluation]

# [START default_safety_evaluation_model_target]
"""
please install the pyrit extra to run this example

cd azure-sdk-for-python/sdk/evaluation/azure-ai-evaluation
pip install -e ".[pyrit]"
"""
model_config = {
"azure_deployment": os.environ.get("AZURE_OPENAI_DEPLOYMENT"),
"azure_endpoint": os.environ.get("AZURE_OPENAI_ENDPOINT"),
}

azure_ai_project = {
"subscription_id": os.environ.get("AZURE_SUBSCRIPTION_ID"),
"resource_group_name": os.environ.get("AZURE_RESOURCE_GROUP_NAME"),
"project_name": os.environ.get("AZURE_PROJECT_NAME"),
}

credential = DefaultAzureCredential()

safety_evaluation_default = _SafetyEvaluation(azure_ai_project=azure_ai_project, credential=credential)
safety_evaluation_default_results = asyncio.run(safety_evaluation_default(
target=model_config,
))
# [END default_safety_evaluation_model_target]

# [START content_safety_safety_evaluation]

def test_target(query: str) -> str:
Expand Down Expand Up @@ -115,7 +141,7 @@ def test_target(query: str) -> str:
safety_evaluation_content_safety_scenario_results = asyncio.run(safety_evaluation_content_safety_scenario(
evaluators=[_SafetyEvaluator.CONTENT_SAFETY],
target=test_target,
scenario=AdversarialScenario.ADVERSARIAL_SUMMARIZATION,,
scenario=AdversarialScenario.ADVERSARIAL_SUMMARIZATION,
num_rows=3,
output_path="evaluation_outputs_safety_scenario.jsonl",
))
Expand Down
Loading
Loading