Skip to content

Commit

Permalink
Merge pull request #1 from stanford-futuredata/sid/updates
Browse files Browse the repository at this point in the history
Add zero-shot CoT, better logging implicit aggregation over all columns, and move model params into model class
  • Loading branch information
liana313 committed Jul 28, 2024
2 parents 54f2a37 + e3350b4 commit 205a3ff
Show file tree
Hide file tree
Showing 13 changed files with 171 additions and 112 deletions.
4 changes: 2 additions & 2 deletions docs/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ This can be achieved by applying a semantic filter followed by a semantic aggreg
from lotus.models import E5Model, OpenAIModel
# Configure models for LOTUS
lm = OpenAIModel()
lm = OpenAIModel(max_tokens=512)
rm = E5Model()
lotus.settings.configure(lm=lm, rm=rm, model_params={"max_tokens": 512})
lotus.settings.configure(lm=lm, rm=rm)
# Dataset containing courses and their descriptions/workloads
data = [
Expand Down
2 changes: 1 addition & 1 deletion examples/op_examples/agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
lm = OpenAIModel()
rm = E5Model()

lotus.settings.configure(lm=lm, rm=rm, model_params={"max_tokens": 512})
lotus.settings.configure(lm=lm, rm=rm)
data = {
"Course Name": [
"Probability and Random Processes",
Expand Down
4 changes: 2 additions & 2 deletions examples/op_examples/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import lotus
from lotus.models import E5Model, OpenAIModel

lm = OpenAIModel()
lm = OpenAIModel(max_tokens=2048)
rm = E5Model()

lotus.settings.configure(lm=lm, rm=rm, model_params={"max_tokens": 2048})
lotus.settings.configure(lm=lm, rm=rm)
data = {
"Course Name": [
"Probability and Random Processes",
Expand Down
35 changes: 25 additions & 10 deletions lotus/models/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from openai import OpenAI
from transformers import AutoTokenizer

import lotus
from lotus.models.lm import LM
from lotus.settings import settings

# Mapping from Databricks model names to their Hugging Face model names for tokenizers
DBRX_NAME_TO_MODEL = {
Expand Down Expand Up @@ -37,7 +37,9 @@ class OpenAIModel(LM):
model (str): The name of the model to use.
api_key (Optional[str]): An API key (e.g. from OpenAI or Databricks).
api_base (Optional[str]): The endpoint of the server.
provider (str): Either openai, dbrx, or vllm
provider (str): Either openai, dbrx, or vllm.
max_batch_size (int): The maximum batch size for the model.
max_ctx_len (int): The maximum context length for the model.
**kwargs (Dict[str, Any]): Additional keyword arguments. They can be used to specify inference parameters.
"""

Expand All @@ -47,16 +49,20 @@ def __init__(
api_key: Optional[str] = None,
api_base: Optional[str] = None,
provider: str = "openai",
max_batch_size=64,
max_ctx_len=4096,
**kwargs: Dict[str, Any],
):
super().__init__()
self.provider = provider
self.use_chat = provider in ["openai", "dbrx"]
self.max_batch_size = max_batch_size
self.max_ctx_len = max_ctx_len

self.kwargs = {
"model": model,
"temperature": 0.0,
"max_tokens": 150,
"max_tokens": 512,
"top_p": 1,
"n": 1,
**kwargs,
Expand Down Expand Up @@ -113,10 +119,13 @@ def handle_completion_request(self, messages: List, **kwargs):
Union[List, Tuple[List, List]]: A list of outputs for each prompt in the batch. If logprobs is specified in the keyword arguments,
then a list of logprobs is also returned.
"""
prompt = [
self.tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
for message in messages
]
if not isinstance(messages[0], list):
prompt = [self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)]
else:
prompt = [
self.tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
for message in messages
]

kwargs = {**self.kwargs, **kwargs}
kwargs["prompt"] = prompt
Expand Down Expand Up @@ -194,12 +203,14 @@ def thread_function(idx, messages, kwargs):
def __call__(
self, messages_batch: Union[List, List[List]], **kwargs: Dict[str, Any]
) -> Union[List, Tuple[List, List]]:
lotus.logger.debug(f"OpenAIModel.__call__ messages_batch: {messages_batch}")
lotus.logger.debug(f"OpenAIModel.__call__ kwargs: {kwargs}")
# Bakes max batch size into model call. # TODO: Figure out less hacky way to do this.
if isinstance(messages_batch[0], list) and len(messages_batch) > settings.max_batch_size:
if isinstance(messages_batch[0], list) and len(messages_batch) > self.max_batch_size:
text_ret = []
logprobs_ret = []
for i in range(0, len(messages_batch), settings.max_batch_size):
res = self(messages_batch[i : i + settings.max_batch_size], **kwargs)
for i in range(0, len(messages_batch), self.max_batch_size):
res = self(messages_batch[i : i + self.max_batch_size], **kwargs)
if kwargs.get("logprobs", False):
text, logprobs = res
logprobs_ret.extend(logprobs)
Expand Down Expand Up @@ -267,3 +278,7 @@ def completion_request(self, **kwargs: Dict[str, Any]) -> Dict[str, Any]:
dict: OpenAI completion response.
"""
return self.client.completions.create(**kwargs).model_dump()

@property
def max_tokens(self) -> int:
return self.kwargs["max_tokens"]
28 changes: 15 additions & 13 deletions lotus/sem_ops/sem_agg.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List
from typing import List

import pandas as pd

Expand All @@ -11,7 +11,6 @@ def sem_agg(
model: lotus.models.LM,
user_instruction: str,
partition_ids: List[int],
**kwargs: Dict[str, Any],
) -> str:
"""
Aggregates multiple documents into a single answer using a model.
Expand All @@ -21,7 +20,6 @@ def sem_agg(
model (lotus.models.LM): The model to use.
user_instruction (str): The user instruction for aggregation.
partition_ids (List[int]): The partition ids for the documents. Documents with the same partition id will be aggregated together.
**kwargs (Dict[str, Any]): Additional keyword arguments.
Returns:
str: The aggregated answer.
Expand Down Expand Up @@ -82,10 +80,9 @@ def doc_formatter(tree_level, doc, ctr):
formatted_doc = doc_formatter(tree_level, docs[idx], doc_ctr)
new_tokens = model.count_tokens(formatted_doc)

if (
new_tokens + context_tokens + template_tokens
> lotus.settings.max_ctx_len - lotus.settings.get_max_tokens()
) or (partition_id != cur_partition_id and not do_fold):
if (new_tokens + context_tokens + template_tokens > model.max_ctx_len - model.max_tokens) or (
partition_id != cur_partition_id and not do_fold
):
# close the current prompt

prompt = template.replace("{{docs_str}}", context_str)
Expand All @@ -110,7 +107,7 @@ def doc_formatter(tree_level, doc, ctr):
lotus.logger.debug(f"Prompt added to batch: {prompt}")
batch.append([{"role": "user", "content": prompt}])
new_partition_ids.append(cur_partition_id)
summaries = model(batch, **kwargs)
summaries = model(batch)
partition_ids = new_partition_ids
new_partition_ids = []

Expand All @@ -136,22 +133,27 @@ def _validate(obj):
def __call__(
self,
user_instruction: str,
all_cols: bool = False,
suffix: str = "_output",
) -> pd.DataFrame:
"""
Applies semantic aggregation over a dataframe.
Args:
user_instruction (str): The user instruction for aggregation.
all_cols (bool): Whether to use all columns in the dataframe. Defaults to False.
suffix (Optional[str]): The suffix for the new column. Defaults to "_output".
Returns:
pd.DataFrame: The dataframe with the aggregated answer.
"""

lotus.logger.debug(user_instruction)
col_li = lotus.nl_expression.parse_cols(user_instruction)
lotus.logger.debug(col_li)
lotus.logger.debug(f"User instruction: {user_instruction}")
if all_cols:
col_li = list(self._obj.columns)
else:
col_li = lotus.nl_expression.parse_cols(user_instruction)
lotus.logger.debug(f"Columns: {col_li}")

# check that column exists
for column in col_li:
Expand All @@ -166,15 +168,15 @@ def __call__(
partition_ids = [0] * len(self._obj)

df_txt = task_instructions.df2text(self._obj, col_li)
lotus.logger.debug(df_txt)
lotus.logger.debug(f"df_txt: {df_txt}")
formatted_usr_instr = lotus.nl_expression.nle2str(user_instruction, col_li)
lotus.logger.debug(f"formatted_usr_instr: {formatted_usr_instr}")

answer = sem_agg(
df_txt,
lotus.settings.lm,
formatted_usr_instr,
partition_ids,
**lotus.settings.model_params,
)

# package answer in a dataframe
Expand Down
10 changes: 3 additions & 7 deletions lotus/sem_ops/sem_extract.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, List, Tuple
from typing import Callable, List, Tuple

import pandas as pd

Expand All @@ -13,7 +13,6 @@ def sem_extract(
model: lotus.models.LM,
user_instruction: str,
postprocessor: Callable = extract_postprocess,
**kwargs: Dict[str, Any],
) -> Tuple:
"""
Extracts from a list of documents using a model.
Expand All @@ -23,7 +22,6 @@ def sem_extract(
model (lotus.models.LM): The model to use.
user_instruction (str): The user instruction for extract.
postprocessor (Optional[Callable]): The postprocessor for the model outputs. Defaults to extract_postprocess.
**kwargs (Dict[str, Any]): Additional keyword arguments.
Returns:
Tuple: The outputs, raw outputs, and quotes.
Expand All @@ -37,12 +35,11 @@ def sem_extract(
inputs.append(prompt)

# call model
raw_outputs = model(inputs, **kwargs)

lotus.logger.debug(f"---\n{raw_outputs}\n---")
raw_outputs = model(inputs)

# post process results
outputs, quotes = postprocessor(raw_outputs)
lotus.logger.debug(f"raw_outputs: {raw_outputs}")
lotus.logger.debug(f"outputs: {outputs}")
lotus.logger.debug(f"quotes: {quotes}")

Expand Down Expand Up @@ -95,7 +92,6 @@ def __call__(
lotus.settings.lm,
formatted_usr_instr,
postprocessor=postprocessor,
**lotus.settings.model_params,
)

new_df = self._obj
Expand Down
18 changes: 9 additions & 9 deletions lotus/sem_ops/sem_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def sem_filter(
examples_df_txt: Optional[str] = None,
examples_answers: Optional[List[bool]] = None,
cot_reasoning: Optional[List[str]] = None,
strategy: Optional[str] = None,
logprobs: bool = False,
**kwargs: Dict[str, Any],
) -> Tuple:
"""
Filters a list of documents based on a given user instruction using a language model.
Expand All @@ -31,26 +31,26 @@ def sem_filter(
examples_answers (Optional[List[bool]]): The answers for examples. Defaults to None.
cot_reasoning (Optional[List[str]]): The reasoning for CoT. Defaults to None.
logprobs (Optional[bool]): Whether to return log probabilities. Defaults to False.
**kwargs (Dict[str, Any]): Additional keyword arguments.
Returns:
Tuple: A tuple containing the True/False outputs, raw outputs, explanations, and raw log probabilities (if logprobs=True).
"""
inputs = []
for doc in docs:
prompt = lotus.templates.task_instructions.filter_formatter(
doc, user_instruction, examples_df_txt, examples_answers, cot_reasoning
doc, user_instruction, examples_df_txt, examples_answers, cot_reasoning, strategy
)
lotus.logger.debug(f"input to model: {prompt}")
inputs.append(prompt)
res = model(inputs, logprobs=logprobs, **kwargs)
res = model(inputs, logprobs=logprobs)
if logprobs:
raw_outputs, raw_logprobs = res
else:
raw_outputs = res

lotus.logger.debug(f"---\n{raw_outputs}\n---")
outputs, explanations = filter_postprocess(raw_outputs, default=default, cot_reasoning=cot_reasoning is not None)
outputs, explanations = filter_postprocess(
raw_outputs, default=default, cot_reasoning=strategy in ["cot", "zs-cot"]
)
lotus.logger.debug(f"outputs: {outputs}")
lotus.logger.debug(f"raw_outputs: {raw_outputs}")
lotus.logger.debug(f"explanations: {explanations}")
Expand Down Expand Up @@ -158,7 +158,7 @@ def __call__(
examples_answers=helper_examples_answers,
cot_reasoning=helper_cot_reasoning,
logprobs=True,
**lotus.settings.model_params,
strategy=helper_strategy,
)

high_conf_idxs = set()
Expand Down Expand Up @@ -195,7 +195,7 @@ def __call__(
examples_df_txt=examples_df_txt,
examples_answers=examples_answers,
cot_reasoning=cot_reasoning,
**lotus.settings.model_params,
strategy=strategy,
)

for idx, large_idx in enumerate(low_conf_idxs):
Expand All @@ -215,7 +215,7 @@ def __call__(
examples_df_txt=examples_df_txt,
examples_answers=examples_answers,
cot_reasoning=cot_reasoning,
**lotus.settings.model_params,
strategy=strategy,
)

# find indices where output is True
Expand Down
Loading

0 comments on commit 205a3ff

Please sign in to comment.