diff --git a/README.md b/README.md index 9755fd3..ed04355 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,7 @@ Written in golang, it is very easy to install (single binary with no dependancie - `v1/embeddings` - `v1/rerank` - `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36)) + - `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867)) - ✅ Run multiple models at once with `profiles` ([docs](https://github.com/mostlygeek/llama-swap/issues/53#issuecomment-2660761741)) - ✅ Remote log monitoring at `/log` - ✅ Direct access to upstream HTTP server via `/upstream/:model_id` ([demo](https://github.com/mostlygeek/llama-swap/pull/31)) diff --git a/misc/simple-responder/simple-responder.go b/misc/simple-responder/simple-responder.go index 57ff632..fcf8354 100644 --- a/misc/simple-responder/simple-responder.go +++ b/misc/simple-responder/simple-responder.go @@ -67,6 +67,46 @@ func main() { c.String(200, *responseMessage) }) + // issue #41 + r.POST("/v1/audio/transcriptions", func(c *gin.Context) { + // Parse the multipart form + if err := c.Request.ParseMultipartForm(10 << 20); err != nil { // 10 MB max memory + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Error parsing multipart form: %s", err)}) + return + } + + // Get the model from the form values + model := c.Request.FormValue("model") + + if model == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "Missing model parameter"}) + return + } + + // Get the file from the form + file, _, err := c.Request.FormFile("file") + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Error getting file: %s", err)}) + return + } + defer file.Close() + + // Read the file content to get its size + fileBytes, err := io.ReadAll(file) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error reading file: %s", err)}) + return + } + + fileSize := len(fileBytes) + + // Return a JSON response with the model and transcription text including file size + c.JSON(http.StatusOK, gin.H{ + "text": fmt.Sprintf("The length of the file is %d bytes", fileSize), + "model": model, + }) + }) + r.GET("/slow-respond", func(c *gin.Context) { echo := c.Query("echo") delay := c.Query("delay") diff --git a/proxy/proxymanager.go b/proxy/proxymanager.go index 1cdffa2..68eb798 100644 --- a/proxy/proxymanager.go +++ b/proxy/proxymanager.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "io" + "mime/multipart" "net/http" "sort" "strconv" @@ -95,6 +96,7 @@ func New(config *Config) *ProxyManager { // Support audio/speech endpoint pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler) + pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIAudioTranscriptionHandler) pm.ginEngine.GET("/v1/models", pm.listModelsHandler) @@ -374,6 +376,105 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) { } } +func (pm *ProxyManager) proxyOAIAudioTranscriptionHandler(c *gin.Context) { + // We need to reconstruct the multipart form in any case since the body is consumed + // Create a new buffer for the reconstructed request + var requestBuffer bytes.Buffer + multipartWriter := multipart.NewWriter(&requestBuffer) + + // Parse multipart form + if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk + pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error())) + return + } + + // Get model parameter from the form + requestedModel := c.Request.FormValue("model") + if requestedModel == "" { + pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' parameter in form data") + return + } + + // Swap to the requested model + process, err := pm.swapModel(requestedModel) + if err != nil { + pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error())) + return + } + + // Get profile name and model name from the requested model + profileName, modelName := splitRequestedModel(requestedModel) + + // Copy all form values + for key, values := range c.Request.MultipartForm.Value { + for _, value := range values { + fieldValue := value + // If this is the model field and we have a profile, use just the model name + if key == "model" && profileName != "" { + fieldValue = modelName + } + field, err := multipartWriter.CreateFormField(key) + if err != nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form field") + return + } + if _, err = field.Write([]byte(fieldValue)); err != nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, "error writing form field") + return + } + } + } + + // Copy all files from the original request + for key, fileHeaders := range c.Request.MultipartForm.File { + for _, fileHeader := range fileHeaders { + formFile, err := multipartWriter.CreateFormFile(key, fileHeader.Filename) + if err != nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form file") + return + } + + file, err := fileHeader.Open() + if err != nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, "error opening uploaded file") + return + } + + if _, err = io.Copy(formFile, file); err != nil { + file.Close() + pm.sendErrorResponse(c, http.StatusInternalServerError, "error copying file data") + return + } + file.Close() + } + } + + // Close the multipart writer to finalize the form + if err := multipartWriter.Close(); err != nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, "error finalizing multipart form") + return + } + + // Create a new request with the reconstructed form data + modifiedReq, err := http.NewRequestWithContext( + c.Request.Context(), + c.Request.Method, + c.Request.URL.String(), + &requestBuffer, + ) + if err != nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, "error creating modified request") + return + } + + // Copy the headers from the original request + modifiedReq.Header = c.Request.Header.Clone() + modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType()) + + // Use the modified request for proxying + process.ProxyRequest(c.Writer, modifiedReq) +} + func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) { acceptHeader := c.GetHeader("Accept") diff --git a/proxy/proxymanager_test.go b/proxy/proxymanager_test.go index d09ccf2..470e8d7 100644 --- a/proxy/proxymanager_test.go +++ b/proxy/proxymanager_test.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/json" "fmt" + "math/rand" + "mime/multipart" "net/http" "net/http/httptest" "sync" @@ -460,3 +462,73 @@ func TestProxyManager_RunningEndpoint(t *testing.T) { assert.Empty(t, expectedModels, "unexpected additional models in response") }) } + +func TestProxyManager_AudioTranscriptionHandler(t *testing.T) { + config := &Config{ + HealthCheckTimeout: 15, + Profiles: map[string][]string{ + "test": {"TheExpectedModel"}, + }, + Models: map[string]ModelConfig{ + "TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"), + }, + } + + proxy := New(config) + defer proxy.StopProcesses() + + testCases := []struct { + name string + modelInput string + expectModel string + }{ + { + name: "With Profile Prefix", + modelInput: "test:TheExpectedModel", + expectModel: "TheExpectedModel", // Profile prefix should be stripped + }, + { + name: "Without Profile Prefix", + modelInput: "TheExpectedModel", + expectModel: "TheExpectedModel", // Should remain the same + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create a buffer with multipart form data + var b bytes.Buffer + w := multipart.NewWriter(&b) + + // Add the model field + fw, err := w.CreateFormField("model") + assert.NoError(t, err) + _, err = fw.Write([]byte(tc.modelInput)) + assert.NoError(t, err) + + // Add a file field + fw, err = w.CreateFormFile("file", "test.mp3") + assert.NoError(t, err) + // Generate random content length between 10 and 20 + contentLength := rand.Intn(11) + 10 // 10 to 20 + content := make([]byte, contentLength) + _, err = fw.Write(content) + assert.NoError(t, err) + w.Close() + + // Create the request with the multipart form data + req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b) + req.Header.Set("Content-Type", w.FormDataContentType()) + rec := httptest.NewRecorder() + proxy.HandlerFunc(rec, req) + + // Verify the response + assert.Equal(t, http.StatusOK, rec.Code) + var response map[string]string + err = json.Unmarshal(rec.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, tc.expectModel, response["model"]) + assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder + }) + } +}