diff --git a/config.example.yaml b/config.example.yaml index 381bd0c..8699a03 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -38,12 +38,19 @@ startPort: 10001 # macros: a dictionary of string substitutions # - optional, default: empty dictionary # - macros are reusable snippets -# - used in a model's cmd, cmdStop, proxy and checkEndpoint +# - used in a model's cmd, cmdStop, proxy, checkEndpoint, filters.stripParams # - useful for reducing common configuration settings +# - macro names are strings and must be less than 64 characters +# - macro names must match the regex ^[a-zA-Z0-9_-]+$ +# - macro names must not be a reserved name: PORT or MODEL_ID +# - macro values must be less than 1024 characters +# +# Important: do not nest macros inside other macros; expansion is single-pass macros: "latest-llama": > /path/to/llama-server/llama-server-ec9e0301 --port ${PORT} + "default_ctx": "4096" # models: a dictionary of model configurations # - required @@ -55,6 +62,13 @@ models: # keys are the model names used in API requests "llama": + # macros: a dictionary of string substitutions specific to this model + # - optional, default: empty dictionary + # - macros defined here override macros defined in the global macros section + # - model level macros follow the same rules as global macros + macros: + "default_ctx": "16384" + # cmd: the command to run to start the inference server. # - required # - it is just a string, similar to what you would run on the CLI @@ -64,6 +78,7 @@ models: # ${latest-llama} is a macro that is defined above ${latest-llama} --model path/to/llama-8B-Q4_K_M.gguf + --ctx-size ${default_ctx} # name: a display name for the model # - optional, default: empty string @@ -119,15 +134,15 @@ models: # filters: a dictionary of filter settings # - optional, default: empty dictionary - # - only strip_params is currently supported + # - only stripParams is currently supported filters: - # strip_params: a comma separated list of parameters to remove from the request + # stripParams: a comma separated list of parameters to remove from the request # - optional, default: "" # - useful for server side enforcement of sampling parameters # - the `model` parameter can never be removed # - can be any JSON key in the request body # - recommended to stick to sampling parameters - strip_params: "temperature, top_p, top_k" + stripParams: "temperature, top_p, top_k" # concurrencyLimit: overrides the allowed number of active parallel requests to a model # - optional, default: 0 diff --git a/proxy/config/config.go b/proxy/config/config.go index 1dc8223..e043202 100644 --- a/proxy/config/config.go +++ b/proxy/config/config.go @@ -6,7 +6,6 @@ import ( "os" "regexp" "runtime" - "slices" "sort" "strconv" "strings" @@ -17,101 +16,7 @@ import ( const DEFAULT_GROUP_ID = "(default)" -type ModelConfig struct { - Cmd string `yaml:"cmd"` - CmdStop string `yaml:"cmdStop"` - Proxy string `yaml:"proxy"` - Aliases []string `yaml:"aliases"` - Env []string `yaml:"env"` - CheckEndpoint string `yaml:"checkEndpoint"` - UnloadAfter int `yaml:"ttl"` - Unlisted bool `yaml:"unlisted"` - UseModelName string `yaml:"useModelName"` - - // #179 for /v1/models - Name string `yaml:"name"` - Description string `yaml:"description"` - - // Limit concurrency of HTTP requests to process - ConcurrencyLimit int `yaml:"concurrencyLimit"` - - // Model filters see issue #174 - Filters ModelFilters `yaml:"filters"` -} - -func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { - type rawModelConfig ModelConfig - defaults := rawModelConfig{ - Cmd: "", - CmdStop: "", - Proxy: "http://localhost:${PORT}", - Aliases: []string{}, - Env: []string{}, - CheckEndpoint: "/health", - UnloadAfter: 0, - Unlisted: false, - UseModelName: "", - ConcurrencyLimit: 0, - Name: "", - Description: "", - } - - // the default cmdStop to taskkill /f /t /pid ${PID} - if runtime.GOOS == "windows" { - defaults.CmdStop = "taskkill /f /t /pid ${PID}" - } - - if err := unmarshal(&defaults); err != nil { - return err - } - - *m = ModelConfig(defaults) - return nil -} - -func (m *ModelConfig) SanitizedCommand() ([]string, error) { - return SanitizeCommand(m.Cmd) -} - -// ModelFilters see issue #174 -type ModelFilters struct { - StripParams string `yaml:"strip_params"` -} - -func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error { - type rawModelFilters ModelFilters - defaults := rawModelFilters{ - StripParams: "", - } - - if err := unmarshal(&defaults); err != nil { - return err - } - - *m = ModelFilters(defaults) - return nil -} - -func (f ModelFilters) SanitizedStripParams() ([]string, error) { - if f.StripParams == "" { - return nil, nil - } - - params := strings.Split(f.StripParams, ",") - cleaned := make([]string, 0, len(params)) - - for _, param := range params { - trimmed := strings.TrimSpace(param) - if trimmed == "model" || trimmed == "" { - continue - } - cleaned = append(cleaned, trimmed) - } - - // sort cleaned - slices.Sort(cleaned) - return cleaned, nil -} +type MacroList map[string]string type GroupConfig struct { Swap bool `yaml:"swap"` @@ -156,7 +61,7 @@ type Config struct { Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */ // for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint - Macros map[string]string `yaml:"macros"` + Macros MacroList `yaml:"macros"` // map aliases to actual model IDs aliases map[string]string @@ -240,21 +145,9 @@ func LoadConfigFromReader(r io.Reader) (Config, error) { - name can not be any reserved macros: PORT, MODEL_ID - macro values must be less than 1024 characters */ - macroNameRegex := regexp.MustCompile(`^[a-zA-Z0-9_-]+$`) for macroName, macroValue := range config.Macros { - if len(macroName) >= 64 { - return Config{}, fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", macroName) - } - if !macroNameRegex.MatchString(macroName) { - return Config{}, fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", macroName) - } - if len(macroValue) >= 1024 { - return Config{}, fmt.Errorf("macro value for '%s' exceeds maximum length of 1024 characters", macroName) - } - switch macroName { - case "PORT": - case "MODEL_ID": - return Config{}, fmt.Errorf("macro name '%s' is reserved and cannot be used", macroName) + if err = validateMacro(macroName, macroValue); err != nil { + return Config{}, err } } @@ -273,8 +166,24 @@ func LoadConfigFromReader(r io.Reader) (Config, error) { modelConfig.Cmd = StripComments(modelConfig.Cmd) modelConfig.CmdStop = StripComments(modelConfig.CmdStop) + // validate model macros + for macroName, macroValue := range modelConfig.Macros { + if err = validateMacro(macroName, macroValue); err != nil { + return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error()) + } + } + + // Merge global config and model macros. Model macros take precedence + mergedMacros := make(MacroList) + for k, v := range config.Macros { + mergedMacros[k] = v + } + for k, v := range modelConfig.Macros { + mergedMacros[k] = v + } + // go through model config fields: cmd, cmdStop, proxy, checkEndPoint and replace macros with macro values - for macroName, macroValue := range config.Macros { + for macroName, macroValue := range mergedMacros { macroSlug := fmt.Sprintf("${%s}", macroName) modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroValue) modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroValue) @@ -305,10 +214,11 @@ func LoadConfigFromReader(r io.Reader) (Config, error) { // make sure there are no unknown macros that have not been replaced macroPattern := regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`) fieldMap := map[string]string{ - "cmd": modelConfig.Cmd, - "cmdStop": modelConfig.CmdStop, - "proxy": modelConfig.Proxy, - "checkEndpoint": modelConfig.CheckEndpoint, + "cmd": modelConfig.Cmd, + "cmdStop": modelConfig.CmdStop, + "proxy": modelConfig.Proxy, + "checkEndpoint": modelConfig.CheckEndpoint, + "filters.stripParams": modelConfig.Filters.StripParams, } for fieldName, fieldValue := range fieldMap { @@ -458,3 +368,27 @@ func StripComments(cmdStr string) string { } return strings.Join(cleanedLines, "\n") } + +var ( + macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`) +) + +// validateMacro validates macro name and value constraints +func validateMacro(name, value string) error { + if len(name) >= 64 { + return fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", name) + } + if !macroNameRegex.MatchString(name) { + return fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", name) + } + if len(value) >= 1024 { + return fmt.Errorf("macro value for '%s' exceeds maximum length of 1024 characters", name) + } + + switch name { + case "PORT", "MODEL_ID": + return fmt.Errorf("macro name '%s' is reserved", name) + } + + return nil +} diff --git a/proxy/config/config_test.go b/proxy/config/config_test.go index 0bee014..fdb87ee 100644 --- a/proxy/config/config_test.go +++ b/proxy/config/config_test.go @@ -65,18 +65,6 @@ models: assert.Contains(t, err.Error(), "duplicate alias m1 found in model: model") } -func TestConfig_ModelConfigSanitizedCommand(t *testing.T) { - config := &ModelConfig{ - Cmd: `python model1.py \ - --arg1 value1 \ - --arg2 value2`, - } - - args, err := config.SanitizedCommand() - assert.NoError(t, err) - assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args) -} - func TestConfig_FindConfig(t *testing.T) { // TODO? @@ -207,15 +195,19 @@ macros: argOne: "--arg1" argTwo: "--arg2" autoPort: "--port ${PORT}" + overriddenByModelMacro: failed models: model1: + macros: + overriddenByModelMacro: success cmd: | ${svr-path} ${argTwo} # the automatic ${PORT} is replaced ${autoPort} ${argOne} --arg3 three + --overridden ${overriddenByModelMacro} cmdStop: | /path/to/stop.sh --port ${PORT} ${argTwo} ` @@ -224,13 +216,68 @@ models: assert.NoError(t, err) sanitizedCmd, err := SanitizeCommand(config.Models["model1"].Cmd) assert.NoError(t, err) - assert.Equal(t, "path/to/server --arg2 --port 9990 --arg1 --arg3 three", strings.Join(sanitizedCmd, " ")) + assert.Equal(t, "path/to/server --arg2 --port 9990 --arg1 --arg3 three --overridden success", strings.Join(sanitizedCmd, " ")) sanitizedCmdStop, err := SanitizeCommand(config.Models["model1"].CmdStop) assert.NoError(t, err) assert.Equal(t, "/path/to/stop.sh --port 9990 --arg2", strings.Join(sanitizedCmdStop, " ")) } +func TestConfig_MacroReservedNames(t *testing.T) { + + tests := []struct { + name string + config string + expectedError string + }{ + { + name: "global macro named PORT", + config: ` +macros: + PORT: "1111" +`, + expectedError: "macro name 'PORT' is reserved", + }, + { + name: "global macro named MODEL_ID", + config: ` +macros: + MODEL_ID: model1 +`, + expectedError: "macro name 'MODEL_ID' is reserved", + }, + { + name: "model macro named PORT", + config: ` +models: + model1: + macros: + PORT: 1111 +`, + expectedError: "model model1: macro name 'PORT' is reserved", + }, + + { + name: "model macro named MODEL_ID", + config: ` +models: + model1: + macros: + MODEL_ID: model1 +`, + expectedError: "model model1: macro name 'MODEL_ID' is reserved", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := LoadConfigFromReader(strings.NewReader(tt.config)) + assert.NotNil(t, err) + assert.Equal(t, tt.expectedError, err.Error()) + }) + } +} + func TestConfig_MacroErrorOnUnknownMacros(t *testing.T) { tests := []struct { name string @@ -288,6 +335,20 @@ models: model1: cmd: "${svr-path} --port ${PORT}" checkEndpoint: "http://localhost:${unknownMacro}/health" +`, + }, + { + name: "unknown macro in filters.stripParams", + field: "filters.stripParams", + content: ` +startPort: 9990 +macros: + svr-path: "path/to/server" +models: + model1: + cmd: "${svr-path} --port ${PORT}" + filters: + stripParams: "model,${unknownMacro}" `, }, } @@ -295,38 +356,13 @@ models: for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { _, err := LoadConfigFromReader(strings.NewReader(tt.content)) - assert.Error(t, err) - assert.Contains(t, err.Error(), "unknown macro '${unknownMacro}' found in model1."+tt.field) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "unknown macro '${unknownMacro}' found in model1."+tt.field) + } //t.Log(err) }) } } - -func TestConfig_ModelFilters(t *testing.T) { - content := ` -macros: - default_strip: "temperature, top_p" -models: - model1: - cmd: path/to/cmd --port ${PORT} - filters: - strip_params: "model, top_k, ${default_strip}, , ," -` - config, err := LoadConfigFromReader(strings.NewReader(content)) - assert.NoError(t, err) - modelConfig, ok := config.Models["model1"] - if !assert.True(t, ok) { - t.FailNow() - } - - // make sure `model` and enmpty strings are not in the list - assert.Equal(t, "model, top_k, temperature, top_p, , ,", modelConfig.Filters.StripParams) - sanitized, err := modelConfig.Filters.SanitizedStripParams() - if assert.NoError(t, err) { - assert.Equal(t, []string{"temperature", "top_k", "top_p"}, sanitized) - } -} - func TestStripComments(t *testing.T) { tests := []struct { name string diff --git a/proxy/config/model_config.go b/proxy/config/model_config.go new file mode 100644 index 0000000..40386f6 --- /dev/null +++ b/proxy/config/model_config.go @@ -0,0 +1,121 @@ +package config + +import ( + "errors" + "runtime" + "slices" + "strings" +) + +type ModelConfig struct { + Cmd string `yaml:"cmd"` + CmdStop string `yaml:"cmdStop"` + Proxy string `yaml:"proxy"` + Aliases []string `yaml:"aliases"` + Env []string `yaml:"env"` + CheckEndpoint string `yaml:"checkEndpoint"` + UnloadAfter int `yaml:"ttl"` + Unlisted bool `yaml:"unlisted"` + UseModelName string `yaml:"useModelName"` + + // #179 for /v1/models + Name string `yaml:"name"` + Description string `yaml:"description"` + + // Limit concurrency of HTTP requests to process + ConcurrencyLimit int `yaml:"concurrencyLimit"` + + // Model filters see issue #174 + Filters ModelFilters `yaml:"filters"` + + // Macros: see #264 + // Model level macros take precedence over the global macros + Macros MacroList `yaml:"macros"` +} + +func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { + type rawModelConfig ModelConfig + defaults := rawModelConfig{ + Cmd: "", + CmdStop: "", + Proxy: "http://localhost:${PORT}", + Aliases: []string{}, + Env: []string{}, + CheckEndpoint: "/health", + UnloadAfter: 0, + Unlisted: false, + UseModelName: "", + ConcurrencyLimit: 0, + Name: "", + Description: "", + } + + // the default cmdStop to taskkill /f /t /pid ${PID} + if runtime.GOOS == "windows" { + defaults.CmdStop = "taskkill /f /t /pid ${PID}" + } + + if err := unmarshal(&defaults); err != nil { + return err + } + + *m = ModelConfig(defaults) + return nil +} + +func (m *ModelConfig) SanitizedCommand() ([]string, error) { + return SanitizeCommand(m.Cmd) +} + +// ModelFilters see issue #174 +type ModelFilters struct { + StripParams string `yaml:"stripParams"` +} + +func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error { + type rawModelFilters ModelFilters + defaults := rawModelFilters{ + StripParams: "", + } + + if err := unmarshal(&defaults); err != nil { + return err + } + + // Try to unmarshal with the old field name for backwards compatibility + if defaults.StripParams == "" { + var legacy struct { + StripParams string `yaml:"strip_params"` + } + if legacyErr := unmarshal(&legacy); legacyErr != nil { + return errors.New("failed to unmarshal legacy filters.strip_params: " + legacyErr.Error()) + } + defaults.StripParams = legacy.StripParams + } + + *m = ModelFilters(defaults) + return nil +} + +func (f ModelFilters) SanitizedStripParams() ([]string, error) { + if f.StripParams == "" { + return nil, nil + } + + params := strings.Split(f.StripParams, ",") + cleaned := make([]string, 0, len(params)) + seen := make(map[string]bool) + + for _, param := range params { + trimmed := strings.TrimSpace(param) + if trimmed == "model" || trimmed == "" || seen[trimmed] { + continue + } + seen[trimmed] = true + cleaned = append(cleaned, trimmed) + } + + // sort cleaned + slices.Sort(cleaned) + return cleaned, nil +} diff --git a/proxy/config/model_config_test.go b/proxy/config/model_config_test.go new file mode 100644 index 0000000..3979d01 --- /dev/null +++ b/proxy/config/model_config_test.go @@ -0,0 +1,54 @@ +package config + +import ( + "fmt" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConfig_ModelConfigSanitizedCommand(t *testing.T) { + config := &ModelConfig{ + Cmd: `python model1.py \ + --arg1 value1 \ + --arg2 value2`, + } + + args, err := config.SanitizedCommand() + assert.NoError(t, err) + assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args) +} + +func TestConfig_ModelFilters(t *testing.T) { + content := ` +macros: + default_strip: "temperature, top_p" +models: + model1: + cmd: path/to/cmd --port ${PORT} + filters: + # macros inserted and list is cleaned of duplicates and empty strings + stripParams: "model, top_k, top_k, temperature, ${default_strip}, , ," + # check for strip_params (legacy field name) compatibility + legacy: + cmd: path/to/cmd --port ${PORT} + filters: + strip_params: "model, top_k, top_k, temperature, ${default_strip}, , ," +` + config, err := LoadConfigFromReader(strings.NewReader(content)) + assert.NoError(t, err) + for modelId, modelConfig := range config.Models { + t.Run(fmt.Sprintf("Testing macros in filters for model %s", modelId), func(t *testing.T) { + assert.Equal(t, "model, top_k, top_k, temperature, temperature, top_p, , ,", modelConfig.Filters.StripParams) + sanitized, err := modelConfig.Filters.SanitizedStripParams() + if assert.NoError(t, err) { + // model has been removed + // empty strings have been removed + // duplicates have been removed + assert.Equal(t, []string{"temperature", "top_k", "top_p"}, sanitized) + } + }) + } + +}