Fix function signature, use default random boundary (#3422)

Fixes the function signature of `parseMultipartResponse` and uses the
default random boundary when creating a new multipart response.
This commit is contained in:
Till 2024-09-13 09:39:30 +02:00 committed by GitHub
parent 002fed3cb9
commit ed6d964e5d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 14 additions and 21 deletions

View file

@ -33,7 +33,6 @@ import (
"sync" "sync"
"unicode" "unicode"
"github.com/google/uuid"
"github.com/matrix-org/dendrite/mediaapi/fileutils" "github.com/matrix-org/dendrite/mediaapi/fileutils"
"github.com/matrix-org/dendrite/mediaapi/storage" "github.com/matrix-org/dendrite/mediaapi/storage"
"github.com/matrix-org/dendrite/mediaapi/thumbnailer" "github.com/matrix-org/dendrite/mediaapi/thumbnailer"
@ -400,22 +399,16 @@ func (r *downloadRequest) respondFromLocalFile(
} }
func multipartResponse(w http.ResponseWriter, r *downloadRequest, contentType string, responseFile io.Reader) (int64, error) { func multipartResponse(w http.ResponseWriter, r *downloadRequest, contentType string, responseFile io.Reader) (int64, error) {
// Update the header to be multipart/mixed; boundary=$randomBoundary
boundary := uuid.NewString()
w.Header().Set("Content-Type", "multipart/mixed; boundary="+boundary)
w.Header().Del("Content-Length") // let Go handle the content length
mw := multipart.NewWriter(w) mw := multipart.NewWriter(w)
// Update the header to be multipart/mixed; boundary=$randomBoundary
w.Header().Set("Content-Type", "multipart/mixed; boundary="+mw.Boundary())
w.Header().Del("Content-Length") // let Go handle the content length
defer func() { defer func() {
if err := mw.Close(); err != nil { if err := mw.Close(); err != nil {
r.Logger.WithError(err).Error("Failed to close multipart writer") r.Logger.WithError(err).Error("Failed to close multipart writer")
} }
}() }()
if err := mw.SetBoundary(boundary); err != nil {
return 0, fmt.Errorf("failed to set multipart boundary: %w", err)
}
// JSON object part // JSON object part
jsonWriter, err := mw.CreatePart(textproto.MIMEHeader{ jsonWriter, err := mw.CreatePart(textproto.MIMEHeader{
"Content-Type": {"application/json"}, "Content-Type": {"application/json"},
@ -858,7 +851,7 @@ func (r *downloadRequest) fetchRemoteFile(
var reader io.Reader var reader io.Reader
var parseErr error var parseErr error
if isAuthed { if isAuthed {
parseErr, contentLength, reader = parseMultipartResponse(r, resp, maxFileSizeBytes) contentLength, reader, parseErr = parseMultipartResponse(r, resp, maxFileSizeBytes)
} else { } else {
// The reader returned here will be limited either by the Content-Length // The reader returned here will be limited either by the Content-Length
// and/or the configured maximum media size. // and/or the configured maximum media size.
@ -928,48 +921,48 @@ func (r *downloadRequest) fetchRemoteFile(
return types.Path(finalPath), duplicate, nil return types.Path(finalPath), duplicate, nil
} }
func parseMultipartResponse(r *downloadRequest, resp *http.Response, maxFileSizeBytes config.FileSizeBytes) (error, int64, io.Reader) { func parseMultipartResponse(r *downloadRequest, resp *http.Response, maxFileSizeBytes config.FileSizeBytes) (int64, io.Reader, error) {
_, params, err := mime.ParseMediaType(resp.Header.Get("Content-Type")) _, params, err := mime.ParseMediaType(resp.Header.Get("Content-Type"))
if err != nil { if err != nil {
return err, 0, nil return 0, nil, err
} }
if params["boundary"] == "" { if params["boundary"] == "" {
return fmt.Errorf("no boundary header found on media %s from %s", r.MediaMetadata.MediaID, r.MediaMetadata.Origin), 0, nil return 0, nil, fmt.Errorf("no boundary header found on media %s from %s", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)
} }
mr := multipart.NewReader(resp.Body, params["boundary"]) mr := multipart.NewReader(resp.Body, params["boundary"])
// Get the first, JSON, part // Get the first, JSON, part
p, err := mr.NextPart() p, err := mr.NextPart()
if err != nil { if err != nil {
return err, 0, nil return 0, nil, err
} }
defer p.Close() // nolint: errcheck defer p.Close() // nolint: errcheck
if p.Header.Get("Content-Type") != "application/json" { if p.Header.Get("Content-Type") != "application/json" {
return fmt.Errorf("first part of the response must be application/json"), 0, nil return 0, nil, fmt.Errorf("first part of the response must be application/json")
} }
// Try to parse media meta information // Try to parse media meta information
meta := mediaMeta{} meta := mediaMeta{}
if err = json.NewDecoder(p).Decode(&meta); err != nil { if err = json.NewDecoder(p).Decode(&meta); err != nil {
return err, 0, nil return 0, nil, err
} }
defer p.Close() // nolint: errcheck defer p.Close() // nolint: errcheck
// Get the actual media content // Get the actual media content
p, err = mr.NextPart() p, err = mr.NextPart()
if err != nil { if err != nil {
return err, 0, nil return 0, nil, err
} }
redirect := p.Header.Get("Location") redirect := p.Header.Get("Location")
if redirect != "" { if redirect != "" {
return fmt.Errorf("Location header is not yet supported"), 0, nil return 0, nil, fmt.Errorf("Location header is not yet supported")
} }
contentLength, reader, err := r.GetContentLengthAndReader(p.Header.Get("Content-Length"), p, maxFileSizeBytes) contentLength, reader, err := r.GetContentLengthAndReader(p.Header.Get("Content-Length"), p, maxFileSizeBytes)
// For multipart requests, we need to get the Content-Type of the second part, which is the actual media // For multipart requests, we need to get the Content-Type of the second part, which is the actual media
r.MediaMetadata.ContentType = types.ContentType(p.Header.Get("Content-Type")) r.MediaMetadata.ContentType = types.ContentType(p.Header.Get("Content-Type"))
return err, contentLength, reader return contentLength, reader, err
} }
// contentDispositionFor returns the Content-Disposition for a given // contentDispositionFor returns the Content-Disposition for a given

View file

@ -35,7 +35,7 @@ func Test_Multipart(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
defer resp.Body.Close() defer resp.Body.Close()
// contentLength is always 0, since there's no Content-Length header on the multipart part. // contentLength is always 0, since there's no Content-Length header on the multipart part.
err, _, reader := parseMultipartResponse(r, resp, 1000) _, reader, err := parseMultipartResponse(r, resp, 1000)
assert.NoError(t, err) assert.NoError(t, err)
gotResponse, err := io.ReadAll(reader) gotResponse, err := io.ReadAll(reader)
assert.NoError(t, err) assert.NoError(t, err)