proxy: add support for user defined metadata in model configs (#333)

Changes: 

- add Metadata key to ModelConfig
- include metadata in /v1/models under meta.llamaswap key
- add recursive macro substitution into Metadata
- change macros at global and model level to be any scalar type

Note: 

This is the first mostly AI generated change to llama-swap. See #333 for notes about the workflow and approach to AI going forward.
This commit is contained in:
Benson Wong
2025-10-04 19:56:41 -07:00
committed by GitHub
parent 1f6179110c
commit 70930e4e91
11 changed files with 807 additions and 25 deletions

View File

@@ -16,7 +16,7 @@ import (
const DEFAULT_GROUP_ID = "(default)"
type MacroList map[string]string
type MacroList map[string]any
type GroupConfig struct {
Swap bool `yaml:"swap"`
@@ -25,6 +25,11 @@ type GroupConfig struct {
Members []string `yaml:"members"`
}
var (
macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
macroPatternRegex = regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
)
// set default values for GroupConfig
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawGroupConfig GroupConfig
@@ -182,14 +187,18 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
mergedMacros[k] = v
}
mergedMacros["MODEL_ID"] = modelId
// go through model config fields: cmd, cmdStop, proxy, checkEndPoint and replace macros with macro values
for macroName, macroValue := range mergedMacros {
macroSlug := fmt.Sprintf("${%s}", macroName)
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroValue)
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroValue)
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroValue)
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroValue)
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroValue)
// Convert macro value to string for command/string field substitution
macroStr := fmt.Sprintf("%v", macroValue)
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
}
// enforce ${PORT} used in both cmd and proxy
@@ -203,16 +212,14 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, "${PORT}", nextPortStr)
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, "${PORT}", nextPortStr)
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, "${PORT}", nextPortStr)
// add port to merged macros so it can be used in metadata
mergedMacros["PORT"] = nextPort
nextPort++
}
if strings.Contains(modelConfig.Cmd, "${MODEL_ID}") || strings.Contains(modelConfig.CmdStop, "${MODEL_ID}") {
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, "${MODEL_ID}", modelId)
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, "${MODEL_ID}", modelId)
}
// make sure there are no unknown macros that have not been replaced
macroPattern := regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
fieldMap := map[string]string{
"cmd": modelConfig.Cmd,
"cmdStop": modelConfig.CmdStop,
@@ -222,7 +229,7 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
}
for fieldName, fieldValue := range fieldMap {
matches := macroPattern.FindAllStringSubmatch(fieldValue, -1)
matches := macroPatternRegex.FindAllStringSubmatch(fieldValue, -1)
for _, match := range matches {
macroName := match[1]
if macroName == "PID" && fieldName == "cmdStop" {
@@ -234,6 +241,15 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
}
}
// Apply macro substitution to metadata
if len(modelConfig.Metadata) > 0 {
substitutedMetadata, err := substituteMetadataMacros(modelConfig.Metadata, mergedMacros)
if err != nil {
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
}
modelConfig.Metadata = substitutedMetadata.(map[string]any)
}
config.Models[modelId] = modelConfig
}
@@ -296,7 +312,7 @@ func AddDefaultGroupToConfig(config Config) Config {
}
} else {
// iterate over existing group members and add non-grouped models into the default group
for modelName, _ := range config.Models {
for modelName := range config.Models {
foundModel := false
found:
// search for the model in existing groups
@@ -369,20 +385,25 @@ func StripComments(cmdStr string) string {
return strings.Join(cleanedLines, "\n")
}
var (
macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
)
// validateMacro validates macro name and value constraints
func validateMacro(name, value string) error {
func validateMacro(name string, value any) error {
if len(name) >= 64 {
return fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", name)
}
if !macroNameRegex.MatchString(name) {
return fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", name)
}
if len(value) >= 1024 {
return fmt.Errorf("macro value for '%s' exceeds maximum length of 1024 characters", name)
// Validate that value is a scalar type
switch v := value.(type) {
case string:
if len(v) >= 1024 {
return fmt.Errorf("macro value for '%s' exceeds maximum length of 1024 characters", name)
}
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool:
// These types are allowed
default:
return fmt.Errorf("macro '%s' has invalid type %T, must be a scalar type (string, int, float, or bool)", name, value)
}
switch name {
@@ -392,3 +413,63 @@ func validateMacro(name, value string) error {
return nil
}
// substituteMetadataMacros recursively substitutes macros in metadata structures
// Direct substitution (key: ${macro}) preserves the macro's type
// Interpolated substitution (key: "text ${macro}") converts to string
func substituteMetadataMacros(value any, macros MacroList) (any, error) {
switch v := value.(type) {
case string:
// Check if this is a direct macro substitution
if strings.HasPrefix(v, "${") && strings.HasSuffix(v, "}") && strings.Count(v, "${") == 1 {
macroName := v[2 : len(v)-1]
if macroValue, exists := macros[macroName]; exists {
return macroValue, nil
}
return nil, fmt.Errorf("unknown macro '${%s}' in metadata", macroName)
}
// Handle string interpolation
matches := macroPatternRegex.FindAllStringSubmatch(v, -1)
result := v
for _, match := range matches {
macroName := match[1]
macroValue, exists := macros[macroName]
if !exists {
return nil, fmt.Errorf("unknown macro '${%s}' in metadata", macroName)
}
// Convert macro value to string for interpolation
macroStr := fmt.Sprintf("%v", macroValue)
result = strings.ReplaceAll(result, match[0], macroStr)
}
return result, nil
case map[string]any:
// Recursively process map values
newMap := make(map[string]any)
for key, val := range v {
newVal, err := substituteMetadataMacros(val, macros)
if err != nil {
return nil, err
}
newMap[key] = newVal
}
return newMap, nil
case []any:
// Recursively process slice elements
newSlice := make([]any, len(v))
for i, val := range v {
newVal, err := substituteMetadataMacros(val, macros)
if err != nil {
return nil, err
}
newSlice[i] = newVal
}
return newSlice, nil
default:
// Return scalar types as-is
return value, nil
}
}