proxy: add support for user defined metadata in model configs (#333)

Changes: 

- add Metadata key to ModelConfig
- include metadata in /v1/models under meta.llamaswap key
- add recursive macro substitution into Metadata
- change macros at global and model level to be any scalar type

Note: 

This is the first mostly AI generated change to llama-swap. See #333 for notes about the workflow and approach to AI going forward.
This commit is contained in:
Benson Wong
2025-10-04 19:56:41 -07:00
committed by GitHub
parent 1f6179110c
commit 70930e4e91
11 changed files with 807 additions and 25 deletions

View File

@@ -16,7 +16,7 @@ import (
const DEFAULT_GROUP_ID = "(default)"
type MacroList map[string]string
type MacroList map[string]any
type GroupConfig struct {
Swap bool `yaml:"swap"`
@@ -25,6 +25,11 @@ type GroupConfig struct {
Members []string `yaml:"members"`
}
var (
macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
macroPatternRegex = regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
)
// set default values for GroupConfig
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawGroupConfig GroupConfig
@@ -182,14 +187,18 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
mergedMacros[k] = v
}
mergedMacros["MODEL_ID"] = modelId
// go through model config fields: cmd, cmdStop, proxy, checkEndPoint and replace macros with macro values
for macroName, macroValue := range mergedMacros {
macroSlug := fmt.Sprintf("${%s}", macroName)
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroValue)
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroValue)
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroValue)
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroValue)
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroValue)
// Convert macro value to string for command/string field substitution
macroStr := fmt.Sprintf("%v", macroValue)
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
}
// enforce ${PORT} used in both cmd and proxy
@@ -203,16 +212,14 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, "${PORT}", nextPortStr)
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, "${PORT}", nextPortStr)
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, "${PORT}", nextPortStr)
// add port to merged macros so it can be used in metadata
mergedMacros["PORT"] = nextPort
nextPort++
}
if strings.Contains(modelConfig.Cmd, "${MODEL_ID}") || strings.Contains(modelConfig.CmdStop, "${MODEL_ID}") {
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, "${MODEL_ID}", modelId)
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, "${MODEL_ID}", modelId)
}
// make sure there are no unknown macros that have not been replaced
macroPattern := regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
fieldMap := map[string]string{
"cmd": modelConfig.Cmd,
"cmdStop": modelConfig.CmdStop,
@@ -222,7 +229,7 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
}
for fieldName, fieldValue := range fieldMap {
matches := macroPattern.FindAllStringSubmatch(fieldValue, -1)
matches := macroPatternRegex.FindAllStringSubmatch(fieldValue, -1)
for _, match := range matches {
macroName := match[1]
if macroName == "PID" && fieldName == "cmdStop" {
@@ -234,6 +241,15 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
}
}
// Apply macro substitution to metadata
if len(modelConfig.Metadata) > 0 {
substitutedMetadata, err := substituteMetadataMacros(modelConfig.Metadata, mergedMacros)
if err != nil {
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
}
modelConfig.Metadata = substitutedMetadata.(map[string]any)
}
config.Models[modelId] = modelConfig
}
@@ -296,7 +312,7 @@ func AddDefaultGroupToConfig(config Config) Config {
}
} else {
// iterate over existing group members and add non-grouped models into the default group
for modelName, _ := range config.Models {
for modelName := range config.Models {
foundModel := false
found:
// search for the model in existing groups
@@ -369,20 +385,25 @@ func StripComments(cmdStr string) string {
return strings.Join(cleanedLines, "\n")
}
var (
macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
)
// validateMacro validates macro name and value constraints
func validateMacro(name, value string) error {
func validateMacro(name string, value any) error {
if len(name) >= 64 {
return fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", name)
}
if !macroNameRegex.MatchString(name) {
return fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", name)
}
if len(value) >= 1024 {
return fmt.Errorf("macro value for '%s' exceeds maximum length of 1024 characters", name)
// Validate that value is a scalar type
switch v := value.(type) {
case string:
if len(v) >= 1024 {
return fmt.Errorf("macro value for '%s' exceeds maximum length of 1024 characters", name)
}
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool:
// These types are allowed
default:
return fmt.Errorf("macro '%s' has invalid type %T, must be a scalar type (string, int, float, or bool)", name, value)
}
switch name {
@@ -392,3 +413,63 @@ func validateMacro(name, value string) error {
return nil
}
// substituteMetadataMacros recursively substitutes macros in metadata structures
// Direct substitution (key: ${macro}) preserves the macro's type
// Interpolated substitution (key: "text ${macro}") converts to string
func substituteMetadataMacros(value any, macros MacroList) (any, error) {
switch v := value.(type) {
case string:
// Check if this is a direct macro substitution
if strings.HasPrefix(v, "${") && strings.HasSuffix(v, "}") && strings.Count(v, "${") == 1 {
macroName := v[2 : len(v)-1]
if macroValue, exists := macros[macroName]; exists {
return macroValue, nil
}
return nil, fmt.Errorf("unknown macro '${%s}' in metadata", macroName)
}
// Handle string interpolation
matches := macroPatternRegex.FindAllStringSubmatch(v, -1)
result := v
for _, match := range matches {
macroName := match[1]
macroValue, exists := macros[macroName]
if !exists {
return nil, fmt.Errorf("unknown macro '${%s}' in metadata", macroName)
}
// Convert macro value to string for interpolation
macroStr := fmt.Sprintf("%v", macroValue)
result = strings.ReplaceAll(result, match[0], macroStr)
}
return result, nil
case map[string]any:
// Recursively process map values
newMap := make(map[string]any)
for key, val := range v {
newVal, err := substituteMetadataMacros(val, macros)
if err != nil {
return nil, err
}
newMap[key] = newVal
}
return newMap, nil
case []any:
// Recursively process slice elements
newSlice := make([]any, len(v))
for i, val := range v {
newVal, err := substituteMetadataMacros(val, macros)
if err != nil {
return nil, err
}
newSlice[i] = newVal
}
return newSlice, nil
default:
// Return scalar types as-is
return value, nil
}
}

View File

@@ -163,7 +163,7 @@ groups:
expected := Config{
LogLevel: "info",
StartPort: 5800,
Macros: map[string]string{
Macros: MacroList{
"svr-path": "path/to/server",
},
Hooks: HooksConfig{

View File

@@ -517,3 +517,243 @@ models:
assert.NoError(t, err)
assert.Equal(t, "/path/to/server -p 9000 -hf author/model:F16", strings.Join(sanitizedCmd3, " "))
}
func TestConfig_TypedMacrosInMetadata(t *testing.T) {
content := `
startPort: 10000
macros:
PORT_NUM: 10001
TEMP: 0.7
ENABLED: true
NAME: "llama model"
CTX: 16384
models:
test-model:
cmd: /path/to/server -p ${PORT}
metadata:
port: ${PORT_NUM}
temperature: ${TEMP}
enabled: ${ENABLED}
model_name: ${NAME}
context: ${CTX}
note: "Running on port ${PORT_NUM} with temp ${TEMP} and context ${CTX}"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
meta := config.Models["test-model"].Metadata
assert.NotNil(t, meta)
// Verify direct substitution preserves types
assert.Equal(t, 10001, meta["port"])
assert.Equal(t, 0.7, meta["temperature"])
assert.Equal(t, true, meta["enabled"])
assert.Equal(t, "llama model", meta["model_name"])
assert.Equal(t, 16384, meta["context"])
// Verify string interpolation converts to string
assert.Equal(t, "Running on port 10001 with temp 0.7 and context 16384", meta["note"])
}
func TestConfig_NestedStructuresInMetadata(t *testing.T) {
content := `
startPort: 10000
macros:
TEMP: 0.7
models:
test-model:
cmd: /path/to/server -p ${PORT}
metadata:
config:
port: ${PORT}
temperature: ${TEMP}
tags: ["model:${MODEL_ID}", "port:${PORT}"]
nested:
deep:
value: ${TEMP}
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
meta := config.Models["test-model"].Metadata
assert.NotNil(t, meta)
// Verify nested objects
configMap := meta["config"].(map[string]any)
assert.Equal(t, 10000, configMap["port"])
assert.Equal(t, 0.7, configMap["temperature"])
// Verify arrays
tags := meta["tags"].([]any)
assert.Equal(t, "model:test-model", tags[0])
assert.Equal(t, "port:10000", tags[1])
// Verify deeply nested structures
nested := meta["nested"].(map[string]any)
deep := nested["deep"].(map[string]any)
assert.Equal(t, 0.7, deep["value"])
}
func TestConfig_ModelLevelMacroPrecedenceInMetadata(t *testing.T) {
content := `
startPort: 10000
macros:
TEMP: 0.5
GLOBAL_VAL: "global"
models:
test-model:
cmd: /path/to/server -p ${PORT}
macros:
TEMP: 0.9
LOCAL_VAL: "local"
metadata:
temperature: ${TEMP}
global: ${GLOBAL_VAL}
local: ${LOCAL_VAL}
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
meta := config.Models["test-model"].Metadata
assert.NotNil(t, meta)
// Model-level macro should override global
assert.Equal(t, 0.9, meta["temperature"])
// Global macro should be accessible
assert.Equal(t, "global", meta["global"])
// Model-level macro should be accessible
assert.Equal(t, "local", meta["local"])
}
func TestConfig_UnknownMacroInMetadata(t *testing.T) {
content := `
startPort: 10000
models:
test-model:
cmd: /path/to/server -p ${PORT}
metadata:
value: ${UNKNOWN_MACRO}
`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.Error(t, err)
assert.Contains(t, err.Error(), "test-model")
assert.Contains(t, err.Error(), "UNKNOWN_MACRO")
}
func TestConfig_InvalidMacroType(t *testing.T) {
content := `
startPort: 10000
macros:
INVALID:
nested: value
models:
test-model:
cmd: /path/to/server -p ${PORT}
`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.Error(t, err)
assert.Contains(t, err.Error(), "INVALID")
assert.Contains(t, err.Error(), "must be a scalar type")
}
func TestConfig_MacroTypeValidation(t *testing.T) {
tests := []struct {
name string
yaml string
shouldErr bool
}{
{
name: "string macro",
yaml: `
startPort: 10000
macros:
STR: "test"
models:
test-model:
cmd: /path/to/server -p ${PORT}
`,
shouldErr: false,
},
{
name: "int macro",
yaml: `
startPort: 10000
macros:
NUM: 42
models:
test-model:
cmd: /path/to/server -p ${PORT}
`,
shouldErr: false,
},
{
name: "float macro",
yaml: `
startPort: 10000
macros:
FLOAT: 3.14
models:
test-model:
cmd: /path/to/server -p ${PORT}
`,
shouldErr: false,
},
{
name: "bool macro",
yaml: `
startPort: 10000
macros:
BOOL: true
models:
test-model:
cmd: /path/to/server -p ${PORT}
`,
shouldErr: false,
},
{
name: "array macro (invalid)",
yaml: `
startPort: 10000
macros:
ARR: [1, 2, 3]
models:
test-model:
cmd: /path/to/server -p ${PORT}
`,
shouldErr: true,
},
{
name: "map macro (invalid)",
yaml: `
startPort: 10000
macros:
MAP:
key: value
models:
test-model:
cmd: /path/to/server -p ${PORT}
`,
shouldErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := LoadConfigFromReader(strings.NewReader(tt.yaml))
if tt.shouldErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@@ -155,7 +155,7 @@ groups:
expected := Config{
LogLevel: "info",
StartPort: 5800,
Macros: map[string]string{
Macros: MacroList{
"svr-path": "path/to/server",
},
Models: map[string]ModelConfig{

View File

@@ -31,6 +31,10 @@ type ModelConfig struct {
// Macros: see #264
// Model level macros take precedence over the global macros
Macros MacroList `yaml:"macros"`
// Metadata: see #264
// Arbitrary metadata that can be exposed through the API
Metadata map[string]any `yaml:"metadata"`
}
func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {