diff --git a/config/config.go b/config/config.go new file mode 100644 index 0000000..93dfaa1 --- /dev/null +++ b/config/config.go @@ -0,0 +1,31 @@ +package config + +import ( + "os" + + "gopkg.in/yaml.v3" +) + +type ModelConfig struct { + Cmd string `yaml:"cmd"` + Proxy string `yaml:"proxy"` +} + +type Config struct { + Models map[string]ModelConfig `yaml:"models"` +} + +func LoadConfig(path string) (*Config, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + var config Config + err = yaml.Unmarshal(data, &config) + if err != nil { + return nil, err + } + + return &config, nil +} diff --git a/go.mod b/go.mod index b9ef41a..ba48cea 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,5 @@ -module golang-llama-cpp-proxy +module github.com/mostlygeek/go-llama-cpp-proxy go 1.23.0 -require gopkg.in/yaml.v3 v3.0.1 // indirect +require gopkg.in/yaml.v3 v3.0.1 // indirect \ No newline at end of file diff --git a/llama-proxy.go b/llama-proxy.go index b3e79de..445cf99 100644 --- a/llama-proxy.go +++ b/llama-proxy.go @@ -15,44 +15,19 @@ import ( "syscall" "time" - "gopkg.in/yaml.v3" + "github.com/mostlygeek/go-llama-cpp-proxy/config" ) -type ModelConfig struct { - Cmd string `yaml:"cmd"` - Proxy string `yaml:"proxy"` -} - -type Config struct { - Models map[string]ModelConfig `yaml:"models"` -} - type ServiceState struct { sync.Mutex currentCmd *exec.Cmd currentModel string } -func loadConfig(path string) (*Config, error) { - data, err := os.ReadFile(path) - if err != nil { - return nil, err - } - - var config Config - err = yaml.Unmarshal(data, &config) - if err != nil { - return nil, err - } - - return &config, nil -} - func startService(command string) (*exec.Cmd, error) { args := strings.Fields(command) cmd := exec.Command(args[0], args[1:]...) - // write it to the stdout/stderr of the proxy cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr @@ -64,7 +39,9 @@ func startService(command string) (*exec.Cmd, error) { return cmd, nil } -func checkHealthEndpoint(client *http.Client, healthURL string, maxDuration time.Duration) error { +func checkHealthEndpoint(healthURL string, maxDuration time.Duration) error { + client := &http.Client{} + startTime := time.Now() for { req, err := http.NewRequest("GET", healthURL, nil) @@ -108,9 +85,7 @@ func checkHealthEndpoint(client *http.Client, healthURL string, maxDuration time } } -func proxyRequest(w http.ResponseWriter, r *http.Request, config *Config, state *ServiceState) { - client := &http.Client{} - +func proxyChatRequest(w http.ResponseWriter, r *http.Request, config *config.Config, state *ServiceState) { // Read the original request body bodyBytes, err := io.ReadAll(r.Body) if err != nil { @@ -153,14 +128,21 @@ func proxyRequest(w http.ResponseWriter, r *http.Request, config *Config, state // Check the /health endpoint healthURL := modelConfig.Proxy + "/health" - err = checkHealthEndpoint(client, healthURL, 30*time.Second) + err = checkHealthEndpoint(healthURL, 30*time.Second) if err != nil { http.Error(w, err.Error(), http.StatusServiceUnavailable) return } } - req, err := http.NewRequest(r.Method, modelConfig.Proxy+r.URL.String(), io.NopCloser(bytes.NewBuffer(bodyBytes))) + // replace r.Body so it can be read again + r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + proxyRequest(modelConfig.Proxy, w, r) +} + +func proxyRequest(proxyHost string, w http.ResponseWriter, r *http.Request) { + client := &http.Client{} + req, err := http.NewRequest(r.Method, proxyHost+r.URL.String(), r.Body) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -211,7 +193,7 @@ func main() { flag.Parse() // Parse the command-line flags - config, err := loadConfig(*configPath) + config, err := config.LoadConfig(*configPath) if err != nil { fmt.Printf("Error loading config: %v\n", err) os.Exit(1) @@ -221,7 +203,7 @@ func main() { http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/v1/chat/completions" { - proxyRequest(w, r, config, serviceState) + proxyChatRequest(w, r, config, serviceState) } else { http.Error(w, "Endpoint not supported", http.StatusNotFound) }