Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add code to use openai_base_url and use OpenAI's model lister function #189

Merged
merged 2 commits into from
Mar 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 38 additions & 48 deletions installer/client/cli/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import requests
import os
from openai import OpenAI
from openai import OpenAI, APIConnectionError
import asyncio
import pyperclip
import sys
Expand Down Expand Up @@ -36,12 +36,10 @@ def __init__(self, args, pattern="", env_file="~/.config/fabric/.env"):
# Expand the tilde to the full path
env_file = os.path.expanduser(env_file)
load_dotenv(env_file)
try:
apikey = os.environ["OPENAI_API_KEY"]
self.client = OpenAI()
self.client.api_key = apikey
except:
print("No API key found. Use the --apikey option to set the key")
assert 'OPENAI_API_KEY' in os.environ, "Error: OPENAI_API_KEY not found in environment variables. Please run fabric --setup and add the key."
api_key = os.environ['OPENAI_API_KEY']
base_url = os.environ.get('OPENAI_BASE_URL', 'https://api.openai.com/v1/')
self.client = OpenAI(api_key=api_key, base_url=base_url)
self.local = False
self.config_pattern_directory = config_directory
self.pattern = pattern
Expand Down Expand Up @@ -253,28 +251,23 @@ def fetch_available_models(self):
fullOllamaList = []
claudeList = ['claude-3-opus-20240229']
try:
headers = {
"Authorization": f"Bearer {self.client.api_key}"
}
response = requests.get(
"https://api.openai.com/v1/models", headers=headers)

if response.status_code == 200:
models = response.json().get("data", [])
# Filter only gpt models
gpt_models = [model for model in models if model.get(
"id", "").startswith(("gpt"))]
# Sort the models alphabetically by their ID
sorted_gpt_models = sorted(
gpt_models, key=lambda x: x.get("id"))

for model in sorted_gpt_models:
gptlist.append(model.get("id"))
models = [model.id for model in self.client.models.list().data]
except APIConnectionError as e:
if getattr(e.__cause__, 'args', [''])[0] == "Illegal header value b'Bearer '":
print("Error: Cannot connect to the OpenAI API Server because the API key is not set. Please run fabric --setup and add a key.")
else:
print(f"Failed to fetch models: HTTP {response.status_code}")
sys.exit()
except:
print('No OpenAI API key found. Please run fabric --setup and add the key if you wish to interact with openai')
print(f'{e.message} trying to access {e.request.url}: {getattr(e.__cause__, 'args', [''])}')
sys.exit()
except Exception as e:
print(f"Error: {getattr(e.__context__, 'args', [''])[0]}")
sys.exit()
if "/" in models[0] or "\\" in models[0]:
# lmstudio returns full paths to models. Iterate and truncate everything before and including the last slash
gptlist = [item[item.rfind("/") + 1:] if "/" in item else item[item.rfind("\\") + 1:] for item in models]
else:
# Keep items that start with "gpt"
gptlist = [item for item in models if item.startswith("gpt")]
gptlist.sort()
import ollama
try:
default_modelollamaList = ollama.list()['models']
Expand Down Expand Up @@ -430,27 +423,24 @@ def __init__(self):
pass

def fetch_available_models(self):
headers = {
"Authorization": f"Bearer {self.openaiapi_key}"
}

response = requests.get(
"https://api.openai.com/v1/models", headers=headers)

if response.status_code == 200:
models = response.json().get("data", [])
# Filter only gpt models
gpt_models = [model for model in models if model.get(
"id", "").startswith(("gpt"))]
# Sort the models alphabetically by their ID
sorted_gpt_models = sorted(
gpt_models, key=lambda x: x.get("id"))

for model in sorted_gpt_models:
self.gptlist.append(model.get("id"))
else:
print(f"Failed to fetch models: HTTP {response.status_code}")
try:
models = [model.id for model in self.client.models.list().data]
except APIConnectionError as e:
if getattr(e.__cause__, 'args', [''])[0] == "Illegal header value b'Bearer '":
print("Error: Cannot connect to the OpenAI API Server because the API key is not set. Please run fabric --setup and add a key.")
else:
print(f'{e.message} trying to access {e.request.url}: {getattr(e.__cause__, 'args', [''])}')
sys.exit()
except Exception as e:
print(f"Error: {getattr(e.__context__, 'args', [''])[0]}")
sys.exit()
if "/" in models[0] or "\\" in models[0]:
# lmstudio returns full paths to models. Iterate and truncate everything before and including the last slash
self.gptlist = [item[item.rfind("/") + 1:] if "/" in item else item[item.rfind("\\") + 1:] for item in models]
else:
# Keep items that start with "gpt"
self.gptlist = [item for item in models if item.startswith("gpt")]
self.gptlist.sort()
import ollama
try:
default_modelollamaList = ollama.list()['models']
Expand Down