* add test for splitRequestedModel() * Add `useModelName` parameter to model configuration * add docs to README
This commit is contained in:
@@ -96,7 +96,7 @@ func New(config *Config) *ProxyManager {
|
||||
|
||||
// Support audio/speech endpoint
|
||||
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIAudioTranscriptionHandler)
|
||||
pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIPostFormHandler)
|
||||
|
||||
pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
|
||||
|
||||
@@ -351,12 +351,21 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
||||
}
|
||||
|
||||
if process, err := pm.swapModel(requestedModel); err != nil {
|
||||
process, err := pm.swapModel(requestedModel)
|
||||
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
|
||||
return
|
||||
} else {
|
||||
}
|
||||
|
||||
// strip
|
||||
// issue #69 allow custom model names to be sent to upstream
|
||||
if process.config.UseModelName != "" {
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", process.config.UseModelName)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
profileName, modelName := splitRequestedModel(requestedModel)
|
||||
if profileName != "" {
|
||||
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", modelName)
|
||||
@@ -366,17 +375,19 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
|
||||
// dechunk it as we already have all the body bytes see issue #11
|
||||
c.Request.Header.Del("transfer-encoding")
|
||||
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
|
||||
|
||||
process.ProxyRequest(c.Writer, c.Request)
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
|
||||
// dechunk it as we already have all the body bytes see issue #11
|
||||
c.Request.Header.Del("transfer-encoding")
|
||||
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
|
||||
|
||||
process.ProxyRequest(c.Writer, c.Request)
|
||||
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyOAIAudioTranscriptionHandler(c *gin.Context) {
|
||||
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
||||
// We need to reconstruct the multipart form in any case since the body is consumed
|
||||
// Create a new buffer for the reconstructed request
|
||||
var requestBuffer bytes.Buffer
|
||||
@@ -410,8 +421,12 @@ func (pm *ProxyManager) proxyOAIAudioTranscriptionHandler(c *gin.Context) {
|
||||
for _, value := range values {
|
||||
fieldValue := value
|
||||
// If this is the model field and we have a profile, use just the model name
|
||||
if key == "model" && profileName != "" {
|
||||
fieldValue = modelName
|
||||
if key == "model" {
|
||||
if process.config.UseModelName != "" {
|
||||
fieldValue = process.config.UseModelName
|
||||
} else if profileName != "" {
|
||||
fieldValue = modelName
|
||||
}
|
||||
}
|
||||
field, err := multipartWriter.CreateFormField(key)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user