Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion adapter/cockroachdb/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func (h *Helper) TearDown() error {
return h.sess.Close()
}

func (h *Helper) TearUp() error {
func (h *Helper) SetUp() error {
var err error

h.sess, err = Open(settings)
Expand Down
4 changes: 2 additions & 2 deletions adapter/mongo/Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
SHELL ?= bash

MONGO_VERSION ?= 4
MONGO_SUPPORTED ?= $(MONGO_VERSION) 3
MONGO_VERSION ?= 8
MONGO_SUPPORTED ?= $(MONGO_VERSION) 7
PROJECT ?= $(subst .,_,"upper_mongo_$(MONGO_VERSION)")

DB_HOST ?= 127.0.0.1
Expand Down
139 changes: 36 additions & 103 deletions adapter/mongo/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,22 @@
package mongo

import (
"context"
"fmt"
"reflect"
"strings"
"sync"

"reflect"

db "github.com/upper/db/v4"
"github.com/upper/db/v4/internal/adapter"
mgo "gopkg.in/mgo.v2"
"gopkg.in/mgo.v2/bson"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
)

// Collection represents a mongodb collection.
type Collection struct {
parent *Source
collection *mgo.Collection
collection *mongo.Collection
}

var (
Expand Down Expand Up @@ -108,9 +108,15 @@ func compare(field string, cmp *adapter.Comparison) (string, interface{}) {
}
return field, bson.M{"$ne": value}
case adapter.ComparisonOperatorRegExp, adapter.ComparisonOperatorLike:
return field, bson.RegEx{Pattern: value.(string), Options: ""}
return field, bson.M{
"$regex": value.(string),
}
case adapter.ComparisonOperatorNotRegExp, adapter.ComparisonOperatorNotLike:
return field, bson.M{"$not": bson.RegEx{Pattern: value.(string), Options: ""}}
return field, bson.M{
"$not": bson.M{
"$regex": value.(string),
},
}
}

if cmpOp, ok := comparisonOperators[op]; ok {
Expand All @@ -122,8 +128,8 @@ func compare(field string, cmp *adapter.Comparison) (string, interface{}) {
panic(fmt.Sprintf("Unsupported operator %v", op))
}

// compileStatement transforms conditions into something *mgo.Session can
// understand.
// compileStatement transforms upper-db conditions into something that the
// adapter can understand.
func compileStatement(cond db.Cond) bson.M {
conds := bson.M{}

Expand Down Expand Up @@ -170,10 +176,7 @@ func compileStatement(cond db.Cond) bson.M {
return conds
}

// compileConditions compiles terms into something *mgo.Session can
// understand.
func (col *Collection) compileConditions(term interface{}) interface{} {

switch t := term.(type) {
case []interface{}:
values := []interface{}{}
Expand Down Expand Up @@ -208,8 +211,6 @@ func (col *Collection) compileConditions(term interface{}) interface{} {
return nil
}

// compileQuery compiles terms into something that *mgo.Session can
// understand.
func (col *Collection) compileQuery(terms ...interface{}) interface{} {
compiled := col.compileConditions(terms)
if compiled == nil {
Expand All @@ -236,13 +237,12 @@ func (col *Collection) compileQuery(terms ...interface{}) interface{} {

// Name returns the name of the table or tables that form the collection.
func (col *Collection) Name() string {
return col.collection.Name
return col.collection.Name()
}

// Truncate deletes all rows from the table.
func (col *Collection) Truncate() error {
err := col.collection.DropCollection()

err := col.collection.Drop(context.Background())
if err != nil {
return err
}
Expand All @@ -268,100 +268,33 @@ func (col *Collection) UpdateReturning(item interface{}) error {

// Insert inserts a record (map or struct) into the collection.
func (col *Collection) Insert(item interface{}) (db.InsertResult, error) {
var err error
ctx := context.Background()

id := getID(item)

if col.parent.versionAtLeast(2, 6, 0, 0) {
// this breaks MongoDb older than 2.6
if _, err = col.collection.Upsert(bson.M{"_id": id}, item); err != nil {
return nil, err
}
} else {
// Allocating a new ID.
if err = col.collection.Insert(bson.M{"_id": id}); err != nil {
return nil, err
}

// Now append data the user wants to append.
if err = col.collection.Update(bson.M{"_id": id}, item); err != nil {
// Cleanup allocated ID
if err := col.collection.Remove(bson.M{"_id": id}); err != nil {
return nil, err
}
return nil, err
}
res, err := col.collection.InsertOne(ctx, item)
if err != nil {
return nil, err
}

return db.NewInsertResult(id), nil
return db.NewInsertResult(res.InsertedID), nil
}

// Exists returns true if the collection exists.
func (col *Collection) Exists() (bool, error) {
query := col.parent.database.C(`system.namespaces`).Find(map[string]string{`name`: fmt.Sprintf(`%s.%s`, col.parent.database.Name, col.collection.Name)})
count, err := query.Count()
return count > 0, err
}
ctx := context.Background()
mcol := col.parent.database.Collection("system.namespaces")

// Fetches object _id or generates a new one if object doesn't have one or the one it has is invalid
func getID(item interface{}) interface{} {
v := reflect.ValueOf(item) // convert interface to Value
v = reflect.Indirect(v) // convert pointers

switch v.Kind() {
case reflect.Map:
if inItem, ok := item.(map[string]interface{}); ok {
if id, ok := inItem["_id"]; ok {
bsonID, ok := id.(bson.ObjectId)
if ok {
return bsonID
}
}
}
case reflect.Struct:
t := v.Type()

idCacheMutex.RLock()
fieldName, found := idCache[t]
idCacheMutex.RUnlock()

if !found {
for n := 0; n < t.NumField(); n++ {
field := t.Field(n)
if field.PkgPath != "" {
continue // Private field
}

tag := field.Tag.Get("bson")
if tag == "" {
tag = field.Tag.Get("db")
}

if tag == "" {
continue
}

parts := strings.Split(tag, ",")

if parts[0] == "_id" {
fieldName = field.Name
idCacheMutex.RLock()
idCache[t] = fieldName
idCacheMutex.RUnlock()
break
}
}
}
if fieldName != "" {
if bsonID, ok := v.FieldByName(fieldName).Interface().(bson.ObjectId); ok {
if bsonID.Valid() {
return bsonID
}
} else {
return v.FieldByName(fieldName).Interface()
}
}
mcur, err := mcol.Find(ctx, bson.M{
"name": fmt.Sprintf("%s.%s", col.parent.database.Name(), col.collection.Name()),
})
if err != nil {
return false, err
}
defer mcur.Close(ctx)

hasNext := mcur.Next(ctx)
if err := mcur.Err(); err != nil {
return false, err
}

return bson.NewObjectId()
return hasNext, nil
}
28 changes: 19 additions & 9 deletions adapter/mongo/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,14 @@ import (
"strings"
)

const connectionScheme = `mongodb`
const (
defaultScheme = "mongodb"
srvScheme = "mongodb+srv"
)

// ConnectionURL implements a MongoDB connection struct.
type ConnectionURL struct {
Scheme string
User string
Password string
Host string
Expand All @@ -39,13 +43,12 @@ type ConnectionURL struct {
}

func (c ConnectionURL) String() (s string) {
vv := url.Values{}

if c.Database == "" {
return ""
}

vv := url.Values{}

// Do we have any options?
if c.Options == nil {
c.Options = map[string]string{}
Expand All @@ -67,9 +70,13 @@ func (c ConnectionURL) String() (s string) {
}
}

if c.Scheme == "" {
c.Scheme = defaultScheme
}

// Building URL.
u := url.URL{
Scheme: connectionScheme,
Scheme: c.Scheme,
Path: c.Database,
Host: c.Host,
User: userInfo,
Expand All @@ -80,17 +87,20 @@ func (c ConnectionURL) String() (s string) {
}

// ParseURL parses s into a ConnectionURL struct.
// See https://www.mongodb.com/docs/manual/reference/connection-string/
func ParseURL(s string) (conn ConnectionURL, err error) {
var u *url.URL

if !strings.HasPrefix(s, connectionScheme+"://") {
return conn, fmt.Errorf(`Expecting mongodb:// connection scheme.`)
hasPrefix := strings.HasPrefix(s, defaultScheme+"://") || strings.HasPrefix(s, srvScheme+"://")
if !hasPrefix {
return conn, fmt.Errorf("invalid scheme")
}

if u, err = url.Parse(s); err != nil {
return conn, err
return conn, fmt.Errorf("invalid URL: %v", err)
}

conn.Scheme = u.Scheme
conn.Host = u.Host

// Deleting / from start of the string.
Expand All @@ -108,12 +118,12 @@ func ParseURL(s string) (conn ConnectionURL, err error) {
var vv url.Values

if vv, err = url.ParseQuery(u.RawQuery); err != nil {
return conn, err
return conn, fmt.Errorf("invalid query: %v", err)
}

for k := range vv {
conn.Options[k] = vv.Get(k)
}

return conn, err
return conn, nil
}
43 changes: 23 additions & 20 deletions adapter/mongo/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,30 +70,33 @@ func TestConnectionURL(t *testing.T) {
}

func TestParseConnectionURL(t *testing.T) {
var u ConnectionURL
var s string
var err error
{
const s = "mongodb:///mydatabase"

s = "mongodb:///mydatabase"
u, err := ParseURL(s)
require.NoError(t, err)

u, err = ParseURL(s)
require.NoError(t, err)

assert.Equal(t, "mydatabase", u.Database)
assert.Equal(t, "mydatabase", u.Database)
}

s = "mongodb://user:pass@localhost,1.2.3.4,example.org:1234/another_database?cache=foobar&mode=ro"
{
const s = "mongodb://user:pass@localhost,1.2.3.4,example.org:1234/another_database?cache=foobar&mode=ro"

u, err = ParseURL(s)
require.NoError(t, err)
u, err := ParseURL(s)
require.NoError(t, err)

assert.Equal(t, "another_database", u.Database)
assert.Equal(t, "foobar", u.Options["cache"])
assert.Equal(t, "ro", u.Options["mode"])
assert.Equal(t, "user", u.User)
assert.Equal(t, "pass", u.Password)
assert.Equal(t, "localhost,1.2.3.4,example.org:1234", u.Host)
assert.Equal(t, "another_database", u.Database)
assert.Equal(t, "foobar", u.Options["cache"])
assert.Equal(t, "ro", u.Options["mode"])
assert.Equal(t, "mongodb", u.Scheme)
assert.Equal(t, "user", u.User)
assert.Equal(t, "pass", u.Password)
assert.Equal(t, "localhost,1.2.3.4,example.org:1234", u.Host)
}

s = "http://example.org"
_, err = ParseURL(s)
require.Error(t, err)
{
const s = "http://example.org"
_, err := ParseURL(s)
require.Error(t, err)
}
}
Loading