proxy: add support for user defined metadata in model configs (#333)
Changes: - add Metadata key to ModelConfig - include metadata in /v1/models under meta.llamaswap key - add recursive macro substitution into Metadata - change macros at global and model level to be any scalar type Note: This is the first mostly AI generated change to llama-swap. See #333 for notes about the workflow and approach to AI going forward.
This commit is contained in:
@@ -16,7 +16,7 @@ import (
|
||||
|
||||
const DEFAULT_GROUP_ID = "(default)"
|
||||
|
||||
type MacroList map[string]string
|
||||
type MacroList map[string]any
|
||||
|
||||
type GroupConfig struct {
|
||||
Swap bool `yaml:"swap"`
|
||||
@@ -25,6 +25,11 @@ type GroupConfig struct {
|
||||
Members []string `yaml:"members"`
|
||||
}
|
||||
|
||||
var (
|
||||
macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
||||
macroPatternRegex = regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
|
||||
)
|
||||
|
||||
// set default values for GroupConfig
|
||||
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawGroupConfig GroupConfig
|
||||
@@ -182,14 +187,18 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
mergedMacros[k] = v
|
||||
}
|
||||
|
||||
mergedMacros["MODEL_ID"] = modelId
|
||||
|
||||
// go through model config fields: cmd, cmdStop, proxy, checkEndPoint and replace macros with macro values
|
||||
for macroName, macroValue := range mergedMacros {
|
||||
macroSlug := fmt.Sprintf("${%s}", macroName)
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroValue)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroValue)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroValue)
|
||||
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroValue)
|
||||
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroValue)
|
||||
// Convert macro value to string for command/string field substitution
|
||||
macroStr := fmt.Sprintf("%v", macroValue)
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
|
||||
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
|
||||
}
|
||||
|
||||
// enforce ${PORT} used in both cmd and proxy
|
||||
@@ -203,16 +212,14 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, "${PORT}", nextPortStr)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, "${PORT}", nextPortStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, "${PORT}", nextPortStr)
|
||||
|
||||
// add port to merged macros so it can be used in metadata
|
||||
mergedMacros["PORT"] = nextPort
|
||||
|
||||
nextPort++
|
||||
}
|
||||
|
||||
if strings.Contains(modelConfig.Cmd, "${MODEL_ID}") || strings.Contains(modelConfig.CmdStop, "${MODEL_ID}") {
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, "${MODEL_ID}", modelId)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, "${MODEL_ID}", modelId)
|
||||
}
|
||||
|
||||
// make sure there are no unknown macros that have not been replaced
|
||||
macroPattern := regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
|
||||
fieldMap := map[string]string{
|
||||
"cmd": modelConfig.Cmd,
|
||||
"cmdStop": modelConfig.CmdStop,
|
||||
@@ -222,7 +229,7 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
}
|
||||
|
||||
for fieldName, fieldValue := range fieldMap {
|
||||
matches := macroPattern.FindAllStringSubmatch(fieldValue, -1)
|
||||
matches := macroPatternRegex.FindAllStringSubmatch(fieldValue, -1)
|
||||
for _, match := range matches {
|
||||
macroName := match[1]
|
||||
if macroName == "PID" && fieldName == "cmdStop" {
|
||||
@@ -234,6 +241,15 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Apply macro substitution to metadata
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
substitutedMetadata, err := substituteMetadataMacros(modelConfig.Metadata, mergedMacros)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||
}
|
||||
modelConfig.Metadata = substitutedMetadata.(map[string]any)
|
||||
}
|
||||
|
||||
config.Models[modelId] = modelConfig
|
||||
}
|
||||
|
||||
@@ -296,7 +312,7 @@ func AddDefaultGroupToConfig(config Config) Config {
|
||||
}
|
||||
} else {
|
||||
// iterate over existing group members and add non-grouped models into the default group
|
||||
for modelName, _ := range config.Models {
|
||||
for modelName := range config.Models {
|
||||
foundModel := false
|
||||
found:
|
||||
// search for the model in existing groups
|
||||
@@ -369,20 +385,25 @@ func StripComments(cmdStr string) string {
|
||||
return strings.Join(cleanedLines, "\n")
|
||||
}
|
||||
|
||||
var (
|
||||
macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
||||
)
|
||||
|
||||
// validateMacro validates macro name and value constraints
|
||||
func validateMacro(name, value string) error {
|
||||
func validateMacro(name string, value any) error {
|
||||
if len(name) >= 64 {
|
||||
return fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", name)
|
||||
}
|
||||
if !macroNameRegex.MatchString(name) {
|
||||
return fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", name)
|
||||
}
|
||||
if len(value) >= 1024 {
|
||||
return fmt.Errorf("macro value for '%s' exceeds maximum length of 1024 characters", name)
|
||||
|
||||
// Validate that value is a scalar type
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
if len(v) >= 1024 {
|
||||
return fmt.Errorf("macro value for '%s' exceeds maximum length of 1024 characters", name)
|
||||
}
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool:
|
||||
// These types are allowed
|
||||
default:
|
||||
return fmt.Errorf("macro '%s' has invalid type %T, must be a scalar type (string, int, float, or bool)", name, value)
|
||||
}
|
||||
|
||||
switch name {
|
||||
@@ -392,3 +413,63 @@ func validateMacro(name, value string) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// substituteMetadataMacros recursively substitutes macros in metadata structures
|
||||
// Direct substitution (key: ${macro}) preserves the macro's type
|
||||
// Interpolated substitution (key: "text ${macro}") converts to string
|
||||
func substituteMetadataMacros(value any, macros MacroList) (any, error) {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
// Check if this is a direct macro substitution
|
||||
if strings.HasPrefix(v, "${") && strings.HasSuffix(v, "}") && strings.Count(v, "${") == 1 {
|
||||
macroName := v[2 : len(v)-1]
|
||||
if macroValue, exists := macros[macroName]; exists {
|
||||
return macroValue, nil
|
||||
}
|
||||
return nil, fmt.Errorf("unknown macro '${%s}' in metadata", macroName)
|
||||
}
|
||||
|
||||
// Handle string interpolation
|
||||
matches := macroPatternRegex.FindAllStringSubmatch(v, -1)
|
||||
result := v
|
||||
for _, match := range matches {
|
||||
macroName := match[1]
|
||||
macroValue, exists := macros[macroName]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("unknown macro '${%s}' in metadata", macroName)
|
||||
}
|
||||
// Convert macro value to string for interpolation
|
||||
macroStr := fmt.Sprintf("%v", macroValue)
|
||||
result = strings.ReplaceAll(result, match[0], macroStr)
|
||||
}
|
||||
return result, nil
|
||||
|
||||
case map[string]any:
|
||||
// Recursively process map values
|
||||
newMap := make(map[string]any)
|
||||
for key, val := range v {
|
||||
newVal, err := substituteMetadataMacros(val, macros)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newMap[key] = newVal
|
||||
}
|
||||
return newMap, nil
|
||||
|
||||
case []any:
|
||||
// Recursively process slice elements
|
||||
newSlice := make([]any, len(v))
|
||||
for i, val := range v {
|
||||
newVal, err := substituteMetadataMacros(val, macros)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newSlice[i] = newVal
|
||||
}
|
||||
return newSlice, nil
|
||||
|
||||
default:
|
||||
// Return scalar types as-is
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user