Add Macro-In-Macro Support (#337)

Add full macro-in-macro support so any user defined macro can contain another one as long as it was previously declared in the configuration file.

Fixes #336
Supercedes #335
This commit is contained in:
Benson Wong
2025-10-06 22:57:15 -07:00
committed by GitHub
parent 70930e4e91
commit 00b738cd0f
7 changed files with 718 additions and 71 deletions

View File

@@ -7,7 +7,6 @@ import (
"regexp"
"runtime"
"sort"
"strconv"
"strings"
"github.com/billziss-gh/golib/shlex"
@@ -16,7 +15,60 @@ import (
const DEFAULT_GROUP_ID = "(default)"
type MacroList map[string]any
type MacroEntry struct {
Name string
Value any
}
type MacroList []MacroEntry
// UnmarshalYAML implements custom YAML unmarshaling that preserves macro definition order
func (ml *MacroList) UnmarshalYAML(value *yaml.Node) error {
if value.Kind != yaml.MappingNode {
return fmt.Errorf("macros must be a mapping")
}
// yaml.Node.Content for a mapping contains alternating key/value nodes
entries := make([]MacroEntry, 0, len(value.Content)/2)
for i := 0; i < len(value.Content); i += 2 {
keyNode := value.Content[i]
valueNode := value.Content[i+1]
var name string
if err := keyNode.Decode(&name); err != nil {
return fmt.Errorf("failed to decode macro name: %w", err)
}
var val any
if err := valueNode.Decode(&val); err != nil {
return fmt.Errorf("failed to decode macro value for '%s': %w", name, err)
}
entries = append(entries, MacroEntry{Name: name, Value: val})
}
*ml = entries
return nil
}
// Get retrieves a macro value by name
func (ml MacroList) Get(name string) (any, bool) {
for _, entry := range ml {
if entry.Name == name {
return entry.Value, true
}
}
return nil, false
}
// ToMap converts MacroList to a map (for backward compatibility if needed)
func (ml MacroList) ToMap() map[string]any {
result := make(map[string]any, len(ml))
for _, entry := range ml {
result[entry.Name] = entry.Value
}
return result
}
type GroupConfig struct {
Swap bool `yaml:"swap"`
@@ -150,8 +202,8 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
- name can not be any reserved macros: PORT, MODEL_ID
- macro values must be less than 1024 characters
*/
for macroName, macroValue := range config.Macros {
if err = validateMacro(macroName, macroValue); err != nil {
for _, macro := range config.Macros {
if err = validateMacro(macro.Name, macro.Value); err != nil {
return Config{}, err
}
}
@@ -172,49 +224,88 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
// validate model macros
for macroName, macroValue := range modelConfig.Macros {
if err = validateMacro(macroName, macroValue); err != nil {
for _, macro := range modelConfig.Macros {
if err = validateMacro(macro.Name, macro.Value); err != nil {
return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error())
}
}
// Merge global config and model macros. Model macros take precedence
mergedMacros := make(MacroList)
for k, v := range config.Macros {
mergedMacros[k] = v
}
for k, v := range modelConfig.Macros {
mergedMacros[k] = v
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros))
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
// Add global macros first
mergedMacros = append(mergedMacros, config.Macros...)
// Add model macros (can override global)
for _, entry := range modelConfig.Macros {
// Remove any existing global macro with same name
found := false
for i, existing := range mergedMacros {
if existing.Name == entry.Name {
mergedMacros[i] = entry // Override
found = true
break
}
}
if !found {
mergedMacros = append(mergedMacros, entry)
}
}
mergedMacros["MODEL_ID"] = modelId
// First pass: Substitute user-defined macros in reverse order (LIFO - last defined first)
// This allows later macros to reference earlier ones
for i := len(mergedMacros) - 1; i >= 0; i-- {
entry := mergedMacros[i]
macroSlug := fmt.Sprintf("${%s}", entry.Name)
macroStr := fmt.Sprintf("%v", entry.Value)
// go through model config fields: cmd, cmdStop, proxy, checkEndPoint and replace macros with macro values
for macroName, macroValue := range mergedMacros {
macroSlug := fmt.Sprintf("${%s}", macroName)
// Convert macro value to string for command/string field substitution
macroStr := fmt.Sprintf("%v", macroValue)
// Substitute in command fields
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
// Substitute in metadata (recursive)
if len(modelConfig.Metadata) > 0 {
var err error
result, err := substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
if err != nil {
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
}
modelConfig.Metadata = result.(map[string]any)
}
}
// enforce ${PORT} used in both cmd and proxy
if !strings.Contains(modelConfig.Cmd, "${PORT}") && strings.Contains(modelConfig.Proxy, "${PORT}") {
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
}
// Final pass: check if PORT macro is needed after macro expansion
// ${PORT} is a resource on the local machine so a new port is only allocated
// if it is required in either cmd or proxy keys
cmdHasPort := strings.Contains(modelConfig.Cmd, "${PORT}")
proxyHasPort := strings.Contains(modelConfig.Proxy, "${PORT}")
if cmdHasPort || proxyHasPort { // either has it
if !cmdHasPort && proxyHasPort { // but both don't have it
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
}
// only iterate over models that use ${PORT} to keep port numbers from increasing unnecessarily
if strings.Contains(modelConfig.Cmd, "${PORT}") || strings.Contains(modelConfig.Proxy, "${PORT}") || strings.Contains(modelConfig.CmdStop, "${PORT}") {
nextPortStr := strconv.Itoa(nextPort)
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, "${PORT}", nextPortStr)
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, "${PORT}", nextPortStr)
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, "${PORT}", nextPortStr)
// Add PORT macro and substitute it
portEntry := MacroEntry{Name: "PORT", Value: nextPort}
macroSlug := "${PORT}"
macroStr := fmt.Sprintf("%v", nextPort)
// add port to merged macros so it can be used in metadata
mergedMacros["PORT"] = nextPort
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
// Substitute PORT in metadata
if len(modelConfig.Metadata) > 0 {
var err error
result, err := substituteMacroInValue(modelConfig.Metadata, portEntry.Name, portEntry.Value)
if err != nil {
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
}
modelConfig.Metadata = result.(map[string]any)
}
nextPort++
}
@@ -235,19 +326,20 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
if macroName == "PID" && fieldName == "cmdStop" {
continue // this is ok, has to be replaced by process later
}
if _, exists := config.Macros[macroName]; !exists {
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
// Reserved macros are always valid (they should have been substituted already)
if macroName == "PORT" || macroName == "MODEL_ID" {
return Config{}, fmt.Errorf("macro '${%s}' should have been substituted in %s.%s", macroName, modelId, fieldName)
}
// Any other macro is unknown
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
}
}
// Apply macro substitution to metadata
// Check for unknown macros in metadata
if len(modelConfig.Metadata) > 0 {
substitutedMetadata, err := substituteMetadataMacros(modelConfig.Metadata, mergedMacros)
if err != nil {
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
if err := validateMetadataForUnknownMacros(modelConfig.Metadata, modelId); err != nil {
return Config{}, err
}
modelConfig.Metadata = substitutedMetadata.(map[string]any)
}
config.Models[modelId] = modelConfig
@@ -400,6 +492,11 @@ func validateMacro(name string, value any) error {
if len(v) >= 1024 {
return fmt.Errorf("macro value for '%s' exceeds maximum length of 1024 characters", name)
}
// Check for self-reference
macroSlug := fmt.Sprintf("${%s}", name)
if strings.Contains(v, macroSlug) {
return fmt.Errorf("macro '%s' contains self-reference", name)
}
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool:
// These types are allowed
default:
@@ -414,41 +511,62 @@ func validateMacro(name string, value any) error {
return nil
}
// substituteMetadataMacros recursively substitutes macros in metadata structures
// Direct substitution (key: ${macro}) preserves the macro's type
// Interpolated substitution (key: "text ${macro}") converts to string
func substituteMetadataMacros(value any, macros MacroList) (any, error) {
// validateMetadataForUnknownMacros recursively checks for any remaining macro references in metadata
func validateMetadataForUnknownMacros(value any, modelId string) error {
switch v := value.(type) {
case string:
matches := macroPatternRegex.FindAllStringSubmatch(v, -1)
for _, match := range matches {
macroName := match[1]
return fmt.Errorf("model %s metadata: unknown macro '${%s}'", modelId, macroName)
}
return nil
case map[string]any:
for _, val := range v {
if err := validateMetadataForUnknownMacros(val, modelId); err != nil {
return err
}
}
return nil
case []any:
for _, val := range v {
if err := validateMetadataForUnknownMacros(val, modelId); err != nil {
return err
}
}
return nil
default:
// Scalar types don't contain macros
return nil
}
}
// substituteMacroInValue recursively substitutes a single macro in a value structure
// This is called once per macro, allowing LIFO substitution order
func substituteMacroInValue(value any, macroName string, macroValue any) (any, error) {
macroSlug := fmt.Sprintf("${%s}", macroName)
macroStr := fmt.Sprintf("%v", macroValue)
switch v := value.(type) {
case string:
// Check if this is a direct macro substitution
if strings.HasPrefix(v, "${") && strings.HasSuffix(v, "}") && strings.Count(v, "${") == 1 {
macroName := v[2 : len(v)-1]
if macroValue, exists := macros[macroName]; exists {
return macroValue, nil
}
return nil, fmt.Errorf("unknown macro '${%s}' in metadata", macroName)
if v == macroSlug {
return macroValue, nil
}
// Handle string interpolation
matches := macroPatternRegex.FindAllStringSubmatch(v, -1)
result := v
for _, match := range matches {
macroName := match[1]
macroValue, exists := macros[macroName]
if !exists {
return nil, fmt.Errorf("unknown macro '${%s}' in metadata", macroName)
}
// Convert macro value to string for interpolation
macroStr := fmt.Sprintf("%v", macroValue)
result = strings.ReplaceAll(result, match[0], macroStr)
if strings.Contains(v, macroSlug) {
return strings.ReplaceAll(v, macroSlug, macroStr), nil
}
return result, nil
return v, nil
case map[string]any:
// Recursively process map values
newMap := make(map[string]any)
for key, val := range v {
newVal, err := substituteMetadataMacros(val, macros)
newVal, err := substituteMacroInValue(val, macroName, macroValue)
if err != nil {
return nil, err
}
@@ -460,7 +578,7 @@ func substituteMetadataMacros(value any, macros MacroList) (any, error) {
// Recursively process slice elements
newSlice := make([]any, len(v))
for i, val := range v {
newVal, err := substituteMetadataMacros(val, macros)
newVal, err := substituteMacroInValue(val, macroName, macroValue)
if err != nil {
return nil, err
}

View File

@@ -164,7 +164,7 @@ groups:
LogLevel: "info",
StartPort: 5800,
Macros: MacroList{
"svr-path": "path/to/server",
{"svr-path", "path/to/server"},
},
Hooks: HooksConfig{
OnStartup: HookOnStartup{

View File

@@ -213,7 +213,9 @@ models:
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
if !assert.NoError(t, err) {
t.FailNow()
}
sanitizedCmd, err := SanitizeCommand(config.Models["model1"].Cmd)
assert.NoError(t, err)
assert.Equal(t, "path/to/server --arg2 --port 9990 --arg1 --arg3 three --overridden success", strings.Join(sanitizedCmd, " "))
@@ -321,7 +323,7 @@ macros:
models:
model1:
cmd: "${svr-path} --port ${PORT}"
proxy: "http://localhost:${unknownMacro}"
proxy: "http://${unknownMacro}:${PORT}"
`,
},
{
@@ -503,7 +505,9 @@ models:
assert.NoError(t, err)
assert.Equal(t, "/path/to/server -p 9001 -hf model1", strings.Join(sanitizedCmd, " "))
assert.Equal(t, "docker stop ${MODEL_ID}", config.Macros["docker-stop"])
dockerStopMacro, found := config.Macros.Get("docker-stop")
assert.True(t, found)
assert.Equal(t, "docker stop ${MODEL_ID}", dockerStopMacro)
sanitizedCmd2, err := SanitizeCommand(config.Models["model2"].Cmd)
assert.NoError(t, err)

View File

@@ -156,7 +156,7 @@ groups:
LogLevel: "info",
StartPort: 5800,
Macros: MacroList{
"svr-path": "path/to/server",
{"svr-path", "path/to/server"},
},
Models: map[string]ModelConfig{
"model1": {

View File

@@ -0,0 +1,123 @@
package config
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
// Test macro-in-macro basic substitution
func TestConfig_MacroInMacroBasic(t *testing.T) {
content := `
startPort: 10000
macros:
"A": "value-A"
"B": "prefix-${A}-suffix"
models:
test:
cmd: echo ${B}
proxy: http://localhost:8080
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, "echo prefix-value-A-suffix", config.Models["test"].Cmd)
}
// Test LIFO substitution order with 3+ macro levels
func TestConfig_MacroInMacroLIFOOrder(t *testing.T) {
content := `
startPort: 10000
macros:
"base": "/models"
"path": "${base}/llama"
"full": "${path}/model.gguf"
models:
test:
cmd: load ${full}
proxy: http://localhost:8080
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, "load /models/llama/model.gguf", config.Models["test"].Cmd)
}
// Test MODEL_ID in global macro used by model
func TestConfig_ModelIdInGlobalMacro(t *testing.T) {
content := `
startPort: 10000
macros:
"podman-llama": "podman run --name ${MODEL_ID} ghcr.io/ggml-org/llama.cpp:server-cuda"
models:
my-model:
cmd: ${podman-llama} -m model.gguf
proxy: http://localhost:8080
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, "podman run --name my-model ghcr.io/ggml-org/llama.cpp:server-cuda -m model.gguf", config.Models["my-model"].Cmd)
}
// Test model macro overrides global macro in substitution
func TestConfig_ModelMacroOverridesGlobal(t *testing.T) {
content := `
startPort: 10000
macros:
"tag": "global"
"msg": "value-${tag}"
models:
test:
macros:
"tag": "model-level"
cmd: echo ${msg}
proxy: http://localhost:8080
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
assert.Equal(t, "echo value-model-level", config.Models["test"].Cmd)
}
// Test self-reference detection error
func TestConfig_SelfReferenceDetection(t *testing.T) {
content := `
startPort: 10000
macros:
"recursive": "value-${recursive}"
models:
test:
cmd: echo ${recursive}
proxy: http://localhost:8080
`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.Error(t, err)
assert.Contains(t, err.Error(), "recursive")
assert.Contains(t, err.Error(), "self-reference")
}
// Test undefined macro reference error
func TestConfig_UndefinedMacroReference(t *testing.T) {
content := `
startPort: 10000
macros:
"A": "value-${UNDEFINED}"
models:
test:
cmd: echo ${A}
proxy: http://localhost:8080
`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.Error(t, err)
assert.Contains(t, err.Error(), "UNDEFINED")
}