Skip to content

Commit

Permalink
Safer interruption checks for dialogues.
Browse files Browse the repository at this point in the history
  • Loading branch information
radi-cho committed Mar 27, 2023
1 parent f5a3f02 commit 45bded5
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions src/datasetGPT/conversations.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,14 @@ def initialize_chain(

return chain, system_message

def end_phrase_interruption(self, agent: str, message: str) -> None:
def end_phrase_interruption(self, agent: str, message: str) -> bool:
"""Check whether to interrupt conversation generation."""
if self.config.interruption == "end_phrase":
if self.config.end_agent == agent or self.config.end_agent == "both":
if self.config.end_phrase in message:
raise StopIteration()
return True

return False

def generate_item(self) -> Dict[str, Union[List[List[Any]], float, int]]:
"""Run two chains to talk with one another and record the chat history."""
Expand All @@ -122,11 +124,16 @@ def generate_item(self) -> Dict[str, Union[List[List[Any]], float, int]]:
for _ in range(conversation_config["length"]):
chain1_out = chain1.predict(input=chain1_inp)
utterances.append(["agent1", chain1_out])
self.end_phrase_interruption("agent1", chain1_out)

if self.end_phrase_interruption("agent1", chain1_out):
break

chain2_out = chain2.predict(input=chain1_out)
utterances.append(["agent2", chain2_out])
self.end_phrase_interruption("agent2", chain2_out)

if self.end_phrase_interruption("agent2", chain2_out):
break

chain1_inp = chain2_out

return {**conversation_config,
Expand Down

0 comments on commit 45bded5

Please sign in to comment.