Skip to content

Commit

Permalink
Fix the typing of Cypher collections.
Browse files Browse the repository at this point in the history
o Previously all collections were CollectionType(AnyType()) because of a bug in
   the type derivation code.
o This fixes a reported bug where indexing twice into a collection didn't work.
  • Loading branch information
benbc committed Oct 16, 2013
1 parent d10c65d commit 4da50a8
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,10 @@ case class Collection(children: Expression*) extends Expression {

def calculateType(symbols: SymbolTable): CypherType = {
children.map(_.getType(symbols)) match {

case Seq() => CollectionType(AnyType())

case types =>
val innerType = types.foldLeft(AnyType().asInstanceOf[CypherType])(_ mergeDown _)
CollectionType( innerType )
case types => CollectionType(types.reduce(_ mergeDown _))
}

}

def symbolTableDependencies = children.flatMap(_.symbolTableDependencies).toSet
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@
*/
package org.neo4j.cypher.internal.commands.expressions

import org.neo4j.cypher.internal.symbols.{SymbolTable, CypherType}
import org.neo4j.cypher.internal.symbols._
import org.neo4j.cypher.internal.ExecutionContext
import org.neo4j.cypher.internal.helpers.{IsCollection, IsMap}
import org.neo4j.cypher.internal.symbols.SymbolTable
import org.neo4j.cypher.internal.pipes.QueryState
import org.neo4j.cypher.internal.symbols.AnyType

case class Literal(v: Any) extends Expression {
def apply(ctx: ExecutionContext)(implicit state: QueryState): Any = v
Expand All @@ -30,9 +33,20 @@ case class Literal(v: Any) extends Expression {

def children = Nil

def calculateType(symbols: SymbolTable): CypherType = CypherType.fromJava(v)
def calculateType(symbols: SymbolTable): CypherType = deriveType(v)

def symbolTableDependencies = Set()

override def toString() = "Literal(" + v + ")"
}
override def toString = "Literal(" + v + ")"

private def deriveType(obj: Any): CypherType = obj match {
case _: String => StringType()
case _: Char => StringType()
case _: Number => NumberType()
case _: Boolean => BooleanType()
case IsMap(_) => MapType()
case IsCollection(coll) if coll.isEmpty => CollectionType(AnyType())
case IsCollection(coll) => CollectionType(coll.map(deriveType).reduce(_ mergeDown _))
case _ => AnyType()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
package org.neo4j.cypher.internal.symbols

import org.neo4j.cypher.CypherTypeException
import org.neo4j.cypher.internal.helpers.{IsCollection, IsMap}

trait CypherType {
def isAssignableFrom(other: CypherType): Boolean = this.getClass.isAssignableFrom(other.getClass)
Expand All @@ -44,21 +43,6 @@ trait CypherType {
def rewrite(f: CypherType => CypherType) = f(this)
}


object CypherType {
def fromJava(obj: Any): CypherType = obj match {
case _: String => StringType()
case _: Char => StringType()
case _: Number => NumberType()
case _: Boolean => BooleanType()
case IsMap(_) => MapType()
case IsCollection(coll) if coll.isEmpty => CollectionType(AnyType())
case IsCollection(coll) => CollectionType(coll.map(fromJava).reduce(_ mergeDown _))
case _ => AnyType()
}
}


/*
TypeSafe is everything that needs to check it's types
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ foreach(x in [1,2,3] |
relate(start, a2, "x")

val result = execute(
s"start start=node(${start.getId}) MATCH (start)-[rel:x]-(a) WHERE a.name = '${name}' return a"
s"start start=node(${start.getId}) MATCH (start)-[rel:x]-(a) WHERE a.name = '$name' return a"
)
assertEquals(List(a2), result.columnAs[Node]("a").toList)
}
Expand Down Expand Up @@ -190,14 +190,11 @@ foreach(x in [1,2,3] |
val result = execute(
s"start node=node(${n1.getId}) match (node)-[rel:KNOWS]->(x) return x, node"
)

val textOutput = result.dumpToString()
result.dumpToString()
}

@Test def doesNotFailOnVisualizingEmptyOutput() {
val result = execute(
s"start node=node(${refNode.getId}) where 1 = 0 return node"
)
execute(s"start node=node(${refNode.getId}) where 1 = 0 return node")
}

@Test def shouldGetRelatedToRelatedTo() {
Expand All @@ -219,7 +216,7 @@ foreach(x in [1,2,3] |
val value = "andres"
indexNode(n, idxName, key, value)

val query = s"start n=node:${idxName}(${key} = '${value}') return n"
val query = s"start n=node:$idxName($key = '$value') return n"

assertInTx(List(Map("n" -> n)) === execute(query).toList)
}
Expand All @@ -231,7 +228,7 @@ foreach(x in [1,2,3] |
val value = "andres"
indexNode(n, idxName, key, value)

val query = s"start n=node:${idxName}('${key}: ${value}') return n"
val query = s"start n=node:$idxName('$key: $value') return n"

assertInTx(List(Map("n" -> n)) === execute(query).toList)
}
Expand All @@ -242,7 +239,7 @@ foreach(x in [1,2,3] |
val key = "key"
indexNode(n, idxName, key, "Andres")

val query = s"start n=node:${idxName}(key = {value}) return n"
val query = s"start n=node:$idxName(key = {value}) return n"

assertInTx(List(Map("n" -> n)) === execute(query, "value" -> "Andres").toList)
}
Expand All @@ -254,7 +251,7 @@ foreach(x in [1,2,3] |
val value = "andres"
indexNode(n, idxName, key, value)

val query = s"start n=node:${idxName}('${key}:andr*') return n"
val query = s"start n=node:$idxName('$key:andr*') return n"

assertInTx(List(Map("n" -> n)) === execute(query).toList)
}
Expand All @@ -263,7 +260,7 @@ foreach(x in [1,2,3] |
val n1 = createNode(Map("name" -> "boy"))
val n2 = createNode(Map("name" -> "girl"))

val query = Query.
Query.
start(NodeById("n", n1.getId, n2.getId)).
where(Or(
Equals(Property(Identifier("n"), PropertyKey("name")), Literal("boy")),
Expand Down Expand Up @@ -368,7 +365,7 @@ foreach(x in [1,2,3] |
relate(refNode, a, "A")
relate(refNode, b, "A")

val query = Query.
Query.
start(NodeById("a", refNode.getId)).
matches(RelatedTo("a", "b", "rel", Seq(), Direction.OUTGOING)).
aggregation(CountStar()).
Expand Down Expand Up @@ -768,8 +765,7 @@ foreach(x in [1,2,3] |
relate("A" -> "KNOWS" -> "B")

//Checking that we don't get an exception
val result = execute("start a = node(1), b = node(2) match p = shortestPath(a-[*]-b) return p").
toList
execute("start a = node(1), b = node(2) match p = shortestPath(a-[*]-b) return p").toList
}

@Test def shouldBeAbleToTakeParamsInDifferentTypes() {
Expand Down Expand Up @@ -1024,7 +1020,7 @@ return x, p""").toList
@Test def shouldHandleOptionalPathsFromACombo() {
val a = createNode("A")
val b = createNode("B")
val r = relate(a, b, "X")
relate(a, b, "X")

val result = execute( """
start a = node(1)
Expand Down Expand Up @@ -1073,7 +1069,7 @@ return a""")
}

@Test def shouldSupportColumnRenamingForAggregatesAsWell() {
val a = createNode(Map("name" -> "Andreas"))
createNode(Map("name" -> "Andreas"))

val result = execute( """
start a = node(1)
Expand Down Expand Up @@ -1765,7 +1761,7 @@ RETURN x0.name
relate(peter, bread, "ATE", Map("times"->7))
relate(peter, meat, "ATE", Map("times"->4))

val result = execute(
execute(
""" start me=node(1)
match me-[r1:ATE]->food<-[r2:ATE]-you
Expand Down Expand Up @@ -1845,7 +1841,7 @@ RETURN x0.name
def var_length_predicate() {
val a = createNode()
val b = createNode()
val r = relate(a, b)
relate(a, b)

val resultPath = execute("START a=node(1), b=node(2) RETURN a-[*]->b as path")
.toList.head("path")
Expand Down Expand Up @@ -1873,7 +1869,7 @@ RETURN x0.name
relate(a,b)
relate(a,c)

val result = execute("CYPHER 1.9 START a=node(1) foreach(n in extract(p in a-->() | last(p)) | set n.touched = true) return a-->()").dumpToString()
execute("CYPHER 1.9 START a=node(1) foreach(n in extract(p in a-->() | last(p)) | set n.touched = true) return a-->()").dumpToString()
}

@Test
Expand Down Expand Up @@ -2172,7 +2168,7 @@ RETURN x0.name

@Test
def extract_string_from_node_collection() {
val a = createNode("name"->"a")
createNode("name"->"a")

val result = execute("""START n=node(1) with collect(n) as nodes return head(extract(x in nodes | x.name)) + "test" as test """)

Expand Down Expand Up @@ -2266,7 +2262,7 @@ RETURN x0.name
// GIVEN
val a = createLabeledNode("foo")
val b = createLabeledNode("foo", "bar")
val c = createNode()
createNode()

// WHEN
val result = execute("""START n=node(1, 2, 3) WHERE n:foo RETURN n""")
Expand All @@ -2277,8 +2273,8 @@ RETURN x0.name

@Test def should_filter_nodes_by_single_negated_label() {
// GIVEN
val a = createLabeledNode("foo")
val b = createLabeledNode("foo", "bar")
createLabeledNode("foo")
createLabeledNode("foo", "bar")
val c = createNode()

// WHEN
Expand All @@ -2290,9 +2286,9 @@ RETURN x0.name

@Test def should_filter_nodes_by_multiple_labels() {
// GIVEN
val a = createLabeledNode("foo")
createLabeledNode("foo")
val b = createLabeledNode("foo", "bar")
val c = createNode()
createNode()

// WHEN
val result = execute("""START n=node(1, 2, 3) WHERE n:foo:bar RETURN n""")
Expand Down Expand Up @@ -2726,7 +2722,7 @@ RETURN x0.name
// given any database

// then shouldn't throw
val result = execute("START x=node(0) RETURN DISTINCT x as otherName ORDER BY x.name ")
execute("START x=node(0) RETURN DISTINCT x as otherName ORDER BY x.name ")
}

def should_not_hang() {
Expand Down Expand Up @@ -2769,4 +2765,10 @@ RETURN x0.name
assert(node.getProperty("second") === "value")
}
}

@Test
def should_be_able_to_index_into_nested_literal_lists() {
execute("RETURN [[1]][0][0]")
// shoud not throw an exception
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package org.neo4j.cypher.internal.commands.expressions

import org.scalatest.Assertions
import org.junit.Test
import org.neo4j.cypher.internal.symbols._
import org.neo4j.cypher.internal.symbols.SymbolTable
import org.neo4j.cypher.internal.symbols.AnyType

/**
* Copyright (c) 2002-2013 "Neo Technology,"
* Network Engine for Objects in Lund AB [http://neotechnology.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
class CollectionTest extends Assertions {
@Test
def empty_collection_should_have_any_type() {
assert(Collection().getType(SymbolTable()) === CollectionType(AnyType()))
}

@Test
def collection_with_one_item_should_be_typed_for_that_items_type() {
assert(Collection(Literal(1)).getType(SymbolTable()) === CollectionType(NumberType()))
}

@Test
def collection_with_several_items_should_be_typed_for_their_common_supertype(){
assert(Collection(Literal(1), Literal(true)).getType(SymbolTable()) === CollectionType(ScalarType()))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package org.neo4j.cypher.internal.commands.expressions

import org.scalatest.Assertions
import org.junit.Test
import org.neo4j.cypher.internal.symbols.{SymbolTable, StringType, CollectionType}

/**
* Copyright (c) 2002-2013 "Neo Technology,"
* Network Engine for Objects in Lund AB [http://neotechnology.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
class LiteralTest extends Assertions {
@Test
def collections_should_be_typed_correctly() {
val value = Literal(Seq(Seq("Text")))
val typ = CollectionType(CollectionType(StringType()))

assert(value.calculateType(SymbolTable()) === typ)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,6 @@ import org.junit.Test
import org.scalatest.Assertions

class CypherTypeTest extends Assertions {
@Test def collections_should_be_typed_correctly() {
val value = Seq(Seq("Text"))
val typ = CollectionType(CollectionType(StringType()))

assert(CypherType.fromJava(value) === typ)
}

@Test
def testTypeMergeDown() {
assertCorrectTypeMergeDown(NumberType(), NumberType(), NumberType())
Expand All @@ -43,7 +36,7 @@ class CypherTypeTest extends Assertions {
def assertCorrectTypeMergeDown(a: CypherType, b: CypherType, result: CypherType) {
val simpleMergedType: CypherType = a mergeDown b
assert(simpleMergedType === result)
val collectionMergedType: CypherType = (CollectionType(a)) mergeDown (CollectionType(b))
val collectionMergedType: CypherType = CollectionType(a) mergeDown CollectionType(b)
assert(collectionMergedType === CollectionType(result))
}

Expand All @@ -61,7 +54,7 @@ class CypherTypeTest extends Assertions {
def assertCorrectTypeMergeUp(a: CypherType, b: CypherType, result: Option[CypherType]) {
val simpleMergedType: Option[CypherType] = a mergeUp b
assert(simpleMergedType === result)
val collectionMergedType: Option[CypherType] = (CollectionType(a)) mergeUp (CollectionType(b))
val collectionMergedType: Option[CypherType] = CollectionType(a) mergeUp CollectionType(b)
assert(collectionMergedType === (for (t <- result) yield CollectionType(t)))
}
}
}

0 comments on commit 4da50a8

Please sign in to comment.