Skip to content

Commit

Permalink
Fix ERFH. Code cleanup and style.
Browse files Browse the repository at this point in the history
  • Loading branch information
agkphysics committed Oct 13, 2023
1 parent c6f26c8 commit 831796d
Show file tree
Hide file tree
Showing 3 changed files with 209 additions and 215 deletions.
20 changes: 15 additions & 5 deletions src/main/java/meka/classifiers/multilabel/meta/ERFH.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ public class ERFH extends MetaProblemTransformationMethod implements Randomizabl
*/
protected double threshold = 0.4;

public ERFH() {
this.m_Classifier = new BR();
}

public static void main(String[] args) {
ProblemTransformationMethod.runClassifier(new ERFH(), args);
}
Expand All @@ -51,7 +55,13 @@ public void setThreshold(double threshold) {
}

public String thresholdTipText() {
return "Prediction threshold for the multi-label classifier distribution.";
return "Prediction threshold";
}

@Override
protected String defaultClassifierString() {
// default classifier for CLI
return "meka.classifiers.multilabel.BR";
}

/**
Expand Down Expand Up @@ -88,11 +98,12 @@ public void buildClassifier(Instances D) throws Exception {

// Modify i'th HOMER tree
HOMER homer = new HOMER();
homer.setDebug(getDebug());
homer.setClassifier(AbstractMultiLabelClassifier.makeCopy(m_Classifier));
homer.setLabelSplitter(new VariableKLabelSplitter(m_Seed + i));
homer.setSeed(m_Seed + i);
if (getDebug())
System.out.print(" " + i);
System.out.print(" " + i + ":\n");
m_Classifiers[i] = homer;
m_Classifiers[i].buildClassifier(bag);
}
Expand All @@ -119,16 +130,15 @@ public String[] getOptions() {

@Override
public void setOptions(String[] options) throws Exception {
setThreshold(OptionUtils.parse(options, "T", 0.4));
super.setOptions(options);
threshold = OptionUtils.parse(options, "T", 0.4);
}

@Override
public Enumeration<Option> listOptions() {
Vector<Option> options = new Vector<>(2);
options.add(new Option(thresholdTipText(), "threshold", 1, "-T threshold"));
OptionUtils.add(options, super.listOptions());
//OptionUtils.add(options, homerClassifier.listOptions());
return options.elements();
}

Expand Down Expand Up @@ -165,7 +175,7 @@ public VariableKLabelSplitter(int seed) {
}

@Override
public Collection<Set<Integer>> splitLabels(int k, Collection<Integer> labels, Instances D) {
public Collection<Set<Integer>> splitLabels(int k, Set<Integer> labels, Instances D) {
return super.splitLabels(rng.nextInt(5) + 2, labels, D);
}
}
Expand Down
Loading

0 comments on commit 831796d

Please sign in to comment.