diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 7d6d7aa..4ee0d28 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -28,7 +28,7 @@ jobs: - name: download surrealdb run: curl --proto '=https' --tlsv1.2 -sSf https://install.surrealdb.com | sh -s -- --nightly - name: start surrealdb - run: surreal start memory -A --auth --user root --pass root & + run: surreal start memory -A --user root --pass root & - name: test run: go test -v -cover ./... env: diff --git a/README.md b/README.md index bfdcb2d..b3624a6 100644 --- a/README.md +++ b/README.md @@ -46,97 +46,143 @@ go get github.com/surrealdb/surrealdb.go ## Getting started -In the example below you can see how to connect to a remote instance of SurrealDB, authenticating with the database, and issuing queries for creating, updating, and selecting data from records. - +[//]: # (In the example below you can see how to connect to a remote instance of SurrealDB, authenticating with the database, and issuing queries for creating, updating, and selecting data from records.) +In the example provided below, we are going to connect and authenticate on a SurrealDB server, set the namespace and make several data manipulation requests. > This example requires SurrealDB to be [installed](https://surrealdb.com/install) and running on port 8000. ```go package main import ( - "github.com/surrealdb/surrealdb.go" + "fmt" + surrealdb "github.com/surrealdb/surrealdb.go" + "github.com/surrealdb/surrealdb.go/pkg/models" ) -type User struct { - ID string `json:"id,omitempty"` - Name string `json:"name"` - Surname string `json:"surname"` +type Person struct { + ID *models.RecordID `json:"id,omitempty"` + Name string `json:"name"` + Surname string `json:"surname"` + Location models.GeometryPoint `json:"location"` } func main() { // Connect to SurrealDB - db, err := surrealdb.New("ws://localhost:8000/rpc") + db, err := surrealdb.New("ws://localhost:8000") if err != nil { panic(err) } - authData := &surrealdb.Auth{ - Database: "test", - Namespace: "test", - Username: "root", - Password: "root", - } - if _, err = db.Signin(authData); err != nil { + // Set the namespace and database + if err = db.Use("testNS", "testDB"); err != nil { panic(err) } - if _, err = db.Use("test", "test"); err != nil { + // Sign in to authentication `db` + authData := &surrealdb.Auth{ + Username: "root", // use your setup username + Password: "root", // use your setup password + } + token, err := db.SignIn(authData) + if err != nil { panic(err) } - // Define user struct - user := User{ - Name: "John", - Surname: "Doe", + // Check token validity. This is not necessary if you called `SignIn` before. This authenticates the `db` instance too if sign in was + // not previously called + if err := db.Authenticate(token); err != nil { + panic(err) } - // Insert user - data, err := db.Create("user", user) + // And we can later on invalidate the token if desired + defer func(token string) { + if err := db.Invalidate(); err != nil { + panic(err) + } + }(token) + + // Create an entry + person1, err := surrealdb.Create[Person](db, models.Table("persons"), map[interface{}]interface{}{ + "Name": "John", + "Surname": "Doe", + "Location": models.NewGeometryPoint(-0.11, 22.00), + }) if err != nil { panic(err) } - - // Unmarshal data - createdUser := make([]User, 1) - err = surrealdb.Unmarshal(data, &createdUser) + fmt.Printf("Created person with a map: %+v\n", person1) + + // Or use structs + person2, err := surrealdb.Create[Person](db, models.Table("persons"), Person{ + Name: "John", + Surname: "Doe", + Location: models.NewGeometryPoint(-0.11, 22.00), + }) if err != nil { panic(err) } + fmt.Printf("Created person with a struvt: %+v\n", person2) - // Get user by ID - data, err = db.Select(createdUser[0].ID) + // Get entry by Record ID + person, err := surrealdb.Select[Person, models.RecordID](db, *person1.ID) if err != nil { panic(err) } + fmt.Printf("Selected a person by record id: %+v\n", person) - // Unmarshal data - selectedUser := new(User) - err = surrealdb.Unmarshal(data, &selectedUser) + // Or retrieve the entire table + persons, err := surrealdb.Select[[]Person, models.Table](db, models.Table("persons")) if err != nil { panic(err) } + fmt.Printf("Selected all in persons table: %+v\n", persons) - // Change part/parts of user - changes := map[string]string{"name": "Jane"} - - // Update user - if _, err = db.Update(selectedUser.ID, changes); err != nil { + // Delete an entry by ID + if err = surrealdb.Delete[models.RecordID](db, *person2.ID); err != nil { panic(err) } - if _, err = db.Query("SELECT * FROM $record", map[string]interface{}{ - "record": createdUser[0].ID, - }); err != nil { + // Delete all entries + if err = surrealdb.Delete[models.Table](db, models.Table("persons")); err != nil { panic(err) } - // Delete user by ID - if _, err = db.Delete(selectedUser.ID); err != nil { + // Confirm empty table + persons, err = surrealdb.Select[[]Person](db, models.Table("persons")) + if err != nil { panic(err) } + fmt.Printf("No Selected person: %+v\n", persons) } ``` +### Doing it your way +All Data manipulation methods are handled by an undelying `send` function. This function is +exposed via `db.Send` function if you want to create requests yourself but limited to a selected set of methods. Theses +methods are: +- select +- create +- insert +- upsert +- update +- patch +- delete +- query +```go +type UserSelectResult struct { + Result []Users +} + +var res UserSelectResult +// or var res surrealdb.Result[[]Users] + +err := db.Send(&res, "query", user.ID) +if err != nil { + panic(err) +} + +``` + ### Instructions for running the example - In a new folder, create a file called `main.go` and paste the above code @@ -144,6 +190,100 @@ func main() { - Run `go mod tidy` to download the `surrealdb.go` dependency - Run `go run main.go` to run the example. +## Connection Engines +There are 2 different connection engines you can use to connect to SurrealDb backend. You can do so via Websocket or through HTTP +connections + +### Via Websocket +```go +db, err := surrealdb.New("ws://localhost:8000") +``` +or for a secure connection +```go +db, err := surrealdb.New("wss://localhost:8000") +``` + +### Via HTTP +There are some functions that are not available on RPC when using HTTP but on Websocket. All these except +the "live" endpoint are effectively implemented in the HTTP library and provides the same result as though +it is natively available on HTTP. While using the HTTP connection engine, note that live queries will still +use a websocket connection if the backend supports it +```go +db, err := surrealdb.New("http://localhost:8000") +``` +or for a secure connection +```go +db, err := surrealdb.New("https://localhost:8000") +``` + + +## Data Models +This package facilitates communication between client and the backend service using the Concise +Binary Object Representation (CBOR) format. It streamlines data serialization and deserialization +while ensuring efficient and lightweight communication. The library also provides custom models +tailored to specific Data models recognised by SurrealDb, which cannot be covered by idiomatic go, enabling seamless interaction between +the client and the backend. + +See the [documetation on data models](https://surrealdb.com/docs/surrealql/datamodel) on support data types + +| CBOR Type | Go Representation | Example | +|-------------------|-----------------------------|----------------------------| +| Null | `nil` | `var x interface{} = nil` | +| None | `surrealdb.None` | `map[string]interface{}{"customer": surrealdb.None}` | +| Boolean | `bool` | `true`, `false` | +| Array | `[]interface{}` | `[]MyStruct{item1, item2}` | +| Date/Time | `time.Time` | `time.Now()` | +| Duration | `time.Duration` | `time.Duration(8821356)` | +| UUID (string representation) | `surrealdb.UUID(string)` | `surrealdb.UUID("123e4567-e89b-12d3-a456-426614174000")` | +| UUID (binary representation) | `surrealdb.UUIDBin([]bytes)`| `surrealdb.UUIDBin([]byte{0x01, 0x02, ...}`)` | +| Integer | `uint`, `uint64`, `int`, `int64` | `42`, `uint64(100000)`, `-42`, `int64(-100000)` | +| Floating Point | `float32`, `float64` | `3.14`, `float64(2.71828)` | +| Byte String, Binary Encoded Data | `[]byte` | `[]byte{0x01, 0x02}` | +| Text String | `string` | `"Hello, World!"` | +| Map | `map[interface{}]interface{}` | `map[string]float64{"one": 1.0}` | +| Table name| `surrealdb.Table(name)` | `surrealdb.Table("users")` | +| Record ID| `surrealdb.RecordID{Table: string, ID: interface{}}` | `surrealdb.RecordID{Table: "customers", ID: 1}, surrealdb.NewRecordID("customers", 1)` | +| Geometry Point | `surrealdb.GeometryPoint{Latitude: float64, Longitude: float64}` | `surrealdb.GeometryPoint{Latitude: 11.11, Longitude: 22.22` | +| Geometry Line | `surrealdb.GeometryLine{GeometricPoint1, GeometricPoint2,... }` | | +| Geometry Polygon | `surrealdb.GeometryPolygon{GeometryLine1, GeometryLine2,... }` | | +| Geometry Multipoint | `surrealdb.GeometryMultiPoint{GeometryPoint1, GeometryPoint2,... }` | | +| Geometry MultiLine | `surrealdb.GeometryMultiLine{GeometryLine1, GeometryLine2,... }` | | +| Geometry MultiPolygon | `surrealdb.GeometryMultiPolygon{GeometryPolygon1, GeometryPolygon2,... }` | | +| Geometry Collection| `surrealdb.GeometryMultiPolygon{GeometryPolygon1, GeometryLine2, GeometryPoint3, GeometryMultiPoint4,... }` | | + +## Helper Types +### surrealdb.O +For some methods like create, insert, update, you can pass a map instead of an struct value. An example: +```go +person, err := surrealdb.Create[Person](db, models.Table("persons"), map[interface{}]interface{}{ + "Name": "John", + "Surname": "Doe", + "Location": models.NewGeometryPoint(-0.11, 22.00), +}) +``` +This can be simplified to: +```go +person, err := surrealdb.Create[Person](db, models.Table("persons"), surrealdb.O{ + "Name": "John", + "Surname": "Doe", + "Location": models.NewGeometryPoint(-0.11, 22.00), +}) +``` +Where surrealdb.O is defined below. There is no special advantage in using this other than simplicity/legibility. +```go +type surrealdb.O map[interface{}]interface{} +``` + +### surrealdb.Result[T] +This is useful for the `Send` function where `T` is the expected response type for a request. An example: +```go +var res surrealdb.Result[[]Users] +err := db.Send(&res, "select", model.Table("users")) +if err != nil { + panic(err) +} +fmt.Printf("users: %+v\n", users.R) +``` ## Contributing You can run the Makefile commands to run and build the project @@ -157,28 +297,6 @@ make lint You also need to be running SurrealDB alongside the tests. We recommend using the nightly build, as development may rely on the latest functionality. -## Helper functions - -### Smart Marshal - -SurrealDB Go library supports smart marshal. It means that you can use any type of data as a value in your struct, and the library will automatically convert it to the correct type. - -```go -// User struct is a test struct -user, err := surrealdb.SmartUnmarshal[testUser](surrealdb.SmartMarshal(s.db.Create, user[0])) - -// Can be used without SmartUnmarshal -data, err := surrealdb.SmartMarshal(s.db.Create, user[0]) -``` - -### Smart Unmarshal - -SurrealDB Go library supports smart unmarshal. It means that you can unmarshal any type of data to the generic type provided, and the library will automatically convert it to that type. - -```go -// User struct is a test struct -data, err := surrealdb.SmartUnmarshal[testUser](s.db.Select(user[0].ID)) -``` diff --git a/db.go b/db.go index f73f2da..aaa316e 100644 --- a/db.go +++ b/db.go @@ -1,161 +1,275 @@ package surrealdb import ( + "context" "fmt" + "log/slog" + "net/url" + "os" + "strings" - "github.com/surrealdb/surrealdb.go/pkg/model" - - "github.com/surrealdb/surrealdb.go/pkg/conn" + "github.com/surrealdb/surrealdb.go/pkg/connection" "github.com/surrealdb/surrealdb.go/pkg/constants" + "github.com/surrealdb/surrealdb.go/pkg/logger" + "github.com/surrealdb/surrealdb.go/pkg/models" ) -// DB is a client for the SurrealDB database that holds the connection. -type DB struct { - conn conn.Connection +type VersionData struct { + Version string `json:"version"` + Build string `json:"build"` + Timestamp string `json:"timestamp"` } -// Auth is a struct that holds surrealdb auth data for login. -type Auth struct { - Namespace string `json:"NS,omitempty"` - Database string `json:"DB,omitempty"` - Scope string `json:"SC,omitempty"` - Username string `json:"user,omitempty"` - Password string `json:"pass,omitempty"` +// DB is a client for the SurrealDB database that holds the connection. +type DB struct { + ctx context.Context + con connection.Connection } // New creates a new SurrealDB client. -func New(url string, connection conn.Connection) (*DB, error) { - connection, err := connection.Connect(url) +func New(connectionURL string) (*DB, error) { + u, err := url.ParseRequestURI(connectionURL) + if err != nil { + return nil, err + } + + scheme := u.Scheme + + newParams := connection.NewConnectionParams{ + Marshaler: models.CborMarshaler{}, + Unmarshaler: models.CborUnmarshaler{}, + BaseURL: fmt.Sprintf("%s://%s", u.Scheme, u.Host), + Logger: logger.New(slog.NewTextHandler(os.Stdout, nil)), + } + + var con connection.Connection + if scheme == "http" || scheme == "https" { + con = connection.NewHTTPConnection(newParams) + } else if scheme == "ws" || scheme == "wss" { + con = connection.NewWebSocketConnection(newParams) + } else { + return nil, fmt.Errorf("invalid connection url") + } + + err = con.Connect() if err != nil { return nil, err } - return &DB{connection}, nil + + return &DB{con: con}, nil } // -------------------------------------------------- // Public methods // -------------------------------------------------- +// WithContext +func (db *DB) WithContext(ctx context.Context) *DB { + db.ctx = ctx + return db +} + // Close closes the underlying WebSocket connection. func (db *DB) Close() error { - return db.conn.Close() + return db.con.Close() } -// -------------------------------------------------- - // Use is a method to select the namespace and table to use. -func (db *DB) Use(ns, database string) (interface{}, error) { - return db.send("use", ns, database) +func (db *DB) Use(ns, database string) error { + return db.con.Use(ns, database) } -func (db *DB) Info() (interface{}, error) { - return db.send("info") +func (db *DB) Info() (map[string]interface{}, error) { + var info connection.RPCResponse[map[string]interface{}] + err := db.con.Send(&info, "info") + return *info.Result, err } -// Signup is a helper method for signing up a new user. -func (db *DB) Signup(authData *Auth) (interface{}, error) { - return db.send("signup", authData) +// SignUp is a helper method for signing up a new user. +func (db *DB) SignUp(authData *Auth) (string, error) { + var token connection.RPCResponse[string] + if err := db.con.Send(&token, "signup", authData); err != nil { + return "", err + } + + if err := db.con.Let(constants.AuthTokenKey, token.Result); err != nil { + return "", err + } + + return *token.Result, nil } -// Signin is a helper method for signing in a user. -func (db *DB) Signin(authData *Auth) (interface{}, error) { - return db.send("signin", authData) +// SignIn is a helper method for signing in a user. +func (db *DB) SignIn(authData *Auth) (string, error) { + var token connection.RPCResponse[string] + if err := db.con.Send(&token, "signin", authData); err != nil { + return "", err + } + + if err := db.con.Let(constants.AuthTokenKey, token.Result); err != nil { + return "", err + } + + return *token.Result, nil } -func (db *DB) Invalidate() (interface{}, error) { - return db.send("invalidate") +func (db *DB) Invalidate() error { + if err := db.con.Send(nil, "invalidate"); err != nil { + return err + } + + if err := db.con.Unset(constants.AuthTokenKey); err != nil { + return err + } + + return nil } -func (db *DB) Authenticate(token string) (interface{}, error) { - return db.send("authenticate", token) +func (db *DB) Authenticate(token string) error { + if err := db.con.Send(nil, "authenticate", token); err != nil { + return err + } + + if err := db.con.Let(constants.AuthTokenKey, token); err != nil { + return err + } + + return nil } -// -------------------------------------------------- +func (db *DB) Let(key string, val interface{}) error { + return db.con.Let(key, val) +} -func (db *DB) Live(table string, diff bool) (string, error) { - id, err := db.send("live", table, diff) - return id.(string), err +func (db *DB) Unset(key string) error { + return db.con.Unset(key) } -func (db *DB) Kill(liveQueryID string) (interface{}, error) { - return db.send("kill", liveQueryID) +func (db *DB) Version() (*VersionData, error) { + var ver connection.RPCResponse[VersionData] + if err := db.con.Send(&ver, "version"); err != nil { + return nil, err + } + return ver.Result, nil } -func (db *DB) Let(key string, val interface{}) (interface{}, error) { - return db.send("let", key, val) +func (db *DB) Send(res interface{}, method string, params ...interface{}) error { + allowedSendMethods := []string{"select", "create", "insert", "update", "upsert", "patch", "delete", "query"} + + allowed := false + for i := 0; i < len(allowedSendMethods); i++ { + if strings.EqualFold(allowedSendMethods[i], strings.ToLower(method)) { + allowed = true + break + } + } + + if !allowed { + return fmt.Errorf("provided method is not allowed") + } + + return db.con.Send(&res, method, params...) } -// Query is a convenient method for sending a query to the database. -func (db *DB) Query(sql string, vars interface{}) (interface{}, error) { - return db.send("query", sql, vars) +func (db *DB) LiveNotifications(liveQueryID string) (chan connection.Notification, error) { + return db.con.LiveNotifications(liveQueryID) } -// Select a table or record from the database. -func (db *DB) Select(what string) (interface{}, error) { - return db.send("select", what) +//-------------------------------------------------------------------------------------------------------------------// + +func Kill(db *DB, id string) error { + return db.con.Send(nil, "kill", id) } -// Creates a table or record in the database like a POST request. -func (db *DB) Create(thing string, data interface{}) (interface{}, error) { - return db.send("create", thing, data) +func Live(db *DB, table models.Table, diff bool) (*models.UUID, error) { + var res connection.RPCResponse[models.UUID] + if err := db.con.Send(&res, "live", table, diff); err != nil { + return nil, err + } + + return res.Result, nil } -// Update a table or record in the database like a PUT request. -func (db *DB) Update(what string, data interface{}) (interface{}, error) { - return db.send("update", what, data) +func Query[T any](db *DB, sql string, vars map[string]interface{}) (*[]QueryResult[T], error) { + var res connection.RPCResponse[[]QueryResult[T]] + if err := db.con.Send(&res, "query", sql, vars); err != nil { + return nil, err + } + + return res.Result, nil } -// Merge a table or record in the database like a PATCH request. -func (db *DB) Merge(what string, data interface{}) (interface{}, error) { - return db.send("merge", what, data) +func Create[TResult any, TWhat models.TableOrRecord](db *DB, what TWhat, data interface{}) (*TResult, error) { + var res connection.RPCResponse[TResult] + if err := db.con.Send(&res, "create", what, data); err != nil { + return nil, err + } + + return res.Result, nil } -// Patch applies a series of JSONPatches to a table or record. -func (db *DB) Patch(what string, data []Patch) (interface{}, error) { - return db.send("patch", what, data) +func Select[TResult any, TWhat models.TableOrRecord](db *DB, what TWhat) (*TResult, error) { + var res connection.RPCResponse[TResult] + + if err := db.con.Send(&res, "select", what); err != nil { + return nil, err + } + + return res.Result, nil } -// Delete a table or a row from the database like a DELETE request. -func (db *DB) Delete(what string) (interface{}, error) { - return db.send("delete", what) +func Patch(db *DB, what interface{}, patches []PatchData) (*[]PatchData, error) { + var patchRes connection.RPCResponse[[]PatchData] + err := db.con.Send(&patchRes, "patch", what, patches, true) + return patchRes.Result, err } -// Insert a table or a row from the database like a POST request. -func (db *DB) Insert(what string, data interface{}) (interface{}, error) { - return db.send("insert", what, data) +func Delete[TWhat models.TableOrRecord](db *DB, what TWhat) error { + return db.con.Send(nil, "delete", what) } -// LiveNotifications returns a channel for live query. -func (db *DB) LiveNotifications(liveQueryID string) (chan model.Notification, error) { - return db.conn.LiveNotifications(liveQueryID) +func Upsert(db *DB, what, data interface{}) error { + return db.con.Send(nil, "upsert", what, data) } -// -------------------------------------------------- -// Private methods -// -------------------------------------------------- +// Update a table or record in the database like a PUT request. +func Update[TResult any, TWhat models.TableOrRecord](db *DB, what TWhat, data interface{}) (*TResult, error) { + var res connection.RPCResponse[TResult] + if err := db.con.Send(&res, "update", what, data); err != nil { + return nil, err + } -// send is a helper method for sending a query to the database. -func (db *DB) send(method string, params ...interface{}) (interface{}, error) { - // here we send the args through our websocket connection - resp, err := db.conn.Send(method, params) - if err != nil { - return nil, fmt.Errorf("sending request failed for method '%s': %w", method, err) + return res.Result, nil +} + +// Merge a table or record in the database like a PATCH request. +func Merge[T any](db *DB, what, data interface{}) (*T, error) { + var res connection.RPCResponse[T] + if err := db.con.Send(&res, "merge", what, data); err != nil { + return nil, err } - switch method { - case "select", "create", "update", "merge", "patch", "insert": - return db.resp(method, params, resp) - case "delete": - return nil, nil - default: - return resp, nil + return res.Result, nil +} + +// Insert a table or a row from the database like a POST request. +func Insert[TResult any](db *DB, what models.Table, data interface{}) (*[]TResult, error) { + var res connection.RPCResponse[[]TResult] + if err := db.con.Send(&res, "insert", what, data); err != nil { + return nil, err } + + return res.Result, nil } -// resp is a helper method for parsing the response from a query. -func (db *DB) resp(_ string, _ []interface{}, res interface{}) (interface{}, error) { - if res == nil { - return nil, constants.ErrNoRow +func Relate[T any](db *DB, in, out models.RecordID, relation models.Table, data interface{}) (*T, error) { + var res connection.RPCResponse[T] + if err := db.con.Send(&res, "relate", in, out, relation, data); err != nil { + return nil, err } - return res, nil + return res.Result, nil +} + +func InsertRelation(db *DB, what, data interface{}) error { + return db.con.Send(nil, "insert", what, data) } diff --git a/db_test.go b/db_test.go index def4dd1..f14f8bb 100644 --- a/db_test.go +++ b/db_test.go @@ -1,109 +1,69 @@ package surrealdb_test import ( - "bytes" - "encoding/json" "fmt" - "io" - rawslog "log/slog" "os" "sync" "testing" - "time" - "github.com/stretchr/testify/assert" - "github.com/surrealdb/surrealdb.go/pkg/logger/slog" - "github.com/surrealdb/surrealdb.go/pkg/model" - - "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "github.com/surrealdb/surrealdb.go" - "github.com/surrealdb/surrealdb.go/pkg/conn/gorilla" - "github.com/surrealdb/surrealdb.go/pkg/constants" - - "github.com/surrealdb/surrealdb.go/pkg/conn" - "github.com/surrealdb/surrealdb.go/pkg/logger" - "github.com/surrealdb/surrealdb.go/pkg/marshal" + "github.com/surrealdb/surrealdb.go/pkg/connection" + "github.com/surrealdb/surrealdb.go/pkg/models" ) -// Default consts and vars for testing +// Default const and vars for testing const ( - defaultURL = "ws://localhost:8000/rpc" + defaultURL = "ws://localhost:8000" ) var currentURL = os.Getenv("SURREALDB_URL") -// +func getURL() string { + if currentURL == "" { + return defaultURL + } + return currentURL +} // TestDBSuite is a test s for the DB struct type SurrealDBTestSuite struct { suite.Suite - db *surrealdb.DB - name string - connImplementations map[string]conn.Connection - logBuffer *bytes.Buffer + db *surrealdb.DB + name string } // a simple user struct for testing type testUser struct { - marshal.Basemodel `table:"test"` - Username string `json:"username,omitempty"` - Password string `json:"password,omitempty"` - ID string `json:"id,omitempty"` + Username string `json:"username,omitempty"` + Password string `json:"password,omitempty"` + ID *models.RecordID `json:"id,omitempty"` } -// a simple user struct for testing -type testUserWithFriend[I any] struct { - marshal.Basemodel `table:"user_with_friend"` - Username string `json:"username,omitempty"` - Password string `json:"password,omitempty"` - ID string `json:"id,omitempty"` - Friends []I `json:"friends,omitempty"` +// assertContains performs an assertion on a list, asserting that at least one element matches a provided condition. +// All the matching elements are returned from this function, which can be used as a filter. +func assertContains[K any](s *SurrealDBTestSuite, input []K, matcher func(K) bool) []K { + matching := make([]K, 0) + for _, v := range input { + if matcher(v) { + matching = append(matching, v) + } + } + s.NotEmptyf(matching, "Input %+v did not contain matching element", fmt.Sprintf("%+v", input)) + return matching } func TestSurrealDBSuite(t *testing.T) { - SurrealDBSuite := new(SurrealDBTestSuite) - SurrealDBSuite.connImplementations = make(map[string]conn.Connection) - - // Without options - buff := bytes.NewBufferString("") - logData := createLogger(t, buff) - SurrealDBSuite.connImplementations["gorilla"] = gorilla.Create().Logger(logData) - SurrealDBSuite.logBuffer = buff - - // With options - buffOpt := bytes.NewBufferString("") - logDataOpt := createLogger(t, buff) - SurrealDBSuite.connImplementations["gorilla_opt"] = gorilla.Create().SetTimeOut(time.Minute).SetCompression(true).Logger(logDataOpt) - SurrealDBSuite.logBuffer = buffOpt - - RunWsMap(t, SurrealDBSuite) -} - -func createLogger(t *testing.T, writer io.Writer) logger.Logger { - t.Helper() - handler := rawslog.NewJSONHandler(writer, &rawslog.HandlerOptions{Level: rawslog.LevelDebug}) - return slog.New(handler) -} + s := new(SurrealDBTestSuite) -func RunWsMap(t *testing.T, s *SurrealDBTestSuite) { - for wsName := range s.connImplementations { - // Run the test suite - t.Run(wsName, func(t *testing.T) { - s.name = wsName - suite.Run(t, s) - }) - } + s.name = "Test_DB" + suite.Run(t, s) } // SetupTest is called after each test func (s *SurrealDBTestSuite) TearDownTest() { - _, err := s.db.Delete("users") + err := surrealdb.Delete[models.Table](s.db, "users") s.Require().NoError(err) - - if s.logBuffer.Len() > 0 { - s.T().Logf("Log output:\n%s", s.logBuffer.String()) - } } // TearDownSuite is called after the s has finished running @@ -112,268 +72,75 @@ func (s *SurrealDBTestSuite) TearDownSuite() { s.Require().NoError(err) } -func (t testUser) String() (str string, err error) { - byteData, err := json.Marshal(t) - if err != nil { - return - } - str = string(byteData) - return -} - -func (s *SurrealDBTestSuite) createTestDB() *surrealdb.DB { - url := os.Getenv("SURREALDB_URL") - if url == "" { - url = "ws://localhost:8000/rpc" - } - impl := s.connImplementations[s.name] - db := s.openConnection(url, impl) - return db -} - -// openConnection opens a new connection to the database -func (s *SurrealDBTestSuite) openConnection(url string, impl conn.Connection) *surrealdb.DB { - require.NotNil(s.T(), impl) - db, err := surrealdb.New(url, impl) - s.Require().NoError(err) - return db -} - // SetupSuite is called before the s starts running func (s *SurrealDBTestSuite) SetupSuite() { - db := s.createTestDB() - s.Require().NotNil(db) + db, err := surrealdb.New(getURL()) + s.Require().NoError(err, "should not return an error when initializing db") s.db = db - _ = signin(s) - _, err := db.Use("test", "test") - s.Require().NoError(err) + + _ = signIn(s) + + err = db.Use("test", "test") + s.Require().NoError(err, "should not return an error when setting namespace and database") } // Sign with the root user // Can be used with any user -func signin(s *SurrealDBTestSuite) interface{} { +func signIn(s *SurrealDBTestSuite) string { authData := &surrealdb.Auth{ Username: "root", Password: "root", } - signin, err := s.db.Signin(authData) + token, err := s.db.SignIn(authData) s.Require().NoError(err) - return signin + return token } -func (s *SurrealDBTestSuite) TestLiveViaMethod() { - live, err := s.db.Live("users", false) - defer func() { - _, err = s.db.Kill(live) - s.Require().NoError(err) - }() - - notifications, er := s.db.LiveNotifications(live) - // create a user - s.Require().NoError(er) - _, e := s.db.Create("users", map[string]interface{}{ - "username": "johnny", - "password": "123", - }) - s.Require().NoError(e) - notification := <-notifications - s.Require().Equal(model.CreateAction, notification.Action) - s.Require().Equal(live, notification.ID) -} - -func (s *SurrealDBTestSuite) TestLiveWithOptionsViaMethod() { - // create a user - userData, e := s.db.Create("users", map[string]interface{}{ - "username": "johnny", - "password": "123", +func (s *SurrealDBTestSuite) TestSend_AllowedMethods() { + s.Run("Send method should be rejected", func() { + err := s.db.Send(nil, "let") + s.Require().Error(err) }) - s.Require().NoError(e) - var user []testUser - err := marshal.Unmarshal(userData, &user) - s.Require().NoError(err) - live, err := s.db.Live("users", true) - defer func() { - _, err = s.db.Kill(live) + s.Run("Send method should be allowed", func() { + err := s.db.Send(nil, "query", "select * from users") s.Require().NoError(err) - }() - - notifications, er := s.db.LiveNotifications(live) - s.Require().NoError(er) - - // update the user - _, e = s.db.Update(user[0].ID, map[string]interface{}{ - "password": "456", }) - s.Require().NoError(e) - - notification := <-notifications - s.Require().Equal(model.UpdateAction, notification.Action) - s.Require().Equal(live, notification.ID) -} - -func (s *SurrealDBTestSuite) TestLiveViaQuery() { - liveResponse, err := s.db.Query("LIVE SELECT * FROM users", map[string]interface{}{}) - assert.NoError(s.T(), err) - responseArray, ok := liveResponse.([]interface{}) - assert.True(s.T(), ok) - singleResponse := responseArray[0].(map[string]interface{}) - liveIDStruct, ok := singleResponse["result"] - assert.True(s.T(), ok) - liveID := liveIDStruct.(string) - - defer func() { - _, err = s.db.Kill(liveID) - s.Require().NoError(err) - }() - - notifications, er := s.db.LiveNotifications(liveID) - // create a user - s.Require().NoError(er) - _, e := s.db.Create("users", map[string]interface{}{ - "username": "johnny", - "password": "123", - }) - s.Require().NoError(e) - notification := <-notifications - s.Require().Equal(model.CreateAction, notification.Action) - s.Require().Equal(liveID, notification.ID) } func (s *SurrealDBTestSuite) TestDelete() { - userData, err := s.db.Create("users", testUser{ + _, err := surrealdb.Create[testUser](s.db, "users", testUser{ Username: "johnny", Password: "123", }) s.Require().NoError(err) - // unmarshal the data into a user struct - var user []testUser - err = marshal.Unmarshal(userData, &user) - s.Require().NoError(err) - // Delete the users... - _, err = s.db.Delete("users") + err = surrealdb.Delete(s.db, "users") s.Require().NoError(err) } -func (s *SurrealDBTestSuite) TestFetch() { - // Define initial user slice - userSlice := []testUserWithFriend[string]{ - { - ID: "users:arthur", - Username: "arthur", - Password: "deer", - Friends: []string{"users:john"}, - }, - { - ID: "users:john", - Username: "john", - Password: "wolf", - Friends: []string{"users:arthur"}, - }, - } - - // Initialize data using users - for _, v := range userSlice { - data, err := s.db.Create(v.ID, v) - s.NoError(err) - s.NotNil(data) - } - - // User rows are individually fetched - s.Run("Run fetch for individual users", func() { - s.T().Skip("TODO(gh-116) Fetch unimplemented") - for _, v := range userSlice { - res, err := s.db.Query("select * from $table fetch $fetchstr;", map[string]interface{}{ - "record": v.ID, - "fetchstr": "friends.*", - }) - s.NoError(err) - s.NotEmpty(res) - } - }) - - s.Run("Run fetch on hardcoded query", func() { - query := "SELECT * from users:arthur fetch friends.*" - res, err := s.db.Query(query, map[string]interface{}{}) - s.NoError(err) - s.NotEmpty(res) - - userSlice, err := marshal.SmartUnmarshal[testUserWithFriend[testUserWithFriend[interface{}]]](res, err) - s.NoError(err) - - s.Require().Len(userSlice, 1) - s.Require().Len(userSlice[0].Friends, 1) - s.Require().NotEmpty(userSlice[0].Friends[0], 1) - }) - - s.Run("Run fetch on query using map[string]interface{} for thing and fetchString", func() { - s.T().Skip("TODO(gh-116) Fetch unimplemented") - res, err := s.db.Query("select * from $record fetch $fetchstr;", map[string]interface{}{ - "record": "users", - "fetchstr": "friends.*", - }) - s.NoError(err) - s.NotEmpty(res) - }) - - s.Run("Run fetch on query using map[string]interface{} for fetchString", func() { - s.T().Skip("TODO(gh-116) Fetch unimplemented") - res, err := s.db.Query("select * from users fetch $fetchstr;", map[string]interface{}{ - "fetchstr": "friends.*", - }) - s.NoError(err) - s.NotEmpty(res) - }) - - s.Run("Run fetch on query using map[string]interface{} for thing or tableName", func() { - res, err := s.db.Query("select * from $record fetch friends.*;", map[string]interface{}{ - "record": "users:arthur", - }) - s.NoError(err) - s.NotEmpty(res) - - userSlice, err := marshal.SmartUnmarshal[testUserWithFriend[testUserWithFriend[interface{}]]](res, err) - s.NoError(err) - - s.Require().Len(userSlice, 1) - s.Require().Len(userSlice[0].Friends, 1) - s.Require().NotEmpty(userSlice[0].Friends[0], 1) - }) -} - func (s *SurrealDBTestSuite) TestInsert() { s.Run("raw map works", func() { - userData, err := s.db.Insert("user", map[string]interface{}{ + insert, err := surrealdb.Insert[testUser](s.db, "users", map[string]interface{}{ "username": "johnny", "password": "123", }) s.Require().NoError(err) - // unmarshal the data into a user struct - var user []testUser - err = marshal.Unmarshal(userData, &user) - s.Require().NoError(err) - - s.Equal("johnny", user[0].Username) - s.Equal("123", user[0].Password) + s.Equal("johnny", (*insert)[0].Username) + s.Equal("123", (*insert)[0].Password) }) s.Run("Single insert works", func() { - userData, err := s.db.Insert("user", testUser{ + insert, err := surrealdb.Insert[testUser](s.db, "users", testUser{ Username: "johnny", Password: "123", }) s.Require().NoError(err) - // unmarshal the data into a user struct - var user []testUser - err = marshal.Unmarshal(userData, &user) - s.Require().NoError(err) - - s.Equal("johnny", user[0].Username) - s.Equal("123", user[0].Password) + s.Equal("johnny", (*insert)[0].Username) + s.Equal("123", (*insert)[0].Password) }) s.Run("Multiple insert works", func() { @@ -385,120 +152,36 @@ func (s *SurrealDBTestSuite) TestInsert() { Username: "johnny2", Password: "123", }) - userData, err := s.db.Insert("user", userInsert) - s.Require().NoError(err) - - // unmarshal the data into a user struct - var users []testUser - err = marshal.Unmarshal(userData, &users) + insert, err := surrealdb.Insert[testUser](s.db, "users", userInsert) s.Require().NoError(err) - s.Len(users, 2) - - assertContains(s, users, func(user testUser) bool { - return s.Contains(users, user) - }) + s.Len(*insert, 2) }) } -func (s *SurrealDBTestSuite) TestCreate() { - s.Run("raw map works", func() { - userData, err := s.db.Create("users", map[string]interface{}{ - "username": "johnny", - "password": "123", - }) - s.Require().NoError(err) - - // unmarshal the data into a user struct - var userSlice []testUser - err = marshal.Unmarshal(userData, &userSlice) - s.Require().NoError(err) - s.Len(userSlice, 1) - - s.Equal("johnny", userSlice[0].Username) - s.Equal("123", userSlice[0].Password) - }) - - s.Run("Single create works", func() { - userData, err := s.db.Create("users", testUser{ - Username: "johnny", - Password: "123", - }) - s.Require().NoError(err) - - // unmarshal the data into a user struct - var userSlice []testUser - err = marshal.Unmarshal(userData, &userSlice) - s.Require().NoError(err) - s.Len(userSlice, 1) - - s.Equal("johnny", userSlice[0].Username) - s.Equal("123", userSlice[0].Password) +func (s *SurrealDBTestSuite) TestPatch() { + _, err := surrealdb.Create[testUser](s.db, *models.ParseRecordID("users:999"), map[string]interface{}{ + "username": "john999", + "password": "123", }) + s.NoError(err) - s.Run("Multiple creates works", func() { - s.T().Skip("Creating multiple records is not supported yet") - data := make([]testUser, 0) - data = append(data, - testUser{ - Username: "johnny", - Password: "123", - }, - testUser{ - Username: "joe", - Password: "123", - }) - userData, err := s.db.Create("users", data) - s.Require().NoError(err) - - // unmarshal the data into a user struct - var users []testUser - err = marshal.Unmarshal(userData, &users) - s.Require().NoError(err) - - assertContains(s, users, func(user testUser) bool { - return s.Contains(users, user) - }) - }) -} + patches := []surrealdb.PatchData{ + {Op: "add", Path: "nickname", Value: "johnny"}, + {Op: "add", Path: "age", Value: int(44)}, + } -func (s *SurrealDBTestSuite) TestSelect() { - createdUsers, err := s.db.Create("users", testUser{ - Username: "johnnyjohn", - Password: "123", - }) + // Update the user + _, err = surrealdb.Patch(s.db, models.ParseRecordID("users:999"), patches) s.Require().NoError(err) - s.NotEmpty(createdUsers) - var createdUsersUnmarshalled []testUser - s.Require().NoError(marshal.Unmarshal(createdUsers, &createdUsersUnmarshalled)) - s.NotEmpty(createdUsersUnmarshalled) - s.NotEmpty(createdUsersUnmarshalled[0].ID, "The ID should have been set by the database") - - s.Run("Select many with table", func() { - userData, err := s.db.Select("users") - s.Require().NoError(err) - - // unmarshal the data into a user slice - var users []testUser - err = marshal.Unmarshal(userData, &users) - s.NoError(err) - matching := assertContains(s, users, func(item testUser) bool { - return item.Username == "johnnyjohn" - }) - s.GreaterOrEqual(len(matching), 1) - }) - s.Run("Select single record", func() { - userData, err := s.db.Select(createdUsersUnmarshalled[0].ID) - s.Require().NoError(err) + user2, err := surrealdb.Select[map[string]interface{}](s.db, *models.ParseRecordID("users:999")) + s.Require().NoError(err) - // unmarshal the data into a user struct - var user testUser - err = marshal.Unmarshal(userData, &user) - s.Require().NoError(err) + username := (*user2)["username"].(string) + data := (*user2)["age"].(uint64) - s.Equal("johnnyjohn", user.Username) - s.Equal("123", user.Password) - }) + s.Equal("john999", username) // Ensure username hasn't change + s.EqualValues(patches[1].Value, data) } func (s *SurrealDBTestSuite) TestUpdate() { @@ -511,226 +194,146 @@ func (s *SurrealDBTestSuite) TestUpdate() { // create users var createdUsers []testUser for _, v := range users { - createdUser, err := s.db.Create("users", v) + createdUser, err := surrealdb.Create[testUser](s.db, models.Table("users"), v) s.Require().NoError(err) - var tempUserArr []testUser - err = marshal.Unmarshal(createdUser, &tempUserArr) - s.Require().NoError(err) - createdUsers = append(createdUsers, tempUserArr...) + createdUsers = append(createdUsers, *createdUser) } createdUsers[0].Password = newPassword // Update the user - UpdatedUserRaw, err := s.db.Update(createdUsers[0].ID, createdUsers[0]) - s.Require().NoError(err) - - // unmarshal the data into a user struct - var updatedUser testUser - err = marshal.Unmarshal(UpdatedUserRaw, &updatedUser) + updatedUser, err := surrealdb.Update[testUser](s.db, *(createdUsers)[0].ID, createdUsers[0]) s.Require().NoError(err) // Check if password changes s.Equal(newPassword, updatedUser.Password) // select controlUser - controlUserRaw, err := s.db.Select(createdUsers[1].ID) - s.Require().NoError(err) - - // unmarshal the data into a user struct - var controlUser testUser - err = marshal.Unmarshal(controlUserRaw, &controlUser) + controlUser, err := surrealdb.Select[testUser](s.db, *createdUsers[1].ID) s.Require().NoError(err) // check control user is changed or not - s.Equal(createdUsers[1], controlUser) + s.Equal(createdUsers[1], *controlUser) } -func (s *SurrealDBTestSuite) TestUnmarshalRaw() { - username := "johnny" - password := "123" +func (s *SurrealDBTestSuite) TestLiveViaMethod() { + live, err := surrealdb.Live(s.db, "users", false) + s.Require().NoError(err, "should not return error on live request") - // create test user with raw SurrealQL and unmarshal - userData, err := s.db.Query("create users:johnny set Username = $user, Password = $pass", map[string]interface{}{ - "user": username, - "pass": password, - }) - s.Require().NoError(err) + defer func() { + err = surrealdb.Kill(s.db, live.String()) + s.Require().NoError(err) + }() - var userSlice []marshal.RawQuery[testUser] - err = marshal.UnmarshalRaw(userData, &userSlice) - s.Require().NoError(err) - s.Len(userSlice, 1) - s.Equal(userSlice[0].Status, marshal.StatusOK) - s.Equal(username, userSlice[0].Result[0].Username) - s.Equal(password, userSlice[0].Result[0].Password) - - // send query with empty result and unmarshal - userData, err = s.db.Query("select * from users where id = $id", map[string]interface{}{ - "id": "users:jim", - }) + notifications, err := s.db.LiveNotifications(live.String()) s.Require().NoError(err) - err = marshal.UnmarshalRaw(userData, &userSlice) - s.NoError(err) - s.Equal(userSlice[0].Status, marshal.StatusOK) - s.Empty(userSlice[0].Result) -} - -func (s *SurrealDBTestSuite) TestMerge() { - _, err := s.db.Create("users:999", map[string]interface{}{ - "username": "john999", + _, e := surrealdb.Create[testUser](s.db, "users", map[string]interface{}{ + "username": "johnny", "password": "123", }) - s.NoError(err) - - // Update the user - _, err = s.db.Merge("users:999", map[string]string{ - "password": "456", - }) - s.Require().NoError(err) - - user2, err := s.db.Select("users:999") - s.Require().NoError(err) - - username := (user2).(map[string]interface{})["username"].(string) - password := (user2).(map[string]interface{})["password"].(string) + s.Require().NoError(e) - s.Equal("john999", username) // Ensure username hasn't change. - s.Equal("456", password) + notification := <-notifications + fmt.Println(notification) + s.Require().Equal(connection.CreateAction, notification.Action) + s.Require().Equal(live, notification.ID) } -func (s *SurrealDBTestSuite) TestPatch() { - _, err := s.db.Create("users:999", map[string]interface{}{ - "username": "john999", - "password": "123", - }) - s.NoError(err) - - patches := []surrealdb.Patch{ - {Op: "add", Path: "nickname", Value: "johnny"}, - {Op: "add", Path: "age", Value: int(44)}, - } - - // Update the user - _, err = s.db.Patch("users:999", patches) - s.Require().NoError(err) - - user2, err := s.db.Select("users:999") +func (s *SurrealDBTestSuite) TestLiveViaQuery() { + res, err := surrealdb.Query[models.UUID](s.db, "LIVE SELECT * FROM users", map[string]interface{}{}) s.Require().NoError(err) - username := (user2).(map[string]interface{})["username"].(string) - data := (user2).(map[string]interface{})["age"].(float64) + liveID := (*res)[0].Result.String() - s.Equal("john999", username) // Ensure username hasn't change. - s.EqualValues(patches[1].Value, data) -} - -func (s *SurrealDBTestSuite) TestNonRowSelect() { - user := testUser{ - Username: "ElecTwix", - Password: "1234", - ID: "users:notexists", - } + notifications, err := s.db.LiveNotifications(liveID) + s.Require().NoError(err) - _, err := s.db.Select("users:notexists") - s.Equal(err, constants.ErrNoRow) + defer func() { + err = surrealdb.Kill(s.db, liveID) + s.Require().NoError(err) + }() - _, err = marshal.SmartUnmarshal[testUser](s.db.Select("users:notexists")) - s.Equal(err, constants.ErrNoRow) + // create user + _, e := surrealdb.Create[testUser](s.db, "users", map[string]interface{}{ + "username": "johnny", + "password": "123", + }) + s.Require().NoError(e) + notification := <-notifications - _, err = marshal.SmartUnmarshal[testUser](marshal.SmartMarshal(s.db.Select, user)) - s.Equal(err, constants.ErrNoRow) + s.Require().Equal(connection.CreateAction, notification.Action) + s.Require().Equal(liveID, notification.ID.String()) } -func (s *SurrealDBTestSuite) TestSmartUnMarshalQuery() { - user := []testUser{{ - Username: "electwix", - Password: "1234", - }} - - s.Run("raw create query", func() { - QueryStr := "Create users set Username = $user, Password = $pass" - dataArr, err := marshal.SmartUnmarshal[testUser](s.db.Query(QueryStr, map[string]interface{}{ - "user": user[0].Username, - "pass": user[0].Password, - })) - +func (s *SurrealDBTestSuite) TestCreate() { + s.Run("raw map works", func() { + user, err := surrealdb.Create[testUser](s.db, "users", map[string]interface{}{ + "username": "johnny", + "password": "123", + }) s.Require().NoError(err) - s.Equal("electwix", dataArr[0].Username) - user = dataArr - }) - - s.Run("raw select query", func() { - dataArr, err := marshal.SmartUnmarshal[testUser](s.db.Query("Select * from $record", map[string]interface{}{ - "record": user[0].ID, - })) - s.Require().NoError(err) - s.Equal("electwix", dataArr[0].Username) + s.Equal("johnny", user.Username) + s.Equal("123", user.Password) }) - s.Run("select query", func() { - data, err := marshal.SmartUnmarshal[testUser](s.db.Select(user[0].ID)) - + s.Run("Single create works", func() { + user, err := surrealdb.Create[testUser](s.db, "users", testUser{ + Username: "johnny", + Password: "123", + }) s.Require().NoError(err) - s.Equal("electwix", data[0].Username) - }) - - s.Run("select array query", func() { - data, err := marshal.SmartUnmarshal[testUser](s.db.Select("users")) - s.Require().NoError(err) - s.Equal("electwix", data[0].Username) + s.Equal("johnny", user.Username) + s.Equal("123", user.Password) }) - s.Run("delete record query", func() { - data, err := marshal.SmartUnmarshal[testUser](s.db.Delete(user[0].ID)) - + s.Run("Multiple creates works", func() { + s.T().Skip("Creating multiple records is not supported yet") + data := make([]testUser, 0) + data = append(data, + testUser{ + Username: "johnny", + Password: "123", + }, + testUser{ + Username: "joe", + Password: "123", + }) + users, err := surrealdb.Create[[]testUser](s.db, "users", data) s.Require().NoError(err) - s.Len(data, 0) + + assertContains(s, *users, func(user testUser) bool { + return s.Contains(users, user) + }) }) } -func (s *SurrealDBTestSuite) TestSmartMarshalQuery() { - user := []testUser{{ - Username: "electwix", - Password: "1234", - ID: "sometable:someid", - }} - - s.Run("create with SmartMarshal query", func() { - data, err := marshal.SmartUnmarshal[testUser](marshal.SmartMarshal(s.db.Create, user[0])) - s.Require().NoError(err) - s.Len(data, 1) - s.Equal(user[0], data[0]) +func (s *SurrealDBTestSuite) TestSelect() { + createdUser, err := surrealdb.Create[testUser](s.db, "users", testUser{ + Username: "johnnyjohn", + Password: "123", }) + s.Require().NoError(err) + s.NotEmpty(createdUser) - s.Run("select with SmartMarshal query", func() { - data, err := marshal.SmartUnmarshal[testUser](marshal.SmartMarshal(s.db.Select, user[0])) + s.Run("Select many with table", func() { + users, err := surrealdb.Select[[]testUser](s.db, "users") s.Require().NoError(err) - s.Len(data, 1) - s.Equal(user[0], data[0]) - }) - s.Run("update with SmartMarshal query", func() { - user[0].Password = "test123" - data, err := marshal.SmartUnmarshal[testUser](marshal.SmartMarshal(s.db.Update, user[0])) - s.Require().NoError(err) - s.Len(data, 1) - s.Equal(user[0].Password, data[0].Password) + matching := assertContains(s, *users, func(item testUser) bool { + return item.Username == "johnnyjohn" + }) + s.GreaterOrEqual(len(matching), 1) }) - s.Run("delete with SmartMarshal query", func() { - data, err := marshal.SmartMarshal(s.db.Delete, user[0]) + s.Run("Select single record", func() { + user, err := surrealdb.Select[testUser](s.db, *createdUser.ID) s.Require().NoError(err) - s.Nil(data) - }) - s.Run("check if data deleted SmartMarshal query", func() { - data, err := marshal.SmartUnmarshal[testUser](marshal.SmartMarshal(s.db.Select, user[0])) - s.Require().Equal(err, constants.ErrNoRow) - s.Len(data, 0) + s.Equal("johnnyjohn", user.Username) + s.Equal("123", user.Password) }) } @@ -738,18 +341,12 @@ func (s *SurrealDBTestSuite) TestConcurrentOperations() { var wg sync.WaitGroup totalGoroutines := 100 - user := testUser{ - Username: "electwix", - Password: "1234", - } - s.Run(fmt.Sprintf("Concurrent select non existent rows %d", totalGoroutines), func() { for i := 0; i < totalGoroutines; i++ { wg.Add(1) go func(j int) { defer wg.Done() - _, err := s.db.Select(fmt.Sprintf("users:%d", j)) - s.Require().Equal(err, constants.ErrNoRow) + _, _ = surrealdb.Select[testUser](s.db, models.NewRecordID("users", j)) }(i) } wg.Wait() @@ -760,7 +357,7 @@ func (s *SurrealDBTestSuite) TestConcurrentOperations() { wg.Add(1) go func(j int) { defer wg.Done() - _, err := s.db.Create(fmt.Sprintf("users:%d", j), user) + _, err := surrealdb.Select[testUser](s.db, models.NewRecordID("users", j)) s.Require().NoError(err) }(i) } @@ -772,7 +369,7 @@ func (s *SurrealDBTestSuite) TestConcurrentOperations() { wg.Add(1) go func(j int) { defer wg.Done() - _, err := s.db.Select(fmt.Sprintf("users:%d", j)) + _, err := surrealdb.Select[testUser](s.db, models.NewRecordID("users", j)) s.Require().NoError(err) }(i) } @@ -780,33 +377,21 @@ func (s *SurrealDBTestSuite) TestConcurrentOperations() { }) } -func (s *SurrealDBTestSuite) TestConnectionBreak() { - ws := gorilla.Create() - var url string - if currentURL == "" { - url = defaultURL - } else { - url = currentURL - } - - db := s.openConnection(url, ws) - // Close the connection hard from ws - ws.Conn.Close() +func (s *SurrealDBTestSuite) TestMerge() { + _, err := surrealdb.Create[testUser](s.db, *models.ParseRecordID("users:999"), map[string]interface{}{ + "username": "john999", + "password": "123", + }) + s.NoError(err) - // Needs to be return error when the connection is closed or broken - _, err := db.Select("users") - s.Require().Error(err) -} + // Update the user + _, err = surrealdb.Merge[testUser](s.db, *models.ParseRecordID("users:999"), map[string]string{ + "password": "456", + }) + s.Require().NoError(err) -// assertContains performs an assertion on a list, asserting that at least one element matches a provided condition. -// All the matching elements are returned from this function, which can be used as a filter. -func assertContains[K any](s *SurrealDBTestSuite, input []K, matcher func(K) bool) []K { - matching := make([]K, 0) - for _, v := range input { - if matcher(v) { - matching = append(matching, v) - } - } - s.NotEmptyf(matching, "Input %+v did not contain matching element", fmt.Sprintf("%+v", input)) - return matching + user, err := surrealdb.Select[testUser](s.db, *models.ParseRecordID("users:999")) + s.Require().NoError(err) + s.Equal("john999", user.Username) // Ensure username hasn't change. + s.Equal("456", user.Password) } diff --git a/go.mod b/go.mod index 43ef31e..493ba2e 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,8 @@ module github.com/surrealdb/surrealdb.go go 1.20 require ( + github.com/fxamacker/cbor/v2 v2.7.0 + github.com/gofrs/uuid v4.4.0+incompatible github.com/gorilla/websocket v1.5.0 github.com/stretchr/testify v1.8.4 ) @@ -12,6 +14,7 @@ require ( github.com/kr/pretty v0.3.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.10.0 // indirect + github.com/x448/float16 v0.8.4 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index ab9f4d8..9617d5d 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,10 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv5E= +github.com/fxamacker/cbor/v2 v2.7.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= +github.com/gofrs/uuid v4.4.0+incompatible h1:3qXRTX8/NbyulANqlc0lchS1gqAVxRgsuW1YrTJupqA= +github.com/gofrs/uuid v4.4.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= @@ -18,6 +22,8 @@ github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjR github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= +github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/internal/benchmark/benchmark_test.go b/internal/benchmark/benchmark_test.go index 03532ce..998001c 100644 --- a/internal/benchmark/benchmark_test.go +++ b/internal/benchmark/benchmark_test.go @@ -4,21 +4,20 @@ import ( "fmt" "testing" - "github.com/surrealdb/surrealdb.go" - "github.com/surrealdb/surrealdb.go/internal/mock" - "github.com/surrealdb/surrealdb.go/pkg/marshal" + "github.com/surrealdb/surrealdb.go/pkg/models" + + surrealdb "github.com/surrealdb/surrealdb.go" ) // a simple user struct for testing type testUser struct { - marshal.Basemodel `table:"test"` - Username string `json:"username,omitempty"` - Password string `json:"password,omitempty"` - ID string `json:"id,omitempty"` + Username string `json:"username,omitempty"` + Password string `json:"password,omitempty"` + ID string `json:"id,omitempty"` } func SetupMockDB() (*surrealdb.DB, error) { - return surrealdb.New("", mock.Create()) + return surrealdb.New("") } func BenchmarkCreate(b *testing.B) { @@ -38,7 +37,7 @@ func BenchmarkCreate(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { // error is ignored for benchmarking purposes. - db.Create(users[i].ID, users[i]) //nolint:errcheck + surrealdb.Create[testUser](db, models.Table("users"), users[i]) //nolint:errcheck } } @@ -51,6 +50,6 @@ func BenchmarkSelect(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { // error is ignored for benchmarking purposes. - db.Select("users:bob") //nolint:errcheck + surrealdb.Select[testUser](db, models.NewRecordID("users", "bob")) //nolint:errcheck } } diff --git a/internal/codec/codec.go b/internal/codec/codec.go new file mode 100644 index 0000000..0d0ff40 --- /dev/null +++ b/internal/codec/codec.go @@ -0,0 +1,21 @@ +package codec + +import "io" + +type Encoder interface { + Encode(v interface{}) error +} + +type Decoder interface { + Decode(v interface{}) error +} + +type Marshaler interface { + Marshal(v interface{}) ([]byte, error) + NewEncoder(w io.Writer) Encoder +} + +type Unmarshaler interface { + Unmarshal(data []byte, dst interface{}) error + NewDecoder(r io.Reader) Decoder +} diff --git a/internal/mock/mock.go b/internal/mock/mock.go index a3b7e3e..ff75858 100644 --- a/internal/mock/mock.go +++ b/internal/mock/mock.go @@ -3,15 +3,14 @@ package mock import ( "errors" - "github.com/surrealdb/surrealdb.go/pkg/conn" - "github.com/surrealdb/surrealdb.go/pkg/model" + "github.com/surrealdb/surrealdb.go/pkg/connection" ) type ws struct { } -func (w *ws) Connect(url string) (conn.Connection, error) { - return w, nil +func (w *ws) Connect(url string) error { + return nil } func (w *ws) Send(method string, params []interface{}) (interface{}, error) { @@ -22,7 +21,7 @@ func (w *ws) Close() error { return nil } -func (w *ws) LiveNotifications(id string) (chan model.Notification, error) { +func (w *ws) LiveNotifications(id string) (chan connection.Notification, error) { return nil, errors.New("live queries are unimplemented for mocks") } diff --git a/pkg/rand/rand.go b/internal/rand/rand.go similarity index 100% rename from pkg/rand/rand.go rename to internal/rand/rand.go diff --git a/internal/rpc/rpc.go b/internal/rpc/rpc.go deleted file mode 100644 index 38285fc..0000000 --- a/internal/rpc/rpc.go +++ /dev/null @@ -1,33 +0,0 @@ -package rpc - -// RPCError represents a JSON-RPC error -type RPCError struct { - Code int `json:"code" msgpack:"code"` - Message string `json:"message,omitempty" msgpack:"message,omitempty"` -} - -func (r *RPCError) Error() string { - return r.Message -} - -// RPCRequest represents an incoming JSON-RPC request -type RPCRequest struct { - ID interface{} `json:"id" msgpack:"id"` - Async bool `json:"async,omitempty" msgpack:"async,omitempty"` - Method string `json:"method,omitempty" msgpack:"method,omitempty"` - Params []interface{} `json:"params,omitempty" msgpack:"params,omitempty"` -} - -// RPCResponse represents an outgoing JSON-RPC response -type RPCResponse struct { - ID interface{} `json:"id" msgpack:"id"` - Error *RPCError `json:"error,omitempty" msgpack:"error,omitempty"` - Result interface{} `json:"result,omitempty" msgpack:"result,omitempty"` -} - -// RPCNotification represents an outgoing JSON-RPC notification -type RPCNotification struct { - ID interface{} `json:"id" msgpack:"id"` - Method string `json:"method,omitempty" msgpack:"method,omitempty"` - Params []interface{} `json:"params,omitempty" msgpack:"params,omitempty"` -} diff --git a/internal/util/util.go b/internal/util/util.go new file mode 100644 index 0000000..96a3096 --- /dev/null +++ b/internal/util/util.go @@ -0,0 +1,20 @@ +package util + +import ( + "reflect" +) + +func IsSlice(value any) bool { + return reflect.TypeOf(value).Kind() == reflect.Slice +} + +func ExistsInSlice(value any, checkList []any) bool { + exists := false + for i := 0; i < len(checkList); i++ { + if checkList[i] == value { + exists = true + break + } + } + return exists +} diff --git a/pkg/conn/conn.go b/pkg/conn/conn.go deleted file mode 100644 index dc619ce..0000000 --- a/pkg/conn/conn.go +++ /dev/null @@ -1,10 +0,0 @@ -package conn - -import "github.com/surrealdb/surrealdb.go/pkg/model" - -type Connection interface { - Connect(url string) (Connection, error) - Send(method string, params []interface{}) (interface{}, error) - Close() error - LiveNotifications(id string) (chan model.Notification, error) -} diff --git a/pkg/conn/gorilla/gorilla.go b/pkg/conn/gorilla/gorilla.go deleted file mode 100644 index 4dcbf2c..0000000 --- a/pkg/conn/gorilla/gorilla.go +++ /dev/null @@ -1,374 +0,0 @@ -package gorilla - -import ( - "encoding/json" - "errors" - "fmt" - "io" - "net" - "reflect" - "strconv" - "sync" - "time" - - "github.com/surrealdb/surrealdb.go/pkg/model" - - gorilla "github.com/gorilla/websocket" - "github.com/surrealdb/surrealdb.go/internal/rpc" - "github.com/surrealdb/surrealdb.go/pkg/conn" - "github.com/surrealdb/surrealdb.go/pkg/logger" - "github.com/surrealdb/surrealdb.go/pkg/rand" -) - -const ( - // RequestIDLength size of id sent on WS request - RequestIDLength = 16 - // CloseMessageCode identifier the message id for a close request - CloseMessageCode = 1000 - // DefaultTimeout timeout in seconds - DefaultTimeout = 30 -) - -type Option func(ws *WebSocket) error - -type WebSocket struct { - Conn *gorilla.Conn - connLock sync.Mutex - Timeout time.Duration - Option []Option - logger logger.Logger - - responseChannels map[string]chan rpc.RPCResponse - responseChannelsLock sync.RWMutex - - notificationChannels map[string]chan model.Notification - notificationChannelsLock sync.RWMutex - - closeChan chan int - closeError error -} - -func Create() *WebSocket { - return &WebSocket{ - Conn: nil, - closeChan: make(chan int), - responseChannels: make(map[string]chan rpc.RPCResponse), - notificationChannels: make(map[string]chan model.Notification), - Timeout: DefaultTimeout * time.Second, - } -} - -func (ws *WebSocket) Connect(url string) (conn.Connection, error) { - dialer := gorilla.DefaultDialer - dialer.EnableCompression = true - - connection, _, err := dialer.Dial(url, nil) - if err != nil { - return nil, err - } - - ws.Conn = connection - - for _, option := range ws.Option { - if err := option(ws); err != nil { - return ws, err - } - } - - go ws.initialize() - return ws, nil -} - -func (ws *WebSocket) SetTimeOut(timeout time.Duration) *WebSocket { - ws.Option = append(ws.Option, func(ws *WebSocket) error { - ws.Timeout = timeout - return nil - }) - return ws -} - -// If path is empty it will use os.stdout/os.stderr -func (ws *WebSocket) Logger(logData logger.Logger) *WebSocket { - ws.logger = logData - return ws -} - -func (ws *WebSocket) RawLogger(logData logger.Logger) *WebSocket { - ws.logger = logData - return ws -} - -func (ws *WebSocket) SetCompression(compress bool) *WebSocket { - ws.Option = append(ws.Option, func(ws *WebSocket) error { - ws.Conn.EnableWriteCompression(compress) - return nil - }) - return ws -} - -func (ws *WebSocket) Close() error { - ws.connLock.Lock() - defer ws.connLock.Unlock() - close(ws.closeChan) - err := ws.Conn.WriteMessage(gorilla.CloseMessage, gorilla.FormatCloseMessage(CloseMessageCode, "")) - if err != nil { - return err - } - - return ws.Conn.Close() -} - -func (ws *WebSocket) LiveNotifications(liveQueryID string) (chan model.Notification, error) { - c, err := ws.createNotificationChannel(liveQueryID) - if err != nil { - ws.logger.Error(err.Error()) - } - return c, err -} - -var ( - ErrIDInUse = errors.New("id already in use") - ErrTimeout = errors.New("timeout") - ErrInvalidResponseID = errors.New("invalid response id") -) - -func (ws *WebSocket) createResponseChannel(id string) (chan rpc.RPCResponse, error) { - ws.responseChannelsLock.Lock() - defer ws.responseChannelsLock.Unlock() - - if _, ok := ws.responseChannels[id]; ok { - return nil, fmt.Errorf("%w: %v", ErrIDInUse, id) - } - - ch := make(chan rpc.RPCResponse) - ws.responseChannels[id] = ch - - return ch, nil -} - -func (ws *WebSocket) createNotificationChannel(liveQueryID string) (chan model.Notification, error) { - ws.notificationChannelsLock.Lock() - defer ws.notificationChannelsLock.Unlock() - - if _, ok := ws.notificationChannels[liveQueryID]; ok { - return nil, fmt.Errorf("%w: %v", ErrIDInUse, liveQueryID) - } - - ch := make(chan model.Notification) - ws.notificationChannels[liveQueryID] = ch - - return ch, nil -} - -func (ws *WebSocket) removeResponseChannel(id string) { - ws.responseChannelsLock.Lock() - defer ws.responseChannelsLock.Unlock() - delete(ws.responseChannels, id) -} - -func (ws *WebSocket) getResponseChannel(id string) (chan rpc.RPCResponse, bool) { - ws.responseChannelsLock.RLock() - defer ws.responseChannelsLock.RUnlock() - ch, ok := ws.responseChannels[id] - return ch, ok -} - -func (ws *WebSocket) getLiveChannel(id string) (chan model.Notification, bool) { - ws.notificationChannelsLock.RLock() - defer ws.notificationChannelsLock.RUnlock() - ch, ok := ws.notificationChannels[id] - return ch, ok -} - -func (ws *WebSocket) Send(method string, params []interface{}) (interface{}, error) { - select { - case <-ws.closeChan: - return nil, ws.closeError - default: - } - - id := rand.String(RequestIDLength) - request := &rpc.RPCRequest{ - ID: id, - Method: method, - Params: params, - } - - responseChan, err := ws.createResponseChannel(id) - if err != nil { - return nil, err - } - defer ws.removeResponseChannel(id) - - if err := ws.write(request); err != nil { - return nil, err - } - - timeout := time.After(ws.Timeout) - - select { - case <-timeout: - return nil, ErrTimeout - case res, open := <-responseChan: - if !open { - return nil, errors.New("channel closed") - } - if res.ID != id { - return nil, ErrInvalidResponseID - } - if res.Error != nil { - return nil, res.Error - } - return res.Result, nil - } -} - -func (ws *WebSocket) read(v interface{}) error { - _, data, err := ws.Conn.ReadMessage() - if err != nil { - return err - } - return json.Unmarshal(data, v) -} - -func (ws *WebSocket) write(v interface{}) error { - data, err := json.Marshal(v) - if err != nil { - return err - } - - ws.connLock.Lock() - defer ws.connLock.Unlock() - return ws.Conn.WriteMessage(gorilla.TextMessage, data) -} - -func (ws *WebSocket) initialize() { - for { - select { - case <-ws.closeChan: - return - default: - var res rpc.RPCResponse - err := ws.read(&res) - if err != nil { - shouldExit := ws.handleError(err) - if shouldExit { - return - } - continue - } - go ws.handleResponse(res) - } - } -} - -func (ws *WebSocket) handleError(err error) bool { - if errors.Is(err, net.ErrClosed) { - ws.closeError = net.ErrClosed - return true - } - if gorilla.IsUnexpectedCloseError(err) { - ws.closeError = io.ErrClosedPipe - <-ws.closeChan - return true - } - - ws.logger.Error(err.Error()) - return false -} - -func (ws *WebSocket) handleResponse(res rpc.RPCResponse) { - if res.ID != nil && res.ID != "" { - // Try to resolve message as response to query - responseChan, ok := ws.getResponseChannel(fmt.Sprintf("%v", res.ID)) - if !ok { - err := fmt.Errorf("unavailable ResponseChannel %+v", res.ID) - ws.logger.Error(err.Error()) - return - } - defer close(responseChan) - responseChan <- res - } else { - // Try to resolve response as live query notification - mappedRes, _ := res.Result.(map[string]interface{}) - resolvedID, ok := mappedRes["id"] - if !ok { - err := fmt.Errorf("response did not contain an 'id' field") - - ws.logger.Error(err.Error(), "result", fmt.Sprint(res.Result)) - return - } - var notification model.Notification - err := unmarshalMapToStruct(mappedRes, ¬ification) - if err != nil { - ws.logger.Error(err.Error(), "result", fmt.Sprint(res.Result)) - return - } - LiveNotificationChan, ok := ws.getLiveChannel(notification.ID) - if !ok { - err := fmt.Errorf("unavailable ResponseChannel %+v", resolvedID) - ws.logger.Error(err.Error(), "result", fmt.Sprint(res.Result)) - return - } - LiveNotificationChan <- notification - } -} - -func unmarshalMapToStruct(data map[string]interface{}, outStruct interface{}) error { - outValue := reflect.ValueOf(outStruct) - if outValue.Kind() != reflect.Ptr || outValue.Elem().Kind() != reflect.Struct { - return fmt.Errorf("outStruct must be a pointer to a struct") - } - - structValue := outValue.Elem() - structType := structValue.Type() - - for i := 0; i < structValue.NumField(); i++ { - field := structType.Field(i) - fieldName := field.Name - jsonTag := field.Tag.Get("json") - if jsonTag != "" { - fieldName = jsonTag - } - mapValue, ok := data[fieldName] - if !ok { - return fmt.Errorf("missing field in map: %s", fieldName) - } - - fieldValue := structValue.Field(i) - if !fieldValue.CanSet() { - return fmt.Errorf("cannot set field: %s", fieldName) - } - - if mapValue == nil { - // Handle nil values appropriately for your struct fields - // For simplicity, we skip nil values in this example - continue - } - - // Type conversion based on the field type - switch fieldValue.Kind() { - case reflect.String: - fieldValue.SetString(fmt.Sprint(mapValue)) - case reflect.Int: - intVal, err := strconv.Atoi(fmt.Sprint(mapValue)) - if err != nil { - return err - } - fieldValue.SetInt(int64(intVal)) - case reflect.Bool: - boolVal, err := strconv.ParseBool(fmt.Sprint(mapValue)) - if err != nil { - return err - } - fieldValue.SetBool(boolVal) - case reflect.Interface: - fieldValue.Set(reflect.ValueOf(mapValue)) - // Add cases for other types as needed - default: - return fmt.Errorf("unsupported field type: %s", fieldName) - } - } - - return nil -} diff --git a/pkg/connection/connection.go b/pkg/connection/connection.go new file mode 100644 index 0000000..89c5936 --- /dev/null +++ b/pkg/connection/connection.go @@ -0,0 +1,119 @@ +package connection + +import ( + "fmt" + "sync" + + "github.com/surrealdb/surrealdb.go/internal/codec" + "github.com/surrealdb/surrealdb.go/pkg/constants" + "github.com/surrealdb/surrealdb.go/pkg/logger" + "github.com/surrealdb/surrealdb.go/pkg/models" +) + +type LiveHandler interface { + Kill(id string) error + Live(table models.Table, diff bool) (*models.UUID, error) +} + +type Connection interface { + Connect() error + Close() error + Send(res interface{}, method string, params ...interface{}) error + Use(namespace string, database string) error + Let(key string, value interface{}) error + Unset(key string) error + LiveNotifications(id string) (chan Notification, error) +} + +type NewConnectionParams struct { + Marshaler codec.Marshaler + Unmarshaler codec.Unmarshaler + BaseURL string + Logger logger.Logger +} + +type BaseConnection struct { + baseURL string + marshaler codec.Marshaler + unmarshaler codec.Unmarshaler + logger logger.Logger + + responseChannels map[string]chan []byte + responseChannelsLock sync.RWMutex + + notificationChannels map[string]chan Notification + notificationChannelsLock sync.RWMutex +} + +func (bc *BaseConnection) createResponseChannel(id string) (chan []byte, error) { + bc.responseChannelsLock.Lock() + defer bc.responseChannelsLock.Unlock() + + if _, ok := bc.responseChannels[id]; ok { + return nil, fmt.Errorf("%w: %v", constants.ErrIDInUse, id) + } + + ch := make(chan []byte) + bc.responseChannels[id] = ch + + return ch, nil +} + +func (bc *BaseConnection) createNotificationChannel(liveQueryID string) (chan Notification, error) { + bc.notificationChannelsLock.Lock() + defer bc.notificationChannelsLock.Unlock() + + if _, ok := bc.notificationChannels[liveQueryID]; ok { + return nil, fmt.Errorf("%w: %v", constants.ErrIDInUse, liveQueryID) + } + + ch := make(chan Notification) + bc.notificationChannels[liveQueryID] = ch + + return ch, nil +} + +func (bc *BaseConnection) removeResponseChannel(id string) { + bc.responseChannelsLock.Lock() + defer bc.responseChannelsLock.Unlock() + delete(bc.responseChannels, id) +} + +func (bc *BaseConnection) getResponseChannel(id string) (chan []byte, bool) { + bc.responseChannelsLock.RLock() + defer bc.responseChannelsLock.RUnlock() + ch, ok := bc.responseChannels[id] + return ch, ok +} + +func (bc *BaseConnection) getLiveChannel(id string) (chan Notification, bool) { + bc.notificationChannelsLock.RLock() + defer bc.notificationChannelsLock.RUnlock() + ch, ok := bc.notificationChannels[id] + + return ch, ok +} + +func (bc *BaseConnection) preConnectionChecks() error { + if bc.baseURL == "" { + return constants.ErrNoBaseURL + } + + if bc.marshaler == nil { + return constants.ErrNoMarshaler + } + + if bc.unmarshaler == nil { + return constants.ErrNoUnmarshaler + } + + return nil +} + +func (bc *BaseConnection) LiveNotifications(liveQueryID string) (chan Notification, error) { + c, err := bc.createNotificationChannel(liveQueryID) + if err != nil { + bc.logger.Error(err.Error()) + } + return c, err +} diff --git a/pkg/connection/connection_test.go b/pkg/connection/connection_test.go new file mode 100644 index 0000000..c03a263 --- /dev/null +++ b/pkg/connection/connection_test.go @@ -0,0 +1,116 @@ +package connection + +import ( + "log/slog" + "os" + "testing" + + "github.com/stretchr/testify/suite" + "github.com/surrealdb/surrealdb.go/pkg/constants" + "github.com/surrealdb/surrealdb.go/pkg/logger" + "github.com/surrealdb/surrealdb.go/pkg/models" +) + +type testUser struct { + Username string `json:"username,omitempty"` + Password string `json:"password,omitempty"` + ID *models.RecordID `json:"id,omitempty"` +} + +type ConnectionTestSuite struct { + suite.Suite + name string + connImplementations map[string]Connection +} + +func TestConnectionTestSuite(t *testing.T) { + ts := new(ConnectionTestSuite) + ts.connImplementations = make(map[string]Connection) + + ts.connImplementations["ws"] = NewWebSocketConnection(NewConnectionParams{ + BaseURL: "ws://localhost:8000", + Marshaler: models.CborMarshaler{}, + Unmarshaler: models.CborUnmarshaler{}, + Logger: logger.New(slog.NewTextHandler(os.Stdout, nil)), + }) + + ts.connImplementations["http"] = NewHTTPConnection(NewConnectionParams{ + BaseURL: "http://localhost:8000", + Marshaler: models.CborMarshaler{}, + Unmarshaler: models.CborUnmarshaler{}, + Logger: logger.New(slog.NewTextHandler(os.Stdout, nil)), + }) + + for wsName := range ts.connImplementations { + // Run the test suite + t.Run(wsName, func(t *testing.T) { + ts.name = wsName + suite.Run(t, ts) + }) + } +} + +// SetupSuite is called before the s starts running +func (s *ConnectionTestSuite) SetupSuite() { + con := s.connImplementations[s.name] + + // connect + err := con.Connect() + s.Require().NoError(err) + + // set namespace, database + err = con.Use("test", "test") + s.Require().NoError(err) + + // sign in + var token RPCResponse[string] + err = con.Send(&token, "signin", map[string]interface{}{ + "user": "root", + "pass": "root", + }) + s.Require().NoError(err) + _ = con.Let(constants.AuthTokenKey, *token.Result) +} + +func (s *ConnectionTestSuite) TearDownSuite() { + con := s.connImplementations[s.name] + err := con.Close() + s.Require().NoError(err) +} + +func (s *ConnectionTestSuite) Test_CRUD() { + con := s.connImplementations[s.name] + + var createRes RPCResponse[testUser] + err := con.Send(&createRes, "create", "users", map[string]interface{}{ + "username": "remi", + "password": "password", + }) + s.Require().NoError(err) + + s.Assert().Equal(createRes.Result.Username, "remi") + s.Assert().Equal(createRes.Result.Password, "password") + + var selectRes RPCResponse[testUser] + err = con.Send(&selectRes, "select", createRes.Result.ID) + s.Require().NoError(err) + + s.Assert().Equal(createRes.Result.Username, "remi") + s.Assert().Equal(createRes.Result.Password, "password") + + userToUpdate := createRes.Result + userToUpdate.Password = "newpassword" + var updateRes RPCResponse[testUser] + err = con.Send(&updateRes, "update", userToUpdate.ID, userToUpdate) + s.Require().NoError(err) + + s.Assert().Equal(userToUpdate.ID, updateRes.Result.ID) + s.Assert().Equal(updateRes.Result.Password, "newpassword") + + err = con.Send(nil, "delete", userToUpdate.ID) + s.Require().NoError(err) + + var selectRes1 RPCResponse[testUser] + err = con.Send(&selectRes1, "select", createRes.Result.ID) + s.Require().NoError(err) +} diff --git a/pkg/connection/http.go b/pkg/connection/http.go new file mode 100644 index 0000000..09258d9 --- /dev/null +++ b/pkg/connection/http.go @@ -0,0 +1,171 @@ +package connection + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "sync" + "time" + + "github.com/surrealdb/surrealdb.go/internal/rand" + "github.com/surrealdb/surrealdb.go/pkg/constants" +) + +type HTTPConnection struct { + BaseConnection + + httpClient *http.Client + variables sync.Map +} + +func NewHTTPConnection(p NewConnectionParams) *HTTPConnection { + con := HTTPConnection{ + BaseConnection: BaseConnection{ + marshaler: p.Marshaler, + unmarshaler: p.Unmarshaler, + baseURL: p.BaseURL, + }, + } + + if con.httpClient == nil { + con.httpClient = &http.Client{ + Timeout: constants.DefaultHTTPTimeout, // Set a default timeout to avoid hanging requests + } + } + + return &con +} + +func (h *HTTPConnection) Connect() error { + ctx := context.TODO() + if err := h.preConnectionChecks(); err != nil { + return err + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, h.baseURL+"/health", http.NoBody) + if err != nil { + return err + } + _, err = h.MakeRequest(httpReq) + if err != nil { + return err + } + + return nil +} + +func (h *HTTPConnection) Close() error { + return nil +} + +func (h *HTTPConnection) SetTimeout(timeout time.Duration) *HTTPConnection { + h.httpClient.Timeout = timeout + return h +} + +func (h *HTTPConnection) SetHTTPClient(client *http.Client) *HTTPConnection { + h.httpClient = client + return h +} + +func (h *HTTPConnection) Send(dest any, method string, params ...interface{}) error { + if h.baseURL == "" { + return constants.ErrNoBaseURL + } + + request := &RPCRequest{ + ID: rand.String(constants.RequestIDLength), + Method: method, + Params: params, + } + reqBody, err := h.marshaler.Marshal(request) + + if err != nil { + return err + } + + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, h.baseURL+"/rpc", bytes.NewBuffer(reqBody)) + if err != nil { + return err + } + req.Header.Set("Accept", "application/cbor") + req.Header.Set("Content-Type", "application/cbor") + + if namespace, ok := h.variables.Load("namespace"); ok { + req.Header.Set("Surreal-NS", namespace.(string)) + } else { + return constants.ErrNoNamespaceOrDB + } + + if database, ok := h.variables.Load("database"); ok { + req.Header.Set("Surreal-DB", database.(string)) + } else { + return constants.ErrNoNamespaceOrDB + } + + if token, ok := h.variables.Load(constants.AuthTokenKey); ok { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + } + + respData, err := h.MakeRequest(req) + if err != nil { + return err + } + + var rpcRes RPCResponse[interface{}] + if err := h.unmarshaler.Unmarshal(respData, &rpcRes); err != nil { + return err + } + if rpcRes.Error != nil { + return rpcRes.Error + } + + if dest != nil { + return h.unmarshaler.Unmarshal(respData, dest) + } + + return nil +} + +func (h *HTTPConnection) MakeRequest(req *http.Request) ([]byte, error) { + resp, err := h.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("error making HTTP request: %w", err) + } + defer resp.Body.Close() + + respBytes, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + return respBytes, nil + } + + var errorResponse RPCResponse[any] + err = h.unmarshaler.Unmarshal(respBytes, &errorResponse) + if err != nil { + panic(err) + } + return nil, errorResponse.Error +} + +func (h *HTTPConnection) Use(namespace, database string) error { + h.variables.Store("namespace", namespace) + h.variables.Store("database", database) + + return nil +} + +func (h *HTTPConnection) Let(key string, value interface{}) error { + h.variables.Store(key, value) + return nil +} + +func (h *HTTPConnection) Unset(key string) error { + h.variables.Delete(key) + return nil +} diff --git a/pkg/connection/http_test.go b/pkg/connection/http_test.go new file mode 100644 index 0000000..2981ac7 --- /dev/null +++ b/pkg/connection/http_test.go @@ -0,0 +1,77 @@ +package connection + +import ( + "bytes" + "context" + "encoding/base64" + "io" + "net/http" + "testing" + + "github.com/stretchr/testify/suite" + "github.com/surrealdb/surrealdb.go/pkg/models" +) + +type RoundTripFunc func(req *http.Request) *http.Response + +func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req), nil +} + +// NewTestClient returns *http.Client with Transport replaced to avoid making real calls +func NewTestClient(fn RoundTripFunc) *http.Client { + return &http.Client{ + Transport: fn, + } +} + +type HTTPTestSuite struct { + suite.Suite + name string +} + +func TestHttpTestSuite(t *testing.T) { + ts := new(HTTPTestSuite) + ts.name = "HTTP Test Suite" + + suite.Run(t, ts) +} + +// SetupSuite is called before the s starts running +func (s *HTTPTestSuite) SetupSuite() { + +} + +func (s *HTTPTestSuite) TearDownSuite() { + +} + +func (s *HTTPTestSuite) TestMockClientEngine_MakeRequest() { + ctx := context.TODO() + + httpClient := NewTestClient(func(req *http.Request) *http.Response { + s.Assert().Equal(req.URL.String(), "http://test.surreal/rpc") + + respBody, _ := base64.StdEncoding.DecodeString("omJpZHAwSEtnRlZsZXFTQnVjYlpEZWVycm9yomRjb2RlGG9nbWVzc2FnZXNUaGVyZSB3YXMgYSBwcm9ibGVt") + return &http.Response{ + StatusCode: 400, + // Send response to be tested + Body: io.NopCloser(bytes.NewReader(respBody)), + // Must be set to non-nil value or it panics + Header: make(http.Header), + } + }) + + p := NewConnectionParams{ + BaseURL: "http://test.surreal", + Marshaler: models.CborMarshaler{}, + Unmarshaler: models.CborUnmarshaler{}, + } + + httpEngine := NewHTTPConnection(p) + httpEngine.SetHTTPClient(httpClient) + + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "http://test.surreal/rpc", http.NoBody) + _, err := httpEngine.MakeRequest(req) + s.Require().Error(err, "should return error for status code 400") +} diff --git a/pkg/connection/model.go b/pkg/connection/model.go new file mode 100644 index 0000000..c0b8d0d --- /dev/null +++ b/pkg/connection/model.go @@ -0,0 +1,61 @@ +package connection + +// RPCError represents a JSON-RPC error +type RPCError struct { + Code int `json:"code" msgpack:"code"` + Message string `json:"message,omitempty" msgpack:"message,omitempty"` + Description string `json:"description,omitempty" msgpack:"message,omitempty"` +} + +func (r RPCError) Error() string { + if r.Description != "" { + return r.Description + } + return r.Message +} + +// RPCRequest represents an incoming JSON-RPC request +type RPCRequest struct { + ID interface{} `json:"id" msgpack:"id"` + Method string `json:"method,omitempty" msgpack:"method,omitempty"` + Params []interface{} `json:"params,omitempty" msgpack:"params,omitempty"` +} + +// RPCResponse represents an outgoing JSON-RPC response +type RPCResponse[T any] struct { + ID interface{} `json:"id" msgpack:"id"` + Error *RPCError `json:"error,omitempty" msgpack:"error,omitempty"` + Result *T `json:"result,omitempty" msgpack:"result,omitempty"` +} + +// RPCNotification represents an outgoing JSON-RPC notification +type RPCNotification struct { + ID interface{} `json:"id" msgpack:"id"` + Method string `json:"method,omitempty" msgpack:"method,omitempty"` + Params []interface{} `json:"params,omitempty" msgpack:"params,omitempty"` +} + +type RPCFunction string + +var ( + Use RPCFunction = "use" + Info RPCFunction = "info" + SignUp RPCFunction = "signup" + SignIn RPCFunction = "signin" + Authenticate RPCFunction = "authenticate" + Invalidate RPCFunction = "invalidate" + Let RPCFunction = "let" + Unset RPCFunction = "unset" + Live RPCFunction = "live" + Kill RPCFunction = "kill" + Query RPCFunction = "query" + Select RPCFunction = "select" + Create RPCFunction = "create" + Insert RPCFunction = "insert" + Update RPCFunction = "update" + Upsert RPCFunction = "upsert" + Relate RPCFunction = "relate" + Merge RPCFunction = "merge" + Patch RPCFunction = "patch" + Delete RPCFunction = "delete" +) diff --git a/pkg/connection/notification.go b/pkg/connection/notification.go new file mode 100644 index 0000000..4df0bbd --- /dev/null +++ b/pkg/connection/notification.go @@ -0,0 +1,16 @@ +package connection + +import "github.com/surrealdb/surrealdb.go/pkg/models" + +type Notification struct { + ID *models.UUID `json:"id,omitempty"` + Action Action `json:"action"` + Result interface{} `json:"result"` +} +type Action string + +const ( + CreateAction Action = "CREATE" + UpdateAction Action = "UPDATE" + DeleteAction Action = "DELETE" +) diff --git a/pkg/connection/ws.go b/pkg/connection/ws.go new file mode 100644 index 0000000..e89f257 --- /dev/null +++ b/pkg/connection/ws.go @@ -0,0 +1,275 @@ +package connection + +import ( + "errors" + "fmt" + "io" + "log/slog" + "net" + "os" + "sync" + "time" + + "github.com/surrealdb/surrealdb.go/internal/rand" + "github.com/surrealdb/surrealdb.go/pkg/constants" + "github.com/surrealdb/surrealdb.go/pkg/logger" + + gorilla "github.com/gorilla/websocket" +) + +type Option func(ws *WebSocketConnection) error + +type WebSocketConnection struct { + BaseConnection + + Conn *gorilla.Conn + connLock sync.Mutex + Timeout time.Duration + Option []Option + logger logger.Logger + + closeChan chan int + closeError error +} + +func NewWebSocketConnection(p NewConnectionParams) *WebSocketConnection { + return &WebSocketConnection{ + BaseConnection: BaseConnection{ + baseURL: p.BaseURL, + + marshaler: p.Marshaler, + unmarshaler: p.Unmarshaler, + + responseChannels: make(map[string]chan []byte), + notificationChannels: make(map[string]chan Notification), + }, + + Conn: nil, + closeChan: make(chan int), + Timeout: constants.DefaultWSTimeout, + logger: logger.New(slog.NewJSONHandler(os.Stdout, nil)), + } +} + +func (ws *WebSocketConnection) Connect() error { + if err := ws.preConnectionChecks(); err != nil { + return err + } + + dialer := gorilla.DefaultDialer + dialer.EnableCompression = true + dialer.Subprotocols = append(dialer.Subprotocols, "cbor") + + connection, res, err := dialer.Dial(fmt.Sprintf("%s/rpc", ws.baseURL), nil) + if err != nil { + return err + } + defer res.Body.Close() + + ws.Conn = connection + + for _, option := range ws.Option { + if err := option(ws); err != nil { + return err + } + } + + go ws.initialize() + return nil +} + +func (ws *WebSocketConnection) SetTimeOut(timeout time.Duration) *WebSocketConnection { + ws.Option = append(ws.Option, func(ws *WebSocketConnection) error { + ws.Timeout = timeout + return nil + }) + return ws +} + +// If path is empty it will use os.stdout/os.stderr +func (ws *WebSocketConnection) Logger(logData logger.Logger) *WebSocketConnection { + ws.logger = logData + return ws +} + +func (ws *WebSocketConnection) RawLogger(logData logger.Logger) *WebSocketConnection { + ws.logger = logData + return ws +} + +func (ws *WebSocketConnection) SetCompression(compress bool) *WebSocketConnection { + ws.Option = append(ws.Option, func(ws *WebSocketConnection) error { + ws.Conn.EnableWriteCompression(compress) + return nil + }) + return ws +} + +func (ws *WebSocketConnection) Close() error { + ws.connLock.Lock() + defer ws.connLock.Unlock() + close(ws.closeChan) + err := ws.Conn.WriteMessage(gorilla.CloseMessage, gorilla.FormatCloseMessage(constants.CloseMessageCode, "")) + if err != nil { + return err + } + + return ws.Conn.Close() +} + +func (ws *WebSocketConnection) Use(namespace, database string) error { + err := ws.Send(nil, "use", namespace, database) + if err != nil { + return err + } + + return nil +} + +func (ws *WebSocketConnection) Let(key string, value interface{}) error { + return ws.Send(nil, "let", key, value) +} + +func (ws *WebSocketConnection) Unset(key string) error { + return ws.Send(nil, "unset", key) +} + +func (ws *WebSocketConnection) Send(dest interface{}, method string, params ...interface{}) error { + select { + case <-ws.closeChan: + return ws.closeError + default: + } + + id := rand.String(constants.RequestIDLength) + request := &RPCRequest{ + ID: id, + Method: method, + Params: params, + } + + responseChan, err := ws.createResponseChannel(id) + if err != nil { + return err + } + defer ws.removeResponseChannel(id) + + if err := ws.write(request); err != nil { + return err + } + timeout := time.After(ws.Timeout) + + select { + case <-timeout: + return constants.ErrTimeout + case resBytes, open := <-responseChan: + if !open { + return errors.New("channel closed") + } + if dest != nil { + return ws.unmarshaler.Unmarshal(resBytes, dest) + } + return nil + } +} + +func (ws *WebSocketConnection) write(v interface{}) error { + data, err := ws.marshaler.Marshal(v) + if err != nil { + return err + } + + ws.connLock.Lock() + defer ws.connLock.Unlock() + return ws.Conn.WriteMessage(gorilla.BinaryMessage, data) +} + +func (ws *WebSocketConnection) initialize() { + for { + select { + case <-ws.closeChan: + return + default: + _, data, err := ws.Conn.ReadMessage() + if err != nil { + shouldExit := ws.handleError(err) + if shouldExit { + return + } + continue + } + go ws.handleResponse(data) + } + } +} + +func (ws *WebSocketConnection) handleError(err error) bool { + if errors.Is(err, net.ErrClosed) { + ws.closeError = net.ErrClosed + return true + } + if gorilla.IsUnexpectedCloseError(err) { + ws.closeError = io.ErrClosedPipe + <-ws.closeChan + return true + } + + ws.logger.Error(err.Error()) + return false +} + +func (ws *WebSocketConnection) handleResponse(res []byte) { + var rpcRes RPCResponse[interface{}] + if err := ws.unmarshaler.Unmarshal(res, &rpcRes); err != nil { + panic(err) + } + + if rpcRes.Error != nil { + err := fmt.Errorf("rpc request err %w", rpcRes.Error) + ws.logger.Error(err.Error()) + return + } + + if rpcRes.ID != nil && rpcRes.ID != "" { + // Try to resolve message as response to query + responseChan, ok := ws.getResponseChannel(fmt.Sprintf("%v", rpcRes.ID)) + if !ok { + err := fmt.Errorf("unavailable ResponseChannel %+v", rpcRes.ID) + ws.logger.Error(err.Error()) + return + } + defer close(responseChan) + responseChan <- res + } else { + // todo: find a surefire way to confirm a notification + + var notificationRes RPCResponse[Notification] + if err := ws.unmarshaler.Unmarshal(res, ¬ificationRes); err != nil { + panic(err) + } + + if notificationRes.Result.ID == nil { + err := fmt.Errorf("response did not contain an 'id' field") + ws.logger.Error(err.Error(), "result", fmt.Sprint(rpcRes.Result)) + return + } + + channelID := notificationRes.Result.ID + + LiveNotificationChan, ok := ws.getLiveChannel(channelID.String()) + if !ok { + err := fmt.Errorf("unavailable ResponseChannel %+v", channelID.String()) + ws.logger.Error(err.Error(), "result", fmt.Sprint(rpcRes.Result)) + return + } + + var notification RPCResponse[Notification] + if err := ws.unmarshaler.Unmarshal(res, ¬ification); err != nil { + err := fmt.Errorf("error unmarshalling notification %+v", channelID.String()) + ws.logger.Error(err.Error(), "result", fmt.Sprint(rpcRes.Result)) + return + } + + LiveNotificationChan <- *notification.Result + } +} diff --git a/pkg/connection/ws_test.go b/pkg/connection/ws_test.go new file mode 100644 index 0000000..e15ea1f --- /dev/null +++ b/pkg/connection/ws_test.go @@ -0,0 +1,28 @@ +package connection + +import ( + "testing" + + "github.com/stretchr/testify/suite" +) + +type WsTestSuite struct { + suite.Suite + name string +} + +func TestWsTestSuite(t *testing.T) { + ts := new(WsTestSuite) + ts.name = "WS Test Suite" + + suite.Run(t, ts) +} + +// SetupSuite is called before the s starts running +func (s *WsTestSuite) SetupSuite() { + +} + +func (s *WsTestSuite) TearDownSuite() { + +} diff --git a/pkg/constants/constants.go b/pkg/constants/constants.go index 3935c5f..b2f946c 100644 --- a/pkg/constants/constants.go +++ b/pkg/constants/constants.go @@ -1,10 +1,20 @@ package constants -import "errors" +import "time" -// Errors var ( - InvalidResponse = errors.New("invalid SurrealDB response") //nolint:stylecheck - ErrQuery = errors.New("error occurred processing the SurrealDB query") - ErrNoRow = errors.New("error no row") + AuthTokenKey = "auth_token" +) + +const ( + // RequestIDLength size of id sent on WS request + RequestIDLength = 16 + // CloseMessageCode identifier the message id for a close request + CloseMessageCode = 1000 + // DefaultTimeout timeout in seconds + DefaultWSTimeout = 30 * time.Second + + DefaultHTTPTimeout = 10 * time.Second + + OneSecondToNanoSecond = 1_000_000_000 ) diff --git a/pkg/constants/errors.go b/pkg/constants/errors.go new file mode 100644 index 0000000..62f7b59 --- /dev/null +++ b/pkg/constants/errors.go @@ -0,0 +1,19 @@ +package constants + +import "errors" + +// Errors +var ( + InvalidResponse = errors.New("invalid SurrealDB response") //nolint:stylecheck + ErrQuery = errors.New("error occurred processing the SurrealDB query") + ErrNoRow = errors.New("error no row") +) +var ( + ErrIDInUse = errors.New("id already in use") + ErrTimeout = errors.New("timeout") + ErrNoBaseURL = errors.New("base url not set") + ErrNoMarshaler = errors.New("marshaler is not set") + ErrNoUnmarshaler = errors.New("unmarshaler is not set") + ErrNoNamespaceOrDB = errors.New("namespace or database or both are not set") + ErrMethodNotAvailable = errors.New("method not available on this connection") +) diff --git a/pkg/logger/slog/slog.go b/pkg/logger/slog.go similarity index 97% rename from pkg/logger/slog/slog.go rename to pkg/logger/slog.go index a8e58e6..2734a08 100644 --- a/pkg/logger/slog/slog.go +++ b/pkg/logger/slog.go @@ -1,4 +1,4 @@ -package slog +package logger import ( "log/slog" diff --git a/pkg/logger/slog/slog_test.go b/pkg/logger/slog_test.go similarity index 94% rename from pkg/logger/slog/slog_test.go rename to pkg/logger/slog_test.go index f86c01b..9d0ccba 100644 --- a/pkg/logger/slog/slog_test.go +++ b/pkg/logger/slog_test.go @@ -1,4 +1,4 @@ -package slog_test +package logger import ( "bytes" @@ -10,7 +10,6 @@ import ( rawslog "log/slog" "github.com/stretchr/testify/require" - "github.com/surrealdb/surrealdb.go/pkg/logger/slog" ) type testMethod struct { @@ -37,7 +36,7 @@ func TestLogger(t *testing.T) { // level needs to be set to debug for log all handler := rawslog.NewJSONHandler(buffer, &rawslog.HandlerOptions{Level: rawslog.LevelDebug}) - logger := slog.New(handler) + logger := New(handler) testMethods := []testMethod{ {fn: logger.Error, level: rawslog.LevelError}, diff --git a/pkg/marshal/marshal.go b/pkg/marshal/marshal.go deleted file mode 100644 index 6e3e1ae..0000000 --- a/pkg/marshal/marshal.go +++ /dev/null @@ -1,164 +0,0 @@ -package marshal - -import ( - "bytes" - "encoding/json" - "errors" - "fmt" - "reflect" - - "github.com/surrealdb/surrealdb.go/pkg/constants" - "github.com/surrealdb/surrealdb.go/pkg/util" -) - -const StatusOK = "OK" - -// Used for RawQuery Unmarshaling -type RawQuery[I any] struct { - Status string `json:"status"` - Time string `json:"time"` - Result []I `json:"result"` - Detail string `json:"detail"` -} - -// Unmarshal loads a SurrealDB response into a struct. -func Unmarshal(data, v interface{}) (err error) { - var jsonBytes []byte - if util.IsSlice(v) { - assertedData, ok := data.([]interface{}) - if !ok { - return fmt.Errorf("failed to deserialise response to slice: %w", constants.InvalidResponse) - } - jsonBytes, err = json.Marshal(assertedData) - if err != nil { - return fmt.Errorf("failed to deserialise response '%+v' to slice: %w", assertedData, constants.InvalidResponse) - } - } else { - jsonBytes, err = json.Marshal(data) - if err != nil { - return fmt.Errorf("failed to deserialise response '%+v' to object: %w", data, err) - } - } - if err != nil { - return err - } - - err = json.Unmarshal(jsonBytes, v) - if err != nil { - return fmt.Errorf("failed unmarshaling jsonBytes '%+v': %w", jsonBytes, err) - } - return nil -} - -// UnmarshalRaw loads a raw SurrealQL response returned by Query into a struct. Queries that return with results will -// return ok = true, and queries that return with no results will return ok = false. -func UnmarshalRaw[I any](rawData interface{}, v *[]RawQuery[I]) (err error) { - data, err := json.Marshal(rawData) - if err != nil { - return - } - err = json.Unmarshal(data, &v) - if err != nil { - return - } - for _, v := range *v { - if v.Status != StatusOK { - err = errors.Join(err, fmt.Errorf("status: %s, detail: %s", v.Status, v.Detail)) - } - } - return -} - -// SmartUnmarshal using generics for return desired type. -// Supports both raw and normal queries. -func SmartUnmarshal[I any](respond interface{}, wrapperError error) (outputs []I, err error) { - // Handle delete - if respond == nil || wrapperError != nil { - return outputs, wrapperError - } - data, err := json.Marshal(respond) - if err != nil { - return outputs, err - } - // Needed for checking fields - decoder := json.NewDecoder(bytes.NewReader(data)) - decoder.DisallowUnknownFields() - if _, isArr := respond.([]interface{}); !isArr { - // Non Arr Normal - var output I - err = decoder.Decode(&output) - if err == nil { - outputs = append(outputs, output) - } - } else { - // Arr Normal - if err = decoder.Decode(&outputs); err != nil { - // Arr Raw - var rawArr []RawQuery[I] - if err = json.Unmarshal(data, &rawArr); err == nil { - outputs = make([]I, 0) - for _, raw := range rawArr { - if raw.Status != StatusOK { - err = errors.Join(err, errors.New(raw.Status)) - } else { - outputs = append(outputs, raw.Result...) - } - } - } - } - } - return outputs, err -} - -// Used for define table name, it has no value. -type Basemodel struct{} - -// Smart Marshal Errors -var ( - ErrNotStruct = errors.New("data is not struct") - ErrNotValidFunc = errors.New("invalid function") -) - -// Smartmarshal can be used with all DB methods with generics and type safety. -// This handles errors and can use any struct tag with `BaseModel` type. -// Warning: "ID" field is case sensitive and expect string. -// Upon failure, the following will happen -// 1. If there are some ID on struct it will fill the table with the ID -// 2. If there are struct tags of the type `Basemodel`, it will use those values instead -// 3. If everything above fails or the IDs do not exist, SmartUnmarshal will use the struct name as the table name. -func SmartMarshal[I any](inputfunc interface{}, data I) (output interface{}, err error) { - var table string - datatype := reflect.TypeOf(data) - datavalue := reflect.ValueOf(data) - if datatype.Kind() == reflect.Pointer { - datatype = datatype.Elem() - datavalue = datavalue.Elem() - } - if datatype.Kind() == reflect.Struct { - if _, ok := datavalue.Field(0).Interface().(Basemodel); ok { - if temptable, ok := datatype.Field(0).Tag.Lookup("table"); ok { - table = temptable - } else { - table = reflect.TypeOf(data).Name() - } - } - if id, ok := datatype.FieldByName("ID"); ok { - if id.Type.Kind() == reflect.String { - if str, ok := datavalue.FieldByName("ID").Interface().(string); ok { - if str != "" { - table = str - } - } - } - } - } else { - return nil, ErrNotStruct - } - if function, ok := inputfunc.(func(thing string, data interface{}) (interface{}, error)); ok { - return function(table, data) - } - if function, ok := inputfunc.(func(thing string) (interface{}, error)); ok { - return function(table) - } - return nil, ErrNotValidFunc -} diff --git a/pkg/model/notification.go b/pkg/model/notification.go deleted file mode 100644 index cf2c280..0000000 --- a/pkg/model/notification.go +++ /dev/null @@ -1,14 +0,0 @@ -package model - -type Notification struct { - ID string `json:"id"` - Action Action `json:"action"` - Result interface{} `json:"result"` -} -type Action string - -const ( - CreateAction Action = "CREATE" - UpdateAction Action = "UPDATE" - DeleteAction Action = "DELETE" -) diff --git a/pkg/models/cbor.go b/pkg/models/cbor.go new file mode 100644 index 0000000..5531a68 --- /dev/null +++ b/pkg/models/cbor.go @@ -0,0 +1,125 @@ +package models + +import ( + "io" + "reflect" + "time" + + "github.com/fxamacker/cbor/v2" + "github.com/surrealdb/surrealdb.go/internal/codec" +) + +type CustomCBORTag uint64 + +var ( + NoneTag CustomCBORTag = 6 + TableNameTag CustomCBORTag = 7 + RecordIDTag CustomCBORTag = 8 + UUIDStringTag CustomCBORTag = 9 + DecimalStringTag CustomCBORTag = 10 + DateTimeCompactString CustomCBORTag = 12 + DurationStringTag CustomCBORTag = 13 + DurationCompactTag CustomCBORTag = 14 + BinaryUUIDTag CustomCBORTag = 37 + GeometryPointTag CustomCBORTag = 88 + GeometryLineTag CustomCBORTag = 89 + GeometryPolygonTag CustomCBORTag = 90 + GeometryMultiPointTag CustomCBORTag = 91 + GeometryMultiLineTag CustomCBORTag = 92 + GeometryMultiPolygonTag CustomCBORTag = 93 + GeometryCollectionTag CustomCBORTag = 94 +) + +func registerCborTags() cbor.TagSet { + customTags := map[CustomCBORTag]interface{}{ + GeometryPointTag: GeometryPoint{}, + GeometryLineTag: GeometryLine{}, + GeometryPolygonTag: GeometryPolygon{}, + GeometryMultiPointTag: GeometryMultiPoint{}, + GeometryMultiLineTag: GeometryMultiLine{}, + GeometryMultiPolygonTag: GeometryMultiPolygon{}, + GeometryCollectionTag: GeometryCollection{}, + + TableNameTag: Table(""), + //UUIDStringTag: UUID(""), + DecimalStringTag: Decimal(""), + BinaryUUIDTag: UUID{}, + NoneTag: CustomNil{}, + + DateTimeCompactString: CustomDateTime(time.Now()), + DurationStringTag: CustomDurationStr("2w"), + //DurationCompactTag: CustomDuration(0), + } + + tags := cbor.NewTagSet() + for tag, customType := range customTags { + err := tags.Add( + cbor.TagOptions{EncTag: cbor.EncTagRequired, DecTag: cbor.DecTagRequired}, + reflect.TypeOf(customType), + uint64(tag), + ) + if err != nil { + panic(err) + } + } + + return tags +} + +type CborMarshaler struct { +} + +func (c CborMarshaler) Marshal(v interface{}) ([]byte, error) { + v = replacerBeforeEncode(v) + em := getCborEncoder() + return em.Marshal(v) +} + +func (c CborMarshaler) NewEncoder(w io.Writer) codec.Encoder { + em := getCborEncoder() + return em.NewEncoder(w) +} + +type CborUnmarshaler struct { +} + +func (c CborUnmarshaler) Unmarshal(data []byte, dst interface{}) error { + dm := getCborDecoder() + err := dm.Unmarshal(data, dst) + if err != nil { + return err + } + + replacerAfterDecode(&dst) + return nil +} + +func (c CborUnmarshaler) NewDecoder(r io.Reader) codec.Decoder { + dm := getCborDecoder() + return dm.NewDecoder(r) +} + +func getCborEncoder() cbor.EncMode { + tags := registerCborTags() + em, err := cbor.EncOptions{ + Time: cbor.TimeRFC3339, + TimeTag: cbor.EncTagRequired, + }.EncModeWithTags(tags) + if err != nil { + panic(err) + } + + return em +} + +func getCborDecoder() cbor.DecMode { + tags := registerCborTags() + dm, err := cbor.DecOptions{ + TimeTagToAny: cbor.TimeTagToTime, + }.DecModeWithTags(tags) + if err != nil { + panic(err) + } + + return dm +} diff --git a/pkg/models/cbor_test.go b/pkg/models/cbor_test.go new file mode 100644 index 0000000..85150bc --- /dev/null +++ b/pkg/models/cbor_test.go @@ -0,0 +1,96 @@ +package models + +import ( + "fmt" + "testing" + "time" + + "github.com/fxamacker/cbor/v2" + "github.com/stretchr/testify/assert" +) + +func TestForGeometryPoint(t *testing.T) { + em := getCborEncoder() + dm := getCborDecoder() + + gp := NewGeometryPoint(12.23, 45.65) + encoded, err := em.Marshal(gp) + assert.Nil(t, err, "Should not encounter an error while encoding") + + decoded := GeometryPoint{} + err = dm.Unmarshal(encoded, &decoded) + + assert.Nil(t, err, "Should not encounter an error while decoding") + assert.Equal(t, gp, decoded) +} + +func TestForGeometryLine(t *testing.T) { + em := getCborEncoder() + dm := getCborDecoder() + + gp1 := NewGeometryPoint(12.23, 45.65) + gp2 := NewGeometryPoint(23.34, 56.75) + gp3 := NewGeometryPoint(33.45, 86.99) + + gl := GeometryLine{gp1, gp2, gp3} + + encoded, err := em.Marshal(gl) + assert.Nil(t, err, "Should not encounter an error while encoding") + + decoded := GeometryLine{} + err = dm.Unmarshal(encoded, &decoded) + assert.Nil(t, err, "Should not encounter an error while decoding") + assert.Equal(t, gl, decoded) +} + +func TestForGeometryPolygon(t *testing.T) { + em := getCborEncoder() + dm := getCborDecoder() + + gl1 := GeometryLine{NewGeometryPoint(12.23, 45.65), NewGeometryPoint(23.33, 44.44)} + gl2 := GeometryLine{GeometryPoint{12.23, 45.65}, GeometryPoint{23.33, 44.44}} + gl3 := GeometryLine{NewGeometryPoint(12.23, 45.65), NewGeometryPoint(23.33, 44.44)} + gp := GeometryPolygon{gl1, gl2, gl3} + + encoded, err := em.Marshal(gp) + assert.Nil(t, err, "Should not encounter an error while encoding") + + decoded := GeometryPolygon{} + err = dm.Unmarshal(encoded, &decoded) + + assert.Nil(t, err, "Should not encounter an error while decoding") + assert.Equal(t, gp, decoded) +} + +func TestForRequestPayload(t *testing.T) { + em := getCborEncoder() + + params := []interface{}{ + "SELECT marketing, count() FROM $tb GROUP BY marketing", + map[string]interface{}{ + "tb": Table("person"), + "line": GeometryLine{NewGeometryPoint(11.11, 22.22), NewGeometryPoint(33.33, 44.44)}, + "datetime": time.Now(), + "testNone": None, + "testNil": nil, + "duration": time.Duration(340), + // "custom_duration": CustomDuration(340), + "custom_datetime": CustomDateTime(time.Now()), + }, + } + + requestPayload := map[string]interface{}{ + "id": "2", + "method": "query", + "params": params, + } + + encoded, err := em.Marshal(requestPayload) + + assert.Nil(t, err, "should not return an error while encoding payload") + + diagStr, err := cbor.Diagnose(encoded) + assert.Nil(t, err, "should not return an error while diagnosing payload") + + fmt.Println(diagStr) +} diff --git a/pkg/models/duration.go b/pkg/models/duration.go new file mode 100644 index 0000000..2ba5aa3 --- /dev/null +++ b/pkg/models/duration.go @@ -0,0 +1,43 @@ +package models + +import ( + "time" + + "github.com/surrealdb/surrealdb.go/pkg/constants" + + "github.com/fxamacker/cbor/v2" +) + +type CustomDuration time.Duration + +type CustomDurationStr string + +func (d *CustomDuration) MarshalCBOR() ([]byte, error) { + enc := getCborEncoder() + + totalNS := time.Duration(*d).Nanoseconds() + s := totalNS / constants.OneSecondToNanoSecond + ns := totalNS % constants.OneSecondToNanoSecond + + return enc.Marshal(cbor.Tag{ + Number: uint64(DurationCompactTag), + Content: [2]int64{s, ns}, + }) +} + +func (d *CustomDuration) UnmarshalCBOR(data []byte) error { + dec := getCborDecoder() + + var temp [2]interface{} + err := dec.Unmarshal(data, &temp) + if err != nil { + return err + } + + s := temp[0].(int64) + ns := temp[1].(int64) + + *d = CustomDuration(time.Duration((float64(s) * constants.OneSecondToNanoSecond) + float64(ns))) + + return nil +} diff --git a/pkg/models/geometry.go b/pkg/models/geometry.go new file mode 100644 index 0000000..cbed82e --- /dev/null +++ b/pkg/models/geometry.go @@ -0,0 +1,54 @@ +package models + +import "github.com/fxamacker/cbor/v2" + +type GeometryPoint struct { + Latitude float64 + Longitude float64 +} + +func NewGeometryPoint(latitude, longitude float64) GeometryPoint { + return GeometryPoint{ + Latitude: latitude, Longitude: longitude, + } +} + +func (gp *GeometryPoint) GetCoordinates() [2]float64 { + return [2]float64{gp.Latitude, gp.Longitude} +} + +func (gp *GeometryPoint) MarshalCBOR() ([]byte, error) { + enc := getCborEncoder() + + return enc.Marshal(cbor.Tag{ + Number: uint64(GeometryPointTag), + Content: gp.GetCoordinates(), + }) +} + +func (gp *GeometryPoint) UnmarshalCBOR(data []byte) error { + dec := getCborDecoder() + + var temp [2]float64 + err := dec.Unmarshal(data, &temp) + if err != nil { + return err + } + + gp.Latitude = temp[0] + gp.Longitude = temp[1] + + return nil +} + +type GeometryLine []GeometryPoint + +type GeometryPolygon []GeometryLine + +type GeometryMultiPoint []GeometryPoint + +type GeometryMultiLine []GeometryLine + +type GeometryMultiPolygon []GeometryPolygon + +type GeometryCollection []any diff --git a/pkg/models/record_id.go b/pkg/models/record_id.go new file mode 100644 index 0000000..a52f31e --- /dev/null +++ b/pkg/models/record_id.go @@ -0,0 +1,60 @@ +package models + +import ( + "fmt" + "strings" + + "github.com/fxamacker/cbor/v2" +) + +type RecordID struct { + Table string + ID any +} + +type RecordIDType interface { + ~int | ~string | []any | map[string]any +} + +func ParseRecordID(idStr string) *RecordID { + expectedLen := 2 + bits := strings.Split(idStr, ":") + if len(bits) != expectedLen { + panic(fmt.Errorf("invalid id string. Expected format is 'tablename:indentifier'")) + } + return &RecordID{ + Table: bits[0], ID: bits[1], + } +} + +func NewRecordID(tableName string, id any) RecordID { + return RecordID{Table: tableName, ID: id} +} + +func (r *RecordID) MarshalCBOR() ([]byte, error) { + enc := getCborEncoder() + + return enc.Marshal(cbor.Tag{ + Number: uint64(RecordIDTag), + Content: []interface{}{r.Table, r.ID}, + }) +} + +func (r *RecordID) UnmarshalCBOR(data []byte) error { + dec := getCborDecoder() + + var temp []interface{} + err := dec.Unmarshal(data, &temp) + if err != nil { + return err + } + + r.Table = temp[0].(string) + r.ID = temp[1].(string) + + return nil +} + +func (r *RecordID) String() string { + return fmt.Sprintf("%s:%s", r.Table, r.ID) +} diff --git a/pkg/models/replacer.go b/pkg/models/replacer.go new file mode 100644 index 0000000..cec3b4b --- /dev/null +++ b/pkg/models/replacer.go @@ -0,0 +1,60 @@ +package models + +import ( + "reflect" + "time" +) + +func replacerBeforeEncode(value interface{}) interface{} { + valueType := reflect.TypeOf(value) + valueKind := valueType.Kind() + + if valueType == reflect.TypeOf(time.Duration(0)) { + oldVal := value.(time.Duration) + newValue := CustomDuration(oldVal.Nanoseconds()) + return newValue + } + + if valueKind == reflect.Map { + oldValue := value.(map[string]interface{}) + newValue := make(map[interface{}]interface{}) + for k, v := range oldValue { + newKey := replacerBeforeEncode(k) + newVal := replacerBeforeEncode(v) + newValue[newKey] = newVal + } + + return newValue + } + + // todo: handle slices + + return value +} + +func replacerAfterDecode(value interface{}) interface{} { + valueType := reflect.TypeOf(value) + valueKind := valueType.Kind() + + if valueType == reflect.TypeOf(CustomDuration(0)) { + oldVal := value.(CustomDuration) + newValue := time.Duration(oldVal) + return newValue + } + + if valueKind == reflect.Map { + oldValue := value.(map[string]interface{}) + newValue := make(map[interface{}]interface{}) + for k, v := range oldValue { + newKey := replacerAfterDecode(k) + newVal := replacerAfterDecode(v) + newValue[newKey] = newVal + } + + return newValue + } + + // todo: handle slices + + return value +} diff --git a/pkg/models/replacer_test.go b/pkg/models/replacer_test.go new file mode 100644 index 0000000..12137dd --- /dev/null +++ b/pkg/models/replacer_test.go @@ -0,0 +1,19 @@ +package models + +import ( + "fmt" + "testing" + "time" +) + +func TestReplacerBeForeEncode(t *testing.T) { + d := map[string]interface{}{ + "duration": time.Duration(2000), + "nested": map[string]interface{}{ + "duration": time.Duration(3000), + }, + } + + newD := replacerBeforeEncode(d) + fmt.Println(newD) +} diff --git a/pkg/models/types.go b/pkg/models/types.go new file mode 100644 index 0000000..14c541c --- /dev/null +++ b/pkg/models/types.go @@ -0,0 +1,77 @@ +package models + +import ( + "time" + + "github.com/surrealdb/surrealdb.go/pkg/constants" + + "github.com/fxamacker/cbor/v2" + "github.com/gofrs/uuid" +) + +type TableOrRecord interface { + string | Table | RecordID | []Table | []RecordID +} + +type Table string + +// type UUID string + +// type UUIDBin []byte +type UUID struct { + uuid.UUID +} + +type Decimal string + +type CustomDateTime time.Time + +func (d *CustomDateTime) MarshalCBOR() ([]byte, error) { + enc := getCborEncoder() + + totalNS := time.Time(*d).Nanosecond() + + s := totalNS / constants.OneSecondToNanoSecond + ns := totalNS % constants.OneSecondToNanoSecond + + return enc.Marshal(cbor.Tag{ + Number: uint64(DateTimeCompactString), + Content: [2]int64{int64(s), int64(ns)}, + }) +} + +func (d *CustomDateTime) UnmarshalCBOR(data []byte) error { + dec := getCborDecoder() + + var temp [2]interface{} + err := dec.Unmarshal(data, &temp) + if err != nil { + return err + } + + s := temp[0].(int64) + ns := temp[1].(int64) + + *d = CustomDateTime(time.Unix(s, ns)) + + return nil +} + +type CustomNil struct { +} + +func (c *CustomNil) MarshalCBOR() ([]byte, error) { + enc := getCborEncoder() + + return enc.Marshal(cbor.Tag{ + Number: uint64(NoneTag), + Content: nil, + }) +} + +func (c *CustomNil) UnMarshalCBOR(data []byte) error { + *c = CustomNil{} + return nil +} + +var None = CustomNil{} diff --git a/pkg/util/util.go b/pkg/util/util.go deleted file mode 100644 index dfc5ea0..0000000 --- a/pkg/util/util.go +++ /dev/null @@ -1,7 +0,0 @@ -package util - -import "reflect" - -func IsSlice(value interface{}) bool { - return reflect.TypeOf(value).Kind() == reflect.Slice -} diff --git a/project.json b/project.json index 78903fa..6eff79e 100644 --- a/project.json +++ b/project.json @@ -1,5 +1,5 @@ { "name": "surrealdb", - "version": "0.1.1", + "version": "1.0.0-beta", "signGitTag": false } diff --git a/types.go b/types.go index a2f3e1d..e4d6889 100644 --- a/types.go +++ b/types.go @@ -1,8 +1,42 @@ package surrealdb +import "github.com/surrealdb/surrealdb.go/pkg/models" + // Patch represents a patch object set to MODIFY a record -type Patch struct { +type PatchData struct { Op string `json:"op"` Path string `json:"path"` Value any `json:"value"` } + +type QueryResult[T any] struct { + Status string `json:"status"` + Time string `json:"time"` + Result T `json:"result"` +} + +type QueryStatement[TResult any] struct { + SQL string + Vars map[string]interface{} +} + +type Relation[T any] struct { + ID string `json:"id"` + In models.RecordID `json:"in"` + Out models.RecordID `json:"out"` +} + +// Auth is a struct that holds surrealdb auth data for login. +type Auth struct { + Namespace string `json:"NS,omitempty"` + Database string `json:"DB,omitempty"` + Scope string `json:"SC,omitempty"` + Username string `json:"user,omitempty"` + Password string `json:"pass,omitempty"` +} + +type O map[interface{}]interface{} + +type Result[T any] struct { + T any +}