From 328f6d2ac965f9cb01233cade14738c9c90c208d Mon Sep 17 00:00:00 2001 From: Stefan Haller Date: Thu, 29 Mar 2018 08:50:58 +0200 Subject: [PATCH 1/3] Allow saving GMs in OpenGM format --- include/hdf5_routines.hxx | 150 ++++++++++++++++++--------- include/mrf_problem_construction.hxx | 6 ++ 2 files changed, 105 insertions(+), 51 deletions(-) diff --git a/include/hdf5_routines.hxx b/include/hdf5_routines.hxx index 3d76d44..8741fef 100644 --- a/include/hdf5_routines.hxx +++ b/include/hdf5_routines.hxx @@ -164,71 +164,119 @@ namespace LP_MP { } */ - template - bool ParseGM(const std::string filename, MRF_CONSTRUCTOR& mrf) - { - typedef double ValueType; - typedef size_t IndexType; - typedef size_t LabelType; - typedef opengm::Adder OperatorType; - typedef opengm::Minimizer AccumulatorType; - typedef opengm::DiscreteSpace SpaceType; + namespace opengm_mrf_types { + typedef double ValueType; + typedef size_t IndexType; + typedef size_t LabelType; + typedef opengm::Adder OperatorType; + typedef opengm::Minimizer AccumulatorType; + typedef opengm::DiscreteSpace SpaceType; - // Set functions for graphical model - typedef opengm::meta::TypeListGenerator< - opengm::ExplicitFunction, - opengm::PottsFunction, - opengm::PottsNFunction, - opengm::PottsGFunction, - opengm::TruncatedSquaredDifferenceFunction, - opengm::TruncatedAbsoluteDifferenceFunction - >::type FunctionTypeList; + typedef opengm::meta::TypeListGenerator< + opengm::ExplicitFunction, + opengm::PottsFunction, + opengm::PottsNFunction, + opengm::PottsGFunction, + opengm::TruncatedSquaredDifferenceFunction, + opengm::TruncatedAbsoluteDifferenceFunction + >::type FunctionTypeList; + typedef opengm::GraphicalModel< + ValueType, + OperatorType, + FunctionTypeList, + SpaceType + > GmType; + } - typedef opengm::GraphicalModel< - ValueType, - OperatorType, - FunctionTypeList, - SpaceType - > GmType; - + template + bool ParseGM(const std::string filename, MRF_CONSTRUCTOR& mrf) + { + using namespace opengm_mrf_types; - GmType gm; - opengm::hdf5::load(gm, filename,"gm"); + GmType gm; + ::opengm::hdf5::load(gm, filename, "gm"); - for(INDEX f=0; f unaryCost(gm[f].numberOfLabels(0)); - for(INDEX l=0; l({l}).begin()); + if(gm[f].numberOfVariables()==0){ + // ignore for now + } + else if(gm[f].numberOfVariables()==1){ + const INDEX i = gm.variableOfFactor(f,0); + std::vector unaryCost(gm[f].numberOfLabels(0)); + for(INDEX l=0; l({l}).begin()); + } + mrf.AddUnaryFactor(i,unaryCost); } - mrf.AddUnaryFactor(i,unaryCost); - } - else if(gm[f].numberOfVariables()==2){ - const INDEX i = gm.variableOfFactor(f,0); - const INDEX j = gm.variableOfFactor(f,1); - matrix pairwiseCost(gm[f].numberOfLabels(0), gm[f].numberOfLabels(1));//gm[f].size()); - for(INDEX l1=0; l1({l1,l2}).begin()); + else if(gm[f].numberOfVariables()==2){ + const INDEX i = gm.variableOfFactor(f,0); + const INDEX j = gm.variableOfFactor(f,1); + matrix pairwiseCost(gm[f].numberOfLabels(0), gm[f].numberOfLabels(1));//gm[f].size()); + for(INDEX l1=0; l1({l1,l2}).begin()); + } } + mrf.AddPairwiseFactor(i,j,pairwiseCost); + } + else{ + std::cout << "Factors of order higher than 2 are so far not supported !" < + bool WriteGM(const std::string filename, MRF_CONSTRUCTOR& mrf) + { + using namespace opengm_mrf_types; + GmType gm; + + // TODO: Handle constant term (future commit...). + + for (INDEX i = 0; i < mrf.GetNumberOfVariables(); ++i) { + INDEX gm_i = gm.addVariable(mrf.GetNumberOfLabels(i)); assert(gm_i == i); + std::array variables { i }; + std::array space { mrf.GetNumberOfLabels(i) }; + opengm::ExplicitFunction f(space.begin(), space.end()); + for (INDEX j = 0; j < mrf.GetNumberOfLabels(i); ++j) + f(j) = mrf.GetUnaryFactor(i)->GetFactor()->operator[](j); + gm.addFactor(gm.addFunction(f), variables.begin(), variables.end()); + } - return true; + for (INDEX i = 0; i < mrf.GetNumberOfPairwiseFactors(); ++i) { + auto variables = mrf.GetPairwiseVariables(i); + std::array space { + mrf.GetNumberOfLabels(variables[0]), + mrf.GetNumberOfLabels(variables[1]) }; + opengm::ExplicitFunction f(space.begin(), space.begin() + 2); + for (INDEX j = 0; j < mrf.GetNumberOfLabels(variables[0]); ++j) + for (INDEX k = 0; k < mrf.GetNumberOfLabels(variables[1]); ++k) + f(j, k) = mrf.GetPairwiseFactor(i)->GetFactor()->operator()(j, k); + gm.addFactor(gm.addFunction(f), variables.begin(), variables.end()); + } + + if constexpr (MrfHasTripletFactors::value) { + for (INDEX i = 0; i < mrf.GetNumberOfTripletFactors(); ++i) { + auto variables = mrf.GetTripletIndices(i); + std::array space { + mrf.GetNumberOfLabels(variables[0]), + mrf.GetNumberOfLabels(variables[1]), + mrf.GetNumberOfLabels(variables[2]) }; + opengm::ExplicitFunction f(space.begin(), space.begin() + 2); + for (INDEX j = 0; j < mrf.GetNumberOfLabels(variables[0]); ++j) + for (INDEX k = 0; k < mrf.GetNumberOfLabels(variables[1]); ++k) + for (INDEX l = 0; l < mrf.GetNumberOfLabels(variables[2]); ++l) + f(j, k, l) = mrf.GetPairwiseFactor(i)->GetFactor()->operator()(j, k, l); + gm.addFactor(gm.addFunction(f), variables.begin(), variables.end()); + } + } + + opengm::hdf5::save(gm, filename, "gm"); } template diff --git a/include/mrf_problem_construction.hxx b/include/mrf_problem_construction.hxx index fb27f24..f9921da 100644 --- a/include/mrf_problem_construction.hxx +++ b/include/mrf_problem_construction.hxx @@ -726,6 +726,12 @@ protected: std::map, INDEX> tripletMap_; // given two sorted indices, return factorId belonging to that index. }; +template> +struct MrfHasTripletFactors : std::false_type { }; + +template +struct MrfHasTripletFactors> : std::true_type { }; + /////////////////////////////////////////////////////////////////// // From baaf6a2c5a71eaf0abb3c566d2062320e6fdf268 Mon Sep 17 00:00:00 2001 From: Stefan Haller Date: Tue, 10 Apr 2018 11:26:14 +0200 Subject: [PATCH 2/3] Adjust PairwiseFactorContainer to match LP_MP This allows to arbitrary compatible arguments to the underlying Factor. --- include/mrf_problem_construction.hxx | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/include/mrf_problem_construction.hxx b/include/mrf_problem_construction.hxx index f9921da..e990987 100644 --- a/include/mrf_problem_construction.hxx +++ b/include/mrf_problem_construction.hxx @@ -70,8 +70,8 @@ public: } unaryFactor_[node_number] = u; } - template - PairwiseFactorContainer* AddPairwiseFactor(INDEX var1, INDEX var2, const COST& cost) + template + PairwiseFactorContainer* AddPairwiseFactor(INDEX var1, INDEX var2, ARGS... args) { //if(var1 > var2) std::swap(var1,var2); assert(var1AddFactor(p); ConstructPairwiseFactor(*(p->GetFactor()), var1, var2); pairwiseFactor_.push_back(p); From 22ce444ec2740809f91d8746d2114f359fc1fda6 Mon Sep 17 00:00:00 2001 From: Stefan Haller Date: Tue, 17 Apr 2018 10:09:31 +0200 Subject: [PATCH 3/3] Get rid of MrfHasTripletFactors trait --- include/hdf5_routines.hxx | 2 +- include/mrf_problem_construction.hxx | 14 ++++++++------ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/include/hdf5_routines.hxx b/include/hdf5_routines.hxx index 8741fef..3a06d13 100644 --- a/include/hdf5_routines.hxx +++ b/include/hdf5_routines.hxx @@ -260,7 +260,7 @@ namespace LP_MP { gm.addFactor(gm.addFunction(f), variables.begin(), variables.end()); } - if constexpr (MrfHasTripletFactors::value) { + if constexpr (mrf.arity() >= 3) { for (INDEX i = 0; i < mrf.GetNumberOfTripletFactors(); ++i) { auto variables = mrf.GetTripletIndices(i); std::array space { diff --git a/include/mrf_problem_construction.hxx b/include/mrf_problem_construction.hxx index e990987..7cfbd05 100644 --- a/include/mrf_problem_construction.hxx +++ b/include/mrf_problem_construction.hxx @@ -348,6 +348,10 @@ public: return std::move(trees); } + constexpr INDEX arity() { + return 2; + } + protected: std::vector unaryFactor_; std::vector pairwiseFactor_; @@ -719,6 +723,10 @@ public: return no_triplets_added; } + constexpr INDEX arity() { + return 3; + } + protected: std::vector tripletFactor_; @@ -726,12 +734,6 @@ protected: std::map, INDEX> tripletMap_; // given two sorted indices, return factorId belonging to that index. }; -template> -struct MrfHasTripletFactors : std::false_type { }; - -template -struct MrfHasTripletFactors> : std::true_type { }; - /////////////////////////////////////////////////////////////////// //