Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Connections: serialization, op== #827

Merged
merged 1 commit into from
Jun 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 42 additions & 35 deletions src/htm/algorithms/Connections.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -763,49 +763,56 @@ std::ostream& operator<< (std::ostream& stream, const Connections& self)



bool Connections::operator==(const Connections &other) const {
if (cells_.size() != other.cells_.size())
return false;
bool Connections::operator==(const Connections &o) const {
try {
NTA_CHECK (cells_.size() == o.cells_.size()) << "Connections equals: cells_" << cells_.size() << " vs. " << o.cells_.size();
NTA_CHECK (cells_ == o.cells_) << "Connections equals: cells_" << cells_.size() << " vs. " << o.cells_.size();

NTA_CHECK (segments_ == o.segments_ ) << "Connections equals: segments_";
NTA_CHECK (destroyedSegments_ == o.destroyedSegments_ ) << "Connections equals: destroyedSegments_";

NTA_CHECK (synapses_ == o.synapses_ ) << "Connections equals: synapses_";
NTA_CHECK (destroyedSynapses_ == o.destroyedSynapses_ ) << "Connections equals: destroyedSynapses_";


//also check underlying datastructures (segments, and subsequently synapses). Can be time consuming.
//1.cells:
for(const auto cellD : cells_) {
//2.segments:
const auto& segments = cellD.segments;
for(const auto seg : segments) {
NTA_CHECK( dataForSegment(seg) == o.dataForSegment(seg) ) << "CellData equals: segmentData";
//3.synapses:
const auto& synapses = dataForSegment(seg).synapses;
for(const auto syn : synapses) {
NTA_CHECK(dataForSynapse(syn) == o.dataForSynapse(syn) ) << "SegmentData equals: synapseData";
}
}
}

if(iteration_ != other.iteration_) return false;

for (CellIdx i = 0; i < static_cast<CellIdx>(cells_.size()); i++) {
const CellData &cellData = cells_[i];
const CellData &otherCellData = other.cells_[i];
NTA_CHECK (connectedThreshold_ == o.connectedThreshold_ ) << "Connections equals: connectedThreshold_";
NTA_CHECK (iteration_ == o.iteration_ ) << "Connections equals: iteration_";

if (cellData.segments.size() != otherCellData.segments.size()) {
return false;
}
NTA_CHECK(potentialSynapsesForPresynapticCell_ == o.potentialSynapsesForPresynapticCell_);
NTA_CHECK(connectedSynapsesForPresynapticCell_ == o.connectedSynapsesForPresynapticCell_);
NTA_CHECK(potentialSegmentsForPresynapticCell_ == o.potentialSegmentsForPresynapticCell_);
NTA_CHECK(connectedSegmentsForPresynapticCell_ == o.connectedSegmentsForPresynapticCell_);

for (SegmentIdx j = 0; j < static_cast<SegmentIdx>(cellData.segments.size()); j++) {
const Segment segment = cellData.segments[j];
const SegmentData &segmentData = segments_[segment];
const Segment otherSegment = otherCellData.segments[j];
const SegmentData &otherSegmentData = other.segments_[otherSegment];
NTA_CHECK (nextSegmentOrdinal_ == o.nextSegmentOrdinal_ ) << "Connections equals: nextSegmentOrdinal_";
NTA_CHECK (nextSynapseOrdinal_ == o.nextSynapseOrdinal_ ) << "Connections equals: nextSynapseOrdinal_";

if (segmentData.synapses.size() != otherSegmentData.synapses.size() ||
segmentData.cell != otherSegmentData.cell) {
return false;
}
NTA_CHECK (timeseries_ == o.timeseries_ ) << "Connections equals: timeseries_";
NTA_CHECK (previousUpdates_ == o.previousUpdates_ ) << "Connections equals: previousUpdates_";
NTA_CHECK (currentUpdates_ == o.currentUpdates_ ) << "Connections equals: currentUpdates_";

for (SynapseIdx k = 0; k < static_cast<SynapseIdx>(segmentData.synapses.size()); k++) {
const Synapse synapse = segmentData.synapses[k];
const SynapseData &synapseData = synapses_[synapse];
const Synapse otherSynapse = otherSegmentData.synapses[k];
const SynapseData &otherSynapseData = other.synapses_[otherSynapse];
NTA_CHECK (prunedSyns_ == o.prunedSyns_ ) << "Connections equals: prunedSyns_";
NTA_CHECK (prunedSegs_ == o.prunedSegs_ ) << "Connections equals: prunedSegs_";

if (synapseData.presynapticCell != otherSynapseData.presynapticCell ||
synapseData.permanence != otherSynapseData.permanence) {
return false;
}

// Two functionally identical instances may have different flatIdxs.
NTA_ASSERT(synapseData.segment == segment);
NTA_ASSERT(otherSynapseData.segment == otherSegment);
}
}
} catch(const htm::Exception& ex) {
std::cout << "Connection equals: differ! " << ex.what();
return false;
}

return true;
}

182 changes: 132 additions & 50 deletions src/htm/algorithms/Connections.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,40 @@ struct SynapseData: public Serializable {

SynapseData() {}

//Serialization
CerealAdapter;
template<class Archive>
void save_ar(Archive & ar) const {
ar(cereal::make_nvp("perm", permanence),
cereal::make_nvp("presyn", presynapticCell));
ar(CEREAL_NVP(permanence),
CEREAL_NVP(presynapticCell),
CEREAL_NVP(segment),
CEREAL_NVP(presynapticMapIndex_),
CEREAL_NVP(id)
);
}
template<class Archive>
void load_ar(Archive & ar) {
ar( permanence, presynapticCell);
ar( permanence, presynapticCell, segment, presynapticMapIndex_, id);
}

//operator==
bool operator==(const SynapseData& o) const {
try {
NTA_CHECK(presynapticCell == o.presynapticCell ) << "SynapseData equals: presynapticCell";
NTA_CHECK(permanence == o.permanence ) << "SynapseData equals: permanence";
NTA_CHECK(segment == o.segment ) << "SynapseData equals: segment";
NTA_CHECK(presynapticMapIndex_ == o.presynapticMapIndex_ ) << "SynapseData equals: presynapticMapIndex_";
NTA_CHECK(id == o.id ) << "SynapseData equals: id";
} catch(const htm::Exception& ex) {
//NTA_WARN << "SynapseData equals: " << ex.what(); //Note: uncomment for debug, tells you
//where the diff is. It's perfectly OK for the "exception" to occur, as it just denotes
//that the data is NOT equal.
return false;
}
return true;
}
inline bool operator!=(const SynapseData& o) const { return !operator==(o); }

};

/**
Expand All @@ -94,14 +117,48 @@ struct SynapseData: public Serializable {
* @param cell
* The cell that this segment is on.
*/
struct SegmentData {
struct SegmentData: public Serializable {
SegmentData(const CellIdx cell, Segment id, UInt32 lastUsed = 0) : cell(cell), numConnected(0), lastUsed(lastUsed), id(id) {} //default constructor

std::vector<Synapse> synapses;
CellIdx cell; //mother cell that this segment originates from
SynapseIdx numConnected; //number of permanences from `synapses` that are >= synPermConnected, ie connected synapses
UInt32 lastUsed = 0; //last used time (iteration). Used for segment pruning by "least recently used" (LRU) in `createSegment`
Segment id;

//Serialize
SegmentData() {}; //empty constructor for serialization, do not use
CerealAdapter;
template<class Archive>
void save_ar(Archive & ar) const {
ar(CEREAL_NVP(synapses),
CEREAL_NVP(cell),
CEREAL_NVP(numConnected),
CEREAL_NVP(lastUsed),
CEREAL_NVP(id)
);
}
template<class Archive>
void load_ar(Archive & ar) {
ar( synapses, cell, numConnected, lastUsed, id);
}

//equals op==
bool operator==(const SegmentData& o) const {
try {
NTA_CHECK(synapses == o.synapses) << "SegmentData equals: synapses";
NTA_CHECK(cell == o.cell) << "SegmentData equals: cell";
NTA_CHECK(numConnected == o.numConnected) << "SegmentData equals: numConnected";
NTA_CHECK(lastUsed == o.lastUsed) << "SegmentData equals: lastUsed";
NTA_CHECK(id == o.id) << "SegmentData equals: id";

} catch(const htm::Exception& ex) {
//NTA_WARN << "SegmentData equals: " << ex.what();
return false;
}
return true;
}
inline bool operator!=(const SegmentData& o) const { return !operator==(o); }
};

/**
Expand All @@ -115,10 +172,35 @@ struct SegmentData {
* Segments on this cell.
*
*/
struct CellData {
struct CellData : public Serializable {
std::vector<Segment> segments;

//Serialization
CerealAdapter;
template<class Archive>
void save_ar(Archive & ar) const {
ar(CEREAL_NVP(segments)
);
}
template<class Archive>
void load_ar(Archive & ar) {
ar( segments);
}

//operator==
bool operator==(const CellData& o) const {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

helper structures in connections SegmentData, SynapseData now implement Serializable and op==, which simplifies these operations in Connections.

try {
NTA_CHECK( segments == o.segments ) << "CellData equals: segments";
} catch(const htm::Exception& ex) {
//NTA_WARN << "CellData equals: " << ex.what();
return false;
}
return true;
}
inline bool operator!=(const CellData& o) const { return !operator==(o); }
};


/**
* A base class for Connections event handlers.
*
Expand Down Expand Up @@ -557,58 +639,58 @@ class Connections : public Serializable
CerealAdapter;
template<class Archive>
void save_ar(Archive & ar) const {
// make this look like a queue of items to be sent.
// and a queue of sizes so we can distribute the
// correct number for each level when deserializing.
std::deque<SynapseData> syndata;
std::deque<size_t> sizes;
sizes.push_back(cells_.size());
for (CellData cellData : cells_) {
const std::vector<Segment> &segments = cellData.segments;
sizes.push_back(segments.size());
for (Segment segment : segments) {
const SegmentData &segmentData = segments_[segment];
const std::vector<Synapse> &synapses = segmentData.synapses;
sizes.push_back(synapses.size());
for (Synapse synapse : synapses) {
const SynapseData &synapseData = synapses_[synapse];
syndata.push_back(synapseData);
}
}
}
ar(CEREAL_NVP(connectedThreshold_));
//the following member must not be serialized (so is set to =0).
//That is because of we serialize only active segments & synapses,
//excluding the "destroyed", so those fields start empty.
//! ar(CEREAL_NVP(destroyedSegments_));
ar(CEREAL_NVP(sizes));
ar(CEREAL_NVP(syndata));
ar(CEREAL_NVP(iteration_));
ar(CEREAL_NVP(cells_));
ar(CEREAL_NVP(segments_));
ar(CEREAL_NVP(synapses_));

ar(CEREAL_NVP(destroyedSynapses_));
ar(CEREAL_NVP(destroyedSegments_));

ar(CEREAL_NVP(potentialSynapsesForPresynapticCell_));
ar(CEREAL_NVP(connectedSynapsesForPresynapticCell_));
ar(CEREAL_NVP(potentialSegmentsForPresynapticCell_));
ar(CEREAL_NVP(connectedSegmentsForPresynapticCell_));

ar(CEREAL_NVP(nextSegmentOrdinal_));
ar(CEREAL_NVP(nextSynapseOrdinal_));

ar(CEREAL_NVP(timeseries_));
ar(CEREAL_NVP(previousUpdates_));
ar(CEREAL_NVP(currentUpdates_));

ar(CEREAL_NVP(prunedSyns_));
ar(CEREAL_NVP(prunedSegs_));
}

template<class Archive>
void load_ar(Archive & ar) {
std::deque<size_t> sizes;
std::deque<SynapseData> syndata;
ar(CEREAL_NVP(connectedThreshold_));
ar(CEREAL_NVP(sizes));
ar(CEREAL_NVP(syndata));

CellIdx numCells = static_cast<CellIdx>(sizes.front()); sizes.pop_front();
initialize(numCells, connectedThreshold_);
for (UInt cell = 0; cell < numCells; cell++) {
size_t numSegments = sizes.front(); sizes.pop_front();
for (SegmentIdx j = 0; j < static_cast<SegmentIdx>(numSegments); j++) {
Segment segment = createSegment( cell );

size_t numSynapses = sizes.front(); sizes.pop_front();
for (SynapseIdx k = 0; k < static_cast<SynapseIdx>(numSynapses); k++) {
SynapseData& syn = syndata.front(); syndata.pop_front();
createSynapse( segment, syn.presynapticCell, syn.permanence );
}
}
}
ar(CEREAL_NVP(iteration_));
//!initialize(numCells, connectedThreshold_); //initialize Connections //Note: we actually don't call Connections
//initialize() as all the members are de/serialized.
ar(CEREAL_NVP(cells_));
ar(CEREAL_NVP(segments_));
ar(CEREAL_NVP(synapses_));

ar(CEREAL_NVP(destroyedSynapses_));
ar(CEREAL_NVP(destroyedSegments_));

ar(CEREAL_NVP(potentialSynapsesForPresynapticCell_));
ar(CEREAL_NVP(connectedSynapsesForPresynapticCell_));
ar(CEREAL_NVP(potentialSegmentsForPresynapticCell_));
ar(CEREAL_NVP(connectedSegmentsForPresynapticCell_));

ar(CEREAL_NVP(nextSegmentOrdinal_));
ar(CEREAL_NVP(nextSynapseOrdinal_));

ar(CEREAL_NVP(timeseries_));
ar(CEREAL_NVP(previousUpdates_));
ar(CEREAL_NVP(currentUpdates_));

ar(CEREAL_NVP(prunedSyns_));
ar(CEREAL_NVP(prunedSegs_));
}

/**
Expand Down Expand Up @@ -771,7 +853,7 @@ class Connections : public Serializable
Synapse prunedSyns_ = 0; //how many synapses have been removed?
Segment prunedSegs_ = 0;

//for listeners
//for listeners //TODO listeners are not serialized, nor included in equals ==
UInt32 nextEventToken_;
std::map<UInt32, ConnectionsEventHandler *> eventHandlers_;
}; // end class Connections
Expand Down