Skip to content

Commit

Permalink
allow init guardrails with output parsing logic
Browse files Browse the repository at this point in the history
  • Loading branch information
ishaan-jaff committed Sep 4, 2024
1 parent f1111f9 commit 4ab8e52
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 66 deletions.
73 changes: 7 additions & 66 deletions litellm/proxy/guardrails/guardrail_hooks/presidio.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,23 +48,24 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
# Class variables or attributes
def __init__(
self,
logging_only: Optional[bool] = None,
mock_testing: bool = False,
mock_redacted_text: Optional[dict] = None,
presidio_analyzer_api_base: Optional[str] = None,
presidio_anonymizer_api_base: Optional[str] = None,
output_parse_pii: Optional[bool] = False,
presidio_ad_hoc_recognizers: Optional[str] = None,
**kwargs,
):
self.pii_tokens: dict = (
{}
) # mapping of PII token to original text - only used with Presidio `replace` operation

self.mock_redacted_text = mock_redacted_text
self.logging_only = logging_only
self.output_parse_pii = output_parse_pii or False
if mock_testing is True: # for testing purposes only
return

ad_hoc_recognizers = litellm.presidio_ad_hoc_recognizers
ad_hoc_recognizers = presidio_ad_hoc_recognizers
if ad_hoc_recognizers is not None:
try:
with open(ad_hoc_recognizers, "r") as file:
Expand Down Expand Up @@ -225,69 +226,9 @@ async def async_pre_call_hook(
"""

try:
if (
self.logging_only is True
): # only modify the logging obj data (done by async_logging_hook)
return data
permissions = user_api_key_dict.permissions
output_parse_pii = permissions.get(
"output_parse_pii", litellm.output_parse_pii
) # allow key to turn on/off output parsing for pii
no_pii = permissions.get(
"no-pii", None
) # allow key to turn on/off pii masking (if user is allowed to set pii controls, then they can override the key defaults)

if no_pii is None:
# check older way of turning on/off pii
no_pii = not permissions.get("pii", True)

content_safety = data.get("content_safety", None)
verbose_proxy_logger.debug("content_safety: %s", content_safety)
## Request-level turn on/off PII controls ##
if content_safety is not None and isinstance(content_safety, dict):
# pii masking ##
if (
content_safety.get("no-pii", None) is not None
and content_safety.get("no-pii") == True
):
# check if user allowed to turn this off
if permissions.get("allow_pii_controls", False) == False:
raise HTTPException(
status_code=400,
detail={
"error": "Not allowed to set PII controls per request"
},
)
else: # user allowed to turn off pii masking
no_pii = content_safety.get("no-pii")
if not isinstance(no_pii, bool):
raise HTTPException(
status_code=400,
detail={"error": "no_pii needs to be a boolean value"},
)
## pii output parsing ##
if content_safety.get("output_parse_pii", None) is not None:
# check if user allowed to turn this off
if permissions.get("allow_pii_controls", False) == False:
raise HTTPException(
status_code=400,
detail={
"error": "Not allowed to set PII controls per request"
},
)
else: # user allowed to turn on/off pii output parsing
output_parse_pii = content_safety.get("output_parse_pii")
if not isinstance(output_parse_pii, bool):
raise HTTPException(
status_code=400,
detail={
"error": "output_parse_pii needs to be a boolean value"
},
)

if no_pii is True: # turn off pii masking
return data

presidio_config = self.get_presidio_settings_from_request_data(data)

if call_type == "completion": # /chat/completions requests
Expand All @@ -299,7 +240,7 @@ async def async_pre_call_hook(
tasks.append(
self.check_pii(
text=m["content"],
output_parse_pii=output_parse_pii,
output_parse_pii=self.output_parse_pii,
presidio_config=presidio_config,
)
)
Expand Down Expand Up @@ -372,9 +313,9 @@ async def async_post_call_success_hook(
Output parse the response object to replace the masked tokens with user sent values
"""
verbose_proxy_logger.debug(
f"PII Masking Args: litellm.output_parse_pii={litellm.output_parse_pii}; type of response={type(response)}"
f"PII Masking Args: self.output_parse_pii={self.output_parse_pii}; type of response={type(response)}"
)
if litellm.output_parse_pii == False:
if self.output_parse_pii == False:
return response

if isinstance(response, ModelResponse) and not isinstance(
Expand Down
22 changes: 22 additions & 0 deletions litellm/proxy/guardrails/init_guardrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# v2 implementation
from litellm.types.guardrails import (
Guardrail,
GuardrailEventHooks,
GuardrailItem,
GuardrailItemSpec,
LakeraCategoryThresholds,
Expand Down Expand Up @@ -104,6 +105,10 @@ def init_guardrails_v2(
api_base=litellm_params_data.get("api_base"),
guardrailIdentifier=litellm_params_data.get("guardrailIdentifier"),
guardrailVersion=litellm_params_data.get("guardrailVersion"),
output_parse_pii=litellm_params_data.get("output_parse_pii"),
presidio_ad_hoc_recognizers=litellm_params_data.get(
"presidio_ad_hoc_recognizers"
),
)

if (
Expand Down Expand Up @@ -173,7 +178,24 @@ def init_guardrails_v2(
_presidio_callback = _OPTIONAL_PresidioPIIMasking(
guardrail_name=guardrail["guardrail_name"],
event_hook=litellm_params["mode"],
output_parse_pii=litellm_params["output_parse_pii"],
presidio_ad_hoc_recognizers=litellm_params[
"presidio_ad_hoc_recognizers"
],
)

if litellm_params["output_parse_pii"] is True:
_success_callback = _OPTIONAL_PresidioPIIMasking(
output_parse_pii=True,
guardrail_name=guardrail["guardrail_name"],
event_hook=GuardrailEventHooks.post_call.value,
presidio_ad_hoc_recognizers=litellm_params[
"presidio_ad_hoc_recognizers"
],
)

litellm.callbacks.append(_success_callback) # type: ignore

litellm.callbacks.append(_presidio_callback) # type: ignore
elif (
isinstance(litellm_params["guardrail"], str)
Expand Down

0 comments on commit 4ab8e52

Please sign in to comment.