diff --git a/src/Language/Wasm/Interpreter.hs b/src/Language/Wasm/Interpreter.hs index 907e2ef..77e0ea8 100644 --- a/src/Language/Wasm/Interpreter.hs +++ b/src/Language/Wasm/Interpreter.hs @@ -172,7 +172,7 @@ data Label = Label ResultType deriving (Show, Eq) type Address = Int -type TableStore = IOVector (Maybe Address) +type TableStore = IORef (IOVector (Maybe Address)) data TableInstance = TableInstance { t :: TableType, @@ -488,7 +488,7 @@ allocTables = fmap Vector.fromList . mapM allocTable allocTable :: Table -> IO TableInstance allocTable (Table t@(TableType lim@(Limit from to) _)) = let elements = MVector.replicate (fromIntegral from) Nothing in - TableInstance t <$> elements + TableInstance t <$> (elements >>= newIORef) defaultBudget :: Natural defaultBudget = 300 @@ -562,14 +562,14 @@ initialize inst Module {elems, datas, start} = do let idx = tableaddrs inst ! fromIntegral tableIndex let last = from + length funcs let TableInstance lim elems = tableInstances st ! idx - let len = MVector.length elems + len <- MVector.length <$> (liftIO $ readIORef elems) Monad.when (last > len) $ throwError "out of bounds table access" return (idx, elemaddrs inst ! elemN, from, funcs) initElem :: (Address, Address, Int, [Maybe Address]) -> Initialize () initElem (tableIdx, elemIdx, from, funcs) = do Store {tableInstances, elemInstances} <- State.get - let elems = items $ tableInstances ! tableIdx + elems <- liftIO $ readIORef $ items $ tableInstances ! tableIdx let ElemInstance {isDropped} = elemInstances ! elemIdx liftIO $ writeIORef isDropped True Monad.forM_ (zip [from..] funcs) $ uncurry $ MVector.unsafeWrite elems @@ -777,10 +777,11 @@ eval budget store FunctionInstance { funcType, moduleInstance, code = Function { let funcType = funcTypes moduleInstance ! fromIntegral typeIdx let TableInstance { items } = tableInstances store ! (tableaddrs moduleInstance ! fromIntegral tableIdx) let pos = fromIntegral v - if pos >= MVector.length items + funcs <- readIORef items + if pos >= MVector.length funcs then return Trap else do - maybeAddr <- MVector.unsafeRead items pos + maybeAddr <- MVector.unsafeRead funcs pos let checks = do addr <- maybeAddr let funcInst = funcInstances store ! addr @@ -922,15 +923,16 @@ eval budget store FunctionInstance { funcType, moduleInstance, code = Function { let src = fromIntegral s let dst = fromIntegral d let len = fromIntegral n + els <- readIORef items isDropped <- readIORef dropFlag if src + len > Vector.length refs - || dst + len > MVector.length items + || dst + len > MVector.length els || isDropped || isDeclarative mode then return Trap else do Vector.iforM_ (Vector.slice src len refs) $ \idx (RF fn) -> do - MVector.unsafeWrite items (dst + idx) (fromIntegral <$> fn) + MVector.unsafeWrite els (dst + idx) (fromIntegral <$> fn) return $ Done ctx { stack = rest } step ctx@EvalCtx{ stack = (VI32 n:VI32 s:VI32 d:rest) } (TableCopy toIdx fromIdx) = do let fromAddr = tableaddrs moduleInstance ! fromIntegral fromIdx @@ -940,14 +942,41 @@ eval budget store FunctionInstance { funcType, moduleInstance, code = Function { let src = fromIntegral s let dst = fromIntegral d let len = fromIntegral n - if src + len > MVector.length fromItems || dst + len > MVector.length toItems + fromEls <- readIORef fromItems + toEls <- readIORef toItems + if src + len > MVector.length fromEls || dst + len > MVector.length toEls then return Trap else do let range = if dst <= src then [0..len - 1] else reverse [0..len - 1] flip mapM_ range $ \off -> do - el <- MVector.unsafeRead fromItems (src + off) - MVector.unsafeWrite toItems (dst + off) el + el <- MVector.unsafeRead fromEls (src + off) + MVector.unsafeWrite toEls (dst + off) el return $ Done ctx { stack = rest } + step ctx@EvalCtx{ stack } (TableSize tableIdx) = do + let tableAddr = tableaddrs moduleInstance ! fromIntegral tableIdx + let TableInstance { items } = tableInstances store ! tableAddr + len <- MVector.length <$> readIORef items + return $ Done ctx { stack = VI32 (fromIntegral len) : stack } + step ctx@EvalCtx{ stack = (VI32 growBy:ref:rest) } (TableGrow tableIdx) = do + let tableAddr = tableaddrs moduleInstance ! fromIntegral tableIdx + let TableInstance { items, t } = tableInstances store ! tableAddr + let TableType (Limit _ max) _ = t + let inc = fromIntegral growBy + let val = case ref of + RE extRef -> extRef + RF fnRef -> fnRef + v -> error "Impossible due to validation" + els <- readIORef items + let currLen = MVector.length els + let newLen = currLen + inc + if maybe False ((newLen >) . fromIntegral) max || newLen > 0xFFFFFFFF + then return $ Done ctx { stack = VI32 (asWord32 $ -1):rest } + else do + newEls <- MVector.grow els inc + writeIORef items newEls + Monad.forM_ [0..inc - 1] $ \off -> + MVector.unsafeWrite newEls (currLen + off) (fromIntegral <$> val) + return $ Done ctx { stack = VI32 (fromIntegral currLen):rest } step ctx@EvalCtx{ stack = (ref:VI32 offset:rest) } (TableSet tableIdx) = do let tableAddr = tableaddrs moduleInstance ! fromIntegral tableIdx let TableInstance { items } = tableInstances store ! tableAddr @@ -956,19 +985,21 @@ eval budget store FunctionInstance { funcType, moduleInstance, code = Function { RE extRef -> extRef RF fnRef -> fnRef v -> error "Impossible due to validation" - if dst > MVector.length items + els <- readIORef items + if dst > MVector.length els then return Trap else do - MVector.unsafeWrite items dst (fromIntegral <$> val) + MVector.unsafeWrite els dst (fromIntegral <$> val) return $ Done ctx { stack = rest } step ctx@EvalCtx{ stack = (VI32 offset:rest) } (TableGet tableIdx) = do let tableAddr = tableaddrs moduleInstance ! fromIntegral tableIdx let TableInstance { t = TableType _ et, items } = tableInstances store ! tableAddr let dst = fromIntegral offset - if dst > MVector.length items + els <- readIORef items + if dst > MVector.length els then return Trap else do - v <- MVector.unsafeRead items dst + v <- MVector.unsafeRead els dst let val = (case et of {FuncRef -> RF; ExternRef -> RE}) (fromIntegral <$> v) return $ Done ctx { stack = val : rest } step ctx (ElemDrop elemIdx) = do diff --git a/src/Language/Wasm/Parser.y b/src/Language/Wasm/Parser.y index a06baf2..0003310 100644 --- a/src/Language/Wasm/Parser.y +++ b/src/Language/Wasm/Parser.y @@ -482,6 +482,8 @@ plaininstr :: { PlainInstr } | 'table.get' index { TableGet $2 } | 'table.set' index { TableSet $2 } | 'table.copy' index index { TableCopy $2 $3 } + | 'table.size' index { TableSize $2 } + | 'table.grow' index { TableGrow $2 } | 'elem.drop' index { ElemDrop $2 } -- numeric instructions | 'i32.const' int32 { I32Const $2 } @@ -1729,6 +1731,14 @@ desugarize fields = do Just toIdx -> return $ S.TableCopy toIdx fromIdx Nothing -> Left "unknown table" Nothing -> Left "unknown table" + synInstrToStruct FunCtx { ctxMod } (PlainInstr (TableSize tableIdx)) = + case getTableIndex ctxMod tableIdx of + Just tableIdx -> return $ S.TableSize tableIdx + Nothing -> Left "unknown table" + synInstrToStruct FunCtx { ctxMod } (PlainInstr (TableGrow tableIdx)) = + case getTableIndex ctxMod tableIdx of + Just tableIdx -> return $ S.TableGrow tableIdx + Nothing -> Left "unknown table" synInstrToStruct FunCtx { ctxMod } (PlainInstr (TableSet tableIdx)) = case getTableIndex ctxMod tableIdx of Just tableIdx -> return $ S.TableSet tableIdx diff --git a/src/Language/Wasm/Script.hs b/src/Language/Wasm/Script.hs index fb895a5..7872005 100644 --- a/src/Language/Wasm/Script.hs +++ b/src/Language/Wasm/Script.hs @@ -193,7 +193,8 @@ runScript onAssertFail script = do getFailureString Validate.GlobalIsImmutable = ["global is immutable"] getFailureString Validate.InvalidStartFunctionType = ["start function"] getFailureString Validate.InvalidTableType = ["size minimum must not be greater than maximum"] - getFailureString r = [TL.concat ["not implemented ", (TL.pack $ show r)]] + getFailureString (Validate.ElemIndexOutOfRange idx) = ["unknown elem segment " <> TL.pack (show idx)] + getFailureString r = [TL.concat ["not implemented ", TL.pack $ show r]] printFailedAssert :: String -> Assertion -> AssertM () printFailedAssert msg assert = do diff --git a/src/Language/Wasm/Validate.hs b/src/Language/Wasm/Validate.hs index 6eb9091..c8f068b 100644 --- a/src/Language/Wasm/Validate.hs +++ b/src/Language/Wasm/Validate.hs @@ -388,6 +388,15 @@ getInstrType (TableCopy toIdx fromIdx) = do let TableType _ toType = tables !! to when (fromType /= toType) $ throwError (RefTypeMismatch fromType toType) return $ [I32, I32, I32] ==> empty +getInstrType (TableSize tableIdx) = do + Ctx { tables } <- ask + when (length tables <= fromIntegral tableIdx) $ throwError (TableIndexOutOfRange tableIdx) + return $ empty ==> I32 +getInstrType (TableGrow tableIdx) = do + Ctx { tables } <- ask + when (length tables <= fromIntegral tableIdx) $ throwError (TableIndexOutOfRange tableIdx) + let TableType _ tableType = tables !! fromIntegral tableIdx + return $ [elemTypeToRefType tableType, I32] ==> I32 getInstrType (TableGet tableIdx) = do Ctx { tables } <- ask when (length tables <= fromIntegral tableIdx) $ throwError (TableIndexOutOfRange tableIdx) diff --git a/tests/Test.hs b/tests/Test.hs index b38edad..afdd538 100644 --- a/tests/Test.hs +++ b/tests/Test.hs @@ -19,7 +19,7 @@ main = do files <- filter (not . List.isPrefixOf "simd") . filter (List.isSuffixOf ".wast") <$> Directory.listDirectory "tests/spec" - let files = ["table_init.wast"] + let files = ["table_copy.wast"] scriptTestCases <- (`mapM` files) $ \file -> do test <- LBS.readFile ("tests/spec/" ++ file) return $ testCase file $ do