Skip to content

Commit

Permalink
Merge pull request #1172 from aprokop/remove_check_get_return_type
Browse files Browse the repository at this point in the history
  • Loading branch information
aprokop authored Nov 12, 2024
2 parents 8c047ee + 4fb3a91 commit 7c4ff52
Show file tree
Hide file tree
Showing 32 changed files with 239 additions and 281 deletions.
53 changes: 30 additions & 23 deletions benchmarks/brute_force_vs_bvh/brute_force_vs_bvh_timpl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,44 +23,51 @@ using MemorySpace = ExecutionSpace::memory_space;

namespace ArborXBenchmark
{
template <int DIM, typename FloatingPoint>
struct PrimitivesTag
{};
struct PredicatesTag
{};

template <int DIM, typename FloatingPoint, typename Tag>
struct Placeholder
{
int count;
};
} // namespace ArborXBenchmark

// Primitives are a set of points located at (i, i, i),
// with i = 0, ..., n-1
template <int DIM, typename FloatingPoint>
struct ArborX::AccessTraits<ArborXBenchmark::Placeholder<DIM, FloatingPoint>,
ArborX::PrimitivesTag>
template <int DIM, typename FloatingPoint, typename Tag>
struct ArborX::AccessTraits<
ArborXBenchmark::Placeholder<DIM, FloatingPoint, Tag>>
{
using Primitives = ArborXBenchmark::Placeholder<DIM, FloatingPoint>;
using memory_space = MemorySpace;
using size_type = typename MemorySpace::size_type;
static KOKKOS_FUNCTION size_type size(Primitives d) { return d.count; }
static KOKKOS_FUNCTION auto get(Primitives, size_type i)

static KOKKOS_FUNCTION size_type
size(ArborXBenchmark::Placeholder<DIM, FloatingPoint, Tag> d)
{
return d.count;
}

static KOKKOS_FUNCTION auto
get(ArborXBenchmark::Placeholder<DIM, FloatingPoint,
ArborXBenchmark::PrimitivesTag>,
size_type i)
{
// Primitives are a set of points located at (i, i, i),
// with i = 0, ..., n-1
ArborX::Point<DIM, FloatingPoint> point;
for (int d = 0; d < DIM; ++d)
point[d] = i;
return point;
}
};

