Add support for sending a custom model name to upstream (#69) (#71)

* add test for splitRequestedModel()
* Add `useModelName` parameter to model configuration
* add docs to README
This commit is contained in:
Benson Wong
2025-03-14 21:07:52 -07:00
committed by GitHub
parent 671c1a5a7b
commit 5c97299e7b
4 changed files with 146 additions and 14 deletions

View File

@@ -96,7 +96,7 @@ func New(config *Config) *ProxyManager {
// Support audio/speech endpoint
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIAudioTranscriptionHandler)
pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIPostFormHandler)
pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
@@ -351,12 +351,21 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
}
if process, err := pm.swapModel(requestedModel); err != nil {
process, err := pm.swapModel(requestedModel)
if err != nil {
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
return
} else {
}
// strip
// issue #69 allow custom model names to be sent to upstream
if process.config.UseModelName != "" {
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", process.config.UseModelName)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
return
}
} else {
profileName, modelName := splitRequestedModel(requestedModel)
if profileName != "" {
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", modelName)
@@ -366,17 +375,19 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
}
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// dechunk it as we already have all the body bytes see issue #11
c.Request.Header.Del("transfer-encoding")
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
process.ProxyRequest(c.Writer, c.Request)
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// dechunk it as we already have all the body bytes see issue #11
c.Request.Header.Del("transfer-encoding")
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
process.ProxyRequest(c.Writer, c.Request)
}
func (pm *ProxyManager) proxyOAIAudioTranscriptionHandler(c *gin.Context) {
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
// We need to reconstruct the multipart form in any case since the body is consumed
// Create a new buffer for the reconstructed request
var requestBuffer bytes.Buffer
@@ -410,8 +421,12 @@ func (pm *ProxyManager) proxyOAIAudioTranscriptionHandler(c *gin.Context) {
for _, value := range values {
fieldValue := value
// If this is the model field and we have a profile, use just the model name
if key == "model" && profileName != "" {
fieldValue = modelName
if key == "model" {
if process.config.UseModelName != "" {
fieldValue = process.config.UseModelName
} else if profileName != "" {
fieldValue = modelName
}
}
field, err := multipartWriter.CreateFormField(key)
if err != nil {