added extra OAuth2 avatar url download checks
This commit is contained in:
@@ -1,15 +1,20 @@
|
||||
package apis
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"net"
|
||||
"net/http"
|
||||
"path"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
@@ -294,9 +299,12 @@ func oauth2Submit(e *core.RecordAuthWithOAuth2RequestEvent, optExternalAuth *cor
|
||||
if mappedField != nil && mappedField.Type() == core.FieldTypeFile {
|
||||
// download the avatar if the mapped field is a file
|
||||
avatarFile, err := func() (*filesystem.File, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
return filesystem.NewFileFromURL(ctx, e.OAuth2User.AvatarURL)
|
||||
|
||||
// the extra checks are not required because the OAuth2 APIs are trusted vendor
|
||||
// but are here to minimize the impact in case the provider is vulnerable
|
||||
return safeFileFromURL(ctx, e.OAuth2User.AvatarURL)
|
||||
}()
|
||||
if err != nil {
|
||||
txApp.Logger().Warn("Failed to retrieve OAuth2 avatar", slog.String("error", err.Error()))
|
||||
@@ -399,3 +407,93 @@ func sendOAuth2RecordCreateRequest(txApp core.App, e *core.RecordAuthWithOAuth2R
|
||||
|
||||
return createdRecord, nil
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
// safeHTTPClient initializes a custom http.Client with extra host checks
|
||||
// to prevent internal network probing requests
|
||||
// (aka. disallow loopback, private, multicast, etc. requests).
|
||||
//
|
||||
// NB! The host checks are not perfect and there are probably edge cases that are not covered,
|
||||
// so if you plan using with untrusted user URL, consider performing additional whitelist checks.
|
||||
//
|
||||
// @todo Evaluate with the refactoring if worth exporting(+tests) and moving under the security package.
|
||||
func safeHTTPClient() (*http.Client, error) {
|
||||
dialer := &net.Dialer{
|
||||
// the same options as in http.DefaultTransport.DialContext
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
|
||||
// check the address right after estrablishing the connection to prevent dns rebinding
|
||||
Control: func(network, address string, c syscall.RawConn) error {
|
||||
host, _, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ip := net.ParseIP(host)
|
||||
|
||||
if ip == nil ||
|
||||
ip.IsLoopback() ||
|
||||
ip.IsUnspecified() ||
|
||||
ip.IsPrivate() ||
|
||||
ip.IsLinkLocalUnicast() ||
|
||||
ip.IsLinkLocalMulticast() ||
|
||||
ip.IsMulticast() {
|
||||
return fmt.Errorf("address %q is invalid or not allowed", ip.String())
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: 180 * time.Second, // can be still cancelled with the request context
|
||||
Transport: &http.Transport{
|
||||
DialContext: dialer.DialContext,
|
||||
// the same options as in http.DefaultTransport
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
},
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// safeFileFromURL downloads the file from the specified url (using safeHTTPClient)
|
||||
// and creates a new filesystem.File value from its content (limited to DefaultMaxBodySize).
|
||||
//
|
||||
// @todo Evaluate with the refactoring if worth exporting/replacing filesystem.NewFileFromURL (or redefine as NewUnsafeFileFromURL).
|
||||
func safeFileFromURL(ctx context.Context, url string) (*filesystem.File, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client, err := safeHTTPClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode < 200 || res.StatusCode > 399 {
|
||||
return nil, fmt.Errorf("failed to download url %s (%d)", url, res.StatusCode)
|
||||
}
|
||||
|
||||
body := io.LimitReader(res.Body, DefaultMaxBodySize)
|
||||
|
||||
var buf bytes.Buffer
|
||||
if _, err = io.Copy(&buf, body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return filesystem.NewFileFromBytes(buf.Bytes(), path.Base(url))
|
||||
}
|
||||
|
||||
@@ -45,12 +45,14 @@ func TestRecordAuthWithOAuth2(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// start a test server
|
||||
server := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
|
||||
localServer := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
|
||||
buf := new(bytes.Buffer)
|
||||
png.Encode(buf, image.Rect(0, 0, 1, 1)) // tiny 1x1 png
|
||||
http.ServeContent(res, req, "test_avatar.png", time.Now(), bytes.NewReader(buf.Bytes()))
|
||||
}))
|
||||
defer server.Close()
|
||||
defer localServer.Close()
|
||||
|
||||
externalImageURL := "https://pocketbase.io/images/logo.svg"
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
@@ -1176,7 +1178,7 @@ func TestRecordAuthWithOAuth2(t *testing.T) {
|
||||
Id: "oauth2_id",
|
||||
Email: "oauth2@example.com",
|
||||
Username: "oauth2_username",
|
||||
AvatarURL: server.URL + "/oauth2_avatar.png",
|
||||
AvatarURL: externalImageURL,
|
||||
},
|
||||
Token: &oauth2.Token{AccessToken: "abc"},
|
||||
}
|
||||
@@ -1208,7 +1210,98 @@ func TestRecordAuthWithOAuth2(t *testing.T) {
|
||||
`"username":"oauth2_username"`,
|
||||
`"verified":true`,
|
||||
`"rel":"0yxhwia2amd8gec"`,
|
||||
`"avatar":"oauth2_avatar_`,
|
||||
`"avatar":"logo_`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
// hidden fields
|
||||
`"tokenKey"`,
|
||||
`"password"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithOAuth2Request": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordCreateRequest": 1,
|
||||
"OnRecordEnrich": 2, // the auth response and from the create request
|
||||
// ---
|
||||
"OnModelCreate": 3, // record + authOrigins + externalAuths
|
||||
"OnModelCreateExecute": 3,
|
||||
"OnModelAfterCreateSuccess": 3,
|
||||
"OnRecordCreate": 3,
|
||||
"OnRecordCreateExecute": 3,
|
||||
"OnRecordAfterCreateSuccess": 3,
|
||||
// ---
|
||||
"OnModelUpdate": 1, // created record verified state change
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
// ---
|
||||
"OnModelValidate": 4,
|
||||
"OnRecordValidate": 4,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "creating user (with mapped OAuth2 fields and local avatarURL->file field; ensures that safeHTTPClient is being used)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-oauth2",
|
||||
Body: strings.NewReader(`{
|
||||
"provider": "test",
|
||||
"code":"123",
|
||||
"redirectURL": "https://example.com",
|
||||
"createData": {
|
||||
"name": "test_name",
|
||||
"emailVisibility": true,
|
||||
"rel": "0yxhwia2amd8gec"
|
||||
}
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
usersCol, err := app.FindCollectionByNameOrId("users")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// register the test provider
|
||||
auth.Providers["test"] = func() auth.Provider {
|
||||
return &oauth2MockProvider{
|
||||
AuthUser: &auth.AuthUser{
|
||||
Id: "oauth2_id",
|
||||
Email: "oauth2@example.com",
|
||||
Username: "oauth2_username",
|
||||
AvatarURL: localServer.URL + "/oauth2_avatar.png", // local/private file download is not allowed
|
||||
},
|
||||
Token: &oauth2.Token{AccessToken: "abc"},
|
||||
}
|
||||
}
|
||||
|
||||
// add the test provider in the collection
|
||||
usersCol.MFA.Enabled = false
|
||||
usersCol.OAuth2.Enabled = true
|
||||
usersCol.OAuth2.Providers = []core.OAuth2ProviderConfig{{
|
||||
Name: "test",
|
||||
ClientId: "123",
|
||||
ClientSecret: "456",
|
||||
}}
|
||||
usersCol.OAuth2.MappedFields = core.OAuth2KnownFields{
|
||||
Username: "name", // should be ignored because of the explicit submitted value
|
||||
Id: "username",
|
||||
AvatarURL: "avatar",
|
||||
}
|
||||
if err := app.Save(usersCol); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"isNew":true`,
|
||||
`"email":"oauth2@example.com"`,
|
||||
`"emailVisibility":true`,
|
||||
`"name":"test_name"`,
|
||||
`"username":"oauth2_username"`,
|
||||
`"verified":true`,
|
||||
`"rel":"0yxhwia2amd8gec"`,
|
||||
`"avatar":"`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
// hidden fields
|
||||
@@ -1343,7 +1436,7 @@ func TestRecordAuthWithOAuth2(t *testing.T) {
|
||||
Email: "oauth2@example.com",
|
||||
Username: "tESt2_username", // wouldn't match with existing because the related field index is case-sensitive
|
||||
Name: "oauth2_name",
|
||||
AvatarURL: server.URL + "/oauth2_avatar.png",
|
||||
AvatarURL: localServer.URL + "/oauth2_avatar.png", // allowed because it is not being downloaded
|
||||
},
|
||||
Token: &oauth2.Token{AccessToken: "abc"},
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user