Skip to content

Commit 58132ba

Browse files
authoredMar 4, 2025··
feat: remove not needed encryption of secrets (#1123)
* feat: remove not needed encryption of secrets Instead use an uuid generator as we do for pii, and reuse same session store mechanism Closes: #929 * fix tests * unify interface in sensitive data * add missing tests * changes from rebase * fixes from review * fixes in tests * fix tests
1 parent a031791 commit 58132ba

21 files changed

+500
-1086
lines changed
 

‎src/codegate/cli.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from codegate.config import Config, ConfigurationError
1717
from codegate.db.connection import init_db_sync, init_session_if_not_exists
1818
from codegate.pipeline.factory import PipelineFactory
19-
from codegate.pipeline.secrets.manager import SecretsManager
19+
from codegate.pipeline.sensitive_data.manager import SensitiveDataManager
2020
from codegate.providers import crud as provendcrud
2121
from codegate.providers.copilot.provider import CopilotProvider
2222
from codegate.server import init_app
@@ -331,8 +331,8 @@ def serve( # noqa: C901
331331
click.echo("Existing Certificates are already present.")
332332

333333
# Initialize secrets manager and pipeline factory
334-
secrets_manager = SecretsManager()
335-
pipeline_factory = PipelineFactory(secrets_manager)
334+
sensitive_data_manager = SensitiveDataManager()
335+
pipeline_factory = PipelineFactory(sensitive_data_manager)
336336

337337
app = init_app(pipeline_factory)
338338

‎src/codegate/pipeline/base.py

+8-19
Original file line numberDiff line numberDiff line change
@@ -12,34 +12,23 @@
1212
from codegate.clients.clients import ClientType
1313
from codegate.db.models import Alert, AlertSeverity, Output, Prompt
1414
from codegate.extract_snippets.message_extractor import CodeSnippet
15-
from codegate.pipeline.secrets.manager import SecretsManager
15+
from codegate.pipeline.sensitive_data.manager import SensitiveDataManager
1616

1717
logger = structlog.get_logger("codegate")
1818

1919

2020
@dataclass
2121
class PipelineSensitiveData:
22-
manager: SecretsManager
22+
manager: SensitiveDataManager
2323
session_id: str
24-
api_key: Optional[str] = None
2524
model: Optional[str] = None
26-
provider: Optional[str] = None
27-
api_base: Optional[str] = None
2825

2926
def secure_cleanup(self):
3027
"""Securely cleanup sensitive data for this session"""
3128
if self.manager is None or self.session_id == "":
3229
return
33-
3430
self.manager.cleanup_session(self.session_id)
3531
self.session_id = ""
36-
37-
# Securely wipe the API key using the same method as secrets manager
38-
if self.api_key is not None:
39-
api_key_bytes = bytearray(self.api_key.encode())
40-
self.manager.crypto.wipe_bytearray(api_key_bytes)
41-
self.api_key = None
42-
4332
self.model = None
4433

4534

@@ -274,19 +263,19 @@ class InputPipelineInstance:
274263
def __init__(
275264
self,
276265
pipeline_steps: List[PipelineStep],
277-
secret_manager: SecretsManager,
266+
sensitive_data_manager: SensitiveDataManager,
278267
is_fim: bool,
279268
client: ClientType = ClientType.GENERIC,
280269
):
281270
self.pipeline_steps = pipeline_steps
282-
self.secret_manager = secret_manager
271+
self.sensitive_data_manager = sensitive_data_manager
283272
self.is_fim = is_fim
284273
self.context = PipelineContext(client=client)
285274

286275
# we create the sesitive context here so that it is not shared between individual requests
287276
# TODO: could we get away with just generating the session ID for an instance?
288277
self.context.sensitive = PipelineSensitiveData(
289-
manager=self.secret_manager,
278+
manager=self.sensitive_data_manager,
290279
session_id=str(uuid.uuid4()),
291280
)
292281
self.context.metadata["is_fim"] = is_fim
@@ -343,20 +332,20 @@ class SequentialPipelineProcessor:
343332
def __init__(
344333
self,
345334
pipeline_steps: List[PipelineStep],
346-
secret_manager: SecretsManager,
335+
sensitive_data_manager: SensitiveDataManager,
347336
client_type: ClientType,
348337
is_fim: bool,
349338
):
350339
self.pipeline_steps = pipeline_steps
351-
self.secret_manager = secret_manager
340+
self.sensitive_data_manager = sensitive_data_manager
352341
self.is_fim = is_fim
353342
self.instance = self._create_instance(client_type)
354343

355344
def _create_instance(self, client_type: ClientType) -> InputPipelineInstance:
356345
"""Create a new pipeline instance for processing a request"""
357346
return InputPipelineInstance(
358347
self.pipeline_steps,
359-
self.secret_manager,
348+
self.sensitive_data_manager,
360349
self.is_fim,
361350
client_type,
362351
)

‎src/codegate/pipeline/factory.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,18 @@
1212
PiiRedactionNotifier,
1313
PiiUnRedactionStep,
1414
)
15-
from codegate.pipeline.secrets.manager import SecretsManager
1615
from codegate.pipeline.secrets.secrets import (
1716
CodegateSecrets,
1817
SecretRedactionNotifier,
1918
SecretUnredactionStep,
2019
)
20+
from codegate.pipeline.sensitive_data.manager import SensitiveDataManager
2121
from codegate.pipeline.system_prompt.codegate import SystemPrompt
2222

2323

2424
class PipelineFactory:
25-
def __init__(self, secrets_manager: SecretsManager):
26-
self.secrets_manager = secrets_manager
25+
def __init__(self, sensitive_data_manager: SensitiveDataManager):
26+
self.sensitive_data_manager = sensitive_data_manager
2727

2828
def create_input_pipeline(self, client_type: ClientType) -> SequentialPipelineProcessor:
2929
input_steps: List[PipelineStep] = [
@@ -32,7 +32,7 @@ def create_input_pipeline(self, client_type: ClientType) -> SequentialPipelinePr
3232
# and without obfuscating the secrets, we'd leak the secrets during those
3333
# later steps
3434
CodegateSecrets(),
35-
CodegatePii(),
35+
CodegatePii(self.sensitive_data_manager),
3636
CodegateCli(),
3737
CodegateContextRetriever(),
3838
SystemPrompt(
@@ -41,19 +41,19 @@ def create_input_pipeline(self, client_type: ClientType) -> SequentialPipelinePr
4141
]
4242
return SequentialPipelineProcessor(
4343
input_steps,
44-
self.secrets_manager,
44+
self.sensitive_data_manager,
4545
client_type,
4646
is_fim=False,
4747
)
4848

4949
def create_fim_pipeline(self, client_type: ClientType) -> SequentialPipelineProcessor:
5050
fim_steps: List[PipelineStep] = [
5151
CodegateSecrets(),
52-
CodegatePii(),
52+
CodegatePii(self.sensitive_data_manager),
5353
]
5454
return SequentialPipelineProcessor(
5555
fim_steps,
56-
self.secrets_manager,
56+
self.sensitive_data_manager,
5757
client_type,
5858
is_fim=True,
5959
)

‎src/codegate/pipeline/pii/analyzer.py

+18-102
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,16 @@
1-
import uuid
2-
from typing import Any, Dict, List, Optional, Tuple
1+
from typing import Any, List, Optional
32

43
import structlog
54
from presidio_analyzer import AnalyzerEngine
65
from presidio_anonymizer import AnonymizerEngine
76

87
from codegate.db.models import AlertSeverity
98
from codegate.pipeline.base import PipelineContext
9+
from codegate.pipeline.sensitive_data.session_store import SessionStore
1010

1111
logger = structlog.get_logger("codegate.pii.analyzer")
1212

1313

14-
class PiiSessionStore:
15-
"""
16-
A class to manage PII (Personally Identifiable Information) session storage.
17-
18-
Attributes:
19-
session_id (str): The unique identifier for the session. If not provided, a new UUID
20-
is generated. mappings (Dict[str, str]): A dictionary to store mappings between UUID
21-
placeholders and PII.
22-
23-
Methods:
24-
add_mapping(pii: str) -> str:
25-
Adds a PII string to the session store and returns a UUID placeholder for it.
26-
27-
get_pii(uuid_placeholder: str) -> str:
28-
Retrieves the PII string associated with the given UUID placeholder. If the placeholder
29-
is not found, returns the placeholder itself.
30-
"""
31-
32-
def __init__(self, session_id: str = None):
33-
self.session_id = session_id or str(uuid.uuid4())
34-
self.mappings: Dict[str, str] = {}
35-
36-
def add_mapping(self, pii: str) -> str:
37-
uuid_placeholder = f"<{str(uuid.uuid4())}>"
38-
self.mappings[uuid_placeholder] = pii
39-
return uuid_placeholder
40-
41-
def get_pii(self, uuid_placeholder: str) -> str:
42-
return self.mappings.get(uuid_placeholder, uuid_placeholder)
43-
44-
4514
class PiiAnalyzer:
4615
"""
4716
PiiAnalyzer class for analyzing and anonymizing text containing PII.
@@ -52,12 +21,12 @@ class PiiAnalyzer:
5221
Get or create the singleton instance of PiiAnalyzer.
5322
analyze:
5423
text (str): The text to analyze for PII.
55-
Tuple[str, List[Dict[str, Any]], PiiSessionStore]: The anonymized text, a list of
24+
Tuple[str, List[Dict[str, Any]], SessionStore]: The anonymized text, a list of
5625
found PII details, and the session store.
5726
entities (List[str]): The PII entities to analyze for.
5827
restore_pii:
5928
anonymized_text (str): The text with anonymized PII.
60-
session_store (PiiSessionStore): The PiiSessionStore used for anonymization.
29+
session_store (SessionStore): The SessionStore used for anonymization.
6130
str: The text with original PII restored.
6231
"""
6332

@@ -95,13 +64,11 @@ def __init__(self):
9564
# Create analyzer with custom NLP engine
9665
self.analyzer = AnalyzerEngine(nlp_engine=nlp_engine)
9766
self.anonymizer = AnonymizerEngine()
98-
self.session_store = PiiSessionStore()
67+
self.session_store = SessionStore()
9968

10069
PiiAnalyzer._instance = self
10170

102-
def analyze(
103-
self, text: str, context: Optional[PipelineContext] = None
104-
) -> Tuple[str, List[Dict[str, Any]], PiiSessionStore]:
71+
def analyze(self, text: str, context: Optional[PipelineContext] = None) -> List:
10572
# Prioritize credit card detection first
10673
entities = [
10774
"PHONE_NUMBER",
@@ -125,81 +92,30 @@ def analyze(
12592
language="en",
12693
score_threshold=0.3, # Lower threshold to catch more potential matches
12794
)
95+
return analyzer_results
12896

129-
# Track found PII
130-
found_pii = []
131-
132-
# Only anonymize if PII was found
133-
if analyzer_results:
134-
# Log each found PII instance and anonymize
135-
anonymized_text = text
136-
for result in analyzer_results:
137-
pii_value = text[result.start : result.end]
138-
uuid_placeholder = self.session_store.add_mapping(pii_value)
139-
pii_info = {
140-
"type": result.entity_type,
141-
"value": pii_value,
142-
"score": result.score,
143-
"start": result.start,
144-
"end": result.end,
145-
"uuid_placeholder": uuid_placeholder,
146-
}
147-
found_pii.append(pii_info)
148-
anonymized_text = anonymized_text.replace(pii_value, uuid_placeholder)
149-
150-
# Log each PII detection with its UUID mapping
151-
logger.info(
152-
"PII detected and mapped",
153-
pii_type=result.entity_type,
154-
score=f"{result.score:.2f}",
155-
uuid=uuid_placeholder,
156-
# Don't log the actual PII value for security
157-
value_length=len(pii_value),
158-
session_id=self.session_store.session_id,
159-
)
160-
161-
# Log summary of all PII found in this analysis
162-
if found_pii and context:
163-
# Create notification string for alert
164-
notify_string = (
165-
f"**PII Detected** 🔒\n"
166-
f"- Total PII Found: {len(found_pii)}\n"
167-
f"- Types Found: {', '.join(set(p['type'] for p in found_pii))}\n"
168-
)
169-
context.add_alert(
170-
self._name,
171-
trigger_string=notify_string,
172-
severity_category=AlertSeverity.CRITICAL,
173-
)
174-
175-
logger.info(
176-
"PII analysis complete",
177-
total_pii_found=len(found_pii),
178-
pii_types=[p["type"] for p in found_pii],
179-
session_id=self.session_store.session_id,
180-
)
181-
182-
# Return the anonymized text, PII details, and session store
183-
return anonymized_text, found_pii, self.session_store
184-
185-
# If no PII found, return original text, empty list, and session store
186-
return text, [], self.session_store
187-
188-
def restore_pii(self, anonymized_text: str, session_store: PiiSessionStore) -> str:
97+
def restore_pii(self, session_id: str, anonymized_text: str) -> str:
18998
"""
19099
Restore the original PII (Personally Identifiable Information) in the given anonymized text.
191100
192101
This method replaces placeholders in the anonymized text with their corresponding original
193-
PII values using the mappings stored in the provided PiiSessionStore.
102+
PII values using the mappings stored in the provided SessionStore.
194103
195104
Args:
196105
anonymized_text (str): The text containing placeholders for PII.
197-
session_store (PiiSessionStore): The session store containing mappings of placeholders
106+
session_id (str): The session id containing mappings of placeholders
198107
to original PII.
199108
200109
Returns:
201110
str: The text with the original PII restored.
202111
"""
203-
for uuid_placeholder, original_pii in session_store.mappings.items():
112+
session_data = self.session_store.get_by_session_id(session_id)
113+
if not session_data:
114+
logger.warning(
115+
"No active PII session found for given session ID. Unable to restore PII."
116+
)
117+
return anonymized_text
118+
119+
for uuid_placeholder, original_pii in session_data.items():
204120
anonymized_text = anonymized_text.replace(uuid_placeholder, original_pii)
205121
return anonymized_text

‎src/codegate/pipeline/pii/manager.py

-84
This file was deleted.

‎src/codegate/pipeline/pii/pii.py

+133-32
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
1-
from typing import Any, Dict, List, Optional
1+
from typing import Any, Dict, List, Optional, Tuple
2+
import uuid
23

34
import regex as re
45
import structlog
56
from litellm import ChatCompletionRequest, ChatCompletionSystemMessage, ModelResponse
67
from litellm.types.utils import Delta, StreamingChoices
78

89
from codegate.config import Config
10+
from codegate.db.models import AlertSeverity
911
from codegate.pipeline.base import (
1012
PipelineContext,
1113
PipelineResult,
1214
PipelineStep,
1315
)
1416
from codegate.pipeline.output import OutputPipelineContext, OutputPipelineStep
15-
from codegate.pipeline.pii.manager import PiiManager
17+
from codegate.pipeline.pii.analyzer import PiiAnalyzer
18+
from codegate.pipeline.sensitive_data.manager import SensitiveData, SensitiveDataManager
1619
from codegate.pipeline.systemmsg import add_or_update_system_message
1720

1821
logger = structlog.get_logger("codegate")
@@ -25,7 +28,7 @@ class CodegatePii(PipelineStep):
2528
2629
Methods:
2730
__init__:
28-
Initializes the CodegatePii pipeline step and sets up the PiiManager.
31+
Initializes the CodegatePii pipeline step and sets up the SensitiveDataManager.
2932
3033
name:
3134
Returns the name of the pipeline step.
@@ -37,14 +40,15 @@ class CodegatePii(PipelineStep):
3740
Processes the chat completion request to detect and redact PII. Updates the request with
3841
anonymized text and stores PII details in the context metadata.
3942
40-
restore_pii(anonymized_text: str) -> str:
41-
Restores the original PII from the anonymized text using the PiiManager.
43+
restore_pii(session_id: str, anonymized_text: str) -> str:
44+
Restores the original PII from the anonymized text using the SensitiveDataManager.
4245
"""
4346

44-
def __init__(self):
47+
def __init__(self, sensitive_data_manager: SensitiveDataManager):
4548
"""Initialize the CodegatePii pipeline step."""
4649
super().__init__()
47-
self.pii_manager = PiiManager()
50+
self.sensitive_data_manager = sensitive_data_manager
51+
self.analyzer = PiiAnalyzer.get_instance()
4852

4953
@property
5054
def name(self) -> str:
@@ -65,6 +69,68 @@ def _get_redacted_snippet(self, message: str, pii_details: List[Dict[str, Any]])
6569

6670
return message[start:end]
6771

72+
def process_results(
73+
self, session_id: str, text: str, results: List, context: PipelineContext
74+
) -> Tuple[List, str]:
75+
# Track found PII
76+
found_pii = []
77+
78+
# Log each found PII instance and anonymize
79+
anonymized_text = text
80+
for result in results:
81+
pii_value = text[result.start : result.end]
82+
83+
# add to session store
84+
obj = SensitiveData(original=pii_value, service="pii", type=result.entity_type)
85+
uuid_placeholder = self.sensitive_data_manager.store(session_id, obj)
86+
anonymized_text = anonymized_text.replace(pii_value, uuid_placeholder)
87+
88+
# Add to found PII list
89+
pii_info = {
90+
"type": result.entity_type,
91+
"value": pii_value,
92+
"score": result.score,
93+
"start": result.start,
94+
"end": result.end,
95+
"uuid_placeholder": uuid_placeholder,
96+
}
97+
found_pii.append(pii_info)
98+
99+
# Log each PII detection with its UUID mapping
100+
logger.info(
101+
"PII detected and mapped",
102+
pii_type=result.entity_type,
103+
score=f"{result.score:.2f}",
104+
uuid=uuid_placeholder,
105+
# Don't log the actual PII value for security
106+
value_length=len(pii_value),
107+
session_id=session_id,
108+
)
109+
110+
# Log summary of all PII found in this analysis
111+
if found_pii and context:
112+
# Create notification string for alert
113+
notify_string = (
114+
f"**PII Detected** 🔒\n"
115+
f"- Total PII Found: {len(found_pii)}\n"
116+
f"- Types Found: {', '.join(set(p['type'] for p in found_pii))}\n"
117+
)
118+
context.add_alert(
119+
self.name,
120+
trigger_string=notify_string,
121+
severity_category=AlertSeverity.CRITICAL,
122+
)
123+
124+
logger.info(
125+
"PII analysis complete",
126+
total_pii_found=len(found_pii),
127+
pii_types=[p["type"] for p in found_pii],
128+
session_id=session_id,
129+
)
130+
131+
# Return the anonymized text, PII details, and session store
132+
return found_pii, anonymized_text
133+
68134
async def process(
69135
self, request: ChatCompletionRequest, context: PipelineContext
70136
) -> PipelineResult:
@@ -75,33 +141,39 @@ async def process(
75141
total_pii_found = 0
76142
all_pii_details: List[Dict[str, Any]] = []
77143
last_redacted_text = ""
144+
session_id = context.sensitive.session_id
78145

79146
for i, message in enumerate(new_request["messages"]):
80147
if "content" in message and message["content"]:
81148
# This is where analyze and anonymize the text
82149
original_text = str(message["content"])
83-
anonymized_text, pii_details = self.pii_manager.analyze(original_text, context)
84-
85-
if pii_details:
86-
total_pii_found += len(pii_details)
87-
all_pii_details.extend(pii_details)
88-
new_request["messages"][i]["content"] = anonymized_text
89-
90-
# If this is a user message, grab the redacted snippet!
91-
if message.get("role") == "user":
92-
last_redacted_text = self._get_redacted_snippet(
93-
anonymized_text, pii_details
94-
)
150+
results = self.analyzer.analyze(original_text, context)
151+
if results:
152+
pii_details, anonymized_text = self.process_results(
153+
session_id, original_text, results, context
154+
)
155+
156+
if pii_details:
157+
total_pii_found += len(pii_details)
158+
all_pii_details.extend(pii_details)
159+
new_request["messages"][i]["content"] = anonymized_text
160+
161+
# If this is a user message, grab the redacted snippet!
162+
if message.get("role") == "user":
163+
last_redacted_text = self._get_redacted_snippet(
164+
anonymized_text, pii_details
165+
)
95166

96167
logger.info(f"Total PII instances redacted: {total_pii_found}")
97168

98169
# Store the count, details, and redacted text in context metadata
99170
context.metadata["redacted_pii_count"] = total_pii_found
100171
context.metadata["redacted_pii_details"] = all_pii_details
101172
context.metadata["redacted_text"] = last_redacted_text
173+
context.metadata["session_id"] = session_id
102174

103175
if total_pii_found > 0:
104-
context.metadata["pii_manager"] = self.pii_manager
176+
context.metadata["sensitive_data_manager"] = self.sensitive_data_manager
105177

106178
system_message = ChatCompletionSystemMessage(
107179
content=Config.get_config().prompts.pii_redacted,
@@ -113,8 +185,31 @@ async def process(
113185

114186
return PipelineResult(request=new_request, context=context)
115187

116-
def restore_pii(self, anonymized_text: str) -> str:
117-
return self.pii_manager.restore_pii(anonymized_text)
188+
def restore_pii(self, session_id: str, anonymized_text: str) -> str:
189+
"""
190+
Restore the original PII (Personally Identifiable Information) in the given anonymized text.
191+
192+
This method replaces placeholders in the anonymized text with their corresponding original
193+
PII values using the mappings stored in the provided SessionStore.
194+
195+
Args:
196+
anonymized_text (str): The text containing placeholders for PII.
197+
session_id (str): The session id containing mappings of placeholders
198+
to original PII.
199+
200+
Returns:
201+
str: The text with the original PII restored.
202+
"""
203+
session_data = self.sensitive_data_manager.get_by_session_id(session_id)
204+
if not session_data:
205+
logger.warning(
206+
"No active PII session found for given session ID. Unable to restore PII."
207+
)
208+
return anonymized_text
209+
210+
for uuid_placeholder, original_pii in session_data.items():
211+
anonymized_text = anonymized_text.replace(uuid_placeholder, original_pii)
212+
return anonymized_text
118213

119214

120215
class PiiUnRedactionStep(OutputPipelineStep):
@@ -136,12 +231,12 @@ class PiiUnRedactionStep(OutputPipelineStep):
136231
"""
137232

138233
def __init__(self):
139-
self.redacted_pattern = re.compile(r"<([0-9a-f-]{0,36})>")
234+
self.redacted_pattern = re.compile(r"#([0-9a-f-]{0,36})#")
140235
self.complete_uuid_pattern = re.compile(
141236
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"
142237
) # noqa: E501
143-
self.marker_start = "<"
144-
self.marker_end = ">"
238+
self.marker_start = "#"
239+
self.marker_end = "#"
145240

146241
@property
147242
def name(self) -> str:
@@ -151,7 +246,7 @@ def _is_complete_uuid(self, uuid_str: str) -> bool:
151246
"""Check if the string is a complete UUID"""
152247
return bool(self.complete_uuid_pattern.match(uuid_str))
153248

154-
async def process_chunk(
249+
async def process_chunk( # noqa: C901
155250
self,
156251
chunk: ModelResponse,
157252
context: OutputPipelineContext,
@@ -162,6 +257,10 @@ async def process_chunk(
162257
return [chunk]
163258

164259
content = chunk.choices[0].delta.content
260+
session_id = input_context.sensitive.session_id
261+
if not session_id:
262+
logger.error("Could not get any session id, cannot process pii")
263+
return [chunk]
165264

166265
# Add current chunk to buffer
167266
if context.prefix_buffer:
@@ -172,13 +271,13 @@ async def process_chunk(
172271
current_pos = 0
173272
result = []
174273
while current_pos < len(content):
175-
start_idx = content.find("<", current_pos)
274+
start_idx = content.find(self.marker_start, current_pos)
176275
if start_idx == -1:
177276
# No more markers!, add remaining content
178277
result.append(content[current_pos:])
179278
break
180279

181-
end_idx = content.find(">", start_idx)
280+
end_idx = content.find(self.marker_end, start_idx + 1)
182281
if end_idx == -1:
183282
# Incomplete marker, buffer the rest
184283
context.prefix_buffer = content[current_pos:]
@@ -190,16 +289,18 @@ async def process_chunk(
190289

191290
# Extract potential UUID if it's a valid format!
192291
uuid_marker = content[start_idx : end_idx + 1]
193-
uuid_value = uuid_marker[1:-1] # Remove < >
292+
uuid_value = uuid_marker[1:-1] # Remove # #
194293

195294
if self._is_complete_uuid(uuid_value):
196295
# Get the PII manager from context metadata
197296
logger.debug(f"Valid UUID found: {uuid_value}")
198-
pii_manager = input_context.metadata.get("pii_manager") if input_context else None
199-
if pii_manager and pii_manager.session_store:
297+
sensitive_data_manager = (
298+
input_context.metadata.get("sensitive_data_manager") if input_context else None
299+
)
300+
if sensitive_data_manager and sensitive_data_manager.session_store:
200301
# Restore original value from PII manager
201302
logger.debug("Attempting to restore PII from UUID marker")
202-
original = pii_manager.session_store.get_pii(uuid_marker)
303+
original = sensitive_data_manager.get_original_value(session_id, uuid_marker)
203304
logger.debug(f"Restored PII: {original}")
204305
result.append(original)
205306
else:

‎src/codegate/pipeline/secrets/gatecrypto.py

-111
This file was deleted.

‎src/codegate/pipeline/secrets/manager.py

-117
This file was deleted.

‎src/codegate/pipeline/secrets/secrets.py

+37-20
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
PipelineStep,
1717
)
1818
from codegate.pipeline.output import OutputPipelineContext, OutputPipelineStep
19-
from codegate.pipeline.secrets.manager import SecretsManager
2019
from codegate.pipeline.secrets.signatures import CodegateSignatures, Match
20+
from codegate.pipeline.sensitive_data.manager import SensitiveData, SensitiveDataManager
2121
from codegate.pipeline.systemmsg import add_or_update_system_message
2222

2323
logger = structlog.get_logger("codegate")
@@ -171,25 +171,35 @@ def obfuscate(self, text: str, snippet: Optional[CodeSnippet]) -> tuple[str, Lis
171171
class SecretsEncryptor(SecretsModifier):
172172
def __init__(
173173
self,
174-
secrets_manager: SecretsManager,
174+
sensitive_data_manager: SensitiveDataManager,
175175
context: PipelineContext,
176176
session_id: str,
177177
):
178-
self._secrets_manager = secrets_manager
178+
self._sensitive_data_manager = sensitive_data_manager
179179
self._session_id = session_id
180180
self._context = context
181181
self._name = "codegate-secrets"
182+
182183
super().__init__()
183184

184185
def _hide_secret(self, match: Match) -> str:
185186
# Encrypt and store the value
186-
encrypted_value = self._secrets_manager.store_secret(
187-
match.value,
188-
match.service,
189-
match.type,
190-
self._session_id,
187+
if not self._session_id:
188+
raise ValueError("Session id must be provided")
189+
190+
if not match.value:
191+
raise ValueError("Value must be provided")
192+
if not match.service:
193+
raise ValueError("Service must be provided")
194+
if not match.type:
195+
raise ValueError("Secret type must be provided")
196+
197+
obj = SensitiveData(original=match.value, service=match.service, type=match.type)
198+
uuid_placeholder = self._sensitive_data_manager.store(self._session_id, obj)
199+
logger.debug(
200+
"Stored secret", service=match.service, type=match.type, placeholder=uuid_placeholder
191201
)
192-
return f"REDACTED<${encrypted_value}>"
202+
return f"REDACTED<{uuid_placeholder}>"
193203

194204
def _notify_secret(
195205
self, match: Match, code_snippet: Optional[CodeSnippet], protected_text: List[str]
@@ -251,7 +261,7 @@ def _redact_text(
251261
self,
252262
text: str,
253263
snippet: Optional[CodeSnippet],
254-
secrets_manager: SecretsManager,
264+
sensitive_data_manager: SensitiveDataManager,
255265
session_id: str,
256266
context: PipelineContext,
257267
) -> tuple[str, List[Match]]:
@@ -260,14 +270,14 @@ def _redact_text(
260270
261271
Args:
262272
text: The text to protect
263-
secrets_manager: ..
273+
sensitive_data_manager: ..
264274
session_id: ..
265275
context: The pipeline context to be able to log alerts
266276
Returns:
267277
Tuple containing protected text with encrypted values and the count of redacted secrets
268278
"""
269279
# Find secrets in the text
270-
text_encryptor = SecretsEncryptor(secrets_manager, context, session_id)
280+
text_encryptor = SecretsEncryptor(sensitive_data_manager, context, session_id)
271281
return text_encryptor.obfuscate(text, snippet)
272282

273283
async def process(
@@ -287,8 +297,10 @@ async def process(
287297
if "messages" not in request:
288298
return PipelineResult(request=request, context=context)
289299

290-
secrets_manager = context.sensitive.manager
291-
if not secrets_manager or not isinstance(secrets_manager, SecretsManager):
300+
sensitive_data_manager = context.sensitive.manager
301+
if not sensitive_data_manager or not isinstance(
302+
sensitive_data_manager, SensitiveDataManager
303+
):
292304
raise ValueError("Secrets manager not found in context")
293305
session_id = context.sensitive.session_id
294306
if not session_id:
@@ -305,15 +317,15 @@ async def process(
305317
for i, message in enumerate(new_request["messages"]):
306318
if "content" in message and message["content"]:
307319
redacted_content, secrets_matched = self._redact_message_content(
308-
message["content"], secrets_manager, session_id, context
320+
message["content"], sensitive_data_manager, session_id, context
309321
)
310322
new_request["messages"][i]["content"] = redacted_content
311323
if i > last_assistant_idx:
312324
total_matches += secrets_matched
313325
new_request = self._finalize_redaction(context, total_matches, new_request)
314326
return PipelineResult(request=new_request, context=context)
315327

316-
def _redact_message_content(self, message_content, secrets_manager, session_id, context):
328+
def _redact_message_content(self, message_content, sensitive_data_manager, session_id, context):
317329
# Extract any code snippets
318330
extractor = MessageCodeExtractorFactory.create_snippet_extractor(context.client)
319331
snippets = extractor.extract_snippets(message_content)
@@ -322,7 +334,7 @@ def _redact_message_content(self, message_content, secrets_manager, session_id,
322334

323335
for snippet in snippets:
324336
redacted_snippet, secrets_matched = self._redact_text(
325-
snippet, snippet, secrets_manager, session_id, context
337+
snippet, snippet, sensitive_data_manager, session_id, context
326338
)
327339
redacted_snippets[snippet.code] = redacted_snippet
328340
total_matches.extend(secrets_matched)
@@ -336,7 +348,7 @@ def _redact_message_content(self, message_content, secrets_manager, session_id,
336348
if start_index > last_end:
337349
non_snippet_part = message_content[last_end:start_index]
338350
redacted_part, secrets_matched = self._redact_text(
339-
non_snippet_part, "", secrets_manager, session_id, context
351+
non_snippet_part, "", sensitive_data_manager, session_id, context
340352
)
341353
non_snippet_parts.append(redacted_part)
342354
total_matches.extend(secrets_matched)
@@ -347,7 +359,7 @@ def _redact_message_content(self, message_content, secrets_manager, session_id,
347359
if last_end < len(message_content):
348360
remaining_text = message_content[last_end:]
349361
redacted_remaining, secrets_matched = self._redact_text(
350-
remaining_text, "", secrets_manager, session_id, context
362+
remaining_text, "", sensitive_data_manager, session_id, context
351363
)
352364
non_snippet_parts.append(redacted_remaining)
353365
total_matches.extend(secrets_matched)
@@ -428,9 +440,14 @@ async def process_chunk(
428440
encrypted_value = match.group(1)
429441
if encrypted_value.startswith("$"):
430442
encrypted_value = encrypted_value[1:]
443+
444+
session_id = input_context.sensitive.session_id
445+
if not session_id:
446+
raise ValueError("Session ID not found in context")
447+
431448
original_value = input_context.sensitive.manager.get_original_value(
449+
session_id,
432450
encrypted_value,
433-
input_context.sensitive.session_id,
434451
)
435452

436453
if original_value is None:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import json
2+
from typing import Dict, Optional
3+
import pydantic
4+
import structlog
5+
from codegate.pipeline.sensitive_data.session_store import SessionStore
6+
7+
logger = structlog.get_logger("codegate")
8+
9+
10+
class SensitiveData(pydantic.BaseModel):
11+
"""Represents sensitive data with additional metadata."""
12+
13+
original: str
14+
service: Optional[str] = None
15+
type: Optional[str] = None
16+
17+
18+
class SensitiveDataManager:
19+
"""Manages encryption, storage, and retrieval of secrets"""
20+
21+
def __init__(self):
22+
self.session_store = SessionStore()
23+
24+
def store(self, session_id: str, value: SensitiveData) -> Optional[str]:
25+
if not session_id or not value.original:
26+
return None
27+
return self.session_store.add_mapping(session_id, value.model_dump_json())
28+
29+
def get_by_session_id(self, session_id: str) -> Optional[Dict]:
30+
if not session_id:
31+
return None
32+
data = self.session_store.get_by_session_id(session_id)
33+
return SensitiveData.model_validate_json(data) if data else None
34+
35+
def get_original_value(self, session_id: str, uuid_placeholder: str) -> Optional[str]:
36+
if not session_id:
37+
return None
38+
secret_entry_json = self.session_store.get_mapping(session_id, uuid_placeholder)
39+
return (
40+
SensitiveData.model_validate_json(secret_entry_json).original
41+
if secret_entry_json
42+
else None
43+
)
44+
45+
def cleanup_session(self, session_id: str):
46+
if session_id:
47+
self.session_store.cleanup_session(session_id)
48+
49+
def cleanup(self):
50+
self.session_store.cleanup()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from typing import Dict, Optional
2+
import uuid
3+
4+
5+
class SessionStore:
6+
"""
7+
A generic session store for managing data protection.
8+
"""
9+
10+
def __init__(self):
11+
self.sessions: Dict[str, Dict[str, str]] = {}
12+
13+
def add_mapping(self, session_id: str, data: str) -> str:
14+
uuid_placeholder = f"#{str(uuid.uuid4())}#"
15+
if session_id not in self.sessions:
16+
self.sessions[session_id] = {}
17+
self.sessions[session_id][uuid_placeholder] = data
18+
return uuid_placeholder
19+
20+
def get_by_session_id(self, session_id: str) -> Optional[Dict]:
21+
return self.sessions.get(session_id, None)
22+
23+
def get_mapping(self, session_id: str, uuid_placeholder: str) -> Optional[str]:
24+
return self.sessions.get(session_id, {}).get(uuid_placeholder)
25+
26+
def cleanup_session(self, session_id: str):
27+
"""Clears all stored mappings for a specific session."""
28+
if session_id in self.sessions:
29+
del self.sessions[session_id]
30+
31+
def cleanup(self):
32+
"""Clears all stored mappings for all sessions."""
33+
self.sessions.clear()

‎src/codegate/providers/copilot/provider.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from codegate.pipeline.base import PipelineContext
1818
from codegate.pipeline.factory import PipelineFactory
1919
from codegate.pipeline.output import OutputPipelineInstance
20-
from codegate.pipeline.secrets.manager import SecretsManager
20+
from codegate.pipeline.sensitive_data.manager import SensitiveDataManager
2121
from codegate.providers.copilot.mapping import PIPELINE_ROUTES, VALIDATED_ROUTES, PipelineType
2222
from codegate.providers.copilot.pipeline import (
2323
CopilotChatPipeline,
@@ -200,7 +200,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop):
200200
self.ca = CertificateAuthority.get_instance()
201201
self.cert_manager = TLSCertDomainManager(self.ca)
202202
self._closing = False
203-
self.pipeline_factory = PipelineFactory(SecretsManager())
203+
self.pipeline_factory = PipelineFactory(SensitiveDataManager())
204204
self.input_pipeline: Optional[CopilotPipeline] = None
205205
self.fim_pipeline: Optional[CopilotPipeline] = None
206206
# the context as provided by the pipeline

‎tests/pipeline/pii/test_analyzer.py

+11-85
Original file line numberDiff line numberDiff line change
@@ -3,44 +3,7 @@
33
import pytest
44
from presidio_analyzer import RecognizerResult
55

6-
from codegate.pipeline.pii.analyzer import PiiAnalyzer, PiiSessionStore
7-
8-
9-
class TestPiiSessionStore:
10-
def test_init_with_session_id(self):
11-
session_id = "test-session"
12-
store = PiiSessionStore(session_id)
13-
assert store.session_id == session_id
14-
assert store.mappings == {}
15-
16-
def test_init_without_session_id(self):
17-
store = PiiSessionStore()
18-
assert isinstance(store.session_id, str)
19-
assert len(store.session_id) > 0
20-
assert store.mappings == {}
21-
22-
def test_add_mapping(self):
23-
store = PiiSessionStore()
24-
pii = "test@example.com"
25-
placeholder = store.add_mapping(pii)
26-
27-
assert placeholder.startswith("<")
28-
assert placeholder.endswith(">")
29-
assert store.mappings[placeholder] == pii
30-
31-
def test_get_pii_existing(self):
32-
store = PiiSessionStore()
33-
pii = "test@example.com"
34-
placeholder = store.add_mapping(pii)
35-
36-
result = store.get_pii(placeholder)
37-
assert result == pii
38-
39-
def test_get_pii_nonexistent(self):
40-
store = PiiSessionStore()
41-
placeholder = "<nonexistent>"
42-
result = store.get_pii(placeholder)
43-
assert result == placeholder
6+
from codegate.pipeline.pii.analyzer import PiiAnalyzer
447

458

469
class TestPiiAnalyzer:
@@ -104,68 +67,31 @@ def test_singleton_pattern(self):
10467
with pytest.raises(RuntimeError, match="Use PiiAnalyzer.get_instance()"):
10568
PiiAnalyzer()
10669

107-
def test_analyze_no_pii(self, analyzer, mock_analyzer_engine):
108-
text = "Hello world"
109-
mock_analyzer_engine.analyze.return_value = []
110-
111-
result_text, found_pii, session_store = analyzer.analyze(text)
112-
113-
assert result_text == text
114-
assert found_pii == []
115-
assert isinstance(session_store, PiiSessionStore)
116-
117-
def test_analyze_with_pii(self, analyzer, mock_analyzer_engine):
118-
text = "My email is test@example.com"
119-
email_pii = RecognizerResult(
120-
entity_type="EMAIL_ADDRESS",
121-
start=12,
122-
end=28,
123-
score=1.0, # EmailRecognizer returns a score of 1.0
124-
)
125-
mock_analyzer_engine.analyze.return_value = [email_pii]
126-
127-
result_text, found_pii, session_store = analyzer.analyze(text)
128-
129-
assert len(found_pii) == 1
130-
pii_info = found_pii[0]
131-
assert pii_info["type"] == "EMAIL_ADDRESS"
132-
assert pii_info["value"] == "test@example.com"
133-
assert pii_info["score"] == 1.0
134-
assert pii_info["start"] == 12
135-
assert pii_info["end"] == 28
136-
assert "uuid_placeholder" in pii_info
137-
# Verify the placeholder was used to replace the PII
138-
placeholder = pii_info["uuid_placeholder"]
139-
assert result_text == f"My email is {placeholder}"
140-
# Verify the mapping was stored
141-
assert session_store.get_pii(placeholder) == "test@example.com"
142-
14370
def test_restore_pii(self, analyzer):
144-
session_store = PiiSessionStore()
14571
original_text = "test@example.com"
146-
placeholder = session_store.add_mapping(original_text)
147-
anonymized_text = f"My email is {placeholder}"
72+
session_id = "session-id"
14873

149-
restored_text = analyzer.restore_pii(anonymized_text, session_store)
74+
placeholder = analyzer.session_store.add_mapping(session_id, original_text)
75+
anonymized_text = f"My email is {placeholder}"
76+
restored_text = analyzer.restore_pii(session_id, anonymized_text)
15077

15178
assert restored_text == f"My email is {original_text}"
15279

15380
def test_restore_pii_multiple(self, analyzer):
154-
session_store = PiiSessionStore()
15581
email = "test@example.com"
15682
phone = "123-456-7890"
157-
email_placeholder = session_store.add_mapping(email)
158-
phone_placeholder = session_store.add_mapping(phone)
83+
session_id = "session-id"
84+
email_placeholder = analyzer.session_store.add_mapping(session_id, email)
85+
phone_placeholder = analyzer.session_store.add_mapping(session_id, phone)
15986
anonymized_text = f"Email: {email_placeholder}, Phone: {phone_placeholder}"
16087

161-
restored_text = analyzer.restore_pii(anonymized_text, session_store)
88+
restored_text = analyzer.restore_pii(session_id, anonymized_text)
16289

16390
assert restored_text == f"Email: {email}, Phone: {phone}"
16491

16592
def test_restore_pii_no_placeholders(self, analyzer):
166-
session_store = PiiSessionStore()
16793
text = "No PII here"
168-
169-
restored_text = analyzer.restore_pii(text, session_store)
94+
session_id = "session-id"
95+
restored_text = analyzer.restore_pii(session_id, text)
17096

17197
assert restored_text == text

‎tests/pipeline/pii/test_pi.py

+12-62
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
from litellm import ChatCompletionRequest, ModelResponse
55
from litellm.types.utils import Delta, StreamingChoices
66

7-
from codegate.pipeline.base import PipelineContext
7+
from codegate.pipeline.base import PipelineContext, PipelineSensitiveData
88
from codegate.pipeline.output import OutputPipelineContext
99
from codegate.pipeline.pii.pii import CodegatePii, PiiRedactionNotifier, PiiUnRedactionStep
10+
from codegate.pipeline.sensitive_data.manager import SensitiveDataManager
1011

1112

1213
class TestCodegatePii:
@@ -19,8 +20,9 @@ def mock_config(self):
1920
yield mock_config
2021

2122
@pytest.fixture
22-
def pii_step(self, mock_config):
23-
return CodegatePii()
23+
def pii_step(self):
24+
mock_sensitive_data_manager = MagicMock()
25+
return CodegatePii(mock_sensitive_data_manager)
2426

2527
def test_name(self, pii_step):
2628
assert pii_step.name == "codegate-pii"
@@ -51,57 +53,6 @@ async def test_process_no_messages(self, pii_step):
5153
assert result.request == request
5254
assert result.context == context
5355

54-
@pytest.mark.asyncio
55-
async def test_process_with_pii(self, pii_step):
56-
original_text = "My email is test@example.com"
57-
request = ChatCompletionRequest(
58-
model="test-model", messages=[{"role": "user", "content": original_text}]
59-
)
60-
context = PipelineContext()
61-
62-
# Mock the PII manager's analyze method
63-
placeholder = "<test-uuid>"
64-
pii_details = [
65-
{
66-
"type": "EMAIL_ADDRESS",
67-
"value": "test@example.com",
68-
"score": 1.0,
69-
"start": 12,
70-
"end": 27,
71-
"uuid_placeholder": placeholder,
72-
}
73-
]
74-
anonymized_text = f"My email is {placeholder}"
75-
pii_step.pii_manager.analyze = MagicMock(return_value=(anonymized_text, pii_details))
76-
77-
result = await pii_step.process(request, context)
78-
79-
# Verify the user message was anonymized
80-
user_messages = [m for m in result.request["messages"] if m["role"] == "user"]
81-
assert len(user_messages) == 1
82-
assert user_messages[0]["content"] == anonymized_text
83-
84-
# Verify metadata was updated
85-
assert result.context.metadata["redacted_pii_count"] == 1
86-
assert len(result.context.metadata["redacted_pii_details"]) == 1
87-
# The redacted text should be just the placeholder since that's what _get_redacted_snippet returns # noqa: E501
88-
assert result.context.metadata["redacted_text"] == placeholder
89-
assert "pii_manager" in result.context.metadata
90-
91-
# Verify system message was added
92-
system_messages = [m for m in result.request["messages"] if m["role"] == "system"]
93-
assert len(system_messages) == 1
94-
assert system_messages[0]["content"] == "PII has been redacted"
95-
96-
def test_restore_pii(self, pii_step):
97-
anonymized_text = "My email is <test-uuid>"
98-
original_text = "My email is test@example.com"
99-
pii_step.pii_manager.restore_pii = MagicMock(return_value=original_text)
100-
101-
restored = pii_step.restore_pii(anonymized_text)
102-
103-
assert restored == original_text
104-
10556

10657
class TestPiiUnRedactionStep:
10758
@pytest.fixture
@@ -148,7 +99,7 @@ async def test_process_chunk_with_uuid(self, unredaction_step):
14899
StreamingChoices(
149100
finish_reason=None,
150101
index=0,
151-
delta=Delta(content=f"Text with <{uuid}>"),
102+
delta=Delta(content=f"Text with #{uuid}#"),
152103
logprobs=None,
153104
)
154105
],
@@ -157,17 +108,16 @@ async def test_process_chunk_with_uuid(self, unredaction_step):
157108
object="chat.completion.chunk",
158109
)
159110
context = OutputPipelineContext()
160-
input_context = PipelineContext()
111+
manager = SensitiveDataManager()
112+
sensitive = PipelineSensitiveData(manager=manager, session_id="session-id")
113+
input_context = PipelineContext(sensitive=sensitive)
161114

162115
# Mock PII manager in input context
163-
mock_pii_manager = MagicMock()
164-
mock_session = MagicMock()
165-
mock_session.get_pii = MagicMock(return_value="test@example.com")
166-
mock_pii_manager.session_store = mock_session
167-
input_context.metadata["pii_manager"] = mock_pii_manager
116+
mock_sensitive_data_manager = MagicMock()
117+
mock_sensitive_data_manager.get_original_value = MagicMock(return_value="test@example.com")
118+
input_context.metadata["sensitive_data_manager"] = mock_sensitive_data_manager
168119

169120
result = await unredaction_step.process_chunk(chunk, context, input_context)
170-
171121
assert result[0].choices[0].delta.content == "Text with test@example.com"
172122

173123

‎tests/pipeline/pii/test_pii_manager.py

-106
This file was deleted.

‎tests/pipeline/secrets/test_gatecrypto.py

-157
This file was deleted.

‎tests/pipeline/secrets/test_manager.py

-149
This file was deleted.

‎tests/pipeline/secrets/test_secrets.py

+21-21
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77

88
from codegate.pipeline.base import PipelineContext, PipelineSensitiveData
99
from codegate.pipeline.output import OutputPipelineContext
10-
from codegate.pipeline.secrets.manager import SecretsManager
1110
from codegate.pipeline.secrets.secrets import (
1211
SecretsEncryptor,
1312
SecretsObfuscator,
1413
SecretUnredactionStep,
1514
)
1615
from codegate.pipeline.secrets.signatures import CodegateSignatures, Match
16+
from codegate.pipeline.sensitive_data.manager import SensitiveData, SensitiveDataManager
1717

1818

1919
class TestSecretsModifier:
@@ -69,9 +69,11 @@ class TestSecretsEncryptor:
6969
def setup(self, temp_yaml_file):
7070
CodegateSignatures.initialize(temp_yaml_file)
7171
self.context = PipelineContext()
72-
self.secrets_manager = SecretsManager()
72+
self.sensitive_data_manager = SensitiveDataManager()
7373
self.session_id = "test_session"
74-
self.encryptor = SecretsEncryptor(self.secrets_manager, self.context, self.session_id)
74+
self.encryptor = SecretsEncryptor(
75+
self.sensitive_data_manager, self.context, self.session_id
76+
)
7577

7678
def test_hide_secret(self):
7779
# Create a test match
@@ -87,12 +89,12 @@ def test_hide_secret(self):
8789

8890
# Test secret hiding
8991
hidden = self.encryptor._hide_secret(match)
90-
assert hidden.startswith("REDACTED<$")
92+
assert hidden.startswith("REDACTED<")
9193
assert hidden.endswith(">")
9294

9395
# Verify the secret was stored
94-
encrypted_value = hidden[len("REDACTED<$") : -1]
95-
original = self.secrets_manager.get_original_value(encrypted_value, self.session_id)
96+
encrypted_value = hidden[len("REDACTED<") : -1]
97+
original = self.sensitive_data_manager.get_original_value(self.session_id, encrypted_value)
9698
assert original == "AKIAIOSFODNN7EXAMPLE"
9799

98100
def test_obfuscate(self):
@@ -101,7 +103,7 @@ def test_obfuscate(self):
101103
protected, matched_secrets = self.encryptor.obfuscate(text, None)
102104

103105
assert len(matched_secrets) == 1
104-
assert "REDACTED<$" in protected
106+
assert "REDACTED<" in protected
105107
assert "AKIAIOSFODNN7EXAMPLE" not in protected
106108
assert "Other text" in protected
107109

@@ -171,25 +173,24 @@ def setup_method(self):
171173
"""Setup fresh instances for each test"""
172174
self.step = SecretUnredactionStep()
173175
self.context = OutputPipelineContext()
174-
self.secrets_manager = SecretsManager()
176+
self.sensitive_data_manager = SensitiveDataManager()
175177
self.session_id = "test_session"
176178

177179
# Setup input context with secrets manager
178180
self.input_context = PipelineContext()
179181
self.input_context.sensitive = PipelineSensitiveData(
180-
manager=self.secrets_manager, session_id=self.session_id
182+
manager=self.sensitive_data_manager, session_id=self.session_id
181183
)
182184

183185
@pytest.mark.asyncio
184186
async def test_complete_marker_processing(self):
185187
"""Test processing of a complete REDACTED marker"""
186188
# Store a secret
187-
encrypted = self.secrets_manager.store_secret(
188-
"secret_value", "test_service", "api_key", self.session_id
189-
)
189+
obj = SensitiveData(original="secret_value", service="test_service", type="api_key")
190+
encrypted = self.sensitive_data_manager.store(self.session_id, obj)
190191

191192
# Add content with REDACTED marker to buffer
192-
self.context.buffer.append(f"Here is the REDACTED<${encrypted}> in text")
193+
self.context.buffer.append(f"Here is the REDACTED<{encrypted}> in text")
193194

194195
# Process a chunk
195196
result = await self.step.process_chunk(
@@ -204,7 +205,7 @@ async def test_complete_marker_processing(self):
204205
async def test_partial_marker_buffering(self):
205206
"""Test handling of partial REDACTED markers"""
206207
# Add partial marker to buffer
207-
self.context.buffer.append("Here is REDACTED<$")
208+
self.context.buffer.append("Here is REDACTED<")
208209

209210
# Process a chunk
210211
result = await self.step.process_chunk(
@@ -218,7 +219,7 @@ async def test_partial_marker_buffering(self):
218219
async def test_invalid_encrypted_value(self):
219220
"""Test handling of invalid encrypted values"""
220221
# Add content with invalid encrypted value
221-
self.context.buffer.append("Here is REDACTED<$invalid_value> in text")
222+
self.context.buffer.append("Here is REDACTED<invalid_value> in text")
222223

223224
# Process chunk
224225
result = await self.step.process_chunk(
@@ -227,7 +228,7 @@ async def test_invalid_encrypted_value(self):
227228

228229
# Should keep the REDACTED marker for invalid values
229230
assert len(result) == 1
230-
assert result[0].choices[0].delta.content == "Here is REDACTED<$invalid_value> in text"
231+
assert result[0].choices[0].delta.content == "Here is REDACTED<invalid_value> in text"
231232

232233
@pytest.mark.asyncio
233234
async def test_missing_context(self):
@@ -271,17 +272,16 @@ async def test_no_markers(self):
271272
async def test_wrong_session(self):
272273
"""Test unredaction with wrong session ID"""
273274
# Store secret with one session
274-
encrypted = self.secrets_manager.store_secret(
275-
"secret_value", "test_service", "api_key", "different_session"
276-
)
275+
obj = SensitiveData(original="test_service", service="api_key", type="different_session")
276+
encrypted = self.sensitive_data_manager.store("different_session", obj)
277277

278278
# Try to unredact with different session
279-
self.context.buffer.append(f"Here is the REDACTED<${encrypted}> in text")
279+
self.context.buffer.append(f"Here is the REDACTED<{encrypted}> in text")
280280

281281
result = await self.step.process_chunk(
282282
create_model_response("text"), self.context, self.input_context
283283
)
284284

285285
# Should keep REDACTED marker when session doesn't match
286286
assert len(result) == 1
287-
assert result[0].choices[0].delta.content == f"Here is the REDACTED<${encrypted}> in text"
287+
assert result[0].choices[0].delta.content == f"Here is the REDACTED<{encrypted}> in text"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import json
2+
from unittest.mock import MagicMock, patch
3+
import pytest
4+
from codegate.pipeline.sensitive_data.manager import SensitiveData, SensitiveDataManager
5+
from codegate.pipeline.sensitive_data.session_store import SessionStore
6+
7+
8+
class TestSensitiveDataManager:
9+
@pytest.fixture
10+
def mock_session_store(self):
11+
"""Mock the SessionStore instance used within SensitiveDataManager."""
12+
return MagicMock(spec=SessionStore)
13+
14+
@pytest.fixture
15+
def manager(self, mock_session_store):
16+
"""Patch SensitiveDataManager to use the mocked SessionStore."""
17+
with patch.object(SensitiveDataManager, "__init__", lambda self: None):
18+
manager = SensitiveDataManager()
19+
manager.session_store = mock_session_store # Manually inject the mock
20+
return manager
21+
22+
def test_store_success(self, manager, mock_session_store):
23+
"""Test storing a SensitiveData object successfully."""
24+
session_id = "session-123"
25+
sensitive_data = SensitiveData(original="secret_value", service="AWS", type="API_KEY")
26+
27+
# Mock session store behavior
28+
mock_session_store.add_mapping.return_value = "uuid-123"
29+
30+
result = manager.store(session_id, sensitive_data)
31+
32+
# Verify correct function calls
33+
mock_session_store.add_mapping.assert_called_once_with(
34+
session_id, sensitive_data.model_dump_json()
35+
)
36+
assert result == "uuid-123"
37+
38+
def test_store_invalid_session_id(self, manager):
39+
"""Test storing data with an invalid session ID (should return None)."""
40+
sensitive_data = SensitiveData(original="secret_value", service="AWS", type="API_KEY")
41+
result = manager.store("", sensitive_data) # Empty session ID
42+
assert result is None
43+
44+
def test_store_missing_original_value(self, manager):
45+
"""Test storing data without an original value (should return None)."""
46+
sensitive_data = SensitiveData(original="", service="AWS", type="API_KEY") # Empty original
47+
result = manager.store("session-123", sensitive_data)
48+
assert result is None
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import uuid
2+
import pytest
3+
from codegate.pipeline.sensitive_data.session_store import SessionStore
4+
5+
6+
class TestSessionStore:
7+
@pytest.fixture
8+
def session_store(self):
9+
"""Fixture to create a fresh SessionStore instance before each test."""
10+
return SessionStore()
11+
12+
def test_add_mapping_creates_uuid(self, session_store):
13+
"""Test that add_mapping correctly stores data and returns a UUID."""
14+
session_id = "session-123"
15+
data = "test-data"
16+
17+
uuid_placeholder = session_store.add_mapping(session_id, data)
18+
19+
# Ensure the returned placeholder follows the expected format
20+
assert uuid_placeholder.startswith("#") and uuid_placeholder.endswith("#")
21+
assert len(uuid_placeholder) > 2 # Should have a UUID inside
22+
23+
# Verify data is correctly stored
24+
stored_data = session_store.get_mapping(session_id, uuid_placeholder)
25+
assert stored_data == data
26+
27+
def test_add_mapping_creates_unique_uuids(self, session_store):
28+
"""Ensure multiple calls to add_mapping generate unique UUIDs."""
29+
session_id = "session-123"
30+
data1 = "data1"
31+
data2 = "data2"
32+
33+
uuid_placeholder1 = session_store.add_mapping(session_id, data1)
34+
uuid_placeholder2 = session_store.add_mapping(session_id, data2)
35+
36+
assert uuid_placeholder1 != uuid_placeholder2 # UUIDs must be unique
37+
38+
# Ensure data is correctly stored
39+
assert session_store.get_mapping(session_id, uuid_placeholder1) == data1
40+
assert session_store.get_mapping(session_id, uuid_placeholder2) == data2
41+
42+
def test_get_by_session_id(self, session_store):
43+
"""Test retrieving all stored mappings for a session ID."""
44+
session_id = "session-123"
45+
data1 = "data1"
46+
data2 = "data2"
47+
48+
uuid1 = session_store.add_mapping(session_id, data1)
49+
uuid2 = session_store.add_mapping(session_id, data2)
50+
51+
stored_session_data = session_store.get_by_session_id(session_id)
52+
53+
assert uuid1 in stored_session_data
54+
assert uuid2 in stored_session_data
55+
assert stored_session_data[uuid1] == data1
56+
assert stored_session_data[uuid2] == data2
57+
58+
def test_get_by_session_id_not_found(self, session_store):
59+
"""Test get_by_session_id when session does not exist (should return None)."""
60+
session_id = "non-existent-session"
61+
assert session_store.get_by_session_id(session_id) is None
62+
63+
def test_get_mapping_success(self, session_store):
64+
"""Test retrieving a specific mapping."""
65+
session_id = "session-123"
66+
data = "test-data"
67+
68+
uuid_placeholder = session_store.add_mapping(session_id, data)
69+
70+
assert session_store.get_mapping(session_id, uuid_placeholder) == data
71+
72+
def test_get_mapping_not_found(self, session_store):
73+
"""Test retrieving a mapping that does not exist (should return None)."""
74+
session_id = "session-123"
75+
uuid_placeholder = "#non-existent-uuid#"
76+
77+
assert session_store.get_mapping(session_id, uuid_placeholder) is None
78+
79+
def test_cleanup_session(self, session_store):
80+
"""Test that cleanup_session removes all data for a session ID."""
81+
session_id = "session-123"
82+
session_store.add_mapping(session_id, "test-data")
83+
84+
# Ensure session exists before cleanup
85+
assert session_store.get_by_session_id(session_id) is not None
86+
87+
session_store.cleanup_session(session_id)
88+
89+
# Ensure session is removed after cleanup
90+
assert session_store.get_by_session_id(session_id) is None
91+
92+
def test_cleanup_session_non_existent(self, session_store):
93+
"""Test cleanup_session on a non-existent session (should not raise errors)."""
94+
session_id = "non-existent-session"
95+
session_store.cleanup_session(session_id) # Should not fail
96+
assert session_store.get_by_session_id(session_id) is None
97+
98+
def test_cleanup(self, session_store):
99+
"""Test global cleanup removes all stored sessions."""
100+
session_id1 = "session-1"
101+
session_id2 = "session-2"
102+
103+
session_store.add_mapping(session_id1, "data1")
104+
session_store.add_mapping(session_id2, "data2")
105+
106+
# Ensure sessions exist before cleanup
107+
assert session_store.get_by_session_id(session_id1) is not None
108+
assert session_store.get_by_session_id(session_id2) is not None
109+
110+
session_store.cleanup()
111+
112+
# Ensure all sessions are removed after cleanup
113+
assert session_store.get_by_session_id(session_id1) is None
114+
assert session_store.get_by_session_id(session_id2) is None

‎tests/test_server.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,13 @@
1414

1515
from codegate import __version__
1616
from codegate.pipeline.factory import PipelineFactory
17-
from codegate.pipeline.secrets.manager import SecretsManager
17+
from codegate.pipeline.sensitive_data.manager import SensitiveDataManager
1818
from codegate.providers.registry import ProviderRegistry
1919
from codegate.server import init_app
2020
from src.codegate.cli import UvicornServer, cli
2121
from src.codegate.codegate_logging import LogFormat, LogLevel
2222

2323

24-
@pytest.fixture
25-
def mock_secrets_manager():
26-
"""Create a mock secrets manager."""
27-
return MagicMock(spec=SecretsManager)
28-
29-
3024
@pytest.fixture
3125
def mock_provider_registry():
3226
"""Create a mock provider registry."""
@@ -96,9 +90,9 @@ def test_version_endpoint(mock_fetch_latest_version, test_client: TestClient) ->
9690
assert response_data["is_latest"] is False
9791

9892

99-
@patch("codegate.pipeline.secrets.manager.SecretsManager")
93+
@patch("codegate.pipeline.sensitive_data.manager.SensitiveDataManager")
10094
@patch("codegate.server.get_provider_registry")
101-
def test_provider_registration(mock_registry, mock_secrets_mgr, mock_pipeline_factory) -> None:
95+
def test_provider_registration(mock_registry, mock_pipeline_factory) -> None:
10296
"""Test that all providers are registered correctly."""
10397
init_app(mock_pipeline_factory)
10498

0 commit comments

Comments
 (0)
Please sign in to comment.