proxy/config: add model level macros (#330)

* proxy/config: add model level macros

Add macros to model configuration. Model macros override macros that are
defined at the global configuration level. They follow the same naming
and value rules as the global macros.

* proxy/config: fix bug with macro reserved name checking

The PORT reserved name was not properly checked

* proxy/config: add tests around model.filters.stripParams

- add check that model.filters.stripParams has no invalid macros
- renamed strip_params to stripParams for camel case consistency
- add legacy code compatibility so  model.filters.strip_params continues to work

* proxy/config: add duplicate removal to model.filters.stripParams

* clean up some doc nits
This commit is contained in:
Benson Wong
2025-09-28 23:32:52 -07:00
committed by GitHub
parent 216c40b951
commit 1f6179110c
5 changed files with 321 additions and 161 deletions

View File

@@ -6,7 +6,6 @@ import (
"os"
"regexp"
"runtime"
"slices"
"sort"
"strconv"
"strings"
@@ -17,101 +16,7 @@ import (
const DEFAULT_GROUP_ID = "(default)"
type ModelConfig struct {
Cmd string `yaml:"cmd"`
CmdStop string `yaml:"cmdStop"`
Proxy string `yaml:"proxy"`
Aliases []string `yaml:"aliases"`
Env []string `yaml:"env"`
CheckEndpoint string `yaml:"checkEndpoint"`
UnloadAfter int `yaml:"ttl"`
Unlisted bool `yaml:"unlisted"`
UseModelName string `yaml:"useModelName"`
// #179 for /v1/models
Name string `yaml:"name"`
Description string `yaml:"description"`
// Limit concurrency of HTTP requests to process
ConcurrencyLimit int `yaml:"concurrencyLimit"`
// Model filters see issue #174
Filters ModelFilters `yaml:"filters"`
}
func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawModelConfig ModelConfig
defaults := rawModelConfig{
Cmd: "",
CmdStop: "",
Proxy: "http://localhost:${PORT}",
Aliases: []string{},
Env: []string{},
CheckEndpoint: "/health",
UnloadAfter: 0,
Unlisted: false,
UseModelName: "",
ConcurrencyLimit: 0,
Name: "",
Description: "",
}
// the default cmdStop to taskkill /f /t /pid ${PID}
if runtime.GOOS == "windows" {
defaults.CmdStop = "taskkill /f /t /pid ${PID}"
}
if err := unmarshal(&defaults); err != nil {
return err
}
*m = ModelConfig(defaults)
return nil
}
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
return SanitizeCommand(m.Cmd)
}
// ModelFilters see issue #174
type ModelFilters struct {
StripParams string `yaml:"strip_params"`
}
func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawModelFilters ModelFilters
defaults := rawModelFilters{
StripParams: "",
}
if err := unmarshal(&defaults); err != nil {
return err
}
*m = ModelFilters(defaults)
return nil
}
func (f ModelFilters) SanitizedStripParams() ([]string, error) {
if f.StripParams == "" {
return nil, nil
}
params := strings.Split(f.StripParams, ",")
cleaned := make([]string, 0, len(params))
for _, param := range params {
trimmed := strings.TrimSpace(param)
if trimmed == "model" || trimmed == "" {
continue
}
cleaned = append(cleaned, trimmed)
}
// sort cleaned
slices.Sort(cleaned)
return cleaned, nil
}
type MacroList map[string]string
type GroupConfig struct {
Swap bool `yaml:"swap"`
@@ -156,7 +61,7 @@ type Config struct {
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
// for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint
Macros map[string]string `yaml:"macros"`
Macros MacroList `yaml:"macros"`
// map aliases to actual model IDs
aliases map[string]string
@@ -240,21 +145,9 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
- name can not be any reserved macros: PORT, MODEL_ID
- macro values must be less than 1024 characters
*/
macroNameRegex := regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
for macroName, macroValue := range config.Macros {
if len(macroName) >= 64 {
return Config{}, fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", macroName)
}
if !macroNameRegex.MatchString(macroName) {
return Config{}, fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", macroName)
}
if len(macroValue) >= 1024 {
return Config{}, fmt.Errorf("macro value for '%s' exceeds maximum length of 1024 characters", macroName)
}
switch macroName {
case "PORT":
case "MODEL_ID":
return Config{}, fmt.Errorf("macro name '%s' is reserved and cannot be used", macroName)
if err = validateMacro(macroName, macroValue); err != nil {
return Config{}, err
}
}
@@ -273,8 +166,24 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
modelConfig.Cmd = StripComments(modelConfig.Cmd)
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
// validate model macros
for macroName, macroValue := range modelConfig.Macros {
if err = validateMacro(macroName, macroValue); err != nil {
return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error())
}
}
// Merge global config and model macros. Model macros take precedence
mergedMacros := make(MacroList)
for k, v := range config.Macros {
mergedMacros[k] = v
}
for k, v := range modelConfig.Macros {
mergedMacros[k] = v
}
// go through model config fields: cmd, cmdStop, proxy, checkEndPoint and replace macros with macro values
for macroName, macroValue := range config.Macros {
for macroName, macroValue := range mergedMacros {
macroSlug := fmt.Sprintf("${%s}", macroName)
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroValue)
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroValue)
@@ -305,10 +214,11 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
// make sure there are no unknown macros that have not been replaced
macroPattern := regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
fieldMap := map[string]string{
"cmd": modelConfig.Cmd,
"cmdStop": modelConfig.CmdStop,
"proxy": modelConfig.Proxy,
"checkEndpoint": modelConfig.CheckEndpoint,
"cmd": modelConfig.Cmd,
"cmdStop": modelConfig.CmdStop,
"proxy": modelConfig.Proxy,
"checkEndpoint": modelConfig.CheckEndpoint,
"filters.stripParams": modelConfig.Filters.StripParams,
}
for fieldName, fieldValue := range fieldMap {
@@ -458,3 +368,27 @@ func StripComments(cmdStr string) string {
}
return strings.Join(cleanedLines, "\n")
}
var (
macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
)
// validateMacro validates macro name and value constraints
func validateMacro(name, value string) error {
if len(name) >= 64 {
return fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", name)
}
if !macroNameRegex.MatchString(name) {
return fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", name)
}
if len(value) >= 1024 {
return fmt.Errorf("macro value for '%s' exceeds maximum length of 1024 characters", name)
}
switch name {
case "PORT", "MODEL_ID":
return fmt.Errorf("macro name '%s' is reserved", name)
}
return nil
}