added extra OAuth2 avatar url download checks

This commit is contained in:
Gani Georgiev
2026-04-02 19:55:05 +03:00
parent 5cb66bd52f
commit cb44d9e716
4 changed files with 203 additions and 12 deletions

View File

@@ -1,15 +1,20 @@
package apis
import (
"bytes"
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"maps"
"net"
"net/http"
"path"
"strings"
"syscall"
"time"
validation "github.com/go-ozzo/ozzo-validation/v4"
@@ -294,9 +299,12 @@ func oauth2Submit(e *core.RecordAuthWithOAuth2RequestEvent, optExternalAuth *cor
if mappedField != nil && mappedField.Type() == core.FieldTypeFile {
// download the avatar if the mapped field is a file
avatarFile, err := func() (*filesystem.File, error) {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
return filesystem.NewFileFromURL(ctx, e.OAuth2User.AvatarURL)
// the extra checks are not required because the OAuth2 APIs are trusted vendor
// but are here to minimize the impact in case the provider is vulnerable
return safeFileFromURL(ctx, e.OAuth2User.AvatarURL)
}()
if err != nil {
txApp.Logger().Warn("Failed to retrieve OAuth2 avatar", slog.String("error", err.Error()))
@@ -399,3 +407,93 @@ func sendOAuth2RecordCreateRequest(txApp core.App, e *core.RecordAuthWithOAuth2R
return createdRecord, nil
}
// -------------------------------------------------------------------
// safeHTTPClient initializes a custom http.Client with extra host checks
// to prevent internal network probing requests
// (aka. disallow loopback, private, multicast, etc. requests).
//
// NB! The host checks are not perfect and there are probably edge cases that are not covered,
// so if you plan using with untrusted user URL, consider performing additional whitelist checks.
//
// @todo Evaluate with the refactoring if worth exporting(+tests) and moving under the security package.
func safeHTTPClient() (*http.Client, error) {
dialer := &net.Dialer{
// the same options as in http.DefaultTransport.DialContext
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
// check the address right after estrablishing the connection to prevent dns rebinding
Control: func(network, address string, c syscall.RawConn) error {
host, _, err := net.SplitHostPort(address)
if err != nil {
return err
}
ip := net.ParseIP(host)
if ip == nil ||
ip.IsLoopback() ||
ip.IsUnspecified() ||
ip.IsPrivate() ||
ip.IsLinkLocalUnicast() ||
ip.IsLinkLocalMulticast() ||
ip.IsMulticast() {
return fmt.Errorf("address %q is invalid or not allowed", ip.String())
}
return nil
},
}
client := &http.Client{
Timeout: 180 * time.Second, // can be still cancelled with the request context
Transport: &http.Transport{
DialContext: dialer.DialContext,
// the same options as in http.DefaultTransport
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
},
}
return client, nil
}
// safeFileFromURL downloads the file from the specified url (using safeHTTPClient)
// and creates a new filesystem.File value from its content (limited to DefaultMaxBodySize).
//
// @todo Evaluate with the refactoring if worth exporting/replacing filesystem.NewFileFromURL (or redefine as NewUnsafeFileFromURL).
func safeFileFromURL(ctx context.Context, url string) (*filesystem.File, error) {
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, err
}
client, err := safeHTTPClient()
if err != nil {
return nil, err
}
res, err := client.Do(req)
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode > 399 {
return nil, fmt.Errorf("failed to download url %s (%d)", url, res.StatusCode)
}
body := io.LimitReader(res.Body, DefaultMaxBodySize)
var buf bytes.Buffer
if _, err = io.Copy(&buf, body); err != nil {
return nil, err
}
return filesystem.NewFileFromBytes(buf.Bytes(), path.Base(url))
}