diff --git a/.circleci/check_golint.sh b/.circleci/check_golint.sh index 55a82b936..9da543e79 100755 --- a/.circleci/check_golint.sh +++ b/.circleci/check_golint.sh @@ -7,7 +7,7 @@ # and skip issues about `_` in packages' names since we won't rename packages just to make this issue disappear result=$(golint ./... | grep -v "\.pb\.go\|don't use an underscore in package name" | tee /dev/stderr | wc -l) -if [[ $result -gt 6 ]]; then +if [[ $result -gt 5 ]]; then # too many golint issues echo "Too many golint issues: $result" exit 1; diff --git a/.circleci/check_gotest.sh b/.circleci/check_gotest.sh index 7a3a20ca3..5e0b43e33 100755 --- a/.circleci/check_gotest.sh +++ b/.circleci/check_gotest.sh @@ -14,6 +14,15 @@ if [ -z "$GO_VERSIONS" ]; then GO_VERSIONS="$(readlink $GOROOT)" fi +# for local run +if [ -z "$GO_VERSIONS" ]; then + echo 'Run tests with local golang executable' + go test -tags="${TEST_BUILD_TAGS}" ${TEST_EXTRA_BUILD_FLAGS} ./...; + status="$?" + exit $status +fi + +# for circleci run for go_version in $GO_VERSIONS; do export GOROOT="/usr/local/lib/go/$go_version" diff --git a/CHANGELOG_DEV.md b/CHANGELOG_DEV.md index 759661494..414c6b10e 100644 --- a/CHANGELOG_DEV.md +++ b/CHANGELOG_DEV.md @@ -1,3 +1,9 @@ +## 0.93.0 - 2022-03-28 +- Transparent decryption with replacing type's metadata +- Extend `encryptor_config` with new settings: `data_type=[int32|int64|str|bytes]` and `default_data_value: ` +- Support values in text format from Postgresql's binary protocol +- Refactored internals of data encoding/decoding, protocol processing, saving session related data + ## 0.93.0 - 2022-03-23 - Remove autogeneration of poison keys on the first access (but keep in poisonrecordmaker). - Add warning on enabled poison detection if keys are not generated. diff --git a/Makefile b/Makefile index 651303f59..b55784a6b 100644 --- a/Makefile +++ b/Makefile @@ -192,8 +192,19 @@ build_protobuf: @protoc --go_out=`pwd` --go-grpc_out=`pwd` \ --go_opt=module=github.com/cossacklabs/acra \ --go-grpc_opt=module=github.com/cossacklabs/acra \ - -Icmd/acra-translator/grpc_api \ - cmd/acra-translator/grpc_api/*.proto + -Ipseudonymization/common \ + pseudonymization/common/*.proto + @protoc --go_out=`pwd` --go-grpc_out=`pwd` \ + --go_opt=module=github.com/cossacklabs/acra \ + --go-grpc_opt=module=github.com/cossacklabs/acra \ + -Icmd/acra-translator/grpc_api \ + cmd/acra-translator/grpc_api/*.proto + @protoc --go_out=`pwd` --go-grpc_out=`pwd` \ + --go_opt=module=github.com/cossacklabs/acra \ + --go-grpc_opt=module=github.com/cossacklabs/acra \ + -Iencryptor/config/common \ + encryptor/config/common/*.proto + @python3 -m grpc_tools.protoc -Icmd/acra-translator/grpc_api --proto_path=. --python_out=tests/ --grpc_python_out=tests/ cmd/acra-translator/grpc_api/*.proto ## Build the application in the subdirectory (default) diff --git a/benchmarks/common/db.go b/benchmarks/common/db.go index 10208c76e..7e11c8dfc 100644 --- a/benchmarks/common/db.go +++ b/benchmarks/common/db.go @@ -17,6 +17,7 @@ package common import ( "database/sql" + // import driver for connect function _ "github.com/jackc/pgx/v4/stdlib" "github.com/sirupsen/logrus" "os" diff --git a/cmd/acra-server/acra-server.go b/cmd/acra-server/acra-server.go index 46272bc82..a0d08a817 100644 --- a/cmd/acra-server/acra-server.go +++ b/cmd/acra-server/acra-server.go @@ -318,7 +318,7 @@ func realMain() error { } serverConfig.SetKeyStore(keyStore) - log.Infof("Keystore init OK") + log.WithField("path", *keysDir).Infof("Keystore init OK") if err := crypto.InitRegistry(keyStore); err != nil { log.WithError(err).Errorln("Can't initialize crypto registry") diff --git a/cmd/acra-server/common/client_session.go b/cmd/acra-server/common/client_session.go index a1f192999..d6098731d 100644 --- a/cmd/acra-server/common/client_session.go +++ b/cmd/acra-server/common/client_session.go @@ -36,6 +36,7 @@ type ClientSession struct { logger *log.Entry statements base.PreparedStatementRegistry protocolState interface{} + data map[string]interface{} } var sessionCounter uint32 @@ -47,8 +48,35 @@ func NewClientSession(ctx context.Context, config *Config, connection net.Conn) sessionID := atomic.AddUint32(&sessionCounter, 1) logger := logging.GetLoggerFromContext(ctx) logger = logger.WithField("session_id", sessionID) + session := &ClientSession{connection: connection, config: config, ctx: ctx, logger: logger, + data: make(map[string]interface{}, 8)} ctx = logging.SetLoggerToContext(ctx, logger) - return &ClientSession{connection: connection, config: config, ctx: ctx, logger: logger}, nil + ctx = base.SetClientSessionToContext(ctx, session) + session.ctx = ctx + return session, nil + +} + +// SetData save session related data by key +func (clientSession *ClientSession) SetData(key string, data interface{}) { + clientSession.data[key] = data +} + +// GetData return session related data by key and true otherwise nil, false +func (clientSession *ClientSession) GetData(key string) (interface{}, bool) { + value, ok := clientSession.data[key] + return value, ok +} + +// DeleteData delete session related data by key +func (clientSession *ClientSession) DeleteData(key string) { + delete(clientSession.data, key) +} + +// HasData return true if session has data by key +func (clientSession *ClientSession) HasData(key string) bool { + _, ok := clientSession.data[key] + return ok } // Logger returns session's logger. diff --git a/cmd/acra-server/common/client_session_test.go b/cmd/acra-server/common/client_session_test.go new file mode 100644 index 000000000..08ca90cb5 --- /dev/null +++ b/cmd/acra-server/common/client_session_test.go @@ -0,0 +1,73 @@ +package common + +import ( + "context" + "reflect" + "testing" +) + +func TestClientSession_Data(t *testing.T) { + session, err := NewClientSession(context.TODO(), nil, nil) + if err != nil { + t.Fatal(err) + } + type testcase struct { + key string + data interface{} + } + testcases := []testcase{ + {`binary key`, []byte(`binary data`)}, + {`string key`, `string value`}, + {`int key`, 123}, + {`struct key`, testcase{`123`, `123`}}, + } + overwriteValue := `some value that will overwrite existing value` + for _, tcase := range testcases { + if session.HasData(tcase.key) { + t.Fatal("session should not have value of not used key") + } + value, ok := session.GetData(tcase.key) + if ok { + t.Fatal("session should not have value of not used key") + } + if value != nil { + t.Fatal("session should return nil for not existing keys") + } + session.SetData(tcase.key, tcase.data) + if !session.HasData(tcase.key) { + t.Fatal("session hasn't value of existing key") + } + value, ok = session.GetData(tcase.key) + if !ok { + t.Fatal("session hasn't value of of existing key") + } + if !reflect.DeepEqual(tcase.data, value) { + t.Fatal("session returned another value") + } + + // overwrite value and check that it successfully overwritten + session.SetData(tcase.key, overwriteValue) + if !session.HasData(tcase.key) { + t.Fatal("session hasn't value of existing key") + } + value, ok = session.GetData(tcase.key) + if !ok { + t.Fatal("session hasn't value of of existing key") + } + if !reflect.DeepEqual(overwriteValue, value) { + t.Fatal("session returned another value") + } + + session.DeleteData(tcase.key) + if session.HasData(tcase.key) { + t.Fatal("session should not have value of not used key") + } + value, ok = session.GetData(tcase.key) + if ok { + t.Fatal("session should not have value of not used key") + } + if value != nil { + t.Fatal("session should return nil for not existing keys") + } + } +} diff --git a/cmd/acra-server/common/listener.go b/cmd/acra-server/common/listener.go index a1065ff5a..7cb72a992 100644 --- a/cmd/acra-server/common/listener.go +++ b/cmd/acra-server/common/listener.go @@ -184,18 +184,18 @@ func (server *SServer) handleClientSession(clientID []byte, clientSession *Clien accessContext := base.NewAccessContext(base.WithClientID(clientID), base.WithZoneMode(server.config.GetWithZone())) // subscribe on clientID changes after switching connection to TLS and using ClientID from TLS certificates proxy.AddClientIDObserver(accessContext) - + clientSession.ctx = base.SetAccessContextToContext(clientSession.ctx, accessContext) server.backgroundWorkersSync.Add(1) go func() { defer server.backgroundWorkersSync.Done() defer recoverConnection(sessionLogger.WithField("function", "ProxyClientConnection"), sessionCloseToCloser(clientSession.Close)) - proxy.ProxyClientConnection(base.SetAccessContextToContext(clientSession.ctx, accessContext), proxyErrCh) + proxy.ProxyClientConnection(clientSession.ctx, proxyErrCh) }() server.backgroundWorkersSync.Add(1) go func() { defer server.backgroundWorkersSync.Done() defer recoverConnection(sessionLogger.WithField("function", "ProxyDatabaseConnection"), sessionCloseToCloser(clientSession.Close)) - proxy.ProxyDatabaseConnection(base.SetAccessContextToContext(clientSession.ctx, accessContext), proxyErrCh) + proxy.ProxyDatabaseConnection(clientSession.ctx, proxyErrCh) }() proxyErr := <-proxyErrCh diff --git a/cmd/acra-translator/grpc_api/Readme.md b/cmd/acra-translator/grpc_api/Readme.md index cc1dc7b24..6ecb2d35b 100644 --- a/cmd/acra-translator/grpc_api/Readme.md +++ b/cmd/acra-translator/grpc_api/Readme.md @@ -1,10 +1,9 @@ # Install grpc dependencies ``` -# from https://github.com/grpc/grpc-go -go get -u github.com/golang/protobuf/{proto,protoc-gen-go} -go get -u google.golang.org/grpc +# from https://developers.google.com/protocol-buffers/docs/gotutorial +go install google.golang.org/protobuf/cmd/protoc-gen-go@latest ``` To recompile proto file run from root of acra repository: ``` -protoc --go_out=plugins=grpc:. cmd/acra-translator/grpc_api/api.proto +make build_protobuf ``` diff --git a/configs/acra-log-verifier.yaml b/configs/acra-log-verifier.yaml deleted file mode 100644 index 8bda12735..000000000 --- a/configs/acra-log-verifier.yaml +++ /dev/null @@ -1,58 +0,0 @@ -version: 0.85.0 -# path to audit log file to verify -audit_log_file: - -# Expected format of audit log file(s) that should be verified: plaintext, json or CEF -audit_log_file_format: plaintext - -# path to list of audit log files to verify -audit_log_file_list: - -# don't fail validation if some audit logs cannot be opened -audit_log_missing_ok: false - -# path to config -config_file: - -# debug mode (shows order of input files) -d: false - -# dump config -dump_config: false - -# Generate with yaml config markdown text file with descriptions of all args -generate_markdown_args_table: false - -# Folder from which will be loaded keys -keys_dir: .acrakeys - -# Logging format of this tool: plaintext, json or CEF -logging_format: plaintext - -# Number of Redis database for keys -redis_db_keys: 0 - -# : used to connect to Redis -redis_host_port: - -# Password to Redis database -redis_password: - -# Connection string (http://x.x.x.x:yyyy) for loading ACRA_MASTER_KEY from HashiCorp Vault -vault_connection_api_string: - -# KV Secret Path (secret/) for reading ACRA_MASTER_KEY from HashiCorp Vault -vault_secrets_path: secret/ - -# Path to CA certificate for HashiCorp Vault certificate validation -vault_tls_ca_path: - -# Path to client TLS certificate for reading ACRA_MASTER_KEY from HashiCorp Vault -vault_tls_client_cert: - -# Path to private key of the client TLS certificate for reading ACRA_MASTER_KEY from HashiCorp Vault -vault_tls_client_key: - -# Use TLS to encrypt transport with HashiCorp Vault -vault_tls_transport_enable: false - diff --git a/crypto/envelope_detector.go b/crypto/envelope_detector.go index 451b3b4f0..87605b869 100644 --- a/crypto/envelope_detector.go +++ b/crypto/envelope_detector.go @@ -6,6 +6,7 @@ import ( "errors" "github.com/cossacklabs/acra/acrablock" "github.com/cossacklabs/acra/acrastruct" + "github.com/cossacklabs/acra/decryptor/base" "github.com/cossacklabs/acra/utils" "github.com/sirupsen/logrus" ) @@ -60,6 +61,7 @@ func (recognizer *EnvelopeDetector) OnColumn(ctx context.Context, inBuffer []byt return ctx, inBuffer, nil } outBuffer := make([]byte, 0, len(inBuffer)) + changed := false // inline mode inIndex := 0 for { @@ -96,6 +98,7 @@ func (recognizer *EnvelopeDetector) OnColumn(ctx context.Context, inBuffer []byt if !bytes.Equal(processedData, container) { outBuffer = append(outBuffer, processedData...) inIndex += n + changed = true break } @@ -109,6 +112,9 @@ func (recognizer *EnvelopeDetector) OnColumn(ctx context.Context, inBuffer []byt } // copy left bytes outBuffer = append(outBuffer, inBuffer[inIndex:]...) + if changed { + return base.MarkDecryptedContext(ctx), outBuffer, nil + } return ctx, outBuffer, nil } @@ -194,6 +200,9 @@ func (wrapper *OldContainerDetectorWrapper) OnColumn(ctx context.Context, inBuff if err != nil { return ctx, inBuffer, err } + if !bytes.Equal(inBuffer, outBuffer) { + return base.MarkDecryptedContext(ctx), outBuffer, nil + } return ctx, outBuffer, nil } diff --git a/crypto/envelope_detector_test.go b/crypto/envelope_detector_test.go index 725729e21..27e559f55 100644 --- a/crypto/envelope_detector_test.go +++ b/crypto/envelope_detector_test.go @@ -183,10 +183,13 @@ func TestOldContainerDetectorWrapper(t *testing.T) { }, } for _, tcase := range testCases { - _, outBuffer, err := containerDetector.OnColumn(base.SetAccessContextToContext(context.Background(), accessContext), tcase.input) + ctx, outBuffer, err := containerDetector.OnColumn(base.SetAccessContextToContext(context.Background(), accessContext), tcase.input) if err != nil { t.Fatal("OnColumn error ", err) } + if !base.IsDecryptedFromContext(ctx) { + t.Fatal("Expects decrypted data") + } if !bytes.Equal(outBuffer, tcase.expected) { t.Fatal("outBuffer is not equals to expected", err) diff --git a/decryptor/base/clientSession.go b/decryptor/base/clientSession.go new file mode 100644 index 000000000..041dd6a6c --- /dev/null +++ b/decryptor/base/clientSession.go @@ -0,0 +1,40 @@ +package base + +import ( + "context" + "net" +) + +// ClientSession is a connection between the client and the database, mediated by AcraServer. +type ClientSession interface { + Context() context.Context + ClientConnection() net.Conn + DatabaseConnection() net.Conn + + PreparedStatementRegistry() PreparedStatementRegistry + SetPreparedStatementRegistry(registry PreparedStatementRegistry) + + ProtocolState() interface{} + SetProtocolState(state interface{}) + GetData(string) (interface{}, bool) + SetData(string, interface{}) + DeleteData(string) + HasData(string) bool +} + +type sessionContextKey struct{} + +// SetClientSessionToContext return context with saved ClientSession +func SetClientSessionToContext(ctx context.Context, session ClientSession) context.Context { + return context.WithValue(ctx, sessionContextKey{}, session) +} + +// ClientSessionFromContext return saved ClientSession from context or nil +func ClientSessionFromContext(ctx context.Context) ClientSession { + value := ctx.Value(sessionContextKey{}) + session, ok := value.(ClientSession) + if ok { + return session + } + return nil +} diff --git a/decryptor/base/clientSession_test.go b/decryptor/base/clientSession_test.go new file mode 100644 index 000000000..e124a1b4e --- /dev/null +++ b/decryptor/base/clientSession_test.go @@ -0,0 +1,67 @@ +package base + +import ( + "context" + "net" + "reflect" + "testing" +) + +type sessionStub struct{} + +func (s sessionStub) Context() context.Context { + panic("implement me") +} + +func (s sessionStub) ClientConnection() net.Conn { + panic("implement me") +} + +func (s sessionStub) DatabaseConnection() net.Conn { + panic("implement me") +} + +func (s sessionStub) PreparedStatementRegistry() PreparedStatementRegistry { + panic("implement me") +} + +func (s sessionStub) SetPreparedStatementRegistry(registry PreparedStatementRegistry) { + panic("implement me") +} + +func (s sessionStub) ProtocolState() interface{} { + panic("implement me") +} + +func (s sessionStub) SetProtocolState(state interface{}) { + panic("implement me") +} + +func (s sessionStub) GetData(s2 string) (interface{}, bool) { + panic("implement me") +} + +func (s sessionStub) SetData(s2 string, i interface{}) { + panic("implement me") +} + +func (s sessionStub) DeleteData(s2 string) { + panic("implement me") +} + +func (s sessionStub) HasData(s2 string) bool { + panic("implement me") +} + +func TestSetClientSessionToContext(t *testing.T) { + session := sessionStub{} + ctx := context.Background() + if value := ClientSessionFromContext(ctx); value != nil { + t.Fatal("Unexpected session value from empty context") + } + ctx = SetClientSessionToContext(ctx, session) + value := ClientSessionFromContext(ctx) + if !reflect.DeepEqual(value, session) { + t.Fatal("Returned incorrect session value") + } +} diff --git a/decryptor/base/decryptionNotification.go b/decryptor/base/decryptionNotification.go index a86142548..76502507d 100644 --- a/decryptor/base/decryptionNotification.go +++ b/decryptor/base/decryptionNotification.go @@ -61,14 +61,12 @@ type DecryptionSubscriber interface { // ColumnDecryptionNotifier interface to subscribe/unsubscribe on OnColumn events type ColumnDecryptionNotifier interface { - SubscribeOnColumnDecryption(i int, subscriber DecryptionSubscriber) SubscribeOnAllColumnsDecryption(subscriber DecryptionSubscriber) Unsubscribe(DecryptionSubscriber) } // ColumnDecryptionObserver is a simple ColumnDecryptionNotifier implementation. type ColumnDecryptionObserver struct { - perColumn map[int][]DecryptionSubscriber allColumns []DecryptionSubscriber } @@ -76,22 +74,10 @@ type ColumnDecryptionObserver struct { func NewColumnDecryptionObserver() ColumnDecryptionObserver { // Reserve some memory for a typical amount of subscribers. return ColumnDecryptionObserver{ - perColumn: make(map[int][]DecryptionSubscriber, 10), allColumns: make([]DecryptionSubscriber, 0, 5), } } -// SubscribeOnColumnDecryption subscribes for notifications about the column, indexed from left to right starting with zero. -func (o *ColumnDecryptionObserver) SubscribeOnColumnDecryption(column int, subscriber DecryptionSubscriber) { - subscribers := o.perColumn[column] - for _, existing := range subscribers { - if existing == subscriber { - return - } - } - o.perColumn[column] = append(subscribers, subscriber) -} - // SubscribeOnAllColumnsDecryption subscribes for notifications on each column. func (o *ColumnDecryptionObserver) SubscribeOnAllColumnsDecryption(subscriber DecryptionSubscriber) { o.allColumns = append(o.allColumns, subscriber) @@ -99,14 +85,6 @@ func (o *ColumnDecryptionObserver) SubscribeOnAllColumnsDecryption(subscriber De // Unsubscribe a subscriber from all notifications. func (o *ColumnDecryptionObserver) Unsubscribe(subscriber DecryptionSubscriber) { - for column, observers := range o.perColumn { - for i, existing := range observers { - if existing == subscriber { - o.perColumn[column] = append(observers[:i], observers[i+1:]...) - break - } - } - } for i, existing := range o.allColumns { if existing == subscriber { o.allColumns = append(o.allColumns[:i], o.allColumns[i+1:]...) @@ -121,14 +99,6 @@ func (o *ColumnDecryptionObserver) Unsubscribe(subscriber DecryptionSubscriber) func (o *ColumnDecryptionObserver) OnColumnDecryption(ctx context.Context, column int, data []byte) ([]byte, error) { var err error // Avoid creating a map entry if it does not exist. - subscribers, _ := o.perColumn[column] - for _, subscriber := range subscribers { - ctx, data, err = subscriber.OnColumn(ctx, data) - if err != nil { - logrus.WithField("subscriber", subscriber.ID()).WithError(err).Errorln("OnColumn error") - return data, err - } - } for _, subscriber := range o.allColumns { ctx, data, err = subscriber.OnColumn(ctx, data) if err != nil { @@ -138,3 +108,15 @@ func (o *ColumnDecryptionObserver) OnColumnDecryption(ctx context.Context, colum } return data, nil } + +type decryptedCtxKey struct{} + +// MarkDecryptedContext save flag in context that data was decrypted +func MarkDecryptedContext(ctx context.Context) context.Context { + return context.WithValue(ctx, decryptedCtxKey{}, true) +} + +// IsDecryptedFromContext return true if data was decrypted related to context +func IsDecryptedFromContext(ctx context.Context) bool { + return ctx.Value(decryptedCtxKey{}) != nil +} diff --git a/decryptor/base/decryptionNotification_test.go b/decryptor/base/decryptionNotification_test.go new file mode 100644 index 000000000..2d071449b --- /dev/null +++ b/decryptor/base/decryptionNotification_test.go @@ -0,0 +1,17 @@ +package base + +import ( + "context" + "testing" +) + +func TestMarkDecryptedContext(t *testing.T) { + ctx := context.Background() + if IsDecryptedFromContext(ctx) { + t.Fatal("Unexpected decrypted flag") + } + ctx = MarkDecryptedContext(ctx) + if !IsDecryptedFromContext(ctx) { + t.Fatal("Expects decrypted flag") + } +} diff --git a/decryptor/base/mocks.go b/decryptor/base/mocks.go new file mode 100644 index 000000000..1bdecb0c7 --- /dev/null +++ b/decryptor/base/mocks.go @@ -0,0 +1,4 @@ +package base + +//go:generate mockery --name ClientSession +//go:generate mockery --name BoundValue diff --git a/decryptor/base/mocks/BoundValue.go b/decryptor/base/mocks/BoundValue.go new file mode 100644 index 000000000..99b556f71 --- /dev/null +++ b/decryptor/base/mocks/BoundValue.go @@ -0,0 +1,119 @@ +// Code generated by mockery v2.8.0. DO NOT EDIT. + +package mocks + +import ( + base "github.com/cossacklabs/acra/decryptor/base" + config "github.com/cossacklabs/acra/encryptor/config" + + mock "github.com/stretchr/testify/mock" +) + +// BoundValue is an autogenerated mock type for the BoundValue type +type BoundValue struct { + mock.Mock +} + +// Copy provides a mock function with given fields: +func (_m *BoundValue) Copy() base.BoundValue { + ret := _m.Called() + + var r0 base.BoundValue + if rf, ok := ret.Get(0).(func() base.BoundValue); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(base.BoundValue) + } + } + + return r0 +} + +// Encode provides a mock function with given fields: +func (_m *BoundValue) Encode() ([]byte, error) { + ret := _m.Called() + + var r0 []byte + if rf, ok := ret.Get(0).(func() []byte); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Format provides a mock function with given fields: +func (_m *BoundValue) Format() base.BoundValueFormat { + ret := _m.Called() + + var r0 base.BoundValueFormat + if rf, ok := ret.Get(0).(func() base.BoundValueFormat); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(base.BoundValueFormat) + } + + return r0 +} + +// GetData provides a mock function with given fields: setting +func (_m *BoundValue) GetData(setting config.ColumnEncryptionSetting) ([]byte, error) { + ret := _m.Called(setting) + + var r0 []byte + if rf, ok := ret.Get(0).(func(config.ColumnEncryptionSetting) []byte); ok { + r0 = rf(setting) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(config.ColumnEncryptionSetting) error); ok { + r1 = rf(setting) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetType provides a mock function with given fields: +func (_m *BoundValue) GetType() byte { + ret := _m.Called() + + var r0 byte + if rf, ok := ret.Get(0).(func() byte); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(byte) + } + + return r0 +} + +// SetData provides a mock function with given fields: newData, setting +func (_m *BoundValue) SetData(newData []byte, setting config.ColumnEncryptionSetting) error { + ret := _m.Called(newData, setting) + + var r0 error + if rf, ok := ret.Get(0).(func([]byte, config.ColumnEncryptionSetting) error); ok { + r0 = rf(newData, setting) + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/decryptor/base/mocks/ClientSession.go b/decryptor/base/mocks/ClientSession.go new file mode 100644 index 000000000..f675bec76 --- /dev/null +++ b/decryptor/base/mocks/ClientSession.go @@ -0,0 +1,155 @@ +// Code generated by mockery v2.8.0. DO NOT EDIT. + +package mocks + +import ( + context "context" + + base "github.com/cossacklabs/acra/decryptor/base" + + mock "github.com/stretchr/testify/mock" + + net "net" +) + +// ClientSession is an autogenerated mock type for the ClientSession type +type ClientSession struct { + mock.Mock +} + +// ClientConnection provides a mock function with given fields: +func (_m *ClientSession) ClientConnection() net.Conn { + ret := _m.Called() + + var r0 net.Conn + if rf, ok := ret.Get(0).(func() net.Conn); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(net.Conn) + } + } + + return r0 +} + +// Context provides a mock function with given fields: +func (_m *ClientSession) Context() context.Context { + ret := _m.Called() + + var r0 context.Context + if rf, ok := ret.Get(0).(func() context.Context); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(context.Context) + } + } + + return r0 +} + +// DatabaseConnection provides a mock function with given fields: +func (_m *ClientSession) DatabaseConnection() net.Conn { + ret := _m.Called() + + var r0 net.Conn + if rf, ok := ret.Get(0).(func() net.Conn); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(net.Conn) + } + } + + return r0 +} + +// DeleteData provides a mock function with given fields: _a0 +func (_m *ClientSession) DeleteData(_a0 string) { + _m.Called(_a0) +} + +// GetData provides a mock function with given fields: _a0 +func (_m *ClientSession) GetData(_a0 string) (interface{}, bool) { + ret := _m.Called(_a0) + + var r0 interface{} + if rf, ok := ret.Get(0).(func(string) interface{}); ok { + r0 = rf(_a0) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(interface{}) + } + } + + var r1 bool + if rf, ok := ret.Get(1).(func(string) bool); ok { + r1 = rf(_a0) + } else { + r1 = ret.Get(1).(bool) + } + + return r0, r1 +} + +// HasData provides a mock function with given fields: _a0 +func (_m *ClientSession) HasData(_a0 string) bool { + ret := _m.Called(_a0) + + var r0 bool + if rf, ok := ret.Get(0).(func(string) bool); ok { + r0 = rf(_a0) + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// PreparedStatementRegistry provides a mock function with given fields: +func (_m *ClientSession) PreparedStatementRegistry() base.PreparedStatementRegistry { + ret := _m.Called() + + var r0 base.PreparedStatementRegistry + if rf, ok := ret.Get(0).(func() base.PreparedStatementRegistry); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(base.PreparedStatementRegistry) + } + } + + return r0 +} + +// ProtocolState provides a mock function with given fields: +func (_m *ClientSession) ProtocolState() interface{} { + ret := _m.Called() + + var r0 interface{} + if rf, ok := ret.Get(0).(func() interface{}); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(interface{}) + } + } + + return r0 +} + +// SetData provides a mock function with given fields: _a0, _a1 +func (_m *ClientSession) SetData(_a0 string, _a1 interface{}) { + _m.Called(_a0, _a1) +} + +// SetPreparedStatementRegistry provides a mock function with given fields: registry +func (_m *ClientSession) SetPreparedStatementRegistry(registry base.PreparedStatementRegistry) { + _m.Called(registry) +} + +// SetProtocolState provides a mock function with given fields: state +func (_m *ClientSession) SetProtocolState(state interface{}) { + _m.Called(state) +} diff --git a/decryptor/base/observers.go b/decryptor/base/observers.go index 40fb57cb9..d6c4954d1 100644 --- a/decryptor/base/observers.go +++ b/decryptor/base/observers.go @@ -70,7 +70,7 @@ type BoundValue interface { Format() BoundValueFormat Copy() BoundValue SetData(newData []byte, setting config.ColumnEncryptionSetting) error - GetData(setting config.ColumnEncryptionSetting) []byte + GetData(setting config.ColumnEncryptionSetting) ([]byte, error) Encode() ([]byte, error) GetType() byte } diff --git a/decryptor/base/proxy.go b/decryptor/base/proxy.go index 90d4cbccd..2fea060c1 100644 --- a/decryptor/base/proxy.go +++ b/decryptor/base/proxy.go @@ -110,19 +110,6 @@ type Proxy interface { ProxyDatabaseConnection(context.Context, chan<- ProxyError) } -// ClientSession is a connection between the client and the database, mediated by AcraServer. -type ClientSession interface { - Context() context.Context - ClientConnection() net.Conn - DatabaseConnection() net.Conn - - PreparedStatementRegistry() PreparedStatementRegistry - SetPreparedStatementRegistry(registry PreparedStatementRegistry) - - ProtocolState() interface{} - SetProtocolState(state interface{}) -} - // TLSConnectionWrapper used by proxy to wrap raw connections to TLS when intercepts client/database request about switching to TLS // Reuse network.ConnectionWrapper to explicitly force TLS usage by name type TLSConnectionWrapper interface { @@ -223,3 +210,13 @@ func (p ProxyError) Unwrap() error { func (p ProxyError) InterruptSide() string { return p.interruptSide } + +// OnlyDefaultEncryptorSettings returns true if config contains settings only for transparent decryption that works by default +func OnlyDefaultEncryptorSettings(store config.TableSchemaStore) bool { + storeMask := store.GetGlobalSettingsMask() + return storeMask&(config.SettingSearchFlag| + config.SettingMaskingFlag| + config.SettingTokenizationFlag| + config.SettingDefaultDataValueFlag| + config.SettingDataTypeFlag) == 0 +} diff --git a/decryptor/mysql/packet.go b/decryptor/mysql/packet.go index 10a4b1e91..161a66372 100644 --- a/decryptor/mysql/packet.go +++ b/decryptor/mysql/packet.go @@ -182,7 +182,11 @@ func (packet *Packet) SetParameters(values []base.BoundValue) (err error) { // and we need to get result tokenization value to set signed/unsigned byte switch Type(boundType) { case TypeLong, TypeLongLong: - intValue, err := strconv.ParseInt(string(values[i].GetData(nil)), 10, 64) + data, err := values[i].GetData(nil) + if err != nil { + return err + } + intValue, err := strconv.ParseInt(string(data), 10, 64) if err != nil { return err } diff --git a/decryptor/mysql/prepared_statements.go b/decryptor/mysql/prepared_statements.go index a77722656..b496f6d5b 100644 --- a/decryptor/mysql/prepared_statements.go +++ b/decryptor/mysql/prepared_statements.go @@ -209,8 +209,8 @@ func (m *mysqlBoundValue) SetData(newData []byte, setting config.ColumnEncryptio } // GetData return BoundValue using ColumnEncryptionSetting if provided -func (m *mysqlBoundValue) GetData(_ config.ColumnEncryptionSetting) []byte { - return m.textData +func (m *mysqlBoundValue) GetData(_ config.ColumnEncryptionSetting) ([]byte, error) { + return m.textData, nil } // Encode format result BoundValue data diff --git a/decryptor/mysql/prepared_statements_test.go b/decryptor/mysql/prepared_statements_test.go index 177e286db..d13de3e10 100644 --- a/decryptor/mysql/prepared_statements_test.go +++ b/decryptor/mysql/prepared_statements_test.go @@ -14,16 +14,23 @@ func TestNewMysqlCopyTextBoundValue(t *testing.T) { sourceData[0] = 22 - if reflect.DeepEqual(sourceData, boundValue.GetData(nil)) { + value, err := boundValue.GetData(nil) + if err != nil { + t.Fatal(err) + } + if reflect.DeepEqual(sourceData, value) { t.Fatal("BoundValue data should not be equal to sourceData") } }) t.Run("nil data provided", func(t *testing.T) { boundValue := NewMysqlCopyTextBoundValue(nil, base.BinaryFormat, TypeBlob) - + value, err := boundValue.GetData(nil) + if err != nil { + t.Fatal(err) + } // we need to validate that textData is nil if nil was provided - required for handling NULL values - if boundValue.GetData(nil) != nil { + if value != nil { t.Fatal("BoundValue data should be nil") } }) diff --git a/decryptor/mysql/proxy.go b/decryptor/mysql/proxy.go index 4b02d11d6..e26eb5a12 100644 --- a/decryptor/mysql/proxy.go +++ b/decryptor/mysql/proxy.go @@ -65,7 +65,7 @@ func (factory *proxyFactory) New(clientID []byte, clientSession base.ClientSessi schemaStore := factory.setting.TableSchemaStore() storeMask := schemaStore.GetGlobalSettingsMask() // register only if masking/tokenization/searching will be used - if storeMask&(config.SettingSearchFlag|config.SettingMaskingFlag|config.SettingTokenizationFlag) > 0 { + if !base.OnlyDefaultEncryptorSettings(schemaStore) { // register Query processor first before other processors because it match SELECT queries for ColumnEncryptorConfig structs // and store it in AccessContext for next decryptions/encryptions and all other processors rely on that // use nil dataEncryptor to avoid extra computations diff --git a/decryptor/mysql/proxy_test.go b/decryptor/mysql/proxy_test.go index a86d11c1f..c6d38075b 100644 --- a/decryptor/mysql/proxy_test.go +++ b/decryptor/mysql/proxy_test.go @@ -157,6 +157,22 @@ func (*tableSchemaStore) GetGlobalSettingsMask() config.SettingMask { type stubSession struct{} +func (s stubSession) GetData(s2 string) (interface{}, bool) { + panic("implement me") +} + +func (s stubSession) SetData(s2 string, i interface{}) { + panic("implement me") +} + +func (s stubSession) DeleteData(s2 string) { + panic("implement me") +} + +func (s stubSession) HasData(s2 string) bool { + panic("implement me") +} + func (stubSession) Context() context.Context { return context.TODO() } diff --git a/decryptor/mysql/response_proxy.go b/decryptor/mysql/response_proxy.go index cdaf094e8..66135743b 100644 --- a/decryptor/mysql/response_proxy.go +++ b/decryptor/mysql/response_proxy.go @@ -215,11 +215,6 @@ func NewMysqlProxy(session base.ClientSession, parser *sqlparser.Parser, setting }, nil } -// SubscribeOnColumnDecryption subscribes for OnColumn notifications about the column, indexed from left to right starting with zero. -func (handler *Handler) SubscribeOnColumnDecryption(i int, subscriber base.DecryptionSubscriber) { - handler.decryptionObserver.SubscribeOnColumnDecryption(i, subscriber) -} - // SubscribeOnAllColumnsDecryption subscribes for OnColumn notifications on each column. func (handler *Handler) SubscribeOnAllColumnsDecryption(subscriber base.DecryptionSubscriber) { handler.decryptionObserver.SubscribeOnAllColumnsDecryption(subscriber) diff --git a/decryptor/postgresql/data_encoder.go b/decryptor/postgresql/data_encoder.go new file mode 100644 index 000000000..db49424c9 --- /dev/null +++ b/decryptor/postgresql/data_encoder.go @@ -0,0 +1,249 @@ +package postgresql + +import ( + "context" + "encoding/base64" + "encoding/binary" + "github.com/cossacklabs/acra/decryptor/base" + "github.com/cossacklabs/acra/encryptor" + "github.com/cossacklabs/acra/encryptor/config" + common2 "github.com/cossacklabs/acra/encryptor/config/common" + "github.com/cossacklabs/acra/logging" + "github.com/cossacklabs/acra/utils" + "github.com/sirupsen/logrus" + "strconv" +) + +// PgSQLDataEncoderProcessor implements processor and encode binary/text values before sending to app +type PgSQLDataEncoderProcessor struct{} + +// NewPgSQLDataEncoderProcessor return new data encoder to text/binary format +func NewPgSQLDataEncoderProcessor() (*PgSQLDataEncoderProcessor, error) { + return &PgSQLDataEncoderProcessor{}, nil +} + +// ID return name of processor +func (p *PgSQLDataEncoderProcessor) ID() string { + return "PgSQLDataEncoderProcessor" +} + +// here we process encryption/tokenization results before send it to a client +// acra decrypts or de-tokenize SQL literals, so we should convert string SQL literals to binary format +// if client expects int, then parse INT literals and convert to binary 4/8 byte format +// if expects bytes, then pass as is +// if expects string, then leave as is if it is valid string or encode to hex +// if it is encrypted data then we return default values or as is if applicable (binary data) +func (p *PgSQLDataEncoderProcessor) encodeBinary(ctx context.Context, data []byte, setting config.ColumnEncryptionSetting, columnInfo base.ColumnInfo, logger *logrus.Entry) (context.Context, []byte, error) { + if len(data) == 0 { + return ctx, data, nil + } + switch setting.GetEncryptedDataType() { + case common2.EncryptedType_String: + if !base.IsDecryptedFromContext(ctx) { + if newVal := setting.GetDefaultDataValue(); newVal != nil { + return ctx, []byte(*newVal), nil + } + } + return ctx, data, nil + case common2.EncryptedType_Bytes: + if !base.IsDecryptedFromContext(ctx) { + if newVal := setting.GetDefaultDataValue(); newVal != nil { + binValue, err := base64.StdEncoding.DecodeString(*newVal) + if err == nil { + return ctx, binValue, nil + } + logger.WithError(err).Errorln("Can't decode base64 default value") + } + } + return ctx, data, nil + case common2.EncryptedType_Int32, common2.EncryptedType_Int64: + size := 8 + if setting.GetEncryptedDataType() == common2.EncryptedType_Int32 { + size = 4 + } + // convert back from text to binary + value, err := strconv.ParseInt(string(data), 10, 64) + // we don't return error to not cause connection drop on Acra side and pass it to app to deal with it + if err != nil { + if newVal := setting.GetDefaultDataValue(); newVal != nil { + value, err = strconv.ParseInt(*newVal, 10, 64) + if err != nil { + logger.WithError(err).Errorln("Can't parse default integer value") + return ctx, data, err + } + } else { + logger.WithError(err).Errorln("Can't decode int value and no default value") + return ctx, data, nil + } + } + newData := make([]byte, size) + switch size { + case 4: + binary.BigEndian.PutUint32(newData, uint32(value)) + break + case 8: + binary.BigEndian.PutUint64(newData, uint64(value)) + break + } + return ctx, newData, nil + } + + return ctx, data, nil +} + +// encodeText converts data according to Text format received after decryption/de-tokenization according to ColumnEncryptionSetting +// binary -> hex encoded +// string/email -> string if valid UTF8/ASCII otherwise hex encoded +// integers as is +// not decrypted data that left in binary format we replace with default values (integers) or encode to hex (binary, strings) +func (p *PgSQLDataEncoderProcessor) encodeText(ctx context.Context, data []byte, setting config.ColumnEncryptionSetting, columnInfo base.ColumnInfo, logger *logrus.Entry) (context.Context, []byte, error) { + logger = logger.WithField("column", setting.ColumnName()).WithField("decrypted", base.IsDecryptedFromContext(ctx)) + logger.Debugln("Encode text") + if len(data) == 0 { + return ctx, data, nil + } + switch setting.GetEncryptedDataType() { + case common2.EncryptedType_String: + if !base.IsDecryptedFromContext(ctx) { + if newVal := setting.GetDefaultDataValue(); newVal != nil { + logger.WithField("data", string(data)).WithField("default", *newVal).Debugln("Change with default") + return ctx, []byte(*newVal), nil + } + } + case common2.EncryptedType_Bytes: + if !base.IsDecryptedFromContext(ctx) { + if newVal := setting.GetDefaultDataValue(); newVal != nil { + binValue, err := base64.StdEncoding.DecodeString(*newVal) + if err != nil { + return ctx, data, err + } + // override and encode at end of function + data = binValue + } + } + case common2.EncryptedType_Int32, common2.EncryptedType_Int64: + _, err := strconv.ParseInt(string(data), 10, 64) + // if it's valid string literal and decrypted, return as is + if err == nil { + return ctx, data, nil + } + // if it's encrypted binary, then it is binary array that is invalid int literal + if !base.IsDecryptedFromContext(ctx) { + if newVal := setting.GetDefaultDataValue(); newVal != nil { + logger.Debugln("Return default value") + return ctx, []byte(*newVal), nil + } + } + logger.Warningln("Can't decode int value and no default value") + return ctx, data, nil + } + if utils.IsPrintablePostgresqlString(data) { + return ctx, data, nil + } + return ctx, utils.PgEncodeToHex(data), nil +} + +// OnColumn encode binary value to text and back. Should be before and after tokenizer processor +func (p *PgSQLDataEncoderProcessor) OnColumn(ctx context.Context, data []byte) (context.Context, []byte, error) { + columnSetting, ok := encryptor.EncryptionSettingFromContext(ctx) + if !ok { + // for case when data encrypted with acrastructs on app's side and used without any encryption setting + columnSetting = &config.BasicColumnEncryptionSetting{} + } + logger := logging.GetLoggerFromContext(ctx) + columnInfo, ok := base.ColumnInfoFromContext(ctx) + if !ok { + logger.WithField("processor", "PgSQLDataEncoderProcessor").Warningln("No column info in ctx") + // we can't do anything + return ctx, data, nil + } + if columnInfo.IsBinaryFormat() { + return p.encodeBinary(ctx, data, columnSetting, columnInfo, logger) + } + return p.encodeText(ctx, data, columnSetting, columnInfo, logger) +} + +// PgSQLDataDecoderProcessor implements processor and decode binary/text values from DB +type PgSQLDataDecoderProcessor struct{} + +// NewPgSQLDataDecoderProcessor return new data decoder from text/binary format from database side +func NewPgSQLDataDecoderProcessor() (*PgSQLDataDecoderProcessor, error) { + return &PgSQLDataDecoderProcessor{}, nil +} + +// ID return name of processor +func (p *PgSQLDataDecoderProcessor) ID() string { + return "PgSQLDataDecoderProcessor" +} + +func (p *PgSQLDataDecoderProcessor) decodeBinary(ctx context.Context, data []byte, setting config.ColumnEncryptionSetting, columnInfo base.ColumnInfo, logger *logrus.Entry) (context.Context, []byte, error) { + var newData [8]byte + // convert from binary to text literal because tokenizer expects int value as string literal + switch setting.GetEncryptedDataType() { + case common2.EncryptedType_Int32, common2.EncryptedType_Int64: + // We decode only tokenized data because it should be valid 4/8 byte values + // If it is encrypted integers then we will see here encrypted blob that cannot be decoded and should be decrypted + // in next handlers. So we return value as is + + // acra operates over string SQL values so here we expect valid int binary values that we should + // convert to string SQL value + if len(data) == 4 { + // if high byte is 0xff then it is negative number and we should fill all previous bytes with 0xx too + // otherwise with zeroes + if data[0] == 0xff { + copy(newData[:4], []byte{0xff, 0xff, 0xff, 0xff}) + copy(newData[4:], data) + } else { + // extend int32 from 4 bytes to int64 with zeroes + copy(newData[:4], []byte{0, 0, 0, 0}) + copy(newData[4:], data) + } + // we accept here only 4 or 8 byte values + } else if len(data) != 8 { + return ctx, data, nil + } else { + copy(newData[:], data) + } + value := binary.BigEndian.Uint64(newData[:]) + return ctx, []byte(strconv.FormatInt(int64(value), 10)), nil + } + // binary and string values in binary format we return as is because it is encrypted blob + return ctx, data, nil +} + +// decodeText converts data from text format for decryptors/de-tokenizers according to ColumnEncryptionSetting +// hex/octal binary -> raw binary data +func (p *PgSQLDataDecoderProcessor) decodeText(ctx context.Context, data []byte, setting config.ColumnEncryptionSetting, columnInfo base.ColumnInfo, logger *logrus.Entry) (context.Context, []byte, error) { + if config.IsBinaryDataOperation(setting) { + // decryptor operates over blobs so all data types will be encrypted as hex/octal string values that we should + // decode before decryption + decodedData, err := utils.DecodeEscaped(data) + if err != nil { + logger.WithError(err).Errorln("Can't decode binary data for decryption") + return ctx, data, nil + } + return ctx, decodedData, nil + } + // all other non-binary data should be valid SQL literals like integers or strings and Acra works with them as is + return ctx, data, nil +} + +// OnColumn encode binary value to text and back. Should be before and after tokenizer processor +func (p *PgSQLDataDecoderProcessor) OnColumn(ctx context.Context, data []byte) (context.Context, []byte, error) { + columnSetting, ok := encryptor.EncryptionSettingFromContext(ctx) + if !ok { + // for case when data encrypted with acrastructs on app's side and used without any encryption setting + columnSetting = &config.BasicColumnEncryptionSetting{} + } + logger := logging.GetLoggerFromContext(ctx) + columnInfo, ok := base.ColumnInfoFromContext(ctx) + if !ok { + logger.WithField("processor", "PgSQLDataDecoderProcessor").Warningln("No column info in ctx") + // we can't do anything + return ctx, data, nil + } + if columnInfo.IsBinaryFormat() { + return p.decodeBinary(ctx, data, columnSetting, columnInfo, logger) + } + return p.decodeText(ctx, data, columnSetting, columnInfo, logger) +} diff --git a/decryptor/postgresql/data_encoder_test.go b/decryptor/postgresql/data_encoder_test.go new file mode 100644 index 000000000..e41e5f5c6 --- /dev/null +++ b/decryptor/postgresql/data_encoder_test.go @@ -0,0 +1,468 @@ +package postgresql + +import ( + "bytes" + "context" + "errors" + "github.com/cossacklabs/acra/decryptor/base" + "github.com/cossacklabs/acra/encryptor" + "github.com/cossacklabs/acra/encryptor/config" + "github.com/cossacklabs/acra/logging" + "github.com/cossacklabs/acra/pseudonymization/common" + "github.com/cossacklabs/acra/utils" + "github.com/sirupsen/logrus" + "strconv" + "strings" + "testing" +) + +// TestEncodingDecodingProcessorBinaryIntData checks decoding binary INT values to string SQL literals and back +func TestEncodingDecodingProcessorBinaryIntData(t *testing.T) { + type testcase struct { + binValue []byte + stringValue []byte + encodeErr error + decodeErr error + binarySize int + } + testcases := []testcase{ + // int32 without errors + {binValue: []byte{0, 0, 0, 0}, stringValue: []byte("0"), encodeErr: nil, decodeErr: nil, binarySize: 4}, + {binValue: []byte{255, 255, 255, 255}, stringValue: []byte("-1"), encodeErr: nil, decodeErr: nil, binarySize: 4}, + {binValue: []byte{0, 0, 0, 128}, stringValue: []byte("128"), encodeErr: nil, decodeErr: nil, binarySize: 4}, + {binValue: []byte{255, 255, 255, 128}, stringValue: []byte("-128"), encodeErr: nil, decodeErr: nil, binarySize: 4}, + + // int64 without errors + {binValue: []byte{0, 0, 0, 0, 0, 0, 0, 0}, stringValue: []byte("0"), encodeErr: nil, decodeErr: nil, binarySize: 8}, + {binValue: []byte{255, 255, 255, 255, 255, 255, 255, 255}, stringValue: []byte("-1"), encodeErr: nil, decodeErr: nil, binarySize: 8}, + {binValue: []byte{0, 0, 0, 0, 0, 0, 0, 128}, stringValue: []byte("128"), encodeErr: nil, decodeErr: nil, binarySize: 8}, + {binValue: []byte{255, 255, 255, 255, 255, 255, 255, 128}, stringValue: []byte("-128"), encodeErr: nil, decodeErr: nil, binarySize: 8}, + } + sizeToTokenType := map[int]string{ + 4: "int32", + 8: "int64", + // set correct values for incorrect sizes + 3: "int32", + 7: "int64", + } + + encoder, err := NewPgSQLDataEncoderProcessor() + if err != nil { + t.Fatal(err) + } + decoder, err := NewPgSQLDataDecoderProcessor() + if err != nil { + t.Fatal(err) + } + for i, tcase := range testcases { + // use -1 as invalid binary size that should be ignored + columnInfo := base.NewColumnInfo(0, "", true, -1) + accessContext := &base.AccessContext{} + accessContext.SetColumnInfo(columnInfo) + ctx := base.SetAccessContextToContext(context.Background(), accessContext) + testSetting := config.BasicColumnEncryptionSetting{ + Tokenized: true, + DataType: sizeToTokenType[tcase.binarySize], + TokenType: sizeToTokenType[tcase.binarySize]} + ctx = encryptor.NewContextWithEncryptionSetting(ctx, &testSetting) + ctx, strData, err := decoder.OnColumn(ctx, tcase.binValue) + if err != tcase.decodeErr { + t.Fatalf("[%d] Expect %s, took %s\n", i, tcase.decodeErr, err) + } + if !bytes.Equal(tcase.stringValue, strData) { + t.Fatalf("[%d] Expect '%s', took '%s'\n", i, tcase.stringValue, strData) + } + _, binData, err := encoder.OnColumn(ctx, strData) + if err != tcase.encodeErr { + t.Fatalf("[%d] Expect %s, took %s\n", i, tcase.encodeErr, err) + } + // we check that start value == final value only if err == nil and check success whole encoding/decoding + if err == nil { + if !bytes.Equal(binData, tcase.binValue) { + t.Fatalf("[%d] Expect '%s', took '%s'\n", i, binData, tcase.binValue) + } + } else { + // if was error then decoded data should be the same as encoded + if !bytes.Equal(binData, tcase.stringValue) { + t.Fatalf("[%d] Expect '%s', took '%s'\n", i, tcase.stringValue, binData) + } + } + } +} + +func TestSkipWithoutSetting(t *testing.T) { + encoder, err := NewPgSQLDataEncoderProcessor() + if err != nil { + t.Fatal(err) + } + decoder, err := NewPgSQLDataDecoderProcessor() + if err != nil { + t.Fatal(err) + } + testData := []byte("some data") + for _, subscriber := range []base.DecryptionSubscriber{encoder, decoder} { + // without column setting data + ctx := context.Background() + _, data, err := subscriber.OnColumn(ctx, testData) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(data, testData) { + t.Fatal("Result data should be the same") + } + // test without column info + columnSetting := &config.BasicColumnEncryptionSetting{} + ctx = encryptor.NewContextWithEncryptionSetting(ctx, columnSetting) + _, data, err = subscriber.OnColumn(ctx, testData) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(data, testData) { + t.Fatal("Result data should be the same") + } + } +} + +func TestTextMode(t *testing.T) { + encoder, err := NewPgSQLDataEncoderProcessor() + if err != nil { + t.Fatal(err) + } + decoder, err := NewPgSQLDataDecoderProcessor() + if err != nil { + t.Fatal(err) + } + + type testcase struct { + input []byte + decodedData []byte + encodedData []byte + decodeErr error + encodeErr error + setting config.ColumnEncryptionSetting + logMessage string + } + strDefaultValue := "123" + testcases := []testcase{ + // decoder expects valid string and pass as is, so no errors. but on encode operation it expects valid int literal + {input: []byte("some data"), decodedData: []byte("some data"), encodedData: []byte("some data"), + decodeErr: nil, encodeErr: nil, + setting: &config.BasicColumnEncryptionSetting{Tokenized: true, DataType: "int32"}, + logMessage: `Can't decode int value and no default value`}, + + {input: []byte("123"), decodedData: []byte("123"), encodedData: []byte("123"), decodeErr: nil, encodeErr: nil, + setting: &config.BasicColumnEncryptionSetting{Tokenized: true, DataType: "int32"}}, + + // encryption/decryption integer data, not tokenization + {input: []byte("some data"), decodedData: []byte("some data"), encodedData: []byte("some data"), decodeErr: nil, encodeErr: nil, + setting: &config.BasicColumnEncryptionSetting{Tokenized: false, DataType: "int32"}, + logMessage: `Can't decode int value and no default value`}, + + // encryption/decryption integer data, not tokenization + {input: []byte("some data"), decodedData: []byte("some data"), encodedData: []byte(strDefaultValue), decodeErr: nil, encodeErr: nil, + setting: &config.BasicColumnEncryptionSetting{Tokenized: false, DataType: "int32", DefaultDataValue: &strDefaultValue}}, + + // invalid binary hex value that should be returned as is. Also encoded into hex due to invalid hex value + {input: []byte("\\xTT"), decodedData: []byte("\\xTT"), encodedData: []byte("\\x5c785454"), decodeErr: nil, encodeErr: nil, + setting: &config.BasicColumnEncryptionSetting{Tokenized: false, DataType: "bytes"}}, + // printable valid value returned as is + {input: []byte("valid string"), decodedData: []byte("valid string"), encodedData: []byte("valid string"), decodeErr: nil, encodeErr: nil, + setting: &config.BasicColumnEncryptionSetting{Tokenized: false, DataType: "bytes"}}, + {input: []byte("valid string"), decodedData: []byte("valid string"), encodedData: []byte("valid string"), decodeErr: nil, encodeErr: nil, + setting: &config.BasicColumnEncryptionSetting{Tokenized: false, DataType: "str"}}, + + // empty values + {input: []byte{}, decodedData: []byte{}, encodedData: []byte{}, decodeErr: nil, encodeErr: nil, + setting: &config.BasicColumnEncryptionSetting{Tokenized: false, DataType: "bytes"}}, + // empty values + {input: nil, decodedData: nil, encodedData: nil, decodeErr: nil, encodeErr: nil, + setting: &config.BasicColumnEncryptionSetting{Tokenized: false, DataType: "bytes"}}, + // empty values + {input: []byte{}, decodedData: []byte{}, encodedData: []byte{}, decodeErr: nil, encodeErr: nil, + setting: &config.BasicColumnEncryptionSetting{Tokenized: false, DataType: "str"}}, + // empty values + {input: nil, decodedData: nil, encodedData: nil, decodeErr: nil, encodeErr: nil, + setting: &config.BasicColumnEncryptionSetting{Tokenized: false, DataType: "str"}}, + } + + columnInfo := base.NewColumnInfo(0, "", false, 4) + accessContext := &base.AccessContext{} + accessContext.SetColumnInfo(columnInfo) + ctx := base.SetAccessContextToContext(context.Background(), accessContext) + logger := logrus.New() + entry := logrus.NewEntry(logger) + logBuffer := &bytes.Buffer{} + logger.SetOutput(logBuffer) + ctx = logging.SetLoggerToContext(ctx, entry) + for i, tcase := range testcases { + logBuffer.Reset() + ctx = encryptor.NewContextWithEncryptionSetting(ctx, tcase.setting) + _, decodedData, decodeErr := decoder.OnColumn(ctx, tcase.input) + if err != nil { + t.Fatal(err) + } + if decodeErr != tcase.decodeErr { + t.Fatalf("[%d] Incorrect decode error. Expect %s, took %s\n", i, tcase.decodeErr, decodeErr) + } + if !bytes.Equal(decodedData, tcase.decodedData) { + t.Fatalf("[%d] Result data should be the same\n", i) + } + _, encodedData, encodeErr := encoder.OnColumn(ctx, decodedData) + if encodeErr != tcase.encodeErr && !errors.As(encodeErr, &tcase.encodeErr) { + t.Fatalf("[%d] Incorrect encode error. Expect %s, took %s\n", i, tcase.encodeErr.Error(), encodeErr.Error()) + } + if !bytes.Equal(encodedData, tcase.encodedData) { + t.Fatalf("[%d] Result data should be the same\n", i) + } + if len(tcase.logMessage) > 0 && !strings.Contains(logBuffer.String(), tcase.logMessage) { + t.Fatal("Log buffer doesn't contain expected message") + } + } +} + +func TestBinaryMode(t *testing.T) { + encoder, err := NewPgSQLDataEncoderProcessor() + if err != nil { + t.Fatal(err) + } + decoder, err := NewPgSQLDataDecoderProcessor() + if err != nil { + t.Fatal(err) + } + + type testcase struct { + input []byte + decodedData []byte + encodedData []byte + decodeErr error + encodeErr error + setting config.ColumnEncryptionSetting + logMessage string + } + strDefaultValue := "1" + testcases := []testcase{ + // decoder expects valid string and pass as is, so no errors. but on encode operation it expects valid int literal + {input: []byte("some data"), decodedData: []byte("some data"), encodedData: []byte("some data"), + decodeErr: nil, encodeErr: nil, + setting: &config.BasicColumnEncryptionSetting{DataType: "int32"}, + logMessage: `Can't decode int value and no default value`}, + + {input: []byte{0, 0, 0, 1}, decodedData: []byte("1"), encodedData: []byte{0, 0, 0, 1}, decodeErr: nil, encodeErr: nil, + setting: &config.BasicColumnEncryptionSetting{DataType: "int32"}}, + + // encryption/decryption integer data, not tokenization + {input: []byte("some data"), decodedData: []byte("some data"), encodedData: []byte("some data"), decodeErr: nil, encodeErr: nil, + setting: &config.BasicColumnEncryptionSetting{DataType: "int32"}, + logMessage: `Can't decode int value and no default value`}, + + // encryption/decryption integer data, not tokenization + {input: []byte("some data"), decodedData: []byte("some data"), encodedData: []byte{0, 0, 0, 1}, decodeErr: nil, encodeErr: nil, + setting: &config.BasicColumnEncryptionSetting{DataType: "int32", DefaultDataValue: &strDefaultValue}}, + + // invalid binary hex value that should be returned as is. Also encoded into hex due to invalid hex value + {input: []byte("\\xTT"), decodedData: []byte("\\xTT"), encodedData: []byte("\\xTT"), decodeErr: nil, encodeErr: nil, + setting: &config.BasicColumnEncryptionSetting{DataType: "bytes"}}, + // printable valid value returned as is + {input: []byte("valid string"), decodedData: []byte("valid string"), encodedData: []byte("valid string"), decodeErr: nil, encodeErr: nil, + setting: &config.BasicColumnEncryptionSetting{DataType: "bytes"}}, + + // empty values + {input: []byte{}, decodedData: []byte{}, encodedData: []byte{}, decodeErr: nil, encodeErr: nil, + setting: &config.BasicColumnEncryptionSetting{DataType: "bytes"}}, + // empty values + {input: nil, decodedData: nil, encodedData: nil, decodeErr: nil, encodeErr: nil, + setting: &config.BasicColumnEncryptionSetting{DataType: "bytes"}}, + } + + columnInfo := base.NewColumnInfo(0, "", true, 4) + accessContext := &base.AccessContext{} + accessContext.SetColumnInfo(columnInfo) + ctx := base.SetAccessContextToContext(context.Background(), accessContext) + logger := logrus.New() + entry := logrus.NewEntry(logger) + logBuffer := &bytes.Buffer{} + logger.SetOutput(logBuffer) + ctx = logging.SetLoggerToContext(ctx, entry) + for i, tcase := range testcases { + logBuffer.Reset() + ctx = encryptor.NewContextWithEncryptionSetting(ctx, tcase.setting) + _, decodedData, decodeErr := decoder.OnColumn(ctx, tcase.input) + if err != nil { + t.Fatal(err) + } + if decodeErr != tcase.decodeErr { + t.Fatalf("[%d] Incorrect decode error. Expect %s, took %s\n", i, tcase.decodeErr, decodeErr) + } + if !bytes.Equal(decodedData, tcase.decodedData) { + t.Fatalf("[%d] Result data should be the same\n", i) + } + _, encodedData, encodeErr := encoder.OnColumn(ctx, decodedData) + if encodeErr != tcase.encodeErr && !errors.As(encodeErr, &tcase.encodeErr) { + t.Fatalf("[%d] Incorrect encode error. Expect %s, took %s\n", i, tcase.encodeErr.Error(), encodeErr.Error()) + } + if !bytes.Equal(encodedData, tcase.encodedData) { + t.Fatalf("[%d] Result data should be the same\n", i) + } + if len(tcase.logMessage) > 0 && !strings.Contains(logBuffer.String(), tcase.logMessage) { + t.Fatalf("[%d] Log buffer doesn't contain expected message\n", i) + } + } +} + +func TestEncodingDecodingTextFormat(t *testing.T) { + encoder, err := NewPgSQLDataEncoderProcessor() + if err != nil { + t.Fatal(err) + } + decoder, err := NewPgSQLDataDecoderProcessor() + if err != nil { + t.Fatal(err) + } + type testcase struct { + inputValue []byte + outputValue []byte + binValue []byte + tokenType common.TokenType + } + testcases := []testcase{ + {inputValue: []byte(`valid string`), outputValue: []byte(`valid string`), binValue: []byte(`valid string`), tokenType: common.TokenType_String}, + {inputValue: []byte(`valid string`), outputValue: []byte(`valid string`), binValue: []byte(`valid string`), tokenType: common.TokenType_Email}, + // input hex encoded value that looks like a valid string should be returned as string literal + {inputValue: []byte(`\x76616c696420737472696e67`), outputValue: []byte(`valid string`), binValue: []byte(`valid string`), tokenType: common.TokenType_Bytes}, + + // max int32 + {inputValue: []byte(`2147483647`), outputValue: []byte(`2147483647`), binValue: []byte(`2147483647`), tokenType: common.TokenType_Int32}, + {inputValue: []byte(`-2147483648`), outputValue: []byte(`-2147483648`), binValue: []byte(`-2147483648`), tokenType: common.TokenType_Int32}, + // max int64 + {inputValue: []byte(`9223372036854775807`), outputValue: []byte(`9223372036854775807`), binValue: []byte(`9223372036854775807`), tokenType: common.TokenType_Int64}, + {inputValue: []byte(`-9223372036854775808`), outputValue: []byte(`-9223372036854775808`), binValue: []byte(`-9223372036854775808`), tokenType: common.TokenType_Int64}, + } + accessContext := &base.AccessContext{} + // use -1 as invalid binary size that should be ignored + columnInfo := base.NewColumnInfo(0, "", false, -1) + accessContext.SetColumnInfo(columnInfo) + ctx := base.SetAccessContextToContext(context.Background(), accessContext) + envelopeValue := config.CryptoEnvelopeTypeAcraBlock + // assign value with pointer and change value in the loop below + testSetting := config.BasicColumnEncryptionSetting{CryptoEnvelope: &envelopeValue} + ctx = encryptor.NewContextWithEncryptionSetting(ctx, &testSetting) + for i, tcase := range testcases { + columnInfo = base.NewColumnInfo(0, "", false, len(tcase.inputValue)) + accessContext.SetColumnInfo(columnInfo) + testSetting.TokenType, err = tcase.tokenType.ToConfigString() + if err != nil { + t.Fatal(err) + } + _, binValue, err := decoder.OnColumn(ctx, tcase.inputValue) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(binValue, tcase.binValue) { + t.Fatalf("[%d] Expect binary value %s, took %s\n", i, string(tcase.binValue), string(binValue)) + } + _, textValue, err := encoder.OnColumn(ctx, binValue) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(textValue, tcase.outputValue) { + t.Fatalf("[%d] Expect text %s, took %s\n", i, string(textValue), string(tcase.inputValue)) + } + } +} + +func TestSkipWithoutColumnInfo(t *testing.T) { + encoder, err := NewPgSQLDataEncoderProcessor() + if err != nil { + t.Fatal(err) + } + decoder, err := NewPgSQLDataDecoderProcessor() + if err != nil { + t.Fatal(err) + } + testData := []byte("some data") + accessContext := &base.AccessContext{} + ctx := base.SetAccessContextToContext(context.Background(), accessContext) + testSetting := config.BasicColumnEncryptionSetting{Tokenized: true, TokenType: "int32"} + ctx = encryptor.NewContextWithEncryptionSetting(ctx, &testSetting) + for _, subscriber := range []base.DecryptionSubscriber{encoder, decoder} { + _, data, err := subscriber.OnColumn(ctx, testData) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(data, testData) { + t.Fatal("Result data should be the same") + } + } +} + +func TestFailedEncodingInvalidTextValue(t *testing.T) { + encoder, err := NewPgSQLDataEncoderProcessor() + if err != nil { + t.Fatal(err) + } + columnInfo := base.NewColumnInfo(0, "", true, 4) + accessContext := &base.AccessContext{} + accessContext.SetColumnInfo(columnInfo) + ctx := base.SetAccessContextToContext(context.Background(), accessContext) + testSetting := config.BasicColumnEncryptionSetting{DataType: "int32"} + ctx = encryptor.NewContextWithEncryptionSetting(ctx, &testSetting) + testData := []byte("asdas") + // without default value + _, data, err := encoder.OnColumn(ctx, testData) + if err != nil { + t.Fatal("Expects nil on encode error") + } + + // invalid int32 valid value + strValue := utils.BytesToString(testData) + testSetting = config.BasicColumnEncryptionSetting{DataType: "int32", DefaultDataValue: &strValue} + ctx = encryptor.NewContextWithEncryptionSetting(ctx, &testSetting) + _, data, err = encoder.OnColumn(ctx, testData) + numErr, ok := err.(*strconv.NumError) + if !ok { + t.Fatal("Expect strconv.NumError") + } + if numErr.Err != strconv.ErrSyntax { + t.Fatalf("Expect ErrSyntax, took %s\n", numErr.Err) + } + if !bytes.Equal(data, testData) { + t.Fatal("Result data should be the same") + } +} + +func TestFailedEncodingInvalidBinaryValue(t *testing.T) { + encoder, err := NewPgSQLDataEncoderProcessor() + if err != nil { + t.Fatal(err) + } + columnInfo := base.NewColumnInfo(0, "", true, 4) + accessContext := &base.AccessContext{} + accessContext.SetColumnInfo(columnInfo) + ctx := base.SetAccessContextToContext(context.Background(), accessContext) + testSetting := config.BasicColumnEncryptionSetting{DataType: "bytes"} + ctx = encryptor.NewContextWithEncryptionSetting(ctx, &testSetting) + testData := []byte("invalid base64 value") + // without default value + _, data, err := encoder.OnColumn(ctx, testData) + if err != nil { + t.Fatal("Expects nil on encode error") + } + + // invalid int32 valid value + strValue := utils.BytesToString(testData) + testSetting = config.BasicColumnEncryptionSetting{DataType: "bytes", DefaultDataValue: &strValue} + ctx = encryptor.NewContextWithEncryptionSetting(ctx, &testSetting) + logger := logrus.New() + entry := logrus.NewEntry(logger) + logBuffer := &bytes.Buffer{} + logger.SetOutput(logBuffer) + ctx = logging.SetLoggerToContext(ctx, entry) + _, data, err = encoder.OnColumn(ctx, testData) + if !bytes.Contains(logBuffer.Bytes(), []byte("Can't decode base64 default value")) { + t.Fatal("Expects warning about failed decoding") + } + if !bytes.Equal(data, testData) { + t.Fatal("Result data should be the same") + } +} diff --git a/decryptor/postgresql/packet_handler.go b/decryptor/postgresql/packet_handler.go index 1128b9f4b..09e7da2b4 100644 --- a/decryptor/postgresql/packet_handler.go +++ b/decryptor/postgresql/packet_handler.go @@ -8,8 +8,9 @@ import ( "io" "github.com/cossacklabs/acra/decryptor/base" + "github.com/cossacklabs/acra/encryptor" "github.com/cossacklabs/acra/logging" - "github.com/cossacklabs/acra/utils" + "github.com/jackc/pgx/pgproto3" "github.com/sirupsen/logrus" ) @@ -68,7 +69,7 @@ func (packet *PacketHandler) updatePacketLength(newLength int) { } // updateDataFromColumns check that any column's data was changed and update packet length and data block with new data -func (packet *PacketHandler) updateDataFromColumns() { +func (packet *PacketHandler) updateDataFromColumns(queryDataItems []*encryptor.QueryDataItem) { columnsDataChanged := false // check is any column was changed for i := 0; i < packet.columnCount; i++ { @@ -78,25 +79,22 @@ func (packet *PacketHandler) updateDataFromColumns() { } } if columnsDataChanged { + packet.descriptionBuf.Reset() + var columnCountBuf [2]byte + binary.BigEndian.PutUint16(columnCountBuf[:], uint16(packet.columnCount)) + packet.descriptionBuf.Write(columnCountBuf[:]) + + for i := 0; i < packet.columnCount; i++ { + column := packet.Columns[i] + packet.descriptionBuf.Write(column.LengthBuf[:]) + packet.descriptionBuf.Write(column.data) + } // column length buffer wasn't included to column length value and should be accumulated too // + 2 is column count buffer newDataLength := packet.columnCount*4 + 2 for i := 0; i < packet.columnCount; i++ { newDataLength += packet.Columns[i].Length() } - packet.descriptionBuf.Reset() - packet.descriptionBuf.Grow(newDataLength) - - columnCountBuf := make([]byte, 2) - binary.BigEndian.PutUint16(columnCountBuf, uint16(packet.columnCount)) - packet.descriptionBuf.Write(columnCountBuf) - - for i := 0; i < packet.columnCount; i++ { - packet.descriptionBuf.Write(packet.Columns[i].LengthBuf[:]) - if !packet.Columns[i].IsNull() { - packet.descriptionBuf.Write(packet.Columns[i].data.Encoded()) - } - } packet.updatePacketLength(newDataLength) } } @@ -137,14 +135,14 @@ func (packet *PacketHandler) sendMessageType() error { // ColumnData hold column length and data type ColumnData struct { LengthBuf [4]byte - data *utils.DecodedData + data []byte changed bool isNull bool } // GetData return raw data, decoded from db format to binary func (column *ColumnData) GetData() []byte { - return column.data.Data() + return column.data } // Length return column length converted from LengthBuf @@ -179,13 +177,13 @@ const ( func (column *ColumnData) readData(reader io.Reader, format base.BoundValueFormat) error { length := column.Length() if int32(length) == NullColumnValue { - column.data = utils.WrapRawDataAsDecoded(nil) + column.data = nil column.isNull = true return nil } column.isNull = false if length == 0 { - column.data = utils.WrapRawDataAsDecoded(nil) + column.data = nil return nil } data := make([]byte, length) @@ -196,15 +194,7 @@ func (column *ColumnData) readData(reader io.Reader, format base.BoundValueForma if err != nil { return err } - if format == base.TextFormat { - column.data, err = utils.DecodeEscaped(data) - if err != nil && err != utils.ErrDecodeOctalString { - return err - } - } else { - // do nothing with binary data - column.data = utils.WrapRawDataAsDecoded(data) - } + column.data = data // ignore utils.ErrDecodeOctalString err = nil @@ -214,11 +204,13 @@ func (column *ColumnData) readData(reader io.Reader, format base.BoundValueForma // SetData to column and update LengthBuf with new size func (column *ColumnData) SetData(newData []byte) { column.changed = true - if column.data == nil { - column.data = utils.WrapRawDataAsDecoded(newData) - } - column.data.Set(newData) - binary.BigEndian.PutUint32(column.LengthBuf[:], uint32(len(column.data.Encoded()))) + column.data = newData + binary.BigEndian.PutUint32(column.LengthBuf[:], uint32(len(column.data))) +} + +// SetDataLength set into LengthBuf +func (column *ColumnData) SetDataLength(length uint32) { + binary.BigEndian.PutUint32(column.LengthBuf[:], length) } // parseColumns split whole data row packet into separate columns data @@ -268,6 +260,16 @@ func (packet *PacketHandler) readMessageType() error { return base.CheckReadWrite(n, 1, err) } +// IsRowDescription return true if packet has RowDescription type +func (packet *PacketHandler) IsRowDescription() bool { + return packet.messageType[0] == RowDescriptionType +} + +// IsParameterDescription return true if packet has ParameterDescription type +func (packet *PacketHandler) IsParameterDescription() bool { + return packet.messageType[0] == ParameterDescriptionType +} + // IsDataRow return true if packet has DataRow type func (packet *PacketHandler) IsDataRow() bool { return packet.messageType[0] == DataRowMessageType @@ -344,6 +346,24 @@ func (packet *PacketHandler) GetExecuteData() (*ExecutePacket, error) { return execute, nil } +// GetRowDescriptionData return parsed RowDescription packet +func (packet *PacketHandler) GetRowDescriptionData() (*pgproto3.RowDescription, error) { + rowDescription := &pgproto3.RowDescription{} + if err := rowDescription.Decode(packet.descriptionBufferCopy()); err != nil { + return nil, err + } + return rowDescription, nil +} + +// GetParameterDescriptionData return parsed ParameterDescription packet +func (packet *PacketHandler) GetParameterDescriptionData() (*pgproto3.ParameterDescription, error) { + parameterDescription := &pgproto3.ParameterDescription{} + if err := parameterDescription.Decode(packet.descriptionBufferCopy()); err != nil { + return nil, err + } + return parameterDescription, nil +} + // ReplaceQuery query in packet with new query and update packet length func (packet *PacketHandler) ReplaceQuery(newQuery string) { if packet.IsSimpleQuery() { diff --git a/decryptor/postgresql/packet_handler_test.go b/decryptor/postgresql/packet_handler_test.go index 67be99fcb..4906edd4e 100644 --- a/decryptor/postgresql/packet_handler_test.go +++ b/decryptor/postgresql/packet_handler_test.go @@ -180,7 +180,7 @@ func TestColumnData_readData(t *testing.T) { columnLength uint32 format base.BoundValueFormat } - t.Run("Binary encoding", func(t *testing.T) { + t.Run("Binary data read", func(t *testing.T) { testCases := []testCase{ // valid hex encoded value {[]byte("\\xaabb"), []byte("\\xaabb"), []byte("\\xaabb"), 6, base.BinaryFormat}, @@ -190,11 +190,11 @@ func TestColumnData_readData(t *testing.T) { {[]byte{1, 2, 3}, []byte{1, 2, 3}, []byte{1, 2, 3}, 3, base.BinaryFormat}, // valid hex encoded value decoded to 2 digits - {[]byte("\\xaabb"), []byte{170, 187}, []byte("\\xaabb"), 6, base.TextFormat}, + {[]byte("\\xaabb"), []byte("\\xaabb"), []byte("\\xaabb"), 6, base.TextFormat}, // valid hex encoded value decoded to 2 digits - {[]byte("\\x"), []byte{}, []byte("\\x"), 2, base.TextFormat}, + {[]byte("\\x"), []byte("\\x"), []byte("\\x"), 2, base.TextFormat}, // valid octal value decoded to 1 digit - {[]byte("\\001"), []byte{1}, []byte("\\001"), 4, base.TextFormat}, + {[]byte("\\001"), []byte("\\001"), []byte("\\001"), 4, base.TextFormat}, // full binary value that should be as is {[]byte{1, 2, 3}, []byte{1, 2, 3}, []byte{1, 2, 3}, 3, base.TextFormat}, } @@ -204,12 +204,12 @@ func TestColumnData_readData(t *testing.T) { if err := column.readData(bytes.NewReader(testcase.data), testcase.format); err != nil { t.Fatal(i, "Error on read data by column", err) } - if !bytes.Equal(column.data.Encoded(), testcase.expected) { + if !bytes.Equal(column.data, testcase.expected) { t.Fatalf("Incorrectly encoded data, %v != %v\n", - column.data.Encoded(), testcase.expected) + column.data, testcase.expected) } - if !bytes.Equal(column.data.Data(), testcase.decoded) { - t.Fatalf("Decoded data not equal to expected, %s != %s\n", column.data.Data(), testcase.decoded) + if !bytes.Equal(column.data, testcase.decoded) { + t.Fatalf("%d. Decoded data not equal to expected, %s != %s\n", i, column.data, testcase.decoded) } } }) @@ -232,7 +232,7 @@ func TestParseColumns(t *testing.T) { if len(handler.Columns) != 1 { t.Fatal("Incorrect length of columns") } - if !bytes.Equal(handler.Columns[0].data.Encoded(), testData) { + if !bytes.Equal(handler.Columns[0].data, testData) { t.Fatal("Incorrect ") } } diff --git a/decryptor/postgresql/pg_decryptor.go b/decryptor/postgresql/pg_decryptor.go index 0eaf668a3..466fd84ff 100644 --- a/decryptor/postgresql/pg_decryptor.go +++ b/decryptor/postgresql/pg_decryptor.go @@ -22,6 +22,8 @@ import ( "context" "encoding/binary" "errors" + "github.com/cossacklabs/acra/encryptor" + "github.com/cossacklabs/acra/encryptor/config" "net" "time" @@ -92,6 +94,8 @@ const ( ParseCompleteMessageType byte = '1' BindCompleteMessageType byte = '2' ReadyForQueryMessageType byte = 'Z' + RowDescriptionType byte = 'T' + ParameterDescriptionType byte = 't' TLSTimeout = time.Second * 2 ) @@ -159,11 +163,6 @@ func NewPgProxy(session base.ClientSession, parser *sqlparser.Parser, setting ba }, nil } -// SubscribeOnColumnDecryption subscribes for notifications about the column, indexed from left to right starting with zero. -func (proxy *PgProxy) SubscribeOnColumnDecryption(column int, subscriber base.DecryptionSubscriber) { - proxy.decryptionObserver.SubscribeOnColumnDecryption(column, subscriber) -} - // SubscribeOnAllColumnsDecryption subscribes for notifications on each column. func (proxy *PgProxy) SubscribeOnAllColumnsDecryption(subscriber base.DecryptionSubscriber) { proxy.decryptionObserver.SubscribeOnAllColumnsDecryption(subscriber) @@ -177,7 +176,9 @@ func (proxy *PgProxy) Unsubscribe(subscriber base.DecryptionSubscriber) { func (proxy *PgProxy) onColumnDecryption(parentCtx context.Context, i int, data []byte, binaryFormat bool) ([]byte, error) { accessContext := base.AccessContextFromContext(parentCtx) accessContext.SetColumnInfo(base.NewColumnInfo(i, "", binaryFormat, len(data))) - return proxy.decryptionObserver.OnColumnDecryption(parentCtx, i, data) + // create new ctx per column processing + ctx := base.SetAccessContextToContext(parentCtx, accessContext) + return proxy.decryptionObserver.OnColumnDecryption(ctx, i, data) } // AddQueryObserver implement QueryObservable interface and proxy call to ObserverManager @@ -346,21 +347,21 @@ func (proxy *PgProxy) handleQueryPacket(ctx context.Context, packet *PacketHandl func (proxy *PgProxy) handleBindPacket(ctx context.Context, packet *PacketHandler, logger *log.Entry) (bool, error) { bind := proxy.protocolState.PendingBind() - log := logger.WithField("portal", bind.PortalName()).WithField("statement", bind.StatementName()) - log.Debug("Bind packet") + logger = logger.WithField("portal", bind.PortalName()).WithField("statement", bind.StatementName()) + logger.Debug("Bind packet") // There must be previously registered prepared statement with requested name. If there isn't // it's likely due to a client error. Print a warning and let the packet through as is. // We can't do anything with it and the database should respond with an error. registry := proxy.session.PreparedStatementRegistry() statement, err := registry.StatementByName(bind.StatementName()) if err != nil { - log.WithError(err).Error("Failed to handle Bind packet: can't find prepared statement") + logger.WithError(err).Error("Failed to handle Bind packet: can't find prepared statement") return false, nil } // Now, repackage the parameters for processing... If that fails, let the packet through too. parameters, err := bind.GetParameters() if err != nil { - log.WithError(err).Error("Failed to handle Bind packet: can't extract parameters") + logger.WithError(err).Error("Failed to handle Bind packet: can't extract parameters") return false, nil } // Process parameter values. If we can't -- you guessed it -- leave the packet unchanged. @@ -371,7 +372,7 @@ func (proxy *PgProxy) handleBindPacket(ctx context.Context, packet *PacketHandle return false, err } - log.WithError(err).Error("Failed to handle Bind packet") + logger.WithError(err).Error("Failed to handle Bind packet") return false, nil } // Finally, if the parameter values have been changed, update the packet. @@ -380,7 +381,7 @@ func (proxy *PgProxy) handleBindPacket(ctx context.Context, packet *PacketHandle bind.SetParameters(newParameters) err = packet.ReplaceBind(bind) if err != nil { - log.WithError(err).Error("Failed to update Bind packet") + logger.WithError(err).Error("Failed to update Bind packet") } return false, nil } @@ -638,6 +639,11 @@ func (proxy *PgProxy) handleDatabasePacket(ctx context.Context, packet *PacketHa bindPacket := proxy.protocolState.PendingBind() defer proxy.protocolState.forgetPendingBind() return proxy.registerCursor(bindPacket, logger) + case RowDescriptionPacket: + return proxy.handleRowDescription(ctx, packet, logger) + + case ParameterDescriptionPacket: + return proxy.handleParameterDescription(ctx, packet, logger) default: // Forward all other uninteresting packets to the client without processing. @@ -645,6 +651,94 @@ func (proxy *PgProxy) handleDatabasePacket(ctx context.Context, packet *PacketHa } } +func (proxy *PgProxy) handleParameterDescription(ctx context.Context, packet *PacketHandler, logger *log.Entry) error { + clientSession := base.ClientSessionFromContext(ctx) + if clientSession == nil { + logger.Warningln("ParameterDescription packet without ClientSession in context") + return nil + } + items := encryptor.PlaceholderSettingsFromClientSession(clientSession) + if items == nil { + logger.Debugln("ParameterDescription packet without registered recognized encryption settings") + return nil + } + parameterDescription, err := packet.GetParameterDescriptionData() + if err != nil { + logger.WithField(logging.FieldKeyEventCode, logging.EventCodeErrorDBProtocolError). + WithError(err). + Errorln("Can't parse ParameterDescription packet") + return nil + } + changed := false + for i := 0; i < len(parameterDescription.ParameterOIDs); i++ { + setting := items[i] + if setting == nil { + continue + } + if config.HasTypeAwareSupport(setting) { + newOID, ok := mapEncryptedTypeToOID(setting.GetEncryptedDataType()) + if ok { + parameterDescription.ParameterOIDs[i] = newOID + changed = true + } + } + } + if changed { + // 5 is MessageType[1] + PacketLength[4] + PacketPayload + newParameterDescription := make([]byte, 0, 5+packet.descriptionBuf.Len()) + newParameterDescription = parameterDescription.Encode(newParameterDescription) + packet.descriptionBuf.Reset() + packet.descriptionBuf.Write(newParameterDescription[5:]) + } + return nil +} + +func (proxy *PgProxy) handleRowDescription(ctx context.Context, packet *PacketHandler, logger *log.Entry) error { + clientSession := base.ClientSessionFromContext(ctx) + if clientSession == nil { + logger.Warningln("RowDescription packet without ClientSession in context") + return nil + } + items := encryptor.QueryDataItemsFromClientSession(clientSession) + if items == nil { + logger.Debugln("RowDescription packet without registered recognized encryption settings") + return nil + } + rowDescription, err := packet.GetRowDescriptionData() + if err != nil { + logger.WithField(logging.FieldKeyEventCode, logging.EventCodeErrorDBProtocolError). + WithError(err). + Errorln("Can't parse RowDescription packet") + return nil + } + if len(items) != len(rowDescription.Fields) { + log.Errorln("Column count in RowDescription packet not same as parsed query count of columns") + return nil + } + changed := false + for i := 0; i < len(rowDescription.Fields); i++ { + setting := items[i] + if setting == nil { + continue + } + if config.HasTypeAwareSupport(setting.Setting()) { + newOID, ok := mapEncryptedTypeToOID(setting.Setting().GetEncryptedDataType()) + if ok { + rowDescription.Fields[i].DataTypeOID = newOID + changed = true + } + } + } + if changed { + // 5 is MessageType[1] + PacketLength[4] + PacketPayload + newRowDescription := make([]byte, 0, 5+packet.descriptionBuf.Len()) + newRowDescription = rowDescription.Encode(newRowDescription) + packet.descriptionBuf.Reset() + packet.descriptionBuf.Write(newRowDescription[5:]) + } + return nil +} + func (proxy *PgProxy) handleQueryDataPacket(ctx context.Context, packet *PacketHandler, logger *log.Entry) error { logger.Debugln("Matched data row packet") // by default it's text format @@ -684,7 +778,7 @@ func (proxy *PgProxy) handleQueryDataPacket(ctx context.Context, packet *PacketH } format = int(boundFormat) } - + logger.WithField("data_length", len(column.GetData())).WithField("column_index", i).Debugln("Process columns data") newData, err := proxy.onColumnDecryption(ctx, i, column.GetData(), format == dataFormatBinary) if err != nil { logger.WithField(logging.FieldKeyEventCode, logging.EventCodeErrorGeneral). @@ -694,7 +788,12 @@ func (proxy *PgProxy) handleQueryDataPacket(ctx context.Context, packet *PacketH column.SetData(newData) } // After we're done processing the columns, update the actual packet data from them - packet.updateDataFromColumns() + queryDataItems := make([]*encryptor.QueryDataItem, packet.columnCount) + clientSession := base.ClientSessionFromContext(ctx) + if clientSession != nil { + queryDataItems = encryptor.QueryDataItemsFromClientSession(clientSession) + } + packet.updateDataFromColumns(queryDataItems) return nil } diff --git a/decryptor/postgresql/prepared_statements.go b/decryptor/postgresql/prepared_statements.go index 4dac36dd7..01f807920 100644 --- a/decryptor/postgresql/prepared_statements.go +++ b/decryptor/postgresql/prepared_statements.go @@ -19,6 +19,9 @@ package postgresql import ( "encoding/binary" "errors" + "github.com/cossacklabs/acra/encryptor" + "github.com/cossacklabs/acra/encryptor/config/common" + "github.com/cossacklabs/acra/utils" "strconv" "github.com/cossacklabs/acra/decryptor/base" @@ -235,12 +238,21 @@ func (p *pgBoundValue) GetType() byte { // SetData set new value to BoundValue using ColumnEncryptionSetting if provided func (p *pgBoundValue) SetData(newData []byte, setting config.ColumnEncryptionSetting) error { - p.data = newData - if setting == nil { + p.data = newData return nil } + if setting.IsTokenized() { + return p.setTokenizedData(newData, setting) + } else if config.IsBinaryDataOperation(setting) { + return p.setEncryptedData(newData, setting) + } + return nil +} + +func (p *pgBoundValue) setTokenizedData(newData []byte, setting config.ColumnEncryptionSetting) error { + p.data = newData switch p.format { case base.BinaryFormat: switch setting.GetTokenType() { @@ -265,26 +277,52 @@ func (p *pgBoundValue) SetData(newData []byte, setting config.ColumnEncryptionSe return nil } +func (p *pgBoundValue) setEncryptedData(newData []byte, setting config.ColumnEncryptionSetting) error { + p.data = newData + switch p.format { + case base.TextFormat: + // here we take encrypted data and encode it to SQL String value that contains binary data in hex format + // or pass it as is if it is already valid string (all other SQL literals) + p.data = encryptor.PgEncodeToHexString(newData) + return nil + case base.BinaryFormat: + // all our encryption operations applied over text format values to be compatible with text format + // and here we work with encrypted TextFormat values that we should pass as is to server + break + } + + return nil +} + // GetData return BoundValue using ColumnEncryptionSetting if provided -func (p *pgBoundValue) GetData(setting config.ColumnEncryptionSetting) []byte { +func (p *pgBoundValue) GetData(setting config.ColumnEncryptionSetting) ([]byte, error) { if setting == nil { - return p.data + return p.data, nil } decodedData := p.data switch p.format { - // TODO(ilammy, 2020-10-19): handle non-bytes binary data - // Encryptor expects binary data to be passed in raw bytes, but most non-byte-arrays - // are expected in text format. If we get binary parameters, we may need to recode them. + case base.TextFormat: + if setting.OnlyEncryption() || setting.IsSearchable() { + // binary data in TextFormat received as Hex/Octal encoded values + // so we should decode them before processing + + decoded, err := utils.DecodeEscaped(p.data) + if err != nil { + return p.data, err + } + return decoded, nil + + } case base.BinaryFormat: - if setting.IsTokenized() { - switch setting.GetTokenType() { - case tokens.TokenType_Int32: + if setting.IsTokenized() || setting.IsSearchable() || setting.OnlyEncryption() { + switch setting.GetEncryptedDataType() { + case common.EncryptedType_Int32: value := binary.BigEndian.Uint32(p.data) strValue := strconv.FormatInt(int64(value), 10) decodedData = []byte(strValue) - case tokens.TokenType_Int64: + case common.EncryptedType_Int64: // if passed int32 as int64, just extend array and fill by zeroes if len(p.data) == 4 { p.data = append([]byte{0, 0, 0, 0}, p.data...) @@ -295,7 +333,7 @@ func (p *pgBoundValue) GetData(setting config.ColumnEncryptionSetting) []byte { } } } - return decodedData + return decodedData, nil } // Encode format result BoundValue data diff --git a/decryptor/postgresql/prepared_statements_test.go b/decryptor/postgresql/prepared_statements_test.go index 98b73aec3..14bcc4d28 100644 --- a/decryptor/postgresql/prepared_statements_test.go +++ b/decryptor/postgresql/prepared_statements_test.go @@ -316,17 +316,23 @@ func TestNewPgBoundValue(t *testing.T) { boundValue := NewPgBoundValue(sourceData, base.BinaryFormat) sourceData[0] = 22 - - if reflect.DeepEqual(sourceData, boundValue.GetData(nil)) { + value, err := boundValue.GetData(nil) + if err != nil { + t.Fatal(err) + } + if reflect.DeepEqual(sourceData, value) { t.Fatal("BoundValue data should not be equal to sourceData") } }) t.Run("nil data provided", func(t *testing.T) { boundValue := NewPgBoundValue(nil, base.BinaryFormat) - + value, err := boundValue.GetData(nil) + if err != nil { + t.Fatal(err) + } // we need to validate that textData is nil if nil was provided - required for handling NULL values - if boundValue.GetData(nil) != nil { + if value != nil { t.Fatal("BoundValue data should be nil") } }) diff --git a/decryptor/postgresql/protocol.go b/decryptor/postgresql/protocol.go index 0a2e5f58e..77148dedf 100644 --- a/decryptor/postgresql/protocol.go +++ b/decryptor/postgresql/protocol.go @@ -20,17 +20,20 @@ import ( "github.com/cossacklabs/acra/decryptor/base" "github.com/cossacklabs/acra/logging" "github.com/cossacklabs/acra/sqlparser" + "github.com/jackc/pgx/pgproto3" ) // PgProtocolState keeps track of PostgreSQL protocol state. type PgProtocolState struct { parser *sqlparser.Parser - lastPacketType PacketType - pendingQuery base.OnQueryObject - pendingParse *ParsePacket - pendingBind *BindPacket - pendingExecute *ExecutePacket + lastPacketType PacketType + pendingQuery base.OnQueryObject + pendingParse *ParsePacket + pendingBind *BindPacket + pendingExecute *ExecutePacket + pendingRowDescription *pgproto3.RowDescription + pendingParameterDescription *pgproto3.ParameterDescription } // PacketType describes how to handle a message packet. @@ -44,6 +47,8 @@ const ( BindStatementPacket BindCompletePacket DataPacket + RowDescriptionPacket + ParameterDescriptionPacket OtherPacket ) @@ -77,6 +82,11 @@ func (p *PgProtocolState) PendingExecute() *ExecutePacket { return p.pendingExecute } +// PendingRowDescription returns the pending query parameters, if any. +func (p *PgProtocolState) PendingRowDescription() *pgproto3.RowDescription { + return p.pendingRowDescription +} + // HandleClientPacket observes a packet from client to the database, // extracts query information from it, and anticipates future database responses. func (p *PgProtocolState) HandleClientPacket(packet *PacketHandler) error { @@ -106,6 +116,8 @@ func (p *PgProtocolState) HandleClientPacket(packet *PacketHandler) error { p.lastPacketType = ParseStatementPacket p.pendingQuery = base.NewOnQueryObjectFromQuery(parsePacket.QueryString(), p.parser) p.pendingParse = parsePacket + p.pendingParameterDescription = nil + p.pendingBind = nil return nil } @@ -150,6 +162,26 @@ func (p *PgProtocolState) HandleDatabasePacket(packet *PacketHandler) error { return nil } + if packet.IsRowDescription() { + p.lastPacketType = RowDescriptionPacket + rowDescription, err := packet.GetRowDescriptionData() + if err != nil { + return err + } + p.pendingRowDescription = rowDescription + return nil + } + + if packet.IsParameterDescription() { + p.lastPacketType = ParameterDescriptionPacket + parameterDescription, err := packet.GetParameterDescriptionData() + if err != nil { + return err + } + p.pendingParameterDescription = parameterDescription + return nil + } + if packet.IsParseComplete() { p.lastPacketType = ParseCompletePacket return nil diff --git a/decryptor/postgresql/proxy.go b/decryptor/postgresql/proxy.go index 952f68160..492f84800 100644 --- a/decryptor/postgresql/proxy.go +++ b/decryptor/postgresql/proxy.go @@ -65,7 +65,7 @@ func (factory *proxyFactory) New(clientID []byte, clientSession base.ClientSessi schemaStore := factory.setting.TableSchemaStore() storeMask := schemaStore.GetGlobalSettingsMask() // register only if masking/tokenization/searching will be used - if storeMask&(config.SettingSearchFlag|config.SettingMaskingFlag|config.SettingTokenizationFlag) > 0 { + if !base.OnlyDefaultEncryptorSettings(schemaStore) { // register Query processor first before other processors because it match SELECT queries for ColumnEncryptorConfig structs // and store it in AccessContext for next decryptions/encryptions and all other processors rely on that // use nil dataEncryptor to avoid extra computations @@ -77,6 +77,18 @@ func (factory *proxyFactory) New(clientID []byte, clientSession base.ClientSessi proxy.SubscribeOnAllColumnsDecryption(queryEncryptor) } + decoderProcessor, err := NewPgSQLDataDecoderProcessor() + if err != nil { + return nil, err + } + encoderProcessor, err := NewPgSQLDataEncoderProcessor() + if err != nil { + return nil, err + } + // register first to decode all data into text/binary formats expected by handlers according to client/database + //requested formats and ColumnEncryptionSetting + proxy.SubscribeOnAllColumnsDecryption(decoderProcessor) + // poison record processor should be first if factory.setting.PoisonRecordCallbackStorage() != nil && factory.setting.PoisonRecordCallbackStorage().HasCallbacks() { // setting PoisonRecords callback for CryptoHandlers inside registry @@ -98,25 +110,11 @@ func (factory *proxyFactory) New(clientID []byte, clientSession base.ClientSessi return nil, err } - decoderProcessor, err := pseudonymization.NewPgSQLDataEncoderProcessor(pseudonymization.DataEncoderModeDecode) - if err != nil { - return nil, err - } - encoderProcessor, err := pseudonymization.NewPgSQLDataEncoderProcessor(pseudonymization.DataEncoderModeEncode) - if err != nil { - return nil, err - } - // register before tokenization to decode binary value if need - proxy.SubscribeOnAllColumnsDecryption(decoderProcessor) - tokenProcessor, err := pseudonymization.NewTokenProcessor(tokenizer) if err != nil { return nil, err } proxy.SubscribeOnAllColumnsDecryption(tokenProcessor) - // register after tokenization to encode text value to binary if need - proxy.SubscribeOnAllColumnsDecryption(encoderProcessor) - tokenEncryptor, err := pseudonymization.NewTokenEncryptor(tokenizer) if err != nil { return nil, err @@ -172,7 +170,9 @@ func (factory *proxyFactory) New(clientID []byte, clientSession base.ClientSessi return nil, err } proxy.AddQueryObserver(queryEncryptor) - proxy.SubscribeOnAllColumnsDecryption(queryEncryptor) + // register last to encode all data into correct format according to client/database requested formats + // and ColumnEncryptionSetting + proxy.SubscribeOnAllColumnsDecryption(encoderProcessor) return proxy, nil } diff --git a/decryptor/postgresql/proxy_test.go b/decryptor/postgresql/proxy_test.go index 28bc60c94..1888d3e8e 100644 --- a/decryptor/postgresql/proxy_test.go +++ b/decryptor/postgresql/proxy_test.go @@ -157,6 +157,22 @@ func (*tableSchemaStore) GetGlobalSettingsMask() config.SettingMask { type stubSession struct{} +func (s stubSession) GetData(s2 string) (interface{}, bool) { + panic("implement me") +} + +func (s stubSession) SetData(s2 string, i interface{}) { + panic("implement me") +} + +func (s stubSession) DeleteData(s2 string) { + panic("implement me") +} + +func (s stubSession) HasData(s2 string) bool { + panic("implement me") +} + func (stubSession) Context() context.Context { return context.TODO() } diff --git a/decryptor/postgresql/type_conversion.go b/decryptor/postgresql/type_conversion.go new file mode 100644 index 000000000..1c0db6674 --- /dev/null +++ b/decryptor/postgresql/type_conversion.go @@ -0,0 +1,18 @@ +package postgresql + +import ( + "github.com/cossacklabs/acra/encryptor/config/common" + "github.com/jackc/pgx/pgtype" +) + +func mapEncryptedTypeToOID(dataType common.EncryptedType) (uint32, bool) { + switch dataType { + case common.EncryptedType_String: + return pgtype.TextOID, true + case common.EncryptedType_Int32: + return pgtype.Int4OID, true + case common.EncryptedType_Int64: + return pgtype.Int8OID, true + } + return 0, false +} diff --git a/decryptor/postgresql/utils.go b/decryptor/postgresql/utils.go index 762f148b0..a32ebacb0 100644 --- a/decryptor/postgresql/utils.go +++ b/decryptor/postgresql/utils.go @@ -295,7 +295,12 @@ func (p *BindPacket) SetParameters(values []base.BoundValue) { p.paramValues = make([][]byte, len(values)) } for i := range p.paramValues { - p.paramValues[i] = values[i].GetData(nil) + value, err := values[i].GetData(nil) + if err != nil { + log.WithError(err).Errorln("Can't get BoundValue data") + return + } + p.paramValues[i] = value } } diff --git a/encryptor/config/common/encryptedTypes.go b/encryptor/config/common/encryptedTypes.go new file mode 100644 index 000000000..8ad1dcdf1 --- /dev/null +++ b/encryptor/config/common/encryptedTypes.go @@ -0,0 +1,91 @@ +package common + +import ( + "errors" + "strconv" + "unicode/utf8" +) + +// ParseStringEncryptedType parse string value to EncryptedType value +func ParseStringEncryptedType(value string) (EncryptedType, error) { + parsed, ok := encryptedTypeNames[value] + if !ok { + return EncryptedType_Unknown, ErrUnknownEncryptedType + } + return parsed, nil +} + +// Data type names as expected in the configuration file. +var encryptedTypeNames = map[string]EncryptedType{ + "int32": EncryptedType_Int32, + "int64": EncryptedType_Int64, + "str": EncryptedType_String, + "bytes": EncryptedType_Bytes, +} +var supportedEncryptedTypes = map[EncryptedType]bool{ + EncryptedType_Int32: true, + EncryptedType_Int64: true, + EncryptedType_String: true, + EncryptedType_Bytes: true, +} + +// ToConfigString converts value to string used in encryptor_config +func (x EncryptedType) ToConfigString() (val string, err error) { + err = ErrUnknownEncryptedType + switch x { + case EncryptedType_Int32: + return "int32", nil + case EncryptedType_Int64: + return "int64", nil + case EncryptedType_String: + return "str", nil + case EncryptedType_Bytes: + return "bytes", nil + } + return +} + +// Validation errors +var ( + ErrUnknownEncryptedType = errors.New("unknown token type") + ErrUnsupportedEncryptedType = errors.New("data type not supported") +) + +// ValidateEncryptedType return true if value is supported EncryptedType +func ValidateEncryptedType(value EncryptedType) error { + supported, ok := supportedEncryptedTypes[value] + if !ok { + return ErrUnknownEncryptedType + } + if !supported { + return ErrUnsupportedEncryptedType + } + return nil +} + +// ValidateDefaultValue default value according to EncryptedType +// +// str -> validates utf8 string +// bytes - do nothing +// int32/64 - try parse string as integer value with base 10 +func ValidateDefaultValue(value *string, dataType EncryptedType) (err error) { + if value == nil { + return nil + } + switch dataType { + case EncryptedType_Int32: + _, err = strconv.ParseInt(*value, 10, 32) + return err + case EncryptedType_Int64: + _, err = strconv.ParseInt(*value, 10, 64) + return err + case EncryptedType_String: + if !utf8.ValidString(*value) { + return errors.New("invalid utf8 string") + } + return nil + case EncryptedType_Bytes: + return nil + } + return errors.New("not supported EncryptedType") +} diff --git a/encryptor/config/common/encryptedTypes.pb.go b/encryptor/config/common/encryptedTypes.pb.go new file mode 100644 index 000000000..86eb87bff --- /dev/null +++ b/encryptor/config/common/encryptedTypes.pb.go @@ -0,0 +1,227 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.26.0 +// protoc v3.6.1 +// source: encryptedTypes.proto + +package common + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// EncryptedType defines types for encrypted data. +type EncryptedType int32 + +const ( + EncryptedType_Unknown EncryptedType = 0 + EncryptedType_Int32 EncryptedType = 1 + EncryptedType_Int64 EncryptedType = 2 + EncryptedType_String EncryptedType = 3 + EncryptedType_Bytes EncryptedType = 4 +) + +// Enum value maps for EncryptedType. +var ( + EncryptedType_name = map[int32]string{ + 0: "Unknown", + 1: "Int32", + 2: "Int64", + 3: "String", + 4: "Bytes", + } + EncryptedType_value = map[string]int32{ + "Unknown": 0, + "Int32": 1, + "Int64": 2, + "String": 3, + "Bytes": 4, + } +) + +func (x EncryptedType) Enum() *EncryptedType { + p := new(EncryptedType) + *p = x + return p +} + +func (x EncryptedType) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (EncryptedType) Descriptor() protoreflect.EnumDescriptor { + return file_encryptedTypes_proto_enumTypes[0].Descriptor() +} + +func (EncryptedType) Type() protoreflect.EnumType { + return &file_encryptedTypes_proto_enumTypes[0] +} + +func (x EncryptedType) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use EncryptedType.Descriptor instead. +func (EncryptedType) EnumDescriptor() ([]byte, []int) { + return file_encryptedTypes_proto_rawDescGZIP(), []int{0} +} + +// EncryptedValue keeps serialized encrypted value. +type EncryptedValue struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Value []byte `protobuf:"bytes,1,opt,name=value,proto3" json:"value,omitempty"` + Type EncryptedType `protobuf:"varint,2,opt,name=type,proto3,enum=github.com.cossacklabs.acra.encryptor.config.common.EncryptedType" json:"type,omitempty"` +} + +func (x *EncryptedValue) Reset() { + *x = EncryptedValue{} + if protoimpl.UnsafeEnabled { + mi := &file_encryptedTypes_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *EncryptedValue) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*EncryptedValue) ProtoMessage() {} + +func (x *EncryptedValue) ProtoReflect() protoreflect.Message { + mi := &file_encryptedTypes_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use EncryptedValue.ProtoReflect.Descriptor instead. +func (*EncryptedValue) Descriptor() ([]byte, []int) { + return file_encryptedTypes_proto_rawDescGZIP(), []int{0} +} + +func (x *EncryptedValue) GetValue() []byte { + if x != nil { + return x.Value + } + return nil +} + +func (x *EncryptedValue) GetType() EncryptedType { + if x != nil { + return x.Type + } + return EncryptedType_Unknown +} + +var File_encryptedTypes_proto protoreflect.FileDescriptor + +var file_encryptedTypes_proto_rawDesc = []byte{ + 0x0a, 0x14, 0x65, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x54, 0x79, 0x70, 0x65, 0x73, + 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x33, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, + 0x6f, 0x6d, 0x2e, 0x63, 0x6f, 0x73, 0x73, 0x61, 0x63, 0x6b, 0x6c, 0x61, 0x62, 0x73, 0x2e, 0x61, + 0x63, 0x72, 0x61, 0x2e, 0x65, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x6f, 0x72, 0x2e, 0x63, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x22, 0x7e, 0x0a, 0x0e, 0x45, + 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x12, 0x14, 0x0a, + 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x05, 0x76, 0x61, + 0x6c, 0x75, 0x65, 0x12, 0x56, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x0e, 0x32, 0x42, 0x2e, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2e, 0x63, + 0x6f, 0x73, 0x73, 0x61, 0x63, 0x6b, 0x6c, 0x61, 0x62, 0x73, 0x2e, 0x61, 0x63, 0x72, 0x61, 0x2e, + 0x65, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x6f, 0x72, 0x2e, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x2e, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, + 0x64, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x2a, 0x49, 0x0a, 0x0d, 0x45, + 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x54, 0x79, 0x70, 0x65, 0x12, 0x0b, 0x0a, 0x07, + 0x55, 0x6e, 0x6b, 0x6e, 0x6f, 0x77, 0x6e, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x49, 0x6e, 0x74, + 0x33, 0x32, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x49, 0x6e, 0x74, 0x36, 0x34, 0x10, 0x02, 0x12, + 0x0a, 0x0a, 0x06, 0x53, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x10, 0x03, 0x12, 0x09, 0x0a, 0x05, 0x42, + 0x79, 0x74, 0x65, 0x73, 0x10, 0x04, 0x42, 0x35, 0x5a, 0x33, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, + 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63, 0x6f, 0x73, 0x73, 0x61, 0x63, 0x6b, 0x6c, 0x61, 0x62, 0x73, + 0x2f, 0x61, 0x63, 0x72, 0x61, 0x2f, 0x65, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x6f, 0x72, 0x2f, + 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2f, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x62, 0x06, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_encryptedTypes_proto_rawDescOnce sync.Once + file_encryptedTypes_proto_rawDescData = file_encryptedTypes_proto_rawDesc +) + +func file_encryptedTypes_proto_rawDescGZIP() []byte { + file_encryptedTypes_proto_rawDescOnce.Do(func() { + file_encryptedTypes_proto_rawDescData = protoimpl.X.CompressGZIP(file_encryptedTypes_proto_rawDescData) + }) + return file_encryptedTypes_proto_rawDescData +} + +var file_encryptedTypes_proto_enumTypes = make([]protoimpl.EnumInfo, 1) +var file_encryptedTypes_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_encryptedTypes_proto_goTypes = []interface{}{ + (EncryptedType)(0), // 0: github.com.cossacklabs.acra.encryptor.config.common.EncryptedType + (*EncryptedValue)(nil), // 1: github.com.cossacklabs.acra.encryptor.config.common.EncryptedValue +} +var file_encryptedTypes_proto_depIdxs = []int32{ + 0, // 0: github.com.cossacklabs.acra.encryptor.config.common.EncryptedValue.type:type_name -> github.com.cossacklabs.acra.encryptor.config.common.EncryptedType + 1, // [1:1] is the sub-list for method output_type + 1, // [1:1] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 1, // [1:1] is the sub-list for extension extendee + 0, // [0:1] is the sub-list for field type_name +} + +func init() { file_encryptedTypes_proto_init() } +func file_encryptedTypes_proto_init() { + if File_encryptedTypes_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_encryptedTypes_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*EncryptedValue); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_encryptedTypes_proto_rawDesc, + NumEnums: 1, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_encryptedTypes_proto_goTypes, + DependencyIndexes: file_encryptedTypes_proto_depIdxs, + EnumInfos: file_encryptedTypes_proto_enumTypes, + MessageInfos: file_encryptedTypes_proto_msgTypes, + }.Build() + File_encryptedTypes_proto = out.File + file_encryptedTypes_proto_rawDesc = nil + file_encryptedTypes_proto_goTypes = nil + file_encryptedTypes_proto_depIdxs = nil +} diff --git a/encryptor/config/common/encryptedTypes.proto b/encryptor/config/common/encryptedTypes.proto new file mode 100644 index 000000000..d28036a5f --- /dev/null +++ b/encryptor/config/common/encryptedTypes.proto @@ -0,0 +1,22 @@ +syntax = "proto3"; + +// we should use unique name to avoid package duplication with pseudonymization/common package (that left as is +// to not break backward compatibility) +package github.com.cossacklabs.acra.encryptor.config.common; + +option go_package = "github.com/cossacklabs/acra/encryptor/config/common"; + +// EncryptedType defines types for encrypted data. +enum EncryptedType { + Unknown = 0; + Int32 = 1; + Int64 = 2; + String = 3; + Bytes = 4; +} + +// EncryptedValue keeps serialized encrypted value. +message EncryptedValue { + bytes value = 1; + EncryptedType type = 2; +} diff --git a/encryptor/config/common/encryptedTypes_test.go b/encryptor/config/common/encryptedTypes_test.go new file mode 100644 index 000000000..8e834e5c6 --- /dev/null +++ b/encryptor/config/common/encryptedTypes_test.go @@ -0,0 +1,49 @@ +package common + +import ( + "math" + "strconv" + "testing" +) + +func TestValidateDefaultValue(t *testing.T) { + type args struct { + value *string + dataType EncryptedType + } + var ( + emptyString = "" + int32String = strconv.FormatUint(math.MaxInt32, 10) + int64String = strconv.FormatUint(math.MaxInt64, 10) + // use max uint64 as value for int64 that should overflow + invalidInt64String = strconv.FormatUint(math.MaxUint64, 10) + // valid ASCII [0, 127]. All greater values validated as UTF8 + invalidString = string([]byte{128, 129}) + someString = "some string" + ) + tests := []struct { + name string + args args + wantErr bool + }{ + {"nil value unknown", args{nil, EncryptedType_Unknown}, false}, + {"non-nil value unknown", args{&int32String, EncryptedType_Unknown}, true}, + {"invalid string", args{&invalidString, EncryptedType_String}, true}, + {"valid bytes", args{&invalidString, EncryptedType_Bytes}, false}, + {"empty string", args{&emptyString, EncryptedType_String}, false}, + {"empty bytes", args{&emptyString, EncryptedType_Bytes}, false}, + {"int32 string", args{&int32String, EncryptedType_Int32}, false}, + {"invalid integer int32 string", args{&int64String, EncryptedType_Int32}, true}, + {"invalid non-integer int32 string", args{&someString, EncryptedType_Int32}, true}, + {"int64 string", args{&int64String, EncryptedType_Int64}, false}, + {"invalid int64 string", args{&invalidInt64String, EncryptedType_Int64}, true}, + {"invalid non-integer int64 string", args{&someString, EncryptedType_Int64}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := ValidateDefaultValue(tt.args.value, tt.args.dataType); (err != nil) != tt.wantErr { + t.Errorf("ValidateDefaultValue() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/encryptor/config/encryptionSettings.go b/encryptor/config/encryptionSettings.go index 045b93d84..ebb79c4ae 100644 --- a/encryptor/config/encryptionSettings.go +++ b/encryptor/config/encryptionSettings.go @@ -18,6 +18,8 @@ package config import ( "errors" + "fmt" + common2 "github.com/cossacklabs/acra/encryptor/config/common" maskingCommon "github.com/cossacklabs/acra/masking/common" "github.com/cossacklabs/acra/pseudonymization/common" ) @@ -39,6 +41,8 @@ const ( SettingZoneIDFlag SettingAcraBlockEncryptionFlag SettingAcraStructEncryptionFlag + SettingDataTypeFlag + SettingDefaultDataValueFlag ) // validSettings store all valid combinations of encryption settings @@ -58,6 +62,34 @@ var validSettings = map[SettingMask]struct{}{ SettingZoneIDFlag | SettingAcraBlockEncryptionFlag | SettingReEncryptionFlag: {}, SettingZoneIDFlag | SettingAcraStructEncryptionFlag | SettingReEncryptionFlag: {}, + + ///////////// + // DataType tampering + ///////////// + + // AcraBlock + + // ClientID + SettingDataTypeFlag | SettingReEncryptionFlag | SettingClientIDFlag | SettingAcraBlockEncryptionFlag: {}, + SettingDataTypeFlag | SettingDefaultDataValueFlag | SettingReEncryptionFlag | SettingClientIDFlag | SettingAcraBlockEncryptionFlag: {}, + + SettingDataTypeFlag | SettingReEncryptionFlag | SettingClientIDFlag | SettingAcraBlockEncryptionFlag | SettingMaskingFlag | SettingMaskingPlaintextLengthFlag | SettingMaskingPlaintextSideFlag: {}, + + // ZoneID + + SettingDataTypeFlag | SettingReEncryptionFlag | SettingAcraBlockEncryptionFlag | SettingZoneIDFlag: {}, + SettingDataTypeFlag | SettingDefaultDataValueFlag | SettingReEncryptionFlag | SettingAcraBlockEncryptionFlag | SettingZoneIDFlag: {}, + + // AcraStruct + // ClientID + SettingDataTypeFlag | SettingReEncryptionFlag | SettingClientIDFlag | SettingAcraStructEncryptionFlag: {}, + SettingDataTypeFlag | SettingDefaultDataValueFlag | SettingReEncryptionFlag | SettingClientIDFlag | SettingAcraStructEncryptionFlag: {}, + + // ZoneID + + SettingDataTypeFlag | SettingReEncryptionFlag | SettingAcraStructEncryptionFlag | SettingZoneIDFlag: {}, + SettingDataTypeFlag | SettingDefaultDataValueFlag | SettingReEncryptionFlag | SettingAcraStructEncryptionFlag | SettingZoneIDFlag: {}, + ///////////// // SEARCHABLE // default ClientID @@ -137,6 +169,11 @@ type BasicColumnEncryptionSetting struct { UsedClientID string `yaml:"client_id"` UsedZoneID string `yaml:"zone_id"` + // same as TokenType but related for encryption operations + DataType string `yaml:"data_type"` + // string for str/email/int32/int64 ans base64 string for binary data + DefaultDataValue *string `yaml:"default_data_value"` + // Data pseudonymization (tokenization) Tokenized bool `yaml:"tokenized"` ConsistentTokenization bool `yaml:"consistent_tokenization"` @@ -153,8 +190,17 @@ type BasicColumnEncryptionSetting struct { settingMask SettingMask } +// IsBinaryDataOperation return true if setting related to operation over binary data +func IsBinaryDataOperation(setting ColumnEncryptionSetting) bool { + // tokenization for binary data or encryption/masking of binary data (not text) + hasBinaryOperation := setting.GetTokenType() == common.TokenType_Bytes + hasBinaryOperation = hasBinaryOperation || setting.OnlyEncryption() || setting.IsSearchable() + hasBinaryOperation = hasBinaryOperation || len(setting.GetMaskingPattern()) != 0 + return hasBinaryOperation +} + // Init validate and initialize SettingMask -func (s *BasicColumnEncryptionSetting) Init() error { +func (s *BasicColumnEncryptionSetting) Init() (err error) { if len(s.Name) == 0 { return ErrInvalidEncryptorConfig } @@ -169,7 +215,7 @@ func (s *BasicColumnEncryptionSetting) Init() error { s.settingMask |= SettingClientIDFlag } if s.CryptoEnvelope != nil { - if err := ValidateCryptoEnvelopeType(*s.CryptoEnvelope); err != nil { + if err = ValidateCryptoEnvelopeType(*s.CryptoEnvelope); err != nil { return err } switch *s.CryptoEnvelope { @@ -184,16 +230,46 @@ func (s *BasicColumnEncryptionSetting) Init() error { if s.ReEncryptToAcraBlock != nil && *s.ReEncryptToAcraBlock { s.settingMask |= SettingReEncryptionFlag } - - if s.Tokenized { - s.settingMask |= SettingTokenizationFlag + if s.TokenType != "" || s.Tokenized { tokenType, ok := tokenTypeNames[s.TokenType] if !ok { - return common.ErrUnknownTokenType + return fmt.Errorf("%s: %w", s.TokenType, common.ErrUnknownTokenType) } - if err := common.ValidateTokenType(tokenType); err != nil { + if err = common.ValidateTokenType(tokenType); err != nil { return err } + } + if s.DataType == "" { + // by default all encrypted data is binary + s.DataType, _ = common2.EncryptedType_Bytes.ToConfigString() + // if DataType empty but configured for tokenization then map TokenType to appropriate DataType + if s.TokenType != "" { + // we don't validate because it's already validated above + tokenType, _ := tokenTypeNames[s.TokenType] + s.DataType, err = tokenType.ToEncryptedDataType().ToConfigString() + if err != nil { + return err + } + } + } else { + s.settingMask |= SettingDataTypeFlag + } + dataType, err := common2.ParseStringEncryptedType(s.DataType) + if err != nil { + return fmt.Errorf("%s: %w", s.DataType, common2.ErrUnknownEncryptedType) + } + if err = common2.ValidateEncryptedType(dataType); err != nil { + return err + } + if s.DefaultDataValue != nil { + s.settingMask |= SettingDefaultDataValueFlag + } + if err = common2.ValidateDefaultValue(s.DefaultDataValue, dataType); err != nil { + return fmt.Errorf("invalid default value: %w", err) + } + + if s.Tokenized { + s.settingMask |= SettingTokenizationFlag s.settingMask |= SettingTokenTypeFlag if s.ConsistentTokenization { s.settingMask |= SettingConsistentTokenizationFlag @@ -205,7 +281,7 @@ func (s *BasicColumnEncryptionSetting) Init() error { } if s.MaskingPattern != "" || s.PlaintextSide != "" { - if err := maskingCommon.ValidateMaskingParams(s.MaskingPattern, s.PartialPlaintextLenBytes, s.PlaintextSide); err != nil { + if err = maskingCommon.ValidateMaskingParams(s.MaskingPattern, s.PartialPlaintextLenBytes, s.PlaintextSide, s.GetEncryptedDataType()); err != nil { return err } s.settingMask |= SettingMaskingFlag | SettingMaskingPlaintextLengthFlag | SettingMaskingPlaintextSideFlag @@ -275,7 +351,7 @@ func (s *BasicColumnEncryptionSetting) IsConsistentTokenization() bool { func (s *BasicColumnEncryptionSetting) GetTokenType() common.TokenType { // If the configuration file contains some unknown or unsupported token type, // return some safe default. - const defaultTokenType = common.TokenType_Bytes + const defaultTokenType = common.TokenType_Unknown tokenType, ok := tokenTypeNames[s.TokenType] if !ok { return defaultTokenType @@ -303,6 +379,23 @@ func (s *BasicColumnEncryptionSetting) IsEndMasking() bool { return s.PlaintextSide == maskingCommon.PlainTextSideLeft } +// GetEncryptedDataType returns data type for encrypted data +func (s *BasicColumnEncryptionSetting) GetEncryptedDataType() common2.EncryptedType { + // If the configuration file contains some unknown or unsupported token type, + // return some safe default. + const defaultDataType = common2.EncryptedType_Bytes + dataType, err := common2.ParseStringEncryptedType(s.DataType) + if err != nil { + return defaultDataType + } + return dataType +} + +// GetDefaultDataValue returns default data value for encrypted data +func (s *BasicColumnEncryptionSetting) GetDefaultDataValue() *string { + return s.DefaultDataValue +} + func (s *BasicColumnEncryptionSetting) applyDefaults(defaults defaultValues) { if s.CryptoEnvelope == nil { v := defaults.GetCryptoEnvelope() @@ -319,3 +412,16 @@ func (s *BasicColumnEncryptionSetting) applyDefaults(defaults defaultValues) { } } } + +// HasTypeAwareSupport return true if setting configured for decryption with type awareness +func HasTypeAwareSupport(setting ColumnEncryptionSetting) bool { + maskingSupport := setting.GetMaskingPattern() != "" + switch setting.GetEncryptedDataType() { + case common2.EncryptedType_String, common2.EncryptedType_Bytes, common2.EncryptedType_Int32, common2.EncryptedType_Int64: + break + default: + // intX not supported masking with type awareness + maskingSupport = false + } + return setting.OnlyEncryption() || setting.IsSearchable() || maskingSupport +} diff --git a/encryptor/config/schemaStore_test.go b/encryptor/config/schemaStore_test.go index e4155971f..c1097068a 100644 --- a/encryptor/config/schemaStore_test.go +++ b/encryptor/config/schemaStore_test.go @@ -273,13 +273,388 @@ schemas: tokenized: true token_type: invalid `, - //pseudonymization.ErrUnknownTokenType // use new declared to avoid cycle import errors.New("unknown token type")}, + + // AcraBlock + // type aware decryption, all supported types + {` +schemas: + - table: test_table + columns: + - data1 + - data2 + - data3 + - data4 + encrypted: + - column: data1 + data_type: str + - column: data2 + data_type: bytes + - column: data3 + data_type: int32 + - column: data4 + data_type: int64 +`, + nil}, + // type aware decryption, all supported types + masking + {` +schemas: + - table: test_table + columns: + - data1 + - data2 + - data3 + - data4 + encrypted: + - column: data1 + data_type: str + masking: "00" + plaintext_length: 2 + plaintext_side: "left" + - column: data2 + data_type: bytes + masking: "00" + plaintext_length: 2 + plaintext_side: "left" + - column: data3 + data_type: int32 + - column: data4 + data_type: int64 +`, + nil}, + // type aware decryption, all supported types, specified client id + {` +schemas: + - table: test_table + columns: + - data1 + - data2 + - data3 + - data4 + encrypted: + - column: data1 + data_type: str + client_id: client + - column: data2 + data_type: bytes + client_id: client + - column: data3 + data_type: int32 + client_id: client + - column: data4 + data_type: int64 + client_id: client +`, + nil}, + // type aware decryption, all supported types, specified client id + masking + {` +schemas: + - table: test_table + columns: + - data1 + - data2 + - data3 + - data4 + encrypted: + - column: data1 + data_type: str + client_id: client + masking: "00" + plaintext_length: 2 + plaintext_side: "left" + - column: data2 + data_type: bytes + client_id: client + masking: "00" + plaintext_length: 2 + plaintext_side: "left" + - column: data3 + data_type: int32 + client_id: client + - column: data4 + data_type: int64 + client_id: client +`, + nil}, + // type aware decryption, all supported types, specified zone id + {` +schemas: + - table: test_table + columns: + - data1 + - data2 + - data3 + - data4 + encrypted: + - column: data1 + data_type: str + zone_id: client + - column: data2 + data_type: bytes + zone_id: client + - column: data3 + data_type: int32 + zone_id: client + - column: data4 + data_type: int64 + zone_id: client +`, + nil}, + // type aware decryption, all supported types, default value + {` +schemas: + - table: test_table + columns: + - data1 + - data2 + - data3 + - data4 + encrypted: + - column: data1 + data_type: str + default_data_value: "str" + - column: data2 + data_type: bytes + default_data_value: "bytes" + - column: data3 + data_type: int32 + default_data_value: "123" + - column: data4 + data_type: int64 + default_data_value: "123" +`, + nil}, + // type aware decryption, all supported types, default value, specified client id + {` +schemas: + - table: test_table + columns: + - data1 + - data2 + - data3 + - data4 + encrypted: + - column: data1 + data_type: str + default_data_value: "str" + client_id: client + - column: data2 + data_type: bytes + default_data_value: "bytes" + client_id: client + - column: data3 + data_type: int32 + default_data_value: "123" + client_id: client + - column: data4 + data_type: int64 + default_data_value: "123" + client_id: client +`, + nil}, + // type aware decryption, all supported types, default value, specified zone id + {` +schemas: + - table: test_table + columns: + - data1 + - data2 + - data3 + - data4 + encrypted: + - column: data1 + data_type: str + default_data_value: "str" + zone_id: zone + - column: data2 + data_type: bytes + default_data_value: "bytes" + zone_id: zone + - column: data3 + data_type: int32 + default_data_value: "123" + zone_id: zone + - column: data4 + data_type: int64 + default_data_value: "123" + zone_id: zone +`, + nil}, + // AcraBlock + // type aware decryption, all supported types + {` +defaults: + crypto_envelope: acrastruct +schemas: + - table: test_table + columns: + - data1 + - data2 + - data3 + - data4 + encrypted: + - column: data1 + data_type: str + - column: data2 + data_type: bytes + - column: data3 + data_type: int32 + - column: data4 + data_type: int64 +`, + nil}, + // type aware decryption, all supported types, specified client id + {` +defaults: + crypto_envelope: acrastruct +schemas: + - table: test_table + columns: + - data1 + - data2 + - data3 + - data4 + encrypted: + - column: data1 + data_type: str + client_id: client + - column: data2 + data_type: bytes + client_id: client + - column: data3 + data_type: int32 + client_id: client + - column: data4 + data_type: int64 + client_id: client +`, + nil}, + // type aware decryption, all supported types, specified zone id + {` +defaults: + crypto_envelope: acrastruct +schemas: + - table: test_table + columns: + - data1 + - data2 + - data3 + - data4 + encrypted: + - column: data1 + data_type: str + zone_id: client + - column: data2 + data_type: bytes + zone_id: client + - column: data3 + data_type: int32 + zone_id: client + - column: data4 + data_type: int64 + zone_id: client +`, + nil}, + // type aware decryption, all supported types, default value + {` +defaults: + crypto_envelope: acrastruct +schemas: + - table: test_table + columns: + - data1 + - data2 + - data3 + - data4 + encrypted: + - column: data1 + data_type: str + default_data_value: "str" + - column: data2 + data_type: bytes + default_data_value: "bytes" + - column: data3 + data_type: int32 + default_data_value: "123" + - column: data4 + data_type: int64 + default_data_value: "123" +`, + nil}, + // type aware decryption, all supported types, default value, specified client id + {` +defaults: + crypto_envelope: acrastruct +schemas: + - table: test_table + columns: + - data1 + - data2 + - data3 + - data4 + encrypted: + - column: data1 + data_type: str + default_data_value: "str" + client_id: client + - column: data2 + data_type: bytes + default_data_value: "bytes" + client_id: client + - column: data3 + data_type: int32 + default_data_value: "123" + client_id: client + - column: data4 + data_type: int64 + default_data_value: "123" + client_id: client +`, + nil}, + // type aware decryption, all supported types, default value, specified zone id + {` +defaults: + crypto_envelope: acrastruct +schemas: + - table: test_table + columns: + - data1 + - data2 + - data3 + - data4 + encrypted: + - column: data1 + data_type: str + default_data_value: "str" + zone_id: zone + - column: data2 + data_type: bytes + default_data_value: "bytes" + zone_id: zone + - column: data3 + data_type: int32 + default_data_value: "123" + zone_id: zone + - column: data4 + data_type: int64 + default_data_value: "123" + zone_id: zone +`, + nil}, } for i, tcase := range testcases { _, err := MapTableSchemaStoreFromConfig([]byte(tcase.config)) - if err == nil || err.Error() != tcase.err.Error() { + u, ok := err.(interface { + Unwrap() error + }) + if ok { + err = u.Unwrap() + } + if tcase.err == err { + continue + } + if err.Error() != tcase.err.Error() { t.Fatalf("[%d] Expect %s, took %s\n", i, tcase.err.Error(), err) } } diff --git a/encryptor/config/tableSchema.go b/encryptor/config/tableSchema.go index f9d442f4d..cb0b68081 100644 --- a/encryptor/config/tableSchema.go +++ b/encryptor/config/tableSchema.go @@ -17,6 +17,7 @@ limitations under the License. package config import ( + common2 "github.com/cossacklabs/acra/encryptor/config/common" "github.com/cossacklabs/acra/pseudonymization/common" ) @@ -44,6 +45,9 @@ type ColumnEncryptionSetting interface { ClientID() []byte ZoneID() []byte + GetEncryptedDataType() common2.EncryptedType + GetDefaultDataValue() *string + // Searchable encryption IsSearchable() bool // Data masking diff --git a/encryptor/dataEncryptor_test.go b/encryptor/dataEncryptor_test.go index e9bdf9956..d0dfd30c3 100644 --- a/encryptor/dataEncryptor_test.go +++ b/encryptor/dataEncryptor_test.go @@ -21,6 +21,7 @@ import ( "errors" "github.com/cossacklabs/acra/acrastruct" "github.com/cossacklabs/acra/encryptor/config" + common2 "github.com/cossacklabs/acra/encryptor/config/common" "testing" "github.com/cossacklabs/acra/pseudonymization/common" @@ -79,6 +80,14 @@ func TestAcrawriterDataEncryptor_EncryptWithClientID(t *testing.T) { type emptyEncryptionSetting struct{} +func (s *emptyEncryptionSetting) GetEncryptedDataType() common2.EncryptedType { + panic("implement me") +} + +func (s *emptyEncryptionSetting) GetDefaultDataValue() *string { + panic("implement me") +} + func (s *emptyEncryptionSetting) OnlyEncryption() bool { return true } diff --git a/encryptor/dbDataCoder.go b/encryptor/dbDataCoder.go index b0fd81041..51b588f06 100644 --- a/encryptor/dbDataCoder.go +++ b/encryptor/dbDataCoder.go @@ -23,6 +23,7 @@ import ( "github.com/cossacklabs/acra/sqlparser" "github.com/cossacklabs/acra/utils" "github.com/sirupsen/logrus" + "strconv" ) var pgHexStringPrefix = []byte{'\\', 'x'} @@ -75,10 +76,29 @@ func (*MysqlDBDataCoder) Encode(expr sqlparser.Expr, data []byte) ([]byte, error return nil, errUnsupportedExpression } -// PostgresqlDBDataCoder implement DBDataCoder for PostgreSQL +// PgEncodeToHexString return data as is if it's valid UTF string otherwise encode to hex with \x prefix +func PgEncodeToHexString(data []byte) []byte { + if utils.IsPrintablePostgresqlString(data) { + return data + } + newVal := make([]byte, len(pgHexStringPrefix)+hex.EncodedLen(len(data))) + copy(newVal, pgHexStringPrefix) + hex.Encode(newVal[len(pgHexStringPrefix):], data) + return newVal +} + +// PostgresqlDBDataCoder responsible to handle decoding/encoding SQL literals before/after QueryEncryptor handlers +// +// Acra captures SQL queries like `INSERT INTO users (age, username, email, photo) VALUES (123, 'john_wick', 'johnwick@mail.com', '\xaabbcc');` +// and manipulates with SQL values `123`, `'john_wick'`, `'johnwick@mail.com'`, `'\xaabbcc'`. On first stage Acra +// decodes with Decode method values from SQL literals into binary or leave as is. For example hex encoded values decoded into binary" +// `'\xaabbcc'` decoded into []byte{170,187,204} and passed to QueryEncryptor's callbacks `EncryptWith[Client|Zone]ID` +// After that it should be encoded with Encode method from binary form into SQL to replace values in the query. type PostgresqlDBDataCoder struct{} -// Decode literal in expression to binary +// Decode hex/escaped literals to raw binary values for encryption/decryption. String values left as is because it +// doesn't need any decoding. Historically Int values had support only for tokenization and operated over string SQL +// literals. func (*PostgresqlDBDataCoder) Decode(expr sqlparser.Expr) ([]byte, error) { switch val := expr.(type) { case *sqlparser.SQLVal: @@ -108,7 +128,7 @@ func (*PostgresqlDBDataCoder) Decode(expr sqlparser.Expr) ([]byte, error) { // return value as is because it may be string with printable characters that wasn't encoded on client return val.Val, nil } - return binValue.Data(), nil + return binValue, nil } } return nil, errUnsupportedExpression @@ -119,20 +139,25 @@ func (*PostgresqlDBDataCoder) Encode(expr sqlparser.Expr, data []byte) ([]byte, switch val := expr.(type) { case *sqlparser.SQLVal: switch val.Type { - case sqlparser.IntVal: - return data, nil case sqlparser.HexVal: output := make([]byte, hex.EncodedLen(len(data))) hex.Encode(output, data) return output, nil - case sqlparser.PgEscapeString, sqlparser.StrVal: - if utils.IsPrintableASCIIArray(data) { + case sqlparser.IntVal: + // QueryDataEncryptor can tokenize INT SQL literal and we should not do anything because it is still valid + // INT literal. Also, handler can encrypt data and replace SQL literal with encrypted data as []byte result. + // Due to invalid format for INT literals, we should encode it as valid hex encoded binary value and change + // type of SQL token for sqlparser that encoded into final SQL string + + // if data was just tokenized, so we return it as is because it is valid int literal + if _, err := strconv.Atoi(string(data)); err == nil { return data, nil } - newVal := make([]byte, len(pgHexStringPrefix)+hex.EncodedLen(len(data))) - copy(newVal, pgHexStringPrefix) - hex.Encode(newVal[len(pgHexStringPrefix):], data) - return newVal, nil + // otherwise change type and pass it below for hex encoding + val.Type = sqlparser.StrVal + fallthrough + case sqlparser.PgEscapeString, sqlparser.StrVal: + return PgEncodeToHexString(data), nil } } return nil, errUnsupportedExpression diff --git a/encryptor/queryDataEncryptor.go b/encryptor/queryDataEncryptor.go index ad0dcb436..ab0d9fb51 100644 --- a/encryptor/queryDataEncryptor.go +++ b/encryptor/queryDataEncryptor.go @@ -20,31 +20,51 @@ import ( "bytes" "context" "errors" - "reflect" - "strconv" - "strings" - "github.com/cossacklabs/acra/decryptor/base" "github.com/cossacklabs/acra/encryptor/config" "github.com/cossacklabs/acra/logging" "github.com/cossacklabs/acra/sqlparser" "github.com/cossacklabs/acra/utils" "github.com/sirupsen/logrus" + "reflect" + "strconv" + "strings" ) -type querySelectSetting struct { +// QueryDataItem stores information about table column and encryption setting +type QueryDataItem struct { setting config.ColumnEncryptionSetting tableName string columnName string columnAlias string } +// Setting return associated ColumnEncryptionSetting or nil if not found +func (q *QueryDataItem) Setting() config.ColumnEncryptionSetting { + return q.setting +} + +// TableName return table name associated with item or empty string if it is not related to any table, or not recognized +func (q *QueryDataItem) TableName() string { + return q.tableName +} + +// ColumnName return column name if it was matched to any +func (q *QueryDataItem) ColumnName() string { + return q.columnName +} + +// ColumnAlias if matched as alias to any data item +func (q *QueryDataItem) ColumnAlias() string { + return q.columnAlias +} + // QueryDataEncryptor parse query and encrypt raw data according to TableSchemaStore type QueryDataEncryptor struct { schemaStore config.TableSchemaStore encryptor DataEncryptor dataCoder DBDataCoder - querySelectSettings []*querySelectSetting + querySelectSettings []*QueryDataItem parser *sqlparser.Parser } @@ -64,7 +84,7 @@ func (encryptor *QueryDataEncryptor) ID() string { } // encryptInsertQuery encrypt data in insert query in VALUES and ON DUPLICATE KEY UPDATE statements -func (encryptor *QueryDataEncryptor) encryptInsertQuery(ctx context.Context, insert *sqlparser.Insert) (bool, error) { +func (encryptor *QueryDataEncryptor) encryptInsertQuery(ctx context.Context, insert *sqlparser.Insert, bindPlaceholders map[int]config.ColumnEncryptionSetting) (bool, error) { tableName := insert.Table.Name schema := encryptor.schemaStore.GetTableSchema(tableName.String()) if schema == nil { @@ -74,7 +94,7 @@ func (encryptor *QueryDataEncryptor) encryptInsertQuery(ctx context.Context, ins } if encryptor.encryptor == nil { - return false, encryptor.onReturning(insert.Returning, tableName.RawValue()) + return false, encryptor.onReturning(ctx, insert.Returning, tableName.RawValue()) } var columnsName []string @@ -95,8 +115,13 @@ func (encryptor *QueryDataEncryptor) encryptInsertQuery(ctx context.Context, ins for _, valTuple := range rows { // collect values per column for j, value := range valTuple { + // in case when query `INSERT INTO table1 (col1, col2) VALUES (1, 2), (3, 4, 5); + // in a tuple has incorrect amount of values ("5" in the example) + if j >= len(columnsName) { + continue + } columnName := columnsName[j] - if changedValue, err := encryptor.encryptExpression(ctx, value, schema, columnName); err != nil { + if changedValue, err := encryptor.encryptExpression(ctx, value, schema, columnName, bindPlaceholders); err != nil { logrus.WithField(logging.FieldKeyEventCode, logging.EventCodeErrorEncryptorCantEncryptExpression).WithError(err).Errorln("Can't encrypt expression") return changed, err } else if changedValue { @@ -108,7 +133,12 @@ func (encryptor *QueryDataEncryptor) encryptInsertQuery(ctx context.Context, ins } if len(insert.OnDup) > 0 { - onDupChanged, err := encryptor.encryptUpdateExpressions(ctx, sqlparser.UpdateExprs(insert.OnDup), insert.Table, AliasToTableMap{insert.Table.Name.String(): insert.Table.Name.String()}) + onDupChanged, err := encryptor.encryptUpdateExpressions( + ctx, + sqlparser.UpdateExprs(insert.OnDup), + insert.Table, + AliasToTableMap{insert.Table.Name.String(): insert.Table.Name.String()}, + bindPlaceholders) if err != nil { return changed, err } @@ -156,8 +186,15 @@ func UpdateExpressionValue(ctx context.Context, expr sqlparser.Expr, coder DBDat } // encryptExpression check that expr is SQLVal and has Hexval then try to encrypt -func (encryptor *QueryDataEncryptor) encryptExpression(ctx context.Context, expr sqlparser.Expr, schema config.TableSchema, columnName string) (bool, error) { +func (encryptor *QueryDataEncryptor) encryptExpression(ctx context.Context, expr sqlparser.Expr, schema config.TableSchema, columnName string, bindPlaceholder map[int]config.ColumnEncryptionSetting) (bool, error) { if schema.NeedToEncrypt(columnName) { + if sqlVal, ok := expr.(*sqlparser.SQLVal); ok { + placeholderIndex, err := ParsePlaceholderIndex(sqlVal) + if err == nil { + setting := schema.GetColumnEncryptionSettings(columnName) + bindPlaceholder[placeholderIndex] = setting + } + } err := UpdateExpressionValue(ctx, expr, encryptor.dataCoder, func(ctx context.Context, data []byte) ([]byte, error) { if len(data) == 0 { return data, nil @@ -219,7 +256,7 @@ func (encryptor *QueryDataEncryptor) hasTablesToEncrypt(tables []*AliasedTableNa } // encryptUpdateExpressions try to encrypt all supported exprs. Use firstTable if column has not explicit table name because it's implicitly used in DBMSs -func (encryptor *QueryDataEncryptor) encryptUpdateExpressions(ctx context.Context, exprs sqlparser.UpdateExprs, firstTable sqlparser.TableName, qualifierMap AliasToTableMap) (bool, error) { +func (encryptor *QueryDataEncryptor) encryptUpdateExpressions(ctx context.Context, exprs sqlparser.UpdateExprs, firstTable sqlparser.TableName, qualifierMap AliasToTableMap, bindPlaceholders map[int]config.ColumnEncryptionSetting) (bool, error) { var schema config.TableSchema changed := false for _, expr := range exprs { @@ -234,7 +271,7 @@ func (encryptor *QueryDataEncryptor) encryptUpdateExpressions(ctx context.Contex continue } columnName := expr.Name.Name.String() - if changedExpr, err := encryptor.encryptExpression(ctx, expr.Expr, schema, columnName); err != nil { + if changedExpr, err := encryptor.encryptExpression(ctx, expr.Expr, schema, columnName, bindPlaceholders); err != nil { logrus.WithField(logging.FieldKeyEventCode, logging.EventCodeErrorEncryptorCantEncryptExpression).WithError(err).Errorln("Can't update expression with encrypted sql value") return changed, err } else if changedExpr { @@ -261,7 +298,7 @@ func NewAliasToTableMapFromTables(tables []*AliasedTableName) AliasToTableMap { } // encryptUpdateQuery encrypt data in Update query and return true if any fields was encrypted, false if wasn't and error if error occurred -func (encryptor *QueryDataEncryptor) encryptUpdateQuery(ctx context.Context, update *sqlparser.Update) (bool, error) { +func (encryptor *QueryDataEncryptor) encryptUpdateQuery(ctx context.Context, update *sqlparser.Update, bindPlaceholders map[int]config.ColumnEncryptionSetting) (bool, error) { tables := GetTablesWithAliases(update.TableExprs) if !encryptor.hasTablesToEncrypt(tables) { return false, nil @@ -271,7 +308,7 @@ func (encryptor *QueryDataEncryptor) encryptUpdateQuery(ctx context.Context, upd } qualifierMap := NewAliasToTableMapFromTables(tables) firstTable := tables[0].TableName - return encryptor.encryptUpdateExpressions(ctx, update.Exprs, firstTable, qualifierMap) + return encryptor.encryptUpdateExpressions(ctx, update.Exprs, firstTable, qualifierMap, bindPlaceholders) } // OnColumn return new encryption setting context if info exist, otherwise column data and passed context will be returned @@ -282,7 +319,9 @@ func (encryptor *QueryDataEncryptor) OnColumn(ctx context.Context, data []byte) if columnInfo.Index() < len(encryptor.querySelectSettings) { selectSetting := encryptor.querySelectSettings[columnInfo.Index()] if selectSetting != nil { - return NewContextWithEncryptionSetting(ctx, selectSetting.setting), data, nil + + logging.GetLoggerFromContext(ctx).WithField("column_index", columnInfo.Index()).WithField("column", selectSetting.ColumnName()).Debugln("Set encryption setting") + return NewContextWithEncryptionSetting(ctx, selectSetting.Setting()), data, nil } } @@ -292,22 +331,22 @@ func (encryptor *QueryDataEncryptor) OnColumn(ctx context.Context, data []byte) const allColumnsName = "*" -func (encryptor *QueryDataEncryptor) onSelect(statement *sqlparser.Select) (bool, error) { +func (encryptor *QueryDataEncryptor) onSelect(ctx context.Context, statement *sqlparser.Select) (bool, error) { columns, err := mapColumnsToAliases(statement) if err != nil { logrus.WithError(err).Errorln("Can't extract columns from SELECT statement") return false, err } - querySelectSettings := make([]*querySelectSetting, 0, len(columns)) + querySelectSettings := make([]*QueryDataItem, 0, len(columns)) for _, data := range columns { if data != nil { if schema := encryptor.schemaStore.GetTableSchema(data.Table); schema != nil { - var setting *querySelectSetting = nil + var setting *QueryDataItem = nil if data.Name == allColumnsName { for _, name := range schema.Columns() { setting = nil if columnSetting := schema.GetColumnEncryptionSettings(name); columnSetting != nil { - setting = &querySelectSetting{ + setting = &QueryDataItem{ setting: columnSetting, tableName: data.Table, columnName: name, @@ -318,7 +357,7 @@ func (encryptor *QueryDataEncryptor) onSelect(statement *sqlparser.Select) (bool } } else { if columnSetting := schema.GetColumnEncryptionSettings(data.Name); columnSetting != nil { - setting = &querySelectSetting{ + setting = &QueryDataItem{ setting: columnSetting, tableName: data.Table, columnName: data.Name, @@ -332,22 +371,25 @@ func (encryptor *QueryDataEncryptor) onSelect(statement *sqlparser.Select) (bool } querySelectSettings = append(querySelectSettings, nil) } + clientSession := base.ClientSessionFromContext(ctx) + SaveQueryDataItemsToClientSession(clientSession, querySelectSettings) + encryptor.querySelectSettings = querySelectSettings return false, nil } -func (encryptor *QueryDataEncryptor) onReturning(returning sqlparser.Returning, tableName string) error { +func (encryptor *QueryDataEncryptor) onReturning(ctx context.Context, returning sqlparser.Returning, tableName string) error { if len(returning) == 0 { return nil } schema := encryptor.schemaStore.GetTableSchema(tableName) - querySelectSettings := make([]*querySelectSetting, 0, 8) + querySelectSettings := make([]*QueryDataItem, 0, 8) if _, ok := returning[0].(*sqlparser.StarExpr); ok { for _, name := range schema.Columns() { if columnSetting := schema.GetColumnEncryptionSettings(name); columnSetting != nil { - querySelectSettings = append(querySelectSettings, &querySelectSetting{ + querySelectSettings = append(querySelectSettings, &QueryDataItem{ setting: columnSetting, tableName: tableName, columnName: name, @@ -356,6 +398,8 @@ func (encryptor *QueryDataEncryptor) onReturning(returning sqlparser.Returning, } querySelectSettings = append(querySelectSettings, nil) } + clientSession := base.ClientSessionFromContext(ctx) + SaveQueryDataItemsToClientSession(clientSession, querySelectSettings) encryptor.querySelectSettings = querySelectSettings return nil } @@ -368,7 +412,7 @@ func (encryptor *QueryDataEncryptor) onReturning(returning sqlparser.Returning, rawColName := colName.Name.String() if columnSetting := schema.GetColumnEncryptionSettings(rawColName); columnSetting != nil { - querySelectSettings = append(querySelectSettings, &querySelectSetting{ + querySelectSettings = append(querySelectSettings, &QueryDataItem{ setting: columnSetting, tableName: tableName, columnName: rawColName, @@ -377,7 +421,8 @@ func (encryptor *QueryDataEncryptor) onReturning(returning sqlparser.Returning, } querySelectSettings = append(querySelectSettings, nil) } - + clientSession := base.ClientSessionFromContext(ctx) + SaveQueryDataItemsToClientSession(clientSession, querySelectSettings) encryptor.querySelectSettings = querySelectSettings return nil } @@ -389,14 +434,18 @@ func (encryptor *QueryDataEncryptor) OnQuery(ctx context.Context, query base.OnQ return query, false, err } changed := false - switch statement := statement.(type) { + // collect placeholder in queries to save for future ParameterDescription packet to replace according to + // setting's data type + clientSession := base.ClientSessionFromContext(ctx) + bindPlaceholders := PlaceholderSettingsFromClientSession(clientSession) + switch typedStatement := statement.(type) { case *sqlparser.Select: - changed, err = encryptor.onSelect(statement) + changed, err = encryptor.onSelect(ctx, typedStatement) case *sqlparser.Insert: - changed, err = encryptor.encryptInsertQuery(ctx, statement) + changed, err = encryptor.encryptInsertQuery(ctx, typedStatement, bindPlaceholders) case *sqlparser.Update: if encryptor.encryptor != nil { - changed, err = encryptor.encryptUpdateQuery(ctx, statement) + changed, err = encryptor.encryptUpdateQuery(ctx, typedStatement, bindPlaceholders) } } if err != nil { @@ -434,15 +483,15 @@ func (encryptor *QueryDataEncryptor) OnBind(ctx context.Context, statement sqlpa return newValues, changed, nil } -func (encryptor *QueryDataEncryptor) encryptInsertValues(ctx context.Context, insert *sqlparser.Insert, values []base.BoundValue) ([]base.BoundValue, bool, error) { - logrus.Debugln("QueryDataEncryptor.encryptInsertValues") +func (encryptor *QueryDataEncryptor) getInsertPlaceholders(ctx context.Context, insert *sqlparser.Insert) (map[int]string, error) { tableName := insert.Table.Name + logger := logging.GetLoggerFromContext(ctx) // Look for the schema of the table where the INSERT happens. // If we don't have a schema then we don't know what to encrypt, so do nothing. schema := encryptor.schemaStore.GetTableSchema(tableName.String()) if schema == nil { - logrus.WithField("table", tableName).Debugln("No encryption schema") - return values, false, nil + logger.WithField("table", tableName).Debugln("No encryption schema") + return nil, nil } // Gather column names from the INSERT query. If there are no columns in the query, @@ -458,11 +507,11 @@ func (encryptor *QueryDataEncryptor) encryptInsertValues(ctx context.Context, in } // If there is no column schema available, we can't encrypt values. if len(columns) == 0 { - logrus.WithField("table", tableName).Debugln("No column information") - return values, false, nil + logger.WithField("table", tableName).Debugln("No column information") + return nil, nil } - placeholders := make(map[int]string, len(values)) + placeholders := make(map[int]string, len(insert.Columns)) // We can also only process simple queries of the form // @@ -472,20 +521,67 @@ func (encryptor *QueryDataEncryptor) encryptInsertValues(ctx context.Context, in // as inserted values. We don't support functions, casts, inserting query results, etc. // // Walk through the query to find out which placeholders stand for which columns. + // Also count amount of passed value to validate that placeholder's index doesn't go out of this number + valuesCount := 0 switch rows := insert.Rows.(type) { case sqlparser.Values: for _, row := range rows { + valuesCount += len(row) for i, value := range row { + if i >= len(columns) { + logger.WithFields(logrus.Fields{"value_index": i, "column_count": len(columns)}).Warningln("Amount of values in INSERT bigger than column count") + continue + } switch value := value.(type) { case *sqlparser.SQLVal: - err := encryptor.updatePlaceholderMap(values, placeholders, value, columns[i]) + err := encryptor.updatePlaceholderMap(valuesCount, placeholders, value, columns[i]) if err != nil { - return values, false, err + return nil, err } } } } } + return placeholders, nil +} + +func (encryptor *QueryDataEncryptor) savePlaceholderSettingIntoClientSession(ctx context.Context, placeholders map[int]string, schema config.TableSchema) { + if schema == nil { + logrus.Debugln("No encryption schema") + return + } + if placeholders == nil { + logrus.Debugln("No placeholders") + return + } + clientSession := base.ClientSessionFromContext(ctx) + bindData := PlaceholderSettingsFromClientSession(clientSession) + for i, columnName := range placeholders { + if !schema.NeedToEncrypt(columnName) { + continue + } + setting := schema.GetColumnEncryptionSettings(columnName) + bindData[i] = setting + } +} + +func (encryptor *QueryDataEncryptor) encryptInsertValues(ctx context.Context, insert *sqlparser.Insert, values []base.BoundValue) ([]base.BoundValue, bool, error) { + logger := logging.GetLoggerFromContext(ctx) + logger.Debugln("QueryDataEncryptor.encryptInsertValues") + tableName := insert.Table.Name + // Look for the schema of the table where the INSERT happens. + // If we don't have a schema then we don't know what to encrypt, so do nothing. + schema := encryptor.schemaStore.GetTableSchema(tableName.String()) + if schema == nil { + logrus.WithField("table", tableName).Debugln("No encryption schema") + return values, false, nil + } + placeholders, err := encryptor.getInsertPlaceholders(ctx, insert) + if err != nil { + logger.WithError(err).Errorln("Can't extract placeholders from INSERT query") + return values, false, err + } + encryptor.savePlaceholderSettingIntoClientSession(ctx, placeholders, schema) // TODO(ilammy, 2020-10-13): handle ON DUPLICATE KEY UPDATE clauses // These clauses are handled for textual queries. It would be nice to encrypt @@ -532,7 +628,7 @@ func (encryptor *QueryDataEncryptor) encryptUpdateValues(ctx context.Context, up columnName := expr.Name.Name.String() switch value := expr.Expr.(type) { case *sqlparser.SQLVal: - err := encryptor.updatePlaceholderMap(values, placeholders, value, columnName) + err := encryptor.updatePlaceholderMap(len(values), placeholders, value, columnName) if err != nil { return values, false, err } @@ -545,7 +641,7 @@ func (encryptor *QueryDataEncryptor) encryptUpdateValues(ctx context.Context, up } // updatePlaceholderMap matches the placeholder of a value to its column and records this into the mapping. -func (encryptor *QueryDataEncryptor) updatePlaceholderMap(values []base.BoundValue, placeholders map[int]string, placeholder *sqlparser.SQLVal, columnName string) error { +func (encryptor *QueryDataEncryptor) updatePlaceholderMap(valuesCount int, placeholders map[int]string, placeholder *sqlparser.SQLVal, columnName string) error { updateMapByPlaceholderPart := func(part string) error { text := string(placeholder.Val) index, err := strconv.Atoi(strings.TrimPrefix(text, part)) @@ -555,13 +651,13 @@ func (encryptor *QueryDataEncryptor) updatePlaceholderMap(values []base.BoundVal } // Placeholders use 1-based indexing and "values" (Go slice) are 0-based. index-- - if index >= len(values) { - logrus.WithFields(logrus.Fields{"placeholder": text, "index": index, "values": len(values)}). + if index >= valuesCount { + logrus.WithFields(logrus.Fields{"placeholder": text, "index": index, "values": valuesCount}). Warning("Invalid placeholder index") return ErrInvalidPlaceholder } // Placeholders must map to columns uniquely. - // If there is already a column for given placholder and it's not the same, + // If there is already a column for given placeholder and it's not the same, // we can't handle such queries currently. name, exists := placeholders[index] if exists && name != columnName { @@ -607,16 +703,22 @@ func (encryptor *QueryDataEncryptor) encryptValuesWithPlaceholders(ctx context.C copy(values, oldValues) } changed = true - settings := schema.GetColumnEncryptionSettings(columnName) - - encryptedData, err := encryptor.encryptWithColumnSettings(ctx, settings, values[valueIndex].GetData(settings)) + setting := schema.GetColumnEncryptionSettings(columnName) + valueData, err := values[valueIndex].GetData(setting) + if err != nil { + return nil, false, err + } + if len(valueData) == 0 { + continue + } + encryptedData, err := encryptor.encryptWithColumnSettings(ctx, setting, valueData) if err != nil && err != ErrUpdateLeaveDataUnchanged { logrus.WithError(err).WithFields(logrus.Fields{"index": valueIndex, "column": columnName}). Debug("Failed to encrypt column") return oldValues, false, err } - err = values[valueIndex].SetData(encryptedData, settings) + err = values[valueIndex].SetData(encryptedData, setting) if err != nil { logrus.WithError(err).WithFields(logrus.Fields{"index": valueIndex, "column": columnName}). Debug("Failed to set encrypted value") diff --git a/encryptor/queryDataEncryptor_test.go b/encryptor/queryDataEncryptor_test.go index 1f094a47a..494926e24 100644 --- a/encryptor/queryDataEncryptor_test.go +++ b/encryptor/queryDataEncryptor_test.go @@ -22,6 +22,10 @@ import ( "encoding/hex" "fmt" "github.com/cossacklabs/acra/acrastruct" + "github.com/cossacklabs/acra/decryptor/base/mocks" + "github.com/cossacklabs/acra/logging" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/mock" "strings" "testing" @@ -427,6 +431,13 @@ schemas: } } ctx := base.SetAccessContextToContext(context.Background(), base.NewAccessContext(base.WithClientID(defaultClientID))) + clientSession := &mocks.ClientSession{} + sessionData := make(map[string]interface{}, 2) + clientSession.On("GetData", mock.Anything).Return(sessionData, true) + clientSession.On("SetData", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + sessionData[args[0].(string)] = args[1] + }) + ctx = base.SetClientSessionToContext(ctx, clientSession) data, changed, err := mysqlParser.OnQuery(ctx, base.NewOnQueryObjectFromQuery(query, parser)) if err != nil { t.Fatalf("%v. %s", i, err.Error()) @@ -479,7 +490,13 @@ schemas: } ctx := base.SetAccessContextToContext(context.Background(), base.NewAccessContext(base.WithClientID(defaultClientID))) - + clientSession := &mocks.ClientSession{} + data := make(map[string]interface{}, 2) + clientSession.On("GetData", mock.Anything).Return(data, true) + clientSession.On("SetData", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + data[args[0].(string)] = args[1] + }) + ctx = base.SetClientSessionToContext(ctx, clientSession) t.Run("RETURNING *", func(t *testing.T) { query := `INSERT INTO TableWithColumnSchema ('zone_id', 'specified_client_id', 'other_column', 'default_client_id') VALUES (1, 1, 1, 1) RETURNING *` @@ -503,7 +520,7 @@ schemas: setting := mysqlParser.querySelectSettings[i] if columns[i] != setting.columnName { - t.Fatalf("%v. Incorrect querySelectSetting \nTook: %v\nExpected: %v", i, setting.columnName, columns[i]) + t.Fatalf("%v. Incorrect QueryDataItem \nTook: %v\nExpected: %v", i, setting.columnName, columns[i]) } } }) @@ -535,7 +552,7 @@ schemas: setting := mysqlParser.querySelectSettings[i] if returningColumns[i] != setting.columnName { - t.Fatalf("%v. Incorrect querySelectSetting \nTook: %v\nExpected: %v", i, setting.columnName, columns[i]) + t.Fatalf("%v. Incorrect QueryDataItem \nTook: %v\nExpected: %v", i, setting.columnName, columns[i]) } } }) @@ -544,7 +561,7 @@ schemas: func TestEncryptionSettingCollection(t *testing.T) { type testcase struct { config string - settings []*querySelectSetting + settings []*QueryDataItem query string } testcases := []testcase{ @@ -560,7 +577,7 @@ func TestEncryptionSettingCollection(t *testing.T) { - column: data2 crypto_envelope: acrablock`, query: `select data1, data2, data3 from test_table`, - settings: []*querySelectSetting{ + settings: []*QueryDataItem{ {setting: &config.BasicColumnEncryptionSetting{Name: "data1"}, tableName: "test_table", columnName: "data1", columnAlias: "test_table"}, {setting: &config.BasicColumnEncryptionSetting{Name: "data2"}, tableName: "test_table", columnName: "data2", columnAlias: "test_table"}, nil, @@ -578,7 +595,7 @@ func TestEncryptionSettingCollection(t *testing.T) { - column: data2 crypto_envelope: acrablock`, query: `select 1 from test_table`, - settings: []*querySelectSetting{ + settings: []*QueryDataItem{ nil, }, }, @@ -594,7 +611,7 @@ func TestEncryptionSettingCollection(t *testing.T) { - column: data2 crypto_envelope: acrablock`, query: `select * from test_table`, - settings: []*querySelectSetting{ + settings: []*QueryDataItem{ {setting: &config.BasicColumnEncryptionSetting{Name: "data1"}, tableName: "test_table", columnName: "data1", columnAlias: ""}, {setting: &config.BasicColumnEncryptionSetting{Name: "data2"}, tableName: "test_table", columnName: "data2", columnAlias: ""}, nil, @@ -612,7 +629,7 @@ func TestEncryptionSettingCollection(t *testing.T) { - column: data2 crypto_envelope: acrablock`, query: `select 'some string', * from test_table`, - settings: []*querySelectSetting{ + settings: []*QueryDataItem{ nil, {setting: &config.BasicColumnEncryptionSetting{Name: "data1"}, tableName: "test_table", columnName: "data1", columnAlias: ""}, {setting: &config.BasicColumnEncryptionSetting{Name: "data2"}, tableName: "test_table", columnName: "data2", columnAlias: ""}, @@ -631,7 +648,7 @@ func TestEncryptionSettingCollection(t *testing.T) { - column: data2 crypto_envelope: acrablock`, query: `select * from test_table t1`, - settings: []*querySelectSetting{ + settings: []*QueryDataItem{ {setting: &config.BasicColumnEncryptionSetting{Name: "data1"}, tableName: "test_table", columnName: "data1", columnAlias: ""}, {setting: &config.BasicColumnEncryptionSetting{Name: "data2"}, tableName: "test_table", columnName: "data2", columnAlias: ""}, nil, @@ -649,7 +666,7 @@ func TestEncryptionSettingCollection(t *testing.T) { - column: data2 crypto_envelope: acrablock`, query: `select t1.* from test_table t1`, - settings: []*querySelectSetting{ + settings: []*QueryDataItem{ {setting: &config.BasicColumnEncryptionSetting{Name: "data1"}, tableName: "test_table", columnName: "data1", columnAlias: ""}, {setting: &config.BasicColumnEncryptionSetting{Name: "data2"}, tableName: "test_table", columnName: "data2", columnAlias: ""}, nil, @@ -676,7 +693,7 @@ func TestEncryptionSettingCollection(t *testing.T) { - column: data2 crypto_envelope: acrablock`, query: `select t1.*, t2.* from test_table t1, test_table2 t2`, - settings: []*querySelectSetting{ + settings: []*QueryDataItem{ {setting: &config.BasicColumnEncryptionSetting{Name: "data1"}, tableName: "test_table", columnName: "data1", columnAlias: ""}, {setting: &config.BasicColumnEncryptionSetting{Name: "data2"}, tableName: "test_table", columnName: "data2", columnAlias: ""}, nil, @@ -706,7 +723,15 @@ func TestEncryptionSettingCollection(t *testing.T) { if !ok { t.Fatalf("[%d] Test query should be SELECT query, took %s\n", i, tcase.query) } - _, err = encryptor.onSelect(selectExpr) + + clientSession := &mocks.ClientSession{} + data := make(map[string]interface{}, 2) + clientSession.On("GetData", mock.Anything).Return(data, true) + clientSession.On("SetData", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + data[args[0].(string)] = args[1] + }) + ctx := base.SetClientSessionToContext(context.Background(), clientSession) + _, err = encryptor.onSelect(ctx, selectExpr) if err != nil { t.Fatal(err) } @@ -777,9 +802,92 @@ func TestEncryptionSettingCollectionFailures(t *testing.T) { if !ok { t.Fatalf("[%d] Test query should be SELECT query, took %s\n", i, tcase.query) } - _, err = encryptor.onSelect(selectExpr) + _, err = encryptor.onSelect(context.TODO(), selectExpr) + if err != tcase.err { + t.Fatalf("Expect error %s, took %s\n", tcase.err, err) + } + } +} + +func TestInsertWithIncorrectPlaceholdersAmount(t *testing.T) { + type testcase struct { + config string + err error + query string + expectedLog string + } + testcases := []testcase{ + // placeholders more than columns + {config: `schemas: + - table: test_table + columns: + - data1 + encrypted: + - column: data1`, + query: `insert into test_table(data1) values ($1, $2);`, + err: nil, + expectedLog: "Amount of values in INSERT bigger than column count", + }, + // placeholders more than columns with several data rows + {config: `schemas: + - table: test_table + columns: + - data1 + encrypted: + - column: data1`, + query: `insert into test_table(data1) values ($1), ($2, $3);`, + err: nil, + expectedLog: "Amount of values in INSERT bigger than column count", + }, + } + parser := sqlparser.New(sqlparser.ModeDefault) + + encryptor, err := NewPostgresqlQueryEncryptor(nil, parser, NewChainDataEncryptor()) + if err != nil { + t.Fatal(err) + } + // use custom output writer to check buffer for expected log entries + logger := logrus.New() + outBuffer := &bytes.Buffer{} + logger.SetOutput(outBuffer) + ctx := logging.SetLoggerToContext(context.Background(), logrus.NewEntry(logger)) + clientSession := &mocks.ClientSession{} + sessionData := make(map[string]interface{}, 2) + clientSession.On("GetData", mock.Anything).Return(func(key string) interface{} { + return sessionData[key] + }, func(key string) bool { + _, ok := sessionData[key] + return ok + }) + clientSession.On("DeleteData", mock.Anything).Run(func(args mock.Arguments) { + delete(sessionData, args[0].(string)) + }) + clientSession.On("SetData", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + sessionData[args[0].(string)] = args[1] + }) + ctx = base.SetClientSessionToContext(ctx, clientSession) + for i, tcase := range testcases { + outBuffer.Reset() + t.Logf("Test tcase %d\n", i) + schemaStore, err := config.MapTableSchemaStoreFromConfig([]byte(tcase.config)) + if err != nil { + t.Fatal(err) + } + encryptor.schemaStore = schemaStore + statement, err := parser.Parse(tcase.query) + if err != nil { + t.Fatal(err) + } + insertExpr, ok := statement.(*sqlparser.Insert) + if !ok { + t.Fatalf("[%d] Test query should be INSERT query, took %s\n", i, tcase.query) + } + DeletePlaceholderSettingsFromClientSession(clientSession) + bindData := PlaceholderSettingsFromClientSession(clientSession) + _, err = encryptor.encryptInsertQuery(ctx, insertExpr, bindData) if err != tcase.err { t.Fatalf("Expect error %s, took %s\n", tcase.err, err) } + strings.Contains(outBuffer.String(), tcase.expectedLog) } } diff --git a/encryptor/utils.go b/encryptor/utils.go index daa69fe96..1366bba46 100644 --- a/encryptor/utils.go +++ b/encryptor/utils.go @@ -18,7 +18,13 @@ package encryptor import ( "errors" + "github.com/cossacklabs/acra/decryptor/base" + "github.com/cossacklabs/acra/encryptor/config" "github.com/cossacklabs/acra/sqlparser" + "github.com/sirupsen/logrus" + "strconv" + "strings" + "sync" ) var errNotFoundtable = errors.New("not found table for alias") @@ -342,3 +348,96 @@ func mapColumnsToAliases(selectQuery *sqlparser.Select) ([]*columnInfo, error) { } return out, nil } + +// InvalidPlaceholderIndex value that represent invalid index for sql placeholders +const InvalidPlaceholderIndex = -1 + +// ParsePlaceholderIndex parse placeholder index if SQLVal is PgPlaceholder/ValArg otherwise return error and InvalidPlaceholderIndex +func ParsePlaceholderIndex(placeholder *sqlparser.SQLVal) (int, error) { + updateMapByPlaceholderPart := func(part string) (int, error) { + text := string(placeholder.Val) + index, err := strconv.Atoi(strings.TrimPrefix(text, part)) + if err != nil { + logrus.WithField("placeholder", text).WithError(err).Warning("Cannot parse placeholder") + return InvalidPlaceholderIndex, err + } + // Placeholders use 1-based indexing and "values" (Go slice) are 0-based. + index-- + return index, nil + } + + switch placeholder.Type { + case sqlparser.PgPlaceholder: + // PostgreSQL placeholders look like "$1". Parse the number out of them. + return updateMapByPlaceholderPart("$") + case sqlparser.ValArg: + // MySQL placeholders look like ":v1". Parse the number out of them. + return updateMapByPlaceholderPart(":v") + } + return InvalidPlaceholderIndex, ErrInvalidPlaceholder +} + +const queryDataItemKey = "query_data_items" + +// SaveQueryDataItemsToClientSession save slice of QueryDataItem into ClientSession +func SaveQueryDataItemsToClientSession(session base.ClientSession, items []*QueryDataItem) { + session.SetData(queryDataItemKey, items) +} + +// DeleteQueryDataItemsFromClientSession delete items from ClientSession +func DeleteQueryDataItemsFromClientSession(session base.ClientSession) { + session.DeleteData(queryDataItemKey) +} + +// QueryDataItemsFromClientSession return QueryDataItems from ClientSession if saved otherwise nil +func QueryDataItemsFromClientSession(session base.ClientSession) []*QueryDataItem { + data, ok := session.GetData(queryDataItemKey) + if !ok { + return nil + } + items, ok := data.([]*QueryDataItem) + if ok { + return items + } + return nil +} + +var bindPlaceholdersPool = sync.Pool{New: func() interface{} { + return make(map[int]config.ColumnEncryptionSetting, 32) +}} + +const placeholdersSettingKey = "bind_encryption_settings" + +// PlaceholderSettingsFromClientSession return stored in client session ColumnEncryptionSettings related to placeholders +// or create new and save in session +func PlaceholderSettingsFromClientSession(session base.ClientSession) map[int]config.ColumnEncryptionSetting { + data, ok := session.GetData(placeholdersSettingKey) + if !ok { + //logger := logging.GetLoggerFromContext(session.Context()) + value := bindPlaceholdersPool.Get().(map[int]config.ColumnEncryptionSetting) + //logger.WithField("session", session).WithField("value", value).Debugln("Create placeholders") + session.SetData(placeholdersSettingKey, value) + return value + } + items, ok := data.(map[int]config.ColumnEncryptionSetting) + if ok { + return items + } + return nil +} + +// DeletePlaceholderSettingsFromClientSession delete items from ClientSession +func DeletePlaceholderSettingsFromClientSession(session base.ClientSession) { + data := PlaceholderSettingsFromClientSession(session) + if data == nil { + logrus.Warningln("Invalid type of PlaceholderSettings") + session.DeleteData(placeholdersSettingKey) + // do nothing because it's invalid + return + } + for key := range data { + delete(data, key) + } + bindPlaceholdersPool.Put(data) + session.DeleteData(placeholdersSettingKey) +} diff --git a/encryptor/utils_test.go b/encryptor/utils_test.go index fa999190b..6a6484fd4 100644 --- a/encryptor/utils_test.go +++ b/encryptor/utils_test.go @@ -1,7 +1,10 @@ package encryptor import ( + "github.com/cossacklabs/acra/decryptor/base/mocks" + "github.com/cossacklabs/acra/encryptor/config" "github.com/cossacklabs/acra/sqlparser" + "github.com/stretchr/testify/mock" "testing" ) @@ -266,3 +269,50 @@ from table1 join table2 as t2 on from_number = t2.number or to_number = t2.numbe } }) } + +func TestPlaceholderSettings(t *testing.T) { + clientSession := &mocks.ClientSession{} + sessionData := make(map[string]interface{}, 2) + clientSession.On("GetData", mock.Anything).Return(func(key string) interface{} { + return sessionData[key] + }, func(key string) bool { + _, ok := sessionData[key] + return ok + }) + clientSession.On("DeleteData", mock.Anything).Run(func(args mock.Arguments) { + delete(sessionData, args[0].(string)) + }) + clientSession.On("SetData", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + sessionData[args[0].(string)] = args[1] + }) + + sessionData[placeholdersSettingKey] = "trash" + + data := PlaceholderSettingsFromClientSession(clientSession) + if data != nil { + t.Fatal("Expect nil for value with invalid type") + } + DeletePlaceholderSettingsFromClientSession(clientSession) + + // get new initialized map + data = PlaceholderSettingsFromClientSession(clientSession) + // set some data + data[0] = &config.BasicColumnEncryptionSetting{} + data[1] = &config.BasicColumnEncryptionSetting{} + + newData := PlaceholderSettingsFromClientSession(clientSession) + if len(newData) != len(data) { + t.Fatal("Unexpected map with different size") + } + // clear data, force to return map to the pool cleared from data + DeletePlaceholderSettingsFromClientSession(clientSession) + + // we expect that will be returned same value from sync.Pool and check that it's cleared + newData = PlaceholderSettingsFromClientSession(clientSession) + if len(newData) != 0 { + t.Fatal("Map's data wasn't cleared") + } + if len(newData) != len(data) { + t.Fatal("Source map's data wasn't cleared") + } +} diff --git a/go.mod b/go.mod index 20cd9c097..44bf41e1a 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,8 @@ require ( github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e github.com/golang/protobuf v1.5.2 github.com/hashicorp/vault/api v1.3.0 + github.com/jackc/fake v0.0.0-20150926172116-812a484cc733 // indirect + github.com/jackc/pgx v3.6.2+incompatible github.com/jackc/pgx/v4 v4.14.1 github.com/lib/pq v1.10.4 github.com/prometheus/client_golang v1.7.1 diff --git a/go.sum b/go.sum index 5f579c9a9..d707b0fce 100644 --- a/go.sum +++ b/go.sum @@ -331,6 +331,8 @@ github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9 github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= +github.com/jackc/fake v0.0.0-20150926172116-812a484cc733 h1:vr3AYkKovP8uR8AvSGGUK1IDqRa5lAAvEkZG1LKaCRc= +github.com/jackc/fake v0.0.0-20150926172116-812a484cc733/go.mod h1:WrMFNQdiFJ80sQsxDoMokWK1W5TQtxBFNpzWTD84ibQ= github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA= github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= @@ -365,6 +367,8 @@ github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrU github.com/jackc/pgtype v1.8.1-0.20210724151600-32e20a603178/go.mod h1:C516IlIV9NKqfsMCXTdChteoXmwgUceqaLfjg2e3NlM= github.com/jackc/pgtype v1.9.1 h1:MJc2s0MFS8C3ok1wQTdQxWuXQcB6+HwAm5x1CzW7mf0= github.com/jackc/pgtype v1.9.1/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4= +github.com/jackc/pgx v3.6.2+incompatible h1:2zP5OD7kiyR3xzRYMhOcXVvkDZsImVXfj+yIyTQf3/o= +github.com/jackc/pgx v3.6.2+incompatible/go.mod h1:0ZGrqGqkRlliWnWB4zKnWtjbSWbGkVEFm4TeybAXq+I= github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= diff --git a/hmac/dataEncryptor.go b/hmac/dataEncryptor.go index 6b2911345..55b719d2f 100644 --- a/hmac/dataEncryptor.go +++ b/hmac/dataEncryptor.go @@ -79,7 +79,7 @@ func (e *SearchableDataEncryptor) EncryptWithClientID(clientID, data []byte, set return nil, err } } - logrus.WithField("decryptor", e.decryptor).Debugln("Hash data") + logrus.Debugln("Hash data") return append(hash, encryptedData...), nil } return data, nil diff --git a/hmac/decryptor/hashQuery.go b/hmac/decryptor/hashQuery.go index f9e063b01..73daf8479 100644 --- a/hmac/decryptor/hashQuery.go +++ b/hmac/decryptor/hashQuery.go @@ -19,16 +19,13 @@ package decryptor import ( "context" "fmt" - "github.com/cossacklabs/acra/utils" - "strconv" - "strings" - "github.com/cossacklabs/acra/decryptor/base" queryEncryptor "github.com/cossacklabs/acra/encryptor" "github.com/cossacklabs/acra/encryptor/config" "github.com/cossacklabs/acra/hmac" "github.com/cossacklabs/acra/keystore" "github.com/cossacklabs/acra/sqlparser" + "github.com/cossacklabs/acra/utils" "github.com/sirupsen/logrus" ) @@ -62,7 +59,7 @@ func (encryptor *HashQuery) ID() string { return "HashQuery" } -func (encryptor *HashQuery) filterSearchableComparisons(statement sqlparser.Statement) []*sqlparser.ComparisonExpr { +func (encryptor *HashQuery) filterSearchableComparisons(statement sqlparser.Statement) []searchableExprItem { // We are interested only in SELECT statements which access at least one encryptable table. // If that's not the case, we have nothing to do here. defaultTable, aliasedTables := encryptor.filterInterestingTables(statement) @@ -79,12 +76,12 @@ func (encryptor *HashQuery) filterSearchableComparisons(statement sqlparser.Stat } // And among those expressions, not all may refer to columns with searchable encryption // enabled for them. Leave only those expressions which are searchable. - exprs = encryptor.filterSerchableComparisons(exprs, defaultTable, aliasedTables) + searchableExprs := encryptor.filterSerchableComparisons(exprs, defaultTable, aliasedTables) if len(exprs) == 0 { logrus.Debugln("No searchable comparisons in search query") return nil } - return exprs + return searchableExprs } func (encryptor *HashQuery) filterInterestingTables(statement sqlparser.Statement) (*queryEncryptor.AliasedTableName, queryEncryptor.AliasToTableMap) { @@ -131,8 +128,13 @@ func (encryptor *HashQuery) filterComparisonExprs(statement sqlparser.Statement) return exprs } -func (encryptor *HashQuery) filterSerchableComparisons(exprs []*sqlparser.ComparisonExpr, defaultTable *queryEncryptor.AliasedTableName, aliasedTables queryEncryptor.AliasToTableMap) []*sqlparser.ComparisonExpr { - filtered := make([]*sqlparser.ComparisonExpr, 0, len(exprs)) +type searchableExprItem struct { + expr *sqlparser.ComparisonExpr + setting config.ColumnEncryptionSetting +} + +func (encryptor *HashQuery) filterSerchableComparisons(exprs []*sqlparser.ComparisonExpr, defaultTable *queryEncryptor.AliasedTableName, aliasedTables queryEncryptor.AliasToTableMap) []searchableExprItem { + filtered := make([]searchableExprItem, 0, len(exprs)) for _, expr := range exprs { // Leave out comparisons of columns which do not have a schema after alias resolution. column := expr.Left.(*sqlparser.ColName) @@ -146,7 +148,7 @@ func (encryptor *HashQuery) filterSerchableComparisons(exprs []*sqlparser.Compar if encryptionSetting == nil || !encryptionSetting.IsSearchable() { continue } - filtered = append(filtered, expr) + filtered = append(filtered, searchableExprItem{expr: expr, setting: encryptionSetting}) } return filtered } @@ -180,27 +182,40 @@ func (encryptor *HashQuery) OnQuery(ctx context.Context, query base.OnQueryObjec // Extract the subexpressions that we are interested in for searchable encryption. // The list might be empty for non-SELECT queries or for non-eligible SELECTs. // In that case we don't have any more work to do here. - exprs := encryptor.filterSearchableComparisons(stmt) - if len(exprs) == 0 { + items := encryptor.filterSearchableComparisons(stmt) + if len(items) == 0 { return query, false, nil } + clientSession := base.ClientSessionFromContext(ctx) + bindSettings := queryEncryptor.PlaceholderSettingsFromClientSession(clientSession) // Now that we have condition expressions, perform rewriting in them. hashSize := []byte(fmt.Sprintf("%d", hmac.GetDefaultHashSize())) - for _, expr := range exprs { + for _, item := range items { // column = 'value' ===> substring(column, 1, ) = 'value' - expr.Left = &sqlparser.SubstrExpr{ - Name: expr.Left.(*sqlparser.ColName), + item.expr.Left = &sqlparser.SubstrExpr{ + Name: item.expr.Left.(*sqlparser.ColName), From: sqlparser.NewIntVal([]byte{'1'}), To: sqlparser.NewIntVal(hashSize), } // substring(column, 1, ) = 'value' ===> substring(column, 1, ) = // substring(column, 1, ) = $1 ===> no changes - err := queryEncryptor.UpdateExpressionValue(ctx, expr.Right, encryptor.coder, encryptor.calculateHmac) + err := queryEncryptor.UpdateExpressionValue(ctx, item.expr.Right, encryptor.coder, encryptor.calculateHmac) if err != nil { logrus.WithError(err).Debugln("Failed to update expression") return query, false, err } + sqlVal, ok := item.expr.Right.(*sqlparser.SQLVal) + if !ok { + continue + } + placeholderIndex, err := queryEncryptor.ParsePlaceholderIndex(sqlVal) + if err == queryEncryptor.ErrInvalidPlaceholder { + continue + } else if err != nil { + return query, false, err + } + bindSettings[placeholderIndex] = item.setting } logrus.Debugln("HashQuery.OnQuery changed query") return base.NewOnQueryObjectFromStatement(stmt, encryptor.parser), true, nil @@ -223,55 +238,31 @@ func (encryptor *HashQuery) OnBind(ctx context.Context, statement sqlparser.Stat // Extract the subexpressions that we are interested in for searchable encryption. // The list might be empty for non-SELECT queries or for non-eligible SELECTs. // In that case we don't have any more work to do here. - exprs := encryptor.filterSearchableComparisons(statement) - if len(exprs) == 0 { + items := encryptor.filterSearchableComparisons(statement) + if len(items) == 0 { return values, false, nil } // Now that we have expressions, analyze them to look for involved placeholders // and map them onto values that we need to update. - placeholders := make([]int, 0, len(values)) - for _, expr := range exprs { - switch value := expr.Right.(type) { + indexes := make([]int, 0, len(values)) + for _, item := range items { + switch value := item.expr.Right.(type) { case *sqlparser.SQLVal: var err error - placeholders, err = encryptor.updatePlaceholderList(placeholders, values, value) + index, err := queryEncryptor.ParsePlaceholderIndex(value) if err != nil { return values, false, err } + if index >= len(values) { + logrus.WithFields(logrus.Fields{"placeholder": value.Val, "index": index, "values": len(values)}). + Warning("Invalid placeholder index") + return values, false, queryEncryptor.ErrInvalidPlaceholder + } + indexes = append(indexes, index) } } // Finally, once we know which values to replace with HMACs, do this replacement. - return encryptor.replaceValuesWithHMACs(ctx, values, placeholders) -} - -func (encryptor *HashQuery) updatePlaceholderList(placeholders []int, values []base.BoundValue, placeholder *sqlparser.SQLVal) ([]int, error) { - updateMapByPlaceholderPart := func(part string) ([]int, error) { - text := string(placeholder.Val) - index, err := strconv.Atoi(strings.TrimPrefix(text, part)) - if err != nil { - logrus.WithField("placeholder", text).WithError(err).Warning("Cannot parse placeholder") - return nil, err - } - // Placeholders use 1-based indexing and "values" (Go slice) are 0-based. - index-- - if index >= len(values) { - logrus.WithFields(logrus.Fields{"placeholder": text, "index": index, "values": len(values)}). - Warning("Invalid placeholder index") - return nil, queryEncryptor.ErrInvalidPlaceholder - } - placeholders = append(placeholders, index) - return placeholders, nil - } - - switch placeholder.Type { - case sqlparser.PgPlaceholder: - // PostgreSQL placeholders look like "$1". Parse the number out of them. - return updateMapByPlaceholderPart("$") - case sqlparser.ValArg: - // MySQL placeholders look like ":v1". Parse the number out of them. - return updateMapByPlaceholderPart(":v") - } - return placeholders, nil + return encryptor.replaceValuesWithHMACs(ctx, values, indexes) } func (encryptor *HashQuery) replaceValuesWithHMACs(ctx context.Context, values []base.BoundValue, placeholders []int) ([]base.BoundValue, bool, error) { @@ -282,31 +273,31 @@ func (encryptor *HashQuery) replaceValuesWithHMACs(ctx context.Context, values [ // Otherwise, decrypt values at positions indicated by placeholders and replace them with their HMACs. newValues := make([]base.BoundValue, len(values)) copy(newValues, values) + clientSession := base.ClientSessionFromContext(ctx) + bindData := queryEncryptor.PlaceholderSettingsFromClientSession(clientSession) for _, valueIndex := range placeholders { - format := values[valueIndex].Format() - - data := values[valueIndex].GetData(nil) - switch format { - case base.BinaryFormat: - // If we can't decrypt the data and compute its HMAC, searchable encryption failed to apply. - // Since we have already modified the query, it's likely to fail, but we can't do much about it. - hmacHash, err := encryptor.calculateHmac(ctx, data) - if err != nil { - logrus.WithError(err).WithField("index", valueIndex).Debug("Failed to encrypt column") - return values, false, err + var encryptionSetting config.ColumnEncryptionSetting = nil + if bindData != nil { + setting, ok := bindData[valueIndex] + if ok { + encryptionSetting = setting } - // it is ok to ignore the error if not column setting provided - _ = newValues[valueIndex].SetData(hmacHash, nil) - - // TODO(ilammy, 2020-10-14): implement support for base.BindText - // We should parse and decode the data, convert that into HMAC instead, - // and then either force binary format or reencode the data back into text. + } - default: - logrus.WithFields(logrus.Fields{"format": format, "index": valueIndex}). - Warning("Parameter format not supported, skipping") + data, err := values[valueIndex].GetData(encryptionSetting) + if err != nil { + return values, false, err + } + // If we can't decrypt the data and compute its HMAC, searchable encryption failed to apply. + // Since we have already modified the query, it's likely to fail, but we can't do much about it. + hmacHash, err := encryptor.calculateHmac(ctx, data) + if err != nil { + logrus.WithError(err).WithField("index", valueIndex).Debug("Failed to encrypt column") + return values, false, err } + // it is ok to ignore the error if not column setting provided + _ = newValues[valueIndex].SetData(hmacHash, encryptionSetting) } return newValues, true, nil } diff --git a/hmac/decryptor/hashQuery_test.go b/hmac/decryptor/hashQuery_test.go new file mode 100644 index 000000000..0965fa766 --- /dev/null +++ b/hmac/decryptor/hashQuery_test.go @@ -0,0 +1,99 @@ +package decryptor + +import ( + "bytes" + "context" + "github.com/cossacklabs/acra/crypto" + "github.com/cossacklabs/acra/decryptor/base" + "github.com/cossacklabs/acra/decryptor/base/mocks" + encryptor2 "github.com/cossacklabs/acra/encryptor" + "github.com/cossacklabs/acra/encryptor/config" + mocks2 "github.com/cossacklabs/acra/keystore/mocks" + "github.com/cossacklabs/acra/sqlparser" + "github.com/stretchr/testify/mock" + "testing" +) + +// TestSearchablePreparedStatementsWithTextFormat process searchable SELECT query with placeholder for prepared statement +// and use binding values in text format +func TestSearchablePreparedStatementsWithTextFormat(t *testing.T) { + clientSession := &mocks.ClientSession{} + sessionData := make(map[string]interface{}, 2) + clientSession.On("GetData", mock.Anything).Return(func(key string) interface{} { + return sessionData[key] + }, func(key string) bool { + _, ok := sessionData[key] + return ok + }) + clientSession.On("DeleteData", mock.Anything).Run(func(args mock.Arguments) { + delete(sessionData, args[0].(string)) + }) + clientSession.On("SetData", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + sessionData[args[0].(string)] = args[1] + }) + schemaConfig := `schemas: + - table: test_table + columns: + - data1 + encrypted: + - column: data1 + searchable: true` + + query := `select data1 from test_table where data1=$1` + + schema, err := config.MapTableSchemaStoreFromConfig([]byte(schemaConfig)) + if err != nil { + t.Fatal(err) + } + ctx := base.SetClientSessionToContext(context.Background(), clientSession) + parser := sqlparser.New(sqlparser.ModeDefault) + keyStore := &mocks2.ServerKeyStore{} + keyStore.On("GetHMACSecretKey", mock.Anything).Return([]byte(`some key`), nil) + registryHandler := crypto.NewRegistryHandler(nil) + encryptor := NewPostgresqlHashQuery(keyStore, schema, registryHandler) + sourceBindValue := []byte{0, 1, 2, 3} + boundValue := &mocks.BoundValue{} + bindValue := sourceBindValue + boundValue.On("Format").Return(base.TextFormat) + boundValue.On("GetData", mock.Anything).Return(func(config.ColumnEncryptionSetting) []byte { + return bindValue + }, nil) + boundValue.On("SetData", mock.MatchedBy(func(data []byte) bool { + bindValue = data + return true + }), mock.Anything).Return(nil) + _ = bindValue + + queryObj := base.NewOnQueryObjectFromQuery(query, parser) + queryObj, _, err = encryptor.OnQuery(ctx, queryObj) + if err != nil { + t.Fatal(err) + } + bindPlaceholders := encryptor2.PlaceholderSettingsFromClientSession(clientSession) + if len(bindPlaceholders) != 1 { + t.Fatal("Not found expected amount of placeholders") + } + queryObj = base.NewOnQueryObjectFromQuery(query, parser) + statement, err := queryObj.Statement() + if err != nil { + t.Fatal(err) + } + newVals, ok, err := encryptor.OnBind(ctx, statement, []base.BoundValue{boundValue}) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Fatal("Values should be changed") + } + if len(newVals) != 1 { + t.Fatal("Invalid amount of bound values") + } + setting := schema.GetTableSchema("test_table").GetColumnEncryptionSettings("data1") + newData, err := newVals[0].GetData(setting) + if err != nil { + t.Fatal(err) + } + if bytes.Equal(newData, sourceBindValue) { + t.Fatal("Data wasn't changed") + } +} diff --git a/logging/event_codes.go b/logging/event_codes.go index cec688055..ecabea8d0 100644 --- a/logging/event_codes.go +++ b/logging/event_codes.go @@ -153,6 +153,8 @@ const ( EventCodeErrorCodingPostgresqlCantParseColumnsDescription = 1207 EventCodeErrorCodingPostgresqlOctalEscape = 1208 EventCodeErrorCodingCantDecodeSQLValue = 1209 + // used as general error + EventCodeErrorDBProtocolError = 1210 // network additional EventCodeErrorNetworkWrite = 1300 diff --git a/masking/common/patterns.go b/masking/common/patterns.go index 759904e08..e7382ad79 100644 --- a/masking/common/patterns.go +++ b/masking/common/patterns.go @@ -16,7 +16,11 @@ package common -import "errors" +import ( + "errors" + "fmt" + "github.com/cossacklabs/acra/encryptor/config/common" +) // PlainTextSide defines which side of data is left untouched (in plain), and which is masked with a pattern. type PlainTextSide string @@ -35,7 +39,7 @@ var ( ) // ValidateMaskingParams checks and returns an error if masking parameters are incorrect. -func ValidateMaskingParams(pattern string, plaintextLength int, plaintextSide PlainTextSide) error { +func ValidateMaskingParams(pattern string, plaintextLength int, plaintextSide PlainTextSide, dataType common.EncryptedType) error { if len(pattern) == 0 { return ErrInvalidMaskingPattern } @@ -45,5 +49,12 @@ func ValidateMaskingParams(pattern string, plaintextLength int, plaintextSide Pl if plaintextSide != PlainTextSideRight && plaintextSide != PlainTextSideLeft { return ErrInvalidPlaintextSide } + switch dataType { + case common.EncryptedType_String, common.EncryptedType_Bytes: + break + default: + // intX not supported masking with type awareness + return fmt.Errorf("masking configuration error: %w", common.ErrUnsupportedEncryptedType) + } return nil } diff --git a/pseudonymization/common/metadata.pb.go b/pseudonymization/common/metadata.pb.go index ee6de4175..8c2468e58 100644 --- a/pseudonymization/common/metadata.pb.go +++ b/pseudonymization/common/metadata.pb.go @@ -1,18 +1,16 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.25.0 +// protoc-gen-go v1.26.0 // protoc v3.6.1 // source: metadata.proto package common import ( - reflect "reflect" - sync "sync" - - proto "github.com/golang/protobuf/proto" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" ) const ( @@ -22,10 +20,6 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) -// This is a compile-time assertion that a sufficiently up-to-date version -// of the legacy proto package is being used. -const _ = proto.ProtoPackageIsVersion4 - // MetadataContainer is Protobuf container for TokenMetadata. type MetadataContainer struct { state protoimpl.MessageState @@ -110,11 +104,11 @@ var file_metadata_proto_rawDesc = []byte{ 0x63, 0x63, 0x65, 0x73, 0x73, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x08, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, 0x65, 0x64, 0x12, 0x1a, 0x0a, 0x08, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x64, 0x69, 0x73, 0x61, 0x62, - 0x6c, 0x65, 0x64, 0x42, 0x40, 0x5a, 0x3e, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, + 0x6c, 0x65, 0x64, 0x42, 0x35, 0x5a, 0x33, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63, 0x6f, 0x73, 0x73, 0x61, 0x63, 0x6b, 0x6c, 0x61, 0x62, 0x73, 0x2f, 0x61, 0x63, - 0x72, 0x61, 0x2d, 0x65, 0x6e, 0x74, 0x65, 0x72, 0x70, 0x72, 0x69, 0x73, 0x65, 0x2f, 0x70, 0x73, - 0x65, 0x75, 0x64, 0x6f, 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2f, 0x63, - 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x72, 0x61, 0x2f, 0x70, 0x73, 0x65, 0x75, 0x64, 0x6f, 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x2f, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x33, } var ( diff --git a/pseudonymization/common/tokenTypes.go b/pseudonymization/common/tokenTypes.go index e7b3afc08..a7873819d 100644 --- a/pseudonymization/common/tokenTypes.go +++ b/pseudonymization/common/tokenTypes.go @@ -18,6 +18,7 @@ package common import ( "errors" + "github.com/cossacklabs/acra/encryptor/config/common" "github.com/golang/protobuf/proto" ) @@ -30,6 +31,41 @@ var supportedTokenTypes = map[TokenType]bool{ TokenType_Email: true, } +// ToConfigString converts value to string used in encryptor_config +func (x TokenType) ToConfigString() (val string, err error) { + err = ErrUnknownTokenType + switch x { + case TokenType_Int32: + return "int32", nil + case TokenType_Int64: + return "int64", nil + case TokenType_String: + return "str", nil + case TokenType_Bytes: + return "bytes", nil + case TokenType_Email: + return "email", nil + } + return +} + +// ToEncryptedDataType converts value to appropriate EncryptedType +func (x TokenType) ToEncryptedDataType() common.EncryptedType { + switch x { + case TokenType_Int32: + return common.EncryptedType_Int32 + case TokenType_Int64: + return common.EncryptedType_Int64 + case TokenType_String: + return common.EncryptedType_String + case TokenType_Bytes: + return common.EncryptedType_Bytes + case TokenType_Email: + return common.EncryptedType_String + } + return common.EncryptedType_Unknown +} + // Validation errors var ( ErrUnknownTokenType = errors.New("unknown token type") diff --git a/pseudonymization/common/tokenTypes.pb.go b/pseudonymization/common/tokenTypes.pb.go index cf20cb9c7..791d17ee4 100644 --- a/pseudonymization/common/tokenTypes.pb.go +++ b/pseudonymization/common/tokenTypes.pb.go @@ -1,18 +1,16 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.25.0 +// protoc-gen-go v1.26.0 // protoc v3.6.1 // source: tokenTypes.proto package common import ( - reflect "reflect" - sync "sync" - - proto "github.com/golang/protobuf/proto" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" ) const ( @@ -22,10 +20,6 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) -// This is a compile-time assertion that a sufficiently up-to-date version -// of the legacy proto package is being used. -const _ = proto.ProtoPackageIsVersion4 - // TokenType defines tokenization type. type TokenType int32 @@ -163,11 +157,11 @@ var file_tokenTypes_proto_rawDesc = []byte{ 0x03, 0x12, 0x09, 0x0a, 0x05, 0x42, 0x79, 0x74, 0x65, 0x73, 0x10, 0x04, 0x12, 0x09, 0x0a, 0x05, 0x45, 0x6d, 0x61, 0x69, 0x6c, 0x10, 0x05, 0x12, 0x0c, 0x0a, 0x08, 0x49, 0x6e, 0x74, 0x33, 0x32, 0x53, 0x74, 0x72, 0x10, 0x06, 0x12, 0x0c, 0x0a, 0x08, 0x49, 0x6e, 0x74, 0x36, 0x34, 0x53, 0x74, - 0x72, 0x10, 0x07, 0x42, 0x40, 0x5a, 0x3e, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, + 0x72, 0x10, 0x07, 0x42, 0x35, 0x5a, 0x33, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63, 0x6f, 0x73, 0x73, 0x61, 0x63, 0x6b, 0x6c, 0x61, 0x62, 0x73, 0x2f, 0x61, 0x63, - 0x72, 0x61, 0x2d, 0x65, 0x6e, 0x74, 0x65, 0x72, 0x70, 0x72, 0x69, 0x73, 0x65, 0x2f, 0x70, 0x73, - 0x65, 0x75, 0x64, 0x6f, 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2f, 0x63, - 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x72, 0x61, 0x2f, 0x70, 0x73, 0x65, 0x75, 0x64, 0x6f, 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x2f, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x33, } var ( diff --git a/pseudonymization/dataProcessor.go b/pseudonymization/dataProcessor.go deleted file mode 100644 index effce7e1f..000000000 --- a/pseudonymization/dataProcessor.go +++ /dev/null @@ -1,144 +0,0 @@ -package pseudonymization - -import ( - "context" - "encoding/binary" - "errors" - "github.com/cossacklabs/acra/logging" - "strconv" - - "github.com/cossacklabs/acra/decryptor/base" - "github.com/cossacklabs/acra/encryptor" - "github.com/cossacklabs/acra/pseudonymization/common" -) - -// TokenProcessor implements processor which tokenize/detokenize data for acra-server used in decryptor module -type TokenProcessor struct { - tokenizer *DataTokenizer -} - -// NewTokenProcessor return new processor -func NewTokenProcessor(tokenizer *DataTokenizer) (*TokenProcessor, error) { - return &TokenProcessor{tokenizer}, nil -} - -// ID return name of processor -func (p *TokenProcessor) ID() string { - return "TokenProcessor" -} - -// OnColumn tokenize data if configured by encryptor config -func (p *TokenProcessor) OnColumn(ctx context.Context, data []byte) (context.Context, []byte, error) { - accessContext := base.AccessContextFromContext(ctx) - columnSetting, ok := encryptor.EncryptionSettingFromContext(ctx) - if ok && columnSetting.IsTokenized() { - tokenContext := common.TokenContext{ClientID: accessContext.GetClientID(), ZoneID: accessContext.GetZoneID()} - data, err := p.tokenizer.Detokenize(data, tokenContext, columnSetting) - return ctx, data, err - } - return ctx, data, nil -} - -// ErrInvalidDataEncoderMode unsupported DataEncoderMode value -var ErrInvalidDataEncoderMode = errors.New("unsupported DataEncoderMode value") - -// ErrInvalidIntValueBinarySize unsupported DataEncoderMode value -var ErrInvalidIntValueBinarySize = errors.New("unsupported binary size of int value") - -// DataEncoderMode mode of PgSQLDataEncoderProcessor -type DataEncoderMode int8 - -// Available modes of DataEncoderMode -const ( - DataEncoderModeEncode = iota - DataEncoderModeDecode -) - -// PgSQLDataEncoderProcessor implements processor and encode/decode binary intX values to text format which acceptable by Tokenizer -type PgSQLDataEncoderProcessor struct { - mode DataEncoderMode -} - -// NewPgSQLDataEncoderProcessor return new data encoder/decoder from/to binary format for tokenization -func NewPgSQLDataEncoderProcessor(mode DataEncoderMode) (*PgSQLDataEncoderProcessor, error) { - switch mode { - case DataEncoderModeDecode, DataEncoderModeEncode: - return &PgSQLDataEncoderProcessor{mode}, nil - } - return nil, ErrInvalidDataEncoderMode -} - -// ID return name of processor -func (p *PgSQLDataEncoderProcessor) ID() string { - return "PgSQLDataEncoderProcessor" -} - -// OnColumn encode binary value to text and back. Should be before and after tokenizer processor -func (p *PgSQLDataEncoderProcessor) OnColumn(ctx context.Context, data []byte) (context.Context, []byte, error) { - columnSetting, ok := encryptor.EncryptionSettingFromContext(ctx) - if !(ok && columnSetting.IsTokenized()) { - return ctx, data, nil - } - // process only int tokenization - switch columnSetting.GetTokenType() { - case common.TokenType_Int64, common.TokenType_Int32: - break - default: - return ctx, data, nil - } - logger := logging.GetLoggerFromContext(ctx) - newData := data - columnInfo, ok := base.ColumnInfoFromContext(ctx) - if !ok { - logger.WithField("processor", "PgSQLDataEncoderProcessor").Warningln("No column info in ctx") - // we can't do anything - return ctx, data, nil - } - // we should decode only if data in binary format - if !columnInfo.IsBinaryFormat() { - return ctx, data, nil - } - if p.mode == DataEncoderModeEncode { - // convert back from text to binary - value, err := strconv.ParseInt(string(data), 10, 64) - if err != nil { - return ctx, data, err - } - newData = make([]byte, columnInfo.DataBinarySize()) - switch columnInfo.DataBinarySize() { - case 4: - binary.BigEndian.PutUint32(newData, uint32(value)) - break - case 8: - binary.BigEndian.PutUint64(newData, uint64(value)) - break - default: - logger.WithField("size", columnInfo.DataBinarySize()).Warningln("Unsupported int value size") - return ctx, data, ErrInvalidIntValueBinarySize - } - } else if p.mode == DataEncoderModeDecode { - // convert from binary to text literal because tokenizer expects int value as string literal - switch columnSetting.GetTokenType() { - case common.TokenType_Int32, common.TokenType_Int64: - if len(newData) == 4 { - // if high byte is 0xff then it is negative number and we should fill all previous bytes with 0xx too - // otherwise with zeroes - if data[0] == 0xff { - newData = append([]byte{0xff, 0xff, 0xff, 0xff}, data...) - } else { - // extend int32 from 4 bytes to int64 with zeroes - newData = append([]byte{0, 0, 0, 0}, data...) - } - // we accept here only 4 or 8 byte values - } else if len(newData) != 8 { - return ctx, data, ErrInvalidIntValueBinarySize - } - value := binary.BigEndian.Uint64(newData) - newData = []byte(strconv.FormatInt(int64(value), 10)) - } - } else { - return ctx, data, ErrInvalidDataEncoderMode - } - return ctx, newData, nil - -} diff --git a/pseudonymization/dataProcessor_test.go b/pseudonymization/dataProcessor_test.go deleted file mode 100644 index eb6781222..000000000 --- a/pseudonymization/dataProcessor_test.go +++ /dev/null @@ -1,238 +0,0 @@ -package pseudonymization - -import ( - "bytes" - "context" - "github.com/cossacklabs/acra/decryptor/base" - "github.com/cossacklabs/acra/encryptor" - "github.com/cossacklabs/acra/encryptor/config" - "strconv" - "testing" -) - -func TestEncodingDecodingProcessor(t *testing.T) { - type testcase struct { - binValue []byte - stringValue []byte - encodeErr error - decodeErr error - binarySize int - } - testcases := []testcase{ - // int32 without errors - {binValue: []byte{0, 0, 0, 0}, stringValue: []byte("0"), encodeErr: nil, decodeErr: nil, binarySize: 4}, - {binValue: []byte{255, 255, 255, 255}, stringValue: []byte("-1"), encodeErr: nil, decodeErr: nil, binarySize: 4}, - {binValue: []byte{0, 0, 0, 128}, stringValue: []byte("128"), encodeErr: nil, decodeErr: nil, binarySize: 4}, - {binValue: []byte{255, 255, 255, 128}, stringValue: []byte("-128"), encodeErr: nil, decodeErr: nil, binarySize: 4}, - - // int32 with invalid size. returned stringValue should be unchanged - {binValue: []byte{255, 255, 255, 128}, stringValue: []byte("-128"), encodeErr: ErrInvalidIntValueBinarySize, decodeErr: nil, binarySize: 3}, - - // int64 without errors - {binValue: []byte{0, 0, 0, 0, 0, 0, 0, 0}, stringValue: []byte("0"), encodeErr: nil, decodeErr: nil, binarySize: 8}, - {binValue: []byte{255, 255, 255, 255, 255, 255, 255, 255}, stringValue: []byte("-1"), encodeErr: nil, decodeErr: nil, binarySize: 8}, - {binValue: []byte{0, 0, 0, 0, 0, 0, 0, 128}, stringValue: []byte("128"), encodeErr: nil, decodeErr: nil, binarySize: 8}, - {binValue: []byte{255, 255, 255, 255, 255, 255, 255, 128}, stringValue: []byte("-128"), encodeErr: nil, decodeErr: nil, binarySize: 8}, - - // int64 with invalid size. returned stringValue should be unchanged - {binValue: []byte{255, 255, 255, 255, 255, 255, 255, 128}, stringValue: []byte("-128"), encodeErr: ErrInvalidIntValueBinarySize, decodeErr: nil, binarySize: 7}, - } - sizeToTokenType := map[int]string{ - 4: "int32", - 8: "int64", - // set correct values for incorrect sizes - 3: "int32", - 7: "int64", - } - - encoder, err := NewPgSQLDataEncoderProcessor(DataEncoderModeEncode) - if err != nil { - t.Fatal(err) - } - decoder, err := NewPgSQLDataEncoderProcessor(DataEncoderModeDecode) - if err != nil { - t.Fatal(err) - } - for i, tcase := range testcases { - columnInfo := base.NewColumnInfo(0, "", true, tcase.binarySize) - accessContext := &base.AccessContext{} - accessContext.SetColumnInfo(columnInfo) - ctx := base.SetAccessContextToContext(context.Background(), accessContext) - testSetting := config.BasicColumnEncryptionSetting{Tokenized: true, TokenType: sizeToTokenType[tcase.binarySize]} - ctx = encryptor.NewContextWithEncryptionSetting(ctx, &testSetting) - ctx, strData, err := decoder.OnColumn(ctx, tcase.binValue) - if err != tcase.decodeErr { - t.Fatalf("[%d] Expect %s, took %s\n", i, tcase.decodeErr, err) - } - if !bytes.Equal(tcase.stringValue, strData) { - t.Fatalf("[%d] Expect '%s', took '%s'\n", i, tcase.stringValue, strData) - } - _, binData, err := encoder.OnColumn(ctx, strData) - if err != tcase.encodeErr { - t.Fatalf("[%d] Expect %s, took %s\n", i, tcase.encodeErr, err) - } - // we check that start value == final value only if err == nil and we check success whole encoding/decoding - if err == nil { - if !bytes.Equal(binData, tcase.binValue) { - t.Fatalf("[%d] Expect '%s', took '%s'\n", i, binData, tcase.binValue) - } - } else { - // if was error then decoded data should be the same as encoded - if !bytes.Equal(binData, tcase.stringValue) { - t.Fatalf("[%d] Expect '%s', took '%s'\n", i, tcase.stringValue, binData) - } - } - } -} - -func TestSkipWithoutSetting(t *testing.T) { - encoder, err := NewPgSQLDataEncoderProcessor(DataEncoderModeEncode) - if err != nil { - t.Fatal(err) - } - decoder, err := NewPgSQLDataEncoderProcessor(DataEncoderModeDecode) - if err != nil { - t.Fatal(err) - } - testData := []byte("some data") - for _, subscriber := range []base.DecryptionSubscriber{encoder, decoder} { - _, data, err := subscriber.OnColumn(context.Background(), testData) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(data, testData) { - t.Fatal("Result data should be the same") - } - } -} - -func TestSkipWithoutBinaryMode(t *testing.T) { - encoder, err := NewPgSQLDataEncoderProcessor(DataEncoderModeEncode) - if err != nil { - t.Fatal(err) - } - decoder, err := NewPgSQLDataEncoderProcessor(DataEncoderModeDecode) - if err != nil { - t.Fatal(err) - } - testData := []byte("some data") - columnInfo := base.NewColumnInfo(0, "", false, 4) - accessContext := &base.AccessContext{} - accessContext.SetColumnInfo(columnInfo) - ctx := base.SetAccessContextToContext(context.Background(), accessContext) - testSetting := config.BasicColumnEncryptionSetting{Tokenized: true, TokenType: "int32"} - ctx = encryptor.NewContextWithEncryptionSetting(ctx, &testSetting) - for _, subscriber := range []base.DecryptionSubscriber{encoder, decoder} { - _, data, err := subscriber.OnColumn(ctx, testData) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(data, testData) { - t.Fatal("Result data should be the same") - } - } -} - -func TestSkipNotIntTokenType(t *testing.T) { - encoder, err := NewPgSQLDataEncoderProcessor(DataEncoderModeEncode) - if err != nil { - t.Fatal(err) - } - decoder, err := NewPgSQLDataEncoderProcessor(DataEncoderModeDecode) - if err != nil { - t.Fatal(err) - } - testData := []byte("some data") - columnInfo := base.NewColumnInfo(0, "", false, 4) - accessContext := &base.AccessContext{} - accessContext.SetColumnInfo(columnInfo) - ctx := base.SetAccessContextToContext(context.Background(), accessContext) - testSetting := config.BasicColumnEncryptionSetting{Tokenized: true, TokenType: "string"} - ctx = encryptor.NewContextWithEncryptionSetting(ctx, &testSetting) - for _, subscriber := range []base.DecryptionSubscriber{encoder, decoder} { - _, data, err := subscriber.OnColumn(ctx, testData) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(data, testData) { - t.Fatal("Result data should be the same") - } - } -} - -func TestSkipWithoutColumnInfo(t *testing.T) { - encoder, err := NewPgSQLDataEncoderProcessor(DataEncoderModeEncode) - if err != nil { - t.Fatal(err) - } - decoder, err := NewPgSQLDataEncoderProcessor(DataEncoderModeDecode) - if err != nil { - t.Fatal(err) - } - testData := []byte("some data") - accessContext := &base.AccessContext{} - ctx := base.SetAccessContextToContext(context.Background(), accessContext) - testSetting := config.BasicColumnEncryptionSetting{Tokenized: true, TokenType: "int32"} - ctx = encryptor.NewContextWithEncryptionSetting(ctx, &testSetting) - for _, subscriber := range []base.DecryptionSubscriber{encoder, decoder} { - _, data, err := subscriber.OnColumn(ctx, testData) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(data, testData) { - t.Fatal("Result data should be the same") - } - } -} - -func TestEncodeDecodeModeValidation(t *testing.T) { - // invalid mode - _, err := NewPgSQLDataEncoderProcessor(3) - if err != ErrInvalidDataEncoderMode { - t.Fatalf("Expect ErrInvalidDataEncoderMode, took %s\n", err) - } - - // create valid, but then change internally to invalid - encoder, err := NewPgSQLDataEncoderProcessor(DataEncoderModeEncode) - if err != nil { - t.Fatal(err) - } - // set invalid - encoder.mode = 3 - testData := []byte("some data") - columnInfo := base.NewColumnInfo(0, "", true, 4) - accessContext := &base.AccessContext{} - accessContext.SetColumnInfo(columnInfo) - ctx := base.SetAccessContextToContext(context.Background(), accessContext) - testSetting := config.BasicColumnEncryptionSetting{Tokenized: true, TokenType: "int32"} - ctx = encryptor.NewContextWithEncryptionSetting(ctx, &testSetting) - _, _, err = encoder.OnColumn(ctx, testData) - if err != ErrInvalidDataEncoderMode { - t.Fatalf("Expect ErrInvalidDataEncoderMode, took %s\n", err) - } -} - -func TestFailedEncodingInvalidTextValue(t *testing.T) { - encoder, err := NewPgSQLDataEncoderProcessor(DataEncoderModeEncode) - if err != nil { - t.Fatal(err) - } - columnInfo := base.NewColumnInfo(0, "", true, 4) - accessContext := &base.AccessContext{} - accessContext.SetColumnInfo(columnInfo) - ctx := base.SetAccessContextToContext(context.Background(), accessContext) - testSetting := config.BasicColumnEncryptionSetting{Tokenized: true, TokenType: "int32"} - ctx = encryptor.NewContextWithEncryptionSetting(ctx, &testSetting) - testData := []byte("asdas") - _, data, err := encoder.OnColumn(ctx, testData) - numErr, ok := err.(*strconv.NumError) - if !ok { - t.Fatal("Expect strconv.NumError") - } - if numErr.Err != strconv.ErrSyntax { - t.Fatalf("Expect ErrSyntax, took %s\n", numErr.Err) - } - if !bytes.Equal(data, testData) { - t.Fatal("Result data should be the same") - } -} diff --git a/pseudonymization/data_encoder.go b/pseudonymization/data_encoder.go new file mode 100644 index 000000000..914ac44c4 --- /dev/null +++ b/pseudonymization/data_encoder.go @@ -0,0 +1,35 @@ +package pseudonymization + +import ( + "context" + "github.com/cossacklabs/acra/decryptor/base" + "github.com/cossacklabs/acra/encryptor" + "github.com/cossacklabs/acra/pseudonymization/common" +) + +// TokenProcessor implements processor which tokenize/detokenize data for acra-server used in decryptor module +type TokenProcessor struct { + tokenizer *DataTokenizer +} + +// NewTokenProcessor return new processor +func NewTokenProcessor(tokenizer *DataTokenizer) (*TokenProcessor, error) { + return &TokenProcessor{tokenizer}, nil +} + +// ID return name of processor +func (p *TokenProcessor) ID() string { + return "TokenProcessor" +} + +// OnColumn tokenize data if configured by encryptor config +func (p *TokenProcessor) OnColumn(ctx context.Context, data []byte) (context.Context, []byte, error) { + accessContext := base.AccessContextFromContext(ctx) + columnSetting, ok := encryptor.EncryptionSettingFromContext(ctx) + if ok && columnSetting.IsTokenized() { + tokenContext := common.TokenContext{ClientID: accessContext.GetClientID(), ZoneID: accessContext.GetZoneID()} + data, err := p.tokenizer.Detokenize(data, tokenContext, columnSetting) + return ctx, data, err + } + return ctx, data, nil +} diff --git a/tests/encryptor_configs/transparent_type_aware_decryption.yaml b/tests/encryptor_configs/transparent_type_aware_decryption.yaml new file mode 100644 index 000000000..000902857 --- /dev/null +++ b/tests/encryptor_configs/transparent_type_aware_decryption.yaml @@ -0,0 +1,71 @@ +schemas: + # used in test.py +- table: test_type_aware_decryption_with_defaults + columns: + - id + - value_str + - value_bytes + - value_int32 + - value_int64 + - value_null_str + - value_null_int32 + - value_empty_str + + encrypted: + - column: value_str + data_type: "str" + default_data_value: "value_str" + + - column: value_bytes + data_type: "bytes" + default_data_value: "dmFsdWVfYnl0ZXM=" + + - column: value_int32 + data_type: "int32" + default_data_value: "32" + + - column: value_int64 + data_type: "int64" + default_data_value: "64" + + - column: value_null_str + data_type: "str" + + - column: value_null_int32 + data_type: "str" + + - column: value_empty_str + data_type: "str" + +- table: test_type_aware_decryption_without_defaults + columns: + - id + - value_str + - value_bytes + - value_int32 + - value_int64 + - value_null_str + - value_null_int32 + - value_empty_str + + encrypted: + - column: value_str + data_type: "str" + + - column: value_bytes + data_type: "bytes" + + - column: value_int32 + data_type: "int32" + + - column: value_int64 + data_type: "int64" + + - column: value_null_str + data_type: "str" + + - column: value_null_int32 + data_type: "str" + + - column: value_empty_str + data_type: "str" diff --git a/tests/random_utils.py b/tests/random_utils.py index 0144c2706..691fe7eb9 100644 --- a/tests/random_utils.py +++ b/tests/random_utils.py @@ -21,6 +21,9 @@ def random_int64(): def random_bytes(n=100): + # if default size then return binary and printable data + if n == 100: + return bytes([i for i in range(n)]) return os.urandom(n) diff --git a/tests/test.py b/tests/test.py index dfc5f6258..b110e53e2 100644 --- a/tests/test.py +++ b/tests/test.py @@ -954,7 +954,8 @@ def _connect(self, loop): asyncpg.connect( host=self.connection_args.host, port=self.connection_args.port, user=self.connection_args.user, password=self.connection_args.password, - database=self.connection_args.dbname, ssl=ssl_context, + database=self.connection_args.dbname, + ssl=ssl_context, **asyncpg_connect_args)) def _set_text_format(self, conn): @@ -1713,9 +1714,9 @@ def checkSkip(self): def setUp(self): super().setUp() - def executor_with_ssl(ssl_key, ssl_cert): + def executor_with_ssl(ssl_key, ssl_cert, port=self.ACRASERVER_PORT): args = ConnectionArgs( - host=get_db_host(), port=self.ACRASERVER_PORT, dbname=DB_NAME, + host=get_db_host(), port=port, dbname=DB_NAME, user=DB_USER, password=DB_USER_PASSWORD, ssl_ca=TEST_TLS_CA, ssl_key=ssl_key, @@ -1726,6 +1727,7 @@ def executor_with_ssl(ssl_key, ssl_cert): self.executor1 = executor_with_ssl(TEST_TLS_CLIENT_KEY, TEST_TLS_CLIENT_CERT) self.executor2 = executor_with_ssl(TEST_TLS_CLIENT_2_KEY, TEST_TLS_CLIENT_2_CERT) + self.raw_executor = executor_with_ssl(TEST_TLS_CLIENT_KEY, TEST_TLS_CLIENT_CERT, 5432) def compileQuery(self, query, parameters={}, literal_binds=False): """ @@ -1981,7 +1983,9 @@ def testBlacklist(self): MysqlExecutor(connection_args)] if TEST_POSTGRESQL: expectedException = (psycopg2.ProgrammingError, - asyncpg.exceptions.SyntaxOrAccessError) + asyncpg.exceptions.SyntaxOrAccessError, + # https://github.com/MagicStack/asyncpg/issues/240 + AttributeError) expectedExceptionInPreparedStatement = asyncpg.exceptions.SyntaxOrAccessError executors = [Psycopg2Executor(connection_args), AsyncpgExecutor(connection_args)] @@ -2020,7 +2024,10 @@ def testWhitelist(self): if TEST_POSTGRESQL: expectedException = (psycopg2.ProgrammingError, asyncpg.exceptions.SyntaxOrAccessError) - expectedExceptionInPreparedStatement = asyncpg.exceptions.SyntaxOrAccessError + expectedExceptionInPreparedStatement = ( + asyncpg.exceptions.SyntaxOrAccessError, + # due to https://github.com/MagicStack/asyncpg/issues/240 + AttributeError) executors = [Psycopg2Executor(connection_args), AsyncpgExecutor(connection_args)] @@ -5292,33 +5299,32 @@ def testAcraTranslator(self): self.checkMetrics(metrics_url, labels) -class TestTransparentEncryption(BaseTestCase): - WHOLECELL_MODE = True - encryptor_table = sa.Table('test_transparent_encryption', metadata, +class BaseTransparentEncryption(BaseTestCase): + encryptor_table = sa.Table( + 'test_transparent_encryption', metadata, sa.Column('id', sa.Integer, primary_key=True), sa.Column('specified_client_id', sa.LargeBinary(length=COLUMN_DATA_SIZE)), - sa.Column('default_client_id', + sa.Column('default_client_id', sa.LargeBinary(length=COLUMN_DATA_SIZE)), - sa.Column('number', sa.Integer), sa.Column('zone_id', sa.LargeBinary(length=COLUMN_DATA_SIZE)), sa.Column('raw_data', sa.LargeBinary(length=COLUMN_DATA_SIZE)), sa.Column('nullable', sa.Text, nullable=True), sa.Column('empty', sa.LargeBinary(length=COLUMN_DATA_SIZE), nullable=False, default=b''), - ) + ) ENCRYPTOR_CONFIG = get_encryptor_config('tests/encryptor_config.yaml') def setUp(self): self.prepare_encryptor_config(client_id=TLS_CERT_CLIENT_ID_1) - super(TestTransparentEncryption, self).setUp() + super(BaseTransparentEncryption, self).setUp() def prepare_encryptor_config(self, client_id=None): prepare_encryptor_config(zone_id=zones[0][ZONE_ID], config_path=self.ENCRYPTOR_CONFIG, client_id=client_id) def tearDown(self): self.engine_raw.execute(self.encryptor_table.delete()) - super(TestTransparentEncryption, self).tearDown() + super(BaseTransparentEncryption, self).tearDown() try: os.remove(get_test_encryptor_config(self.ENCRYPTOR_CONFIG)) except FileNotFoundError: @@ -5327,9 +5333,12 @@ def tearDown(self): def fork_acra(self, popen_kwargs: dict=None, **acra_kwargs: dict): acra_kwargs['encryptor_config_file'] = get_test_encryptor_config( self.ENCRYPTOR_CONFIG) - return super(TestTransparentEncryption, self).fork_acra( + return super(BaseTransparentEncryption, self).fork_acra( popen_kwargs, **acra_kwargs) + +class TestTransparentEncryption(BaseTransparentEncryption): + def get_context_data(self): context = { 'id': get_random_id(), @@ -6858,7 +6867,7 @@ def execute(self, query, ssl_key, ssl_cert): for column, value in row.items(): if isinstance(value, (bytes, bytearray)): try: - row[column] = value.decode('utf8') + row[column] = bytes(value) except (LookupError, UnicodeDecodeError): pass return result @@ -6910,11 +6919,12 @@ def testTokenizationDefaultClientID(self): # data owner take source data for k in ('token_i32', 'token_i64', 'token_str', 'token_bytes', 'token_email'): - if isinstance(source_data[0][k], bytearray) and isinstance(data[k], str): - self.assertEqual(source_data[0][k], bytearray(data[k], encoding='utf-8')) + if isinstance(source_data[0][k], (bytearray, bytes)) and isinstance(data[k], str): + self.assertEqual(source_data[0][k], data[k].encode('utf-8')) + self.assertNotEqual(hidden_data[0][k], data[k].encode('utf-8')) else: self.assertEqual(source_data[0][k], data[k]) - self.assertNotEqual(hidden_data[0][k], data[k]) + self.assertNotEqual(hidden_data[0][k], data[k]) def testTokenizationDefaultClientIDWithBulkInsert(self): default_client_id_table = sa.Table( @@ -6960,8 +6970,8 @@ def testTokenizationDefaultClientIDWithBulkInsert(self): for idx in range(len(source_data)): # data owner take source data for k in ('token_i32', 'token_i64', 'token_str', 'token_bytes', 'token_email'): - if isinstance(source_data[idx][k], bytearray) and isinstance(values[idx][k], str): - self.assertEqual(source_data[idx][k], bytearray(values[idx][k], encoding='utf-8')) + if isinstance(source_data[idx][k], (bytearray, bytes)) and isinstance(values[idx][k], str): + self.assertEqual(source_data[idx][k], bytes(values[idx][k], encoding='utf-8')) else: self.assertEqual(source_data[idx][k], values[idx][k]) self.assertNotEqual(hidden_data[idx][k], values[idx][k]) @@ -7009,8 +7019,8 @@ def testTokenizationSpecificClientID(self): # data owner take source data for k in ('token_i32', 'token_i64', 'token_str', 'token_bytes', 'token_email'): - if isinstance(source_data[0][k], bytearray) and isinstance(data[k], str): - self.assertEqual(source_data[0][k], bytearray(data[k], encoding='utf-8')) + if isinstance(source_data[0][k], (bytearray, bytes)) and isinstance(data[k], str): + self.assertEqual(source_data[0][k], bytes(data[k], encoding='utf-8')) else: self.assertEqual(source_data[0][k], data[k]) self.assertNotEqual(hidden_data[0][k], data[k]) @@ -7058,12 +7068,14 @@ def testTokenizationDefaultClientIDStarExpression(self): # data owner take source data for k in ('token_i32', 'token_i64', 'token_str', 'token_bytes', 'token_email'): - # binary data returned as memoryview objects - if isinstance(source_data[0][k], bytearray) and isinstance(data[k], str): - self.assertEqual(utils.memoryview_to_bytes(source_data[0][k]), bytearray(data[k], encoding='utf-8')) + # successfully decrypted data returned as string otherwise as bytes + # always encode to bytes to compare values with same type coercions + if isinstance(source_data[0][k], (bytearray, bytes, memoryview)) and isinstance(data[k], str): + self.assertEqual(utils.memoryview_to_bytes(source_data[0][k]), data[k].encode('utf-8')) + self.assertNotEqual(utils.memoryview_to_bytes(hidden_data[0][k]), data[k].encode('utf-8')) else: self.assertEqual(utils.memoryview_to_bytes(source_data[0][k]), data[k]) - self.assertNotEqual(utils.memoryview_to_bytes(hidden_data[0][k]), data[k]) + self.assertNotEqual(utils.memoryview_to_bytes(hidden_data[0][k]), data[k]) class TestReturningProcessingMixing: @@ -7260,9 +7272,9 @@ def testTokenizationSpecificZoneID(self): token_fields = ('token_i32', 'token_i64', 'token_str', 'token_bytes', 'token_email') # data owner take source data for k in token_fields: - if isinstance(source_data[0][k], bytearray) and isinstance(data[k], str): - self.assertEqual(source_data[0][k], bytearray(data[k], encoding='utf-8')) - self.assertEqual(hidden_data[0][k], bytearray(data[k], encoding='utf-8')) + if isinstance(source_data[0][k], (bytearray, bytes)) and isinstance(data[k], str): + self.assertEqual(source_data[0][k], data[k].encode('utf-8')) + self.assertEqual(hidden_data[0][k], data[k].encode('utf-8')) else: self.assertEqual(source_data[0][k], data[k]) self.assertEqual(hidden_data[0][k], data[k]) @@ -7332,9 +7344,9 @@ def testTokenizationSpecificZoneIDStarExpression(self): token_fields = ('token_i32', 'token_i64', 'token_str', 'token_bytes', 'token_email') # data owner take source data for k in token_fields: - if isinstance(source_data[0][k], bytearray) and isinstance(data[k], str): - self.assertEqual(utils.memoryview_to_bytes(source_data[0][k]), bytearray(data[k], encoding='utf-8')) - self.assertEqual(utils.memoryview_to_bytes(hidden_data[0][k]), bytearray(data[k], encoding='utf-8')) + if isinstance(source_data[0][k], (bytearray, bytes)) and isinstance(data[k], str): + self.assertEqual(utils.memoryview_to_bytes(source_data[0][k]), data[k].encode('utf-8')) + self.assertEqual(utils.memoryview_to_bytes(hidden_data[0][k]), data[k].encode('utf-8')) else: self.assertEqual(utils.memoryview_to_bytes(source_data[0][k]), data[k]) self.assertEqual(utils.memoryview_to_bytes(hidden_data[0][k]), data[k]) @@ -8435,6 +8447,317 @@ def testOctalIntegerValue(self): self.assertNotEqual(hidden_data[0][k], data[k]) +class TestPostgresqlTextFormatTypeAwareDecryptionWithDefaults(BaseTransparentEncryption): + # test table used for queries and data mapping into python types + test_table = sa.Table( + # use new object of metadata to avoid name conflict + 'test_type_aware_decryption_with_defaults', sa.MetaData(), + sa.Column('id', sa.Integer, primary_key=True), + sa.Column('value_str', sa.Text), + sa.Column('value_bytes', sa.LargeBinary), + sa.Column('value_int32', sa.Integer), + sa.Column('value_int64', sa.BigInteger), + sa.Column('value_null_str', sa.Text, nullable=True, default=None), + sa.Column('value_null_int32', sa.Integer, nullable=True, default=None), + sa.Column('value_empty_str', sa.Text, nullable=False, default=''), + ) + # schema table used to generate table in the database with binary column types + schema_table = sa.Table( + + 'test_type_aware_decryption_with_defaults', metadata, + sa.Column('id', sa.Integer, primary_key=True), + sa.Column('value_str', sa.LargeBinary), + sa.Column('value_bytes', sa.LargeBinary), + sa.Column('value_int32', sa.LargeBinary), + sa.Column('value_int64', sa.LargeBinary), + sa.Column('value_null_str', sa.LargeBinary, nullable=True, default=None), + sa.Column('value_null_int32', sa.LargeBinary, nullable=True, default=None), + sa.Column('value_empty_str', sa.LargeBinary, nullable=False, default=b''), + ) + ENCRYPTOR_CONFIG = get_encryptor_config('tests/encryptor_configs/transparent_type_aware_decryption.yaml') + + def checkSkip(self): + if not (TEST_POSTGRESQL and TEST_WITH_TLS): + self.skipTest("Test only for PostgreSQL with TLS") + + def testClientIDRead(self): + """test decrypting with correct clientID and not decrypting with + incorrect clientID or using direct connection to db + All result data should be valid for application. Not decrypted data should be returned with their default value + """ + data = { + 'id': get_random_id(), + 'value_str': random_str(), + 'value_bytes': random_bytes(), + 'value_int32': random_int32(), + 'value_int64': random_int64(), + 'value_null_str': None, + 'value_null_int32': None, + 'value_empty_str': '' + } + default_expected_values = { + 'value_int32': 32, + 'value_int64': 64, + 'value_bytes': b'value_bytes', + 'value_str': 'value_str', + } + + self.schema_table.create(bind=self.engine_raw, checkfirst=True) + columns = ('value_str', 'value_bytes', 'value_int32', 'value_int64', 'value_null_str', 'value_null_int32', + 'value_empty_str') + self.engine1.execute(self.test_table.insert(), data) + result = self.engine1.execute( + sa.select([self.test_table]) + .where(self.test_table.c.id == data['id'])) + row = result.fetchone() + for column in columns: + self.assertEqual(data[column], row[column]) + self.assertIsInstance(row[column], type(data[column])) + + result = self.engine2.execute( + sa.select([self.test_table]) + .where(self.test_table.c.id == data['id'])) + row = result.fetchone() + for column in columns: + self.assertIsInstance(row[column], type(data[column])) + if 'null' in column: + self.assertIsNone(row[column]) + continue + if 'empty' in column: + self.assertEqual(row[column], '') + continue + self.assertNotEqual(data[column], row[column]) + if column in ('value_int32', 'value_int64'): + self.assertEqual(row[column], default_expected_values[column]) + + result = self.engine_raw.execute( + sa.select([self.test_table]) + .where(self.test_table.c.id == data['id'])) + row = result.fetchone() + for column in columns: + if 'null' in column: + self.assertIsNone(row[column]) + continue + self.assertIsInstance(utils.memoryview_to_bytes(row[column]), bytes) + if column in ('value_str', 'value_bytes'): + # length of data should be greater than source data due to encryption overhead + self.assertTrue(len(utils.memoryview_to_bytes(row[column])) > len(data[column])) + + +class TestPostgresqlBinaryFormatTypeAwareDecryptionWithDefaults( + BaseBinaryPostgreSQLTestCase, TestPostgresqlTextFormatTypeAwareDecryptionWithDefaults): + def testClientIDRead(self): + """test decrypting with correct clientID and not decrypting with + incorrect clientID or using direct connection to db + All result data should be valid for application. Not decrypted data should be returned with their default value + """ + data = { + 'id': get_random_id(), + 'value_str': random_str(), + 'value_bytes': random_bytes(), + 'value_int32': random_int32(), + 'value_int64': random_int64(), + 'value_null_str': None, + 'value_null_int32': None, + 'value_empty_str': '' + } + default_expected_values = { + 'value_int32': 32, + 'value_int64': 64, + 'value_bytes': b'value_bytes', + 'value_str': 'value_str', + } + + self.schema_table.create(bind=self.engine_raw, checkfirst=True) + columns = ('value_str', 'value_bytes', 'value_int32', 'value_int64', 'value_null_str', 'value_null_int32', + 'value_empty_str') + query, args = self.compileQuery(self.test_table.insert(), data) + self.executor1.execute_prepared_statement(query, args) + + query, args = self.compileQuery( + sa.select([self.test_table]) + .where(self.test_table.c.id == sa.bindparam('id')), {'id': data['id']}) + row = self.executor1.execute_prepared_statement(query, args)[0] + for column in columns: + if 'null' in column: + # asyncpg decodes None values as empty str/bytes value + self.assertFalse(row[column]) + continue + self.assertEqual(data[column], row[column]) + self.assertIsInstance(row[column], type(data[column])) + + row = self.executor2.execute_prepared_statement(query, args)[0] + for column in columns: + if 'null' in column: + # asyncpg decodes None values as empty str/bytes value + self.assertFalse(row[column]) + continue + if 'empty' in column: + self.assertEqual(data[column], row[column]) + else: + self.assertNotEqual(data[column], row[column]) + self.assertIsInstance(row[column], type(data[column])) + + row = self.executor2.execute_prepared_statement(query, args)[0] + for column in columns: + if 'null' in column or 'empty' in column: + # asyncpg decodes None values as empty str/bytes value + self.assertFalse(row[column]) + continue + self.assertNotEqual(data[column], row[column]) + + +class TestPostgresqlTextTypeAwareDecryptionWithoutDefaults(BaseTransparentEncryption): + # test table used for queries and data mapping into python types + test_table = sa.Table( + # use new object of metadata to avoid name conflict + 'test_type_aware_decryption_without_defaults', sa.MetaData(), + sa.Column('id', sa.Integer, primary_key=True), + sa.Column('value_str', sa.Text), + sa.Column('value_bytes', sa.LargeBinary), + sa.Column('value_int32', sa.Integer), + sa.Column('value_int64', sa.BigInteger), + sa.Column('value_null_str', sa.Text, nullable=True, default=None), + sa.Column('value_null_int32', sa.Integer, nullable=True, default=None), + sa.Column('value_empty_str', sa.Text, nullable=False, default=''), + ) + # schema table used to generate table in the database with binary column types + schema_table = sa.Table( + + 'test_type_aware_decryption_without_defaults', metadata, + sa.Column('id', sa.Integer, primary_key=True), + sa.Column('value_str', sa.LargeBinary), + sa.Column('value_bytes', sa.LargeBinary), + sa.Column('value_int32', sa.LargeBinary), + sa.Column('value_int64', sa.LargeBinary), + sa.Column('value_null_str', sa.LargeBinary, nullable=True, default=None), + sa.Column('value_null_int32', sa.LargeBinary, nullable=True, default=None), + sa.Column('value_empty_str', sa.LargeBinary, nullable=False, default=b''), + ) + ENCRYPTOR_CONFIG = get_encryptor_config('tests/encryptor_configs/transparent_type_aware_decryption.yaml') + + def checkSkip(self): + if not (TEST_POSTGRESQL and TEST_WITH_TLS): + self.skipTest("Test only for PostgreSQL with TLS") + + def testClientIDRead(self): + """test decrypting with correct clientID and not decrypting with + incorrect clientID or using direct connection to db + All result data should be valid for application. Not decrypted data should be returned as is and DB driver + should cause error + """ + data = { + 'id': get_random_id(), + 'value_str': random_str(), + 'value_bytes': random_bytes(), + 'value_int32': random_int32(), + 'value_int64': random_int64(), + 'value_null_str': None, + 'value_null_int32': None, + 'value_empty_str': '' + } + self.schema_table.create(bind=self.engine_raw, checkfirst=True) + self.engine1.execute(self.test_table.insert(), data) + columns = ('value_str', 'value_bytes', 'value_int32', 'value_int64', 'value_null_str', 'value_null_int32', + 'value_empty_str') + + result = self.engine2.execute( + sa.select([self.test_table]) + .where(self.test_table.c.id == data['id'])) + # acra change types for binary data columns and python driver can't decode to correct types + with self.assertRaises(UnicodeDecodeError): + row = result.fetchone() + + # direct connection should receive binary data according to real scheme + result = self.engine_raw.execute( + sa.select([self.test_table]) + .where(self.test_table.c.id == data['id'])) + row = result.fetchone() + for column in columns: + if 'null' in column or 'empty' in column: + # asyncpg decodes None values as empty str/bytes value + self.assertFalse(row[column]) + continue + value = utils.memoryview_to_bytes(row[column]) + self.assertIsInstance(value, bytes, column) + self.assertNotEqual(data[column], value, column) + + +class TestPostgresqlBinaryTypeAwareDecryptionWithoutDefaults(TestPostgresqlBinaryFormatTypeAwareDecryptionWithDefaults): + # test table used for queries and data mapping into python types + test_table = sa.Table( + # use new object of metadata to avoid name conflict + 'test_type_aware_decryption_without_defaults', sa.MetaData(), + sa.Column('id', sa.Integer, primary_key=True), + sa.Column('value_str', sa.Text), + sa.Column('value_bytes', sa.LargeBinary), + sa.Column('value_int32', sa.Integer), + sa.Column('value_int64', sa.BigInteger), + sa.Column('value_null_str', sa.Text, nullable=True, default=None), + sa.Column('value_null_int32', sa.Integer, nullable=True, default=None), + sa.Column('value_empty_str', sa.Text, nullable=False, default=''), + extend_existing=True + ) + # schema table used to generate table in the database with binary column types + schema_table = sa.Table( + 'test_type_aware_decryption_without_defaults', metadata, + sa.Column('id', sa.Integer, primary_key=True), + sa.Column('value_str', sa.LargeBinary), + sa.Column('value_bytes', sa.LargeBinary), + sa.Column('value_int32', sa.LargeBinary), + sa.Column('value_int64', sa.LargeBinary), + sa.Column('value_null_str', sa.LargeBinary, nullable=True, default=None), + sa.Column('value_null_int32', sa.LargeBinary, nullable=True, default=None), + sa.Column('value_empty_str', sa.LargeBinary, nullable=False, default=b''), + extend_existing=True + ) + ENCRYPTOR_CONFIG = get_encryptor_config('tests/encryptor_configs/transparent_type_aware_decryption.yaml') + + def checkSkip(self): + if not (TEST_POSTGRESQL and TEST_WITH_TLS): + self.skipTest("Test only for PostgreSQL with TLS") + + def testClientIDRead(self): + """test decrypting with correct clientID and not decrypting with + incorrect clientID or using direct connection to db + All result data should be valid for application. Not decrypted data should be returned as is and DB driver + should cause error + """ + data = { + 'id': get_random_id(), + 'value_str': random_str(), + 'value_bytes': random_bytes(), + 'value_int32': random_int32(), + 'value_int64': random_int64(), + 'value_null_str': None, + 'value_null_int32': None, + 'value_empty_str': '' + } + self.schema_table.create(bind=self.engine_raw, checkfirst=True) + ###### + columns = ('value_str', 'value_bytes', 'value_int32', 'value_int64', 'value_null_str', 'value_null_int32', + 'value_empty_str') + query, args = self.compileQuery(self.test_table.insert(), data) + self.executor1.execute_prepared_statement(query, args) + + query, args = self.compileQuery( + sa.select([self.test_table]) + .where(self.test_table.c.id == sa.bindparam('id')), {'id': data['id']}) + + with self.assertRaises(UnicodeDecodeError): + row = self.executor2.execute_prepared_statement(query, args)[0] + + row = self.raw_executor.execute_prepared_statement(query, args)[0] + for column in columns: + if 'null' in column or 'empty' in column: + # asyncpg decodes None values as empty str/bytes value + self.assertFalse(row[column]) + continue + value = utils.memoryview_to_bytes(row[column]) + self.assertIsInstance(value, bytes, column) + self.assertNotEqual(data[column], value, column) + + if __name__ == '__main__': import xmlrunner output_path = os.environ.get('TEST_XMLOUTPUT', '') diff --git a/utils/dbByteArrayEncoders.go b/utils/dbByteArrayEncoders.go index 01de40694..8d237efc6 100644 --- a/utils/dbByteArrayEncoders.go +++ b/utils/dbByteArrayEncoders.go @@ -87,56 +87,30 @@ func DecodeOctal(data []byte) ([]byte, error) { return output, nil } -// DecodedData wrap binary data which should be encoded in final format after usage -type DecodedData struct { - data []byte - encodeFunc func([]byte) []byte -} - -// Data return binary data -func (d *DecodedData) Data() []byte { - return d.data -} - -// Set set binary data -func (d *DecodedData) Set(data []byte) { - d.data = data -} - -// Encoded return encoded binary data in final format according to encoding logic -func (d *DecodedData) Encoded() []byte { - return d.encodeFunc(d.data) -} - -func hexEncode(data []byte) []byte { +// PgEncodeToHex encode binary data to hex SQL literal +func PgEncodeToHex(data []byte) []byte { output := make([]byte, 2+hex.EncodedLen(len(data))) copy(output[:2], []byte{'\\', 'x'}) hex.Encode(output[2:], data) return output } -func dryEncode(data []byte) []byte { - return data -} - -// WrapRawDataAsDecoded return DecodedData with Encode function which return data as is -func WrapRawDataAsDecoded(data []byte) *DecodedData { - return &DecodedData{data: data, encodeFunc: dryEncode} -} - // DecodeEscaped with hex or octal encodings -func DecodeEscaped(data []byte) (*DecodedData, error) { +func DecodeEscaped(data []byte) ([]byte, error) { if len(data) >= 2 && bytes.Equal(data[:2], []byte{'\\', 'x'}) { hexdata := data[2:] output := make([]byte, hex.DecodedLen(len(hexdata))) _, err := hex.Decode(output, hexdata) - return &DecodedData{data: output, encodeFunc: hexEncode}, err + if err != nil { + return data, err + } + return output, err } result, err := DecodeOctal(data) if err != nil { - return &DecodedData{data: data, encodeFunc: dryEncode}, ErrDecodeOctalString + return data, ErrDecodeOctalString } - return &DecodedData{data: result, encodeFunc: EncodeToOctal}, nil + return result, nil } // QuoteValue returns name in quotes, if name contains quotes, doubles them diff --git a/utils/dbByteArrayEncoders_test.go b/utils/dbByteArrayEncoders_test.go index b88a6b057..5508df000 100644 --- a/utils/dbByteArrayEncoders_test.go +++ b/utils/dbByteArrayEncoders_test.go @@ -35,33 +35,6 @@ func TestEncodeToOctal(t *testing.T) { } } -func TestDecodeEscaped(t *testing.T) { - type testcase struct { - data []byte - expected []byte - err error - } - testcases := []testcase{ - {[]byte("\\x"), []byte{}, nil}, - {[]byte{}, []byte{}, nil}, - {[]byte("\\001"), []byte{1}, nil}, - {[]byte("\\x01"), []byte{1}, nil}, - } - for _, tcase := range testcases { - result, err := DecodeEscaped(tcase.data) - if err != tcase.err { - t.Fatalf("Incorrect error, took %s, expects %s\n", err, tcase.err) - } - if result == nil { - t.Fatal("Result is nil") - } - if !bytes.Equal(result.Data(), tcase.expected) { - t.Fatalf("Invalid result, took %v, expects %v\n", result, tcase.expected) - } - } - -} - func BenchmarkEncodeToOctal(b *testing.B) { data := make([]byte, 256) for i := 0; i < len(data); i++ { diff --git a/utils/escape_format.go b/utils/escape_format.go index 393facdc2..e09bc55c5 100644 --- a/utils/escape_format.go +++ b/utils/escape_format.go @@ -16,6 +16,10 @@ limitations under the License. package utils +import ( + "unicode" +) + // IsPrintableEscapeChar returns true if character is ASCII printable (code between 32 and 126) func IsPrintableEscapeChar(c byte) bool { if c >= 32 && c <= 126 { @@ -24,10 +28,17 @@ func IsPrintableEscapeChar(c byte) bool { return false } -// IsPrintableASCIIArray return true if all symbols in data are ASCII printable symbols -func IsPrintableASCIIArray(data []byte) bool { - for _, c := range data { - if !IsPrintableEscapeChar(c) { +// IsPrintablePostgresqlString returns true if it's valid ASCII or utf8 string except '\' character that used as escape +// character in strings +func IsPrintablePostgresqlString(data []byte) bool { + if len(data) == 0 { + return true + } + // convert byte slice to string to work with Runes instead of bytes + stringValue := BytesToString(data) + // '\' is special case because PostgreSQL escapes it + for _, c := range stringValue { + if c == '\\' || !unicode.IsPrint(c) { return false } } diff --git a/utils/utils.go b/utils/utils.go index 634d531cb..78c2f41a4 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -25,6 +25,7 @@ import ( "io/ioutil" "sync" "time" + "unsafe" log "github.com/sirupsen/logrus" "os" @@ -224,3 +225,14 @@ func WaitWithTimeout(wg *sync.WaitGroup, timeout time.Duration) bool { return true // timed out } } + +// BytesToString converts data to string with re-using same allocated memory +// Warning: data shouldn't be changed after that because it will cause runtime error due to strings are immutable +// Only for read/iterate operations +// See https://groups.google.com/forum/#!msg/Golang-Nuts/ENgbUzYvCuU/90yGx7GUAgAJ . +// +// Note it may break if string and/or slice header will change +// in the future go versions. +func BytesToString(data []byte) string { + return *(*string)(unsafe.Pointer(&data)) +} diff --git a/utils/utils_test.go b/utils/utils_test.go index 623ded446..c3cc22496 100644 --- a/utils/utils_test.go +++ b/utils/utils_test.go @@ -20,7 +20,9 @@ import ( "crypto/rand" "github.com/cossacklabs/themis/gothemis/keys" "os" + "reflect" "testing" + "unsafe" ) func TestFileExists(t *testing.T) { @@ -98,3 +100,34 @@ func TestZeroizeNilKeyPair(t *testing.T) { ZeroizeKeyPair(nil) // no panic ZeroizeKeyPair(&keys.Keypair{}) } + +func TestBytesToString(t *testing.T) { + type args struct { + data []byte + } + tests := []struct { + name string + args args + want string + }{ + {"nil array", args{nil}, ""}, + {"empty array", args{[]byte{}}, ""}, + {"empty array", args{[]byte(`some data`)}, `some data`}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := BytesToString(tt.args.data) + if got != tt.want { + t.Errorf("BytesToString() = %v, want %v", got, tt.want) + } + byteHeader := (*reflect.SliceHeader)(unsafe.Pointer(&tt.args.data)) + strHeader := (*reflect.StringHeader)(unsafe.Pointer(&got)) + if byteHeader.Len == 0 { + return + } + if byteHeader.Data != strHeader.Data { + t.Fatal("Result string uses another block of memory") + } + }) + } +}