proxy/config: add model level macros (#330)

* proxy/config: add model level macros

Add macros to model configuration. Model macros override macros that are
defined at the global configuration level. They follow the same naming
and value rules as the global macros.

* proxy/config: fix bug with macro reserved name checking

The PORT reserved name was not properly checked

* proxy/config: add tests around model.filters.stripParams

- add check that model.filters.stripParams has no invalid macros
- renamed strip_params to stripParams for camel case consistency
- add legacy code compatibility so  model.filters.strip_params continues to work

* proxy/config: add duplicate removal to model.filters.stripParams

* clean up some doc nits
This commit is contained in:
Benson Wong
2025-09-28 23:32:52 -07:00
committed by GitHub
parent 216c40b951
commit 1f6179110c
5 changed files with 321 additions and 161 deletions

View File

@@ -38,12 +38,19 @@ startPort: 10001
# macros: a dictionary of string substitutions # macros: a dictionary of string substitutions
# - optional, default: empty dictionary # - optional, default: empty dictionary
# - macros are reusable snippets # - macros are reusable snippets
# - used in a model's cmd, cmdStop, proxy and checkEndpoint # - used in a model's cmd, cmdStop, proxy, checkEndpoint, filters.stripParams
# - useful for reducing common configuration settings # - useful for reducing common configuration settings
# - macro names are strings and must be less than 64 characters
# - macro names must match the regex ^[a-zA-Z0-9_-]+$
# - macro names must not be a reserved name: PORT or MODEL_ID
# - macro values must be less than 1024 characters
#
# Important: do not nest macros inside other macros; expansion is single-pass
macros: macros:
"latest-llama": > "latest-llama": >
/path/to/llama-server/llama-server-ec9e0301 /path/to/llama-server/llama-server-ec9e0301
--port ${PORT} --port ${PORT}
"default_ctx": "4096"
# models: a dictionary of model configurations # models: a dictionary of model configurations
# - required # - required
@@ -55,6 +62,13 @@ models:
# keys are the model names used in API requests # keys are the model names used in API requests
"llama": "llama":
# macros: a dictionary of string substitutions specific to this model
# - optional, default: empty dictionary
# - macros defined here override macros defined in the global macros section
# - model level macros follow the same rules as global macros
macros:
"default_ctx": "16384"
# cmd: the command to run to start the inference server. # cmd: the command to run to start the inference server.
# - required # - required
# - it is just a string, similar to what you would run on the CLI # - it is just a string, similar to what you would run on the CLI
@@ -64,6 +78,7 @@ models:
# ${latest-llama} is a macro that is defined above # ${latest-llama} is a macro that is defined above
${latest-llama} ${latest-llama}
--model path/to/llama-8B-Q4_K_M.gguf --model path/to/llama-8B-Q4_K_M.gguf
--ctx-size ${default_ctx}
# name: a display name for the model # name: a display name for the model
# - optional, default: empty string # - optional, default: empty string
@@ -119,15 +134,15 @@ models:
# filters: a dictionary of filter settings # filters: a dictionary of filter settings
# - optional, default: empty dictionary # - optional, default: empty dictionary
# - only strip_params is currently supported # - only stripParams is currently supported
filters: filters:
# strip_params: a comma separated list of parameters to remove from the request # stripParams: a comma separated list of parameters to remove from the request
# - optional, default: "" # - optional, default: ""
# - useful for server side enforcement of sampling parameters # - useful for server side enforcement of sampling parameters
# - the `model` parameter can never be removed # - the `model` parameter can never be removed
# - can be any JSON key in the request body # - can be any JSON key in the request body
# - recommended to stick to sampling parameters # - recommended to stick to sampling parameters
strip_params: "temperature, top_p, top_k" stripParams: "temperature, top_p, top_k"
# concurrencyLimit: overrides the allowed number of active parallel requests to a model # concurrencyLimit: overrides the allowed number of active parallel requests to a model
# - optional, default: 0 # - optional, default: 0

View File

@@ -6,7 +6,6 @@ import (
"os" "os"
"regexp" "regexp"
"runtime" "runtime"
"slices"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
@@ -17,101 +16,7 @@ import (
const DEFAULT_GROUP_ID = "(default)" const DEFAULT_GROUP_ID = "(default)"
type ModelConfig struct { type MacroList map[string]string
Cmd string `yaml:"cmd"`
CmdStop string `yaml:"cmdStop"`
Proxy string `yaml:"proxy"`
Aliases []string `yaml:"aliases"`
Env []string `yaml:"env"`
CheckEndpoint string `yaml:"checkEndpoint"`
UnloadAfter int `yaml:"ttl"`
Unlisted bool `yaml:"unlisted"`
UseModelName string `yaml:"useModelName"`
// #179 for /v1/models
Name string `yaml:"name"`
Description string `yaml:"description"`
// Limit concurrency of HTTP requests to process
ConcurrencyLimit int `yaml:"concurrencyLimit"`
// Model filters see issue #174
Filters ModelFilters `yaml:"filters"`
}
func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawModelConfig ModelConfig
defaults := rawModelConfig{
Cmd: "",
CmdStop: "",
Proxy: "http://localhost:${PORT}",
Aliases: []string{},
Env: []string{},
CheckEndpoint: "/health",
UnloadAfter: 0,
Unlisted: false,
UseModelName: "",
ConcurrencyLimit: 0,
Name: "",
Description: "",
}
// the default cmdStop to taskkill /f /t /pid ${PID}
if runtime.GOOS == "windows" {
defaults.CmdStop = "taskkill /f /t /pid ${PID}"
}
if err := unmarshal(&defaults); err != nil {
return err
}
*m = ModelConfig(defaults)
return nil
}
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
return SanitizeCommand(m.Cmd)
}
// ModelFilters see issue #174
type ModelFilters struct {
StripParams string `yaml:"strip_params"`
}
func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawModelFilters ModelFilters
defaults := rawModelFilters{
StripParams: "",
}
if err := unmarshal(&defaults); err != nil {
return err
}
*m = ModelFilters(defaults)
return nil
}
func (f ModelFilters) SanitizedStripParams() ([]string, error) {
if f.StripParams == "" {
return nil, nil
}
params := strings.Split(f.StripParams, ",")
cleaned := make([]string, 0, len(params))
for _, param := range params {
trimmed := strings.TrimSpace(param)
if trimmed == "model" || trimmed == "" {
continue
}
cleaned = append(cleaned, trimmed)
}
// sort cleaned
slices.Sort(cleaned)
return cleaned, nil
}
type GroupConfig struct { type GroupConfig struct {
Swap bool `yaml:"swap"` Swap bool `yaml:"swap"`
@@ -156,7 +61,7 @@ type Config struct {
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */ Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
// for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint // for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint
Macros map[string]string `yaml:"macros"` Macros MacroList `yaml:"macros"`
// map aliases to actual model IDs // map aliases to actual model IDs
aliases map[string]string aliases map[string]string
@@ -240,21 +145,9 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
- name can not be any reserved macros: PORT, MODEL_ID - name can not be any reserved macros: PORT, MODEL_ID
- macro values must be less than 1024 characters - macro values must be less than 1024 characters
*/ */
macroNameRegex := regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
for macroName, macroValue := range config.Macros { for macroName, macroValue := range config.Macros {
if len(macroName) >= 64 { if err = validateMacro(macroName, macroValue); err != nil {
return Config{}, fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", macroName) return Config{}, err
}
if !macroNameRegex.MatchString(macroName) {
return Config{}, fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", macroName)
}
if len(macroValue) >= 1024 {
return Config{}, fmt.Errorf("macro value for '%s' exceeds maximum length of 1024 characters", macroName)
}
switch macroName {
case "PORT":
case "MODEL_ID":
return Config{}, fmt.Errorf("macro name '%s' is reserved and cannot be used", macroName)
} }
} }
@@ -273,8 +166,24 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
modelConfig.Cmd = StripComments(modelConfig.Cmd) modelConfig.Cmd = StripComments(modelConfig.Cmd)
modelConfig.CmdStop = StripComments(modelConfig.CmdStop) modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
// validate model macros
for macroName, macroValue := range modelConfig.Macros {
if err = validateMacro(macroName, macroValue); err != nil {
return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error())
}
}
// Merge global config and model macros. Model macros take precedence
mergedMacros := make(MacroList)
for k, v := range config.Macros {
mergedMacros[k] = v
}
for k, v := range modelConfig.Macros {
mergedMacros[k] = v
}
// go through model config fields: cmd, cmdStop, proxy, checkEndPoint and replace macros with macro values // go through model config fields: cmd, cmdStop, proxy, checkEndPoint and replace macros with macro values
for macroName, macroValue := range config.Macros { for macroName, macroValue := range mergedMacros {
macroSlug := fmt.Sprintf("${%s}", macroName) macroSlug := fmt.Sprintf("${%s}", macroName)
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroValue) modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroValue)
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroValue) modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroValue)
@@ -309,6 +218,7 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
"cmdStop": modelConfig.CmdStop, "cmdStop": modelConfig.CmdStop,
"proxy": modelConfig.Proxy, "proxy": modelConfig.Proxy,
"checkEndpoint": modelConfig.CheckEndpoint, "checkEndpoint": modelConfig.CheckEndpoint,
"filters.stripParams": modelConfig.Filters.StripParams,
} }
for fieldName, fieldValue := range fieldMap { for fieldName, fieldValue := range fieldMap {
@@ -458,3 +368,27 @@ func StripComments(cmdStr string) string {
} }
return strings.Join(cleanedLines, "\n") return strings.Join(cleanedLines, "\n")
} }
var (
macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
)
// validateMacro validates macro name and value constraints
func validateMacro(name, value string) error {
if len(name) >= 64 {
return fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", name)
}
if !macroNameRegex.MatchString(name) {
return fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", name)
}
if len(value) >= 1024 {
return fmt.Errorf("macro value for '%s' exceeds maximum length of 1024 characters", name)
}
switch name {
case "PORT", "MODEL_ID":
return fmt.Errorf("macro name '%s' is reserved", name)
}
return nil
}

