diff --git a/Makefile b/Makefile index f2807e66..f7a06390 100644 --- a/Makefile +++ b/Makefile @@ -13,9 +13,15 @@ dropdb: migrateup: migrate -path db/migration -database "postgresql://root:secret@localhost:5432/simple_bank?sslmode=disable" -verbose up +migrateup1: + migrate -path db/migration -database "postgresql://root:secret@localhost:5432/simple_bank?sslmode=disable" -verbose up 1 + migratedown: migrate -path db/migration -database "postgresql://root:secret@localhost:5432/simple_bank?sslmode=disable" -verbose down +migratedown1: + migrate -path db/migration -database "postgresql://root:secret@localhost:5432/simple_bank?sslmode=disable" -verbose down 1 + sqlc: sqlc generate @@ -28,4 +34,4 @@ server: mock: mockgen -package mockdb -destination db/mock/store.go github.com/techschool/simplebank/db/sqlc Store -.PHONY: postgres createdb dropdb migrateup migratedown sqlc test server mock +.PHONY: postgres createdb dropdb migrateup migratedown migrateup1 migratedown1 sqlc test server mock diff --git a/api/account.go b/api/account.go index e5ebeb9f..085bc92f 100644 --- a/api/account.go +++ b/api/account.go @@ -5,6 +5,7 @@ import ( "net/http" "github.com/gin-gonic/gin" + "github.com/lib/pq" db "github.com/techschool/simplebank/db/sqlc" ) @@ -28,6 +29,13 @@ func (server *Server) createAccount(ctx *gin.Context) { account, err := server.store.CreateAccount(ctx, arg) if err != nil { + if pqErr, ok := err.(*pq.Error); ok { + switch pqErr.Code.Name() { + case "foreign_key_violation", "unique_violation": + ctx.JSON(http.StatusForbidden, errorResponse(err)) + return + } + } ctx.JSON(http.StatusInternalServerError, errorResponse(err)) return } diff --git a/db/migration/000002_add_users.down.sql b/db/migration/000002_add_users.down.sql new file mode 100644 index 00000000..543e10cb --- /dev/null +++ b/db/migration/000002_add_users.down.sql @@ -0,0 +1,5 @@ +ALTER TABLE IF EXISTS "accounts" DROP CONSTRAINT IF EXISTS "owner_currency_key"; + +ALTER TABLE IF EXISTS "accounts" DROP CONSTRAINT IF EXISTS "accounts_owner_fkey"; + +DROP TABLE IF EXISTS "users"; diff --git a/db/migration/000002_add_users.up.sql b/db/migration/000002_add_users.up.sql new file mode 100644 index 00000000..85069e87 --- /dev/null +++ b/db/migration/000002_add_users.up.sql @@ -0,0 +1,12 @@ +CREATE TABLE "users" ( + "username" varchar PRIMARY KEY, + "hashed_password" varchar NOT NULL, + "full_name" varchar NOT NULL, + "email" varchar UNIQUE NOT NULL, + "password_changed_at" timestamptz NOT NULL DEFAULT('0001-01-01 00:00:00+00'), + "created_at" timestamptz NOT NULL DEFAULT (now()) +); + +ALTER TABLE "accounts" ADD FOREIGN KEY ("owner") REFERENCES "users" ("username"); + +ALTER TABLE "accounts" ADD CONSTRAINT "owner_currency_key" UNIQUE ("owner", "currency"); diff --git a/db/mock/store.go b/db/mock/store.go index 2a1b036d..29c9056d 100644 --- a/db/mock/store.go +++ b/db/mock/store.go @@ -94,6 +94,21 @@ func (mr *MockStoreMockRecorder) CreateTransfer(arg0, arg1 interface{}) *gomock. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateTransfer", reflect.TypeOf((*MockStore)(nil).CreateTransfer), arg0, arg1) } +// CreateUser mocks base method +func (m *MockStore) CreateUser(arg0 context.Context, arg1 db.CreateUserParams) (db.User, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateUser", arg0, arg1) + ret0, _ := ret[0].(db.User) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateUser indicates an expected call of CreateUser +func (mr *MockStoreMockRecorder) CreateUser(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateUser", reflect.TypeOf((*MockStore)(nil).CreateUser), arg0, arg1) +} + // DeleteAccount mocks base method func (m *MockStore) DeleteAccount(arg0 context.Context, arg1 int64) error { m.ctrl.T.Helper() @@ -168,6 +183,21 @@ func (mr *MockStoreMockRecorder) GetTransfer(arg0, arg1 interface{}) *gomock.Cal return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTransfer", reflect.TypeOf((*MockStore)(nil).GetTransfer), arg0, arg1) } +// GetUser mocks base method +func (m *MockStore) GetUser(arg0 context.Context, arg1 string) (db.User, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetUser", arg0, arg1) + ret0, _ := ret[0].(db.User) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetUser indicates an expected call of GetUser +func (mr *MockStoreMockRecorder) GetUser(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUser", reflect.TypeOf((*MockStore)(nil).GetUser), arg0, arg1) +} + // ListAccounts mocks base method func (m *MockStore) ListAccounts(arg0 context.Context, arg1 db.ListAccountsParams) ([]db.Account, error) { m.ctrl.T.Helper() diff --git a/db/query/user.sql b/db/query/user.sql new file mode 100644 index 00000000..a3901988 --- /dev/null +++ b/db/query/user.sql @@ -0,0 +1,13 @@ +-- name: CreateUser :one +INSERT INTO users ( + username, + hashed_password, + full_name, + email +) VALUES ( + $1, $2, $3, $4 +) RETURNING *; + +-- name: GetUser :one +SELECT * FROM users +WHERE username = $1 LIMIT 1; diff --git a/db/sqlc/account_test.go b/db/sqlc/account_test.go index 962c51cd..790eaec7 100644 --- a/db/sqlc/account_test.go +++ b/db/sqlc/account_test.go @@ -11,8 +11,10 @@ import ( ) func createRandomAccount(t *testing.T) Account { + user := createRandomUser(t) + arg := CreateAccountParams{ - Owner: util.RandomOwner(), + Owner: user.Username, Balance: util.RandomMoney(), Currency: util.RandomCurrency(), } diff --git a/db/sqlc/models.go b/db/sqlc/models.go index ffcdfa50..b4249902 100644 --- a/db/sqlc/models.go +++ b/db/sqlc/models.go @@ -30,3 +30,12 @@ type Transfer struct { Amount int64 `json:"amount"` CreatedAt time.Time `json:"created_at"` } + +type User struct { + Username string `json:"username"` + HashedPassword string `json:"hashed_password"` + FullName string `json:"full_name"` + Email string `json:"email"` + PasswordChangedAt time.Time `json:"password_changed_at"` + CreatedAt time.Time `json:"created_at"` +} diff --git a/db/sqlc/querier.go b/db/sqlc/querier.go index bbdfb5f9..82434e77 100644 --- a/db/sqlc/querier.go +++ b/db/sqlc/querier.go @@ -11,11 +11,13 @@ type Querier interface { CreateAccount(ctx context.Context, arg CreateAccountParams) (Account, error) CreateEntry(ctx context.Context, arg CreateEntryParams) (Entry, error) CreateTransfer(ctx context.Context, arg CreateTransferParams) (Transfer, error) + CreateUser(ctx context.Context, arg CreateUserParams) (User, error) DeleteAccount(ctx context.Context, id int64) error GetAccount(ctx context.Context, id int64) (Account, error) GetAccountForUpdate(ctx context.Context, id int64) (Account, error) GetEntry(ctx context.Context, id int64) (Entry, error) GetTransfer(ctx context.Context, id int64) (Transfer, error) + GetUser(ctx context.Context, username string) (User, error) ListAccounts(ctx context.Context, arg ListAccountsParams) ([]Account, error) ListEntries(ctx context.Context, arg ListEntriesParams) ([]Entry, error) ListTransfers(ctx context.Context, arg ListTransfersParams) ([]Transfer, error) diff --git a/db/sqlc/user.sql.go b/db/sqlc/user.sql.go new file mode 100644 index 00000000..564f127e --- /dev/null +++ b/db/sqlc/user.sql.go @@ -0,0 +1,64 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: user.sql + +package db + +import ( + "context" +) + +const createUser = `-- name: CreateUser :one +INSERT INTO users ( + username, + hashed_password, + full_name, + email +) VALUES ( + $1, $2, $3, $4 +) RETURNING username, hashed_password, full_name, email, password_changed_at, created_at +` + +type CreateUserParams struct { + Username string `json:"username"` + HashedPassword string `json:"hashed_password"` + FullName string `json:"full_name"` + Email string `json:"email"` +} + +func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, error) { + row := q.db.QueryRowContext(ctx, createUser, + arg.Username, + arg.HashedPassword, + arg.FullName, + arg.Email, + ) + var i User + err := row.Scan( + &i.Username, + &i.HashedPassword, + &i.FullName, + &i.Email, + &i.PasswordChangedAt, + &i.CreatedAt, + ) + return i, err +} + +const getUser = `-- name: GetUser :one +SELECT username, hashed_password, full_name, email, password_changed_at, created_at FROM users +WHERE username = $1 LIMIT 1 +` + +func (q *Queries) GetUser(ctx context.Context, username string) (User, error) { + row := q.db.QueryRowContext(ctx, getUser, username) + var i User + err := row.Scan( + &i.Username, + &i.HashedPassword, + &i.FullName, + &i.Email, + &i.PasswordChangedAt, + &i.CreatedAt, + ) + return i, err +} diff --git a/db/sqlc/user_test.go b/db/sqlc/user_test.go new file mode 100644 index 00000000..234a8dbb --- /dev/null +++ b/db/sqlc/user_test.go @@ -0,0 +1,50 @@ +package db + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/techschool/simplebank/util" +) + +func createRandomUser(t *testing.T) User { + arg := CreateUserParams{ + Username: util.RandomOwner(), + HashedPassword: "secret", + FullName: util.RandomOwner(), + Email: util.RandomEmail(), + } + + user, err := testQueries.CreateUser(context.Background(), arg) + require.NoError(t, err) + require.NotEmpty(t, user) + + require.Equal(t, arg.Username, user.Username) + require.Equal(t, arg.HashedPassword, user.HashedPassword) + require.Equal(t, arg.FullName, user.FullName) + require.Equal(t, arg.Email, user.Email) + require.Zero(t, user.PasswordChangedAt) + require.NotZero(t, user.CreatedAt) + + return user +} + +func TestCreateUser(t *testing.T) { + createRandomUser(t) +} + +func TestGetUser(t *testing.T) { + user1 := createRandomUser(t) + user2, err := testQueries.GetUser(context.Background(), user1.Username) + require.NoError(t, err) + require.NotEmpty(t, user2) + + require.Equal(t, user1.Username, user2.Username) + require.Equal(t, user1.HashedPassword, user2.HashedPassword) + require.Equal(t, user1.FullName, user2.FullName) + require.Equal(t, user1.Email, user2.Email) + require.WithinDuration(t, user1.PasswordChangedAt, user2.PasswordChangedAt, time.Second) + require.WithinDuration(t, user1.CreatedAt, user2.CreatedAt, time.Second) +} diff --git a/util/random.go b/util/random.go index 727ca945..0ffd515a 100644 --- a/util/random.go +++ b/util/random.go @@ -1,6 +1,7 @@ package util import ( + "fmt" "math/rand" "strings" "time" @@ -46,3 +47,8 @@ func RandomCurrency() string { n := len(currencies) return currencies[rand.Intn(n)] } + +// RandomEmail generates a random email +func RandomEmail() string { + return fmt.Sprintf("%s@email.com", RandomString(6)) +}