{-# LANGUAGE ForeignFunctionInterface, OverloadedStrings #-}
{-# LANGUAGE CPP #-}
module Network.Socket.BufferPool.Recv (
receive
, makeRecvN
) where
import qualified Data.ByteString as BS
import Data.ByteString.Internal (ByteString(..), unsafeCreate)
import Data.IORef
import Foreign.C.Error (eAGAIN, getErrno, throwErrno)
import Foreign.C.Types
import Foreign.Ptr (Ptr, castPtr)
import GHC.Conc (threadWaitRead)
import Network.Socket (Socket, withFdSocket)
import System.Posix.Types (Fd(..))
#ifdef mingw32_HOST_OS
import GHC.IO.FD (FD(..), readRawBufferPtr)
import Network.Socket.BufferPool.Windows
#endif
import Network.Socket.BufferPool.Types
import Network.Socket.BufferPool.Buffer
receive :: Socket -> BufferPool -> Recv
receive :: Socket -> BufferPool -> Recv
receive Socket
sock BufferPool
pool = BufferPool -> (Buffer -> Int -> IO Int) -> Recv
withBufferPool BufferPool
pool ((Buffer -> Int -> IO Int) -> Recv)
-> (Buffer -> Int -> IO Int) -> Recv
forall a b. (a -> b) -> a -> b
$ \Buffer
ptr Int
size -> do
#if MIN_VERSION_network(3,1,0)
Socket -> (CInt -> IO Int) -> IO Int
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
sock ((CInt -> IO Int) -> IO Int) -> (CInt -> IO Int) -> IO Int
forall a b. (a -> b) -> a -> b
$ \CInt
fd -> do
#elif MIN_VERSION_network(3,0,0)
fd <- fdSocket sock
#else
let fd = fdSocket sock
#endif
let size' :: CSize
size' = Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
size
CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> Int) -> IO CInt -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CInt -> Buffer -> CSize -> IO CInt
tryReceive CInt
fd Buffer
ptr CSize
size'
tryReceive :: CInt -> Buffer -> CSize -> IO CInt
tryReceive :: CInt -> Buffer -> CSize -> IO CInt
tryReceive CInt
sock Buffer
ptr CSize
size = IO CInt
go
where
go :: IO CInt
go = do
#ifdef mingw32_HOST_OS
bytes <- windowsThreadBlockHack $ fromIntegral <$> readRawBufferPtr "tryReceive" (FD sock 1) (castPtr ptr) 0 size
#else
bytes <- CInt -> Ptr CChar -> CSize -> CInt -> IO CInt
c_recv CInt
sock (Buffer -> Ptr CChar
forall a b. Ptr a -> Ptr b
castPtr Buffer
ptr) CSize
size CInt
0
#endif
if bytes == -1 then do
errno <- getErrno
if errno == eAGAIN then do
threadWaitRead (Fd sock)
go
else
throwErrno "tryReceive"
else
return bytes
makeRecvN :: ByteString -> Recv -> IO RecvN
makeRecvN :: ByteString -> Recv -> IO RecvN
makeRecvN ByteString
bs0 Recv
recv = do
ref <- ByteString -> IO (IORef ByteString)
forall a. a -> IO (IORef a)
newIORef ByteString
bs0
return $ recvN ref recv
recvN :: IORef ByteString -> Recv -> RecvN
recvN :: IORef ByteString -> Recv -> RecvN
recvN IORef ByteString
ref Recv
recv Int
size = do
cached <- IORef ByteString -> Recv
forall a. IORef a -> IO a
readIORef IORef ByteString
ref
(bs, leftover) <- tryRecvN cached size recv
writeIORef ref leftover
return bs
tryRecvN :: ByteString -> Int -> IO ByteString -> IO (ByteString, ByteString)
tryRecvN :: ByteString -> Int -> Recv -> IO (ByteString, ByteString)
tryRecvN ByteString
init0 Int
siz0 Recv
recv
| Int
siz0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
len0 = (ByteString, ByteString) -> IO (ByteString, ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((ByteString, ByteString) -> IO (ByteString, ByteString))
-> (ByteString, ByteString) -> IO (ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> (ByteString, ByteString)
BS.splitAt Int
siz0 ByteString
init0
| Bool
otherwise = ([ByteString] -> [ByteString])
-> Int -> IO (ByteString, ByteString)
go (ByteString
init0ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
:) (Int
siz0 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len0)
where
len0 :: Int
len0 = ByteString -> Int
BS.length ByteString
init0
go :: ([ByteString] -> [ByteString])
-> Int -> IO (ByteString, ByteString)
go [ByteString] -> [ByteString]
build Int
left = do
bs <- Recv
recv
let len = ByteString -> Int
BS.length ByteString
bs
if len == 0 then
return ("", "")
else if len >= left then do
let (consume, leftover) = BS.splitAt left bs
ret = Int -> [ByteString] -> ByteString
concatN Int
siz0 ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString]
build [ByteString
consume]
return (ret, leftover)
else do
let build' = [ByteString] -> [ByteString]
build ([ByteString] -> [ByteString])
-> ([ByteString] -> [ByteString]) -> [ByteString] -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString
bs ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
:)
left' = Int
left Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len
go build' left'
concatN :: Int -> [ByteString] -> ByteString
concatN :: Int -> [ByteString] -> ByteString
concatN Int
total [ByteString]
bss0 = Int -> (Buffer -> IO ()) -> ByteString
unsafeCreate Int
total ((Buffer -> IO ()) -> ByteString)
-> (Buffer -> IO ()) -> ByteString
forall a b. (a -> b) -> a -> b
$ \Buffer
ptr -> [ByteString] -> Buffer -> IO ()
goCopy [ByteString]
bss0 Buffer
ptr
where
goCopy :: [ByteString] -> Buffer -> IO ()
goCopy [] Buffer
_ = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
goCopy (ByteString
bs:[ByteString]
bss) Buffer
ptr = do
ptr' <- Buffer -> ByteString -> IO Buffer
copy Buffer
ptr ByteString
bs
goCopy bss ptr'
#ifndef mingw32_HOST_OS
foreign import ccall unsafe "recv"
c_recv :: CInt -> Ptr CChar -> CSize -> CInt -> IO CInt
#endif