forked from jaspervdj/websockets
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Stream.hs
200 lines (175 loc) · 7.4 KB
/
Stream.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
--------------------------------------------------------------------------------
-- | Lightweight abstraction over an input/output stream.
{-# LANGUAGE CPP #-}
module Network.WebSockets.Stream
( Stream
, makeStream
, makeSocketStream
, makeEchoStream
, parse
, parseBin
, write
, close
) where
import qualified Data.Binary.Get as BIN
import Control.Concurrent.MVar (MVar, newEmptyMVar, newMVar,
putMVar, takeMVar, withMVar)
import Control.Exception (onException, throwIO)
import Control.Monad (forM_, when)
import qualified Data.Attoparsec.ByteString as Atto
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import Data.IORef (IORef, atomicModifyIORef,
newIORef, readIORef,
writeIORef)
import qualified Network.Socket as S
import qualified Network.Socket.ByteString as SB (recv)
#if !defined(mingw32_HOST_OS)
import qualified Network.Socket.ByteString.Lazy as SBL (sendAll)
#else
import qualified Network.Socket.ByteString as SB (sendAll)
#endif
import Network.WebSockets.Types
--------------------------------------------------------------------------------
-- | State of the stream
data StreamState
= Closed !B.ByteString -- Remainder
| Open !B.ByteString -- Buffer
--------------------------------------------------------------------------------
-- | Lightweight abstraction over an input/output stream.
data Stream = Stream
{ streamIn :: IO (Maybe B.ByteString)
, streamOut :: (Maybe BL.ByteString -> IO ())
, streamState :: !(IORef StreamState)
}
--------------------------------------------------------------------------------
-- | Create a stream from a "receive" and "send" action. The following
-- properties apply:
--
-- - Regardless of the provided "receive" and "send" functions, reading and
-- writing from the stream will be thread-safe, i.e. this function will create
-- a receive and write lock to be used internally.
--
-- - Reading from or writing or to a closed 'Stream' will always throw an
-- exception, even if the underlying "receive" and "send" functions do not
-- (we do the bookkeeping).
--
-- - Streams should always be closed.
makeStream
:: IO (Maybe B.ByteString) -- ^ Reading
-> (Maybe BL.ByteString -> IO ()) -- ^ Writing
-> IO Stream -- ^ Resulting stream
makeStream receive send = do
ref <- newIORef (Open B.empty)
receiveLock <- newMVar ()
sendLock <- newMVar ()
return $ Stream (receive' ref receiveLock) (send' ref sendLock) ref
where
closeRef :: IORef StreamState -> IO ()
closeRef ref = atomicModifyIORef ref $ \state -> case state of
Open buf -> (Closed buf, ())
Closed buf -> (Closed buf, ())
assertNotClosed :: IORef StreamState -> IO a -> IO a
assertNotClosed ref io = do
state <- readIORef ref
case state of
Closed _ -> throwIO ConnectionClosed
Open _ -> io
receive' :: IORef StreamState -> MVar () -> IO (Maybe B.ByteString)
receive' ref lock = withMVar lock $ \() -> assertNotClosed ref $ do
mbBs <- onException receive (closeRef ref)
case mbBs of
Nothing -> closeRef ref >> return Nothing
Just bs -> return (Just bs)
send' :: IORef StreamState -> MVar () -> (Maybe BL.ByteString -> IO ())
send' ref lock mbBs = withMVar lock $ \() -> assertNotClosed ref $ do
when (mbBs == Nothing) (closeRef ref)
onException (send mbBs) (closeRef ref)
--------------------------------------------------------------------------------
makeSocketStream :: S.Socket -> IO Stream
makeSocketStream socket = makeStream receive send
where
receive = do
bs <- SB.recv socket 1024
return $ if B.null bs then Nothing else Just bs
send Nothing = return ()
send (Just bs) = do
#if !defined(mingw32_HOST_OS)
SBL.sendAll socket bs
#else
forM_ (BL.toChunks bs) (SB.sendAll socket)
#endif
--------------------------------------------------------------------------------
makeEchoStream :: IO Stream
makeEchoStream = do
mvar <- newEmptyMVar
makeStream (takeMVar mvar) $ \mbBs -> case mbBs of
Nothing -> putMVar mvar Nothing
Just bs -> forM_ (BL.toChunks bs) $ \c -> putMVar mvar (Just c)
--------------------------------------------------------------------------------
parseBin :: Stream -> BIN.Get a -> IO (Maybe a)
parseBin stream parser = do
state <- readIORef (streamState stream)
case state of
Closed remainder
| B.null remainder -> return Nothing
| otherwise -> go (BIN.runGetIncremental parser `BIN.pushChunk` remainder) True
Open buffer
| B.null buffer -> do
mbBs <- streamIn stream
case mbBs of
Nothing -> do
writeIORef (streamState stream) (Closed B.empty)
return Nothing
Just bs -> go (BIN.runGetIncremental parser `BIN.pushChunk` bs) False
| otherwise -> go (BIN.runGetIncremental parser `BIN.pushChunk` buffer) False
where
-- Buffer is empty when entering this function.
go (BIN.Done remainder _ x) closed = do
writeIORef (streamState stream) $
if closed then Closed remainder else Open remainder
return (Just x)
go (BIN.Partial f) closed
| closed = go (f Nothing) True
| otherwise = do
mbBs <- streamIn stream
case mbBs of
Nothing -> go (f Nothing) True
Just bs -> go (f (Just bs)) False
go (BIN.Fail _ _ err) _ = throwIO (ParseException err)
parse :: Stream -> Atto.Parser a -> IO (Maybe a)
parse stream parser = do
state <- readIORef (streamState stream)
case state of
Closed remainder
| B.null remainder -> return Nothing
| otherwise -> go (Atto.parse parser remainder) True
Open buffer
| B.null buffer -> do
mbBs <- streamIn stream
case mbBs of
Nothing -> do
writeIORef (streamState stream) (Closed B.empty)
return Nothing
Just bs -> go (Atto.parse parser bs) False
| otherwise -> go (Atto.parse parser buffer) False
where
-- Buffer is empty when entering this function.
go (Atto.Done remainder x) closed = do
writeIORef (streamState stream) $
if closed then Closed remainder else Open remainder
return (Just x)
go (Atto.Partial f) closed
| closed = go (f B.empty) True
| otherwise = do
mbBs <- streamIn stream
case mbBs of
Nothing -> go (f B.empty) True
Just bs -> go (f bs) False
go (Atto.Fail _ _ err) _ = throwIO (ParseException err)
--------------------------------------------------------------------------------
write :: Stream -> BL.ByteString -> IO ()
write stream = streamOut stream . Just
--------------------------------------------------------------------------------
close :: Stream -> IO ()
close stream = streamOut stream Nothing