labels, Instances D) {
return super.splitLabels(rng.nextInt(5) + 2, labels, D);
}
}
diff --git a/src/main/java/meka/classifiers/multilabel/meta/HOMER.java b/src/main/java/meka/classifiers/multilabel/meta/HOMER.java
index 88e8766a..5120d056 100644
--- a/src/main/java/meka/classifiers/multilabel/meta/HOMER.java
+++ b/src/main/java/meka/classifiers/multilabel/meta/HOMER.java
@@ -17,7 +17,7 @@
import meka.classifiers.multilabel.*;
import meka.core.*;
-import weka.classifiers.*;
+import weka.classifiers.Classifier;
import weka.classifiers.meta.Bagging;
import weka.core.*;
import weka.core.TechnicalInformation.*;
@@ -26,7 +26,7 @@
import java.util.*;
/**
- * HOMER (Hierarchy Of Multi-label classifiERs algorithm.
+ * HOMER (Hierarchy Of Multi-label classifiERs) algorithm.
*
* This algorithm divides the multilabel classification problem into a tree structure where at each level the labels are
* partitioned into k subsets and one multilabel classifier classifies the labels per subset, using k metalabels
@@ -49,14 +49,14 @@ public class HOMER extends ProblemTransformationMethod implements Randomizable,
* The threshold to use for the multi-label classifier distribution at each node.
*/
protected String threshold = "0.3";
- /**
- * Root node of the HOMER tree.
- */
- private HOMERNode root;
/**
* The label splitter.
*/
protected LabelSplitter labelSplitter = new RandomLabelSplitter(0);
+ /**
+ * Root node of the HOMER tree.
+ */
+ private HOMERNode root;
/**
* HOMER node ID counter for debugging purposes.
*/
@@ -89,10 +89,6 @@ public void setThreshold(String t) {
threshold = t;
}
- public LabelSplitter getLabelSplitter() {
- return labelSplitter;
- }
-
public void setLabelSplitter(LabelSplitter splitter) {
this.labelSplitter = splitter;
}
@@ -138,7 +134,8 @@ public String globalInfo() {
public TechnicalInformation getTechnicalInformation() {
TechnicalInformation info = new TechnicalInformation(Type.INPROCEEDINGS);
info.setValue(Field.AUTHOR, "Tsoumakas, Grigorios and Katakis, Ioannis and Vlahavas, Ioannis");
- info.setValue(Field.TITLE, "Effective and efficient multilabel classification in domains with large number of labels");
+ info.setValue(Field.TITLE,
+ "Effective and efficient multilabel classification in domains with large number of labels");
info.setValue(Field.YEAR, "2008");
info.setValue(Field.PAGES, "53--59");
info.setValue(Field.BOOKTITLE, "Proc. ECML/PKDD 2008 Workshop on Mining Multidimensional Data (MMD’08)");
@@ -164,10 +161,8 @@ public void buildClassifier(Instances D) throws Exception {
if (!(m_Classifier instanceof MultiLabelClassifier))
throw new IllegalStateException("Classifier must be a MultiLabelClassifier!");
- root = new HOMERNode();
- root.setClassifier(ProblemTransformationMethod.makeCopy(m_Classifier));
- root.setSplitter(labelSplitter);
- root.buildClassifier(D);
+ root = new HOMERNode((ProblemTransformationMethod) ProblemTransformationMethod.makeCopy(m_Classifier));
+ root.buildTree(D);
if (getDebug())
System.out.println("Trained all nodes.");
}
@@ -203,11 +198,11 @@ public String[] getOptions() {
@Override
public void setOptions(String[] options) throws Exception {
- super.setOptions(options);
k = OptionUtils.parse(options, "k", 3);
seed = OptionUtils.parse(options, "S", 0);
threshold = OptionUtils.parse(options, "t", "0.3");
setLabelSplitterString(OptionUtils.parse(options, "ls", "RandomLabelSplitter"));
+ super.setOptions(options);
}
@Override
@@ -241,7 +236,7 @@ public interface LabelSplitter extends Serializable {
* @param D the dataset that defines the labels
* @return a partition of the labels in {@code parentLabels}
*/
- Collection> splitLabels(int k, Collection labels, Instances D);
+ Collection> splitLabels(int k, Set labels, Instances D);
}
/**
@@ -256,7 +251,7 @@ public abstract static class RandomizableLabelSplitter implements LabelSplitter,
/**
* Random seed for {@link #rng}.
*/
- protected int seed;
+ private int seed;
public RandomizableLabelSplitter(int seed) {
this.seed = seed;
@@ -287,25 +282,24 @@ public RandomLabelSplitter(int seed) {
}
@Override
- public Collection> splitLabels(int k, Collection labels, Instances D) {
+ public Collection> splitLabels(int k, Set labels, Instances D) {
List> result = new ArrayList<>();
- int L = labels.size();
- List labelList = Arrays.asList(labels.toArray(new Integer[0]));
+ int Ln = labels.size();
+ if (Ln <= k) {
+ // Ln children, each with a single label
+ for (int label : labels)
+ result.add(new HashSet<>(Collections.singletonList(label)));
+ return result;
+ }
+ List labelList = new ArrayList<>(labels);
Collections.shuffle(labelList, rng);
int idx = 0;
- if (L <= k) {
- for (int i = 0; i < L; i++) {
- Set labelSet = new HashSet<>(1);
+ for (int i = 0; i < k; i++) {
+ // k children, with either Ln / k or Ln / k + 1 labels per child
+ Set labelSet = new HashSet<>(Ln / k);
+ for (int x = 0; x < Ln / k + (i < Ln % k ? 1 : 0); x++)
labelSet.add(labelList.get(idx++));
- result.add(labelSet);
- }
- } else {
- for (int i = 0; i < k; i++) {
- Set labelSet = new HashSet<>(L / k);
- for (int x = 0; x < (i < L % k ? L / k + 1 : L / k); x++)
- labelSet.add(labelList.get(idx++));
- result.add(labelSet);
- }
+ result.add(labelSet);
}
return result;
}
@@ -319,59 +313,71 @@ public Collection> splitLabels(int k, Collection labels, I
public static class ClusterLabelSplitter extends RandomizableLabelSplitter {
private static final long serialVersionUID = 3545709733670034266L;
- private double[][] labelVectors = null;
public ClusterLabelSplitter(int seed) {
super(seed);
}
@Override
- public Collection> splitLabels(int k, Collection labels, Instances D) {
- int L = D.classIndex();
- int n = D.numInstances();
- if (labelVectors == null) {
- labelVectors = new double[L][n];
- for (int i = 0; i < n; i++)
- for (int l = 0; l < L; l++)
- labelVectors[l][i] = D.get(i).value(l);
- }
-
- List labelList = new ArrayList<>(labels);
- int idx = 0;
+ public Collection> splitLabels(int k, Set labels, Instances D) {
int Ln = labels.size();
-
if (Ln <= k) {
+ // Ln children, each with a single label
List> result = new ArrayList<>(Ln);
- for (int i = 0; i < Ln; i++)
- result.add(new HashSet<>(Collections.singletonList(labelList.get(idx++))));
+ for (int label : labels)
+ result.add(new HashSet<>(Collections.singletonList(label)));
return result;
}
+ int L = D.classIndex();
+ int n = D.numInstances();
+
+ /*
+ * L x n label matrix. Each row corresponds to a label and represents the presence of that label for
+ * each of the n instances.
+ */
+ double[][] labelMatrix = new double[L][n];
+ for (int i = 0; i < n; i++)
+ for (int l = 0; l < L; l++)
+ labelMatrix[l][i] = D.get(i).value(l);
+
+ /* See Fig. 2 of the paper for the below steps */
+
+ // Init kmeans centroids to random labels
+ Set centreIdx = new HashSet<>(k);
+ while (centreIdx.size() != k)
+ centreIdx.add(rng.nextInt(labels.size()));
+ List centreIdxList = List.copyOf(centreIdx);
+ List labelList = List.copyOf(labels);
+ double[][] centroids = new double[k][n];
+ for (int i = 0; i < k; i++)
+ centroids[i] = labelMatrix[labelList.get(centreIdxList.get(i))];
+
List> clusters = new ArrayList<>(k);
- int[] centres = new int[k];
- Collections.shuffle(labelList, rng);
- for (int i = 0; i < k; i++) {
+ for (int i = 0; i < k; i++)
clusters.add(new ArrayList<>(L / k));
- centres[i] = labelList.get(idx++);
- }
double[][] d = new double[L][k];
- for (int it = 0; it < 2; it++) {
- for (Integer l : labels) {
+ for (int it = 0; it < 2; it++) { // kmeans iterations
+ for (List cluster : clusters)
+ cluster.clear();
+ for (Integer label : labelList) {
for (int i = 0; i < k; i++) {
- /* Calculate distance to cluster centres */
- double[] vci = labelVectors[centres[i]];
- double[] vl = labelVectors[l];
+ // Calculate distance to cluster centres
double sum = 0;
- for (int j = 0; j < n; j++)
- sum += (vci[j] - vl[j]) * (vci[j] - vl[j]);
- d[l][i] = Math.sqrt(sum);
+ for (int j = 0; j < n; j++) {
+ double diff = centroids[i][j] - labelMatrix[label][j];
+ sum += diff * diff;
+ }
+ d[label][i] = Math.sqrt(sum);
}
boolean finished = false;
- int nu = l;
+ int nu = label;
while (!finished) {
int minj = 0;
+
+ // argmin distance over centroids
double currMin = Double.MAX_VALUE;
for (int j = 0; j < k; j++) {
if (d[nu][j] < currMin) {
@@ -385,7 +391,7 @@ public Collection> splitLabels(int k, Collection labels, I
break;
Cj.add(nu);
final int _minj = minj;
- Cj.sort(Comparator.comparingDouble(o -> d[o][_minj]));
+ Cj.sort(Comparator.comparingDouble(val -> d[val][_minj]));
int size = Cj.size();
if (size > Math.ceil(Ln / (double) k)) {
@@ -397,15 +403,15 @@ public Collection> splitLabels(int k, Collection labels, I
}
}
for (int i = 0; i < k; i++) {
- double[] cent = new double[n];
- for (Integer l : clusters.get(i)) {
- /* Calculate new cluster centres */
- double[] vl = labelVectors[l];
+ // Calculate new cluster centres
+ List cluster = clusters.get(i);
+ for (Integer l : cluster) {
+ double[] vl = labelMatrix[l];
for (int j = 0; j < n; j++)
- cent[j] += vl[j];
+ centroids[i][j] += vl[j];
}
for (int j = 0; j < n; j++)
- cent[j] /= clusters.get(i).size();
+ centroids[i][j] /= cluster.size();
}
}
Collection> clusterSets = new ArrayList<>(clusters.size());
@@ -418,19 +424,18 @@ public Collection> splitLabels(int k, Collection labels, I
/**
* Utility class for storing node info.
*/
- private class HOMERNode extends SingleClassifierEnhancer {
+ private class HOMERNode {
private final int id;
- private final List children;
- private Set labels;
+ private final List children = new ArrayList<>();
+ private final ProblemTransformationMethod classifier;
+ private Set labels = new HashSet<>();
private Instances instances;
private double[] thresholds;
- private LabelSplitter splitter;
- public HOMERNode() {
+ private HOMERNode(ProblemTransformationMethod classifier) {
id = debugNodeId++;
- children = new ArrayList<>();
- labels = new HashSet<>();
+ this.classifier = classifier;
}
/**
@@ -447,24 +452,18 @@ private void buildNode() throws Exception {
}
Instances culledInstances = new Instances(instances, 0);
- for (Instance i : getInstances()) {
- for (Integer label : labels) {
- if (i.stringValue(label).equals("1")) {
- culledInstances.add(i);
- break;
- }
- }
- }
+ for (Instance instance : instances)
+ if (labels.stream().anyMatch(label -> instance.stringValue(label).equals("1")))
+ culledInstances.add(instance);
- getChildren().clear();
- Collection> labelSplits = splitter.splitLabels(k, labels, culledInstances);
+ children.clear();
+ Collection> labelSplits = labelSplitter.splitLabels(k, labels, culledInstances);
for (Set labelSet : labelSplits) {
- HOMERNode child = new HOMERNode();
- child.setClassifier(AbstractClassifier.makeCopy(m_Classifier));
- child.setSplitter(splitter);
- child.setInstances(culledInstances);
- child.setLabels(labelSet);
- getChildren().add(child);
+ HOMERNode child = new HOMERNode((ProblemTransformationMethod) ProblemTransformationMethod.makeCopy(
+ classifier));
+ child.instances = culledInstances;
+ child.labels = labelSet;
+ children.add(child);
if (labelSet.size() > 1) // Leaves don't need to recurse down any further
child.buildNode();
}
@@ -473,8 +472,8 @@ private void buildNode() throws Exception {
/**
* Train the HOMER tree.
*
- * @throws Exception If an exception is thrown during label preprocessing, training the base classifier, or training any
- * children of this node.
+ * @throws Exception If an exception is thrown during label preprocessing, training the base classifier, or
+ * training any children of this node.
*/
private void trainNode() throws Exception {
int L = instances.classIndex();
@@ -488,53 +487,52 @@ private void trainNode() throws Exception {
}
newInstances.setClassIndex(c);
- /* Set each meta-label value to the disjunction of instance label values */
+ /* Set each meta-label value to the disjunction of instance values of child labels. Remove instance if it
+ * has no labels among child labels.
+ */
Iterator it = newInstances.iterator();
for (int i = 0; it.hasNext(); i++) {
boolean remove = true;
- Instance next = it.next();
+ Instance newInstance = it.next();
+ Instance oldInstance = instances.get(i);
for (int m = 0; m < c; m++) {
- next.setValue(m, "0");
- for (Integer l : children.get(m).getLabels()) {
- if (instances.get(i).stringValue(l).equals("1")) {
- next.setValue(m, "1");
- remove = false;
- break;
- }
+ newInstance.setValue(m, "0");
+ if (children.get(m).labels.stream().anyMatch(lab -> oldInstance.stringValue(lab).equals("1"))) {
+ newInstance.setValue(m, "1");
+ remove = false;
}
}
if (remove)
it.remove();
}
- setInstances(newInstances);
- if (!newInstances.isEmpty()) { // This should rarely not be the case
- Classifier classifier = ((ProblemTransformationMethod) m_Classifier).getClassifier();
- if (newInstances.size() == 1 && classifier instanceof Bagging)
- ((Bagging) classifier).setBagSizePercent(100);
- m_Classifier.buildClassifier(newInstances);
+ instances = newInstances;
+ if (!newInstances.isEmpty()) { // This should usually be true
+ Classifier baseClassifier = classifier.getClassifier();
+ if (newInstances.size() == 1 && baseClassifier instanceof Bagging)
+ ((Bagging) baseClassifier).setBagSizePercent(100);
+ classifier.buildClassifier(newInstances);
}
if (getDebug())
- System.out.println("Trained node " + getId() + " with " + newInstances.size() + " instances.");
+ System.out.println("Trained node " + id + " with " + newInstances.size() + " instances.");
if (newInstances.size() < c) {
double[] thresholds = new double[c];
Arrays.fill(thresholds, 0.5);
- setThresholds(thresholds);
+ this.thresholds = thresholds;
} else {
try {
Double.parseDouble(threshold);
- setThresholds(ThresholdUtils.thresholdStringToArray(threshold, c));
+ this.thresholds = ThresholdUtils.thresholdStringToArray(threshold, c);
} catch (NumberFormatException e) {
- Result r = meka.classifiers.multilabel.Evaluation.testClassifier((MultiLabelClassifier) m_Classifier, newInstances);
+ Result r = meka.classifiers.multilabel.Evaluation.testClassifier(classifier, newInstances);
String threshStr = MLEvalUtils.getThreshold(r.predictions, newInstances, threshold);
- setThresholds(ThresholdUtils.thresholdStringToArray(threshStr, c));
+ this.thresholds = ThresholdUtils.thresholdStringToArray(threshStr, c);
}
}
- for (HOMERNode child : children) {
- if (child.getLabels().size() > 1)
+ for (HOMERNode child : children)
+ if (child.labels.size() > 1)
child.trainNode();
- }
}
/**
@@ -544,80 +542,39 @@ private void trainNode() throws Exception {
* @param y shared array of prediction values
* @throws Exception If this node's classifier or any child classifiers throw an exception.
*/
- public void classify(Instance x, double[] y) throws Exception {
- int c = getChildren().size();
+ private void classify(Instance x, double[] y) throws Exception {
+ int c = children.size();
Instance newX = (Instance) x.copy();
int oldL = newX.classIndex();
newX.setDataset(null);
MLUtils.deleteAttributesAt(newX, A.make_sequence(oldL));
for (int i = 0; i < c; i++)
newX.insertAttributeAt(i);
- newX.setDataset(getInstances());
- double[] metaDist;
- if (!getInstances().isEmpty())
- metaDist = getClassifier().distributionForInstance(newX); // Distribution on meta-labels
- else
- metaDist = new double[c];
- double[] thresholds = getThresholds();
- for (int i = 0; i < getChildren().size(); i++) {
- HOMERNode child = getChildren().get(i);
+ newX.setDataset(instances);
+ // Distribution over meta-labels
+ double[] metaDist = instances.isEmpty() ? new double[c] : classifier.distributionForInstance(newX);
+ for (int i = 0; i < children.size(); i++) {
+ HOMERNode child = children.get(i);
if (metaDist[i] > thresholds[i]) {
- if (child.getLabels().size() == 1)
- y[child.getLabels().iterator().next()] = 1;
+ if (child.labels.size() == 1)
+ y[child.labels.iterator().next()] = 1;
else
child.classify(newX, y);
} else {
- for (int l : child.getLabels())
+ for (int l : child.labels)
y[l] = 0;
}
}
}
- public List getChildren() {
- return children;
- }
-
- public Set getLabels() {
- return labels;
- }
-
- public void setLabels(Set labels) {
- this.labels = labels;
- }
-
- public Instances getInstances() {
- return instances;
- }
-
- public void setInstances(Instances instances) {
- this.instances = instances;
- }
-
- public double[] getThresholds() {
- return thresholds;
- }
-
- public void setThresholds(double[] thresholds) {
- this.thresholds = thresholds;
- }
-
- public int getId() {
- return id;
- }
-
- public void setSplitter(LabelSplitter splitter) {
- this.splitter = splitter;
- }
-
- @Override
- public void buildClassifier(Instances D) throws Exception {
+ private void buildTree(Instances D) throws Exception {
int L = D.classIndex();
Set labelSet = new HashSet<>(L);
for (int l = 0; l < L; l++)
labelSet.add(l);
- setInstances(D);
- setLabels(labelSet);
+ instances = D;
+ labels = labelSet;
buildNode();
trainNode();
}
diff --git a/src/main/java/meka/classifiers/multilabel/meta/MLRF.java b/src/main/java/meka/classifiers/multilabel/meta/MLRF.java
index 04080329..f5aa3b66 100644
--- a/src/main/java/meka/classifiers/multilabel/meta/MLRF.java
+++ b/src/main/java/meka/classifiers/multilabel/meta/MLRF.java
@@ -80,6 +80,16 @@ public String numFeaturesTipText() {
return "The dimensionality reduction parameter.";
}
+ @Override
+ public String globalInfo() {
+ return "Multi-label rotation forest.";
+ }
+
+ @Override
+ protected String defaultClassifierString() {
+ return "meka.classifiers.multilabel.BR";
+ }
+
/**
* Partitions the features into disjoint subsets.
*
@@ -107,39 +117,39 @@ private List> generateFeatureSubsets(int d) {
@Override
public void buildClassifier(Instances D) throws Exception {
testCapabilities(D);
- int n = D.numInstances();
+ int numInstances = D.numInstances();
int L = D.classIndex();
int d = D.numAttributes() - L;
m_Classifiers = ProblemTransformationMethod.makeCopies((MultiLabelClassifier) m_Classifier, m_NumIterations);
R = new DenseMatrix[m_NumIterations];
Instances transformedInstances = F.remove(D, A.make_sequence(L), true);
- int nNew = n * m_BagSizePercent / 100;
+ int nNew = numInstances * m_BagSizePercent / 100;
Matrix dataMatrix;
try {
- dataMatrix = new DenseMatrix(n, d);
- for (int i = 0; i < n; i++)
+ dataMatrix = new DenseMatrix(numInstances, d);
+ for (int i = 0; i < numInstances; i++)
for (int j = 0; j < d; j++)
dataMatrix.set(i, j, D.get(i).value(L + j));
} catch (OutOfMemoryError e) {
- int[][] nz = new int[n][];
- for (int i = 0; i < n; i++) {
+ int[][] nz = new int[numInstances][];
+ for (int i = 0; i < numInstances; i++) {
Instance inst = transformedInstances.get(i);
nz[i] = new int[inst.numValues()];
for (int j = 0; j < inst.numValues(); j++)
nz[i][j] = inst.index(j);
}
- dataMatrix = new CompRowMatrix(n, d, nz);
- for (int i = 0; i < n; i++) {
+ dataMatrix = new CompRowMatrix(numInstances, d, nz);
+ for (int i = 0; i < numInstances; i++) {
Instance inst = transformedInstances.get(i);
for (int j = 0; j < inst.numValues(); j++)
dataMatrix.set(i, nz[i][j], inst.valueSparse(j));
}
}
- DenseMatrix transformedData = new DenseMatrix(n, numFeatures * K);
- DenseMatrix partialSInv = new DenseMatrix(numFeatures, numFeatures);
+ DenseMatrix transformedData = new DenseMatrix(numInstances, numFeatures * K);
+ // Instance attributes after dimensionality reduction
for (int m = 0; m < numFeatures * K; m++)
transformedInstances.insertAttributeAt(new Attribute("F" + m), L + m);
transformedInstances.setClassIndex(L);
@@ -147,50 +157,66 @@ public void buildClassifier(Instances D) throws Exception {
for (int i = 0; i < m_NumIterations; i++) {
if (getDebug())
- System.out.print("Building classifier " + i + ": subsets");
+ System.out.print("Building classifier " + i + ": subsets ");
Random r = new Random(m_Seed + i);
R[i] = new DenseMatrix(d, numFeatures * K);
List> subsets = generateFeatureSubsets(d);
- for (int s = 0; s < K; s++) {
+ for (int subId = 0; subId < K; subId++) {
if (getDebug())
- System.out.print(" " + s);
- List subset = subsets.get(s);
- int m = subset.size();
- DenseMatrix mat = new DenseMatrix(nNew, m);
- for (int j = 0; j < nNew; j++)
- for (int k = 0; k < m; k++)
- mat.set(j, k, D.get(r.nextInt(n)).value(subset.get(k) + L));
-
- SVD svd = new SVD(nNew, m);
+ System.out.print(subId + " ");
+ List subset = subsets.get(subId);
+ int subSize = subset.size();
+ DenseMatrix mat = new DenseMatrix(nNew, subSize);
+ // Bootstrap sample from data
+ for (int j = 0; j < nNew; j++) {
+ int sampId = r.nextInt(numInstances);
+ for (int k = 0; k < subSize; k++)
+ mat.set(j, k, dataMatrix.get(sampId, subset.get(k)));
+ }
+
+ // Latent semantic indexing
+ SVD svd = new SVD(nNew, subSize);
svd = svd.factor(mat);
- double[] singVals = svd.getS();
- Matrix partialVt = Matrices.getSubMatrix(svd.getVt(), A.make_sequence(numFeatures), A.make_sequence(m));
-
- for (int j = 0; j < numFeatures; j++)
- if (singVals[j] > 0)
- partialSInv.set(j, j, 1 / singVals[j]);
-
- DenseMatrix transformationMatrix = new DenseMatrix(m, numFeatures);
- partialVt.transAmult(partialSInv, transformationMatrix);
- for (int j = 0; j < m; j++) {
- Matrix row = Matrices.getSubMatrix(transformationMatrix, new int[]{j}, A.make_sequence(numFeatures));
- Matrices.getSubMatrix(R[i], new int[]{subset.get(j)}, A.make_sequence(
- numFeatures * s, numFeatures * (s + 1))).set(row);
+ DenseMatrix transformationMatrix = getTransformationMatrix(svd, subSize);
+
+ // Insert each subset transformation matrix into full rotation matrix R[i]
+ for (int j = 0; j < subSize; j++) {
+ Matrix row = Matrices.getSubMatrix(transformationMatrix,
+ new int[]{j},
+ A.make_sequence(numFeatures));
+ Matrices.getSubMatrix(R[i],
+ new int[]{subset.get(j)},
+ A.make_sequence(numFeatures * subId, numFeatures * (subId + 1))).set(row);
}
}
+ // Transform data and train classifier on transformed instances
dataMatrix.mult(R[i], transformedData);
-
for (int m = 0; m < numFeatures * K; m++)
- for (int j = 0; j < n; j++)
+ for (int j = 0; j < numInstances; j++)
transformedInstances.get(j).setValue(L + m, transformedData.get(j, m));
+ if (getDebug())
+ System.out.print("Training base multi-label classifier on transformed data.");
m_Classifiers[i].buildClassifier(transformedInstances);
if (getDebug())
System.out.println();
}
}
+ private DenseMatrix getTransformationMatrix(SVD svd, int m) {
+ double[] singVals = svd.getS();
+ Matrix partialVt = Matrices.getSubMatrix(svd.getVt(), A.make_sequence(numFeatures), A.make_sequence(m));
+ DenseMatrix partialSInv = new DenseMatrix(numFeatures, numFeatures);
+ for (int j = 0; j < numFeatures; j++)
+ if (singVals[j] > 0)
+ partialSInv.set(j, j, 1 / singVals[j]);
+
+ DenseMatrix transformationMatrix = new DenseMatrix(m, numFeatures);
+ partialVt.transAmult(partialSInv, transformationMatrix);
+ return transformationMatrix;
+ }
+
@Override
public double[] distributionForInstance(Instance x) throws Exception {
int L = x.classIndex();
@@ -234,9 +260,9 @@ public String[] getOptions() {
@Override
public void setOptions(String[] options) throws Exception {
- super.setOptions(options);
K = OptionUtils.parse(options, "K", 10);
numFeatures = OptionUtils.parse(options, "k", 10);
+ super.setOptions(options);
}
@Override
@@ -252,7 +278,8 @@ public Enumeration