diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index b8a8be57ca05..56914108b7a8 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -25,6 +25,7 @@ import annotation.constructorOnly import cc.* import NameKinds.WildcardParamName import MatchTypes.isConcrete +import scala.util.boundary, boundary.break /** Provides methods to compare types. */ @@ -2054,6 +2055,45 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling else op2 end necessaryEither + /** Finds the necessary (the weakest) GADT constraint among a list of them. + * It returns the one being subsumed by all others if exists, and `None` otherwise. + * + * This is used when typechecking pattern alternatives, for instance: + * + * enum Expr[+T]: + * case I1(x: Int) extends Expr[Int] + * case I2(x: Int) extends Expr[Int] + * case B(x: Boolean) extends Expr[Boolean] + * import Expr.* + * + * The following function should compile: + * + * def foo[T](e: Expr[T]): T = e match + * case I1(_) | I2(_) => 42 + * + * since `T >: Int` is subsumed by both alternatives in the first match clause. + * + * However, the following should not: + * + * def foo[T](e: Expr[T]): T = e match + * case I1(_) | B(_) => 42 + * + * since the `I1(_)` case gives the constraint `T >: Int` while `B(_)` gives `T >: Boolean`. + * Neither of the constraints is subsumed by the other. + */ + def necessaryGadtConstraint(constrs: List[GadtConstraint], preGadt: GadtConstraint)(using Context): Option[GadtConstraint] = boundary: + constrs match + case Nil => break(None) + case c0 :: constrs => + var weakest = c0 + for c <- constrs do + if subsumes(weakest.constraint, c.constraint, preGadt.constraint) then + weakest = c + else if !subsumes(c.constraint, weakest.constraint, preGadt.constraint) then + // this two constraints are disjoint + break(None) + break(Some(weakest)) + inline def rollbackConstraintsUnless(inline op: Boolean): Boolean = val saved = constraint var result = false @@ -3376,6 +3416,9 @@ object TypeComparer { def constrainPatternType(pat: Type, scrut: Type, forceInvariantRefinement: Boolean = false)(using Context): Boolean = comparing(_.constrainPatternType(pat, scrut, forceInvariantRefinement)) + def necessaryGadtConstraint(constrs: List[GadtConstraint], preGadt: GadtConstraint)(using Context): Option[GadtConstraint] = + comparing(_.necessaryGadtConstraint(constrs, preGadt)) + def explained[T](op: ExplainingTypeComparer => T, header: String = "Subtype trace:", short: Boolean = false)(using Context): String = comparing(_.explained(op, header, short)) diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index b69873e2b49f..40e30aed582b 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -2826,10 +2826,20 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer else assert(ctx.reporter.errorsReported) tree.withType(defn.AnyType) - val savedGadt = nestedCtx.gadt - val trees1 = tree.trees.mapconserve(typed(_, pt)(using nestedCtx)) + val preGadt = nestedCtx.gadt + var gadtConstrs: mutable.ArrayBuffer[GadtConstraint] = mutable.ArrayBuffer.empty + val trees1 = tree.trees.mapconserve: t => + nestedCtx.gadtState.restore(preGadt) + val res = typed(t, pt)(using nestedCtx) + gadtConstrs += ctx.gadt + res .mapconserve(ensureValueTypeOrWildcard) - nestedCtx.gadtState.restore(savedGadt) // Disable GADT reasoning for pattern alternatives + // Look for the necessary constraint that is subsumed by all alternatives. + // Use that constraint as the outcome if possible, otherwise fallback to not using + // GADT reasoning for soundness. + TypeComparer.necessaryGadtConstraint(gadtConstrs.toList, preGadt) match + case Some(constr) => nestedCtx.gadtState.restore(constr) + case None => nestedCtx.gadtState.restore(preGadt) assignType(cpy.Alternative(tree)(trees1), trees1) } diff --git a/tests/neg/gadt-alt-expr1.scala b/tests/neg/gadt-alt-expr1.scala new file mode 100644 index 000000000000..dc3296e781c8 --- /dev/null +++ b/tests/neg/gadt-alt-expr1.scala @@ -0,0 +1,13 @@ +enum Expr[+T]: + case I1() extends Expr[Int] + case I2() extends Expr[Int] + case B() extends Expr[Boolean] +import Expr.* +def foo[T](e: Expr[T]): T = + e match + case I1() | I2() => 42 // ok + case B() => true +def bar[T](e: Expr[T]): T = + e match + case I1() | B() => 42 // error + case I2() => 0 diff --git a/tests/neg/gadt-alt-expr2.scala b/tests/neg/gadt-alt-expr2.scala new file mode 100644 index 000000000000..a7e05cbe0870 --- /dev/null +++ b/tests/neg/gadt-alt-expr2.scala @@ -0,0 +1,15 @@ +enum Expr[+T]: + case I1() extends Expr[Int] + case I2() extends Expr[Int] + case I3() extends Expr[Int] + case I4() extends Expr[Int] + case I5() extends Expr[Int] + case B() extends Expr[Boolean] +import Expr.* +def test1[T](e: Expr[T]): T = + e match + case I1() | I2() | I3() | I4() | I5() => 42 // ok + case B() => true +def test2[T](e: Expr[T]): T = + e match + case I1() | I2() | I3() | I4() | I5() | B() => 42 // error diff --git a/tests/neg/gadt-alt-expr3.scala b/tests/neg/gadt-alt-expr3.scala new file mode 100644 index 000000000000..a6ebf69ba8d1 --- /dev/null +++ b/tests/neg/gadt-alt-expr3.scala @@ -0,0 +1,34 @@ +trait A +trait B extends A +trait C extends B +trait D +enum Expr[+T]: + case IsA() extends Expr[A] + case IsB() extends Expr[B] + case IsC() extends Expr[C] + case IsD() extends Expr[D] +import Expr.* +def test1[T](e: Expr[T]): T = e match + case IsA() => new A {} + case IsB() => new B {} + case IsC() => new C {} +def test2[T](e: Expr[T]): T = e match + case IsA() | IsB() => + // IsA() implies T >: A + // IsB() implies T >: B + // So T >: B is chosen + new B {} + case IsC() => new C {} +def test3[T](e: Expr[T]): T = e match + case IsA() | IsB() | IsC() => + // T >: C is chosen + new C {} +def test4[T](e: Expr[T]): T = e match + case IsA() | IsB() | IsC() => + new B {} // error +def test5[T](e: Expr[T]): T = e match + case IsA() | IsB() => + new A {} // error +def test6[T](e: Expr[T]): T = e match + case IsA() | IsC() | IsD() => + new C {} // error diff --git a/tests/neg/gadt-alt-expr4.scala b/tests/neg/gadt-alt-expr4.scala new file mode 100644 index 000000000000..6b85655c40c1 --- /dev/null +++ b/tests/neg/gadt-alt-expr4.scala @@ -0,0 +1,19 @@ +trait A +trait B extends A +trait C extends B +enum Expr[T]: + case IsA() extends Expr[A] + case IsB() extends Expr[B] + case IsC() extends Expr[C] +import Expr.* +def test1[T](e: Expr[T]): T = e match + case IsA() => new A {} + case IsB() => new B {} + case IsC() => new C {} +def test2[T](e: Expr[T]): T = e match + case IsA() | IsB() => + // IsA() implies T =:= A + // IsB() implies T =:= B + // No necessary constraint can be found + new B {} // error + case IsC() => new C {} diff --git a/tests/neg/gadt-alt-expr5.scala b/tests/neg/gadt-alt-expr5.scala new file mode 100644 index 000000000000..19993a6f34ab --- /dev/null +++ b/tests/neg/gadt-alt-expr5.scala @@ -0,0 +1,14 @@ +trait A +trait B extends A +trait C extends B +enum Expr[-T]: + case IsA() extends Expr[A] + case IsB() extends Expr[B] + case IsC() extends Expr[C] +import Expr.* +def test1[T](e: Expr[T]): Unit = e match + case IsA() | IsB() => + val t1: T = ??? + val t2: A = t1 + val t3: B = t1 // error + case IsC() => diff --git a/tests/neg/gadt-alternatives.scala b/tests/neg/gadt-alternatives.scala index 034f88ea5f24..e105e1f1d0fd 100644 --- a/tests/neg/gadt-alternatives.scala +++ b/tests/neg/gadt-alternatives.scala @@ -6,4 +6,4 @@ import Expr.* def eval[T](e: Expr[T]): T = e match case StringVal(_) | IntVal(_) => "42" // error def eval1[T](e: Expr[T]): T = e match - case IntValAlt(_) | IntVal(_) => 42 // error // limitation + case IntValAlt(_) | IntVal(_) => 42 // previously error, now ok diff --git a/tests/pos/gadt-alt-doc1.scala b/tests/pos/gadt-alt-doc1.scala new file mode 100644 index 000000000000..f6457b4660e0 --- /dev/null +++ b/tests/pos/gadt-alt-doc1.scala @@ -0,0 +1,8 @@ +trait Document[Doc <: Document[Doc]] +sealed trait Conversion[Doc, V] + +case class C[Doc <: Document[Doc]]() extends Conversion[Doc, Doc] + +def Test[Doc <: Document[Doc], V](conversion: Conversion[Doc, V]) = + conversion match + case C() | C() => ??? // error