commit b63b81b121f1b521231a20e2d32e81d7d696a926 Author: Benson Wong Date: Thu Oct 3 20:20:01 2024 -0700 first commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..062e1a6 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +.aider* +.env +build/ \ No newline at end of file diff --git a/bin/simple-responder/simple-responder.go b/bin/simple-responder/simple-responder.go new file mode 100644 index 0000000..aea99fe --- /dev/null +++ b/bin/simple-responder/simple-responder.go @@ -0,0 +1,37 @@ +package main + +import ( + "flag" + "fmt" + "net/http" +) + +func main() { + // Define a command-line flag for the port + port := flag.String("port", "8080", "port to listen on") + + // Define a command-line flag for the response message + responseMessage := flag.String("respond", "hi", "message to respond with") + + flag.Parse() // Parse the command-line flags + + // Set up the handler function using the provided response message + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, *responseMessage) + }) + + // Set up the /health endpoint handler function + http.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + response := `{"status": "ok"}` + w.Write([]byte(response)) + }) + + address := ":" + *port // Address with the specified port + fmt.Printf("Server is listening on port %s\n", *port) + + // Start the server and log any error if it occurs + if err := http.ListenAndServe(address, nil); err != nil { + fmt.Printf("Error starting server: %s\n", err) + } +} diff --git a/config.example.yaml b/config.example.yaml new file mode 100644 index 0000000..7fa8dda --- /dev/null +++ b/config.example.yaml @@ -0,0 +1,7 @@ +models: + "llama": + cmd: "models/llama-server-osx --port 8999 -m models/Llama-3.2-1B-Instruct-Q4_K_M.gguf" + proxy: "http://127.0.0.1:8999" + "qwen": + cmd: "models/llama-server-osx --port 8999 -m models/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf " + proxy: "http://127.0.0.1:8999" \ No newline at end of file diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..b9ef41a --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module golang-llama-cpp-proxy + +go 1.23.0 + +require gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..4bc0337 --- /dev/null +++ b/go.sum @@ -0,0 +1,3 @@ +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/llama-proxy.go b/llama-proxy.go new file mode 100644 index 0000000..b855f7a --- /dev/null +++ b/llama-proxy.go @@ -0,0 +1,215 @@ +package main + +import ( + "bytes" + "context" + "encoding/json" + "flag" + "fmt" + "io" + "net/http" + "os" + "os/exec" + "strings" + "sync" + "syscall" + "time" + + "gopkg.in/yaml.v3" +) + +type ModelConfig struct { + Cmd string `yaml:"cmd"` + Proxy string `yaml:"proxy"` +} + +type Config struct { + Models map[string]ModelConfig `yaml:"models"` +} + +type ServiceState struct { + sync.Mutex + currentCmd *exec.Cmd + currentModel string +} + +func loadConfig(path string) (*Config, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + var config Config + err = yaml.Unmarshal(data, &config) + if err != nil { + return nil, err + } + + return &config, nil +} + +func startService(command string) (*exec.Cmd, error) { + args := strings.Fields(command) + cmd := exec.Command(args[0], args[1:]...) + + // write it to the stdout/stderr of the proxy + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + err := cmd.Start() + if err != nil { + return nil, err + } + + return cmd, nil +} + +func checkHealthEndpoint(client *http.Client, healthURL string, maxDuration time.Duration) error { + startTime := time.Now() + for { + req, err := http.NewRequest("GET", healthURL, nil) + if err != nil { + return err + } + + // Set request timeout + ctx, cancel := context.WithTimeout(req.Context(), 250*time.Millisecond) + defer cancel() + + // Execute the request with the context + req = req.WithContext(ctx) + resp, err := client.Do(req) + if err != nil { + // Log error and check elapsed time before retrying + if time.Since(startTime) >= maxDuration { + return fmt.Errorf("failed to get a healthy response from: %s", healthURL) + } + + // Wait a second before retrying + time.Sleep(time.Second) + continue + } + + // Close response body + defer resp.Body.Close() + + // Check if we got a 200 OK response + if resp.StatusCode == http.StatusOK { + return nil // Health check succeeded + } + + // Check elapsed time before retrying + if time.Since(startTime) >= maxDuration { + return fmt.Errorf("failed to get a healthy response from: %s", healthURL) + } + + // Wait a second before retrying + time.Sleep(time.Second) + } +} + +func proxyRequest(w http.ResponseWriter, r *http.Request, config *Config, state *ServiceState) { + client := &http.Client{} + + // Read the original request body + bodyBytes, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Invalid JSON", http.StatusBadRequest) + return + } + + var requestBody map[string]interface{} + if err := json.Unmarshal(bodyBytes, &requestBody); err != nil { + http.Error(w, "Invalid JSON", http.StatusBadRequest) + return + } + + model, ok := requestBody["model"].(string) + if !ok { + http.Error(w, "Missing or invalid 'model' key", http.StatusBadRequest) + return + } + + modelConfig, ok := config.Models[model] + if !ok { + http.Error(w, "Model not found in configuration", http.StatusNotFound) + return + } + + err = error(nil) + state.Lock() + defer state.Unlock() + + if state.currentModel != model { + if state.currentCmd != nil { + state.currentCmd.Process.Signal(syscall.SIGTERM) + } + state.currentCmd, err = startService(modelConfig.Cmd) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + state.currentModel = model + + // Check the /health endpoint + healthURL := modelConfig.Proxy + "/health" + err = checkHealthEndpoint(client, healthURL, 30*time.Second) + if err != nil { + http.Error(w, err.Error(), http.StatusServiceUnavailable) + return + } + } + + req, err := http.NewRequest(r.Method, modelConfig.Proxy+r.URL.String(), io.NopCloser(bytes.NewBuffer(bodyBytes))) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + req.Header = r.Header + resp, err := client.Do(req) + if err != nil { + http.Error(w, err.Error(), http.StatusBadGateway) + return + } + defer resp.Body.Close() + + for k, vv := range resp.Header { + for _, v := range vv { + w.Header().Add(k, v) + } + } + w.WriteHeader(resp.StatusCode) + + io.Copy(w, resp.Body) +} + +func main() { + // Define a command-line flag for the port + configPath := flag.String("config", "config.yaml", "config file name") + listenStr := flag.String("listen", ":8080", "listen ip/port") + + flag.Parse() // Parse the command-line flags + + config, err := loadConfig(*configPath) + if err != nil { + fmt.Printf("Error loading config: %v\n", err) + os.Exit(1) + } + + serviceState := &ServiceState{} + + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/v1/chat/completions" { + proxyRequest(w, r, config, serviceState) + } else { + http.Error(w, "Endpoint not supported", http.StatusNotFound) + } + }) + + fmt.Println("Proxy server started on :8080") + if err := http.ListenAndServe(*listenStr, nil); err != nil { + fmt.Printf("Error starting server: %v\n", err) + os.Exit(1) + } +} diff --git a/models/.gitignore b/models/.gitignore new file mode 100644 index 0000000..e5af87e --- /dev/null +++ b/models/.gitignore @@ -0,0 +1,3 @@ +* +!.gitignore +!README.md \ No newline at end of file diff --git a/models/README.md b/models/README.md new file mode 100644 index 0000000..459fc21 --- /dev/null +++ b/models/README.md @@ -0,0 +1,7 @@ +TODO improve these docs + +1. Download a llama-server suitable for your architecture +1. Fetch some small models for testing / swapping between + - `huggingface-cli download bartowski/Qwen2.5-1.5B-Instruct-GGUF --include "Qwen2.5-1.5B-Instruct-Q4_K_M.gguf" --local-dir ./` + - `huggingface-cli download bartowski/Llama-3.2-1B-Instruct-GGUF --include "Llama-3.2-1B-Instruct-Q4_K_M.gguf" --local-dir ./` +1. Create a new config.yaml file (see `config.example.yaml`) pointing to the models \ No newline at end of file