Skip to content

Commit 7bd0ac7

Browse files
committed
Use ST monad for diffBy
This reduces extra memory requirement from O(nm) to O(n + m).
1 parent 19cbafe commit 7bd0ac7

File tree

2 files changed

+97
-33
lines changed

2 files changed

+97
-33
lines changed

src/Data/TreeDiff/List.hs

Lines changed: 96 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
{-# LANGUAGE BangPatterns #-}
12
{-# LANGUAGE ScopedTypeVariables #-}
23
-- | A list diff.
34
module Data.TreeDiff.List (
@@ -6,9 +7,12 @@ module Data.TreeDiff.List (
67
) where
78

89
import Control.DeepSeq (NFData (..))
10+
import Control.Monad.ST (ST, runST)
911

1012
import qualified Data.Primitive as P
1113

14+
-- import Debug.Trace
15+
1216
-- | List edit operations
1317
--
1418
-- The 'Swp' constructor is redundant, but it let us spot
@@ -40,45 +44,105 @@ instance NFData a => NFData (Edit a) where
4044
-- /Note:/ currently this has O(n*m) memory requirements, for the sake
4145
-- of more obviously correct implementation.
4246
--
43-
diffBy :: forall a. (a -> a -> Bool) -> [a] -> [a] -> [Edit a]
44-
diffBy eq xs' ys' = reverse (getCell (lcs xn yn))
47+
diffBy :: forall a. Show a => (a -> a -> Bool) -> [a] -> [a] -> [Edit a]
48+
diffBy _ [] [] = []
49+
diffBy _ [] ys' = map Ins ys'
50+
diffBy _ xs' [] = map Del xs'
51+
diffBy eq xs' ys'
52+
| otherwise = reverse (getCell lcs)
4553
where
4654
xn = length xs'
4755
yn = length ys'
4856

4957
xs = P.arrayFromListN xn xs'
5058
ys = P.arrayFromListN yn ys'
5159

52-
memo :: P.Array (Cell [Edit a])
53-
memo = P.arrayFromListN ((xn + 1) * (yn + 1))
54-
[ impl xi yi
55-
| xi <- [0 .. xn]
56-
, yi <- [0 .. yn]
57-
]
58-
59-
lcs :: Int -> Int -> Cell [Edit a]
60-
lcs xi yi = P.indexArray memo (yi + xi * (yn + 1))
61-
62-
impl :: Int -> Int -> Cell [Edit a]
63-
impl 0 0 = Cell 0 []
64-
impl 0 m = case lcs 0 (m - 1) of
65-
Cell w edit -> Cell (w + 1) (Ins (P.indexArray ys (m - 1)) : edit)
66-
impl n 0 = case lcs (n - 1) 0 of
67-
Cell w edit -> Cell (w + 1) (Del (P.indexArray xs (n - 1)) : edit)
68-
69-
impl n m = bestOfThree
70-
edit
71-
(bimap (+1) (Ins y :) (lcs n (m - 1)))
72-
(bimap (+1) (Del x :) (lcs (n - 1) m))
73-
where
74-
x = P.indexArray xs (n - 1)
75-
y = P.indexArray ys (m - 1)
76-
77-
edit
78-
| eq x y = bimap id (Cpy x :) (lcs (n - 1) (m - 1))
79-
| otherwise = bimap (+1) (Swp x y :) (lcs (n - 1) (m - 1))
80-
81-
data Cell a = Cell !Int !a
60+
lcs :: Cell [Edit a]
61+
lcs = runST $ do
62+
-- traceShowM ("sizes", xn, yn)
63+
64+
-- create two buffers.
65+
buf1 <- P.newArray yn (Cell 0 [])
66+
buf2 <- P.newArray yn (Cell 0 [])
67+
68+
-- fill the first row
69+
-- 0,0 case is filled already
70+
yLoop (Cell 0 []) $ \m (Cell w edit) -> do
71+
let cell = Cell (w + 1) (Ins (P.indexArray ys m) : edit)
72+
P.writeArray buf1 m cell
73+
P.writeArray buf2 m cell
74+
-- traceShowM ("init", m, cell)
75+
return cell
76+
77+
-- following rows
78+
--
79+
-- cellC cellT
80+
-- cellL cellX
81+
(buf1final, _, _) <- xLoop (buf1, buf2, Cell 0 []) $ \n (prev, curr, cellC) -> do
82+
-- prevZ <- P.unsafeFreezeArray prev
83+
-- currZ <- P.unsafeFreezeArray prev
84+
-- traceShowM ("prev", n, prevZ)
85+
-- traceShowM ("curr", n, currZ)
86+
87+
let cellL :: Cell [Edit a]
88+
cellL = case cellC of (Cell w edit) -> Cell (w + 1) (Del (P.indexArray xs n) : edit)
89+
90+
-- traceShowM ("cellC, cellL", n, cellC, cellL)
91+
92+
yLoop (cellC, cellL) $ \m (cellC', cellL') -> do
93+
-- traceShowM ("inner loop", n, m)
94+
cellT <- P.readArray prev m
95+
96+
-- traceShowM ("cellT", n, m, cellT)
97+
98+
let x, y :: a
99+
x = P.indexArray xs n
100+
y = P.indexArray ys m
101+
102+
-- from diagonal
103+
let cellX1 :: Cell [Edit a]
104+
cellX1
105+
| eq x y = bimap id (Cpy x :) cellC'
106+
| otherwise = bimap (+1) (Swp x y :) cellC'
107+
108+
-- from top
109+
let cellX2 :: Cell [Edit a]
110+
cellX2 = bimap (+1) (Del x :) cellT
111+
112+
-- from left
113+
let cellX3 :: Cell [Edit a]
114+
cellX3 = bimap (+1) (Ins y :) cellL'
115+
116+
-- the actual cell is best of three
117+
let cellX :: Cell [Edit a]
118+
cellX = bestOfThree cellX1 cellX2 cellX3
119+
120+
-- traceShowM ("cellX", n, m, cellX)
121+
122+
-- memoize
123+
P.writeArray curr m cellX
124+
125+
return (cellT, cellX)
126+
127+
return (curr, prev, cellL)
128+
129+
P.readArray buf1final (yn - 1)
130+
131+
xLoop :: acc -> (Int -> acc -> ST s acc) -> ST s acc
132+
xLoop !acc0 f = go acc0 0 where
133+
go !acc !n | n < xn = do
134+
acc' <- f n acc
135+
go acc' (n + 1)
136+
go !acc _ = return acc
137+
138+
yLoop :: acc -> (Int -> acc -> ST s acc) -> ST s ()
139+
yLoop !acc0 f = go acc0 0 where
140+
go !acc !m | m < yn = do
141+
acc' <- f m acc
142+
go acc' (m + 1)
143+
go _ _ = return ()
144+
145+
data Cell a = Cell !Int !a deriving Show
82146

83147
getCell :: Cell a -> a
84148
getCell (Cell _ x) = x

src/Data/TreeDiff/Tree.hs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ import Data.TreeDiff.List
7272
-- >>> ppEditTree PP.char (treeDiff x w)
7373
-- (a b (c d +x e) f)
7474
--
75-
treeDiff :: Eq a => Tree a -> Tree a -> Edit (EditTree a)
75+
treeDiff :: (Show a, Eq a) => Tree a -> Tree a -> Edit (EditTree a)
7676
treeDiff ta@(Node a as) tb@(Node b bs)
7777
| a == b = Cpy $ EditNode a (map rec (diffBy (==) as bs))
7878
| otherwise = Swp (treeToEdit ta) (treeToEdit tb)

0 commit comments

Comments
 (0)