View File

@@ -65,18 +65,6 @@ models:
assert.Contains(t, err.Error(), "duplicate alias m1 found in model: model") assert.Contains(t, err.Error(), "duplicate alias m1 found in model: model")
} }
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
config := &ModelConfig{
Cmd: `python model1.py \
--arg1 value1 \
--arg2 value2`,
}
args, err := config.SanitizedCommand()
assert.NoError(t, err)
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
}
func TestConfig_FindConfig(t *testing.T) { func TestConfig_FindConfig(t *testing.T) {
// TODO? // TODO?
@@ -207,15 +195,19 @@ macros:
argOne: "--arg1" argOne: "--arg1"
argTwo: "--arg2" argTwo: "--arg2"
autoPort: "--port ${PORT}" autoPort: "--port ${PORT}"
overriddenByModelMacro: failed
models: models:
model1: model1:
macros:
overriddenByModelMacro: success
cmd: | cmd: |
${svr-path} ${argTwo} ${svr-path} ${argTwo}
# the automatic ${PORT} is replaced # the automatic ${PORT} is replaced
${autoPort} ${autoPort}
${argOne} ${argOne}
--arg3 three --arg3 three
--overridden ${overriddenByModelMacro}
cmdStop: | cmdStop: |
/path/to/stop.sh --port ${PORT} ${argTwo} /path/to/stop.sh --port ${PORT} ${argTwo}
` `
@@ -224,13 +216,68 @@ models:
assert.NoError(t, err) assert.NoError(t, err)
sanitizedCmd, err := SanitizeCommand(config.Models["model1"].Cmd) sanitizedCmd, err := SanitizeCommand(config.Models["model1"].Cmd)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "path/to/server --arg2 --port 9990 --arg1 --arg3 three", strings.Join(sanitizedCmd, " ")) assert.Equal(t, "path/to/server --arg2 --port 9990 --arg1 --arg3 three --overridden success", strings.Join(sanitizedCmd, " "))
sanitizedCmdStop, err := SanitizeCommand(config.Models["model1"].CmdStop) sanitizedCmdStop, err := SanitizeCommand(config.Models["model1"].CmdStop)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "/path/to/stop.sh --port 9990 --arg2", strings.Join(sanitizedCmdStop, " ")) assert.Equal(t, "/path/to/stop.sh --port 9990 --arg2", strings.Join(sanitizedCmdStop, " "))
} }
func TestConfig_MacroReservedNames(t *testing.T) {
tests := []struct {
name string
config string
expectedError string
}{
{
name: "global macro named PORT",
config: `
macros:
PORT: "1111"
`,
expectedError: "macro name 'PORT' is reserved",
},
{
name: "global macro named MODEL_ID",
config: `
macros:
MODEL_ID: model1
`,
expectedError: "macro name 'MODEL_ID' is reserved",
},
{
name: "model macro named PORT",
config: `
models:
model1:
macros:
PORT: 1111
`,
expectedError: "model model1: macro name 'PORT' is reserved",
},
{
name: "model macro named MODEL_ID",
config: `
models:
model1:
macros:
MODEL_ID: model1
`,
expectedError: "model model1: macro name 'MODEL_ID' is reserved",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := LoadConfigFromReader(strings.NewReader(tt.config))
assert.NotNil(t, err)
assert.Equal(t, tt.expectedError, err.Error())
})
}
}
func TestConfig_MacroErrorOnUnknownMacros(t *testing.T) { func TestConfig_MacroErrorOnUnknownMacros(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@@ -288,6 +335,20 @@ models:
model1: model1:
cmd: "${svr-path} --port ${PORT}" cmd: "${svr-path} --port ${PORT}"
checkEndpoint: "http://localhost:${unknownMacro}/health" checkEndpoint: "http://localhost:${unknownMacro}/health"
`,
},
{
name: "unknown macro in filters.stripParams",
field: "filters.stripParams",
content: `
startPort: 9990
macros:
svr-path: "path/to/server"
models:
model1:
cmd: "${svr-path} --port ${PORT}"
filters:
stripParams: "model,${unknownMacro}"
`, `,
}, },
} }
@@ -295,38 +356,13 @@ models:
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
_, err := LoadConfigFromReader(strings.NewReader(tt.content)) _, err := LoadConfigFromReader(strings.NewReader(tt.content))
assert.Error(t, err) if assert.Error(t, err) {
assert.Contains(t, err.Error(), "unknown macro '${unknownMacro}' found in model1."+tt.field) assert.Contains(t, err.Error(), "unknown macro '${unknownMacro}' found in model1."+tt.field)
}
//t.Log(err) //t.Log(err)
}) })
} }
} }
func TestConfig_ModelFilters(t *testing.T) {
content := `
macros:
default_strip: "temperature, top_p"
models:
model1:
cmd: path/to/cmd --port ${PORT}
filters:
strip_params: "model, top_k, ${default_strip}, , ,"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
modelConfig, ok := config.Models["model1"]
if !assert.True(t, ok) {
t.FailNow()
}
// make sure `model` and enmpty strings are not in the list
assert.Equal(t, "model, top_k, temperature, top_p, , ,", modelConfig.Filters.StripParams)
sanitized, err := modelConfig.Filters.SanitizedStripParams()
if assert.NoError(t, err) {
assert.Equal(t, []string{"temperature", "top_k", "top_p"}, sanitized)
}
}
func TestStripComments(t *testing.T) { func TestStripComments(t *testing.T) {
tests := []struct { tests := []struct {
name string name string

View File

@@ -0,0 +1,121 @@
package config
import (
"errors"
"runtime"
"slices"
"strings"
)
type ModelConfig struct {
Cmd string `yaml:"cmd"`
CmdStop string `yaml:"cmdStop"`
Proxy string `yaml:"proxy"`
Aliases []string `yaml:"aliases"`
Env []string `yaml:"env"`
CheckEndpoint string `yaml:"checkEndpoint"`
UnloadAfter int `yaml:"ttl"`
Unlisted bool `yaml:"unlisted"`
UseModelName string `yaml:"useModelName"`
// #179 for /v1/models
Name string `yaml:"name"`
Description string `yaml:"description"`
// Limit concurrency of HTTP requests to process
ConcurrencyLimit int `yaml:"concurrencyLimit"`
// Model filters see issue #174
Filters ModelFilters `yaml:"filters"`
// Macros: see #264
// Model level macros take precedence over the global macros
Macros MacroList `yaml:"macros"`
}
func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawModelConfig ModelConfig
defaults := rawModelConfig{
Cmd: "",
CmdStop: "",
Proxy: "http://localhost:${PORT}",
Aliases: []string{},
Env: []string{},
CheckEndpoint: "/health",
UnloadAfter: 0,
Unlisted: false,
UseModelName: "",
ConcurrencyLimit: 0,
Name: "",
Description: "",
}
// the default cmdStop to taskkill /f /t /pid ${PID}
if runtime.GOOS == "windows" {
defaults.CmdStop = "taskkill /f /t /pid ${PID}"
}
if err := unmarshal(&defaults); err != nil {
return err
}
*m = ModelConfig(defaults)
return nil
}
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
return SanitizeCommand(m.Cmd)
}
// ModelFilters see issue #174
type ModelFilters struct {
StripParams string `yaml:"stripParams"`
}
func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawModelFilters ModelFilters
defaults := rawModelFilters{
StripParams: "",
}
if err := unmarshal(&defaults); err != nil {
return err
}
// Try to unmarshal with the old field name for backwards compatibility
if defaults.StripParams == "" {
var legacy struct {
StripParams string `yaml:"strip_params"`
}
if legacyErr := unmarshal(&legacy); legacyErr != nil {
return errors.New("failed to unmarshal legacy filters.strip_params: " + legacyErr.Error())
}
defaults.StripParams = legacy.StripParams
}
*m = ModelFilters(defaults)
return nil
}
func (f ModelFilters) SanitizedStripParams() ([]string, error) {
if f.StripParams == "" {
return nil, nil
}
params := strings.Split(f.StripParams, ",")
cleaned := make([]string, 0, len(params))
seen := make(map[string]bool)
for _, param := range params {
trimmed := strings.TrimSpace(param)
if trimmed == "model" || trimmed == "" || seen[trimmed] {
continue
}
seen[trimmed] = true
cleaned = append(cleaned, trimmed)
}
// sort cleaned
slices.Sort(cleaned)
return cleaned, nil
}

View File

@@ -0,0 +1,54 @@
package config
import (
"fmt"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
config := &ModelConfig{
Cmd: `python model1.py \
--arg1 value1 \
--arg2 value2`,
}
args, err := config.SanitizedCommand()
assert.NoError(t, err)
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
}
func TestConfig_ModelFilters(t *testing.T) {
content := `
macros:
default_strip: "temperature, top_p"
models:
model1:
cmd: path/to/cmd --port ${PORT}
filters:
# macros inserted and list is cleaned of duplicates and empty strings
stripParams: "model, top_k, top_k, temperature, ${default_strip}, , ,"
# check for strip_params (legacy field name) compatibility
legacy:
cmd: path/to/cmd --port ${PORT}
filters:
strip_params: "model, top_k, top_k, temperature, ${default_strip}, , ,"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
for modelId, modelConfig := range config.Models {
t.Run(fmt.Sprintf("Testing macros in filters for model %s", modelId), func(t *testing.T) {
assert.Equal(t, "model, top_k, top_k, temperature, temperature, top_p, , ,", modelConfig.Filters.StripParams)
sanitized, err := modelConfig.Filters.SanitizedStripParams()
if assert.NoError(t, err) {
// model has been removed
// empty strings have been removed
// duplicates have been removed
assert.Equal(t, []string{"temperature", "top_k", "top_p"}, sanitized)
}
})
}
}