From 8335ba8434aeceaa719cdf49b3b25a4dc8b3d9da Mon Sep 17 00:00:00 2001 From: Ruslan Semagin Date: Thu, 11 Jan 2024 14:56:38 +0300 Subject: [PATCH] mongodb: Connection URL parser improved --- adapter/mongo/connection.go | 17 +++++++---------- adapter/mongo/connection_test.go | 18 ++++++++++++++++++ 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/adapter/mongo/connection.go b/adapter/mongo/connection.go index 98728078..0f4b1828 100644 --- a/adapter/mongo/connection.go +++ b/adapter/mongo/connection.go @@ -27,10 +27,9 @@ import ( "strings" ) -const connectionScheme = `mongodb` - // ConnectionURL implements a MongoDB connection struct. type ConnectionURL struct { + Scheme string User string Password string Host string @@ -39,11 +38,6 @@ type ConnectionURL struct { } func (c ConnectionURL) String() (s string) { - - if c.Database == "" { - return "" - } - vv := url.Values{} // Do we have any options? @@ -69,7 +63,7 @@ func (c ConnectionURL) String() (s string) { // Building URL. u := url.URL{ - Scheme: connectionScheme, + Scheme: c.Scheme, Path: c.Database, Host: c.Host, User: userInfo, @@ -80,17 +74,20 @@ func (c ConnectionURL) String() (s string) { } // ParseURL parses s into a ConnectionURL struct. +// See https://www.mongodb.com/docs/manual/reference/connection-string/ func ParseURL(s string) (conn ConnectionURL, err error) { var u *url.URL - if !strings.HasPrefix(s, connectionScheme+"://") { - return conn, fmt.Errorf(`Expecting mongodb:// connection scheme.`) + hasPrefix := strings.HasPrefix(s, "mongodb://") || strings.HasPrefix(s, "mongodb+srv://") + if !hasPrefix { + return conn, fmt.Errorf("invalid scheme") } if u, err = url.Parse(s); err != nil { return conn, err } + conn.Scheme = u.Scheme conn.Host = u.Host // Deleting / from start of the string. diff --git a/adapter/mongo/connection_test.go b/adapter/mongo/connection_test.go index 4b21726f..ca0992bb 100644 --- a/adapter/mongo/connection_test.go +++ b/adapter/mongo/connection_test.go @@ -102,6 +102,10 @@ func TestParseConnectionURL(t *testing.T) { t.Fatal(err) } + if u.Scheme != "mongodb" { + t.Fatal("Invalid scheme") + } + if u.Database != "another_database" { t.Fatal("Failed to get database.") } @@ -132,4 +136,18 @@ func TestParseConnectionURL(t *testing.T) { t.Fatal("Expecting error.") } + s = "mongodb+srv://myDatabaseUser:D1fficultP%40ssw0rd@db1.example.net:27017,db2.example.net:2500/?replicaSet=test&connectTimeoutMS=300000" + + if u, err = ParseURL(s); err != nil { + t.Fatal(err) + } + + if u.Scheme != "mongodb+srv" { + t.Fatal("Invalid scheme") + } + + if u.Database != "" { + t.Fatal("Invalid database") + } + }