From 92203d23e0c0d1388f6bae924b15bfdd88c9c70c Mon Sep 17 00:00:00 2001 From: presbrey Date: Sun, 25 Aug 2024 17:20:34 -0400 Subject: [PATCH 1/3] add Farm test, maximize coverage --- farm_test.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/farm_test.go b/farm_test.go index e443d65..1445f45 100644 --- a/farm_test.go +++ b/farm_test.go @@ -195,4 +195,11 @@ func TestFarmMethods(t *testing.T) { t.Errorf("Expected 2 llamas, got %d", len(llamas)) } }) + + t.Run("TestOllamaFarm", func(t *testing.T) { + llama := farm.First(&ollamafarm.Where{Group: "group1"}) + if llama.Farm() != farm { + t.Error("Expected farm, got nil") + } + }) } From 2b3a28c362930866b4ad1a15db2b4f24c508119c Mon Sep 17 00:00:00 2001 From: presbrey Date: Mon, 2 Sep 2024 22:38:50 -0400 Subject: [PATCH 2/3] feat: add server package with version and models endpoints --- cmd/ollamafarmd/main.go | 25 +++++++++++++++++++++++++ farm.go | 22 ++++++++++++++++++++++ server/server.go | 41 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 88 insertions(+) create mode 100644 cmd/ollamafarmd/main.go create mode 100644 server/server.go diff --git a/cmd/ollamafarmd/main.go b/cmd/ollamafarmd/main.go new file mode 100644 index 0000000..36cf9d6 --- /dev/null +++ b/cmd/ollamafarmd/main.go @@ -0,0 +1,25 @@ +package main + +import ( + "log" + "net/http" + + "github.com/presbrey/ollamafarm" + "github.com/presbrey/ollamafarm/server" +) + +func main() { + farm := ollamafarm.New() + + // Register your Ollama clients here + // For example: + // farm.RegisterURL("http://localhost:11434", nil) + + s := server.NewServer(farm) + + http.HandleFunc("/version", s.VersionHandler) + http.HandleFunc("/models", s.ModelsHandler) + + log.Println("Server starting on :8080") + log.Fatal(http.ListenAndServe(":8080", nil)) +} diff --git a/farm.go b/farm.go index 21c0d66..2746e9b 100644 --- a/farm.go +++ b/farm.go @@ -147,3 +147,25 @@ func (f *Farm) ModelCounts(where *Where) map[string]uint { return modelCounts } + +// AllModels returns a list of all unique models available across all registered Ollamas. +func (f *Farm) AllModels() []string { + f.mu.RLock() + defer f.mu.RUnlock() + + modelSet := make(map[string]struct{}) + for _, ollama := range f.ollamas { + if !ollama.properties.Offline { + for model := range ollama.models { + modelSet[model] = struct{}{} + } + } + } + + models := make([]string, 0, len(modelSet)) + for model := range modelSet { + models = append(models, model) + } + + return models +} diff --git a/server/server.go b/server/server.go new file mode 100644 index 0000000..40a6cc6 --- /dev/null +++ b/server/server.go @@ -0,0 +1,41 @@ +package server + +import ( + "encoding/json" + "net/http" + + "github.com/presbrey/ollamafarm" +) + +type Server struct { + Farm *ollamafarm.Farm +} + +func NewServer(farm *ollamafarm.Farm) *Server { + return &Server{Farm: farm} +} + +func (s *Server) VersionHandler(w http.ResponseWriter, r *http.Request) { + ollama := s.Farm.First(nil) + if ollama == nil { + http.Error(w, "No available Ollama instances", http.StatusServiceUnavailable) + return + } + + ctx := r.Context() + version, err := ollama.Client().Version(ctx) + if err != nil { + http.Error(w, "Failed to get version", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{"version": version}) +} + +func (s *Server) ModelsHandler(w http.ResponseWriter, r *http.Request) { + models := s.Farm.AllModels() + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(models) +} From d8401f729641f7dce9582e63cd0c431f19f54332 Mon Sep 17 00:00:00 2001 From: presbrey Date: Mon, 2 Sep 2024 22:46:47 -0400 Subject: [PATCH 3/3] server: Add catch-all proxy handler --- cmd/ollamafarmd/main.go | 7 +--- farm.go | 57 +++++++++++++++++++++---- ollama.go | 12 +++++- server/server.go | 93 ++++++++++++++++++++++++++++++++++++++--- types.go | 6 ++- 5 files changed, 153 insertions(+), 22 deletions(-) diff --git a/cmd/ollamafarmd/main.go b/cmd/ollamafarmd/main.go index 36cf9d6..b50ce6a 100644 --- a/cmd/ollamafarmd/main.go +++ b/cmd/ollamafarmd/main.go @@ -17,9 +17,6 @@ func main() { s := server.NewServer(farm) - http.HandleFunc("/version", s.VersionHandler) - http.HandleFunc("/models", s.ModelsHandler) - - log.Println("Server starting on :8080") - log.Fatal(http.ListenAndServe(":8080", nil)) + log.Println("Server starting on :11343") + log.Fatal(http.ListenAndServe(":11434", s)) } diff --git a/farm.go b/farm.go index 2746e9b..fbf075f 100644 --- a/farm.go +++ b/farm.go @@ -33,11 +33,11 @@ func NewWithOptions(options *Options) *Farm { } // RegisterClient adds a new Ollama to the Farm if it doesn't already exist. -func (f *Farm) RegisterClient(id string, client *api.Client, properties *Properties) { +func (f *Farm) RegisterClient(name string, client *api.Client, properties *Properties) { f.mu.Lock() defer f.mu.Unlock() - if _, exists := f.ollamas[id]; exists { + if _, exists := f.ollamas[name]; exists { return } @@ -49,18 +49,50 @@ func (f *Farm) RegisterClient(id string, client *api.Client, properties *Propert } ollama := &Ollama{ + name: name, + client: client, farm: f, - models: make(map[string]bool), + models: make(map[string]*api.ListModelResponse), properties: p, } - f.ollamas[id] = ollama + f.ollamas[name] = ollama go ollama.updateTickers() } -// RegisterURL adds a new Ollama to the Farm using the baseURL as the ID. -func (f *Farm) RegisterURL(baseURL string, properties *Properties) error { +// RegisterClient adds a new Ollama to the Farm if it doesn't already exist. +func (f *Farm) RegisterClientURL(name string, client *api.Client, properties *Properties, url *url.URL) { + f.mu.Lock() + defer f.mu.Unlock() + + if _, exists := f.ollamas[name]; exists { + return + } + + p := Properties{} + if properties != nil { + p.Group = properties.Group + p.Offline = properties.Offline + p.Priority = properties.Priority + } + + ollama := &Ollama{ + name: name, + url: url, + + client: client, + farm: f, + models: make(map[string]*api.ListModelResponse), + properties: p, + } + f.ollamas[name] = ollama + + go ollama.updateTickers() +} + +// RegisterNamedURL adds a new Ollama to the Farm using the given name as the ID. +func (f *Farm) RegisterNamedURL(name, baseURL string, properties *Properties) error { parsedURL, err := url.Parse(baseURL) if err != nil { return err @@ -68,10 +100,19 @@ func (f *Farm) RegisterURL(baseURL string, properties *Properties) error { client := api.NewClient(parsedURL, http.DefaultClient) - f.RegisterClient(parsedURL.String(), client, properties) + f.RegisterClientURL(name, client, properties, parsedURL) return nil } +// RegisterURL adds a new Ollama to the Farm using the baseURL as the ID. +func (f *Farm) RegisterURL(baseURL string, properties *Properties) error { + parsedURL, err := url.Parse(baseURL) + if err != nil { + return err + } + return f.RegisterNamedURL(parsedURL.String(), baseURL, properties) +} + // First returns the first Ollama that matches the given where. func (f *Farm) First(where *Where) *Ollama { f.mu.RLock() @@ -120,7 +161,7 @@ func (f *Farm) matchesWhere(ollama *Ollama, where *Where) bool { if where.Group != "" && ollama.properties.Group != where.Group { return false } - if where.Model != "" && !ollama.models[where.Model] { + if where.Model != "" && ollama.models[where.Model] == nil { return false } if where.Offline != ollama.properties.Offline { diff --git a/ollama.go b/ollama.go index 9605dee..95288f5 100644 --- a/ollama.go +++ b/ollama.go @@ -2,11 +2,19 @@ package ollamafarm import ( "context" + "net/url" "time" "github.com/ollama/ollama/api" ) +// BaseURL returns the base URL of the Ollama. +func (ollama *Ollama) BaseURL() *url.URL { + ollama.farm.mu.RLock() + defer ollama.farm.mu.RUnlock() + return ollama.url +} + // Client returns the Ollama client. func (ollama *Ollama) Client() *api.Client { ollama.farm.mu.RLock() @@ -48,11 +56,11 @@ func (ollama *Ollama) updateModels() { ollama.farm.mu.Lock() if err != nil { ollama.properties.Offline = true - ollama.models = make(map[string]bool) + ollama.models = make(map[string]*api.ListModelResponse) } else { ollama.properties.Offline = false for _, model := range listResponse.Models { - ollama.models[model.Name] = true + ollama.models[model.Name] = &model } } ollama.farm.mu.Unlock() diff --git a/server/server.go b/server/server.go index 40a6cc6..e977ad5 100644 --- a/server/server.go +++ b/server/server.go @@ -2,21 +2,32 @@ package server import ( "encoding/json" + "io" + "log" "net/http" "github.com/presbrey/ollamafarm" ) +// Server is an HTTP server that proxies requests to Ollamas on a Farm. type Server struct { - Farm *ollamafarm.Farm + farm *ollamafarm.Farm + mux *http.ServeMux } +// NewServer creates a new Server instance with the given Farm. func NewServer(farm *ollamafarm.Farm) *Server { - return &Server{Farm: farm} + s := &Server{farm: farm} + mux := http.NewServeMux() + mux.HandleFunc("/api/tags", s.handleTags) + mux.HandleFunc("/api/version", s.handleVersion) + mux.HandleFunc("/", s.catchAllPost) + s.mux = mux + return s } -func (s *Server) VersionHandler(w http.ResponseWriter, r *http.Request) { - ollama := s.Farm.First(nil) +func (s *Server) handleVersion(w http.ResponseWriter, r *http.Request) { + ollama := s.farm.First(nil) if ollama == nil { http.Error(w, "No available Ollama instances", http.StatusServiceUnavailable) return @@ -33,9 +44,79 @@ func (s *Server) VersionHandler(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(map[string]string{"version": version}) } -func (s *Server) ModelsHandler(w http.ResponseWriter, r *http.Request) { - models := s.Farm.AllModels() +func (s *Server) handleTags(w http.ResponseWriter, r *http.Request) { + models := s.farm.AllModels() w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(models) } + +func (s *Server) catchAllPost(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var body map[string]interface{} + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, "Invalid JSON body", http.StatusBadRequest) + return + } + + model, ok := body["model"].(string) + if !ok { + http.Error(w, "Missing or invalid 'model' field", http.StatusBadRequest) + return + } + + ollama := s.farm.First(&ollamafarm.Where{Model: model}) + if ollama == nil { + http.Error(w, "No available Ollama instance for the specified model", http.StatusServiceUnavailable) + return + } + + // Create a new request to the selected Ollama + proxyURL := ollama.BaseURL().ResolveReference(r.URL) + proxyReq, err := http.NewRequest(r.Method, proxyURL.String(), r.Body) + if err != nil { + http.Error(w, "Error creating proxy request", http.StatusInternalServerError) + return + } + + // Copy headers + for key, values := range r.Header { + for _, value := range values { + proxyReq.Header.Add(key, value) + } + } + + // Send the request to the Ollama instance + resp, err := http.DefaultClient.Do(proxyReq) + if err != nil { + http.Error(w, "Error proxying request", http.StatusInternalServerError) + return + } + defer resp.Body.Close() + + // Copy the response headers + for key, values := range resp.Header { + for _, value := range values { + w.Header().Add(key, value) + } + } + + // Set the status code + w.WriteHeader(resp.StatusCode) + + // Copy the response body + if _, err := io.Copy(w, resp.Body); err != nil { + // We've already started writing the response, so we can't use http.Error here + // Just log the error + log.Printf("Error copying response body: %v", err) + } +} + +// ServeHTTP implements the http.Handler interface. +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + s.mux.ServeHTTP(w, r) +} diff --git a/types.go b/types.go index 4a6832b..e67a99a 100644 --- a/types.go +++ b/types.go @@ -2,6 +2,7 @@ package ollamafarm import ( "net/http" + "net/url" "sync" "time" @@ -18,9 +19,12 @@ type Farm struct { // Ollama stores information about an Ollama server. type Ollama struct { + name string + url *url.URL + client *api.Client farm *Farm - models map[string]bool + models map[string]*api.ListModelResponse properties Properties }