Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transfer span from Singe #47

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions ports-of-call/span.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#ifndef _PORTS_OF_CALL_SPAN_HPP_
#define _PORTS_OF_CALL_SPAN_HPP_

#include "ports-of-call/portability.hpp"

#include <cassert>
#include <cstddef>

namespace PortsOfCall {

// ================================================================================================
// Heavily simplified relative to std::span, but provides a similar concept in a way
// that's portable to GPUs.

template <typename T>
class span {
private:
T *ptr_{nullptr};
std::size_t size_{0};

public:
BrendanKKrueger marked this conversation as resolved.
Show resolved Hide resolved
// Member types
using element_type = T;
using value_type = std::remove_cv_t<T>;
using size_type = std::size_t;
using difference_type = std::ptrdiff_t;
using pointer = T*;
using const_pointer = const T*;
using reference = T&;
using const_reference = const T&;
using iterator = T*;
using reverse_iterator = std::reverse_iterator<iterator>;

// Construct an empty span
constexpr span() noexcept = default;

// Construct a span from a pointer, along with offsets.
// -- ptr : Pointer to the beginning of the array
// -- count : Number of entries in the span
// For example:
// double* my_array = new double[10]
// span full_range(my_array, 10); // elements 0-9 (all 10 elements in the array)
// span sub_range(my_array+2, 6); // elements 2-7 (total of six elements)
// The full_range object provides access to the full range of my_array. The sub_range
// object provides access to indices 2 <= i < 8 of the array my_array, but the indices
// are shifted so that sub_range[i] is my_array[i+2].
template <typename SizeType>
PORTABLE_FUNCTION constexpr span(T *ptr, const SizeType count)
: ptr_{ptr}, size_(count) {
assert(count >= 0);
}

// Query the size of the range
PORTABLE_FUNCTION constexpr auto size() const { return size_; }

// Iterator (really a pointer) to the beginning of the range, providing mutable access.
PORTABLE_FUNCTION constexpr T *begin() { return ptr_; }

// Iterator (really a pointer) to the beginning of the range, providing constant access.
PORTABLE_FUNCTION constexpr const T *begin() const { return ptr_; }

// Iterator (really a pointer) to the beginning of the range, providing constant access.
PORTABLE_FUNCTION constexpr const T *cbegin() const { return ptr_; }

// Iterator (really a pointer) to the end of the range, providing mutable access.
PORTABLE_FUNCTION constexpr T *end() { return ptr_ + size_; }

// Iterator (really a pointer) to the beginning of the range, providing constant access.
PORTABLE_FUNCTION constexpr const T *end() const { return ptr_ + size_; }

// Iterator (really a pointer) to the beginning of the range, providing constant access.
PORTABLE_FUNCTION constexpr const T *cend() const { return ptr_ + size_; }

// Index operator to obtain mutable access to an element of the range.
template <typename Index>
PORTABLE_FUNCTION constexpr T &operator[](const Index &index) {
assert(index >= 0);
assert(static_cast<std::size_t>(index) < size_);
return *(ptr_ + index);
}

// Index operator to obtain constant access to an element of the range.
template <typename Index>
PORTABLE_FUNCTION constexpr const T &operator[](const Index index) const {
assert(index >= static_cast<Index>(0));
assert(index < static_cast<Index>(size_));
return *(ptr_ + index);
}
};

// ================================================================================================

template <typename T, typename SizeType>
PORTABLE_FUNCTION constexpr auto make_span(T *const pointer,
const SizeType count) {
return span<T>(pointer, count);
}

// ================================================================================================

} // end namespace PortsOfCall

#endif // #ifndef _PORTS_OF_CALL_SPAN_HPP_
7 changes: 6 additions & 1 deletion test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,9 @@ target_link_libraries(test_portsofcall
include(Catch)
catch_discover_tests(test_portsofcall)

target_sources(test_portsofcall PRIVATE test_portability.cpp test_array.cpp)
target_sources(test_portsofcall
PRIVATE
test_portability.cpp
test_array.cpp
test_span.cpp
)
103 changes: 103 additions & 0 deletions test/test_span.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#include "ports-of-call/span.hpp"

#ifndef CATCH_CONFIG_FAST_COMPILE
#define CATCH_CONFIG_FAST_COMPILE
#include <catch2/catch_test_macros.hpp>
#endif

#include <catch2/matchers/catch_matchers_floating_point.hpp>

#include <algorithm>
#include <array>
#include <iterator>
#include <numeric>
#include <type_traits>
#include <vector>

namespace span_test {

template <typename T>
constexpr static bool really_const = std::is_const<std::remove_reference_t<T>>::value;
}

TEST_CASE("span", "[util][span]") {
using std::begin;
using std::cbegin;
using std::cend;
using std::end;

using span_test::really_const;

SECTION("begin/end iteration") {
SECTION("with non-zero size") {
std::vector<double> data0 = {1, 2, 3};
PortsOfCall::span<double> data(data0.data(), data0.size());

auto b = begin(data);
auto e = end(data);

CHECK(std::distance(b, e) == 3);
CHECK(b + 3 == e);
CHECK(not really_const<decltype(*b)>);
CHECK(not really_const<decltype(*e)>);

auto cb = cbegin(data);
auto ce = cend(data);

CHECK(really_const<decltype(*cb)>);
CHECK(really_const<decltype(*ce)>);
}
}

SECTION("operator[]") {
SECTION("with non-zero size") {
SECTION("non-const data") {
std::vector<int> data0{1, 2, 3};
PortsOfCall::span<int> data(data0.data(), data0.size());

for (int i = 0; i < 3; ++i) {
CHECK(data[i] == i + 1);
}
CHECK(not really_const<decltype(data[0])>);
}

SECTION("with const data") {
std::vector<int> const data0{1, 2, 3};
PortsOfCall::span<int const> data(data0.data(), data0.size());

for (int i = 0; i < 3; ++i) {
CHECK(data[i] == i + 1);
}
CHECK(really_const<decltype(data[0])>);
}
}
}

SECTION("range-based for") {
constexpr int N{10};
std::vector<float> vec(N);
float *ptr = vec.data();
PortsOfCall::span<float> span(ptr, N);
float const denom = static_cast<float>(1) / static_cast<float>(N);
int n{0};
for (auto &x : span) {
x = static_cast<float>(n++) * denom;
}
for (int i{0}; i < N; ++i) {
REQUIRE_THAT(span[i], Catch::Matchers::WithinRel(static_cast<float>(i) * denom));
}
}

SECTION("STL algorithms") {
constexpr int N{10};
std::vector<int> vec(N);
PortsOfCall::span<int> span(vec.data(), N);
std::fill(span.begin(), span.end(), 42);
bool all42 =
std::all_of(span.begin(), span.end(), [](int const x) { return x == 42; });
CHECK(all42);
std::iota(span.begin(), span.end(), 1);
int sum = std::accumulate(span.begin(), span.end(), 5);
CHECK(sum == 60);
}
}
Loading