diff --git a/docs/quickstart.rst b/docs/quickstart.rst index 9355240..fe7e99d 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -53,10 +53,10 @@ This can be achieved by applying a semantic filter followed by a semantic aggreg from lotus.models import E5Model, OpenAIModel # Configure models for LOTUS - lm = OpenAIModel() + lm = OpenAIModel(max_tokens=512) rm = E5Model() - lotus.settings.configure(lm=lm, rm=rm, model_params={"max_tokens": 512}) + lotus.settings.configure(lm=lm, rm=rm) # Dataset containing courses and their descriptions/workloads data = [ diff --git a/examples/op_examples/agg.py b/examples/op_examples/agg.py index c80c95a..add1711 100644 --- a/examples/op_examples/agg.py +++ b/examples/op_examples/agg.py @@ -6,7 +6,7 @@ lm = OpenAIModel() rm = E5Model() -lotus.settings.configure(lm=lm, rm=rm, model_params={"max_tokens": 512}) +lotus.settings.configure(lm=lm, rm=rm) data = { "Course Name": [ "Probability and Random Processes", diff --git a/examples/op_examples/partition.py b/examples/op_examples/partition.py index 026246e..ca42d17 100644 --- a/examples/op_examples/partition.py +++ b/examples/op_examples/partition.py @@ -3,10 +3,10 @@ import lotus from lotus.models import E5Model, OpenAIModel -lm = OpenAIModel() +lm = OpenAIModel(max_tokens=2048) rm = E5Model() -lotus.settings.configure(lm=lm, rm=rm, model_params={"max_tokens": 2048}) +lotus.settings.configure(lm=lm, rm=rm) data = { "Course Name": [ "Probability and Random Processes", diff --git a/lotus/models/openai_model.py b/lotus/models/openai_model.py index 1542046..bf0726e 100644 --- a/lotus/models/openai_model.py +++ b/lotus/models/openai_model.py @@ -8,8 +8,8 @@ from openai import OpenAI from transformers import AutoTokenizer +import lotus from lotus.models.lm import LM -from lotus.settings import settings # Mapping from Databricks model names to their Hugging Face model names for tokenizers DBRX_NAME_TO_MODEL = { @@ -37,7 +37,9 @@ class OpenAIModel(LM): model (str): The name of the model to use. api_key (Optional[str]): An API key (e.g. from OpenAI or Databricks). api_base (Optional[str]): The endpoint of the server. - provider (str): Either openai, dbrx, or vllm + provider (str): Either openai, dbrx, or vllm. + max_batch_size (int): The maximum batch size for the model. + max_ctx_len (int): The maximum context length for the model. **kwargs (Dict[str, Any]): Additional keyword arguments. They can be used to specify inference parameters. """ @@ -47,16 +49,20 @@ def __init__( api_key: Optional[str] = None, api_base: Optional[str] = None, provider: str = "openai", + max_batch_size=64, + max_ctx_len=4096, **kwargs: Dict[str, Any], ): super().__init__() self.provider = provider self.use_chat = provider in ["openai", "dbrx"] + self.max_batch_size = max_batch_size + self.max_ctx_len = max_ctx_len self.kwargs = { "model": model, "temperature": 0.0, - "max_tokens": 150, + "max_tokens": 512, "top_p": 1, "n": 1, **kwargs, @@ -113,10 +119,13 @@ def handle_completion_request(self, messages: List, **kwargs): Union[List, Tuple[List, List]]: A list of outputs for each prompt in the batch. If logprobs is specified in the keyword arguments, then a list of logprobs is also returned. """ - prompt = [ - self.tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) - for message in messages - ] + if not isinstance(messages[0], list): + prompt = [self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)] + else: + prompt = [ + self.tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) + for message in messages + ] kwargs = {**self.kwargs, **kwargs} kwargs["prompt"] = prompt @@ -194,12 +203,14 @@ def thread_function(idx, messages, kwargs): def __call__( self, messages_batch: Union[List, List[List]], **kwargs: Dict[str, Any] ) -> Union[List, Tuple[List, List]]: + lotus.logger.debug(f"OpenAIModel.__call__ messages_batch: {messages_batch}") + lotus.logger.debug(f"OpenAIModel.__call__ kwargs: {kwargs}") # Bakes max batch size into model call. # TODO: Figure out less hacky way to do this. - if isinstance(messages_batch[0], list) and len(messages_batch) > settings.max_batch_size: + if isinstance(messages_batch[0], list) and len(messages_batch) > self.max_batch_size: text_ret = [] logprobs_ret = [] - for i in range(0, len(messages_batch), settings.max_batch_size): - res = self(messages_batch[i : i + settings.max_batch_size], **kwargs) + for i in range(0, len(messages_batch), self.max_batch_size): + res = self(messages_batch[i : i + self.max_batch_size], **kwargs) if kwargs.get("logprobs", False): text, logprobs = res logprobs_ret.extend(logprobs) @@ -267,3 +278,7 @@ def completion_request(self, **kwargs: Dict[str, Any]) -> Dict[str, Any]: dict: OpenAI completion response. """ return self.client.completions.create(**kwargs).model_dump() + + @property + def max_tokens(self) -> int: + return self.kwargs["max_tokens"] diff --git a/lotus/sem_ops/sem_agg.py b/lotus/sem_ops/sem_agg.py index 5318517..498d8cd 100644 --- a/lotus/sem_ops/sem_agg.py +++ b/lotus/sem_ops/sem_agg.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import List import pandas as pd @@ -11,7 +11,6 @@ def sem_agg( model: lotus.models.LM, user_instruction: str, partition_ids: List[int], - **kwargs: Dict[str, Any], ) -> str: """ Aggregates multiple documents into a single answer using a model. @@ -21,7 +20,6 @@ def sem_agg( model (lotus.models.LM): The model to use. user_instruction (str): The user instruction for aggregation. partition_ids (List[int]): The partition ids for the documents. Documents with the same partition id will be aggregated together. - **kwargs (Dict[str, Any]): Additional keyword arguments. Returns: str: The aggregated answer. @@ -82,10 +80,9 @@ def doc_formatter(tree_level, doc, ctr): formatted_doc = doc_formatter(tree_level, docs[idx], doc_ctr) new_tokens = model.count_tokens(formatted_doc) - if ( - new_tokens + context_tokens + template_tokens - > lotus.settings.max_ctx_len - lotus.settings.get_max_tokens() - ) or (partition_id != cur_partition_id and not do_fold): + if (new_tokens + context_tokens + template_tokens > model.max_ctx_len - model.max_tokens) or ( + partition_id != cur_partition_id and not do_fold + ): # close the current prompt prompt = template.replace("{{docs_str}}", context_str) @@ -110,7 +107,7 @@ def doc_formatter(tree_level, doc, ctr): lotus.logger.debug(f"Prompt added to batch: {prompt}") batch.append([{"role": "user", "content": prompt}]) new_partition_ids.append(cur_partition_id) - summaries = model(batch, **kwargs) + summaries = model(batch) partition_ids = new_partition_ids new_partition_ids = [] @@ -136,6 +133,7 @@ def _validate(obj): def __call__( self, user_instruction: str, + all_cols: bool = False, suffix: str = "_output", ) -> pd.DataFrame: """ @@ -143,15 +141,19 @@ def __call__( Args: user_instruction (str): The user instruction for aggregation. + all_cols (bool): Whether to use all columns in the dataframe. Defaults to False. suffix (Optional[str]): The suffix for the new column. Defaults to "_output". Returns: pd.DataFrame: The dataframe with the aggregated answer. """ - lotus.logger.debug(user_instruction) - col_li = lotus.nl_expression.parse_cols(user_instruction) - lotus.logger.debug(col_li) + lotus.logger.debug(f"User instruction: {user_instruction}") + if all_cols: + col_li = list(self._obj.columns) + else: + col_li = lotus.nl_expression.parse_cols(user_instruction) + lotus.logger.debug(f"Columns: {col_li}") # check that column exists for column in col_li: @@ -166,15 +168,15 @@ def __call__( partition_ids = [0] * len(self._obj) df_txt = task_instructions.df2text(self._obj, col_li) - lotus.logger.debug(df_txt) + lotus.logger.debug(f"df_txt: {df_txt}") formatted_usr_instr = lotus.nl_expression.nle2str(user_instruction, col_li) + lotus.logger.debug(f"formatted_usr_instr: {formatted_usr_instr}") answer = sem_agg( df_txt, lotus.settings.lm, formatted_usr_instr, partition_ids, - **lotus.settings.model_params, ) # package answer in a dataframe diff --git a/lotus/sem_ops/sem_extract.py b/lotus/sem_ops/sem_extract.py index 61c528f..f1f863b 100644 --- a/lotus/sem_ops/sem_extract.py +++ b/lotus/sem_ops/sem_extract.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Tuple +from typing import Callable, List, Tuple import pandas as pd @@ -13,7 +13,6 @@ def sem_extract( model: lotus.models.LM, user_instruction: str, postprocessor: Callable = extract_postprocess, - **kwargs: Dict[str, Any], ) -> Tuple: """ Extracts from a list of documents using a model. @@ -23,7 +22,6 @@ def sem_extract( model (lotus.models.LM): The model to use. user_instruction (str): The user instruction for extract. postprocessor (Optional[Callable]): The postprocessor for the model outputs. Defaults to extract_postprocess. - **kwargs (Dict[str, Any]): Additional keyword arguments. Returns: Tuple: The outputs, raw outputs, and quotes. @@ -37,12 +35,11 @@ def sem_extract( inputs.append(prompt) # call model - raw_outputs = model(inputs, **kwargs) - - lotus.logger.debug(f"---\n{raw_outputs}\n---") + raw_outputs = model(inputs) # post process results outputs, quotes = postprocessor(raw_outputs) + lotus.logger.debug(f"raw_outputs: {raw_outputs}") lotus.logger.debug(f"outputs: {outputs}") lotus.logger.debug(f"quotes: {quotes}") @@ -95,7 +92,6 @@ def __call__( lotus.settings.lm, formatted_usr_instr, postprocessor=postprocessor, - **lotus.settings.model_params, ) new_df = self._obj diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index 36fe0f0..0ba9526 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -16,8 +16,8 @@ def sem_filter( examples_df_txt: Optional[str] = None, examples_answers: Optional[List[bool]] = None, cot_reasoning: Optional[List[str]] = None, + strategy: Optional[str] = None, logprobs: bool = False, - **kwargs: Dict[str, Any], ) -> Tuple: """ Filters a list of documents based on a given user instruction using a language model. @@ -31,7 +31,6 @@ def sem_filter( examples_answers (Optional[List[bool]]): The answers for examples. Defaults to None. cot_reasoning (Optional[List[str]]): The reasoning for CoT. Defaults to None. logprobs (Optional[bool]): Whether to return log probabilities. Defaults to False. - **kwargs (Dict[str, Any]): Additional keyword arguments. Returns: Tuple: A tuple containing the True/False outputs, raw outputs, explanations, and raw log probabilities (if logprobs=True). @@ -39,18 +38,19 @@ def sem_filter( inputs = [] for doc in docs: prompt = lotus.templates.task_instructions.filter_formatter( - doc, user_instruction, examples_df_txt, examples_answers, cot_reasoning + doc, user_instruction, examples_df_txt, examples_answers, cot_reasoning, strategy ) lotus.logger.debug(f"input to model: {prompt}") inputs.append(prompt) - res = model(inputs, logprobs=logprobs, **kwargs) + res = model(inputs, logprobs=logprobs) if logprobs: raw_outputs, raw_logprobs = res else: raw_outputs = res - lotus.logger.debug(f"---\n{raw_outputs}\n---") - outputs, explanations = filter_postprocess(raw_outputs, default=default, cot_reasoning=cot_reasoning is not None) + outputs, explanations = filter_postprocess( + raw_outputs, default=default, cot_reasoning=strategy in ["cot", "zs-cot"] + ) lotus.logger.debug(f"outputs: {outputs}") lotus.logger.debug(f"raw_outputs: {raw_outputs}") lotus.logger.debug(f"explanations: {explanations}") @@ -158,7 +158,7 @@ def __call__( examples_answers=helper_examples_answers, cot_reasoning=helper_cot_reasoning, logprobs=True, - **lotus.settings.model_params, + strategy=helper_strategy, ) high_conf_idxs = set() @@ -195,7 +195,7 @@ def __call__( examples_df_txt=examples_df_txt, examples_answers=examples_answers, cot_reasoning=cot_reasoning, - **lotus.settings.model_params, + strategy=strategy, ) for idx, large_idx in enumerate(low_conf_idxs): @@ -215,7 +215,7 @@ def __call__( examples_df_txt=examples_df_txt, examples_answers=examples_answers, cot_reasoning=cot_reasoning, - **lotus.settings.model_params, + strategy=strategy, ) # find indices where output is True diff --git a/lotus/sem_ops/sem_join.py b/lotus/sem_ops/sem_join.py index e7cc3f7..1db09c5 100644 --- a/lotus/sem_ops/sem_join.py +++ b/lotus/sem_ops/sem_join.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np import pandas as pd @@ -21,7 +21,7 @@ def sem_join( examples_answers: Optional[List[bool]] = None, cot_reasoning: Optional[List[str]] = None, default: bool = True, - **kwargs: Dict[str, Any], + strategy: Optional[str] = None, ) -> Tuple: """ Joins two series using a model. @@ -39,7 +39,6 @@ def sem_join( examples_answers (Optional[List[bool]]): The answers for examples. Defaults to None. cot_reasoning (Optional[List[str]]): The reasoning for CoT. Defaults to None. default (bool): The default value for the join in case of parsing errors. Defaults to True. - **kwargs (Dict[str, Any]): Additional keyword arguments. Returns: Tuple: The join results, filter outputs, all raw outputs, and all explanations. @@ -62,7 +61,7 @@ def sem_join( examples_answers=examples_answers, cot_reasoning=cot_reasoning, default=default, - **kwargs, + strategy=strategy, ) filter_outputs.extend(outputs) all_raw_outputs.extend(raw_outputs) @@ -95,7 +94,7 @@ def sem_join_cascade( examples_answers: Optional[List[bool]] = None, cot_reasoning: Optional[List[str]] = None, default: bool = True, - **kwargs: Dict[str, Any], + strategy: Optional[str] = None, ) -> List[str]: filter_outputs = [] all_raw_outputs = [] @@ -117,7 +116,7 @@ def sem_join_cascade( cot_reasoning=cot_reasoning, default=default, logprobs=True, - **kwargs, + strategy=strategy, ) high_conf_idxs = set() @@ -150,7 +149,7 @@ def sem_join_cascade( examples_df_txt=examples_df_txt, examples_answers=examples_answers, cot_reasoning=cot_reasoning, - **lotus.settings.model_params, + strategy=strategy, ) outputs, raw_outputs, explanations = ( @@ -315,7 +314,7 @@ def __call__( examples_answers=examples_answers, cot_reasoning=cot_reasoning, default=default, - **lotus.settings.model_params, + strategy=strategy, ) else: join_results, filter_outputs, all_raw_outputs, all_explanations = sem_join( @@ -331,7 +330,7 @@ def __call__( examples_answers=examples_answers, cot_reasoning=cot_reasoning, default=default, - **lotus.settings.model_params, + strategy=strategy, ) lotus.logger.debug(f"join_results: {join_results}") lotus.logger.debug(f"all_raw_outputs: {all_raw_outputs}") diff --git a/lotus/sem_ops/sem_map.py b/lotus/sem_ops/sem_map.py index 7ca11b9..99a5d94 100644 --- a/lotus/sem_ops/sem_map.py +++ b/lotus/sem_ops/sem_map.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Callable, List, Optional, Tuple import pandas as pd @@ -16,7 +16,7 @@ def sem_map( examples_df_txt: Optional[str] = None, examples_answers: Optional[List[str]] = None, cot_reasoning: Optional[List[str]] = None, - **kwargs: Dict[str, Any], + strategy: Optional[str] = None, ) -> Tuple: """ Maps a list of documents to a list of outputs using a model. @@ -29,7 +29,6 @@ def sem_map( examples_df_txt (Optional[str]: The text for examples. Defaults to None. examples_answers (Optional[List[str]]): The answers for examples. Defaults to None. cot_reasoning (Optional[List[str]]): The reasoning for CoT. Defaults to None. - **kwargs (Dict[str, Any]): Additional keyword arguments. Returns: Tuple: The outputs, raw outputs, and explanations. @@ -38,19 +37,18 @@ def sem_map( inputs = [] for doc in docs: prompt = lotus.templates.task_instructions.map_formatter( - doc, user_instruction, examples_df_txt, examples_answers, cot_reasoning + doc, user_instruction, examples_df_txt, examples_answers, cot_reasoning, strategy=strategy ) lotus.logger.debug(f"input to model: {prompt}") lotus.logger.debug(f"inputs content to model: {[x.get('content') for x in prompt]}") inputs.append(prompt) # call model - raw_outputs = model(inputs, **kwargs) - - lotus.logger.debug(f"---\n{raw_outputs}\n---") + raw_outputs = model(inputs) # post process results - outputs, explanations = postprocessor(raw_outputs, cot_reasoning=cot_reasoning is not None) + outputs, explanations = postprocessor(raw_outputs, cot_reasoning=strategy in ["cot", "zs-cot"]) + lotus.logger.debug(f"raw_outputs: {raw_outputs}") lotus.logger.debug(f"outputs: {outputs}") lotus.logger.debug(f"explanations: {explanations}") @@ -125,7 +123,7 @@ def __call__( examples_df_txt=examples_df_txt, examples_answers=examples_answers, cot_reasoning=cot_reasoning, - **lotus.settings.model_params, + strategy=strategy, ) new_df = self._obj diff --git a/lotus/sem_ops/sem_topk.py b/lotus/sem_ops/sem_topk.py index b703de6..6da1e4a 100644 --- a/lotus/sem_ops/sem_topk.py +++ b/lotus/sem_ops/sem_topk.py @@ -9,14 +9,23 @@ from lotus.templates import task_instructions -def get_match_prompt_binary(doc1, doc2, user_instruction): - sys_prompt = ( - "Your job is to to select and return the most relevant document to the user's question.\n" - "Carefully read the user's question and the two documents provided below.\n" - 'Respond only with the label of the document such as "Document NUMBER".\n' - "NUMBER must be either 1 or 2, depending on which document is most relevant.\n" - 'You must pick a number and cannot say things like "None" or "Neither"' - ) +def get_match_prompt_binary(doc1, doc2, user_instruction, strategy=None): + if strategy == "zs-cot": + sys_prompt = ( + "Your job is to to select and return the most relevant document to the user's question.\n" + "Carefully read the user's question and the two documents provided below.\n" + 'First give your reasoning. Then you MUST end your output with "Answer: Document 1 or Document 2"\n' + 'You must pick a number and cannot say things like "None" or "Neither"\n' + 'Remember to explicitly state "Answer:" at the end before your choice.' + ) + else: + sys_prompt = ( + "Your job is to to select and return the most relevant document to the user's question.\n" + "Carefully read the user's question and the two documents provided below.\n" + 'Respond only with the label of the document such as "Document NUMBER".\n' + "NUMBER must be either 1 or 2, depending on which document is most relevant.\n" + 'You must pick a number and cannot say things like "None" or "Neither"' + ) prompt = f"Question: {user_instruction}\n\n" for idx, doc in enumerate([doc1, doc2]): @@ -28,6 +37,7 @@ def get_match_prompt_binary(doc1, doc2, user_instruction): def parse_ans_binary(answer): + lotus.logger.debug(f"Response from model: {answer}") try: matches = list(re.finditer(r"Document[\s*](\d+)", answer, re.IGNORECASE)) if len(matches) == 0: @@ -42,27 +52,27 @@ def parse_ans_binary(answer): return True -def compare_batch_binary(pairs, user_instruction, **kwargs): +def compare_batch_binary(pairs, user_instruction, strategy=None): match_prompts = [] results = [] tokens = 0 for doc1, doc2 in pairs: - match_prompts.append(get_match_prompt_binary(doc1, doc2, user_instruction)) + match_prompts.append(get_match_prompt_binary(doc1, doc2, user_instruction, strategy=strategy)) tokens += lotus.settings.lm.count_tokens(match_prompts[-1]) - results = lotus.settings.lm(match_prompts, **kwargs) + results = lotus.settings.lm(match_prompts) results = list(map(parse_ans_binary, results)) return results, tokens -def compare_batch_binary_cascade(pairs, user_instruction, cascade_threshold, **kwargs): +def compare_batch_binary_cascade(pairs, user_instruction, cascade_threshold, strategy=None): match_prompts = [] small_tokens = 0 for doc1, doc2 in pairs: - match_prompts.append(get_match_prompt_binary(doc1, doc2, user_instruction)) + match_prompts.append(get_match_prompt_binary(doc1, doc2, user_instruction, strategy=strategy)) small_tokens += lotus.settings.helper_lm.count_tokens(match_prompts[-1]) - results, helper_logprobs = lotus.settings.helper_lm(match_prompts, logprobs=True, **kwargs) + results, helper_logprobs = lotus.settings.helper_lm(match_prompts, logprobs=True) helper_tokens, helper_confidences = lotus.settings.helper_lm.format_logprobs_for_cascade(helper_logprobs) parsed_results = [] @@ -89,7 +99,7 @@ def compare_batch_binary_cascade(pairs, user_instruction, cascade_threshold, **k large_match_prompts.append(match_prompts[i]) large_tokens += lotus.settings.lm.count_tokens(large_match_prompts[-1]) - results = lotus.settings.lm(large_match_prompts, **kwargs) + results = lotus.settings.lm(large_match_prompts) for idx, res in enumerate(results): new_idx = low_conf_idxs[idx] parsed_res = parse_ans_binary(res) @@ -102,7 +112,7 @@ def compare_batch_binary_cascade(pairs, user_instruction, cascade_threshold, **k def llm_naive_sort( docs: List[str], user_instruction: str, - **kwargs: Dict[str, Any], + strategy: Optional[str] = None, ) -> Tuple[List[int], Dict[str, Any]]: """ Sorts the documents using a naive quadratic method. @@ -110,7 +120,6 @@ def llm_naive_sort( Args: docs (List[str]): The list of documents to sort. user_instruction (str): The user instruction for sorting. - **kwargs (Dict[str, Any]): Additional keyword arguments. Returns: Tuple[List[int], Dict[str, Any]]: The indexes of the top k documents and stats. @@ -122,7 +131,7 @@ def llm_naive_sort( pairs.append((docs[i], docs[j])) llm_calls = len(pairs) - comparisons, tokens = compare_batch_binary(pairs, user_instruction, **kwargs) + comparisons, tokens = compare_batch_binary(pairs, user_instruction, strategy=strategy) votes = [0] * N idx = 0 for i in range(N): @@ -144,8 +153,8 @@ def llm_quicksort( user_instruction: str, k: int, embedding: bool = False, + strategy: Optional[str] = None, cascade_threshold=None, - **kwargs: Dict[str, Any], ) -> Tuple[List[int], Dict[str, Any]]: """ Sorts the documents using quicksort. @@ -156,7 +165,6 @@ def llm_quicksort( k (int): The number of documents to return. embedding (bool): Whether to use embedding optimization. cascade_threshold (Optional[float]): The confidence threshold for cascading to a larger model. - **kwargs (Dict[str, Any]): Additional keyword arguments. Returns: Tuple[List[int], Dict[str, Any]]: The indexes of the top k documents and stats @@ -192,12 +200,15 @@ def partition(indexes, low, high, k): pairs = [(docs[indexes[j]], pivot) for j in range(low, high)] if cascade_threshold is None: - comparisons, tokens = compare_batch_binary(pairs, user_instruction, **kwargs) + comparisons, tokens = compare_batch_binary(pairs, user_instruction, strategy=strategy) stats["total_tokens"] += tokens stats["total_llm_calls"] += len(pairs) else: comparisons, small_tokens, large_tokens, num_large_calls = compare_batch_binary_cascade( - pairs, user_instruction, cascade_threshold, **kwargs + pairs, + user_instruction, + cascade_threshold, + strategy=strategy, ) stats["total_small_tokens"] += small_tokens stats["total_large_tokens"] += large_tokens @@ -236,15 +247,15 @@ class HeapDoc: num_calls = 0 total_tokens = 0 + strategy = None - def __init__(self, doc, user_instruction, idx, kwargs): + def __init__(self, doc, user_instruction, idx): self.doc = doc self.user_instruction = user_instruction self.idx = idx - self.kwargs = kwargs def __lt__(self, other): - prompt = get_match_prompt_binary(self.doc, other.doc, self.user_instruction, **self.kwargs) + prompt = get_match_prompt_binary(self.doc, other.doc, self.user_instruction, strategy=self.strategy) HeapDoc.num_calls += 1 HeapDoc.total_tokens += lotus.settings.lm.count_tokens(prompt) result = lotus.settings.lm(prompt) @@ -252,7 +263,10 @@ def __lt__(self, other): def llm_heapsort( - docs: List[str], user_instruction: str, k: int, **kwargs: Dict[str, Any] + docs: List[str], + user_instruction: str, + k: int, + strategy: Optional[str] = None, ) -> Tuple[List[int], Dict[str, Any]]: """ Sorts the documents using a heap. @@ -261,15 +275,15 @@ def llm_heapsort( docs (List[str]): The list of documents to sort. user_instruction (str): The user instruction for sorting. k (int): The number of documents to return. - **kwargs (Dict[str, Any]): Additional keyword arguments. Returns: Tuple[List[int], Dict[str, Any]]: The indexes of the top k documents and stats. """ HeapDoc.num_calls = 0 HeapDoc.total_tokens = 0 + HeapDoc.strategy = strategy N = len(docs) - heap = [HeapDoc(docs[idx], user_instruction, idx, kwargs) for idx in range(N)] + heap = [HeapDoc(docs[idx], user_instruction, idx) for idx in range(N)] heap = heapq.nsmallest(k, heap) indexes = [heapq.heappop(heap).idx for _ in range(len(heap))] @@ -294,10 +308,10 @@ def __call__( user_instruction: str, K: int, method: str = "quick", + strategy: Optional[str] = None, group_by: Optional[List[str]] = None, cascade_threshold: Optional[float] = None, return_stats: bool = False, - **kwargs, ) -> Union[pd.DataFrame, Tuple[pd.DataFrame, Dict[str, Any]]]: """ Sorts the DataFrame based on the user instruction and returns the top K rows. @@ -309,14 +323,13 @@ def __call__( group_by (Optional[List[str]]): The columns to group by before sorting. Each group will be sorted separately. cascade_threshold (Optional[float]): The confidence threshold for cascading to a larger model. return_stats (bool): Whether to return stats. - **kwargs (Dict[str, Any]): Additional keyword arguments. Returns: Union[pd.DataFrame, Tuple[pd.DataFrame, Dict[str, Any]]]: The sorted DataFrame. If return_stats is True, returns a tuple with the sorted DataFrame and stats """ - lotus.logger.debug(user_instruction) + lotus.logger.debug(f"Sorting DataFrame with user instruction: {user_instruction}") col_li = lotus.nl_expression.parse_cols(user_instruction) - lotus.logger.debug(col_li) + lotus.logger.debug(f"Columns: {col_li}") # check that column exists for column in col_li: @@ -333,10 +346,10 @@ def __call__( user_instruction, K, method=method, + strategy=strategy, group_by=None, cascade_threshold=cascade_threshold, return_stats=return_stats, - **kwargs, ) if return_stats: @@ -360,7 +373,7 @@ def __call__( ) df_txt = task_instructions.df2text(self._obj, col_li) - lotus.logger.debug(df_txt) + lotus.logger.debug(f"df_txt: {df_txt}") formatted_usr_instr = lotus.nl_expression.nle2str(user_instruction, col_li) if method in ["quick", "quick-sem"]: @@ -369,16 +382,16 @@ def __call__( formatted_usr_instr, K, embedding=method == "quick-sem", + strategy=strategy, cascade_threshold=cascade_threshold, - **kwargs, ) elif method == "heap": - sorted_order, stats = llm_heapsort(df_txt, formatted_usr_instr, K, **kwargs) + sorted_order, stats = llm_heapsort(df_txt, formatted_usr_instr, K, strategy=strategy) elif method == "naive": sorted_order, stats = llm_naive_sort( df_txt, formatted_usr_instr, - **kwargs, + strategy=strategy, ) else: raise ValueError(f"Method {method} not recognized") diff --git a/lotus/settings.py b/lotus/settings.py index 46b0a59..7cbe76b 100644 --- a/lotus/settings.py +++ b/lotus/settings.py @@ -52,23 +52,11 @@ def __new__(cls): lm=None, helper_lm=None, rm=None, - model_params={}, - max_batch_size=16, - max_ctx_len=4096, ) cls._instance.__append(config) return cls._instance - def get_max_tokens(self): - model_params = self.config["model_params"] - if "max_tokens" in model_params: - return model_params["max_tokens"] - if "max_new_tokens" in model_params: - return model_params["max_new_tokens"] - - raise ValueError("max_tokens or max_new_tokens not found in model_params") - @property def config(self): thread_id = threading.get_ident() diff --git a/lotus/templates/task_instructions.py b/lotus/templates/task_instructions.py index b05c6b4..d93d7d2 100644 --- a/lotus/templates/task_instructions.py +++ b/lotus/templates/task_instructions.py @@ -40,15 +40,35 @@ def filter_formatter_cot( return messages +def filter_formatter_zs_cot( + df_text: str, + user_instruction: str, +) -> List[str]: + sys_instruction = ( + "The user will povide a claim and some relevant context.\n" + "Your job is to determine whether the claim is true for the given context.\n" + 'First give your reasoning. Then you MUST end your output with "Answer: True or False"' + ) + messages = [ + {"role": "system", "content": sys_instruction}, + ] + + messages.append({"role": "user", "content": f"Context:\n{df_text}\n\nClaim: {user_instruction}"}) + return messages + + def filter_formatter( df_text: str, user_instruction: str, examples_df_text: Optional[List[str]] = None, examples_answer: Optional[List[str]] = None, cot_reasoning: Optional[List[str]] = None, + strategy: Optional[str] = None, ) -> List[str]: if cot_reasoning: return filter_formatter_cot(df_text, user_instruction, examples_df_text, examples_answer, cot_reasoning) + elif strategy == "zs-cot": + return filter_formatter_zs_cot(df_text, user_instruction) sys_instruction = ( "The user will povide a claim and some relevant context.\n" @@ -117,15 +137,40 @@ def map_formatter_cot( return messages +def map_formatter_zs_cot( + df_text: str, + user_instruction: str, +) -> List[str]: + sys_instruction = ( + "The user will povide an instruction and some relevant context.\n" + "Your job is to answer the user's instruction given the context." + 'First give your reasoning. Then you MUST end your output with "Answer: your answer"' + ) + messages = [ + {"role": "system", "content": sys_instruction}, + ] + + messages.append( + { + "role": "user", + "content": f"Context:\n{df_text}\n\Instruction: {user_instruction}", + } + ) + return messages + + def map_formatter( df_text: str, user_instruction: str, examples_df_text: Optional[List[str]] = None, examples_answer: Optional[List[str]] = None, cot_reasoning: Optional[List[str]] = None, + strategy: Optional[str] = None, ) -> List[str]: if cot_reasoning: return map_formatter_cot(df_text, user_instruction, examples_df_text, examples_answer, cot_reasoning) + elif strategy == "zs-cot": + return map_formatter_zs_cot(df_text, user_instruction) sys_instruction = ( "The user will povide an instruction and some relevant context.\n" diff --git a/pyproject.toml b/pyproject.toml index 709b415..314b2d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "lotus-ai" -version = "0.1.3" +version = "0.1.4" description = "lotus" readme = "README.md" authors = [ @@ -67,3 +67,6 @@ line-ending = "auto" [tool.ruff.lint.per-file-ignores] "**/{docs}/*" = ["ALL"] "**__init__.py" = ["ALL"] + +[tool.setuptools] +packages = ["lotus"] \ No newline at end of file