Add /v1/audio/transcriptions support (#41)
* add support for /v1/audio/transcriptions
This commit is contained in:
@@ -18,6 +18,7 @@ Written in golang, it is very easy to install (single binary with no dependancie
|
|||||||
- `v1/embeddings`
|
- `v1/embeddings`
|
||||||
- `v1/rerank`
|
- `v1/rerank`
|
||||||
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
|
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
|
||||||
|
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
|
||||||
- ✅ Run multiple models at once with `profiles` ([docs](https://github.com/mostlygeek/llama-swap/issues/53#issuecomment-2660761741))
|
- ✅ Run multiple models at once with `profiles` ([docs](https://github.com/mostlygeek/llama-swap/issues/53#issuecomment-2660761741))
|
||||||
- ✅ Remote log monitoring at `/log`
|
- ✅ Remote log monitoring at `/log`
|
||||||
- ✅ Direct access to upstream HTTP server via `/upstream/:model_id` ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
- ✅ Direct access to upstream HTTP server via `/upstream/:model_id` ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
|
||||||
|
|||||||
@@ -67,6 +67,46 @@ func main() {
|
|||||||
c.String(200, *responseMessage)
|
c.String(200, *responseMessage)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// issue #41
|
||||||
|
r.POST("/v1/audio/transcriptions", func(c *gin.Context) {
|
||||||
|
// Parse the multipart form
|
||||||
|
if err := c.Request.ParseMultipartForm(10 << 20); err != nil { // 10 MB max memory
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Error parsing multipart form: %s", err)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the model from the form values
|
||||||
|
model := c.Request.FormValue("model")
|
||||||
|
|
||||||
|
if model == "" {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "Missing model parameter"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the file from the form
|
||||||
|
file, _, err := c.Request.FormFile("file")
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Error getting file: %s", err)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
// Read the file content to get its size
|
||||||
|
fileBytes, err := io.ReadAll(file)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error reading file: %s", err)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fileSize := len(fileBytes)
|
||||||
|
|
||||||
|
// Return a JSON response with the model and transcription text including file size
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"text": fmt.Sprintf("The length of the file is %d bytes", fileSize),
|
||||||
|
"model": model,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
r.GET("/slow-respond", func(c *gin.Context) {
|
r.GET("/slow-respond", func(c *gin.Context) {
|
||||||
echo := c.Query("echo")
|
echo := c.Query("echo")
|
||||||
delay := c.Query("delay")
|
delay := c.Query("delay")
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -95,6 +96,7 @@ func New(config *Config) *ProxyManager {
|
|||||||
|
|
||||||
// Support audio/speech endpoint
|
// Support audio/speech endpoint
|
||||||
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
|
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
|
||||||
|
pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIAudioTranscriptionHandler)
|
||||||
|
|
||||||
pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
|
pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
|
||||||
|
|
||||||
@@ -374,6 +376,105 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) proxyOAIAudioTranscriptionHandler(c *gin.Context) {
|
||||||
|
// We need to reconstruct the multipart form in any case since the body is consumed
|
||||||
|
// Create a new buffer for the reconstructed request
|
||||||
|
var requestBuffer bytes.Buffer
|
||||||
|
multipartWriter := multipart.NewWriter(&requestBuffer)
|
||||||
|
|
||||||
|
// Parse multipart form
|
||||||
|
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
|
||||||
|
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get model parameter from the form
|
||||||
|
requestedModel := c.Request.FormValue("model")
|
||||||
|
if requestedModel == "" {
|
||||||
|
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' parameter in form data")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Swap to the requested model
|
||||||
|
process, err := pm.swapModel(requestedModel)
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get profile name and model name from the requested model
|
||||||
|
profileName, modelName := splitRequestedModel(requestedModel)
|
||||||
|
|
||||||
|
// Copy all form values
|
||||||
|
for key, values := range c.Request.MultipartForm.Value {
|
||||||
|
for _, value := range values {
|
||||||
|
fieldValue := value
|
||||||
|
// If this is the model field and we have a profile, use just the model name
|
||||||
|
if key == "model" && profileName != "" {
|
||||||
|
fieldValue = modelName
|
||||||
|
}
|
||||||
|
field, err := multipartWriter.CreateFormField(key)
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form field")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err = field.Write([]byte(fieldValue)); err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error writing form field")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy all files from the original request
|
||||||
|
for key, fileHeaders := range c.Request.MultipartForm.File {
|
||||||
|
for _, fileHeader := range fileHeaders {
|
||||||
|
formFile, err := multipartWriter.CreateFormFile(key, fileHeader.Filename)
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form file")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
file, err := fileHeader.Open()
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error opening uploaded file")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err = io.Copy(formFile, file); err != nil {
|
||||||
|
file.Close()
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error copying file data")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
file.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close the multipart writer to finalize the form
|
||||||
|
if err := multipartWriter.Close(); err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error finalizing multipart form")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new request with the reconstructed form data
|
||||||
|
modifiedReq, err := http.NewRequestWithContext(
|
||||||
|
c.Request.Context(),
|
||||||
|
c.Request.Method,
|
||||||
|
c.Request.URL.String(),
|
||||||
|
&requestBuffer,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error creating modified request")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy the headers from the original request
|
||||||
|
modifiedReq.Header = c.Request.Header.Clone()
|
||||||
|
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
|
||||||
|
|
||||||
|
// Use the modified request for proxying
|
||||||
|
process.ProxyRequest(c.Writer, modifiedReq)
|
||||||
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) {
|
func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) {
|
||||||
acceptHeader := c.GetHeader("Accept")
|
acceptHeader := c.GetHeader("Accept")
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"mime/multipart"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -460,3 +462,73 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
|
|||||||
assert.Empty(t, expectedModels, "unexpected additional models in response")
|
assert.Empty(t, expectedModels, "unexpected additional models in response")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
|
||||||
|
config := &Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Profiles: map[string][]string{
|
||||||
|
"test": {"TheExpectedModel"},
|
||||||
|
},
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy := New(config)
|
||||||
|
defer proxy.StopProcesses()
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
modelInput string
|
||||||
|
expectModel string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "With Profile Prefix",
|
||||||
|
modelInput: "test:TheExpectedModel",
|
||||||
|
expectModel: "TheExpectedModel", // Profile prefix should be stripped
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Without Profile Prefix",
|
||||||
|
modelInput: "TheExpectedModel",
|
||||||
|
expectModel: "TheExpectedModel", // Should remain the same
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
// Create a buffer with multipart form data
|
||||||
|
var b bytes.Buffer
|
||||||
|
w := multipart.NewWriter(&b)
|
||||||
|
|
||||||
|
// Add the model field
|
||||||
|
fw, err := w.CreateFormField("model")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
_, err = fw.Write([]byte(tc.modelInput))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Add a file field
|
||||||
|
fw, err = w.CreateFormFile("file", "test.mp3")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
// Generate random content length between 10 and 20
|
||||||
|
contentLength := rand.Intn(11) + 10 // 10 to 20
|
||||||
|
content := make([]byte, contentLength)
|
||||||
|
_, err = fw.Write(content)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
w.Close()
|
||||||
|
|
||||||
|
// Create the request with the multipart form data
|
||||||
|
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
||||||
|
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
proxy.HandlerFunc(rec, req)
|
||||||
|
|
||||||
|
// Verify the response
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
var response map[string]string
|
||||||
|
err = json.Unmarshal(rec.Body.Bytes(), &response)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, tc.expectModel, response["model"])
|
||||||
|
assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user