move config to its own package

This commit is contained in:
Benson Wong
2024-10-03 21:08:11 -07:00
parent cb576fb178
commit f44faf5a93
3 changed files with 49 additions and 36 deletions

31
config/config.go Normal file
View File

@@ -0,0 +1,31 @@
package config
import (
"os"
"gopkg.in/yaml.v3"
)
type ModelConfig struct {
Cmd string `yaml:"cmd"`
Proxy string `yaml:"proxy"`
}
type Config struct {
Models map[string]ModelConfig `yaml:"models"`
}
func LoadConfig(path string) (*Config, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var config Config
err = yaml.Unmarshal(data, &config)
if err != nil {
return nil, err
}
return &config, nil
}

4
go.mod
View File

@@ -1,5 +1,5 @@
module golang-llama-cpp-proxy module github.com/mostlygeek/go-llama-cpp-proxy
go 1.23.0 go 1.23.0
require gopkg.in/yaml.v3 v3.0.1 // indirect require gopkg.in/yaml.v3 v3.0.1 // indirect

View File

@@ -15,44 +15,19 @@ import (
"syscall" "syscall"
"time" "time"
"gopkg.in/yaml.v3" "github.com/mostlygeek/go-llama-cpp-proxy/config"
) )
type ModelConfig struct {
Cmd string `yaml:"cmd"`
Proxy string `yaml:"proxy"`
}
type Config struct {
Models map[string]ModelConfig `yaml:"models"`
}
type ServiceState struct { type ServiceState struct {
sync.Mutex sync.Mutex
currentCmd *exec.Cmd currentCmd *exec.Cmd
currentModel string currentModel string
} }
func loadConfig(path string) (*Config, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var config Config
err = yaml.Unmarshal(data, &config)
if err != nil {
return nil, err
}
return &config, nil
}
func startService(command string) (*exec.Cmd, error) { func startService(command string) (*exec.Cmd, error) {
args := strings.Fields(command) args := strings.Fields(command)
cmd := exec.Command(args[0], args[1:]...) cmd := exec.Command(args[0], args[1:]...)
// write it to the stdout/stderr of the proxy
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr cmd.Stderr = os.Stderr
@@ -64,7 +39,9 @@ func startService(command string) (*exec.Cmd, error) {
return cmd, nil return cmd, nil
} }
func checkHealthEndpoint(client *http.Client, healthURL string, maxDuration time.Duration) error { func checkHealthEndpoint(healthURL string, maxDuration time.Duration) error {
client := &http.Client{}
startTime := time.Now() startTime := time.Now()
for { for {
req, err := http.NewRequest("GET", healthURL, nil) req, err := http.NewRequest("GET", healthURL, nil)
@@ -108,9 +85,7 @@ func checkHealthEndpoint(client *http.Client, healthURL string, maxDuration time
} }
} }
func proxyRequest(w http.ResponseWriter, r *http.Request, config *Config, state *ServiceState) { func proxyChatRequest(w http.ResponseWriter, r *http.Request, config *config.Config, state *ServiceState) {
client := &http.Client{}
// Read the original request body // Read the original request body
bodyBytes, err := io.ReadAll(r.Body) bodyBytes, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
@@ -153,14 +128,21 @@ func proxyRequest(w http.ResponseWriter, r *http.Request, config *Config, state
// Check the /health endpoint // Check the /health endpoint
healthURL := modelConfig.Proxy + "/health" healthURL := modelConfig.Proxy + "/health"
err = checkHealthEndpoint(client, healthURL, 30*time.Second) err = checkHealthEndpoint(healthURL, 30*time.Second)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable) http.Error(w, err.Error(), http.StatusServiceUnavailable)
return return
} }
} }
req, err := http.NewRequest(r.Method, modelConfig.Proxy+r.URL.String(), io.NopCloser(bytes.NewBuffer(bodyBytes))) // replace r.Body so it can be read again
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
proxyRequest(modelConfig.Proxy, w, r)
}
func proxyRequest(proxyHost string, w http.ResponseWriter, r *http.Request) {
client := &http.Client{}
req, err := http.NewRequest(r.Method, proxyHost+r.URL.String(), r.Body)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
@@ -211,7 +193,7 @@ func main() {
flag.Parse() // Parse the command-line flags flag.Parse() // Parse the command-line flags
config, err := loadConfig(*configPath) config, err := config.LoadConfig(*configPath)
if err != nil { if err != nil {
fmt.Printf("Error loading config: %v\n", err) fmt.Printf("Error loading config: %v\n", err)
os.Exit(1) os.Exit(1)
@@ -221,7 +203,7 @@ func main() {
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/v1/chat/completions" { if r.URL.Path == "/v1/chat/completions" {
proxyRequest(w, r, config, serviceState) proxyChatRequest(w, r, config, serviceState)
} else { } else {
http.Error(w, "Endpoint not supported", http.StatusNotFound) http.Error(w, "Endpoint not supported", http.StatusNotFound)
} }