diff --git a/src/main/java/meka/classifiers/multilabel/meta/MLRF.java b/src/main/java/meka/classifiers/multilabel/meta/MLRF.java index 9601401a..1a8571f2 100644 --- a/src/main/java/meka/classifiers/multilabel/meta/MLRF.java +++ b/src/main/java/meka/classifiers/multilabel/meta/MLRF.java @@ -32,6 +32,7 @@ import no.uib.cipr.matrix.SVD; import no.uib.cipr.matrix.sparse.CompRowMatrix; import weka.core.Attribute; +import weka.core.Capabilities; import weka.core.Instance; import weka.core.Instances; import weka.core.Option; @@ -40,9 +41,9 @@ import weka.core.TechnicalInformation.Field; import weka.core.TechnicalInformation.Type; import weka.core.TechnicalInformationHandler; +import weka.core.WekaException; import java.util.ArrayList; -import java.util.Collections; import java.util.Enumeration; import java.util.List; import java.util.Random; @@ -60,12 +61,12 @@ public class MLRF extends MetaProblemTransformationMethod implements Randomizabl /** * The number of subsets to generate. */ - protected int K = 10; + protected int K = 2; /** * The dimensionality reduction parameter. */ - protected int numFeatures = 10; + protected int numFeatures = 5; private DenseMatrix[] R; @@ -121,14 +122,12 @@ private List> generateFeatureSubsets(int d) { A.shuffle(indices, r); for (int i = 0; i < K; i++) { List subset = new ArrayList<>(); - for (int j = 0; j < d / K; j++) { - if (i * (d / K) + j >= d) - break; + for (int j = 0; j < d / K; j++) subset.add(indices[i * (d / K) + j]); - } - Collections.sort(subset); result.add(subset); } + for (int i = K * (d / K); i < d; i++) + result.get(i - K * (d / K)).add(indices[i]); return result; } @@ -139,6 +138,11 @@ public void buildClassifier(Instances D) throws Exception { int L = D.classIndex(); int d = D.numAttributes() - L; + if (K > d) + throw new WekaException("Number of subsets is larger than the number of attributes."); + if (numFeatures > (d / K)) + throw new WekaException("Dimensionality reduction parameter k is larger than the subset size."); + m_Classifiers = ProblemTransformationMethod.makeCopies((MultiLabelClassifier) m_Classifier, m_NumIterations); R = new DenseMatrix[m_NumIterations]; Instances transformedInstances = F.remove(D, A.make_sequence(L), true); @@ -195,8 +199,7 @@ public void buildClassifier(Instances D) throws Exception { } // Latent semantic indexing - SVD svd = new SVD(nNew, subSize); - svd = svd.factor(mat); + SVD svd = SVD.factorize(mat); DenseMatrix transformationMatrix = getTransformationMatrix(svd, subSize); // Insert each subset transformation matrix into full rotation matrix R[i] @@ -279,8 +282,8 @@ public String[] getOptions() { @Override public void setOptions(String[] options) throws Exception { - K = OptionUtils.parse(options, "K", 10); - numFeatures = OptionUtils.parse(options, "k", 10); + K = OptionUtils.parse(options, "K", 2); + numFeatures = OptionUtils.parse(options, "k", 5); super.setOptions(options); } @@ -307,7 +310,15 @@ public TechnicalInformation getTechnicalInformation() { return info; } - public static void main(String[] args) { - ProblemTransformationMethod.runClassifier(new MLRF(), args); + @Override + public Capabilities getCapabilities() { + Capabilities result = super.getCapabilities(); + result.disable(Capabilities.Capability.NOMINAL_ATTRIBUTES); + result.disable(Capabilities.Capability.MISSING_VALUES); + return result; + } + + public static void main(String[] args) { + ProblemTransformationMethod.runClassifier(new MLRF(), args); } }