// Predicates are sphere intersections with spheres of radius i
// centered at (i, i, i), with i = 0, ..., n-1
template <int DIM, typename FloatingPoint>
struct ArborX::AccessTraits<ArborXBenchmark::Placeholder<DIM, FloatingPoint>,
ArborX::PredicatesTag>
{
using Predicates = ArborXBenchmark::Placeholder<DIM, FloatingPoint>;
using memory_space = MemorySpace;
using size_type = typename MemorySpace::size_type;
static KOKKOS_FUNCTION size_type size(Predicates d) { return d.count; }
static KOKKOS_FUNCTION auto get(Predicates, size_type i)
static KOKKOS_FUNCTION auto
get(ArborXBenchmark::Placeholder<DIM, FloatingPoint,
ArborXBenchmark::PredicatesTag>,
size_type i)
{
// Predicates are sphere intersections with spheres of radius i
// centered at (i, i, i), with i = 0, ..., n-1
ArborX::Point<DIM, FloatingPoint> center;
for (int d = 0; d < DIM; ++d)
center[d] = i;
Expand All @@ -76,8 +83,8 @@ static void run_fp(int nprimitives, int nqueries, int nrepeats)
{
ExecutionSpace space{};

Placeholder<DIM, FloatingPoint> primitives{nprimitives};
Placeholder<DIM, FloatingPoint> predicates{nqueries};
Placeholder<DIM, FloatingPoint, PrimitivesTag> primitives{nprimitives};
Placeholder<DIM, FloatingPoint, PredicatesTag> predicates{nqueries};
using Point = ArborX::Point<DIM, FloatingPoint>;

for (int i = 0; i < nrepeats; i++)
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/bvh_driver/benchmark_registration.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ struct Iota
};

template <typename MemorySpace>
struct ArborX::AccessTraits<Iota<MemorySpace>, ArborX::PrimitivesTag>
struct ArborX::AccessTraits<Iota<MemorySpace>>
{
using Self = Iota<MemorySpace>;

Expand Down
2 changes: 1 addition & 1 deletion benchmarks/dbscan/ArborX_DBSCANVerification.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ bool verifyDBSCAN(ExecutionSpace exec_space, Primitives const &primitives,

static_assert(Kokkos::is_view<LabelsView>{});

using Points = Details::AccessValues<Primitives, PrimitivesTag>;
using Points = Details::AccessValues<Primitives>;
using MemorySpace = typename Points::memory_space;

static_assert(std::is_same<typename LabelsView::value_type, int>{});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ struct Triangles
};

template <typename MemorySpace>
class ArborX::AccessTraits<Triangles<MemorySpace>, ArborX::PrimitivesTag>
class ArborX::AccessTraits<Triangles<MemorySpace>>
{
using Self = Triangles<MemorySpace>;

Expand Down
4 changes: 2 additions & 2 deletions examples/access_traits/example_cuda_access_traits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ struct Spheres
};

template <>
struct ArborX::AccessTraits<PointCloud, ArborX::PrimitivesTag>
struct ArborX::AccessTraits<PointCloud>
{
static KOKKOS_FUNCTION std::size_t size(PointCloud const &cloud)
{
Expand All @@ -49,7 +49,7 @@ struct ArborX::AccessTraits<PointCloud, ArborX::PrimitivesTag>
};

template <>
struct ArborX::AccessTraits<Spheres, ArborX::PredicatesTag>
struct ArborX::AccessTraits<Spheres>
{
static KOKKOS_FUNCTION std::size_t size(Spheres const &d) { return d.N; }
static KOKKOS_FUNCTION auto get(Spheres const &d, std::size_t i)
Expand Down
4 changes: 2 additions & 2 deletions examples/access_traits/example_host_access_traits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
#include <random>
#include <vector>

template <typename T, typename Tag>
struct ArborX::AccessTraits<std::vector<T>, Tag>
template <typename T>
struct ArborX::AccessTraits<std::vector<T>>
{
static std::size_t size(std::vector<T> const &v) { return v.size(); }
static T const &get(std::vector<T> const &v, std::size_t i) { return v[i]; }
Expand Down
4 changes: 2 additions & 2 deletions examples/brute_force/example_brute_force.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ struct Iota
};

template <typename MemorySpace>
struct ArborX::AccessTraits<Iota<MemorySpace>, ArborX::PrimitivesTag>
struct ArborX::AccessTraits<Iota<MemorySpace>>
{
using Self = Iota<MemorySpace>;

Expand All @@ -54,7 +54,7 @@ struct DummyIndexableGetter
};

template <>
struct ArborX::AccessTraits<Dummy, ArborX::PredicatesTag>
struct ArborX::AccessTraits<Dummy>
{
using memory_space = MemorySpace;
using size_type = typename MemorySpace::size_type;
Expand Down
4 changes: 2 additions & 2 deletions examples/callback/example_callback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ struct NearestToOrigin
};

template <>
struct ArborX::AccessTraits<FirstOctant, ArborX::PredicatesTag>
struct ArborX::AccessTraits<FirstOctant>
{
static KOKKOS_FUNCTION std::size_t size(FirstOctant) { return 1; }
static KOKKOS_FUNCTION auto get(FirstOctant, std::size_t)
Expand All @@ -40,7 +40,7 @@ struct ArborX::AccessTraits<FirstOctant, ArborX::PredicatesTag>
};

template <>
struct ArborX::AccessTraits<NearestToOrigin, ArborX::PredicatesTag>
struct ArborX::AccessTraits<NearestToOrigin>
{
static KOKKOS_FUNCTION std::size_t size(NearestToOrigin) { return 1; }
static KOKKOS_FUNCTION auto get(NearestToOrigin d, std::size_t)
Expand Down
2 changes: 1 addition & 1 deletion examples/molecular_dynamics/example_molecular_dynamics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ struct Neighbors
};

template <class MemorySpace>
struct ArborX::AccessTraits<Neighbors<MemorySpace>, ArborX::PredicatesTag>
struct ArborX::AccessTraits<Neighbors<MemorySpace>>
{
using memory_space = MemorySpace;
using size_type = std::size_t;
Expand Down
3 changes: 1 addition & 2 deletions examples/raytracing/example_raytracing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,7 @@ struct DepositEnergy
} // namespace OrderedIntersectsBased

template <typename MemorySpace>
struct ArborX::AccessTraits<OrderedIntersectsBased::Rays<MemorySpace>,
ArborX::PredicatesTag>
struct ArborX::AccessTraits<OrderedIntersectsBased::Rays<MemorySpace>>
{
using memory_space = MemorySpace;
using size_type = std::size_t;
Expand Down
11 changes: 4 additions & 7 deletions src/cluster/ArborX_DBSCAN.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ struct MixedBoxPrimitives

template <typename Primitives, typename PermuteFilter>
struct AccessTraits<Details::PrimitivesWithRadiusReorderedAndFiltered<
Primitives, PermuteFilter>,
PredicatesTag>
Primitives, PermuteFilter>>
{
using memory_space = typename Primitives::memory_space;
using Predicates =
Expand Down Expand Up @@ -125,8 +124,7 @@ struct AccessTraits<Details::PrimitivesWithRadiusReorderedAndFiltered<
template <typename Points, typename MixedOffsets, typename CellIndices,
typename Permutation>
struct AccessTraits<
Details::MixedBoxPrimitives<Points, MixedOffsets, CellIndices, Permutation>,
ArborX::PrimitivesTag>
Details::MixedBoxPrimitives<Points, MixedOffsets, CellIndices, Permutation>>
{
using Primitives = Details::MixedBoxPrimitives<Points, MixedOffsets,
CellIndices, Permutation>;
Expand Down Expand Up @@ -200,8 +198,7 @@ struct Parameters
} // namespace DBSCAN

template <typename ExecutionSpace, typename Primitives>
Kokkos::View<int *,
typename AccessTraits<Primitives, PrimitivesTag>::memory_space>
Kokkos::View<int *, typename AccessTraits<Primitives>::memory_space>
dbscan(ExecutionSpace const &exec_space, Primitives const &primitives,
float eps, int core_min_size,
DBSCAN::Parameters const &parameters = DBSCAN::Parameters())
Expand All @@ -210,7 +207,7 @@ dbscan(ExecutionSpace const &exec_space, Primitives const &primitives,

namespace KokkosExt = ArborX::Details::KokkosExt;

using Points = Details::AccessValues<Primitives, PrimitivesTag>;
using Points = Details::AccessValues<Primitives>;
using MemorySpace = typename Points::memory_space;

static_assert(
Expand Down
4 changes: 2 additions & 2 deletions src/cluster/ArborX_MinimumSpanningTree.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ struct MinimumSpanningTree
int k = 1)
: edges(Kokkos::view_alloc(space, Kokkos::WithoutInitializing,
"ArborX::MST::edges"),
AccessTraits<Primitives, PrimitivesTag>::size(primitives) - 1)
AccessTraits<Primitives>::size(primitives) - 1)
, dendrogram_parents("ArborX::MST::dendrogram_parents", 0)
, dendrogram_parent_heights("ArborX::MST::dendrogram_parent_heights", 0)
{
Kokkos::Profiling::pushRegion("ArborX::MST::MST");

using Points = Details::AccessValues<Primitives, PrimitivesTag>;
using Points = Details::AccessValues<Primitives>;
using Point = typename Points::value_type;
static_assert(GeometryTraits::is_point_v<Point>);

Expand Down
8 changes: 4 additions & 4 deletions src/distributed/ArborX_DistributedTree.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class DistributedTreeBase
Details::KokkosExt::is_accessible_from<MemorySpace,
ExecutionSpace>::value);

using Predicates = Details::AccessValues<UserPredicates, PredicatesTag>;
using Predicates = Details::AccessValues<UserPredicates>;
static_assert(Details::KokkosExt::is_accessible_from<
typename Predicates::memory_space, ExecutionSpace>::value,
"Predicates must be accessible from the execution space");
Expand Down Expand Up @@ -156,9 +156,9 @@ KOKKOS_DEDUCTION_GUIDE
#else
KOKKOS_FUNCTION
#endif
DistributedTree(MPI_Comm, ExecutionSpace, Values) -> DistributedTree<
typename Details::AccessValues<Values, PrimitivesTag>::memory_space,
typename Details::AccessValues<Values, PrimitivesTag>::value_type>;
DistributedTree(MPI_Comm, ExecutionSpace, Values)
-> DistributedTree<typename Details::AccessValues<Values>::memory_space,
typename Details::AccessValues<Values>::value_type>;

template <typename BottomTree>
template <typename ExecutionSpace, typename... Args>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ KOKKOS_INLINE_FUNCTION auto approx_expand_by_radius(Geometry const &geometry,

template <class Predicates, class Distances>
struct AccessTraits<
Details::WithinDistanceFromPredicates<Predicates, Distances>, PredicatesTag>
Details::WithinDistanceFromPredicates<Predicates, Distances>>
{
using Predicate = typename Predicates::value_type;
using Geometry =
Expand Down
10 changes: 4 additions & 6 deletions src/interpolation/ArborX_InterpMovingLeastSquares.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,8 @@ class MovingLeastSquares
"Memory space must be accessible from the execution space");

// SourcePoints is an access trait of points
ArborX::Details::check_valid_access_traits(PrimitivesTag{}, source_points);
using SourceAccess =
ArborX::Details::AccessValues<SourcePoints, PrimitivesTag>;
ArborX::Details::check_valid_access_traits(source_points);
using SourceAccess = ArborX::Details::AccessValues<SourcePoints>;
static_assert(
KokkosExt::is_accessible_from<typename SourceAccess::memory_space,
ExecutionSpace>::value,
Expand All @@ -94,9 +93,8 @@ class MovingLeastSquares
static constexpr int dimension = GeometryTraits::dimension_v<SourcePoint>;

// TargetPoints is an access trait of points
ArborX::Details::check_valid_access_traits(PrimitivesTag{}, target_points);
using TargetAccess =
ArborX::Details::AccessValues<TargetPoints, PrimitivesTag>;
ArborX::Details::check_valid_access_traits(target_points);
using TargetAccess = ArborX::Details::AccessValues<TargetPoints>;
static_assert(
KokkosExt::is_accessible_from<typename TargetAccess::memory_space,
ExecutionSpace>::value,
Expand Down
29 changes: 14 additions & 15 deletions src/spatial/ArborX_BruteForce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class BruteForce
check_valid_callback_if_first_argument_is_not_a_view<value_type>(
callback_or_view, user_predicates, view);

using Predicates = Details::AccessValues<UserPredicates, PredicatesTag>;
using Predicates = Details::AccessValues<UserPredicates>;
using Tag = typename Predicates::value_type::Tag;

Details::CrsGraphWrapperImpl::queryDispatch(
Expand All @@ -105,40 +105,38 @@ KOKKOS_DEDUCTION_GUIDE
#else
KOKKOS_FUNCTION
#endif
BruteForce(ExecutionSpace, Values) -> BruteForce<
typename Details::AccessValues<Values, PrimitivesTag>::memory_space,
typename Details::AccessValues<Values, PrimitivesTag>::value_type>;
BruteForce(ExecutionSpace, Values)
-> BruteForce<typename Details::AccessValues<Values>::memory_space,
typename Details::AccessValues<Values>::value_type>;

template <typename ExecutionSpace, typename Values, typename IndexableGetter>
#if KOKKOS_VERSION >= 40400
KOKKOS_DEDUCTION_GUIDE
#else
KOKKOS_FUNCTION
#endif
BruteForce(ExecutionSpace, Values, IndexableGetter) -> BruteForce<
typename Details::AccessValues<Values, PrimitivesTag>::memory_space,
typename Details::AccessValues<Values, PrimitivesTag>::value_type,
IndexableGetter>;
BruteForce(ExecutionSpace, Values, IndexableGetter)
-> BruteForce<typename Details::AccessValues<Values>::memory_space,
typename Details::AccessValues<Values>::value_type,
IndexableGetter>;

template <typename MemorySpace, typename Value, typename IndexableGetter,
typename BoundingVolume>
template <typename ExecutionSpace, typename UserValues>
BruteForce<MemorySpace, Value, IndexableGetter, BoundingVolume>::BruteForce(
ExecutionSpace const &space, UserValues const &user_values,
IndexableGetter const &indexable_getter)
: _size(AccessTraits<UserValues, PrimitivesTag>::size(user_values))
: _size(AccessTraits<UserValues>::size(user_values))
, _values(Kokkos::view_alloc(space, Kokkos::WithoutInitializing,
"ArborX::BruteForce::values"),
_size)
, _indexable_getter(indexable_getter)
{
static_assert(Details::KokkosExt::is_accessible_from<MemorySpace,
ExecutionSpace>::value);
// FIXME redo with RangeTraits
Details::check_valid_access_traits<UserValues>(
PrimitivesTag{}, user_values, Details::DoNotCheckGetReturnType());
Details::check_valid_access_traits<UserValues>(user_values);

using Values = Details::AccessValues<UserValues, PrimitivesTag>;
using Values = Details::AccessValues<UserValues>;
Values values{user_values}; // NOLINT

static_assert(
Expand Down Expand Up @@ -167,10 +165,11 @@ void BruteForce<MemorySpace, Value, IndexableGetter, BoundingVolume>::query(
{
static_assert(Details::KokkosExt::is_accessible_from<MemorySpace,
ExecutionSpace>::value);
Details::check_valid_access_traits(PredicatesTag{}, user_predicates);
Details::check_valid_access_traits(user_predicates,
Details::CheckReturnTypeTag{});
Details::check_valid_callback<value_type>(callback, user_predicates);

using Predicates = Details::AccessValues<UserPredicates, PredicatesTag>;
using Predicates = Details::AccessValues<UserPredicates>;
static_assert(
Details::KokkosExt::is_accessible_from<typename Predicates::memory_space,
ExecutionSpace>::value,
Expand Down
Loading

0 comments on commit 7c4ff52

Please sign in to comment.