diff --git a/proxy/config.go b/proxy/config.go index ee86174..182dc84 100644 --- a/proxy/config.go +++ b/proxy/config.go @@ -2,6 +2,7 @@ package proxy import ( "fmt" + "io" "os" "sort" "strings" @@ -83,7 +84,16 @@ func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) { } func LoadConfig(path string) (Config, error) { - data, err := os.ReadFile(path) + file, err := os.Open(path) + if err != nil { + return Config{}, err + } + defer file.Close() + return LoadConfigFromReader(file) +} + +func LoadConfigFromReader(r io.Reader) (Config, error) { + data, err := io.ReadAll(r) if err != nil { return Config{}, err } @@ -102,6 +112,9 @@ func LoadConfig(path string) (Config, error) { config.aliases = make(map[string]string) for modelName, modelConfig := range config.Models { for _, alias := range modelConfig.Aliases { + if _, found := config.aliases[alias]; found { + return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName) + } config.aliases[alias] = modelName } } diff --git a/proxy/config_test.go b/proxy/config_test.go index 292c87c..bb0ce5a 100644 --- a/proxy/config_test.go +++ b/proxy/config_test.go @@ -3,6 +3,7 @@ package proxy import ( "os" "path/filepath" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -138,14 +139,6 @@ groups: } func TestConfig_GroupMemberIsUnique(t *testing.T) { - // Create a temporary YAML file for testing - tempDir, err := os.MkdirTemp("", "test-config") - if err != nil { - t.Fatalf("Failed to create temporary directory: %v", err) - } - defer os.RemoveAll(tempDir) - - tempFile := filepath.Join(tempDir, "config.yaml") content := ` models: model1: @@ -171,15 +164,32 @@ groups: exclusive: false members: ["model2"] ` + // Load the config and verify + _, err := LoadConfigFromReader(strings.NewReader(content)) + assert.Equal(t, "model member model2 is used in multiple groups: group1 and group2", err.Error()) - if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil { - t.Fatalf("Failed to write temporary file: %v", err) - } +} + +func TestConfig_ModelAliasesAreUnique(t *testing.T) { + content := ` +models: + model1: + cmd: path/to/cmd --arg1 one + proxy: "http://localhost:8080" + aliases: + - m1 + model2: + cmd: path/to/cmd --arg1 one + proxy: "http://localhost:8081" + checkEndpoint: "/" + aliases: + - m1 + - m2 +` // Load the config and verify - _, err = LoadConfig(tempFile) - assert.NotNil(t, err) - + _, err := LoadConfigFromReader(strings.NewReader(content)) + assert.Equal(t, "duplicate alias m1 found in model: model2", err.Error()) } func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {