Skip to content

Commit

Permalink
Merge pull request apple#247 from lorentey/fix-TreeDictionary-merge
Browse files Browse the repository at this point in the history
[TreeDictionary] Fix in-place merge operation to properly update the count
  • Loading branch information
lorentey authored Dec 1, 2022
2 parents edc166e + b8ed495 commit 531d2e6
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 34 deletions.
65 changes: 39 additions & 26 deletions Sources/HashTreeCollections/Node/_Node+Structural merge.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,26 @@
//===----------------------------------------------------------------------===//

extension _Node {
/// - Returns: The number of new items added to `self`.
@inlinable
internal mutating func merge(
_ level: _Level,
_ other: _Node,
_ combine: (Value, Value) throws -> Value
) rethrows {
guard other.count > 0 else { return }
) rethrows -> Int {
guard other.count > 0 else { return 0 }
guard self.count > 0 else {
self = other
return
return self.count
}
if level.isAtRoot, self.hasSingletonItem {
// In this special case, the root node may turn into a collision node
// during the merge process. Prevent this from causing issues below by
// handling it up front.
var copy = other
try self.read { l in
let delta = try self.read { l in
let lp = l.itemPtr(at: .zero)
let c = copy.count
let res = copy.updateValue(
level, forKey: lp.pointee.key, _Hash(lp.pointee.key)
) {
Expand All @@ -39,31 +41,32 @@ extension _Node {
p.pointee.value = try combine(lp.pointee.value, p.pointee.value)
}
}
return c - (res.inserted ? 0 : 1)
}
self = copy
return
return delta
}

try _merge(level, other, combine)
return try _merge(level, other, combine)
}

@inlinable
internal mutating func _merge(
_ level: _Level,
_ other: _Node,
_ combine: (Value, Value) throws -> Value
) rethrows {
) rethrows -> Int {
// Note: don't check storage identities -- we do need to merge the contents
// of identical nodes.

if self.isCollisionNode || other.isCollisionNode {
try _merge_slow(level, other, combine)
return
return try _merge_slow(level, other, combine)
}

var isUnique = self.isUnique()
return try other.read { r in
var isUnique = self.isUnique()
var delta = 0

try other.read { r in
let (originalItems, originalChildren) = self.read {
($0.itemMap, $0.childMap)
}
Expand Down Expand Up @@ -93,6 +96,7 @@ extension _Node {
// then this call would sometimes turn `self` into a collision
// node on a compressed path, causing mischief.
assert(!self.isCollisionNode)
delta &+= 1
}
isUnique = true
}
Expand All @@ -103,6 +107,7 @@ extension _Node {
self.ensureUnique(
isUnique: isUnique, withFreeSpace: _Node.spaceForSpawningChild)
let item = self.removeItem(at: bucket)
delta &-= 1
var child = rp.pointee
let r = child.updateValue(
level.descend(), forKey: item.key, _Hash(item.key)
Expand All @@ -117,6 +122,7 @@ extension _Node {
}
self.insertChild(child, bucket)
isUnique = true
delta &+= child.count
}
}

Expand All @@ -138,6 +144,7 @@ extension _Node {
}
if res.inserted {
self.count &+= 1
delta &+= 1
} else {
try UnsafeHandle.update(res.leaf) {
let p = $0.itemPtr(at: res.slot)
Expand All @@ -149,12 +156,14 @@ extension _Node {
else if r.childMap.contains(bucket) {
let rslot = r.childMap.slot(of: bucket)
self.ensureUnique(isUnique: isUnique)
try self.update { l in
let d = try self.update { l in
try l[child: lslot].merge(
level.descend(),
r[child: rslot],
combine)
}
self.count &+= d
delta &+= d
isUnique = true
}
}
Expand All @@ -167,17 +176,20 @@ extension _Node {
let rslot = r.itemMap.slot(of: bucket)
self.ensureUniqueAndInsertItem(
isUnique: isUnique, r[item: rslot], at: bucket)
delta &+= 1
isUnique = true
}
for (bucket, _) in r.childMap.subtracting(seen) {
let rslot = r.childMap.slot(of: bucket)
self.ensureUnique(
isUnique: isUnique, withFreeSpace: _Node.spaceForNewChild)
self.insertChild(r[child: rslot], bucket)
delta &+= r[child: rslot].count
isUnique = true
}

assert(isUnique)
return delta
}
}

Expand All @@ -186,7 +198,7 @@ extension _Node {
_ level: _Level,
_ other: _Node,
_ combine: (Value, Value) throws -> Value
) rethrows {
) rethrows -> Int {
let lc = self.isCollisionNode
let rc = other.isCollisionNode
if lc && rc {
Expand All @@ -195,10 +207,11 @@ extension _Node {
level: level,
child1: self, self.collisionHash,
child2: other, other.collisionHash)
return
return other.count
}
var isUnique = self.isUnique()
return try other.read { r in
var isUnique = self.isUnique()
var delta = 0
let originalItemCount = self.count
for rs: _Slot in stride(from: .zero, to: r.itemsEndSlot, by: 1) {
let rp = r.itemPtr(at: rs)
Expand All @@ -218,10 +231,11 @@ extension _Node {
} else {
_ = self.ensureUniqueAndAppendCollision(
isUnique: isUnique, rp.pointee)
delta &+= 1
}
isUnique = true
}
return
return delta
}
}

Expand Down Expand Up @@ -251,22 +265,23 @@ extension _Node {
}
self = other._copyNodeAndReplaceItemWithNewChild(
level: level, self, at: bucket, itemSlot: rslot)
return
return other.count - (res.inserted ? 0 : 1)
}

if r.childMap.contains(bucket) {
let originalCount = self.count
let rslot = r.childMap.slot(of: bucket)
try self._merge(level.descend(), r[child: rslot], combine)
_ = try self._merge(level.descend(), r[child: rslot], combine)
var node = other.copy()
_ = node.replaceChild(at: bucket, rslot, with: self)
self = node
return
return self.count - originalCount
}

var node = other.copy(withFreeSpace: _Node.spaceForNewChild)
node.insertChild(self, bucket)
self = node
return
return other.count
}
}

Expand All @@ -292,25 +307,23 @@ extension _Node {
}
assert(self.count > 0) // Singleton case handled up front above
self.insertChild(copy, bucket)
return
return other.count - (res.inserted ? 0 : 1)
}
if self.read({ $0.childMap.contains(bucket) }) {
self.ensureUnique(isUnique: isUnique)
let delta: Int = try self.update { l in
let lslot = l.childMap.slot(of: bucket)
let lchild = l.childPtr(at: lslot)
let origCount = lchild.pointee.count
try lchild.pointee._merge(level.descend(), other, combine)
return lchild.pointee.count &- origCount
return try lchild.pointee._merge(level.descend(), other, combine)
}
assert(delta >= 0)
self.count &+= delta
return
return delta
}
self.ensureUnique(
isUnique: isUnique, withFreeSpace: _Node.spaceForNewChild)
self.insertChild(other, bucket)
return
return other.count
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ extension TreeDictionary {
) rethrows -> Self {
let result = try _root.filter(.top, isIncluded)
guard let result = result else { return self }
return TreeDictionary(_new: result.finalize(.top))
let r = TreeDictionary(_new: result.finalize(.top))
r._invariantCheck()
return r
}

/// Removes all the elements that satisfy the given predicate.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,5 +300,6 @@ extension TreeDictionary {
array.append(value)
}
}
_invariantCheck()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,9 @@ extension TreeDictionary.Keys {
guard let r = _base._root.intersection(.top, other._base._root) else {
return self
}
return TreeDictionary(_new: r).keys
let d = TreeDictionary(_new: r)
d._invariantCheck()
return d.keys
}

/// Returns a new keys view with the elements that are common to both this
Expand All @@ -235,7 +237,9 @@ extension TreeDictionary.Keys {
guard let r = _base._root.intersection(.top, other._root) else {
return self
}
return TreeDictionary(_new: r).keys
let d = TreeDictionary(_new: r)
d._invariantCheck()
return d.keys
}

/// Returns a new keys view containing the elements of `self` that do not
Expand All @@ -260,7 +264,9 @@ extension TreeDictionary.Keys {
guard let r = _base._root.subtracting(.top, other._base._root) else {
return self
}
return TreeDictionary(_new: r).keys
let d = TreeDictionary(_new: r)
d._invariantCheck()
return d.keys
}

/// Returns a new keys view containing the elements of `self` that do not
Expand All @@ -282,6 +288,8 @@ extension TreeDictionary.Keys {
guard let r = _base._root.subtracting(.top, other._root) else {
return self
}
return TreeDictionary(_new: r).keys
let d = TreeDictionary(_new: r)
d._invariantCheck()
return d.keys
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ extension TreeDictionary {
_ transform: (Value) throws -> T
) rethrows -> TreeDictionary<Key, T> {
let transformed = try _root.mapValues { try transform($0.value) }
return TreeDictionary<Key, T>(_new: transformed)
let r = TreeDictionary<Key, T>(_new: transformed)
r._invariantCheck()
return r
}

/// Returns a new dictionary containing only the key-value pairs that have
Expand Down Expand Up @@ -59,6 +61,8 @@ extension TreeDictionary {
_ transform: (Value) throws -> T?
) rethrows -> TreeDictionary<Key, T> {
let result = try _root.compactMapValues(.top, transform)
return TreeDictionary<Key, T>(_new: result.finalize(.top))
let d = TreeDictionary<Key, T>(_new: result.finalize(.top))
d._invariantCheck()
return d
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ extension TreeDictionary {
uniquingKeysWith combine: (Value, Value) throws -> Value
) rethrows {
_invalidateIndices()
try _root.merge(.top, keysAndValues._root, combine)
_ = try _root.merge(.top, keysAndValues._root, combine)
_invariantCheck()
}

/// Merges the key-value pairs in the given sequence into the dictionary,
Expand Down Expand Up @@ -87,6 +88,7 @@ extension TreeDictionary {
}
}
}
_invariantCheck()
}

/// Merges the key-value pairs in the given sequence into the dictionary,
Expand Down
1 change: 1 addition & 0 deletions Tests/HashTreeCollectionsTests/Utilities.swift
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ func expectEqualDictionaries<Key: Hashable, Value: Equatable>(
file: StaticString = #file,
line: UInt = #line
) {
expectEqual(map.count, dict.count, "Mismatching count", file: file, line: line)
var dict = dict
var seen: Set<Key> = []
var mismatches: [(key: Key, map: Value?, dict: Value?)] = []
Expand Down

0 comments on commit 531d2e6

Please sign in to comment.