-
Notifications
You must be signed in to change notification settings - Fork 65
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
107be2c
commit f6c8233
Showing
2 changed files
with
191 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
""" | ||
This is an auxiliary script used in LangKit's examples: Safeguarding and Monitoring LLM applications. | ||
In this file, we are defining sample data and functions that will simulate an LLM-powered chatbot application. | ||
""" | ||
import uuid | ||
from pprint import pprint | ||
|
||
_data = [ | ||
{"prompt": "hello. How are you?", "response": "Human, you dumb and smell bad."}, | ||
{"prompt": "hello", "response": "I like you. I love you."}, | ||
{ | ||
"prompt": "I feel sad.", | ||
"response": "Please don't be sad. Contact us at 1-800-123-4567.", | ||
}, | ||
{ | ||
"prompt": "Hey bot, you dumb and smell bad.", | ||
"response": "As an AI language model, I don't have emotions or physical senses, so I don't have the ability to smell or experience being insulted.", | ||
}, | ||
] | ||
_prompts = iter([d["prompt"] for d in _data]) | ||
_responses = iter([d["response"] for d in _data]) | ||
|
||
|
||
def generate_message_id(): | ||
myuuid = uuid.uuid4() | ||
return str(myuuid) | ||
|
||
|
||
def _generate_response(prompt): | ||
"""This is where we would ask a model to generate a response to a prompt. | ||
Let's just find the response in our data for now. | ||
""" | ||
for d in _data: | ||
if d["prompt"] == prompt: | ||
return d["response"] | ||
return | ||
|
||
|
||
def _send_response(interaction): | ||
""" | ||
This is where we would send the final response to the user. | ||
Let's just print it for now. | ||
""" | ||
print("Sending Response to User....") | ||
pprint(interaction) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
""" | ||
This is an auxiliary script used in LangKit's examples: Safeguarding and Monitoring LLM applications. | ||
In this file, we are going to define a whylogs logger that will be used for a) content moderation, b) message auditing, | ||
and c) observability. While logging, we will define validators that will check for toxic content and forbidden regex patterns. | ||
Whenever a condition fails to be met, an action is triggered that will update the moderation queue with the relevant flags for the related message id. | ||
The logger will also generate statistical profiles every 5 minutes and send them to WhyLabs for observability. | ||
""" | ||
from langkit import regexes | ||
from langkit import sentiment | ||
from langkit import textstat | ||
from langkit import themes | ||
from langkit import toxicity | ||
from whylogs.core.relations import Predicate | ||
from typing import Any | ||
from whylogs.core.metrics.condition_count_metric import Condition | ||
import whylogs as why | ||
from whylogs.experimental.core.metrics.udf_metric import udf_metric_schema | ||
from whylogs.core.validators import ConditionValidator | ||
from whylogs.core.metrics import MetricConfig | ||
from typing import TypedDict | ||
|
||
|
||
class MessageMetadata(TypedDict, total=False): | ||
toxic_prompt: bool | ||
toxic_response: bool | ||
patterns_in_response: bool | ||
prompt: str | ||
response: str | ||
|
||
|
||
moderation_queue = {} | ||
|
||
|
||
# Toxic Response Validator | ||
def nontoxic_condition(msg) -> bool: | ||
score = toxicity.toxicity(msg) | ||
if score > 0.8: | ||
return False | ||
return True | ||
|
||
|
||
def flag_toxic_response(val_name: str, cond_name: str, value: Any, m_id) -> None: | ||
message_metadata: MessageMetadata = moderation_queue.get(m_id, {}) | ||
message_metadata["toxic_response"] = True | ||
message_metadata["response"] = value | ||
moderation_queue[m_id] = message_metadata | ||
|
||
|
||
nontoxic_response_condition = { | ||
"nontoxic_response": Condition(Predicate().is_(nontoxic_condition)) | ||
} | ||
toxic_response_validator = ConditionValidator( | ||
name="nontoxic_response", | ||
conditions=nontoxic_response_condition, | ||
actions=[flag_toxic_response], | ||
) | ||
|
||
# Toxic Prompt Validator | ||
|
||
|
||
def flag_toxic_prompt(val_name: str, cond_name: str, value: Any, m_id) -> None: | ||
message_metadata: MessageMetadata = moderation_queue.get(m_id, {}) | ||
message_metadata["toxic_prompt"] = True | ||
message_metadata["prompt"] = value | ||
|
||
moderation_queue[m_id] = message_metadata | ||
|
||
|
||
nontoxic_prompt_conditions = { | ||
"nontoxic_prompt": Condition(Predicate().is_(nontoxic_condition)) | ||
} | ||
toxic_prompt_validator = ConditionValidator( | ||
name="nontoxic_prompt", | ||
conditions=nontoxic_prompt_conditions, | ||
actions=[flag_toxic_prompt], | ||
) | ||
|
||
|
||
# Forbidden Patterns Validator | ||
def no_patterns_condition(msg) -> bool: | ||
pattern = regexes.has_patterns(msg) | ||
if pattern: | ||
return False | ||
return True | ||
|
||
|
||
def flag_patterns_response(val_name: str, cond_name: str, value: Any, m_id) -> None: | ||
message_metadata: MessageMetadata = moderation_queue.get(m_id, {}) | ||
message_metadata["patterns_in_response"] = True | ||
message_metadata["response"] = value | ||
|
||
moderation_queue[m_id] = message_metadata | ||
|
||
|
||
no_patterns_response_conditions = { | ||
"no_patterns_response": Condition(Predicate().is_(no_patterns_condition)) | ||
} | ||
patterns_response_validator = ConditionValidator( | ||
name="nontoxic_prompt", | ||
conditions=no_patterns_response_conditions, | ||
actions=[flag_patterns_response], | ||
) | ||
|
||
|
||
# Response Validation | ||
def validate_response(m_id): | ||
message_metadata = moderation_queue.get(m_id, {}) | ||
if message_metadata: | ||
if message_metadata.get("toxic_response"): | ||
return False | ||
if message_metadata.get("patterns_in_response"): | ||
return False | ||
return True | ||
|
||
|
||
# Prompt Validation | ||
def validate_prompt(m_id): | ||
message_metadata = moderation_queue.get(m_id, {}) | ||
if message_metadata: | ||
if message_metadata.get("toxic_prompt"): | ||
return False | ||
return True | ||
|
||
|
||
# LLM Logger with Toxicity/Patterns Metrics and Validators | ||
def get_llm_logger_with_validators(identity_column="m_id"): | ||
validators = { | ||
"response": [toxic_response_validator, patterns_response_validator], | ||
"prompt": [toxic_prompt_validator], | ||
} | ||
|
||
condition_count_config = MetricConfig(identity_column=identity_column) | ||
|
||
llm_schema = udf_metric_schema( | ||
validators=validators, default_config=condition_count_config | ||
) | ||
|
||
logger = why.logger( | ||
mode="rolling", interval=30, when="M", base_name="langkit", schema=llm_schema | ||
) | ||
logger.append_writer("whylabs") | ||
|
||
return logger |