diff --git a/.gitignore b/.gitignore index ec8848a..3c8fa18 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ build/ dist/ .vscode .DS_Store +.dev/ diff --git a/README.md b/README.md index fa119ed..7818ea5 100644 --- a/README.md +++ b/README.md @@ -31,8 +31,9 @@ Written in golang, it is very easy to install (single binary with no dependencie - ✅ Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107)) - ✅ Automatic unloading of models after timeout by setting a `ttl` - ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc) -- ✅ Docker and Podman support +- ✅ Reliable Docker and Podman support with `cmdStart` and `cmdStop` - ✅ Full control over server settings per model +- ✅ Preload models on startup with `hooks` ([#235](https://github.com/mostlygeek/llama-swap/pull/235)) ## How does llama-swap work? @@ -42,9 +43,9 @@ In the most basic configuration llama-swap handles one model at a time. For more ## config.yaml -llama-swap is managed entirely through a yaml configuration file. +llama-swap is managed entirely through a yaml configuration file. -It can be very minimal to start: +It can be very minimal to start: ```yaml models: @@ -55,7 +56,7 @@ models: --port ${PORT} ``` -However, there are many more capabilities that llama-swap supports: +However, there are many more capabilities that llama-swap supports: - `groups` to run multiple models at once - `ttl` to automatically unload models @@ -90,7 +91,7 @@ llama-swap can be installed in multiple ways ### Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap)) -Docker images with llama-swap and llama-server are built nightly. +Docker images with llama-swap and llama-server are built nightly. ```shell # use CPU inference comes with the example config above @@ -137,10 +138,10 @@ $ docker run -it --rm --runtime nvidia -p 9292:8080 \ ### Homebrew Install (macOS/Linux) -The latest release of `llama-swap` can be installed via [Homebrew](https://brew.sh). +The latest release of `llama-swap` can be installed via [Homebrew](https://brew.sh). ```shell -# Set up tap and install formula +# Set up tap and install formula brew tap mostlygeek/llama-swap brew install llama-swap # Run llama-swap diff --git a/config.example.yaml b/config.example.yaml index 806c500..77d68be 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -1,6 +1,13 @@ # llama-swap YAML configuration example # ------------------------------------- # +# 💡 Tip - Use an LLM with this file! +# ==================================== +# This example configuration is written to be LLM friendly! Try +# copying this file into an LLM and asking it to explain or generate +# sections for you. +# ==================================== +# # - Below are all the available configuration options for llama-swap. # - Settings with a default value, or noted as optional can be omitted. # - Settings that are marked required must be in your configuration file @@ -207,3 +214,19 @@ groups: - "forever-modelA" - "forever-modelB" - "forever-modelc" + +# hooks: a dictionary of event triggers and actions +# - optional, default: empty dictionary +# - the only supported hook is on_startup +hooks: + # on_startup: a dictionary of actions to perform on startup + # - optional, default: empty dictionar + # - the only supported action is preload + on_startup: + # preload: a list of model ids to load on startup + # - optional, default: empty list + # - model names must match keys in the models sections + # - when preloading multiple models at once, define a group + # otherwise models will be loaded and swapped out + preload: + - "llama" diff --git a/proxy/config.go b/proxy/config.go index ee72747..9447060 100644 --- a/proxy/config.go +++ b/proxy/config.go @@ -138,6 +138,14 @@ func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { return nil } +type HooksConfig struct { + OnStartup HookOnStartup `yaml:"on_startup"` +} + +type HookOnStartup struct { + Preload []string `yaml:"preload"` +} + type Config struct { HealthCheckTimeout int `yaml:"healthCheckTimeout"` LogRequests bool `yaml:"logRequests"` @@ -155,6 +163,9 @@ type Config struct { // automatic port assignments StartPort int `yaml:"startPort"` + + // hooks, see: #209 + Hooks HooksConfig `yaml:"hooks"` } func (c *Config) RealModelName(search string) (string, bool) { @@ -330,6 +341,22 @@ func LoadConfigFromReader(r io.Reader) (Config, error) { } } + // clean up hooks preload + if len(config.Hooks.OnStartup.Preload) > 0 { + var toPreload []string + for _, modelID := range config.Hooks.OnStartup.Preload { + modelID = strings.TrimSpace(modelID) + if modelID == "" { + continue + } + if real, found := config.RealModelName(modelID); found { + toPreload = append(toPreload, real) + } + } + + config.Hooks.OnStartup.Preload = toPreload + } + return config, nil } diff --git a/proxy/config_posix_test.go b/proxy/config_posix_test.go index da49997..122d351 100644 --- a/proxy/config_posix_test.go +++ b/proxy/config_posix_test.go @@ -100,6 +100,9 @@ func TestConfig_LoadPosix(t *testing.T) { content := ` macros: svr-path: "path/to/server" +hooks: + on_startup: + preload: ["model1", "model2"] models: model1: cmd: path/to/cmd --arg1 one @@ -163,6 +166,11 @@ groups: Macros: map[string]string{ "svr-path": "path/to/server", }, + Hooks: HooksConfig{ + OnStartup: HookOnStartup{ + Preload: []string{"model1", "model2"}, + }, + }, Models: map[string]ModelConfig{ "model1": { Cmd: "path/to/cmd --arg1 one", diff --git a/proxy/discardWriter.go b/proxy/discardWriter.go new file mode 100644 index 0000000..8af8c04 --- /dev/null +++ b/proxy/discardWriter.go @@ -0,0 +1,27 @@ +package proxy + +import "net/http" + +// Custom discard writer that implements http.ResponseWriter but just discards everything +type DiscardWriter struct { + header http.Header + status int +} + +func (w *DiscardWriter) Header() http.Header { + if w.header == nil { + w.header = make(http.Header) + } + return w.header +} + +func (w *DiscardWriter) Write(data []byte) (int, error) { + return len(data), nil +} + +func (w *DiscardWriter) WriteHeader(code int) { + w.status = code +} + +// Satisfy the http.Flusher interface for streaming responses +func (w *DiscardWriter) Flush() {} diff --git a/proxy/events.go b/proxy/events.go index 9a4d5fb..11403fc 100644 --- a/proxy/events.go +++ b/proxy/events.go @@ -7,6 +7,7 @@ const ChatCompletionStatsEventID = 0x02 const ConfigFileChangedEventID = 0x03 const LogDataEventID = 0x04 const TokenMetricsEventID = 0x05 +const ModelPreloadedEventID = 0x06 type ProcessStateChangeEvent struct { ProcessName string @@ -48,3 +49,12 @@ type LogDataEvent struct { func (e LogDataEvent) Type() uint32 { return LogDataEventID } + +type ModelPreloadedEvent struct { + ModelName string + Success bool +} + +func (e ModelPreloadedEvent) Type() uint32 { + return ModelPreloadedEventID +} diff --git a/proxy/helpers_test.go b/proxy/helpers_test.go index ccea359..bcb5acb 100644 --- a/proxy/helpers_test.go +++ b/proxy/helpers_test.go @@ -13,9 +13,10 @@ import ( ) var ( - nextTestPort int = 12000 - portMutex sync.Mutex - testLogger = NewLogMonitorWriter(os.Stdout) + nextTestPort int = 12000 + portMutex sync.Mutex + testLogger = NewLogMonitorWriter(os.Stdout) + simpleResponderPath = getSimpleResponderPath() ) // Check if the binary exists @@ -69,13 +70,11 @@ func getTestSimpleResponderConfig(expectedMessage string) ModelConfig { } func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelConfig { - binaryPath := getSimpleResponderPath() - // Create a YAML string with just the values we want to set yamlStr := fmt.Sprintf(` cmd: '%s --port %d --silent --respond %s' proxy: "http://127.0.0.1:%d" -`, binaryPath, port, expectedMessage, port) +`, simpleResponderPath, port, expectedMessage, port) var cfg ModelConfig if err := yaml.Unmarshal([]byte(yamlStr), &cfg); err != nil { diff --git a/proxy/proxymanager.go b/proxy/proxymanager.go index 3f123f6..644851d 100644 --- a/proxy/proxymanager.go +++ b/proxy/proxymanager.go @@ -15,6 +15,7 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/mostlygeek/llama-swap/event" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) @@ -96,6 +97,35 @@ func New(config Config) *ProxyManager { } pm.setupGinEngine() + + // run any startup hooks + if len(config.Hooks.OnStartup.Preload) > 0 { + // do it in the background, don't block startup -- not sure if good idea yet + go func() { + discardWriter := &DiscardWriter{} + for _, realModelName := range config.Hooks.OnStartup.Preload { + proxyLogger.Infof("Preloading model: %s", realModelName) + processGroup, _, err := pm.swapProcessGroup(realModelName) + + if err != nil { + event.Emit(ModelPreloadedEvent{ + ModelName: realModelName, + Success: false, + }) + proxyLogger.Errorf("Failed to preload model %s: %v", realModelName, err) + continue + } else { + req, _ := http.NewRequest("GET", "/", nil) + processGroup.ProxyRequest(realModelName, discardWriter, req) + event.Emit(ModelPreloadedEvent{ + ModelName: realModelName, + Success: true, + }) + } + } + }() + } + return pm } diff --git a/proxy/proxymanager_test.go b/proxy/proxymanager_test.go index 0959eea..e632be5 100644 --- a/proxy/proxymanager_test.go +++ b/proxy/proxymanager_test.go @@ -14,6 +14,7 @@ import ( "testing" "time" + "github.com/mostlygeek/llama-swap/event" "github.com/stretchr/testify/assert" "github.com/tidwall/gjson" ) @@ -832,3 +833,62 @@ func TestProxyManager_HealthEndpoint(t *testing.T) { assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, "OK", rec.Body.String()) } + +func TestProxyManager_StartupHooks(t *testing.T) { + + // using real YAML as the configuration has gotten more complex + // is the right approach as LoadConfigFromReader() does a lot more + // than parse YAML now. Eventually migrate all tests to use this approach + configStr := strings.Replace(` +logLevel: error +hooks: + on_startup: + preload: + - model1 + - model2 +groups: + preloadTestGroup: + swap: false + members: + - model1 + - model2 +models: + model1: + cmd: ${simpleresponderpath} --port ${PORT} --silent --respond model1 + model2: + cmd: ${simpleresponderpath} --port ${PORT} --silent --respond model2 +`, "${simpleresponderpath}", simpleResponderPath, -1) + + // Create a test model configuration + config, err := LoadConfigFromReader(strings.NewReader(configStr)) + if !assert.NoError(t, err, "Invalid configuration") { + return + } + + preloadChan := make(chan ModelPreloadedEvent, 2) // buffer for 2 expected events + + unsub := event.On(func(e ModelPreloadedEvent) { + preloadChan <- e + }) + + defer unsub() + + // Create the proxy which should trigger preloading + proxy := New(config) + defer proxy.StopProcesses(StopWaitForInflightRequest) + + for i := 0; i < 2; i++ { + select { + case <-preloadChan: + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for models to preload") + } + } + // make sure they are both loaded + _, foundGroup := proxy.processGroups["preloadTestGroup"] + if !assert.True(t, foundGroup, "preloadTestGroup should exist") { + return + } + assert.Equal(t, StateReady, proxy.processGroups["preloadTestGroup"].processes["model1"].CurrentState()) + assert.Equal(t, StateReady, proxy.processGroups["preloadTestGroup"].processes["model2"].CurrentState()) +}