Skip to content

Commit

Permalink
Merge pull request #24 from cpeikert/better_benchmarks2
Browse files Browse the repository at this point in the history
Better benchmarks
  • Loading branch information
cpeikert authored Aug 13, 2016
2 parents 377d5ae + 01443ee commit c4e1da8
Show file tree
Hide file tree
Showing 21 changed files with 912 additions and 212 deletions.
2 changes: 1 addition & 1 deletion lol/Crypto/Lol/Cyclotomic/CRTSentinel.hs
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,5 @@ embedCRTCS _ _ = fromJust embedCRT
twaceCRTCS :: (Tensor t, m `Divides` m', CRTrans Maybe r, TElt t r)
=> CSentinel t m' r -> CSentinel t m r -> t m' r -> t m r
twaceCRTCS _ _ = fromJust twaceCRT
{-# INLINABLE twaceCRTCS #-}
{-# INLINE twaceCRTCS #-}

10 changes: 5 additions & 5 deletions lol/Crypto/Lol/Cyclotomic/Cyc.hs
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ embed' (Sub (c :: Cyc t k r)) = embed' c
-- | The "tweaked trace" (twace) function
-- \(\Tw(x) = (\hat{m} / \hat{m}') \cdot \Tr((g' / g) \cdot x)\),
-- which fixes \(R\) pointwise (i.e., @twace . embed == id@).
twace :: forall t m m' r . (m `Divides` m', CElt t r)
twace :: forall t m m' r . (m `Divides` m', UCRTElt t r, ZeroTestable r)
=> Cyc t m' r -> Cyc t m r
{-# INLINABLE twace #-}
twace (Pow u) = Pow $ U.twacePow u
Expand Down Expand Up @@ -562,10 +562,10 @@ instance (Correct gad zq, Fact m, CElt t zq) => Correct gad (Cyc t m zq) where

---------- Change of representation (internal use only) ----------

toPow', toDec', toCRT' :: (Fact m, CElt t r) => Cyc t m r -> Cyc t m r
{-# INLINE toPow' #-}
{-# INLINE toDec' #-}
{-# INLINE toCRT' #-}
toPow', toDec', toCRT' :: (Fact m, UCRTElt t r, ZeroTestable r) => Cyc t m r -> Cyc t m r
{-# INLINABLE toPow' #-}
{-# INLINABLE toDec' #-}
{-# INLINABLE toCRT' #-}

-- | Force to powerful-basis representation (for internal use only).
toPow' c@(Pow _) = c
Expand Down
8 changes: 5 additions & 3 deletions lol/Crypto/Lol/Cyclotomic/Tensor.hs
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ mulGCRT, divGCRT, crt, crtInv ::
{-# INLINABLE mulGCRT #-}
{-# INLINABLE divGCRT #-}
{-# INLINABLE crt #-}
{-# INLINABLE crtInv #-}
{-# INLINE crtInv #-}

-- | Multiply by \(g_m\) in the CRT basis. (This function is simply an
-- appropriate entry from 'crtFuncs'.)
Expand All @@ -223,6 +223,7 @@ crtInv = (\(_,_,_,_,f) -> f) <$> crtFuncs
-- (This function is simply an appropriate entry from 'crtExtFuncs'.)
twaceCRT :: forall t m m' mon r . (CRTrans mon r, Tensor t, m `Divides` m', TElt t r)
=> mon (t m' r -> t m r)
{-# INLINABLE twaceCRT #-}
twaceCRT = proxyT hasCRTFuncs (Proxy::Proxy (t m' r)) *>
proxyT hasCRTFuncs (Proxy::Proxy (t m r)) *>
(fst <$> crtExtFuncs)
Expand Down Expand Up @@ -413,6 +414,7 @@ indexInfo = let pps = proxy ppsFact (Proxy::Proxy m)
-- the index into the powerful\/decoding basis of \(\O_{m'}\) of the
-- \(i\)th entry of the powerful/decoding basis of \(\O_m\).
extIndicesPowDec :: (m `Divides` m') => Tagged '(m, m') (U.Vector Int)
{-# INLINABLE extIndicesPowDec #-}
extIndicesPowDec = do
(_, phi, _, tots) <- indexInfo
return $ U.generate phi (fromIndexPair tots . (0,))
Expand All @@ -438,15 +440,15 @@ baseWrapper f = do
-- | A lookup table for 'toIndexPair' applied to indices \([\varphi(m')]\).
baseIndicesPow :: forall m m' . (m `Divides` m')
=> Tagged '(m, m') (U.Vector (Int,Int))
{-# INLINABLE baseIndicesPow #-}
-- | A lookup table for 'baseIndexDec' applied to indices \([\varphi(m')]\).
baseIndicesDec :: forall m m' . (m `Divides` m')
=> Tagged '(m, m') (U.Vector (Maybe (Int,Bool)))

{-# INLINABLE baseIndicesDec #-}
-- | Same as 'baseIndicesPow', but only includes the second component
-- of each pair.
baseIndicesCRT :: forall m m' . (m `Divides` m')
=> Tagged '(m, m') (U.Vector Int)

baseIndicesPow = baseWrapper (toIndexPair . totients)

-- this one is more complicated; requires the prime powers
Expand Down
84 changes: 40 additions & 44 deletions lol/Crypto/Lol/Cyclotomic/Tensor/CTensor.hs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@

-- | Wrapper for a C++ implementation of the 'Tensor' interface.

module Crypto.Lol.Cyclotomic.Tensor.CTensor
( CT ) where
module Crypto.Lol.Cyclotomic.Tensor.CTensor (CT) where

import Algebra.Additive as Additive (C)
import Algebra.Module as Module (C)
Expand All @@ -44,7 +43,7 @@ import Data.Traversable as T
import Data.Vector.Generic as V (fromList, toList, unzip)
import Data.Vector.Storable as SV (Vector, convert, foldl',
fromList, generate,
length, map, mapM, replicate,
length, map, replicate,
replicateM, thaw, thaw, toList,
unsafeFreeze,
unsafeWith, zipWith, (!))
Expand Down Expand Up @@ -158,13 +157,14 @@ toZV v@(ZV _) = v
zvToCT' :: forall m r . (Storable r) => IZipVector m r -> CT' m r
zvToCT' v = coerce (convert $ unIZipVector v :: Vector r)

wrap :: (Storable r) => (CT' l r -> CT' m r) -> (CT l r -> CT m r)
wrap :: (Storable s, Storable r) => (CT' l s -> CT' m r) -> (CT l s -> CT m r)
{-# INLINABLE wrap #-}
wrap f (CT v) = CT $ f v
wrap f (ZV v) = CT $ f $ zvToCT' v

wrapM :: (Storable r, Monad mon) => (CT' l r -> mon (CT' m r))
-> (CT l r -> mon (CT m r))
wrapM :: (Storable s, Storable r, Monad mon) => (CT' l s -> mon (CT' m r))
-> (CT l s -> mon (CT m r))
{-# INLINABLE wrapM #-}
wrapM f (CT v) = CT <$> f v
wrapM f (ZV v) = CT <$> f (zvToCT' v)

Expand Down Expand Up @@ -248,15 +248,14 @@ instance Tensor CT where

scalarPow = CT . scalarPow' -- Vector code

l = wrap $ untag $ basicDispatch dl
lInv = wrap $ untag $ basicDispatch dlinv
l = wrap $ basicDispatch dl
lInv = wrap $ basicDispatch dlinv

mulGPow = wrap mulGPow'
mulGDec = wrap $ untag $ basicDispatch dmulgdec
mulGPow = wrap $ basicDispatch dmulgpow
mulGDec = wrap $ basicDispatch dmulgdec

divGPow = wrapM divGPow'
-- we divide by p in the C code (for divGDec only(?)), do NOT call checkDiv!
divGDec = wrapM $ Just . untag (basicDispatch dginvdec)
divGPow = wrapM $ dispatchGInv dginvpow
divGDec = wrapM $ dispatchGInv dginvdec

crtFuncs = (,,,,) <$>
return (CT . repl) <*>
Expand Down Expand Up @@ -285,14 +284,16 @@ instance Tensor CT where

crtSetDec = (CT <$>) <$> coerceBasis crtSetDec'

fmapT f (CT v) = CT $ coerce (SV.map f) v
fmapT f v@(ZV _) = fmapT f $ toCT v
fmapT f = wrap $ coerce (SV.map f)

zipWithT f (CT (CT' v1)) (CT (CT' v2)) = CT $ CT' $ SV.zipWith f v1 v2
zipWithT f v1 v2 = zipWithT f (toCT v1) (toCT v2)
zipWithT f v1' v2' =
let (CT (CT' v1)) = toCT v1'
(CT (CT' v2)) = toCT v2'
in CT $ CT' $ SV.zipWith f v1 v2

unzipT (CT (CT' v)) = (CT . CT') *** (CT . CT') $ unzip v
unzipT v = unzipT $ toCT v
unzipT v =
let (CT (CT' x)) = toCT v
in (CT . CT') *** (CT . CT') $ unzip x

{-# INLINABLE entailIndexT #-}
{-# INLINABLE entailEqT #-}
Expand All @@ -313,14 +314,13 @@ instance Tensor CT where
{-# INLINABLE embedDec #-}
{-# INLINABLE tGaussianDec #-}
{-# INLINABLE gSqNormDec #-}
{-# INLINABLE crtExtFuncs #-}
{-# INLINE crtExtFuncs #-}
{-# INLINABLE coeffs #-}
{-# INLINABLE powBasisPow #-}
{-# INLINABLE crtSetDec #-}
{-# INLINABLE fmapT #-}
{-# INLINABLE zipWithT #-}
{-# INLINABLE unzipT #-}

{-# INLINE zipWithT #-}
{-# INLINE unzipT #-}

coerceTw :: (Functor mon) => TaggedT '(m, m') mon (Vector r -> Vector r) -> mon (CT' m' r -> CT' m r)
coerceTw = (coerce <$>) . untagT
Expand All @@ -338,12 +338,21 @@ coerceCoeffs = coerce
coerceBasis :: Tagged '(m,m') [Vector r] -> Tagged m [CT' m' r]
coerceBasis = coerce

mulGPow' :: (TElt CT r, Fact m) => CT' m r -> CT' m r
mulGPow' = untag $ basicDispatch dmulgpow

divGPow' :: (TElt CT r, Fact m, IntegralDomain r, ZeroTestable r)
=> CT' m r -> Maybe (CT' m r)
divGPow' = untag $ checkDiv $ basicDispatch dginvpow
dispatchGInv :: forall m r . (Storable r, Fact m)
=> (Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO Int16)
-> CT' m r -> Maybe (CT' m r)
dispatchGInv f =
let factors = proxy (marshalFactors <$> ppsFact) (Proxy::Proxy m)
totm = proxy (fromIntegral <$> totientFact) (Proxy::Proxy m)
numFacts = fromIntegral $ SV.length factors
in \(CT' x) -> unsafePerformIO $ do
yout <- SV.thaw x
ret <- SM.unsafeWith yout (\pout ->
SV.unsafeWith factors (\pfac ->
f pout totm pfac numFacts))
if ret /= 0
then Just . CT' <$> unsafeFreeze yout
else return Nothing

withBasicArgs :: forall m r . (Fact m, Storable r)
=> (Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO ())
Expand All @@ -361,8 +370,8 @@ withBasicArgs f =

basicDispatch :: (Storable r, Fact m)
=> (Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO ())
-> Tagged m (CT' m r -> CT' m r)
basicDispatch f = return $ unsafePerformIO . withBasicArgs f
-> CT' m r -> CT' m r
basicDispatch f = unsafePerformIO . withBasicArgs f

gSqNormDec' :: (Storable r, Fact m, Dispatch r)
=> Tagged m (CT' m r -> r)
Expand All @@ -384,19 +393,6 @@ ctCRTInv = do
return $ \x -> unsafePerformIO $
withPtrArray ruinv' (\ruptr -> with mhatInv (flip withBasicArgs x . dcrtinv ruptr))

checkDiv :: (Storable r, IntegralDomain r, ZeroTestable r, Fact m)
=> Tagged m (CT' m r -> CT' m r) -> Tagged m (CT' m r -> Maybe (CT' m r))
checkDiv f = do
f' <- f
oddRad' <- fromIntegral <$> oddRadicalFact
return $ \x ->
let (CT' y) = f' x
in CT' <$> SV.mapM (`divIfDivis` oddRad') y

divIfDivis :: (IntegralDomain r, ZeroTestable r) => r -> r -> Maybe r
divIfDivis num den = let (q,r) = num `divMod` den
in if isZero r then Just q else Nothing

cZipDispatch :: (Storable r, Fact m)
=> (Ptr r -> Ptr r -> Int64 -> IO ())
-> Tagged m (CT' m r -> CT' m r -> CT' m r)
Expand Down
16 changes: 8 additions & 8 deletions lol/Crypto/Lol/Cyclotomic/Tensor/CTensor/Backend.hs
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,9 @@ class (repr ~ CTypeOf r) => Dispatch' repr r where
-- | Equivalent to 'Tensor's @mulGDec@.
dmulgdec :: Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO ()
-- | Equivalent to 'Tensor's @divGPow@.
dginvpow :: Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO ()
dginvpow :: Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO Int16
-- | Equivalent to 'Tensor's @divGDec@.
dginvdec :: Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO ()
dginvdec :: Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO Int16
-- | Equivalent to @zipWith (*)@
dmul :: Ptr r -> Ptr r -> Int64 -> IO ()

Expand Down Expand Up @@ -315,12 +315,12 @@ foreign import ccall unsafe "tensorGPowC" tensorGPowC :: Int16 -> Ptr (C
foreign import ccall unsafe "tensorGDecR" tensorGDecR :: Int16 -> Ptr Int64 -> Int64 -> Ptr CPP -> Int16 -> IO ()
foreign import ccall unsafe "tensorGDecRq" tensorGDecRq :: Int16 -> Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> Ptr Int64 -> IO ()
foreign import ccall unsafe "tensorGDecC" tensorGDecC :: Int16 -> Ptr (Complex Double) -> Int64 -> Ptr CPP -> Int16 -> IO ()
foreign import ccall unsafe "tensorGInvPowR" tensorGInvPowR :: Int16 -> Ptr Int64 -> Int64 -> Ptr CPP -> Int16 -> IO ()
foreign import ccall unsafe "tensorGInvPowRq" tensorGInvPowRq :: Int16 -> Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> Ptr Int64 -> IO ()
foreign import ccall unsafe "tensorGInvPowC" tensorGInvPowC :: Int16 -> Ptr (Complex Double) -> Int64 -> Ptr CPP -> Int16 -> IO ()
foreign import ccall unsafe "tensorGInvDecR" tensorGInvDecR :: Int16 -> Ptr Int64 -> Int64 -> Ptr CPP -> Int16 -> IO ()
foreign import ccall unsafe "tensorGInvDecRq" tensorGInvDecRq :: Int16 -> Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> Ptr Int64 -> IO ()
foreign import ccall unsafe "tensorGInvDecC" tensorGInvDecC :: Int16 -> Ptr (Complex Double) -> Int64 -> Ptr CPP -> Int16 -> IO ()
foreign import ccall unsafe "tensorGInvPowR" tensorGInvPowR :: Int16 -> Ptr Int64 -> Int64 -> Ptr CPP -> Int16 -> IO Int16
foreign import ccall unsafe "tensorGInvPowRq" tensorGInvPowRq :: Int16 -> Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> Ptr Int64 -> IO Int16
foreign import ccall unsafe "tensorGInvPowC" tensorGInvPowC :: Int16 -> Ptr (Complex Double) -> Int64 -> Ptr CPP -> Int16 -> IO Int16
foreign import ccall unsafe "tensorGInvDecR" tensorGInvDecR :: Int16 -> Ptr Int64 -> Int64 -> Ptr CPP -> Int16 -> IO Int16
foreign import ccall unsafe "tensorGInvDecRq" tensorGInvDecRq :: Int16 -> Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> Ptr Int64 -> IO Int16
foreign import ccall unsafe "tensorGInvDecC" tensorGInvDecC :: Int16 -> Ptr (Complex Double) -> Int64 -> Ptr CPP -> Int16 -> IO Int16

foreign import ccall unsafe "tensorCRTRq" tensorCRTRq :: Int16 -> Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> Ptr (Ptr (ZqBasic q Int64)) -> Ptr Int64 -> IO ()
foreign import ccall unsafe "tensorCRTC" tensorCRTC :: Int16 -> Ptr (Complex Double) -> Int64 -> Ptr CPP -> Int16 -> Ptr (Ptr (Complex Double)) -> IO ()
Expand Down
3 changes: 3 additions & 0 deletions lol/Crypto/Lol/Cyclotomic/Tensor/CTensor/Extension.hs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ backpermute' is v = generate (U.length is) (\i -> v ! (is U.! i))

embedPow', embedDec' :: (Additive r, Storable r, m `Divides` m')
=> Tagged '(m, m') (Vector r -> Vector r)
{-# INLINABLE embedPow' #-}
{-# INLINABLE embedDec' #-}
-- | Embeds an vector in the powerful basis of the the mth cyclotomic ring
-- to an vector in the powerful basis of the m'th cyclotomic ring when @m | m'@
embedPow' = (\indices arr -> generate (U.length indices) $ \idx ->
Expand Down Expand Up @@ -98,6 +100,7 @@ kronToVec v = do
twaceCRT' :: forall mon m m' r .
(Storable r, CRTrans mon r, m `Divides` m')
=> TaggedT '(m, m') mon (Vector r -> Vector r)
{-# INLINE twaceCRT' #-}
twaceCRT' = tagT $ do
g' <- proxyT (kronToVec gCRTK) (Proxy::Proxy m')
gInv <- proxyT (kronToVec gInvCRTK) (Proxy::Proxy m)
Expand Down
9 changes: 0 additions & 9 deletions lol/Crypto/Lol/Cyclotomic/Tensor/CTensor/common.h
Original file line number Diff line number Diff line change
@@ -1,17 +1,8 @@
#ifndef COMMON_H_
#define COMMON_H_

#include <stdio.h>
#include <stdlib.h>
#include "types.h"

#define ASSERT(EXP) { \
if (!(EXP)) { \
fprintf (stderr, "Assertion in file '%s' line %d : " #EXP " is false\n", __FILE__, __LINE__); \
exit(-1); \
} \
}

// calculates base ** exp
hDim_t ipow(hDim_t base, hShort_t exp);

Expand Down
Loading

0 comments on commit c4e1da8

Please sign in to comment.