diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_callback_chat_target.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_callback_chat_target.py new file mode 100644 index 000000000000..14dc989b7c90 --- /dev/null +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_callback_chat_target.py @@ -0,0 +1,75 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import logging +from typing import Any, Callable, Dict, List, Optional + +from pyrit.models import ( + PromptRequestResponse, + construct_response_from_request, +) +from pyrit.prompt_target import PromptChatTarget + +logger = logging.getLogger(__name__) + + +class CallbackChatTarget(PromptChatTarget): + def __init__( + self, + *, + callback: Callable[[List[Dict], bool, Optional[str], Optional[Dict[str, Any]]], Dict], + stream: bool = False, + ) -> None: + """ + Initializes an instance of the CallbackChatTarget class. + + It is intended to be used with PYRIT where users define a callback function + that handles sending a prompt to a target and receiving a response. + The CallbackChatTarget class is a wrapper around the callback function that allows it to be used + as a target in the PyRIT framework. + For that reason, it merely handles additional functionality such as memory. + + Args: + callback (Callable): The callback function that sends a prompt to a target and receives a response. + stream (bool, optional): Indicates whether the target supports streaming. Defaults to False. + """ + PromptChatTarget.__init__(self) + self._callback = callback + self._stream = stream + + async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse: + + self._validate_request(prompt_request=prompt_request) + request = prompt_request.request_pieces[0] + + messages = self._memory.get_chat_messages_with_conversation_id(conversation_id=request.conversation_id) + + messages.append(request.to_chat_message()) + + logger.info(f"Sending the following prompt to the prompt target: {request}") + + # response_context contains "messages", "stream", "session_state, "context" + response_context = await self._callback(messages=messages, stream=self._stream, session_state=None, context=None) + + response_text = response_context["messages"][-1]["content"] + response_entry = construct_response_from_request( + request=request, response_text_pieces=[response_text] + ) + + logger.info( + "Received the following response from the prompt target" + + f"{response_text}" + ) + return response_entry + + def _validate_request(self, *, prompt_request: PromptRequestResponse) -> None: + if len(prompt_request.request_pieces) != 1: + raise ValueError("This target only supports a single prompt request piece.") + + if prompt_request.request_pieces[0].converted_value_data_type != "text": + raise ValueError("This target only supports text prompt input.") + + def is_json_response_supported(self) -> bool: + """Indicates that this target supports JSON response format.""" + return False diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_safety_evaluation.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_safety_evaluation.py index 730165e75bab..c80a5cd0b504 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_safety_evaluation.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_safety_evaluation.py @@ -8,7 +8,7 @@ import logging from datetime import datetime from azure.ai.evaluation._common._experimental import experimental -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast from azure.ai.evaluation._common.math import list_mean_nan_safe from azure.ai.evaluation._constants import CONTENT_SAFETY_DEFECT_RATE_THRESHOLD_DEFAULT from azure.ai.evaluation._evaluators import _content_safety, _protected_material, _groundedness, _relevance, _similarity, _fluency, _xpia, _coherence @@ -17,11 +17,19 @@ from azure.ai.evaluation._model_configurations import AzureAIProject, EvaluationResult from azure.ai.evaluation.simulator import Simulator, AdversarialSimulator, AdversarialScenario, AdversarialScenarioJailbreak, IndirectAttackSimulator, DirectAttackSimulator from azure.ai.evaluation.simulator._utils import JsonLineList +from azure.ai.evaluation.simulator._model_tools import ManagedIdentityAPITokenManager, TokenScope, RAIClient, AdversarialTemplateHandler from azure.ai.evaluation._common.utils import validate_azure_ai_project from azure.ai.evaluation._model_configurations import AzureOpenAIModelConfiguration, OpenAIModelConfiguration +from azure.ai.evaluation._safety_evaluation._callback_chat_target import CallbackChatTarget from azure.core.credentials import TokenCredential import json from pathlib import Path +import re +import itertools +from pyrit.common import initialize_pyrit, DUCK_DB +from pyrit.orchestrator import PromptSendingOrchestrator +from pyrit.prompt_target import OpenAIChatTarget +from pyrit.models import ChatMessage logger = logging.getLogger(__name__) @@ -87,6 +95,21 @@ def __init__( self.credential=credential self.logger = _setup_logger() + # For pyrit + self.token_manager = ManagedIdentityAPITokenManager( + token_scope=TokenScope.DEFAULT_AZURE_MANAGEMENT, + logger=logging.getLogger("AdversarialSimulator"), + credential=cast(TokenCredential, credential), + ) + + self.rai_client = RAIClient(azure_ai_project=self.azure_ai_project, token_manager=self.token_manager) + self.adversarial_template_handler = AdversarialTemplateHandler( + azure_ai_project=self.azure_ai_project, rai_client=self.rai_client + ) + + initialize_pyrit(memory_db_type=DUCK_DB) + + @staticmethod def _validate_model_config(model_config: Any): """ @@ -426,11 +449,27 @@ def _check_target_returns_context(target: Callable) -> bool: if ret_type is tuple: return True return False + + @staticmethod + def _check_target_returns_str(target: Callable) -> bool: + ''' + Checks if the target function returns a string. + + :param target: The target function to check. + :type target: Callable + ''' + sig = inspect.signature(target) + ret_type = sig.return_annotation + if ret_type == inspect.Signature.empty: + return False + if ret_type is str: + return True + return False def _validate_inputs( self, evaluators: List[_SafetyEvaluator], - target: Callable, + target: Union[Callable, AzureOpenAIModelConfiguration, OpenAIModelConfiguration], num_turns: int = 1, scenario: Optional[Union[AdversarialScenario, AdversarialScenarioJailbreak]] = None, source_text: Optional[str] = None, @@ -447,10 +486,23 @@ def _validate_inputs( :type scenario: Optional[Union[AdversarialScenario, AdversarialScenarioJailbreak]] :param source_text: The source text to use as grounding document in the evaluation. :type source_text: Optional[str] - ''' - if _SafetyEvaluator.GROUNDEDNESS in evaluators and not (self._check_target_returns_context(target) or source_text): - self.logger.error(f"GroundednessEvaluator requires either source_text or a target function that returns context. Source text: {source_text}, _check_target_returns_context: {self._check_target_returns_context(target)}") - msg = "GroundednessEvaluator requires either source_text or a target function that returns context" + ''' + if not callable(target): + self._validate_model_config(target) + elif not self._check_target_returns_str(target): + self.logger.error(f"Target function {target} does not return a string.") + msg = f"Target function {target} does not return a string." + raise EvaluationException( + message=msg, + internal_message=msg, + target=ErrorTarget.UNKNOWN, + category=ErrorCategory.INVALID_VALUE, + blame=ErrorBlame.USER_ERROR, + ) + + if _SafetyEvaluator.GROUNDEDNESS in evaluators and not source_text: + self.logger.error(f"GroundednessEvaluator requires source_text. Source text: {source_text}") + msg = "GroundednessEvaluator requires source_text" raise EvaluationException( message=msg, internal_message=msg, @@ -458,8 +510,8 @@ def _validate_inputs( category=ErrorCategory.MISSING_FIELD, blame=ErrorBlame.USER_ERROR, ) - - if scenario and not _SafetyEvaluator.CONTENT_SAFETY in evaluators: + + if scenario and len(evaluators)>0 and not _SafetyEvaluator.CONTENT_SAFETY in evaluators: self.logger.error(f"Adversarial scenario {scenario} is not supported without content safety evaluation.") msg = f"Adversarial scenario {scenario} is not supported without content safety evaluation." raise EvaluationException( @@ -470,7 +522,7 @@ def _validate_inputs( blame=ErrorBlame.USER_ERROR, ) - if _SafetyEvaluator.CONTENT_SAFETY in evaluators and scenario and num_turns > 1: + if _SafetyEvaluator.CONTENT_SAFETY in evaluators and scenario and num_turns > 1 and scenario != AdversarialScenario.ADVERSARIAL_CONVERSATION: self.logger.error(f"Adversarial scenario {scenario} is not supported for content safety evaluation with more than 1 turn.") msg = f"Adversarial scenario {scenario} is not supported for content safety evaluation with more than 1 turn." raise EvaluationException( @@ -524,11 +576,107 @@ def _calculate_defect_rate(self, evaluation_result_dict) -> EvaluationResult: } evaluation_result['studio_url'] = evaluation_result_dict['jailbreak']['studio_url'] + '\t' + evaluation_result_dict['regular']['studio_url'] return evaluation_result + + async def _get_all_prompts(self, scenario: AdversarialScenario, num_rows: int = 3) -> List[str]: + templates = await self.adversarial_template_handler._get_content_harm_template_collections( + scenario.value + ) + parameter_lists = [t.template_parameters for t in templates] + zipped_parameters = list(zip(*parameter_lists)) + + def fill_template(template_str, parameter): + pattern = re.compile(r'{{\s*(.*?)\s*}}') + def replacer(match): + placeholder = match.group(1) + return parameter.get(placeholder, f"{{{{ {placeholder} }}}}") + filled_text = pattern.sub(replacer, template_str) + return filled_text + + all_prompts = [] + count = 0 + for param_group in zipped_parameters: + for _, parameter in zip(templates, param_group): + filled_template = fill_template(parameter['conversation_starter'], parameter) + all_prompts.append(filled_template) + count += 1 + if count >= num_rows: + break + if count >= num_rows: + break + return all_prompts + + def _message_to_dict(self, message: ChatMessage): + return { + "role": message.role, + "content": message.content, + } + + async def _pyrit( + self, + target: Union[Callable, AzureOpenAIModelConfiguration, OpenAIModelConfiguration], + scenario: AdversarialScenario, + num_rows: int = 1, + ) -> str: + chat_target: OpenAIChatTarget = None + if not isinstance(target, Callable): + if "azure_deployment" in target and "azure_endpoint" in target: # Azure OpenAI + api_key = target.get("api_key", None) + if not api_key: + chat_target = OpenAIChatTarget(deployment_name=target["azure_deployment"], endpoint=target["azure_endpoint"], use_aad_auth=True) + else: + chat_target = OpenAIChatTarget(deployment_name=target["azure_deployment"], endpoint=target["azure_endpoint"], api_key=api_key) + else: + chat_target = OpenAIChatTarget(deployment=target["model"], endpoint=target.get("base_url", None), key=target["api_key"], is_azure_target=False) + else: + async def callback_target( + messages: List[Dict], + stream: bool = False, + session_state: Optional[str] = None, + context: Optional[Dict] = None + ) -> dict: + messages_list = [self._message_to_dict(chat_message) for chat_message in messages] # type: ignore + latest_message = messages_list[-1] + application_input = latest_message["content"] + try: + response = target(query=application_input) + except Exception as e: + response = f"Something went wrong {e!s}" + + ## We format the response to follow the openAI chat protocol format + formatted_response = { + "content": response, + "role": "assistant", + "context":{}, + } + ## NOTE: In the future, instead of appending to messages we should just return `formatted_response` + messages_list.append(formatted_response) # type: ignore + return {"messages": messages_list, "stream": stream, "session_state": session_state, "context": {}} + + + chat_target = CallbackChatTarget(callback=callback_target) + + all_prompts_list = await self._get_all_prompts(scenario, num_rows=num_rows) + orchestrator = PromptSendingOrchestrator(objective_target=chat_target) + await orchestrator.send_prompts_async(prompt_list=all_prompts_list) + memory = orchestrator.get_memory() + + # Get conversations as a List[List[ChatMessage]] + conversations = [[item.to_chat_message() for item in group] for conv_id, group in itertools.groupby(memory, key=lambda x: x.conversation_id)] + + #Convert to json lines + json_lines = "" + for conversation in conversations: # each conversation is a List[ChatMessage] + json_lines += json.dumps({"conversation": {"messages": [self._message_to_dict(message) for message in conversation]}}) + "\n" + + data_path = "pyrit_outputs.jsonl" + with Path(data_path).open("w") as f: + f.writelines(json_lines) + return data_path async def __call__( self, - target: Callable, + target: Union[Callable, AzureOpenAIModelConfiguration, OpenAIModelConfiguration], evaluators: List[_SafetyEvaluator] = [], evaluation_name: Optional[str] = None, num_turns : int = 1, @@ -583,12 +731,17 @@ async def __call__( # Get scenario adversarial_scenario = self._get_scenario(evaluators, num_turns=num_turns, scenario=scenario) + self.logger.info(f"Using scenario: {adversarial_scenario}") + + if isinstance(adversarial_scenario, AdversarialScenario) and num_turns==1: + self.logger.info(f"Running Pyrit with inputs target={target}, scenario={scenario}") + data_path = await self._pyrit(target, adversarial_scenario, num_rows=num_rows) ## Get evaluators evaluators_dict = self._get_evaluators(evaluators) ## If `data_path` is not provided, run simulator - if data_path is None and jailbreak_data_path is None: + if data_path is None and jailbreak_data_path is None and isinstance(target, Callable): self.logger.info(f"No data_path provided. Running simulator.") data_paths = await self._simulate( target=target, @@ -637,4 +790,7 @@ async def __call__( target=ErrorTarget.UNKNOWN, category=ErrorCategory.MISSING_FIELD, blame=ErrorBlame.USER_ERROR, - ) \ No newline at end of file + ) + + + diff --git a/sdk/evaluation/azure-ai-evaluation/samples/evaluation_samples_safety_evaluation.py b/sdk/evaluation/azure-ai-evaluation/samples/evaluation_samples_safety_evaluation.py index f90064e17908..ab3e4de7b97d 100644 --- a/sdk/evaluation/azure-ai-evaluation/samples/evaluation_samples_safety_evaluation.py +++ b/sdk/evaluation/azure-ai-evaluation/samples/evaluation_samples_safety_evaluation.py @@ -48,6 +48,32 @@ def test_target(query: str) -> str: )) # [END default_safety_evaluation] + # [START default_safety_evaluation_model_target] + """ + please install the pyrit extra to run this example + + cd azure-sdk-for-python/sdk/evaluation/azure-ai-evaluation + pip install -e ".[pyrit]" + """ + model_config = { + "azure_deployment": os.environ.get("AZURE_OPENAI_DEPLOYMENT"), + "azure_endpoint": os.environ.get("AZURE_OPENAI_ENDPOINT"), + } + + azure_ai_project = { + "subscription_id": os.environ.get("AZURE_SUBSCRIPTION_ID"), + "resource_group_name": os.environ.get("AZURE_RESOURCE_GROUP_NAME"), + "project_name": os.environ.get("AZURE_PROJECT_NAME"), + } + + credential = DefaultAzureCredential() + + safety_evaluation_default = _SafetyEvaluation(azure_ai_project=azure_ai_project, credential=credential) + safety_evaluation_default_results = asyncio.run(safety_evaluation_default( + target=model_config, + )) + # [END default_safety_evaluation_model_target] + # [START content_safety_safety_evaluation] def test_target(query: str) -> str: @@ -115,7 +141,7 @@ def test_target(query: str) -> str: safety_evaluation_content_safety_scenario_results = asyncio.run(safety_evaluation_content_safety_scenario( evaluators=[_SafetyEvaluator.CONTENT_SAFETY], target=test_target, - scenario=AdversarialScenario.ADVERSARIAL_SUMMARIZATION,, + scenario=AdversarialScenario.ADVERSARIAL_SUMMARIZATION, num_rows=3, output_path="evaluation_outputs_safety_scenario.jsonl", )) diff --git a/sdk/evaluation/azure-ai-evaluation/samples/pyrit_sim.py b/sdk/evaluation/azure-ai-evaluation/samples/pyrit_sim.py new file mode 100644 index 000000000000..5b4016976998 --- /dev/null +++ b/sdk/evaluation/azure-ai-evaluation/samples/pyrit_sim.py @@ -0,0 +1,60 @@ +""" +please install the pyrit extra to run this example + +cd azure-sdk-for-python/sdk/evaluation/azure-ai-evaluation +pip install -e ".[pyrit]" +""" + + +from azure.ai.evaluation._safety_evaluation._safety_evaluation import _SafetyEvaluation +import os +from azure.identity import DefaultAzureCredential +from azure.ai.evaluation.simulator import AdversarialScenario + + +async def main(): + model_config = { + "azure_endpoint": os.environ.get("AZURE_ENDPOINT"), + "azure_deployment": os.environ.get("AZURE_DEPLOYMENT_NAME"), + } + + def test_target_fn(query: str) -> str: + return "mock response" + + azure_ai_project = { + "subscription_id": os.environ.get("AZURE_SUBSCRIPTION_ID"), + "resource_group_name": os.environ.get("AZURE_RESOURCE_GROUP"), + "project_name": os.environ.get("AZURE_PROJECT_NAME"), + } + + safety_eval_callback_target = _SafetyEvaluation( + azure_ai_project=azure_ai_project, + credential=DefaultAzureCredential(), + ) + + outputs = await safety_eval_callback_target( + target=test_target_fn, + num_rows=8, + ) + + print(outputs) + + safety_eval_model_target = _SafetyEvaluation( + azure_ai_project=azure_ai_project, + credential=DefaultAzureCredential(), + ) + + outputs = await safety_eval_model_target( + target=model_config, + num_rows=8, + ) + print(outputs) + +if __name__ == "__main__": + import asyncio + import time + start = time.perf_counter() + asyncio.run(main()) + end = time.perf_counter() + print(f"Runtime: {end - start:.2f} seconds") + diff --git a/sdk/evaluation/azure-ai-evaluation/setup.py b/sdk/evaluation/azure-ai-evaluation/setup.py index 8c1ba9f75ca1..94d86277f439 100644 --- a/sdk/evaluation/azure-ai-evaluation/setup.py +++ b/sdk/evaluation/azure-ai-evaluation/setup.py @@ -71,8 +71,11 @@ "azure-identity>=1.16.0", "azure-core>=1.30.2", "nltk>=3.9.1", - "azure-storage-blob>=12.10.0" + "azure-storage-blob>=12.10.0", ], + extras_require={ + "pyrit": ["pyrit @ git+https://github.com/Azure/PyRIT.git"] + }, project_urls={ "Bug Reports": "https://github.com/Azure/azure-sdk-for-python/issues", "Source": "https://github.com/Azure/azure-sdk-for-python", diff --git a/shared_requirements.txt b/shared_requirements.txt index 8664d40ceba2..8aed1a7b0ae3 100644 --- a/shared_requirements.txt +++ b/shared_requirements.txt @@ -71,4 +71,5 @@ dnspython promptflow-core promptflow-devkit nltk -azure-monitor-opentelemetry \ No newline at end of file +azure-monitor-opentelemetry +pyrit \ No newline at end of file