Skip to content

Commit

Permalink
src|tests|python-bindings: Make DiskProver(de)serializable (#323)
Browse files Browse the repository at this point in the history
* serialize|tests: Implement basic serialization and corresponding tests

* prover_disk|tests: Make `DiskProver` serializable

* prover_disk|tests: Implement and test move constructor for `DiskProver`

* python-bindings|tests: Introduce `DiskProver.{__bytes__|from_bytes}`

* tests: Move `(De)serialization` test case down in hierarchy

This is to work around an issue in the `blake3` source code namely the 
following function having a "init phase" which isn't thread safe 
(Assigning the global variable `g_cpu_features` if its undefined yet):

https://github.com/BLAKE3-team/BLAKE3/blob/4056af6d7ffdf4d13bb776b7ea1db2a6b52d4d75/c/blake3_dispatch.c#L85 

Since it's not thread safe it lead to data race here prior to this 
commit because 24e3057 introduced the 
`(De)Serialization` test as first test in the file and it triggeres plot 
creation with 2 phase1 threads which trigger `get_cpu_features` and in 
this case concurrently. The workaround here is to move the new test 
below the `F Functions` test which is the first test in the file which 
is calling `get_cpu_features` from a single thread. This single thread 
runs the "init phase" and since this is only required once the follow up 
tests are as fine like they were before on `main`.

* prover_disk: Explicitly delete the copy constructor

* prover_disk: Don't lock the mutex in move contructor

* prover_disk: Move the `version` field also

* serialize: Be more specific in some cases about the in/out type

* serialize: Constrain functions in the template to trivial types

* serialize: `elements` -> `size` and `size` -> `offset`
  • Loading branch information
xdustinface authored Apr 4, 2022
1 parent 7ed53db commit cc9fd7a
Show file tree
Hide file tree
Showing 5 changed files with 378 additions and 2 deletions.
15 changes: 15 additions & 0 deletions python-bindings/chiapos.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,21 @@ PYBIND11_MODULE(chiapos, m)

py::class_<DiskProver>(m, "DiskProver")
.def(py::init<const std::string &>(), py::call_guard<py::gil_scoped_release>())
.def_static("from_bytes", [](const py::bytes &bytes) -> DiskProver {
py::buffer_info info(py::buffer(bytes).request());
auto data = reinterpret_cast<const uint8_t*>(info.ptr);
auto vecBytes = std::vector<uint8_t>(data, data + info.size);
py::gil_scoped_release release;
return DiskProver(vecBytes);
})
.def("__bytes__", [](DiskProver &dp) {
std::vector<uint8_t> vecBytes;
{
py::gil_scoped_release release;
vecBytes = dp.ToBytes();
}
return py::bytes(reinterpret_cast<const char*>(vecBytes.data()), vecBytes.size());
})
.def(
"get_memo",
[](DiskProver &dp) {
Expand Down
45 changes: 44 additions & 1 deletion src/prover_disk.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "calculate_bucket.hpp"
#include "encoding.hpp"
#include "entry_sizes.hpp"
#include "serialize.hpp"
#include "util.hpp"

struct plot_header {
Expand All @@ -48,6 +49,7 @@ struct plot_header {
// of space, for a given challenge.
class DiskProver {
public:
static const uint16_t VERSION{1};
// The constructor opens the file, and reads the contents of the file header. The table pointers
// will be used to find and seek to all seven tables, at the time of proving.
explicit DiskProver(const std::string& filename) : id(kIdLen)
Expand Down Expand Up @@ -117,6 +119,35 @@ class DiskProver {
delete[] c2_buf;
}

explicit DiskProver(const std::vector<uint8_t>& vecBytes)
{
Deserializer deserializer(vecBytes);
deserializer >> version;
if (version != VERSION) {
// TODO: Migrate to new version if we change something related to the data structure
throw std::invalid_argument("DiskProver: Invalid version.");
}
deserializer >> filename;
deserializer >> memo;
deserializer >> id;
deserializer >> k;
deserializer >> table_begin_pointers;
deserializer >> C2;
}

DiskProver(DiskProver const&) = delete;

DiskProver(DiskProver&& other) noexcept
{
filename = std::move(other.filename);
memo = std::move(other.memo);
id = std::move(other.id);
k = other.k;
table_begin_pointers = std::move(other.table_begin_pointers);
C2 = std::move(other.C2);
version = std::move(other.version);
}

~DiskProver()
{
std::lock_guard<std::mutex> l(_mtx);
Expand All @@ -130,7 +161,11 @@ class DiskProver {

const std::vector<uint8_t>& GetId() { return id; }

std::string GetFilename() const noexcept { return filename; }
const std::vector<uint64_t>& GetTableBeginPointers() const noexcept { return table_begin_pointers; }

const std::vector<uint64_t>& GetC2() const noexcept { return C2; }

const std::string& GetFilename() const noexcept { return filename; }

uint8_t GetSize() const noexcept { return k; }

Expand Down Expand Up @@ -233,7 +268,15 @@ class DiskProver {
return full_proof;
}

std::vector<uint8_t> ToBytes() const
{
Serializer serializer;
serializer << version << filename << memo << id << k << table_begin_pointers << C2;
return serializer.Data();
}

private:
uint16_t version{VERSION};
mutable std::mutex _mtx;
std::string filename;
std::vector<uint8_t> memo;
Expand Down
153 changes: 153 additions & 0 deletions src/serialize.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
// Copyright 2022 Chia Network Inc

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef SRC_SERIALIZE_HPP_
#define SRC_SERIALIZE_HPP_

#include <stdexcept>
#include <string>
#include <vector>

template<typename Type> class Serializable{
public:
static void SerializeImpl(const Type& in, std::vector<uint8_t>& out)
{
static_assert(std::is_trivial_v<Type>);
size_t nSize = sizeof(in);
out.reserve(out.size() + nSize);
for (size_t i = 0; i < nSize; ++i) {
out.push_back(*((uint8_t*)&in + i));
}
}
static size_t DeserializeImpl(const std::vector<uint8_t>& in, Type& out, size_t position)
{
static_assert(std::is_trivial_v<Type>);
size_t size = sizeof(out);
if (position + size > in.size()) {
throw std::invalid_argument("DeserializeImpl: Trying to read out of bounds.");
}
for (size_t i = 0; i < size; ++i) {
*((uint8_t*)&out + i) = in[position + i];
}
return size;
}
};

template<typename TypeIn>
void Serialize(const TypeIn& in, std::vector<uint8_t>& out)
{
Serializable<TypeIn>::SerializeImpl(in, out);
}

template<typename TypeOut>
size_t Deserialize(const std::vector<uint8_t>& in, TypeOut& out, const size_t position)
{
return Serializable<TypeOut>::DeserializeImpl(in, out, position);
}

template<typename TypeIn>
void SerializeContainer(const TypeIn& in, std::vector<uint8_t>& out)
{
Serialize(in.size(), out);
for (auto& entry : in) {
Serialize(entry, out);
}
}

template<typename TypeOut>
size_t DeserializeContainer(const std::vector<uint8_t>& in, TypeOut& out, const size_t position)
{
size_t size;
size_t offset = Deserialize(in, size, position);
if (size == 0) {
return offset;
}
out.clear();
out.resize(size);
for (size_t i = 0; i < size; ++i) {
offset += Deserialize(in, out[i], position + offset);
}
return offset;
}

template<typename Type>
class Serializable<std::vector<Type>>{
public:
static void SerializeImpl(const std::vector<Type>& in, std::vector<uint8_t>& out)
{
SerializeContainer(in, out);
}
static size_t DeserializeImpl(const std::vector<uint8_t>& in, std::vector<Type>& out, const size_t position)
{
return DeserializeContainer(in, out, position);
}
};

template<>
class Serializable<std::string>{
public:
static void SerializeImpl(const std::string& in, std::vector<uint8_t>& out)
{
SerializeContainer(in, out);
}
static size_t DeserializeImpl(const std::vector<uint8_t>& in, std::string& out, const size_t position)
{
return DeserializeContainer(in, out, position);
}
};

class Serializer
{
std::vector<uint8_t> data;
public:
template <typename InputType>
friend Serializer& operator <<(Serializer& serializer, const InputType& value) {
Serialize(value, serializer.data);
return serializer;
}
std::vector<uint8_t>& Data()
{
return data;
}
void Reset()
{
data.clear();
}
};


class Deserializer
{
size_t position{0};
const std::vector<uint8_t>& data;
public:
explicit Deserializer(const std::vector<uint8_t>& data) : data(data) {}
void Reset()
{
position = 0;
}
template <typename OutputType>
friend Deserializer& operator>>(Deserializer& deserializer, OutputType& output) {
deserializer.position += Deserialize(deserializer.data,
output,
deserializer.position);
return deserializer;
}
bool End() const
{
return position == data.size();
}
};

#endif // SRC_SERIALIZE_HPP_
Loading

0 comments on commit cc9fd7a

Please sign in to comment.