diff --git a/warp/Network/Wai/Handler/Warp/HTTP1.hs b/warp/Network/Wai/Handler/Warp/HTTP1.hs index 0d16d2423..5e11239d5 100644 --- a/warp/Network/Wai/Handler/Warp/HTTP1.hs +++ b/warp/Network/Wai/Handler/Warp/HTTP1.hs @@ -115,7 +115,7 @@ http1server -> Source -> IO () http1server settings ii conn transport app addr th istatus src = - loop True `UnliftIO.catchAny` handler + loop FirstRequest `UnliftIO.catchAny` handler where handler e -- See comment below referencing @@ -154,18 +154,22 @@ http1server settings ii conn transport app addr th istatus src = `UnliftIO.catchAny` \e -> do settingsOnException settings (Just req) e -- Don't throw the error again to prevent calling settingsOnException twice. - return False + return CloseConnection -- When doing a keep-alive connection, the other side may just -- close the connection. We don't want to treat that as an - -- exceptional situation, so we pass in False to http1 (which - -- in turn passes in False to recvRequest), indicating that + -- exceptional situation, so we pass in SubsequentRequest to http1 (which + -- in turn passes in SubsequentRequest to recvRequest), indicating that -- this is not the first request. If, when trying to read the -- request headers, no data is available, recvRequest will -- throw a NoKeepAliveRequest exception, which we catch here -- and ignore. See: https://github.com/yesodweb/wai/issues/618 - when keepAlive $ loop False + case keepAlive of + ReuseConnection -> loop SubsequentRequest + CloseConnection -> return () + +data ReuseConnection = ReuseConnection | CloseConnection processRequest :: Settings @@ -179,7 +183,7 @@ processRequest -> Maybe (IORef Int) -> IndexedHeader -> IO ByteString - -> IO Bool + -> IO ReuseConnection processRequest settings ii conn app th istatus src req mremainingRef idxhdr nextBodyFlush = do -- Let the application run for as long as it wants T.pause th @@ -226,7 +230,7 @@ processRequest settings ii conn app th istatus src req mremainingRef idxhdr next Nothing -> do flushEntireBody nextBodyFlush T.resume th - return True + return ReuseConnection Just maxToRead -> do let tryKeepAlive = do -- flush the rest of the request body @@ -234,16 +238,16 @@ processRequest settings ii conn app th istatus src req mremainingRef idxhdr next if isComplete then do T.resume th - return True - else return False + return ReuseConnection + else return CloseConnection case mremainingRef of Just ref -> do remaining <- readIORef ref if remaining <= maxToRead then tryKeepAlive - else return False + else return CloseConnection Nothing -> tryKeepAlive - else return False + else return CloseConnection sendErrorResponse :: Settings diff --git a/warp/Network/Wai/Handler/Warp/Internal.hs b/warp/Network/Wai/Handler/Warp/Internal.hs index 1fc6983b7..e686272f0 100644 --- a/warp/Network/Wai/Handler/Warp/Internal.hs +++ b/warp/Network/Wai/Handler/Warp/Internal.hs @@ -75,6 +75,7 @@ module Network.Wai.Handler.Warp.Internal ( -- * Request and response Source, + FirstRequest (..), recvRequest, sendResponse, diff --git a/warp/Network/Wai/Handler/Warp/Request.hs b/warp/Network/Wai/Handler/Warp/Request.hs index 0b51d4a68..e9632d56a 100644 --- a/warp/Network/Wai/Handler/Warp/Request.hs +++ b/warp/Network/Wai/Handler/Warp/Request.hs @@ -4,6 +4,7 @@ {-# OPTIONS_GHC -fno-warn-deprecations #-} module Network.Wai.Handler.Warp.Request ( + FirstRequest(..), recvRequest, headerLines, pauseTimeoutKey, @@ -50,11 +51,13 @@ import Network.Wai.Handler.Warp.Settings ( ---------------------------------------------------------------- +-- | first request on this connection? +data FirstRequest = FirstRequest | SubsequentRequest + -- | Receiving a HTTP request from 'Connection' and parsing its header -- to create 'Request'. recvRequest - :: Bool - -- ^ first request on this connection? + :: FirstRequest -> Settings -> Connection -> InternalInfo @@ -118,7 +121,7 @@ recvRequest firstRequest settings conn ii th addr src transport = do ---------------------------------------------------------------- -headerLines :: Int -> Bool -> Source -> IO [ByteString] +headerLines :: Int -> FirstRequest -> Source -> IO [ByteString] headerLines maxTotalHeaderLength firstRequest src = do bs <- readSource src if S.null bs @@ -127,9 +130,9 @@ headerLines maxTotalHeaderLength firstRequest src = do -- lack of data as a real exception. See the http1 function in -- the Run module for more details. - if firstRequest - then throwIO ConnectionClosedByPeer - else throwIO NoKeepAliveRequest + case firstRequest of + FirstRequest -> throwIO ConnectionClosedByPeer + SubsequentRequest -> throwIO NoKeepAliveRequest else push maxTotalHeaderLength src (THStatus 0 0 id id) bs data NoKeepAliveRequest = NoKeepAliveRequest diff --git a/warp/bench/Parser.hs b/warp/bench/Parser.hs index 0c5ec19fb..ea58d3fe3 100644 --- a/warp/bench/Parser.hs +++ b/warp/bench/Parser.hs @@ -19,7 +19,7 @@ import qualified Network.HTTP.Types as H import UnliftIO.Exception (impureThrow, throwIO) import Prelude hiding (lines) -import Network.Wai.Handler.Warp.Request (headerLines) +import Network.Wai.Handler.Warp.Request (FirstRequest (..), headerLines) import Network.Wai.Handler.Warp.Types #if MIN_VERSION_gauge(0, 2, 0) @@ -61,7 +61,7 @@ main = do ] ] where - testIt req = producer req >>= headerLines 800 False + testIt req = producer req >>= headerLines 800 FirstRequest ---------------------------------------------------------------- diff --git a/warp/test/RequestSpec.hs b/warp/test/RequestSpec.hs index 6eafe3b2d..b93af3060 100644 --- a/warp/test/RequestSpec.hs +++ b/warp/test/RequestSpec.hs @@ -70,7 +70,7 @@ spec = do describe "headerLines" $ do let parseHeaderLine chunks = do src <- mkSourceFunc chunks >>= mkSource - x <- headerLines defaultMaxTotalHeaderLength True src + x <- headerLines defaultMaxTotalHeaderLength FirstRequest src x `shouldBe` ["Status: 200", "Content-Type: text/plain"] it "can handle a normal case" $ @@ -95,9 +95,9 @@ spec = do it "can (not) handle an illegal case (1)" $ do let chunks = ["\nStatus:", "\n 200", "\nContent-Type: text/plain", "\r\n\r\n"] src <- mkSourceFunc chunks >>= mkSource - x <- headerLines defaultMaxTotalHeaderLength True src + x <- headerLines defaultMaxTotalHeaderLength FirstRequest src x `shouldBe` [] - y <- headerLines defaultMaxTotalHeaderLength True src + y <- headerLines defaultMaxTotalHeaderLength FirstRequest src y `shouldBe` ["Status:", " 200", "Content-Type: text/plain"] let testLengthHeaders = ["Sta", "tus: 200\r", "\n", "Content-Type: ", "text/plain\r\n\r\n"] @@ -106,12 +106,12 @@ spec = do -- Length is 39, this shouldn't fail it "doesn't throw on correct length" $ do src <- mkSourceFunc testLengthHeaders >>= mkSource - x <- headerLines testLength True src + x <- headerLines testLength FirstRequest src x `shouldBe` ["Status: 200", "Content-Type: text/plain"] -- Length is still 39, this should fail it "throws error on correct length too long" $ do src <- mkSourceFunc testLengthHeaders >>= mkSource - headerLines (testLength - 1) True src `shouldThrow` (== OverLargeHeader) + headerLines (testLength - 1) FirstRequest src `shouldThrow` (== OverLargeHeader) where blankSafe = headerLinesList ["f", "oo\n", "bar\nbaz\n\r\n"] whiteSafe = headerLinesList ["foo\r\nbar\r\nbaz\r\n\r\n hi there"] @@ -135,7 +135,7 @@ headerLinesList' orig = do writeIORef ref z return y src' <- mkSource src - res <- headerLines defaultMaxTotalHeaderLength True src' + res <- headerLines defaultMaxTotalHeaderLength FirstRequest src' return (res, src') consumeLen :: Int -> Source -> IO S8.ByteString