Add Filters to Model Configuration (#174)

llama-swap can strip specific keys in JSON requests. This is useful for removing the ability for clients to set sampling parameters like temperature, top_k, top_p, etc.
This commit is contained in:
Benson Wong
2025-06-23 10:52:29 -07:00
committed by GitHub
parent 756193d0dd
commit 4236cec03a
9 changed files with 297 additions and 71 deletions

View File

@@ -6,6 +6,7 @@ import (
"os"
"regexp"
"runtime"
"slices"
"sort"
"strconv"
"strings"
@@ -29,6 +30,9 @@ type ModelConfig struct {
// Limit concurrency of HTTP requests to process
ConcurrencyLimit int `yaml:"concurrencyLimit"`
// Model filters see issue #174
Filters ModelFilters `yaml:"filters"`
}
func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
@@ -63,6 +67,46 @@ func (m *ModelConfig) SanitizedCommand() ([]string, error) {
return SanitizeCommand(m.Cmd)
}
// ModelFilters see issue #174
type ModelFilters struct {
StripParams string `yaml:"strip_params"`
}
func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawModelFilters ModelFilters
defaults := rawModelFilters{
StripParams: "",
}
if err := unmarshal(&defaults); err != nil {
return err
}
*m = ModelFilters(defaults)
return nil
}
func (f ModelFilters) SanitizedStripParams() ([]string, error) {
if f.StripParams == "" {
return nil, nil
}
params := strings.Split(f.StripParams, ",")
cleaned := make([]string, 0, len(params))
for _, param := range params {
trimmed := strings.TrimSpace(param)
if trimmed == "model" || trimmed == "" {
continue
}
cleaned = append(cleaned, trimmed)
}
// sort cleaned
slices.Sort(cleaned)
return cleaned, nil
}
type GroupConfig struct {
Swap bool `yaml:"swap"`
Exclusive bool `yaml:"exclusive"`
@@ -212,6 +256,7 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroValue)
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroValue)
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroValue)
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroValue)
}
// enforce ${PORT} used in both cmd and proxy

View File

@@ -83,6 +83,9 @@ models:
assert.Equal(t, "", model1.UseModelName)
assert.Equal(t, 0, model1.ConcurrencyLimit)
}
// default empty filter exists
assert.Equal(t, "", model1.Filters.StripParams)
}
func TestConfig_LoadPosix(t *testing.T) {

View File

@@ -300,3 +300,28 @@ models:
})
}
}
func TestConfig_ModelFilters(t *testing.T) {
content := `
macros:
default_strip: "temperature, top_p"
models:
model1:
cmd: path/to/cmd --port ${PORT}
filters:
strip_params: "model, top_k, ${default_strip}, , ,"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
modelConfig, ok := config.Models["model1"]
if !assert.True(t, ok) {
t.FailNow()
}
// make sure `model` and enmpty strings are not in the list
assert.Equal(t, "model, top_k, temperature, top_p, , ,", modelConfig.Filters.StripParams)
sanitized, err := modelConfig.Filters.SanitizedStripParams()
if assert.NoError(t, err) {
assert.Equal(t, []string{"temperature", "top_k", "top_p"}, sanitized)
}
}

View File

@@ -80,6 +80,9 @@ models:
assert.Equal(t, "", model1.UseModelName)
assert.Equal(t, 0, model1.ConcurrencyLimit)
}
// default empty filter exists
assert.Equal(t, "", model1.Filters.StripParams)
}
func TestConfig_LoadWindows(t *testing.T) {

View File

@@ -365,6 +365,21 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
}
}
// issue #174 strip parameters from the JSON body
stripParams, err := pm.config.Models[realModelName].Filters.SanitizedStripParams()
if err != nil { // just log it and continue
pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[realModelName].Filters.StripParams, err.Error())
} else {
for _, param := range stripParams {
pm.proxyLogger.Debugf("<%s> stripping param: %s", realModelName, param)
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param))
return
}
}
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// dechunk it as we already have all the body bytes see issue #11

View File

@@ -623,3 +623,37 @@ func TestProxyManager_ChatContentLength(t *testing.T) {
assert.Equal(t, "81", response["h_content_length"])
assert.Equal(t, "model1", response["responseMessage"])
}
func TestProxyManager_FiltersStripParams(t *testing.T) {
modelConfig := getTestSimpleResponderConfig("model1")
modelConfig.Filters = ModelFilters{
StripParams: "temperature, model, stream",
}
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
LogLevel: "error",
Models: map[string]ModelConfig{
"model1": modelConfig,
},
})
proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest)
reqBody := `{"model":"model1", "temperature":0.1, "x_param":"123", "y_param":"abc", "stream":true}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var response map[string]string
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
// `temperature` and `stream` are gone but model remains
assert.Equal(t, `{"model":"model1", "x_param":"123", "y_param":"abc"}`, response["request_body"])
// assert.Nil(t, response["temperature"])
// assert.Equal(t, "123", response["x_param"])
// assert.Equal(t, "abc", response["y_param"])
// t.Logf("%v", response)
}