1
- from typing import Any , Dict , List , Optional
1
+ from typing import Any , Dict , List , Optional , Tuple
2
+ import uuid
2
3
3
4
import regex as re
4
5
import structlog
5
6
from litellm import ChatCompletionRequest , ChatCompletionSystemMessage , ModelResponse
6
7
from litellm .types .utils import Delta , StreamingChoices
7
8
8
9
from codegate .config import Config
10
+ from codegate .db .models import AlertSeverity
9
11
from codegate .pipeline .base import (
10
12
PipelineContext ,
11
13
PipelineResult ,
12
14
PipelineStep ,
13
15
)
14
16
from codegate .pipeline .output import OutputPipelineContext , OutputPipelineStep
15
- from codegate .pipeline .pii .manager import PiiManager
17
+ from codegate .pipeline .pii .analyzer import PiiAnalyzer
18
+ from codegate .pipeline .sensitive_data .manager import SensitiveData , SensitiveDataManager
16
19
from codegate .pipeline .systemmsg import add_or_update_system_message
17
20
18
21
logger = structlog .get_logger ("codegate" )
@@ -25,7 +28,7 @@ class CodegatePii(PipelineStep):
25
28
26
29
Methods:
27
30
__init__:
28
- Initializes the CodegatePii pipeline step and sets up the PiiManager .
31
+ Initializes the CodegatePii pipeline step and sets up the SensitiveDataManager .
29
32
30
33
name:
31
34
Returns the name of the pipeline step.
@@ -37,14 +40,15 @@ class CodegatePii(PipelineStep):
37
40
Processes the chat completion request to detect and redact PII. Updates the request with
38
41
anonymized text and stores PII details in the context metadata.
39
42
40
- restore_pii(anonymized_text: str) -> str:
41
- Restores the original PII from the anonymized text using the PiiManager .
43
+ restore_pii(session_id: str, anonymized_text: str) -> str:
44
+ Restores the original PII from the anonymized text using the SensitiveDataManager .
42
45
"""
43
46
44
- def __init__ (self ):
47
+ def __init__ (self , sensitive_data_manager : SensitiveDataManager ):
45
48
"""Initialize the CodegatePii pipeline step."""
46
49
super ().__init__ ()
47
- self .pii_manager = PiiManager ()
50
+ self .sensitive_data_manager = sensitive_data_manager
51
+ self .analyzer = PiiAnalyzer .get_instance ()
48
52
49
53
@property
50
54
def name (self ) -> str :
@@ -65,6 +69,68 @@ def _get_redacted_snippet(self, message: str, pii_details: List[Dict[str, Any]])
65
69
66
70
return message [start :end ]
67
71
72
+ def process_results (
73
+ self , session_id : str , text : str , results : List , context : PipelineContext
74
+ ) -> Tuple [List , str ]:
75
+ # Track found PII
76
+ found_pii = []
77
+
78
+ # Log each found PII instance and anonymize
79
+ anonymized_text = text
80
+ for result in results :
81
+ pii_value = text [result .start : result .end ]
82
+
83
+ # add to session store
84
+ obj = SensitiveData (original = pii_value , service = "pii" , type = result .entity_type )
85
+ uuid_placeholder = self .sensitive_data_manager .store (session_id , obj )
86
+ anonymized_text = anonymized_text .replace (pii_value , uuid_placeholder )
87
+
88
+ # Add to found PII list
89
+ pii_info = {
90
+ "type" : result .entity_type ,
91
+ "value" : pii_value ,
92
+ "score" : result .score ,
93
+ "start" : result .start ,
94
+ "end" : result .end ,
95
+ "uuid_placeholder" : uuid_placeholder ,
96
+ }
97
+ found_pii .append (pii_info )
98
+
99
+ # Log each PII detection with its UUID mapping
100
+ logger .info (
101
+ "PII detected and mapped" ,
102
+ pii_type = result .entity_type ,
103
+ score = f"{ result .score :.2f} " ,
104
+ uuid = uuid_placeholder ,
105
+ # Don't log the actual PII value for security
106
+ value_length = len (pii_value ),
107
+ session_id = session_id ,
108
+ )
109
+
110
+ # Log summary of all PII found in this analysis
111
+ if found_pii and context :
112
+ # Create notification string for alert
113
+ notify_string = (
114
+ f"**PII Detected** 🔒\n "
115
+ f"- Total PII Found: { len (found_pii )} \n "
116
+ f"- Types Found: { ', ' .join (set (p ['type' ] for p in found_pii ))} \n "
117
+ )
118
+ context .add_alert (
119
+ self .name ,
120
+ trigger_string = notify_string ,
121
+ severity_category = AlertSeverity .CRITICAL ,
122
+ )
123
+
124
+ logger .info (
125
+ "PII analysis complete" ,
126
+ total_pii_found = len (found_pii ),
127
+ pii_types = [p ["type" ] for p in found_pii ],
128
+ session_id = session_id ,
129
+ )
130
+
131
+ # Return the anonymized text, PII details, and session store
132
+ return found_pii , anonymized_text
133
+
68
134
async def process (
69
135
self , request : ChatCompletionRequest , context : PipelineContext
70
136
) -> PipelineResult :
@@ -75,33 +141,39 @@ async def process(
75
141
total_pii_found = 0
76
142
all_pii_details : List [Dict [str , Any ]] = []
77
143
last_redacted_text = ""
144
+ session_id = context .sensitive .session_id
78
145
79
146
for i , message in enumerate (new_request ["messages" ]):
80
147
if "content" in message and message ["content" ]:
81
148
# This is where analyze and anonymize the text
82
149
original_text = str (message ["content" ])
83
- anonymized_text , pii_details = self .pii_manager .analyze (original_text , context )
84
-
85
- if pii_details :
86
- total_pii_found += len (pii_details )
87
- all_pii_details .extend (pii_details )
88
- new_request ["messages" ][i ]["content" ] = anonymized_text
89
-
90
- # If this is a user message, grab the redacted snippet!
91
- if message .get ("role" ) == "user" :
92
- last_redacted_text = self ._get_redacted_snippet (
93
- anonymized_text , pii_details
94
- )
150
+ results = self .analyzer .analyze (original_text , context )
151
+ if results :
152
+ pii_details , anonymized_text = self .process_results (
153
+ session_id , original_text , results , context
154
+ )
155
+
156
+ if pii_details :
157
+ total_pii_found += len (pii_details )
158
+ all_pii_details .extend (pii_details )
159
+ new_request ["messages" ][i ]["content" ] = anonymized_text
160
+
161
+ # If this is a user message, grab the redacted snippet!
162
+ if message .get ("role" ) == "user" :
163
+ last_redacted_text = self ._get_redacted_snippet (
164
+ anonymized_text , pii_details
165
+ )
95
166
96
167
logger .info (f"Total PII instances redacted: { total_pii_found } " )
97
168
98
169
# Store the count, details, and redacted text in context metadata
99
170
context .metadata ["redacted_pii_count" ] = total_pii_found
100
171
context .metadata ["redacted_pii_details" ] = all_pii_details
101
172
context .metadata ["redacted_text" ] = last_redacted_text
173
+ context .metadata ["session_id" ] = session_id
102
174
103
175
if total_pii_found > 0 :
104
- context .metadata ["pii_manager " ] = self .pii_manager
176
+ context .metadata ["sensitive_data_manager " ] = self .sensitive_data_manager
105
177
106
178
system_message = ChatCompletionSystemMessage (
107
179
content = Config .get_config ().prompts .pii_redacted ,
@@ -113,8 +185,31 @@ async def process(
113
185
114
186
return PipelineResult (request = new_request , context = context )
115
187
116
- def restore_pii (self , anonymized_text : str ) -> str :
117
- return self .pii_manager .restore_pii (anonymized_text )
188
+ def restore_pii (self , session_id : str , anonymized_text : str ) -> str :
189
+ """
190
+ Restore the original PII (Personally Identifiable Information) in the given anonymized text.
191
+
192
+ This method replaces placeholders in the anonymized text with their corresponding original
193
+ PII values using the mappings stored in the provided SessionStore.
194
+
195
+ Args:
196
+ anonymized_text (str): The text containing placeholders for PII.
197
+ session_id (str): The session id containing mappings of placeholders
198
+ to original PII.
199
+
200
+ Returns:
201
+ str: The text with the original PII restored.
202
+ """
203
+ session_data = self .sensitive_data_manager .get_by_session_id (session_id )
204
+ if not session_data :
205
+ logger .warning (
206
+ "No active PII session found for given session ID. Unable to restore PII."
207
+ )
208
+ return anonymized_text
209
+
210
+ for uuid_placeholder , original_pii in session_data .items ():
211
+ anonymized_text = anonymized_text .replace (uuid_placeholder , original_pii )
212
+ return anonymized_text
118
213
119
214
120
215
class PiiUnRedactionStep (OutputPipelineStep ):
@@ -136,12 +231,12 @@ class PiiUnRedactionStep(OutputPipelineStep):
136
231
"""
137
232
138
233
def __init__ (self ):
139
- self .redacted_pattern = re .compile (r"< ([0-9a-f-]{0,36})> " )
234
+ self .redacted_pattern = re .compile (r"# ([0-9a-f-]{0,36})# " )
140
235
self .complete_uuid_pattern = re .compile (
141
236
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"
142
237
) # noqa: E501
143
- self .marker_start = "< "
144
- self .marker_end = "> "
238
+ self .marker_start = "# "
239
+ self .marker_end = "# "
145
240
146
241
@property
147
242
def name (self ) -> str :
@@ -151,7 +246,7 @@ def _is_complete_uuid(self, uuid_str: str) -> bool:
151
246
"""Check if the string is a complete UUID"""
152
247
return bool (self .complete_uuid_pattern .match (uuid_str ))
153
248
154
- async def process_chunk (
249
+ async def process_chunk ( # noqa: C901
155
250
self ,
156
251
chunk : ModelResponse ,
157
252
context : OutputPipelineContext ,
@@ -162,6 +257,10 @@ async def process_chunk(
162
257
return [chunk ]
163
258
164
259
content = chunk .choices [0 ].delta .content
260
+ session_id = input_context .sensitive .session_id
261
+ if not session_id :
262
+ logger .error ("Could not get any session id, cannot process pii" )
263
+ return [chunk ]
165
264
166
265
# Add current chunk to buffer
167
266
if context .prefix_buffer :
@@ -172,13 +271,13 @@ async def process_chunk(
172
271
current_pos = 0
173
272
result = []
174
273
while current_pos < len (content ):
175
- start_idx = content .find ("<" , current_pos )
274
+ start_idx = content .find (self . marker_start , current_pos )
176
275
if start_idx == - 1 :
177
276
# No more markers!, add remaining content
178
277
result .append (content [current_pos :])
179
278
break
180
279
181
- end_idx = content .find (">" , start_idx )
280
+ end_idx = content .find (self . marker_end , start_idx + 1 )
182
281
if end_idx == - 1 :
183
282
# Incomplete marker, buffer the rest
184
283
context .prefix_buffer = content [current_pos :]
@@ -190,16 +289,18 @@ async def process_chunk(
190
289
191
290
# Extract potential UUID if it's a valid format!
192
291
uuid_marker = content [start_idx : end_idx + 1 ]
193
- uuid_value = uuid_marker [1 :- 1 ] # Remove < >
292
+ uuid_value = uuid_marker [1 :- 1 ] # Remove # #
194
293
195
294
if self ._is_complete_uuid (uuid_value ):
196
295
# Get the PII manager from context metadata
197
296
logger .debug (f"Valid UUID found: { uuid_value } " )
198
- pii_manager = input_context .metadata .get ("pii_manager" ) if input_context else None
199
- if pii_manager and pii_manager .session_store :
297
+ sensitive_data_manager = (
298
+ input_context .metadata .get ("sensitive_data_manager" ) if input_context else None
299
+ )
300
+ if sensitive_data_manager and sensitive_data_manager .session_store :
200
301
# Restore original value from PII manager
201
302
logger .debug ("Attempting to restore PII from UUID marker" )
202
- original = pii_manager . session_store . get_pii ( uuid_marker )
303
+ original = sensitive_data_manager . get_original_value ( session_id , uuid_marker )
203
304
logger .debug (f"Restored PII: { original } " )
204
305
result .append (original )
205
306
else :
0 commit comments