diff --git a/README.md b/README.md index 028c5f1..0d5e40b 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,11 @@ python3 -m venv .venv # activate the virtual env . .venv/bin/activate -pip install langroid +# install with `hf-embeddings` extra to be able to use sentence_transformers embeddings +pip install "langroid[hf-embeddings]" + +# or to update an existing installation: +pip install --upgrade "langroid[hf-embeddings]" ``` diff --git a/examples/basic/chat-search.py b/examples/basic/chat-search.py index 5435e72..3b7cd39 100644 --- a/examples/basic/chat-search.py +++ b/examples/basic/chat-search.py @@ -1,11 +1,25 @@ +""" +This is a basic example of a chatbot that uses the GoogleSearchTool: +when the LLM doesn't know the answer to a question, it will use the tool to +search the web for relevant results, and then use the results to answer the +question. + +NOTE: running this example requires setting the GOOGLE_API_KEY and GOOGLE_CSE_ID +environment variables in your `.env` file, as explained in the +[README](https://github.com/langroid/langroid#gear-installation-and-setup). +""" + import typer from rich import print from rich.prompt import Prompt +from pydantic import BaseSettings +from dotenv import load_dotenv from langroid.agent.chat_agent import ChatAgent, ChatAgentConfig from langroid.agent.task import Task -from langroid.agent.stateless_tools.google_search_tool import GoogleSearchTool -from langroid.language_models.openai_gpt import OpenAIChatModel, OpenAIGPTConfig +from langroid.language_models.base import LocalModelConfig +from langroid.agent.tools.google_search_tool import GoogleSearchTool +from langroid.language_models.openai_gpt import OpenAIGPTConfig from langroid.utils.configuration import set_global, Settings from langroid.utils.logging import setup_colored_logging @@ -15,10 +29,27 @@ setup_colored_logging() -def chat() -> None: +class CLIOptions(BaseSettings): + local: bool = False + api_base: str = "http://localhost:8000/v1" + local_model: str = "" + local_ctx: int = 2048 + # use completion endpoint for chat? + # if so, we should format chat->prompt ourselves, if we know the required syntax + completion: bool = False + + class Config: + extra = "forbid" + env_prefix = "" + + +def chat(opts: CLIOptions) -> None: print( """ - [blue]Welcome to the basic chatbot! + [blue]Welcome to the Google Search chatbot! + I will try to answer your questions, relying on (summaries of links from) + Google Search when needed. + Enter x or q to quit at any point. """ ) @@ -27,11 +58,32 @@ def chat() -> None: default="Default: 'You are a helpful assistant'", ) + load_dotenv() + + # create the appropriate OpenAIGPTConfig depending on local model or not + + if opts.local or opts.local_model: + # assumes local endpoint is either the default http://localhost:8000/v1 + # or if not, it has been set in the .env file as the value of + # OPENAI_LOCAL.API_BASE + local_model_config = LocalModelConfig( + api_base=opts.api_base, + model=opts.local_model, + context_length=opts.local_ctx, + use_completion_for_chat=opts.completion, + ) + llm_config = OpenAIGPTConfig( + local=local_model_config, + timeout=180, + ) + else: + # defaults to chat_model = OpenAIChatModel.GPT4 + llm_config = OpenAIGPTConfig() + config = ChatAgentConfig( system_message=sys_msg, - llm=OpenAIGPTConfig( - chat_model=OpenAIChatModel.GPT4, - ), + llm=llm_config, + vecdb=None, ) agent = ChatAgent(config) agent.enable_message(GoogleSearchTool) @@ -39,12 +91,27 @@ def chat() -> None: agent, system_message=""" You are a helpful assistant. You will try your best to answer my questions. - If you don't know you can use up to 2 results from the `web search` - tool/function-call to help you with answering the question. + If you cannot answer from your own knowledge, you can use up to 5 + results from the `web_search` tool/function-call to help you with + answering the question. Be very concise in your responses, use no more than 1-2 sentences. + When you answer based on a web search, First show me your answer, + and then show me the SOURCE(s) and EXTRACT(s) to justify your answer, + in this format: + + + SOURCE: https://www.wikihow.com/Be-a-Good-Assistant-Manager + EXTRACT: Be a Good Assistant ... requires good leadership skills. + + SOURCE: ... + EXTRACT: ... + + For the EXTRACT, ONLY show up to first 3 words, and last 3 words. """, ) - task.run() + # local models do not like the first message to be empty + user_message = "Hello." if (opts.local or opts.local_model) else None + task.run(user_message) @app.command() @@ -55,6 +122,19 @@ def main( cache_type: str = typer.Option( "redis", "--cachetype", "-ct", help="redis or momento" ), + local: bool = typer.Option(False, "--local", "-l", help="use local llm"), + local_model: str = typer.Option( + "", "--local_model", "-lm", help="local model path" + ), + api_base: str = typer.Option( + "http://localhost:8000/v1", "--api_base", "-api", help="local model api base" + ), + local_ctx: int = typer.Option( + 2048, "--local_ctx", "-lc", help="local llm context size (default 2048)" + ), + completion: bool = typer.Option( + False, "--completion", "-c", help="use completion endpoint for chat" + ), ) -> None: set_global( Settings( @@ -64,7 +144,14 @@ def main( cache_type=cache_type, ) ) - chat() + opts = CLIOptions( + local=local, + api_base=api_base, + local_model=local_model, + local_ctx=local_ctx, + completion=completion, + ) + chat(opts) if __name__ == "__main__": diff --git a/examples/data-qa/sql-chat/demo.db b/examples/data-qa/sql-chat/demo.db new file mode 100644 index 0000000..a6b515b Binary files /dev/null and b/examples/data-qa/sql-chat/demo.db differ diff --git a/examples/data-qa/sql-chat/demo.json b/examples/data-qa/sql-chat/demo.json new file mode 100644 index 0000000..4a487c1 --- /dev/null +++ b/examples/data-qa/sql-chat/demo.json @@ -0,0 +1,25 @@ +{ + "departments": { + "description": "The 'departments' table holds details about the various departments. It relates to the 'employees' table via a foreign key in the 'employees' table.", + "columns": { + "id": "A unique identifier for a department. This ID is used as a foreign key in the 'employees' table.", + "name": "The name of the department." + } + }, + "employees": { + "description": "The 'employees' table contains information about the employees. It relates to the 'departments' and 'sales' tables via foreign keys.", + "columns": { + "id": "A unique identifier for an employee. This ID is used as a foreign key in the 'sales' table.", + "name": "The name of the employee.", + "department_id": "The ID of the department the employee belongs to. This is a foreign key referencing the 'id' in the 'departments' table." + } + }, + "sales": { + "description": "The 'sales' table keeps a record of all sales made by employees. It relates to the 'employees' table via a foreign key.", + "columns": { + "id": "A unique identifier for a sale.", + "amount": "The amount of the sale in eastern Caribbean dollars (XCD).", + "employee_id": "The ID of the employee who made the sale. This is a foreign key referencing the 'id' in the 'employees' table." + } + } +} \ No newline at end of file diff --git a/examples/data-qa/sql-chat/sql_chat.py b/examples/data-qa/sql-chat/sql_chat.py new file mode 100644 index 0000000..b243c67 --- /dev/null +++ b/examples/data-qa/sql-chat/sql_chat.py @@ -0,0 +1,208 @@ +""" +Example showing how to chat with a SQL database. + +Note if you are using this with a postgres db, you will need to: + +(a) Install PostgreSQL dev libraries for your platform, e.g. + - `sudo apt-get install libpq-dev` on Ubuntu, + - `brew install postgresql` on Mac, etc. +(b) langroid with the postgres extra, e.g. `pip install langroid[postgres]` + or `poetry add langroid[postgres]` or `poetry install -E postgres`. + If this gives you an error, try `pip install psycopg2-binary` in your virtualenv. +""" +import typer +from rich import print +from rich.prompt import Prompt +from typing import Dict, Any +import json +import os +from pydantic import BaseSettings + +from sqlalchemy import create_engine, inspect +from sqlalchemy.engine import Engine +from prettytable import PrettyTable + +from utils import get_database_uri, fix_uri +from langroid.agent.special.sql.sql_chat_agent import ( + SQLChatAgent, + SQLChatAgentConfig, +) +from langroid.agent.task import Task +from langroid.language_models.openai_gpt import OpenAIChatModel, OpenAIGPTConfig +from langroid.utils.configuration import set_global, Settings +from langroid.utils.logging import setup_colored_logging +import logging + +logger = logging.getLogger(__name__) + + +app = typer.Typer() + +setup_colored_logging() + + +def create_descriptions_file(filepath: str, engine: Engine) -> None: + """ + Create an empty descriptions JSON file for SQLAlchemy tables. + + This function inspects the database, generates a template for table and + column descriptions, and writes that template to a new JSON file. + + Args: + filepath: The path to the file where the descriptions should be written. + engine: The SQLAlchemy Engine connected to the database to describe. + + Raises: + FileExistsError: If the file at `filepath` already exists. + + Returns: + None + """ + if os.path.exists(filepath): + raise FileExistsError(f"File {filepath} already exists.") + + inspector = inspect(engine) + descriptions: Dict[str, Dict[str, Any]] = {} + + for table_name in inspector.get_table_names(): + descriptions[table_name] = { + "description": "", + "columns": {col["name"]: "" for col in inspector.get_columns(table_name)}, + } + + with open(filepath, "w") as f: + json.dump(descriptions, f, indent=4) + + +def load_context_descriptions(engine: Engine) -> dict: + """ + Ask the user for a path to a JSON file and load context descriptions from it. + + Returns: + dict: The context descriptions, or an empty dictionary if the user decides to skip this step. + """ + + while True: + filepath = Prompt.ask( + "[blue]Enter the path to your context descriptions file. \n" + "('n' to create a NEW file, 's' to SKIP, or Hit enter to use DEFAULT) ", + default="examples/data-qa/sql-chat/demo.json", + ) + + if filepath.strip() == "s": + return {} + + if filepath.strip() == "n": + filepath = Prompt.ask( + "[blue]To create a new context description file, enter the path", + default="examples/data-qa/sql-chat/description.json", + ) + print(f"[blue]Creating new context description file at {filepath}...") + create_descriptions_file(filepath, engine) + print( + f"[blue] Please fill in the descriptions in {filepath}, " + f"then try again." + ) + + # Try to load the file + if not os.path.exists(filepath): + print(f"[red]The file '{filepath}' does not exist. Please try again.") + continue + + try: + with open(filepath, "r") as file: + return json.load(file) + except json.JSONDecodeError: + print( + f"[red]The file '{filepath}' is not a valid JSON file. Please try again." + ) + + +class CLIOptions(BaseSettings): + fn_api: bool = True # whether to use function-calling instead of langroid Tools + schema_tools: bool = False # whether to use schema tools + + +def chat(opts: CLIOptions) -> None: + print("[blue]Welcome to the SQL database chatbot!\n") + database_uri = Prompt.ask( + """ + [blue]Enter the URI for your SQL database + (type 'i' for interactive, or hit enter for default) + """, + default="sqlite:///examples/data-qa/sql-chat/demo.db", + ) + + if database_uri == "i": + database_uri = get_database_uri() + + database_uri = fix_uri(database_uri) + logger.warning(f"Using database URI: {database_uri}") + + # Create engine and inspector + engine = create_engine(database_uri) + inspector = inspect(engine) + + context_descriptions = load_context_descriptions(engine) + + # Get table names + table_names = inspector.get_table_names() + + for table_name in table_names: + print(f"[blue]Table: {table_name}") + + # Create a new table for the columns + table = PrettyTable() + table.field_names = ["Column Name", "Type"] + + # Get the columns for the table + columns = inspector.get_columns(table_name) + for column in columns: + table.add_row([column["name"], column["type"]]) + + print(table) + + agent = SQLChatAgent( + config=SQLChatAgentConfig( + database_uri=database_uri, + use_tools=not opts.fn_api, + use_functions_api=opts.fn_api, + context_descriptions=context_descriptions, # Add context descriptions to the config + use_schema_tools=opts.schema_tools, + llm=OpenAIGPTConfig( + chat_model=OpenAIChatModel.GPT4, + ), + ) + ) + task = Task(agent) + task.run() + + +@app.command() +def main( + debug: bool = typer.Option(False, "--debug", "-d", help="debug mode"), + no_stream: bool = typer.Option(False, "--nostream", "-ns", help="no streaming"), + nocache: bool = typer.Option(False, "--nocache", "-nc", help="don't use cache"), + tools: bool = typer.Option( + False, "--tools", "-t", help="use langroid tools instead of function-calling" + ), + cache_type: str = typer.Option( + "redis", "--cachetype", "-ct", help="redis or momento" + ), + schema_tools: bool = typer.Option( + False, "--schema_tools", "-st", help="use schema tools" + ), +) -> None: + set_global( + Settings( + debug=debug, + cache=not nocache, + stream=not no_stream, + cache_type=cache_type, + ) + ) + chat(CLIOptions(fn_api=not tools, schema_tools=schema_tools)) + + +if __name__ == "__main__": + app() diff --git a/examples/data-qa/sql-chat/utils.py b/examples/data-qa/sql-chat/utils.py new file mode 100644 index 0000000..addf7df --- /dev/null +++ b/examples/data-qa/sql-chat/utils.py @@ -0,0 +1,95 @@ +from rich import print +from rich.prompt import Prompt +import urllib.parse + +from langroid.parsing.utils import closest_string +import logging + +logger = logging.getLogger(__name__) + + +DEFAULT_PORTS = dict( + postgresql=5432, + mysql=3306, + mariadb=3306, + mssql=1433, + oracle=1521, + mongodb=27017, + redis=6379, +) + + +def fix_uri(uri: str) -> str: + """Fixes a URI by percent-encoding the username and password.""" + + if "%" in uri: + return uri # already %-encoded, so don't do anything + # Split by '://' + scheme_part, rest_of_uri = uri.split("://", 1) + + # Get the final '@' (assuming only the last '@' is the separator for user info) + last_at_index = rest_of_uri.rfind("@") + userinfo_part = rest_of_uri[:last_at_index] + rest_of_uri_after_at = rest_of_uri[last_at_index + 1 :] + + if ":" not in userinfo_part: + return uri + # Split userinfo by ':' to get username and password + username, password = userinfo_part.split(":", 1) + + # Percent-encode the username and password + username = urllib.parse.quote(username) + password = urllib.parse.quote(password) + + # Construct the fixed URI + fixed_uri = f"{scheme_part}://{username}:{password}@{rest_of_uri_after_at}" + + return fixed_uri + + +def _create_database_uri( + scheme: str, + username: str, + password: str, + hostname: str, + port: int, + databasename: str, +) -> str: + """Generates a database URI based on provided parameters.""" + username = urllib.parse.quote_plus(username) + password = urllib.parse.quote_plus(password) + port_str = f":{port}" if port else "" + return f"{scheme}://{username}:{password}@{hostname}{port_str}/{databasename}" + + +def get_database_uri() -> str: + """Main function to gather input and print the database URI.""" + scheme_input = Prompt.ask("Enter the database type (e.g., postgresql, mysql)") + scheme = closest_string(scheme_input, list(DEFAULT_PORTS.keys())) + + # Handle if no close match is found. + if scheme == "No match found": + print(f"No close match found for '{scheme_input}'. Please verify your input.") + return + + username = Prompt.ask("Enter the database username") + password = Prompt.ask("Enter the database password", password=True) + hostname = Prompt.ask("Enter the database hostname") + + # Inform user of default port, and let them choose to override or leave blank + default_port = DEFAULT_PORTS.get(scheme, "") + port_msg = ( + f"Enter the database port " + f"(hit enter to use default: {default_port} or specify another value)" + ) + + port = Prompt.ask(port_msg, default=default_port) + if not port: # If user pressed enter without entering anything + port = default_port + port = int(port) + + databasename = Prompt.ask("Enter the database name") + + uri = _create_database_uri(scheme, username, password, hostname, port, databasename) + print(f"Your {scheme.upper()} URI is:\n{uri}") + return uri diff --git a/examples/data-qa/table_chat.py b/examples/data-qa/table_chat.py index 9b8f869..8f1b3f4 100644 --- a/examples/data-qa/table_chat.py +++ b/examples/data-qa/table_chat.py @@ -19,7 +19,10 @@ def chat() -> None: print("[blue]Welcome to the tabular-data chatbot!\n") - path = Prompt.ask("[blue]Enter a local path or URL to a tabular dataset") + path = Prompt.ask( + "[blue]Enter a local path or URL to a tabular dataset (hit enter to use default)\n", + default="https://raw.githubusercontent.com/fivethirtyeight/data/master/airline-safety/airline-safety.csv" + ) agent = TableChatAgent( config=TableChatAgentConfig( diff --git a/pyproject.toml b/pyproject.toml index adffb1f..40ef262 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ packages = [{include = "examples"}] [tool.poetry.dependencies] python = ">=3.11,<3.12" -langroid = {version="^0.1.73", extras = ["hf-embeddings"]} +langroid = {version="^0.1.77", extras = ["hf-embeddings"]} [build-system]