Skip to content

Commit

Permalink
* Add class ExtraContext(extraContext), used to compute extra
Browse files Browse the repository at this point in the history
  constraints for different types.
* Add an argument to the internal derive functions to pass
  a value which will be used by extraContext.
* Add a test case that shows deriving SafeCopy for a type rather
  than a type name, and supplying extra context for the instance.
  • Loading branch information
ddssff committed Dec 1, 2024
1 parent 14d9173 commit bd0147b
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 35 deletions.
82 changes: 53 additions & 29 deletions src/Data/SafeCopy/Derive.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{-# LANGUAGE TemplateHaskell, LambdaCase, FlexibleInstances, CPP #-}
{-# LANGUAGE TemplateHaskell, NoOverloadedStrings, LambdaCase, FlexibleInstances, CPP #-}

module Data.SafeCopy.Derive where

Expand Down Expand Up @@ -106,10 +106,10 @@ import Text.Regex.TDFA ((=~), MatchResult(MR))
-- version without any problems.
deriveSafeCopy :: Version a -> Name -> Name -> Q [Dec]
deriveSafeCopy versionId kindName tyName =
internalDeriveSafeCopy Normal versionId kindName (conT tyName)
internalDeriveSafeCopy Normal versionId kindName tyName (conT tyName)

deriveSafeCopy' :: Version a -> Name -> TypeQ -> Q [Dec]
deriveSafeCopy' versionId kindName typ = internalDeriveSafeCopy Normal versionId kindName typ
deriveSafeCopy' versionId kindName typ = internalDeriveSafeCopy Normal versionId kindName typ typ

deriveSafeCopyIndexedType :: Version a -> Name -> Name -> [Name] -> Q [Dec]
deriveSafeCopyIndexedType versionId kindName tyName =
Expand Down Expand Up @@ -174,7 +174,7 @@ deriveSafeCopySimple versionId kindName tyName =

deriveSafeCopySimple' :: Version a -> Name -> TypeQ -> Q [Dec]
deriveSafeCopySimple' versionId kindName typ =
internalDeriveSafeCopy Simple versionId kindName typ
internalDeriveSafeCopy Simple versionId kindName typ typ

deriveSafeCopySimpleIndexedType :: Version a -> Name -> Name -> [Name] -> Q [Dec]
deriveSafeCopySimpleIndexedType versionId kindName tyName =
Expand Down Expand Up @@ -230,10 +230,11 @@ deriveSafeCopySimpleIndexedType' versionId kindName typ =
-- without any problems.
deriveSafeCopyHappstackData :: Version a -> Name -> Name -> Q [Dec]
deriveSafeCopyHappstackData versionId kindName tyName =
deriveSafeCopyHappstackData' versionId kindName (conT tyName)
deriveSafeCopyHappstackData' versionId kindName (conT tyName) tyName

deriveSafeCopyHappstackData' :: Version a -> Name -> TypeQ -> Q [Dec]
deriveSafeCopyHappstackData' = internalDeriveSafeCopy HappstackData
deriveSafeCopyHappstackData' :: ExtraContext t => Version a -> Name -> TypeQ -> t -> Q [Dec]
deriveSafeCopyHappstackData' versionId kindName typq t =
internalDeriveSafeCopy HappstackData versionId kindName t typq

deriveSafeCopyHappstackDataIndexedType :: Version a -> Name -> Name -> [Name] -> Q [Dec]
deriveSafeCopyHappstackDataIndexedType versionId kindName tyName =
Expand Down Expand Up @@ -263,30 +264,53 @@ class ExtraContext a where
extraContext :: a -> Q Cxt

-- | Generate SafeCopy constraints for a list of type variables
instance ExtraContext Cxt where
extraContext context = pure context

instance ExtraContext [TyVarBndr] where
extraContext tyvars =
pure $ fmap (\var -> AppT (ConT ''SafeCopy) (VarT $ tyVarName var)) tyvars

internalDeriveSafeCopy :: DeriveType -> Version a -> Name -> TypeQ -> Q [Dec]
internalDeriveSafeCopy deriveType versionId kindName typq = do
instance ExtraContext Name where
extraContext tyName =
reify tyName >>= \case
TyConI (DataD _ _ tyvars _ _ _) -> extraContext tyvars
TyConI (NewtypeD _ _ tyvars _ _ _) -> extraContext tyvars
FamilyI _ _ -> pure []
info -> fail $ "Can't derive SafeCopy instance for: " ++ show (tyName, info)

instance ExtraContext TypeQ where
extraContext typq =
typq >>= \case
ConT tyName -> extraContext tyName
ForallT _ context _ -> pure context
typ -> fail $ "Can't derive SafeCopy instance for: " ++ show typ

internalDeriveSafeCopy :: ExtraContext t => DeriveType -> Version a -> Name -> t -> TypeQ -> Q [Dec]
internalDeriveSafeCopy deriveType versionId kindName t typq = do
typq >>= \case
typ@(ConT tyName) -> do
reify tyName >>= \case
TyConI (DataD context _name tyvars _kind cons _derivs)
| length cons > 255 -> fail $ "Can't derive SafeCopy instance for: " ++ show tyName ++
". The datatype must have less than 256 constructors."
| otherwise -> do
extra <- extraContext tyvars
worker1 deriveType versionId kindName tyName typ (context ++ extra) tyvars (zip [0..] cons)
ConT tyName -> doInfo deriveType versionId kindName t tyName =<< reify tyName
ForallT _ cxt' typ' -> internalDeriveSafeCopy deriveType versionId kindName cxt' (pure typ')
AppT t1 _t2 -> internalDeriveSafeCopy deriveType versionId kindName t (pure t1)
TupleT n -> let tyName = tupleTypeName n in doInfo deriveType versionId kindName t tyName =<< reify tyName
typ -> fail $ "Can't derive SafeCopy instance for: " ++ show typ

TyConI (NewtypeD context _name tyvars _kind con _derivs) ->
worker1 deriveType versionId kindName tyName typ context tyvars [(0, con)]
doInfo :: ExtraContext t => DeriveType -> Version a -> Name -> t -> Name -> Info -> Q [Dec]
doInfo deriveType versionId kindName t tyName info =
case info of
TyConI (DataD context _name tyvars _kind cons _derivs)
| length cons > 255 -> fail $ "Can't derive SafeCopy instance for: " ++ show tyName ++
". The datatype must have less than 256 constructors."
| otherwise -> do
extra <- extraContext t
worker1 deriveType versionId kindName tyName (ConT tyName) (context ++ extra) tyvars (zip [0..] cons)

FamilyI _ insts -> do
concat <$> (forM insts $ withInst typ (worker1 deriveType versionId kindName tyName))
info -> fail $ "Can't derive SafeCopy instance for: " ++ show (tyName, info)
-- typ@(Forall tyvars cxt' typ') -> undefined
typ -> fail $ "Can't derive SafeCopy instance for: " ++ show typ
TyConI (NewtypeD context _name tyvars _kind con _derivs) ->
worker1 deriveType versionId kindName tyName (ConT tyName) context tyvars [(0, con)]

FamilyI _ insts -> do
concat <$> (forM insts $ withInst (ConT tyName) (worker1 deriveType versionId kindName tyName))
_ -> fail $ "Can't derive SafeCopy instance for: " ++ show (tyName, info)

internalDeriveSafeCopyIndexedType :: DeriveType -> Version a -> Name -> TypeQ -> [Name] -> Q [Dec]
internalDeriveSafeCopyIndexedType deriveType versionId kindName typq tyIndex' = do
Expand Down Expand Up @@ -517,11 +541,11 @@ fixChars s =
-- * Remove suffixes on the four ids we export
-- * Leave suffixes on all variables and type variables
safeName :: Name -> Name
safeName (Name oc (NameG _ns _pn _mn)) = traceShowId $ Name oc NameS
safeName (Name oc (NameQ _mn)) = traceShowId $ Name oc NameS
safeName (Name oc@(OccName _) (NameU _)) = traceShowId $ Name oc NameS
safeName name@(Name _ (NameL _)) = traceShowId $ name -- Not seeing any of these
safeName name@(Name _ NameS) = traceShowId $ name
safeName (Name oc (NameG _ns _pn _mn)) = Name oc NameS
safeName (Name oc (NameQ _mn)) = Name oc NameS
safeName (Name oc@(OccName _) (NameU _)) = Name oc NameS
safeName name@(Name _ (NameL _)) = name -- Not seeing any of these
safeName name@(Name _ NameS) = name

-- This will probably make the expression invalid, but it
-- removes random elements that will make tests fail.
Expand Down
22 changes: 16 additions & 6 deletions test/instances.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import Data.SafeCopy.Internal (pprWithoutSuffixes)
import Data.Serialize (runPut, runGet)
import Data.Time (UniversalTime(..), ZonedTime(..))
import Data.Tree (Tree)
import Data.Typeable (Typeable)
import Language.Haskell.TH
import Language.Haskell.TH.Instances ()
import Language.Haskell.TH.Syntax
Expand Down Expand Up @@ -119,11 +120,20 @@ main = defaultMain $ testGroup "SafeCopy instances"
[ testCase "deriveSafeCopy 0 'base ''(,,,,,,,)" $ do
let decs = $(lift =<< deriveSafeCopy 0 'base ''(,,,,,,,))
pprWithoutSuffixes ppr decs @?= intercalate "\n"
["instance (SafeCopy a, SafeCopy b, SafeCopy c, SafeCopy d, SafeCopy e, SafeCopy f, SafeCopy g, SafeCopy h) => SafeCopy ((,,,,,,,) a b c d e f g h)",
" where putCopy ((,,,,,,,) a1 a2 a3 a4 a5 a6 a7 a8) = contain (do {safePut_a <- getSafePut; safePut_b <- getSafePut; safePut_c <- getSafePut; safePut_d <- getSafePut; safePut_e <- getSafePut; safePut_f <- getSafePut; safePut_g <- getSafePut; safePut_h <- getSafePut; safePut_a a1; safePut_b a2; safePut_c a3; safePut_d a4; safePut_e a5; safePut_f a6; safePut_g a7; safePut_h a8; return ()})",
" getCopy = contain (label \"(,,,,,,,):\" (do {safeGet_a <- getSafeGet; safeGet_b <- getSafeGet; safeGet_c <- getSafeGet; safeGet_d <- getSafeGet; safeGet_e <- getSafeGet; safeGet_f <- getSafeGet; safeGet_g <- getSafeGet; safeGet_h <- getSafeGet; (((((((return (,,,,,,,) <*> safeGet_a) <*> safeGet_b) <*> safeGet_c) <*> safeGet_d) <*> safeGet_e) <*> safeGet_f) <*> safeGet_g) <*> safeGet_h}))",
" version = 0",
" kind = base",
" errorTypeName _ = \"(,,,,,,,)\""]
["instance (SafeCopy a, SafeCopy b, SafeCopy c, SafeCopy d, SafeCopy e, SafeCopy f, SafeCopy g, SafeCopy h) => SafeCopy ((,,,,,,,) a b c d e f g h)",
" where putCopy ((,,,,,,,) a1 a2 a3 a4 a5 a6 a7 a8) = contain (do {safePut_a <- getSafePut; safePut_b <- getSafePut; safePut_c <- getSafePut; safePut_d <- getSafePut; safePut_e <- getSafePut; safePut_f <- getSafePut; safePut_g <- getSafePut; safePut_h <- getSafePut; safePut_a a1; safePut_b a2; safePut_c a3; safePut_d a4; safePut_e a5; safePut_f a6; safePut_g a7; safePut_h a8; return ()})",
" getCopy = contain (label \"(,,,,,,,):\" (do {safeGet_a <- getSafeGet; safeGet_b <- getSafeGet; safeGet_c <- getSafeGet; safeGet_d <- getSafeGet; safeGet_e <- getSafeGet; safeGet_f <- getSafeGet; safeGet_g <- getSafeGet; safeGet_h <- getSafeGet; (((((((return (,,,,,,,) <*> safeGet_a) <*> safeGet_b) <*> safeGet_c) <*> safeGet_d) <*> safeGet_e) <*> safeGet_f) <*> safeGet_g) <*> safeGet_h}))",
" version = 0",
" kind = base",
" errorTypeName _ = \"(,,,,,,,)\""]
, testCase "deriveSafeCopy' 0 'base [t(,,,,,,,)|]" $ do
let decs = $(lift =<< deriveSafeCopy' 0 'base [t|forall a b c d e f g h. (Show a, Typeable a, SafeCopy a, SafeCopy b, SafeCopy c, SafeCopy d, SafeCopy e, SafeCopy f, SafeCopy g, SafeCopy h) => (a,b,c,d,e,f,g,h)|])
pprWithoutSuffixes ppr decs @?= intercalate "\n"
["instance (Show a, Typeable a, SafeCopy a, SafeCopy b, SafeCopy c, SafeCopy d, SafeCopy e, SafeCopy f, SafeCopy g, SafeCopy h) => SafeCopy ((,,,,,,,) a b c d e f g h)",
" where putCopy ((,,,,,,,) a1 a2 a3 a4 a5 a6 a7 a8) = contain (do {safePut_a <- getSafePut; safePut_b <- getSafePut; safePut_c <- getSafePut; safePut_d <- getSafePut; safePut_e <- getSafePut; safePut_f <- getSafePut; safePut_g <- getSafePut; safePut_h <- getSafePut; safePut_a a1; safePut_b a2; safePut_c a3; safePut_d a4; safePut_e a5; safePut_f a6; safePut_g a7; safePut_h a8; return ()})",
" getCopy = contain (label \"(,,,,,,,):\" (do {safeGet_a <- getSafeGet; safeGet_b <- getSafeGet; safeGet_c <- getSafeGet; safeGet_d <- getSafeGet; safeGet_e <- getSafeGet; safeGet_f <- getSafeGet; safeGet_g <- getSafeGet; safeGet_h <- getSafeGet; (((((((return (,,,,,,,) <*> safeGet_a) <*> safeGet_b) <*> safeGet_c) <*> safeGet_d) <*> safeGet_e) <*> safeGet_f) <*> safeGet_g) <*> safeGet_h}))",
" version = 0",
" kind = base",
" errorTypeName _ = \"(,,,,,,,)\""]
]
]

0 comments on commit bd0147b

Please sign in to comment.