Skip to content

Commit

Permalink
bug fix: flag of builtin 'IFNULL' 's result is not consistent with my…
Browse files Browse the repository at this point in the history
…sql (pingcap#4158)
  • Loading branch information
spongedu authored and zz-jason committed Aug 21, 2017
1 parent 97e9dfe commit 2df9456
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 110 deletions.
166 changes: 79 additions & 87 deletions expression/builtin_control.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,67 @@ type caseWhenFunctionClass struct {
baseFunctionClass
}

// Infer result type for builtin IF, IFNULL && NULLIF.
func inferType4ControlFuncs(tp1, tp2 *types.FieldType) *types.FieldType {
retTp, typeClass := &types.FieldType{}, types.ClassString
if tp1.Tp == mysql.TypeNull {
*retTp, typeClass = *tp2, tp2.ToClass()
// If both arguments are NULL, make resulting type BINARY(0).
if tp2.Tp == mysql.TypeNull {
retTp.Tp, typeClass = mysql.TypeString, types.ClassString
retTp.Flen, retTp.Decimal = 0, 0
types.SetBinChsClnFlag(retTp)
}
} else if tp2.Tp == mysql.TypeNull {
*retTp, typeClass = *tp1, tp1.ToClass()
} else {
var unsignedFlag uint
typeClass = types.AggTypeClass([]*types.FieldType{tp1, tp2}, &unsignedFlag)
retTp = types.AggFieldType([]*types.FieldType{tp1, tp2})
retTp.Decimal = 0
if typeClass != types.ClassInt {
retTp.Decimal = mathutil.Max(tp1.Decimal, tp2.Decimal)
}
if types.IsNonBinaryStr(tp1) && !types.IsBinaryStr(tp2) {
retTp.Charset, retTp.Collate, retTp.Flag = charset.CharsetUTF8, charset.CollationUTF8, 0
if mysql.HasBinaryFlag(tp1.Flag) {
retTp.Flag |= mysql.BinaryFlag
}
} else if types.IsNonBinaryStr(tp2) && !types.IsBinaryStr(tp1) {
retTp.Charset, retTp.Collate, retTp.Flag = charset.CharsetUTF8, charset.CollationUTF8, 0
if mysql.HasBinaryFlag(tp2.Flag) {
retTp.Flag |= mysql.BinaryFlag
}
} else if types.IsBinaryStr(tp1) || types.IsBinaryStr(tp2) || typeClass != types.ClassString {
types.SetBinChsClnFlag(retTp)
} else {
retTp.Charset, retTp.Collate, retTp.Flag = charset.CharsetUTF8, charset.CollationUTF8, 0
}
if typeClass == types.ClassDecimal || typeClass == types.ClassInt {
unsignedFlag1, unsignedFlag2 := mysql.HasUnsignedFlag(tp1.Flag), mysql.HasUnsignedFlag(tp2.Flag)
flagLen1, flagLen2 := 0, 0
if !unsignedFlag1 {
flagLen1 = 1
}
if !unsignedFlag2 {
flagLen2 = 1
}
len1 := tp1.Flen - flagLen1
len2 := tp2.Flen - flagLen2
if tp1.Decimal != types.UnspecifiedLength {
len1 -= tp1.Decimal
}
if tp1.Decimal != types.UnspecifiedLength {
len2 -= tp2.Decimal
}
retTp.Flen = mathutil.Max(len1, len2) + retTp.Decimal + 1
} else {
retTp.Flen = mathutil.Max(tp1.Flen, tp2.Flen)
}
}
return retTp
}

func (c *caseWhenFunctionClass) getFunction(args []Expression, ctx context.Context) (sig builtinFunc, err error) {
if err = c.verifyArgs(args); err != nil {
return nil, errors.Trace(err)
Expand Down Expand Up @@ -314,7 +375,7 @@ func (c *ifFunctionClass) getFunction(args []Expression, ctx context.Context) (s
if err = c.verifyArgs(args); err != nil {
return nil, errors.Trace(err)
}
retTp := c.inferType(args[1].GetType(), args[2].GetType())
retTp := inferType4ControlFuncs(args[1].GetType(), args[2].GetType())
evalTps := fieldTp2EvalTp(retTp)
bf, err := newBaseBuiltinFuncWithTp(args, ctx, evalTps, tpInt, evalTps, evalTps)
if err != nil {
Expand All @@ -338,54 +399,6 @@ func (c *ifFunctionClass) getFunction(args []Expression, ctx context.Context) (s
return sig.setSelf(sig), nil
}

func (c *ifFunctionClass) inferType(tp1, tp2 *types.FieldType) *types.FieldType {
retTp, typeClass := &types.FieldType{}, types.ClassString
if tp1.Tp == mysql.TypeNull {
*retTp, typeClass = *tp2, tp2.ToClass()
// If both arguments are NULL, make resulting type BINARY(0).
if tp2.Tp == mysql.TypeNull {
retTp.Tp, typeClass = mysql.TypeString, types.ClassString
retTp.Flen, retTp.Decimal = 0, 0
types.SetBinChsClnFlag(retTp)
}
} else if tp2.Tp == mysql.TypeNull {
*retTp, typeClass = *tp1, tp1.ToClass()
} else {
var unsignedFlag uint
typeClass = types.AggTypeClass([]*types.FieldType{tp1, tp2}, &unsignedFlag)
retTp = types.AggFieldType([]*types.FieldType{tp1, tp2})
retTp.Decimal = mathutil.Max(tp1.Decimal, tp2.Decimal)
types.SetBinChsClnFlag(retTp)
if types.IsNonBinaryStr(tp1) && !types.IsBinaryStr(tp2) {
retTp.Charset, retTp.Collate, retTp.Flag = charset.CharsetUTF8, charset.CollationUTF8, 0
if mysql.HasBinaryFlag(tp1.Flag) {
retTp.Flag |= mysql.BinaryFlag
}
} else if types.IsNonBinaryStr(tp2) && !types.IsBinaryStr(tp1) {
retTp.Charset, retTp.Collate, retTp.Flag = charset.CharsetUTF8, charset.CollationUTF8, 0
if mysql.HasBinaryFlag(tp2.Flag) {
retTp.Flag |= mysql.BinaryFlag
}
}
if typeClass == types.ClassDecimal || typeClass == types.ClassInt {
unsignedFlag1, unsignedFlag2 := mysql.HasUnsignedFlag(tp1.Flag), mysql.HasUnsignedFlag(tp2.Flag)
flagLen1, flagLen2 := 0, 0
if !unsignedFlag1 {
flagLen1 = 1
}
if !unsignedFlag2 {
flagLen2 = 1
}
len1 := tp1.Flen - tp1.Decimal - flagLen1
len2 := tp2.Flen - tp2.Decimal - flagLen2
retTp.Flen = mathutil.Max(len1, len2) + retTp.Decimal + 1
} else {
retTp.Flen = mathutil.Max(tp1.Flen, tp2.Flen)
}
}
return retTp
}

type builtinIfIntSig struct {
baseIntBuiltinFunc
}
Expand Down Expand Up @@ -557,53 +570,32 @@ func (c *ifNullFunctionClass) getFunction(args []Expression, ctx context.Context
return nil, errors.Trace(err)
}
tp0, tp1 := args[0].GetType(), args[1].GetType()
fieldTp := types.AggFieldType([]*types.FieldType{tp0, tp1})
types.SetBinChsClnFlag(fieldTp)
classType := types.AggTypeClass([]*types.FieldType{tp0, tp1}, &fieldTp.Flag)
fieldTp.Decimal = mathutil.Max(tp0.Decimal, tp1.Decimal)
// TODO: make it more accurate when inferring FLEN
fieldTp.Flen = tp0.Flen + tp1.Flen

var evalTps evalTp
switch classType {
case types.ClassInt:
evalTps = tpInt
fieldTp.Decimal = 0
case types.ClassReal:
evalTps = tpReal
case types.ClassDecimal:
evalTps = tpDecimal
case types.ClassString:
evalTps = tpString
if !types.IsBinaryStr(tp0) && !types.IsBinaryStr(tp1) {
fieldTp.Charset, fieldTp.Collate = mysql.DefaultCharset, mysql.DefaultCollationName
fieldTp.Flag ^= mysql.BinaryFlag
}
if types.IsTypeTime(fieldTp.Tp) {
evalTps = tpTime
} else if fieldTp.Tp == mysql.TypeDuration {
evalTps = tpDuration
}
retTp := inferType4ControlFuncs(tp0, tp1)
retTp.Flag |= (tp0.Flag & mysql.NotNullFlag) | (tp1.Flag & mysql.NotNullFlag)
if tp0.Tp == mysql.TypeNull && tp1.Tp == mysql.TypeNull {
retTp.Tp = mysql.TypeNull
retTp.Flen, retTp.Decimal = 0, -1
types.SetBinChsClnFlag(retTp)
}
evalTps := fieldTp2EvalTp(retTp)
bf, err := newBaseBuiltinFuncWithTp(args, ctx, evalTps, evalTps, evalTps)
if err != nil {
return nil, errors.Trace(err)
}
bf.tp = fieldTp
switch classType {
case types.ClassInt:
bf.tp = retTp
switch evalTps {
case tpInt:
sig = &builtinIfNullIntSig{baseIntBuiltinFunc{bf}}
case types.ClassReal:
case tpReal:
sig = &builtinIfNullRealSig{baseRealBuiltinFunc{bf}}
case types.ClassDecimal:
case tpDecimal:
sig = &builtinIfNullDecimalSig{baseDecimalBuiltinFunc{bf}}
case types.ClassString:
case tpString:
sig = &builtinIfNullStringSig{baseStringBuiltinFunc{bf}}
if types.IsTypeTime(fieldTp.Tp) {
sig = &builtinIfNullTimeSig{baseTimeBuiltinFunc{bf}}
} else if fieldTp.Tp == mysql.TypeDuration {
sig = &builtinIfNullDurationSig{baseDurationBuiltinFunc{bf}}
}
case tpTime:
sig = &builtinIfNullTimeSig{baseTimeBuiltinFunc{bf}}
case tpDuration:
sig = &builtinIfNullDurationSig{baseDurationBuiltinFunc{bf}}
}
return sig.setSelf(sig), nil
}
Expand Down
52 changes: 29 additions & 23 deletions plan/typeinfer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -633,16 +633,22 @@ func (s *testPlanSuite) createTestCase4LogicalFuncs() []typeInferTestCase {

func (s *testPlanSuite) createTestCase4ControlFuncs() []typeInferTestCase {
return []typeInferTestCase{
{"ifnull(c_int_d, c_int_d )", mysql.TypeLong, charset.CharsetBin, mysql.BinaryFlag, 22, 0},
{"ifnull(c_int_d, c_decimal)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 17, 3},
{"ifnull(c_int_d, c_int_d )", mysql.TypeLong, charset.CharsetBin, mysql.BinaryFlag, 11, 0},
{"ifnull(c_int_d, c_decimal)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 14, 3},
{"ifnull(c_int_d, c_char)", mysql.TypeString, charset.CharsetUTF8, 0, 20, types.UnspecifiedLength},
{"ifnull(c_int_d, c_binary)", mysql.TypeString, charset.CharsetBin, mysql.BinaryFlag, 20, types.UnspecifiedLength},
{"ifnull(c_char, c_binary)", mysql.TypeString, charset.CharsetBin, mysql.BinaryFlag, 20, types.UnspecifiedLength},
{"ifnull(null, null)", mysql.TypeNull, charset.CharsetBin, mysql.BinaryFlag, 0, types.UnspecifiedLength},
{"ifnull(c_double_d, c_timestamp_d)", mysql.TypeVarchar, charset.CharsetUTF8, mysql.NotNullFlag, 22, types.UnspecifiedLength},
{"if(c_int_d, c_decimal, c_int_d)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 15, 3},
{"if(c_int_d, c_char, c_int_d)", mysql.TypeString, charset.CharsetUTF8, 0, 20, -1},
{"if(c_int_d, c_binary, c_int_d)", mysql.TypeString, charset.CharsetBin, mysql.BinaryFlag, 20, -1},
{"if(c_int_d, c_bchar, c_int_d)", mysql.TypeString, charset.CharsetUTF8, mysql.BinaryFlag, 20, -1},
{"if(c_int_d, c_char, c_decimal)", mysql.TypeString, charset.CharsetUTF8, 0, 20, 3},
{"if(c_int_d, c_datetime, c_int_d)", mysql.TypeVarchar, charset.CharsetBin, mysql.BinaryFlag, 22, 2},
{"if(c_int_d, c_datetime, c_int_d)", mysql.TypeVarchar, charset.CharsetUTF8, 0, 22, 2},
{"if(c_int_d, c_int_d, c_double_d)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, 22, types.UnspecifiedLength},
{"if(c_int_d, c_time_d, c_datetime)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 22, 2},
{"if(c_int_d, c_time_d, c_datetime_d)", mysql.TypeDatetime, charset.CharsetUTF8, 0, 19, 0},
{"if(null, null, null)", mysql.TypeString, charset.CharsetBin, mysql.BinaryFlag, 0, 0},
{"case when c_int_d then c_char else c_varchar end", mysql.TypeVarchar, charset.CharsetUTF8, 0, 20, -1},
{"case when c_int_d > 1 then c_double_d else c_bchar end", mysql.TypeString, charset.CharsetUTF8, mysql.BinaryFlag, 22, -1},
{"case when c_int_d > 2 then c_double_d when c_int_d < 1 then c_decimal else c_double_d end", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, 22, 3},
Expand Down Expand Up @@ -803,32 +809,32 @@ func (s *testPlanSuite) createTestCase4CompareFuncs() []typeInferTestCase {
{"isnull(c_bigint_d )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"isnull(c_float_d )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"isnull(c_double_d )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"isnull(c_decimal )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"isnull(c_datetime )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"isnull(c_decimal )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"isnull(c_datetime )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"isnull(c_time_d )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"isnull(c_timestamp_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"isnull(c_char )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"isnull(c_varchar )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"isnull(c_char )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"isnull(c_varchar )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"isnull(c_text_d )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"isnull(c_binary )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"isnull(c_varbinary )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"isnull(c_binary )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"isnull(c_varbinary)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"isnull(c_blob_d )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"isnull(c_set )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"isnull(c_enum )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"isnull(c_set )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"isnull(c_enum )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},

{"nullif(c_int_d , 123)", mysql.TypeLong, charset.CharsetBin, mysql.BinaryFlag, 11, types.UnspecifiedLength}, // TODO: tp should be TypeLonglong, decimal should be 0
{"nullif(c_bigint_d , 123)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 21, types.UnspecifiedLength}, // TODO: flen should be 20, decimal should be 0
{"nullif(c_float_d , 123)", mysql.TypeFloat, charset.CharsetBin, mysql.BinaryFlag, 12, types.UnspecifiedLength}, // TODO: tp should be TypeDouble
{"nullif(c_int_d , 123)", mysql.TypeLong, charset.CharsetBin, mysql.BinaryFlag, 11, 0}, // TODO: tp should be TypeLonglong, decimal should be 0
{"nullif(c_bigint_d , 123)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 21, 0}, // TODO: flen should be 20, decimal should be 0
{"nullif(c_float_d , 123)", mysql.TypeFloat, charset.CharsetBin, mysql.BinaryFlag, 12, types.UnspecifiedLength}, // TODO: tp should be TypeDouble
{"nullif(c_double_d , 123)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, 22, types.UnspecifiedLength},
{"nullif(c_decimal , 123)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 6, 3},
{"nullif(c_datetime , 123)", mysql.TypeVarchar, charset.CharsetBin, mysql.BinaryFlag, 22, 2}, // TODO: tp should be TypeVarString, binary flag
{"nullif(c_time_d , 123)", mysql.TypeVarchar, charset.CharsetBin, mysql.BinaryFlag, 10, 0}, // TODO: tp should be TypeVarString, no binary flag
{"nullif(c_timestamp_d, 123)", mysql.TypeVarchar, charset.CharsetBin, mysql.BinaryFlag, 19, types.UnspecifiedLength}, // TODO: tp should be TypeVarString, decimal should be 0, no binary flag
{"nullif(c_char , 123)", mysql.TypeString, charset.CharsetUTF8, 0, 20, types.UnspecifiedLength},
{"nullif(c_varchar , 123)", mysql.TypeVarchar, charset.CharsetUTF8, 0, 20, types.UnspecifiedLength}, // TODO: tp should be TypeVarString
{"nullif(c_decimal , 123)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 9, 3},
{"nullif(c_datetime , 123)", mysql.TypeVarchar, charset.CharsetUTF8, 0, 22, 2}, // TODO: tp should be TypeVarString
{"nullif(c_time_d , 123)", mysql.TypeVarchar, charset.CharsetUTF8, 0, 10, 0}, // TODO: tp should be TypeVarString
{"nullif(c_timestamp_d, 123)", mysql.TypeVarchar, charset.CharsetUTF8, 0, 19, -1}, // TODO: tp should be TypeVarString, decimal should be 0
{"nullif(c_char , 123)", mysql.TypeString, charset.CharsetUTF8, 0, 20, types.UnspecifiedLength},
{"nullif(c_varchar , 123)", mysql.TypeVarchar, charset.CharsetUTF8, 0, 20, types.UnspecifiedLength}, // TODO: tp should be TypeVarString
{"nullif(c_text_d , 123)", mysql.TypeBlob, charset.CharsetUTF8, 0, 65535, types.UnspecifiedLength}, // TODO: tp should be TypeMediumBlob
{"nullif(c_binary , 123)", mysql.TypeString, charset.CharsetBin, mysql.BinaryFlag, 20, types.UnspecifiedLength}, // TODO: tp should be TypeVarString
{"nullif(c_varbinary , 123)", mysql.TypeVarchar, charset.CharsetBin, mysql.BinaryFlag, 20, types.UnspecifiedLength}, // TODO: tp should be TypeVarString
{"nullif(c_binary , 123)", mysql.TypeString, charset.CharsetBin, mysql.BinaryFlag, 20, types.UnspecifiedLength}, // TODO: tp should be TypeVarString
{"nullif(c_varbinary, 123)", mysql.TypeVarchar, charset.CharsetBin, mysql.BinaryFlag, 20, types.UnspecifiedLength}, // TODO: tp should be TypeVarString
{"nullif(c_blob_d , 123)", mysql.TypeBlob, charset.CharsetBin, mysql.BinaryFlag, 65535, types.UnspecifiedLength}, // TODO: tp should be TypeVarString
}
}
Expand Down

0 comments on commit 2df9456

Please sign in to comment.