Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

General improvements for GM handling #1

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 99 additions & 51 deletions include/hdf5_routines.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -164,71 +164,119 @@ namespace LP_MP {
}
*/

template<typename MRF_CONSTRUCTOR>
bool ParseGM(const std::string filename, MRF_CONSTRUCTOR& mrf)
{
typedef double ValueType;
typedef size_t IndexType;
typedef size_t LabelType;
typedef opengm::Adder OperatorType;
typedef opengm::Minimizer AccumulatorType;
typedef opengm::DiscreteSpace<IndexType, LabelType> SpaceType;
namespace opengm_mrf_types {
typedef double ValueType;
typedef size_t IndexType;
typedef size_t LabelType;
typedef opengm::Adder OperatorType;
typedef opengm::Minimizer AccumulatorType;
typedef opengm::DiscreteSpace<IndexType, LabelType> SpaceType;

// Set functions for graphical model
typedef opengm::meta::TypeListGenerator<
opengm::ExplicitFunction<ValueType, IndexType, LabelType>,
opengm::PottsFunction<ValueType, IndexType, LabelType>,
opengm::PottsNFunction<ValueType, IndexType, LabelType>,
opengm::PottsGFunction<ValueType, IndexType, LabelType>,
opengm::TruncatedSquaredDifferenceFunction<ValueType, IndexType, LabelType>,
opengm::TruncatedAbsoluteDifferenceFunction<ValueType, IndexType, LabelType>
>::type FunctionTypeList;
typedef opengm::meta::TypeListGenerator<
opengm::ExplicitFunction<ValueType, IndexType, LabelType>,
opengm::PottsFunction<ValueType, IndexType, LabelType>,
opengm::PottsNFunction<ValueType, IndexType, LabelType>,
opengm::PottsGFunction<ValueType, IndexType, LabelType>,
opengm::TruncatedSquaredDifferenceFunction<ValueType, IndexType, LabelType>,
opengm::TruncatedAbsoluteDifferenceFunction<ValueType, IndexType, LabelType>
>::type FunctionTypeList;

typedef opengm::GraphicalModel<
ValueType,
OperatorType,
FunctionTypeList,
SpaceType
> GmType;
}

typedef opengm::GraphicalModel<
ValueType,
OperatorType,
FunctionTypeList,
SpaceType
> GmType;

template<typename MRF_CONSTRUCTOR>
bool ParseGM(const std::string filename, MRF_CONSTRUCTOR& mrf)
{
using namespace opengm_mrf_types;

GmType gm;
opengm::hdf5::load(gm, filename,"gm");
GmType gm;
::opengm::hdf5::load(gm, filename, "gm");

for(INDEX f=0; f<gm.numberOfFactors(); ++f){
for(INDEX f=0; f<gm.numberOfFactors(); ++f){

if(gm[f].numberOfVariables()==0){
// ignore for now
}
else if(gm[f].numberOfVariables()==1){
const INDEX i = gm.variableOfFactor(f,0);
std::vector<REAL> unaryCost(gm[f].numberOfLabels(0));
for(INDEX l=0; l<gm[f].numberOfLabels(0); ++l){
unaryCost[l] = gm[f](std::array<INDEX,1>({l}).begin());
if(gm[f].numberOfVariables()==0){
// ignore for now
}
else if(gm[f].numberOfVariables()==1){
const INDEX i = gm.variableOfFactor(f,0);
std::vector<REAL> unaryCost(gm[f].numberOfLabels(0));
for(INDEX l=0; l<gm[f].numberOfLabels(0); ++l){
unaryCost[l] = gm[f](std::array<INDEX,1>({l}).begin());
}
mrf.AddUnaryFactor(i,unaryCost);
}
mrf.AddUnaryFactor(i,unaryCost);
}
else if(gm[f].numberOfVariables()==2){
const INDEX i = gm.variableOfFactor(f,0);
const INDEX j = gm.variableOfFactor(f,1);
matrix<REAL> pairwiseCost(gm[f].numberOfLabels(0), gm[f].numberOfLabels(1));//gm[f].size());
for(INDEX l1=0; l1<gm[f].numberOfLabels(0); ++l1){
for(INDEX l2=0; l2<gm[f].numberOfLabels(1); ++l2){
pairwiseCost(l1, l2) = gm[f](std::array<INDEX,2>({l1,l2}).begin());
else if(gm[f].numberOfVariables()==2){
const INDEX i = gm.variableOfFactor(f,0);
const INDEX j = gm.variableOfFactor(f,1);
matrix<REAL> pairwiseCost(gm[f].numberOfLabels(0), gm[f].numberOfLabels(1));//gm[f].size());
for(INDEX l1=0; l1<gm[f].numberOfLabels(0); ++l1){
for(INDEX l2=0; l2<gm[f].numberOfLabels(1); ++l2){
pairwiseCost(l1, l2) = gm[f](std::array<INDEX,2>({l1,l2}).begin());
}
}
mrf.AddPairwiseFactor(i,j,pairwiseCost);
}
else{
std::cout << "Factors of order higher than 2 are so far not supported !" <<std::endl;
return 1;
}
mrf.AddPairwiseFactor(i,j,pairwiseCost);
}
else{
std::cout << "Factors of order higher than 2 are so far not supported !" <<std::endl;
return 1;
}

return true;
}

template<typename MRF_CONSTRUCTOR>
bool WriteGM(const std::string filename, MRF_CONSTRUCTOR& mrf)
{
using namespace opengm_mrf_types;
GmType gm;

// TODO: Handle constant term (future commit...).

for (INDEX i = 0; i < mrf.GetNumberOfVariables(); ++i) {
INDEX gm_i = gm.addVariable(mrf.GetNumberOfLabels(i)); assert(gm_i == i);
std::array<IndexType, 1> variables { i };
std::array<IndexType, 1> space { mrf.GetNumberOfLabels(i) };
opengm::ExplicitFunction<ValueType, IndexType, LabelType> f(space.begin(), space.end());
for (INDEX j = 0; j < mrf.GetNumberOfLabels(i); ++j)
f(j) = mrf.GetUnaryFactor(i)->GetFactor()->operator[](j);
gm.addFactor(gm.addFunction(f), variables.begin(), variables.end());
}

return true;
for (INDEX i = 0; i < mrf.GetNumberOfPairwiseFactors(); ++i) {
auto variables = mrf.GetPairwiseVariables(i);
std::array<IndexType, 2> space {
mrf.GetNumberOfLabels(variables[0]),
mrf.GetNumberOfLabels(variables[1]) };
opengm::ExplicitFunction<ValueType, IndexType, LabelType> f(space.begin(), space.begin() + 2);
for (INDEX j = 0; j < mrf.GetNumberOfLabels(variables[0]); ++j)
for (INDEX k = 0; k < mrf.GetNumberOfLabels(variables[1]); ++k)
f(j, k) = mrf.GetPairwiseFactor(i)->GetFactor()->operator()(j, k);
gm.addFactor(gm.addFunction(f), variables.begin(), variables.end());
}

if constexpr (mrf.arity() >= 3) {
for (INDEX i = 0; i < mrf.GetNumberOfTripletFactors(); ++i) {
auto variables = mrf.GetTripletIndices(i);
std::array<IndexType, 3> space {
mrf.GetNumberOfLabels(variables[0]),
mrf.GetNumberOfLabels(variables[1]),
mrf.GetNumberOfLabels(variables[2]) };
opengm::ExplicitFunction<ValueType, IndexType, LabelType> f(space.begin(), space.begin() + 2);
for (INDEX j = 0; j < mrf.GetNumberOfLabels(variables[0]); ++j)
for (INDEX k = 0; k < mrf.GetNumberOfLabels(variables[1]); ++k)
for (INDEX l = 0; l < mrf.GetNumberOfLabels(variables[2]); ++l)
f(j, k, l) = mrf.GetPairwiseFactor(i)->GetFactor()->operator()(j, k, l);
gm.addFactor(gm.addFunction(f), variables.begin(), variables.end());
}
}

opengm::hdf5::save(gm, filename, "gm");
}

template<typename MRF_CONSTRUCTOR>
Expand Down
14 changes: 11 additions & 3 deletions include/mrf_problem_construction.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,16 @@ public:
}
unaryFactor_[node_number] = u;
}
template<typename COST>
PairwiseFactorContainer* AddPairwiseFactor(INDEX var1, INDEX var2, const COST& cost)
template<typename... ARGS>
PairwiseFactorContainer* AddPairwiseFactor(INDEX var1, INDEX var2, ARGS... args)
{
//if(var1 > var2) std::swap(var1,var2);
assert(var1<var2);
assert(!HasPairwiseFactor(var1,var2));
//assert(cost.size() == GetNumberOfLabels(var1) * GetNumberOfLabels(var2));
//assert(pairwiseMap_.find(std::make_tuple(var1,var2)) == pairwiseMap_.end());
//PairwiseFactorContainer* p = new PairwiseFactorContainer(PairwiseFactor(cost), cost);
auto* p = new PairwiseFactorContainer(GetNumberOfLabels(var1), GetNumberOfLabels(var2), cost);
auto* p = new PairwiseFactorContainer(GetNumberOfLabels(var1), GetNumberOfLabels(var2), args...);
lp_->AddFactor(p);
ConstructPairwiseFactor(*(p->GetFactor()), var1, var2);
pairwiseFactor_.push_back(p);
Expand Down Expand Up @@ -348,6 +348,10 @@ public:
return std::move(trees);
}

constexpr INDEX arity() {
return 2;
}

protected:
std::vector<UnaryFactorContainer*> unaryFactor_;
std::vector<PairwiseFactorContainer*> pairwiseFactor_;
Expand Down Expand Up @@ -719,6 +723,10 @@ public:
return no_triplets_added;
}

constexpr INDEX arity() {
return 3;
}


protected:
std::vector<TripletFactorContainer*> tripletFactor_;
Expand Down