Skip to content

Commit

Permalink
expression: rewrite builtin function: MD5, SHA1, SHA2 (pingcap#4109)
Browse files Browse the repository at this point in the history
  • Loading branch information
breezewish authored and XuHuaiyu committed Aug 10, 2017
1 parent 68f58dc commit bf97933
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 85 deletions.
154 changes: 75 additions & 79 deletions expression/builtin_encryption.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,36 +265,6 @@ func (b *builtinEncryptSig) eval(row []types.Datum) (d types.Datum, err error) {
return d, errFunctionNotExists.GenByArgs("ENCRYPT")
}

type md5FunctionClass struct {
baseFunctionClass
}

func (c *md5FunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
bf, err := newBaseBuiltinFuncWithTp(args, ctx, tpString, tpString)
if err != nil {
return nil, errors.Trace(err)
}
bf.tp.Flen = 32
sig := &builtinMD5Sig{baseStringBuiltinFunc{bf}}
return sig.setSelf(sig), errors.Trace(c.verifyArgs(args))
}

type builtinMD5Sig struct {
baseStringBuiltinFunc
}

// evalString evals a builtinMD5Sig.
// See https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_md5
func (b *builtinMD5Sig) evalString(row []types.Datum) (string, bool, error) {
arg, isNull, err := b.args[0].EvalString(row, b.ctx.GetSessionVars().StmtCtx)
if isNull || err != nil {
return "", isNull, errors.Trace(err)
}
sum := md5.Sum([]byte(arg))
hexStr := fmt.Sprintf("%x", sum)
return hexStr, false, nil
}

type oldPasswordFunctionClass struct {
baseFunctionClass
}
Expand Down Expand Up @@ -386,54 +356,92 @@ func (b *builtinRandomBytesSig) eval(row []types.Datum) (d types.Datum, err erro
return d, nil
}

type md5FunctionClass struct {
baseFunctionClass
}

func (c *md5FunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
if err := c.verifyArgs(args); err != nil {
return nil, errors.Trace(err)
}
bf, err := newBaseBuiltinFuncWithTp(args, ctx, tpString, tpString)
if err != nil {
return nil, errors.Trace(err)
}
bf.tp.Flen = 32
sig := &builtinMD5Sig{baseStringBuiltinFunc{bf}}
return sig.setSelf(sig), nil
}

type builtinMD5Sig struct {
baseStringBuiltinFunc
}

// evalString evals a builtinMD5Sig.
// See https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_md5
func (b *builtinMD5Sig) evalString(row []types.Datum) (string, bool, error) {
arg, isNull, err := b.args[0].EvalString(row, b.ctx.GetSessionVars().StmtCtx)
if isNull || err != nil {
return "", isNull, errors.Trace(err)
}
sum := md5.Sum([]byte(arg))
hexStr := fmt.Sprintf("%x", sum)
return hexStr, false, nil
}

type sha1FunctionClass struct {
baseFunctionClass
}

func (c *sha1FunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
sig := &builtinSHA1Sig{newBaseBuiltinFunc(args, ctx)}
return sig.setSelf(sig), errors.Trace(c.verifyArgs(args))
if err := c.verifyArgs(args); err != nil {
return nil, errors.Trace(err)
}
bf, err := newBaseBuiltinFuncWithTp(args, ctx, tpString, tpString)
if err != nil {
return nil, errors.Trace(err)
}
bf.tp.Flen = 40
sig := &builtinSHA1Sig{baseStringBuiltinFunc{bf}}
return sig.setSelf(sig), nil
}

type builtinSHA1Sig struct {
baseBuiltinFunc
baseStringBuiltinFunc
}

// eval evals a builtinSHA1Sig.
// evalString evals SHA1(str).
// See https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_sha1
// The value is returned as a string of 40 hexadecimal digits, or NULL if the argument was NULL.
func (b *builtinSHA1Sig) eval(row []types.Datum) (d types.Datum, err error) {
args, err := b.evalArgs(row)
if err != nil {
return types.Datum{}, errors.Trace(err)
}
// SHA/SHA1 function only accept 1 parameter
arg := args[0]
if arg.IsNull() {
return d, nil
}
bin, err := arg.ToBytes()
if err != nil {
return d, errors.Trace(err)
func (b *builtinSHA1Sig) evalString(row []types.Datum) (string, bool, error) {
str, isNull, err := b.args[0].EvalString(row, b.ctx.GetSessionVars().StmtCtx)
if isNull || err != nil {
return "", isNull, errors.Trace(err)
}
hasher := sha1.New()
hasher.Write(bin)
data := fmt.Sprintf("%x", hasher.Sum(nil))
d.SetString(data)
return d, nil
hasher.Write([]byte(str))
return fmt.Sprintf("%x", hasher.Sum(nil)), false, nil
}

type sha2FunctionClass struct {
baseFunctionClass
}

func (c *sha2FunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
sig := &builtinSHA2Sig{newBaseBuiltinFunc(args, ctx)}
return sig.setSelf(sig), errors.Trace(c.verifyArgs(args))
if err := c.verifyArgs(args); err != nil {
return nil, errors.Trace(err)
}
bf, err := newBaseBuiltinFuncWithTp(args, ctx, tpString, tpString, tpInt)
if err != nil {
return nil, errors.Trace(err)
}
bf.tp.Flen = 128 // sha512
sig := &builtinSHA2Sig{baseStringBuiltinFunc{bf}}
return sig.setSelf(sig), nil
}

type builtinSHA2Sig struct {
baseBuiltinFunc
baseStringBuiltinFunc
}

// Supported hash length of SHA-2 family
Expand All @@ -445,28 +453,16 @@ const (
SHA512 int = 512
)

// eval evals a builtinSHA2Sig.
// evalString evals SHA2(str, hash_length).
// See https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_sha2
func (b *builtinSHA2Sig) eval(row []types.Datum) (d types.Datum, err error) {
args, err := b.evalArgs(row)
if err != nil {
return d, errors.Trace(err)
}
for _, arg := range args {
if arg.IsNull() {
return d, nil
}
}
// Meaning of each argument:
// args[0]: the cleartext string to be hashed
// args[1]: desired bit length of result
bin, err := args[0].ToBytes()
if err != nil {
return d, errors.Trace(err)
func (b *builtinSHA2Sig) evalString(row []types.Datum) (string, bool, error) {
str, isNull, err := b.args[0].EvalString(row, b.ctx.GetSessionVars().StmtCtx)
if isNull || err != nil {
return "", isNull, errors.Trace(err)
}
hashLength, err := args[1].ToInt64(b.ctx.GetSessionVars().StmtCtx)
if err != nil {
return d, errors.Trace(err)
hashLength, isNull, err := b.args[1].EvalInt(row, b.ctx.GetSessionVars().StmtCtx)
if isNull || err != nil {
return "", isNull, errors.Trace(err)
}
var hasher hash.Hash
switch int(hashLength) {
Expand All @@ -479,12 +475,12 @@ func (b *builtinSHA2Sig) eval(row []types.Datum) (d types.Datum, err error) {
case SHA512:
hasher = sha512.New()
}
if hasher != nil {
hasher.Write(bin)
data := fmt.Sprintf("%x", hasher.Sum(nil))
d.SetString(data)
if hasher == nil {
return "", true, nil
}
return d, nil

hasher.Write([]byte(str))
return fmt.Sprintf("%x", hasher.Sum(nil)), false, nil
}

// deflate compresses a string using the DEFLATE format.
Expand Down
10 changes: 5 additions & 5 deletions expression/builtin_encryption_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func fromHex(str interface{}) (d types.Datum) {
return d
}

var shaTests = []struct {
var sha1Tests = []struct {
origin interface{}
crypt string
}{
Expand All @@ -112,10 +112,10 @@ var shaTests = []struct {
{123.45, "22f8b438ad7e89300b51d88684f3f0b9fa1d7a32"},
}

func (s *testEvaluatorSuite) TestShaEncrypt(c *C) {
func (s *testEvaluatorSuite) TestSha1Hash(c *C) {
defer testleak.AfterTest(c)()
fc := funcs[ast.SHA]
for _, tt := range shaTests {
for _, tt := range sha1Tests {
in := types.NewDatum(tt.origin)
f, _ := fc.getFunction(datumsToConstants([]types.Datum{in}), s.ctx)
crypt, err := f.eval(nil)
Expand Down Expand Up @@ -153,7 +153,7 @@ var sha2Tests = []struct {
{"pingcap", 123, nil, false},
}

func (s *testEvaluatorSuite) TestSha2Encrypt(c *C) {
func (s *testEvaluatorSuite) TestSha2Hash(c *C) {
defer testleak.AfterTest(c)()
fc := funcs[ast.SHA2]
for _, tt := range sha2Tests {
Expand All @@ -172,7 +172,7 @@ func (s *testEvaluatorSuite) TestSha2Encrypt(c *C) {
}
}

func (s *testEvaluatorSuite) TestMD5(c *C) {
func (s *testEvaluatorSuite) TestMD5Hash(c *C) {
defer testleak.AfterTest(c)()

cases := []struct {
Expand Down
25 changes: 25 additions & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,31 @@ func (s *testIntegrationSuite) TestEncryptionBuiltin(c *C) {
result = tk.MustQuery("select md5('123'), md5(123), md5(''), md5('你好'), md5(NULL), md5('👍')")
result.Check(testkit.Rows(`202cb962ac59075b964b07152d234b70 202cb962ac59075b964b07152d234b70 d41d8cd98f00b204e9800998ecf8427e 7eca689f0d3389d9dea66ae112e5cfd7 <nil> 0215ac4dab1ecaf71d83f98af5726984`))

// for sha/sha1
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a char(10), b int, c double, d datetime, e time, f bit(4), g binary(20), h blob(10), i text(30))")
tk.MustExec(`insert into t values('2', 2, 2.3, "2017-01-01 12:01:01", "12:01:01", 0b1010, "512", "48", "tidb")`)
result = tk.MustQuery("select sha1(a), sha1(b), sha1(c), sha1(d), sha1(e), sha1(f), sha1(g), sha1(h), sha1(i) from t")
result.Check(testkit.Rows("da4b9237bacccdf19c0760cab7aec4a8359010b0 da4b9237bacccdf19c0760cab7aec4a8359010b0 ce0d88c5002b6cf7664052f1fc7d652cbdadccec 6c6956de323692298e4e5ad3028ff491f7ad363c 1906f8aeb5a717ca0f84154724045839330b0ea9 adc83b19e793491b1c6ea0fd8b46cd9f32e592fc 9aadd14ceb737b28697b8026f205f4b3e31de147 64e095fe763fc62418378753f9402623bea9e227 4df56fc09a3e66b48fb896e90b0a6fc02c978e9e"))
result = tk.MustQuery("select sha1('123'), sha1(123), sha1(''), sha1('你好'), sha1(NULL)")
result.Check(testkit.Rows(`40bd001563085fc35165329ea1ff5c5ecbdbbeef 40bd001563085fc35165329ea1ff5c5ecbdbbeef da39a3ee5e6b4b0d3255bfef95601890afd80709 440ee0853ad1e99f962b63e459ef992d7c211722 <nil>`))
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a char(10), b int, c double, d datetime, e time, f bit(4), g binary(20), h blob(10), i text(30))")
tk.MustExec(`insert into t values('2', 2, 2.3, "2017-01-01 12:01:01", "12:01:01", 0b1010, "512", "48", "tidb")`)
result = tk.MustQuery("select sha(a), sha(b), sha(c), sha(d), sha(e), sha(f), sha(g), sha(h), sha(i) from t")
result.Check(testkit.Rows("da4b9237bacccdf19c0760cab7aec4a8359010b0 da4b9237bacccdf19c0760cab7aec4a8359010b0 ce0d88c5002b6cf7664052f1fc7d652cbdadccec 6c6956de323692298e4e5ad3028ff491f7ad363c 1906f8aeb5a717ca0f84154724045839330b0ea9 adc83b19e793491b1c6ea0fd8b46cd9f32e592fc 9aadd14ceb737b28697b8026f205f4b3e31de147 64e095fe763fc62418378753f9402623bea9e227 4df56fc09a3e66b48fb896e90b0a6fc02c978e9e"))
result = tk.MustQuery("select sha('123'), sha(123), sha(''), sha('你好'), sha(NULL)")
result.Check(testkit.Rows(`40bd001563085fc35165329ea1ff5c5ecbdbbeef 40bd001563085fc35165329ea1ff5c5ecbdbbeef da39a3ee5e6b4b0d3255bfef95601890afd80709 440ee0853ad1e99f962b63e459ef992d7c211722 <nil>`))

// for sha2
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a char(10), b int, c double, d datetime, e time, f bit(4), g binary(20), h blob(10), i text(30))")
tk.MustExec(`insert into t values('2', 2, 2.3, "2017-01-01 12:01:01", "12:01:01", 0b1010, "512", "48", "tidb")`)
result = tk.MustQuery("select sha2(a, 224), sha2(b, 0), sha2(c, 512), sha2(d, 256), sha2(e, 384), sha2(f, 0), sha2(g, 512), sha2(h, 256), sha2(i, 224) from t")
result.Check(testkit.Rows("58b2aaa0bfae7acc021b3260e941117b529b2e69de878fd7d45c61a9 d4735e3a265e16eee03f59718b9b5d03019c07d8b6c51f90da3a666eec13ab35 42415572557b0ca47e14fa928e83f5746d33f90c74270172cc75c61a78db37fe1485159a4fd75f33ab571b154572a5a300938f7d25969bdd05d8ac9dd6c66123 8c2fa3f276952c92b0b40ed7d27454e44b8399a19769e6bceb40da236e45a20a b11d35f1a37e54d5800d210d8e6b80b42c9f6d20ea7ae548c762383ebaa12c5954c559223c6c7a428e37af96bb4f1e0d 01ba4719c80b6fe911b091a7c05124b64eeece964e09c058ef8f9805daca546b 9550da35ea1683abaf5bfa8de68fe02b9c6d756c64589d1ef8367544c254f5f09218a6466cadcee8d74214f0c0b7fb342d1a9f3bd4d406aacf7be59c327c9306 98010bd9270f9b100b6214a21754fd33bdc8d41b2bc9f9dd16ff54d3c34ffd71 a7cddb7346fbc66ab7f803e865b74cbd99aace8e7dabbd8884c148cb"))
result = tk.MustQuery("select sha2('123', 512), sha2(123, 512), sha2('', 512), sha2('你好', 224), sha2(NULL, 256), sha2('foo', 123)")
result.Check(testkit.Rows(`3c9909afec25354d551dae21590bb26e38d53f2173b8d3dc3eee4c047e7ab1c1eb8b85103e3be7ba613b31bb5c9c36214dc9f14a42fd7a2fdb84856bca5c44c2 3c9909afec25354d551dae21590bb26e38d53f2173b8d3dc3eee4c047e7ab1c1eb8b85103e3be7ba613b31bb5c9c36214dc9f14a42fd7a2fdb84856bca5c44c2 cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e e91f006ed4e0882de2f6a3c96ec228a6a5c715f356d00091bce842b5 <nil> <nil>`))

// for AES_ENCRYPT
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a char(10), b int, c double, d datetime, e time, f bit(4), g binary(20), h blob(10), i text(30))")
Expand Down
Loading

0 comments on commit bf97933

Please sign in to comment.