Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change base image to CUDA, change to dspy.Retrieve #11

Merged
merged 2 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
change to dspy.Retrieve
  • Loading branch information
etwk committed Aug 16, 2024
commit 558379c47b8184a7f61abba36ac31c4cb118f6f1
7 changes: 3 additions & 4 deletions src/modules/verdict.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,19 @@ class GenerateSearchQuery(dspy.Signature):
- does different InputField name other than answer compateble with dspy evaluate
"""
class Verdict(dspy.Module):
def __init__(self, retrieve, passages_per_hop=3, max_hops=3):
def __init__(self, passages_per_hop=3, max_hops=3):
super().__init__()
# self.generate_query = dspy.ChainOfThought(GenerateSearchQuery) # IMPORTANT: solves error `list index out of range`
self.generate_query = [dspy.ChainOfThought(GenerateSearchQuery) for _ in range(max_hops)]
self.retrieve = retrieve
self.retrieve.k = passages_per_hop
self.retrieve = dspy.Retrieve(k=passages_per_hop)
self.generate_verdict = dspy.ChainOfThought(CheckStatementFaithfulness)
self.max_hops = max_hops

def forward(self, statement):
context = []
for hop in range(self.max_hops):
query = self.generate_query[hop](context=context, statement=statement).query
passages = self.retrieve(query=query, text_only=True)
passages = self.retrieve(query).passages
context = deduplicate(context + passages)

verdict = self.generate_verdict(context=context, statement=statement)
Expand Down
15 changes: 9 additions & 6 deletions src/pipeline/verdict_citation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@ def __init__(
):
self.retrieve = LlamaIndexRM(docs=docs)

# loading compiled Verdict
self.context_verdict = Verdict(retrieve=self.retrieve)
optimizer_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../optimizers/verdict_MIPROv2.json")
self.context_verdict.load(optimizer_path)

def get(self, statement):
rep = self.context_verdict(statement)
with dspy.context(rm=self.retrieve):
self.context_verdict = Verdict()

# loading compiled Verdict
optimizer_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../optimizers/verdict_MIPROv2.json")
self.context_verdict.load(optimizer_path)

rep = self.context_verdict(statement)

context = rep.context
verdict = rep.answer

Expand Down