From 689847ed0c861d6036bea1a1005ecaf6e571ddbd Mon Sep 17 00:00:00 2001 From: Sydney Lister Date: Thu, 20 Feb 2025 11:21:42 -0800 Subject: [PATCH 01/14] Adding pyrit dependency --- sdk/evaluation/azure-ai-evaluation/setup.py | 3 ++- shared_requirements.txt | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/sdk/evaluation/azure-ai-evaluation/setup.py b/sdk/evaluation/azure-ai-evaluation/setup.py index 8c1ba9f75ca1..b691cf67177e 100644 --- a/sdk/evaluation/azure-ai-evaluation/setup.py +++ b/sdk/evaluation/azure-ai-evaluation/setup.py @@ -71,7 +71,8 @@ "azure-identity>=1.16.0", "azure-core>=1.30.2", "nltk>=3.9.1", - "azure-storage-blob>=12.10.0" + "azure-storage-blob>=12.10.0", + "pyrit>=0.5.2" ], project_urls={ "Bug Reports": "https://github.com/Azure/azure-sdk-for-python/issues", diff --git a/shared_requirements.txt b/shared_requirements.txt index 8664d40ceba2..8aed1a7b0ae3 100644 --- a/shared_requirements.txt +++ b/shared_requirements.txt @@ -71,4 +71,5 @@ dnspython promptflow-core promptflow-devkit nltk -azure-monitor-opentelemetry \ No newline at end of file +azure-monitor-opentelemetry +pyrit \ No newline at end of file From 0bf11086c60594eab51fd74bc6c8f8a55072ae60 Mon Sep 17 00:00:00 2001 From: Sydney Lister Date: Thu, 20 Feb 2025 17:09:48 -0800 Subject: [PATCH 02/14] updates --- .../_safety_evaluation/_safety_evaluation.py | 76 ++++++++++++++++++- 1 file changed, 73 insertions(+), 3 deletions(-) diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_safety_evaluation.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_safety_evaluation.py index 730165e75bab..53f9c9e8b9f5 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_safety_evaluation.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_safety_evaluation.py @@ -8,7 +8,7 @@ import logging from datetime import datetime from azure.ai.evaluation._common._experimental import experimental -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast from azure.ai.evaluation._common.math import list_mean_nan_safe from azure.ai.evaluation._constants import CONTENT_SAFETY_DEFECT_RATE_THRESHOLD_DEFAULT from azure.ai.evaluation._evaluators import _content_safety, _protected_material, _groundedness, _relevance, _similarity, _fluency, _xpia, _coherence @@ -17,11 +17,16 @@ from azure.ai.evaluation._model_configurations import AzureAIProject, EvaluationResult from azure.ai.evaluation.simulator import Simulator, AdversarialSimulator, AdversarialScenario, AdversarialScenarioJailbreak, IndirectAttackSimulator, DirectAttackSimulator from azure.ai.evaluation.simulator._utils import JsonLineList +from azure.ai.evaluation.simulator._model_tools import ManagedIdentityAPITokenManager, TokenScope, RAIClient, AdversarialTemplateHandler from azure.ai.evaluation._common.utils import validate_azure_ai_project from azure.ai.evaluation._model_configurations import AzureOpenAIModelConfiguration, OpenAIModelConfiguration from azure.core.credentials import TokenCredential import json from pathlib import Path +import re +from pyrit.common import initialize_pyrit, IN_MEMORY +from pyrit.orchestrator import PromptSendingOrchestrator +from pyrit.prompt_target import OpenAIChatTarget logger = logging.getLogger(__name__) @@ -87,6 +92,21 @@ def __init__( self.credential=credential self.logger = _setup_logger() + # For pyrit + self.token_manager = ManagedIdentityAPITokenManager( + token_scope=TokenScope.DEFAULT_AZURE_MANAGEMENT, + logger=logging.getLogger("AdversarialSimulator"), + credential=cast(TokenCredential, credential), + ) + + self.rai_client = RAIClient(azure_ai_project=self.azure_ai_project, token_manager=self.token_manager) + self.adversarial_template_handler = AdversarialTemplateHandler( + azure_ai_project=self.azure_ai_project, rai_client=self.rai_client + ) + + initialize_pyrit(memory_db_type=IN_MEMORY) + + @staticmethod def _validate_model_config(model_config: Any): """ @@ -524,7 +544,28 @@ def _calculate_defect_rate(self, evaluation_result_dict) -> EvaluationResult: } evaluation_result['studio_url'] = evaluation_result_dict['jailbreak']['studio_url'] + '\t' + evaluation_result_dict['regular']['studio_url'] return evaluation_result - + + async def _get_all_prompts(self, scenario: AdversarialScenario) -> List[str]: + templates = await self.adversarial_template_handler._get_content_harm_template_collections( + AdversarialScenario.ADVERSARIAL_QA.value + ) + parameter_lists = [t.template_parameters for t in templates] + zipped_parameters = list(zip(*parameter_lists)) + + def fill_template(template_str, parameter): + pattern = re.compile(r'{{\s*(.*?)\s*}}') + def replacer(match): + placeholder = match.group(1) + return parameter.get(placeholder, f"{{{{ {placeholder} }}}}") + filled_text = pattern.sub(replacer, template_str) + return filled_text + + all_prompts = [] + for param_group in zipped_parameters: + for template_obj, parameter in zip(templates, param_group): + filled_template = fill_template(parameter['conversation_starter'], parameter) + all_prompts.append(filled_template) + return all_prompts async def __call__( self, @@ -637,4 +678,33 @@ async def __call__( target=ErrorTarget.UNKNOWN, category=ErrorCategory.MISSING_FIELD, blame=ErrorBlame.USER_ERROR, - ) \ No newline at end of file + ) + + + +if __name__ == "__main__": + from dotenv import load_dotenv + from azure.identity import DefaultAzureCredential + import asyncio + import re + load_dotenv() + def test_target(query: str) -> str: + return "Hello, how can I help you?" + + azure_ai_project = { + "subscription_id": os.environ.get("AZURE_SUBSCRIPTION_ID"), + "resource_group_name": os.environ.get("AZURE_RESOURCE_GROUP_NAME"), + "project_name": os.environ.get("AZURE_PROJECT_NAME"), + } + + credential = DefaultAzureCredential() + + safety_evaluation = _SafetyEvaluation(azure_ai_project=azure_ai_project, credential=credential) + target = OpenAIChatTarget() + all_prompts_list = asyncio.run(safety_evaluation._get_all_prompts(AdversarialScenario.ADVERSARIAL_QA)) + orchestrator = PromptSendingOrchestrator(objective_target=target) + asyncio.run(orchestrator.send_prompts_async(prompt_list=all_prompts_list)) + asyncio.run(orchestrator.print_conversations_async()) + + + From 5adf8646aaf8d911e838df849825315d8cf9b8e4 Mon Sep 17 00:00:00 2001 From: Sydney Lister Date: Fri, 21 Feb 2025 10:09:43 -0800 Subject: [PATCH 03/14] updates --- .../_safety_evaluation/_safety_evaluation.py | 136 +++++++++++++----- 1 file changed, 99 insertions(+), 37 deletions(-) diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_safety_evaluation.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_safety_evaluation.py index 53f9c9e8b9f5..e747cac2fc7e 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_safety_evaluation.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_safety_evaluation.py @@ -24,7 +24,7 @@ import json from pathlib import Path import re -from pyrit.common import initialize_pyrit, IN_MEMORY +from pyrit.common import initialize_pyrit, DUCK_DB from pyrit.orchestrator import PromptSendingOrchestrator from pyrit.prompt_target import OpenAIChatTarget @@ -104,7 +104,7 @@ def __init__( azure_ai_project=self.azure_ai_project, rai_client=self.rai_client ) - initialize_pyrit(memory_db_type=IN_MEMORY) + initialize_pyrit(memory_db_type=DUCK_DB) @staticmethod @@ -446,11 +446,27 @@ def _check_target_returns_context(target: Callable) -> bool: if ret_type is tuple: return True return False + + @staticmethod + def _check_target_returns_str(target: Callable) -> bool: + ''' + Checks if the target function returns a string. + + :param target: The target function to check. + :type target: Callable + ''' + sig = inspect.signature(target) + ret_type = sig.return_annotation + if ret_type == inspect.Signature.empty: + return False + if ret_type is str: + return True + return False def _validate_inputs( self, evaluators: List[_SafetyEvaluator], - target: Callable, + target: Union[Callable, AzureOpenAIModelConfiguration, OpenAIModelConfiguration], num_turns: int = 1, scenario: Optional[Union[AdversarialScenario, AdversarialScenarioJailbreak]] = None, source_text: Optional[str] = None, @@ -467,10 +483,23 @@ def _validate_inputs( :type scenario: Optional[Union[AdversarialScenario, AdversarialScenarioJailbreak]] :param source_text: The source text to use as grounding document in the evaluation. :type source_text: Optional[str] - ''' - if _SafetyEvaluator.GROUNDEDNESS in evaluators and not (self._check_target_returns_context(target) or source_text): - self.logger.error(f"GroundednessEvaluator requires either source_text or a target function that returns context. Source text: {source_text}, _check_target_returns_context: {self._check_target_returns_context(target)}") - msg = "GroundednessEvaluator requires either source_text or a target function that returns context" + ''' + if not callable(target): + self._validate_model_config(target) + elif not self._check_target_returns_str(target): + self.logger.error(f"Target function {target} does not return a string.") + msg = f"Target function {target} does not return a string." + raise EvaluationException( + message=msg, + internal_message=msg, + target=ErrorTarget.UNKNOWN, + category=ErrorCategory.INVALID_VALUE, + blame=ErrorBlame.USER_ERROR, + ) + + if _SafetyEvaluator.GROUNDEDNESS in evaluators and not source_text: + self.logger.error(f"GroundednessEvaluator requires source_text. Source text: {source_text}") + msg = "GroundednessEvaluator requires source_text" raise EvaluationException( message=msg, internal_message=msg, @@ -547,7 +576,7 @@ def _calculate_defect_rate(self, evaluation_result_dict) -> EvaluationResult: async def _get_all_prompts(self, scenario: AdversarialScenario) -> List[str]: templates = await self.adversarial_template_handler._get_content_harm_template_collections( - AdversarialScenario.ADVERSARIAL_QA.value + scenario.value ) parameter_lists = [t.template_parameters for t in templates] zipped_parameters = list(zip(*parameter_lists)) @@ -562,14 +591,68 @@ def replacer(match): all_prompts = [] for param_group in zipped_parameters: - for template_obj, parameter in zip(templates, param_group): + for _, parameter in zip(templates, param_group): filled_template = fill_template(parameter['conversation_starter'], parameter) all_prompts.append(filled_template) return all_prompts + + async def _pyrit(self, target: Union[AzureOpenAIModelConfiguration, OpenAIModelConfiguration], scenario: AdversarialScenario) -> str: + chat_target: OpenAIChatTarget = None + if "azure_deployment" in target and "azure_endpoint" in target: # Azure OpenAI + api_key = target.get("api_key", None) + if not api_key: + chat_target = OpenAIChatTarget(deployment_name=target["azure_deployment"], endpoint=target["azure_endpoint"], use_aad_auth=True) + else: + chat_target = OpenAIChatTarget(deployment_name=target["azure_deployment"], endpoint=target["azure_endpoint"], api_key=api_key) + else: + chat_target = OpenAIChatTarget(deployment=target["model"], endpoint=target.get("base_url", None), key=target["api_key"], is_azure_target=False) + + all_prompts_list = await self._get_all_prompts(scenario) + + orchestrator = PromptSendingOrchestrator(objective_target=chat_target) + await orchestrator.send_prompts_async(prompt_list=all_prompts_list) + memory = orchestrator.get_memory() + + # Get conversations as a List[List[ChatMessage]] + conversations = [] + conversation_id = None + current_conversation = [] + for item in memory: + if conversation_id == None or conversation_id == item.conversation_id: + conversation_id = item.conversation_id + current_conversation.append(item.to_chat_message()) + else: + conversations.append(current_conversation) + current_conversation = [item.to_chat_message()] + conversation_id = item.conversation_id + + #Convert to json lines + json_lines = "" + for conversation in conversations: # each conversation is a List[ChatMessage] + user_message = None + assistant_message = None + for message in conversation: # each message is a ChatMessage + if message.role == "user": + user_message = message.content + elif message.role == "assistant": + assistant_message = message.content + + if user_message and assistant_message: + json_lines += ( + json.dumps({"query": user_message, "response": assistant_message, "category": None}) + + "\n" + ) + user_message = assistant_message = None + + data_path = "pyrit_outputs.jsonl" + with Path(data_path).open("w") as f: + f.writelines(json_lines) + + return data_path async def __call__( self, - target: Callable, + target: Union[Callable, AzureOpenAIModelConfiguration, OpenAIModelConfiguration], evaluators: List[_SafetyEvaluator] = [], evaluation_name: Optional[str] = None, num_turns : int = 1, @@ -624,12 +707,17 @@ async def __call__( # Get scenario adversarial_scenario = self._get_scenario(evaluators, num_turns=num_turns, scenario=scenario) + self.logger.info(f"Using scenario: {adversarial_scenario}") + + if not isinstance(target, Callable) and isinstance(adversarial_scenario, AdversarialScenario): + self.logger.info(f"Running Pyrit with inputs target={target}, scenario={scenario}") + data_path = await self._pyrit(target, adversarial_scenario) ## Get evaluators evaluators_dict = self._get_evaluators(evaluators) ## If `data_path` is not provided, run simulator - if data_path is None and jailbreak_data_path is None: + if data_path is None and jailbreak_data_path is None and isinstance(target, Callable): self.logger.info(f"No data_path provided. Running simulator.") data_paths = await self._simulate( target=target, @@ -679,32 +767,6 @@ async def __call__( category=ErrorCategory.MISSING_FIELD, blame=ErrorBlame.USER_ERROR, ) - - - -if __name__ == "__main__": - from dotenv import load_dotenv - from azure.identity import DefaultAzureCredential - import asyncio - import re - load_dotenv() - def test_target(query: str) -> str: - return "Hello, how can I help you?" - - azure_ai_project = { - "subscription_id": os.environ.get("AZURE_SUBSCRIPTION_ID"), - "resource_group_name": os.environ.get("AZURE_RESOURCE_GROUP_NAME"), - "project_name": os.environ.get("AZURE_PROJECT_NAME"), - } - - credential = DefaultAzureCredential() - - safety_evaluation = _SafetyEvaluation(azure_ai_project=azure_ai_project, credential=credential) - target = OpenAIChatTarget() - all_prompts_list = asyncio.run(safety_evaluation._get_all_prompts(AdversarialScenario.ADVERSARIAL_QA)) - orchestrator = PromptSendingOrchestrator(objective_target=target) - asyncio.run(orchestrator.send_prompts_async(prompt_list=all_prompts_list)) - asyncio.run(orchestrator.print_conversations_async()) From ef1a43153b730048c165c4064c1d099c2c81009e Mon Sep 17 00:00:00 2001 From: Nagkumar Arkalgud Date: Fri, 21 Feb 2025 10:20:58 -0800 Subject: [PATCH 04/14] Make pyrit extra --- sdk/evaluation/azure-ai-evaluation/setup.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sdk/evaluation/azure-ai-evaluation/setup.py b/sdk/evaluation/azure-ai-evaluation/setup.py index b691cf67177e..dbb842f07c9c 100644 --- a/sdk/evaluation/azure-ai-evaluation/setup.py +++ b/sdk/evaluation/azure-ai-evaluation/setup.py @@ -72,8 +72,10 @@ "azure-core>=1.30.2", "nltk>=3.9.1", "azure-storage-blob>=12.10.0", - "pyrit>=0.5.2" ], + extras_require={ + "pyrit": ["git+https://github.com/Azure/PyRIT.git#egg=pyrit"] + }, project_urls={ "Bug Reports": "https://github.com/Azure/azure-sdk-for-python/issues", "Source": "https://github.com/Azure/azure-sdk-for-python", From fdad913672e114df2c821c1d0dfb0be41925df84 Mon Sep 17 00:00:00 2001 From: Nagkumar Arkalgud Date: Fri, 21 Feb 2025 11:34:13 -0800 Subject: [PATCH 05/14] Set up pyrit as extra correctly --- sdk/evaluation/azure-ai-evaluation/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/evaluation/azure-ai-evaluation/setup.py b/sdk/evaluation/azure-ai-evaluation/setup.py index dbb842f07c9c..94d86277f439 100644 --- a/sdk/evaluation/azure-ai-evaluation/setup.py +++ b/sdk/evaluation/azure-ai-evaluation/setup.py @@ -74,7 +74,7 @@ "azure-storage-blob>=12.10.0", ], extras_require={ - "pyrit": ["git+https://github.com/Azure/PyRIT.git#egg=pyrit"] + "pyrit": ["pyrit @ git+https://github.com/Azure/PyRIT.git"] }, project_urls={ "Bug Reports": "https://github.com/Azure/azure-sdk-for-python/issues", From b5263df24bda006f3ec89453c657b9fb2f49ac0c Mon Sep 17 00:00:00 2001 From: Nagkumar Arkalgud Date: Fri, 21 Feb 2025 11:43:06 -0800 Subject: [PATCH 06/14] Limit the number of turns --- .../_safety_evaluation/_safety_evaluation.py | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_safety_evaluation.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_safety_evaluation.py index e747cac2fc7e..e682322e3c3c 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_safety_evaluation.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_safety_evaluation.py @@ -574,7 +574,7 @@ def _calculate_defect_rate(self, evaluation_result_dict) -> EvaluationResult: evaluation_result['studio_url'] = evaluation_result_dict['jailbreak']['studio_url'] + '\t' + evaluation_result_dict['regular']['studio_url'] return evaluation_result - async def _get_all_prompts(self, scenario: AdversarialScenario) -> List[str]: + async def _get_all_prompts(self, scenario: AdversarialScenario, num_turns: int = 3) -> List[str]: templates = await self.adversarial_template_handler._get_content_harm_template_collections( scenario.value ) @@ -590,13 +590,24 @@ def replacer(match): return filled_text all_prompts = [] + count = 0 for param_group in zipped_parameters: for _, parameter in zip(templates, param_group): filled_template = fill_template(parameter['conversation_starter'], parameter) all_prompts.append(filled_template) + count += 1 + if count >= num_turns: + break + if count >= num_turns: + break return all_prompts - async def _pyrit(self, target: Union[AzureOpenAIModelConfiguration, OpenAIModelConfiguration], scenario: AdversarialScenario) -> str: + async def _pyrit( + self, + target: Union[AzureOpenAIModelConfiguration, OpenAIModelConfiguration], + scenario: AdversarialScenario, + num_turns: int = 1, + ) -> str: chat_target: OpenAIChatTarget = None if "azure_deployment" in target and "azure_endpoint" in target: # Azure OpenAI api_key = target.get("api_key", None) @@ -607,7 +618,7 @@ async def _pyrit(self, target: Union[AzureOpenAIModelConfiguration, OpenAIModelC else: chat_target = OpenAIChatTarget(deployment=target["model"], endpoint=target.get("base_url", None), key=target["api_key"], is_azure_target=False) - all_prompts_list = await self._get_all_prompts(scenario) + all_prompts_list = await self._get_all_prompts(scenario, num_turns=num_turns) orchestrator = PromptSendingOrchestrator(objective_target=chat_target) await orchestrator.send_prompts_async(prompt_list=all_prompts_list) @@ -647,7 +658,6 @@ async def _pyrit(self, target: Union[AzureOpenAIModelConfiguration, OpenAIModelC data_path = "pyrit_outputs.jsonl" with Path(data_path).open("w") as f: f.writelines(json_lines) - return data_path async def __call__( @@ -711,7 +721,7 @@ async def __call__( if not isinstance(target, Callable) and isinstance(adversarial_scenario, AdversarialScenario): self.logger.info(f"Running Pyrit with inputs target={target}, scenario={scenario}") - data_path = await self._pyrit(target, adversarial_scenario) + data_path = await self._pyrit(target, adversarial_scenario, num_turns=num_turns) ## Get evaluators evaluators_dict = self._get_evaluators(evaluators) From c81d2b6920556c654e39def706ecf6345ebbba6c Mon Sep 17 00:00:00 2001 From: Sydney Lister Date: Fri, 21 Feb 2025 11:50:20 -0800 Subject: [PATCH 07/14] adding sample --- .../evaluation_samples_safety_evaluation.py | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/sdk/evaluation/azure-ai-evaluation/samples/evaluation_samples_safety_evaluation.py b/sdk/evaluation/azure-ai-evaluation/samples/evaluation_samples_safety_evaluation.py index f90064e17908..842a3097aab8 100644 --- a/sdk/evaluation/azure-ai-evaluation/samples/evaluation_samples_safety_evaluation.py +++ b/sdk/evaluation/azure-ai-evaluation/samples/evaluation_samples_safety_evaluation.py @@ -48,6 +48,26 @@ def test_target(query: str) -> str: )) # [END default_safety_evaluation] + # [START default_safety_evaluation_model_target] + model_config = { + "azure_deployment": os.environ.get("AZURE_OPENAI_DEPLOYMENT"), + "azure_endpoint": os.environ.get("AZURE_OPENAI_ENDPOINT"), + } + + azure_ai_project = { + "subscription_id": os.environ.get("AZURE_SUBSCRIPTION_ID"), + "resource_group_name": os.environ.get("AZURE_RESOURCE_GROUP_NAME"), + "project_name": os.environ.get("AZURE_PROJECT_NAME"), + } + + credential = DefaultAzureCredential() + + safety_evaluation_default = _SafetyEvaluation(azure_ai_project=azure_ai_project, credential=credential) + safety_evaluation_default_results = asyncio.run(safety_evaluation_default( + target=model_config, + )) + # [END default_safety_evaluation_model_target] + # [START content_safety_safety_evaluation] def test_target(query: str) -> str: @@ -115,7 +135,7 @@ def test_target(query: str) -> str: safety_evaluation_content_safety_scenario_results = asyncio.run(safety_evaluation_content_safety_scenario( evaluators=[_SafetyEvaluator.CONTENT_SAFETY], target=test_target, - scenario=AdversarialScenario.ADVERSARIAL_SUMMARIZATION,, + scenario=AdversarialScenario.ADVERSARIAL_SUMMARIZATION, num_rows=3, output_path="evaluation_outputs_safety_scenario.jsonl", )) From 3ff1d7f3be252676c310dcc88272195ce332c9d1 Mon Sep 17 00:00:00 2001 From: Nagkumar Arkalgud Date: Fri, 21 Feb 2025 12:31:26 -0800 Subject: [PATCH 08/14] baseline sample with evals --- .../_safety_evaluation/_safety_evaluation.py | 37 +++++----------- .../azure-ai-evaluation/samples/pyrit_sim.py | 43 +++++++++++++++++++ 2 files changed, 54 insertions(+), 26 deletions(-) create mode 100644 sdk/evaluation/azure-ai-evaluation/samples/pyrit_sim.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_safety_evaluation.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_safety_evaluation.py index e682322e3c3c..0345cfccd61d 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_safety_evaluation.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_safety_evaluation.py @@ -24,6 +24,7 @@ import json from pathlib import Path import re +import itertools from pyrit.common import initialize_pyrit, DUCK_DB from pyrit.orchestrator import PromptSendingOrchestrator from pyrit.prompt_target import OpenAIChatTarget @@ -625,35 +626,19 @@ async def _pyrit( memory = orchestrator.get_memory() # Get conversations as a List[List[ChatMessage]] - conversations = [] - conversation_id = None - current_conversation = [] - for item in memory: - if conversation_id == None or conversation_id == item.conversation_id: - conversation_id = item.conversation_id - current_conversation.append(item.to_chat_message()) - else: - conversations.append(current_conversation) - current_conversation = [item.to_chat_message()] - conversation_id = item.conversation_id - + + conversations = [[item.to_chat_message() for item in group] for conv_id, group in itertools.groupby(memory, key=lambda x: x.conversation_id)] + + # TODO: convert this to a helper + def message_to_dict(message: "ChatMessage"): + return { + "role": message.role, + "content": message.content, + } #Convert to json lines json_lines = "" for conversation in conversations: # each conversation is a List[ChatMessage] - user_message = None - assistant_message = None - for message in conversation: # each message is a ChatMessage - if message.role == "user": - user_message = message.content - elif message.role == "assistant": - assistant_message = message.content - - if user_message and assistant_message: - json_lines += ( - json.dumps({"query": user_message, "response": assistant_message, "category": None}) - + "\n" - ) - user_message = assistant_message = None + json_lines += json.dumps({"conversation": {"messages": [message_to_dict(message) for message in conversation]}}) + "\n" data_path = "pyrit_outputs.jsonl" with Path(data_path).open("w") as f: diff --git a/sdk/evaluation/azure-ai-evaluation/samples/pyrit_sim.py b/sdk/evaluation/azure-ai-evaluation/samples/pyrit_sim.py new file mode 100644 index 000000000000..11fcb52a5ace --- /dev/null +++ b/sdk/evaluation/azure-ai-evaluation/samples/pyrit_sim.py @@ -0,0 +1,43 @@ +""" +please install the pyrit extra to run this example + +cd azure-sdk-for-python/sdk/evaluation/azure-ai-evaluation +pip install -e ".[pyrit]" +""" + + +from azure.ai.evaluation._safety_evaluation._safety_evaluation import _SafetyEvaluation +import os +from azure.identity import DefaultAzureCredential + + +async def main(): + model_config = { + "azure_endpoint": os.environ.get("AZURE_ENDPOINT"), + "azure_deployment": os.environ.get("AZURE_DEPLOYMENT_NAME"), + } + azure_ai_project = { + "subscription_id": os.environ.get("AZURE_SUBSCRIPTION_ID"), + "resource_group_name": os.environ.get("AZURE_RESOURCE_GROUP"), + "project_name": os.environ.get("AZURE_PROJECT_NAME"), + } + safety_eval = _SafetyEvaluation( + azure_ai_project=azure_ai_project, + model_config=model_config, + credential=DefaultAzureCredential(), + ) + + outputs = await safety_eval( + target=model_config, + num_turns=8, + ) + print(outputs) + +if __name__ == "__main__": + import asyncio + import time + start = time.perf_counter() + asyncio.run(main()) + end = time.perf_counter() + print(f"Runtime: {end - start:.2f} seconds") + From e7e1e92417cb0be008c0ecdc600699f1aaf49799 Mon Sep 17 00:00:00 2001 From: Sydney Lister Date: Fri, 21 Feb 2025 14:58:00 -0800 Subject: [PATCH 09/14] updates --- .../_safety_evaluation/_safety_evaluation.py | 34 +++++++++---------- .../evaluation_samples_safety_evaluation.py | 6 ++++ .../azure-ai-evaluation/samples/pyrit_sim.py | 2 +- 3 files changed, 24 insertions(+), 18 deletions(-) diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_safety_evaluation.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_safety_evaluation.py index 0345cfccd61d..91be9360d027 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_safety_evaluation.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_safety_evaluation.py @@ -28,6 +28,7 @@ from pyrit.common import initialize_pyrit, DUCK_DB from pyrit.orchestrator import PromptSendingOrchestrator from pyrit.prompt_target import OpenAIChatTarget +from pyrit.models import ChatMessage logger = logging.getLogger(__name__) @@ -508,8 +509,8 @@ def _validate_inputs( category=ErrorCategory.MISSING_FIELD, blame=ErrorBlame.USER_ERROR, ) - - if scenario and not _SafetyEvaluator.CONTENT_SAFETY in evaluators: + + if scenario and len(evaluators)>0 and not _SafetyEvaluator.CONTENT_SAFETY in evaluators: self.logger.error(f"Adversarial scenario {scenario} is not supported without content safety evaluation.") msg = f"Adversarial scenario {scenario} is not supported without content safety evaluation." raise EvaluationException( @@ -520,7 +521,7 @@ def _validate_inputs( blame=ErrorBlame.USER_ERROR, ) - if _SafetyEvaluator.CONTENT_SAFETY in evaluators and scenario and num_turns > 1: + if _SafetyEvaluator.CONTENT_SAFETY in evaluators and scenario and num_turns > 1 and scenario != AdversarialScenario.ADVERSARIAL_CONVERSATION: self.logger.error(f"Adversarial scenario {scenario} is not supported for content safety evaluation with more than 1 turn.") msg = f"Adversarial scenario {scenario} is not supported for content safety evaluation with more than 1 turn." raise EvaluationException( @@ -575,7 +576,7 @@ def _calculate_defect_rate(self, evaluation_result_dict) -> EvaluationResult: evaluation_result['studio_url'] = evaluation_result_dict['jailbreak']['studio_url'] + '\t' + evaluation_result_dict['regular']['studio_url'] return evaluation_result - async def _get_all_prompts(self, scenario: AdversarialScenario, num_turns: int = 3) -> List[str]: + async def _get_all_prompts(self, scenario: AdversarialScenario, num_rows: int = 3) -> List[str]: templates = await self.adversarial_template_handler._get_content_harm_template_collections( scenario.value ) @@ -597,17 +598,23 @@ def replacer(match): filled_template = fill_template(parameter['conversation_starter'], parameter) all_prompts.append(filled_template) count += 1 - if count >= num_turns: + if count >= num_rows: break - if count >= num_turns: + if count >= num_rows: break return all_prompts + def message_to_dict(self, message: ChatMessage): + return { + "role": message.role, + "content": message.content, + } + async def _pyrit( self, target: Union[AzureOpenAIModelConfiguration, OpenAIModelConfiguration], scenario: AdversarialScenario, - num_turns: int = 1, + num_rows: int = 1, ) -> str: chat_target: OpenAIChatTarget = None if "azure_deployment" in target and "azure_endpoint" in target: # Azure OpenAI @@ -619,26 +626,19 @@ async def _pyrit( else: chat_target = OpenAIChatTarget(deployment=target["model"], endpoint=target.get("base_url", None), key=target["api_key"], is_azure_target=False) - all_prompts_list = await self._get_all_prompts(scenario, num_turns=num_turns) + all_prompts_list = await self._get_all_prompts(scenario, num_rows=num_rows) orchestrator = PromptSendingOrchestrator(objective_target=chat_target) await orchestrator.send_prompts_async(prompt_list=all_prompts_list) memory = orchestrator.get_memory() # Get conversations as a List[List[ChatMessage]] - conversations = [[item.to_chat_message() for item in group] for conv_id, group in itertools.groupby(memory, key=lambda x: x.conversation_id)] - # TODO: convert this to a helper - def message_to_dict(message: "ChatMessage"): - return { - "role": message.role, - "content": message.content, - } #Convert to json lines json_lines = "" for conversation in conversations: # each conversation is a List[ChatMessage] - json_lines += json.dumps({"conversation": {"messages": [message_to_dict(message) for message in conversation]}}) + "\n" + json_lines += json.dumps({"conversation": {"messages": [self.message_to_dict(message) for message in conversation]}}) + "\n" data_path = "pyrit_outputs.jsonl" with Path(data_path).open("w") as f: @@ -706,7 +706,7 @@ async def __call__( if not isinstance(target, Callable) and isinstance(adversarial_scenario, AdversarialScenario): self.logger.info(f"Running Pyrit with inputs target={target}, scenario={scenario}") - data_path = await self._pyrit(target, adversarial_scenario, num_turns=num_turns) + data_path = await self._pyrit(target, adversarial_scenario, num_rows=num_rows) ## Get evaluators evaluators_dict = self._get_evaluators(evaluators) diff --git a/sdk/evaluation/azure-ai-evaluation/samples/evaluation_samples_safety_evaluation.py b/sdk/evaluation/azure-ai-evaluation/samples/evaluation_samples_safety_evaluation.py index 842a3097aab8..ab3e4de7b97d 100644 --- a/sdk/evaluation/azure-ai-evaluation/samples/evaluation_samples_safety_evaluation.py +++ b/sdk/evaluation/azure-ai-evaluation/samples/evaluation_samples_safety_evaluation.py @@ -49,6 +49,12 @@ def test_target(query: str) -> str: # [END default_safety_evaluation] # [START default_safety_evaluation_model_target] + """ + please install the pyrit extra to run this example + + cd azure-sdk-for-python/sdk/evaluation/azure-ai-evaluation + pip install -e ".[pyrit]" + """ model_config = { "azure_deployment": os.environ.get("AZURE_OPENAI_DEPLOYMENT"), "azure_endpoint": os.environ.get("AZURE_OPENAI_ENDPOINT"), diff --git a/sdk/evaluation/azure-ai-evaluation/samples/pyrit_sim.py b/sdk/evaluation/azure-ai-evaluation/samples/pyrit_sim.py index 11fcb52a5ace..1a95f98e4146 100644 --- a/sdk/evaluation/azure-ai-evaluation/samples/pyrit_sim.py +++ b/sdk/evaluation/azure-ai-evaluation/samples/pyrit_sim.py @@ -29,7 +29,7 @@ async def main(): outputs = await safety_eval( target=model_config, - num_turns=8, + num_rows=8, ) print(outputs) From 353245d1cdc3443e0c5e7bbed9369487a40f9c97 Mon Sep 17 00:00:00 2001 From: Nagkumar Arkalgud Date: Mon, 24 Feb 2025 08:20:30 -0800 Subject: [PATCH 10/14] Update setup to use roman's branch --- sdk/evaluation/azure-ai-evaluation/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/evaluation/azure-ai-evaluation/setup.py b/sdk/evaluation/azure-ai-evaluation/setup.py index 94d86277f439..d80d1246860a 100644 --- a/sdk/evaluation/azure-ai-evaluation/setup.py +++ b/sdk/evaluation/azure-ai-evaluation/setup.py @@ -74,7 +74,7 @@ "azure-storage-blob>=12.10.0", ], extras_require={ - "pyrit": ["pyrit @ git+https://github.com/Azure/PyRIT.git"] + "pyrit": ["pyrit @ git+https://github.com/romanlutz/PyRIT.git@romanlutz/callback_chat_target"] }, project_urls={ "Bug Reports": "https://github.com/Azure/azure-sdk-for-python/issues", From 05dd4a661857285a58240640c1849a8424683c7a Mon Sep 17 00:00:00 2001 From: Sydney Lister Date: Mon, 24 Feb 2025 11:17:19 -0800 Subject: [PATCH 11/14] callback chat target --- .../_safety_evaluation/_safety_evaluation.py | 51 +++++++++++++++---- 1 file changed, 40 insertions(+), 11 deletions(-) diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_safety_evaluation.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_safety_evaluation.py index 91be9360d027..a339b001fd6f 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_safety_evaluation.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_safety_evaluation.py @@ -28,6 +28,7 @@ from pyrit.common import initialize_pyrit, DUCK_DB from pyrit.orchestrator import PromptSendingOrchestrator from pyrit.prompt_target import OpenAIChatTarget +from pyrit.prompt_target.callback_chat_target import CallbackChatTarget from pyrit.models import ChatMessage logger = logging.getLogger(__name__) @@ -604,7 +605,7 @@ def replacer(match): break return all_prompts - def message_to_dict(self, message: ChatMessage): + def _message_to_dict(self, message: ChatMessage): return { "role": message.role, "content": message.content, @@ -612,19 +613,47 @@ def message_to_dict(self, message: ChatMessage): async def _pyrit( self, - target: Union[AzureOpenAIModelConfiguration, OpenAIModelConfiguration], + target: Union[Callable, AzureOpenAIModelConfiguration, OpenAIModelConfiguration], scenario: AdversarialScenario, num_rows: int = 1, ) -> str: chat_target: OpenAIChatTarget = None - if "azure_deployment" in target and "azure_endpoint" in target: # Azure OpenAI - api_key = target.get("api_key", None) - if not api_key: - chat_target = OpenAIChatTarget(deployment_name=target["azure_deployment"], endpoint=target["azure_endpoint"], use_aad_auth=True) - else: - chat_target = OpenAIChatTarget(deployment_name=target["azure_deployment"], endpoint=target["azure_endpoint"], api_key=api_key) + if not isinstance(target, Callable): + if "azure_deployment" in target and "azure_endpoint" in target: # Azure OpenAI + api_key = target.get("api_key", None) + if not api_key: + chat_target = OpenAIChatTarget(deployment_name=target["azure_deployment"], endpoint=target["azure_endpoint"], use_aad_auth=True) + else: + chat_target = OpenAIChatTarget(deployment_name=target["azure_deployment"], endpoint=target["azure_endpoint"], api_key=api_key) + else: + chat_target = OpenAIChatTarget(deployment=target["model"], endpoint=target.get("base_url", None), key=target["api_key"], is_azure_target=False) else: - chat_target = OpenAIChatTarget(deployment=target["model"], endpoint=target.get("base_url", None), key=target["api_key"], is_azure_target=False) + async def callback_target( + messages: List[Dict], + stream: bool = False, + session_state: Optional[str] = None, + context: Optional[Dict] = None + ) -> dict: + messages_list = [self._message_to_dict(chat_message) for chat_message in messages] # type: ignore + latest_message = messages_list[-1] + application_input = latest_message["content"] + try: + response = target(query=application_input) + except Exception as e: + response = f"Something went wrong {e!s}" + + ## We format the response to follow the openAI chat protocol format + formatted_response = { + "content": response, + "role": "assistant", + "context":{}, + } + ## NOTE: In the future, instead of appending to messages we should just return `formatted_response` + messages_list.append(formatted_response) # type: ignore + return {"messages": messages_list, "stream": stream, "session_state": session_state, "context": {}} + + + chat_target = CallbackChatTarget(callback=callback_target) all_prompts_list = await self._get_all_prompts(scenario, num_rows=num_rows) @@ -638,7 +667,7 @@ async def _pyrit( #Convert to json lines json_lines = "" for conversation in conversations: # each conversation is a List[ChatMessage] - json_lines += json.dumps({"conversation": {"messages": [self.message_to_dict(message) for message in conversation]}}) + "\n" + json_lines += json.dumps({"conversation": {"messages": [self._message_to_dict(message) for message in conversation]}}) + "\n" data_path = "pyrit_outputs.jsonl" with Path(data_path).open("w") as f: @@ -704,7 +733,7 @@ async def __call__( adversarial_scenario = self._get_scenario(evaluators, num_turns=num_turns, scenario=scenario) self.logger.info(f"Using scenario: {adversarial_scenario}") - if not isinstance(target, Callable) and isinstance(adversarial_scenario, AdversarialScenario): + if isinstance(adversarial_scenario, AdversarialScenario) and num_turns==1: self.logger.info(f"Running Pyrit with inputs target={target}, scenario={scenario}") data_path = await self._pyrit(target, adversarial_scenario, num_rows=num_rows) From e15a4e75d9315eb79aa54eaee6354b66f993867d Mon Sep 17 00:00:00 2001 From: Sydney Lister Date: Mon, 24 Feb 2025 11:40:18 -0800 Subject: [PATCH 12/14] sample update --- .../azure-ai-evaluation/samples/pyrit_sim.py | 23 ++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/sdk/evaluation/azure-ai-evaluation/samples/pyrit_sim.py b/sdk/evaluation/azure-ai-evaluation/samples/pyrit_sim.py index 1a95f98e4146..5b4016976998 100644 --- a/sdk/evaluation/azure-ai-evaluation/samples/pyrit_sim.py +++ b/sdk/evaluation/azure-ai-evaluation/samples/pyrit_sim.py @@ -9,6 +9,7 @@ from azure.ai.evaluation._safety_evaluation._safety_evaluation import _SafetyEvaluation import os from azure.identity import DefaultAzureCredential +from azure.ai.evaluation.simulator import AdversarialScenario async def main(): @@ -16,18 +17,34 @@ async def main(): "azure_endpoint": os.environ.get("AZURE_ENDPOINT"), "azure_deployment": os.environ.get("AZURE_DEPLOYMENT_NAME"), } + + def test_target_fn(query: str) -> str: + return "mock response" + azure_ai_project = { "subscription_id": os.environ.get("AZURE_SUBSCRIPTION_ID"), "resource_group_name": os.environ.get("AZURE_RESOURCE_GROUP"), "project_name": os.environ.get("AZURE_PROJECT_NAME"), } - safety_eval = _SafetyEvaluation( + + safety_eval_callback_target = _SafetyEvaluation( + azure_ai_project=azure_ai_project, + credential=DefaultAzureCredential(), + ) + + outputs = await safety_eval_callback_target( + target=test_target_fn, + num_rows=8, + ) + + print(outputs) + + safety_eval_model_target = _SafetyEvaluation( azure_ai_project=azure_ai_project, - model_config=model_config, credential=DefaultAzureCredential(), ) - outputs = await safety_eval( + outputs = await safety_eval_model_target( target=model_config, num_rows=8, ) From cf0b3f5f5963abc65accf3f0ca4ed5f6666a7d31 Mon Sep 17 00:00:00 2001 From: Sydney Lister Date: Mon, 24 Feb 2025 13:15:08 -0800 Subject: [PATCH 13/14] add local callback chat target --- .../_callback_chat_target.py | 75 +++++++++++++++++++ .../_safety_evaluation/_safety_evaluation.py | 2 +- 2 files changed, 76 insertions(+), 1 deletion(-) create mode 100644 sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_callback_chat_target.py diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_callback_chat_target.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_callback_chat_target.py new file mode 100644 index 000000000000..14dc989b7c90 --- /dev/null +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_callback_chat_target.py @@ -0,0 +1,75 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import logging +from typing import Any, Callable, Dict, List, Optional + +from pyrit.models import ( + PromptRequestResponse, + construct_response_from_request, +) +from pyrit.prompt_target import PromptChatTarget + +logger = logging.getLogger(__name__) + + +class CallbackChatTarget(PromptChatTarget): + def __init__( + self, + *, + callback: Callable[[List[Dict], bool, Optional[str], Optional[Dict[str, Any]]], Dict], + stream: bool = False, + ) -> None: + """ + Initializes an instance of the CallbackChatTarget class. + + It is intended to be used with PYRIT where users define a callback function + that handles sending a prompt to a target and receiving a response. + The CallbackChatTarget class is a wrapper around the callback function that allows it to be used + as a target in the PyRIT framework. + For that reason, it merely handles additional functionality such as memory. + + Args: + callback (Callable): The callback function that sends a prompt to a target and receives a response. + stream (bool, optional): Indicates whether the target supports streaming. Defaults to False. + """ + PromptChatTarget.__init__(self) + self._callback = callback + self._stream = stream + + async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse: + + self._validate_request(prompt_request=prompt_request) + request = prompt_request.request_pieces[0] + + messages = self._memory.get_chat_messages_with_conversation_id(conversation_id=request.conversation_id) + + messages.append(request.to_chat_message()) + + logger.info(f"Sending the following prompt to the prompt target: {request}") + + # response_context contains "messages", "stream", "session_state, "context" + response_context = await self._callback(messages=messages, stream=self._stream, session_state=None, context=None) + + response_text = response_context["messages"][-1]["content"] + response_entry = construct_response_from_request( + request=request, response_text_pieces=[response_text] + ) + + logger.info( + "Received the following response from the prompt target" + + f"{response_text}" + ) + return response_entry + + def _validate_request(self, *, prompt_request: PromptRequestResponse) -> None: + if len(prompt_request.request_pieces) != 1: + raise ValueError("This target only supports a single prompt request piece.") + + if prompt_request.request_pieces[0].converted_value_data_type != "text": + raise ValueError("This target only supports text prompt input.") + + def is_json_response_supported(self) -> bool: + """Indicates that this target supports JSON response format.""" + return False diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_safety_evaluation.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_safety_evaluation.py index a339b001fd6f..c80a5cd0b504 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_safety_evaluation.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_safety_evaluation/_safety_evaluation.py @@ -20,6 +20,7 @@ from azure.ai.evaluation.simulator._model_tools import ManagedIdentityAPITokenManager, TokenScope, RAIClient, AdversarialTemplateHandler from azure.ai.evaluation._common.utils import validate_azure_ai_project from azure.ai.evaluation._model_configurations import AzureOpenAIModelConfiguration, OpenAIModelConfiguration +from azure.ai.evaluation._safety_evaluation._callback_chat_target import CallbackChatTarget from azure.core.credentials import TokenCredential import json from pathlib import Path @@ -28,7 +29,6 @@ from pyrit.common import initialize_pyrit, DUCK_DB from pyrit.orchestrator import PromptSendingOrchestrator from pyrit.prompt_target import OpenAIChatTarget -from pyrit.prompt_target.callback_chat_target import CallbackChatTarget from pyrit.models import ChatMessage logger = logging.getLogger(__name__) From 1997babb53829c7186145485980e4a407d2af191 Mon Sep 17 00:00:00 2001 From: Nagkumar Arkalgud Date: Mon, 24 Feb 2025 13:41:40 -0800 Subject: [PATCH 14/14] Revert to adding main pyrit as extra --- sdk/evaluation/azure-ai-evaluation/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/evaluation/azure-ai-evaluation/setup.py b/sdk/evaluation/azure-ai-evaluation/setup.py index d80d1246860a..94d86277f439 100644 --- a/sdk/evaluation/azure-ai-evaluation/setup.py +++ b/sdk/evaluation/azure-ai-evaluation/setup.py @@ -74,7 +74,7 @@ "azure-storage-blob>=12.10.0", ], extras_require={ - "pyrit": ["pyrit @ git+https://github.com/romanlutz/PyRIT.git@romanlutz/callback_chat_target"] + "pyrit": ["pyrit @ git+https://github.com/Azure/PyRIT.git"] }, project_urls={ "Bug Reports": "https://github.com/Azure/azure-sdk-for-python/issues",