module Snap.Internal.Http.Server.SimpleBackend
( simpleEventLoop
) where
import Control.Monad.Trans
import Control.Concurrent hiding (yield)
import Control.Concurrent.Extended (forkOnLabeledWithUnmaskBs)
import Control.Exception
import Control.Monad
import Data.ByteString (ByteString)
import qualified Data.ByteString as S
import qualified Data.ByteString.Char8 as SC
import Data.ByteString.Internal (c2w)
import Foreign hiding (new)
import Foreign.C
#if MIN_VERSION_base(4,4,0)
import GHC.Conc (forkOn, labelThread)
#else
import GHC.Conc (forkOnIO,
labelThread)
#endif
import Network.Socket
#if !MIN_VERSION_base(4,6,0)
import Prelude hiding (catch)
#endif
import Snap.Internal.Debug
import Snap.Internal.Http.Server.Address
import Snap.Internal.Http.Server.Backend
import Snap.Internal.Http.Server.Date
import qualified Snap.Internal.Http.Server.ListenHelpers as Listen
import Snap.Internal.Http.Server.TimeoutManager (TimeoutManager)
import qualified Snap.Internal.Http.Server.TimeoutManager as TM
import Snap.Iteratee hiding (map)
#if defined(HAS_SENDFILE)
import System.Posix.IO
import System.Posix.Types (Fd (..))
import qualified System.SendFile as SF
#endif
#if !MIN_VERSION_base(4,4,0)
forkOn :: Int -> IO () -> IO ThreadId
forkOn = forkOnIO
#endif
data EventLoopCpu = EventLoopCpu
{ _boundCpu :: Int
, _acceptThreads :: [ThreadId]
, _timeoutManager :: TimeoutManager
, _exitMVar :: !(MVar ())
}
simpleEventLoop :: EventLoop
simpleEventLoop defaultTimeout sockets cap elog initial handler = do
loops <- Prelude.mapM (newLoop defaultTimeout sockets handler elog)
[0..(cap1)]
initial
debug "simpleEventLoop: waiting for mvars"
Prelude.mapM_ (takeMVar . _exitMVar) loops `finally` do
debug "simpleEventLoop: killing all threads"
_ <- mapM_ stopLoop loops
mapM_ Listen.closeSocket sockets
newLoop :: Int
-> [ListenSocket]
-> SessionHandler
-> (S.ByteString -> IO ())
-> Int
-> IO EventLoopCpu
newLoop defaultTimeout sockets handler elog cpu = do
tmgr <- TM.initialize defaultTimeout getCurrentDateTime
exit <- newEmptyMVar
accThreads <- forM sockets $ \p -> do
let label = S.concat
[ "snap-server: ", SC.pack (show p)
, " on capability: ", SC.pack (show cpu)
]
forkOnLabeledWithUnmaskBs label cpu $ \unmask ->
acceptThread defaultTimeout handler tmgr elog cpu p unmask
`finally` (tryPutMVar exit () >> return ())
return $! EventLoopCpu cpu accThreads tmgr exit
stopLoop :: EventLoopCpu -> IO ()
stopLoop loop = mask_ $ do
TM.stop $ _timeoutManager loop
Prelude.mapM_ killThread $ _acceptThreads loop
acceptThread :: Int
-> SessionHandler
-> TimeoutManager
-> (S.ByteString -> IO ())
-> Int
-> ListenSocket
-> (forall a. IO a -> IO a)
-> IO ()
acceptThread defaultTimeout handler tmgr elog cpu sock unmask = loop
where
loop = do
unmask (forever acceptAndFork) `catches` acceptHandler
loop
acceptAndFork = do
debug $ "acceptThread: calling accept() on socket " ++ show sock
(s,addr) <- accept $ Listen.listenSocket sock
setSocketOption s NoDelay 1
debug $ "acceptThread: accepted connection from remote: " ++ show addr
let label = S.concat
[ "snap-server: connection from: "
, SC.pack (show addr)
, " on socket: "
, SC.pack (show (fdSocket s))
, "\0"
]
_ <- forkOnLabeledWithUnmaskBs label cpu $ \unmask' ->
unmask' (runSession defaultTimeout handler tmgr sock s addr)
`catches` cleanup
return ()
acceptHandler =
[ Handler $ \(e :: AsyncException) -> throwIO e
, Handler $ \(e :: SomeException) -> do
elog $ S.concat [ "SimpleBackend.acceptThread: accept threw: "
, S.pack . map c2w $ show e ]
threadDelay $ 10000
]
cleanup =
[
Handler $ \(e :: AsyncException) ->
case e of
ThreadKilled -> return ()
UserInterrupt -> return ()
_ -> throwIO e
, Handler $ \(e :: SomeException) -> elog
$ S.concat [ "SimpleBackend.acceptThread: "
, S.pack . map c2w $ show e]
]
runSession :: Int
-> SessionHandler
-> TimeoutManager
-> ListenSocket
-> Socket
-> SockAddr -> IO ()
runSession defaultTimeout handler tmgr lsock sock addr = do
let fd = fdSocket sock
curId <- myThreadId
debug $ "Backend.withConnection: running session: " ++ show addr
(rport,rhost) <- getAddress addr
(lport,lhost) <- getSocketName sock >>= getAddress
let sinfo = SessionInfo lhost lport rhost rport $ Listen.isSecure lsock
timeoutHandle <- TM.register (killThread curId) tmgr
let modifyTimeout = TM.modify timeoutHandle
let tickleTimeout = modifyTimeout . max
bracket (Listen.createSession lsock 8192 fd
(threadWaitRead $ fromIntegral fd))
(\session -> mask_ $ do
debug "thread killed, closing socket"
TM.cancel timeoutHandle
eatException $ Listen.endSession lsock session
eatException $ shutdown sock ShutdownBoth
eatException $ sClose sock
)
(\s -> let writeEnd = writeOut lsock s sock
(tickleTimeout defaultTimeout)
in handler sinfo
(enumerate lsock s sock)
writeEnd
(sendFile lsock (tickleTimeout defaultTimeout)
fd writeEnd)
modifyTimeout
)
eatException :: IO a -> IO ()
eatException act = (act >> return ()) `catch` \(_::SomeException) -> return ()
sendFile :: ListenSocket
-> IO ()
-> CInt
-> Iteratee ByteString IO ()
-> FilePath
-> Int64
-> Int64
-> IO ()
#if defined(HAS_SENDFILE)
sendFile lsock tickle sock writeEnd fp start sz =
case lsock of
ListenHttp _ -> bracket (openFd fp ReadOnly Nothing defaultFileFlags)
(closeFd)
(go start sz)
_ -> do
step <- runIteratee writeEnd
run_ $ enumFilePartial fp (start,start+sz) step
where
go off bytes fd
| bytes == 0 = return ()
| otherwise = do
sent <- SF.sendFile (threadWaitWrite $ fromIntegral sock)
sfd fd off bytes
if sent < bytes
then tickle >> go (off+sent) (bytessent) fd
else return ()
sfd = Fd sock
#else
sendFile _ _ _ writeEnd fp start sz = do
step <- runIteratee writeEnd
run_ $ enumFilePartial fp (start,start+sz) step
return ()
#endif
enumerate :: (MonadIO m)
=> ListenSocket
-> NetworkSession
-> Socket
-> Enumerator ByteString m a
enumerate port session sock = loop
where
dbg s = debug $ "SimpleBackend.enumerate(" ++ show (_socket session)
++ "): " ++ s
loop (Continue k) = do
dbg "reading from socket"
s <- liftIO $ timeoutRecv
case s of
Nothing -> do
dbg "got EOF from socket"
sendOne k ""
Just s' -> do
dbg $ "got " ++ Prelude.show (S.length s')
++ " bytes from read end"
sendOne k s'
loop x = returnI x
sendOne k s | S.null s = do
dbg "sending EOF to continuation"
enumEOF $ Continue k
| otherwise = do
dbg $ "sending " ++ show s ++ " to continuation"
step <- lift $ runIteratee $ k $ Chunks [s]
case step of
(Yield x st) -> do
dbg $ "got yield, remainder is " ++ show st
yield x st
r@(Continue _) -> do
dbg $ "got continue"
loop r
(Error e) -> throwError e
fd = fdSocket sock
#ifdef PORTABLE
timeoutRecv = Listen.recv port sock (threadWaitRead $
fromIntegral fd) session
#else
timeoutRecv = Listen.recv port (threadWaitRead $
fromIntegral fd) session
#endif
writeOut :: (MonadIO m)
=> ListenSocket
-> NetworkSession
-> Socket
-> (IO ())
-> Iteratee ByteString m ()
writeOut port session sock tickle = loop
where
dbg s = debug $ "SimpleBackend.writeOut(" ++ show (_socket session)
++ "): " ++ s
loop = continue k
k EOF = yield () EOF
k (Chunks xs) = do
let s = S.concat xs
let n = S.length s
dbg $ "got chunk with " ++ show n ++ " bytes"
ee <- liftIO $ try $ timeoutSend s
case ee of
(Left (e::SomeException)) -> do
dbg $ "timeoutSend got error " ++ show e
throwError e
(Right _) -> do
let last10 = S.drop (n10) s
dbg $ "wrote " ++ show n ++ " bytes, last 10=" ++ show last10
loop
fd = fdSocket sock
#ifdef PORTABLE
timeoutSend = Listen.send port sock tickle
(threadWaitWrite $ fromIntegral fd) session
#else
timeoutSend = Listen.send port tickle
(threadWaitWrite $ fromIntegral fd) session
#endif