diff --git a/src/main/java/weka/finito/client.java b/src/main/java/weka/finito/client.java index bff92ad..722de9b 100644 --- a/src/main/java/weka/finito/client.java +++ b/src/main/java/weka/finito/client.java @@ -25,24 +25,19 @@ public final class client implements Runnable { private final SSLSocketFactory factory = (SSLSocketFactory) SSLSocketFactory.getDefault(); - private final String classes_file = "classes.txt"; private final String features_file; private final int key_size; private final int precision; - private final String [] level_site_ips; private final int [] level_site_ports; private final int port; - private String classification = null; - private KeyPair dgk; private KeyPair paillier; private features feature = null; private boolean classification_complete = false; private String [] classes; - private DGKPublicKey dgk_public_key; private PaillierPublicKey paillier_public_key; private DGKPrivateKey dgk_private_key; @@ -50,7 +45,6 @@ public final class client implements Runnable { private final HashMap hashed_classification = new HashMap<>(); private final String server_ip; private final int server_port; - private Integer next_index = 0; //For k8s deployment. public static void main(String[] args) { @@ -310,11 +304,6 @@ private void communicate_with_level_site(Socket level_site) client = new bob_joye(paillier, dgk, null); client.set_socket(level_site); - // Send bool: - // 1- true, there is an encrypted index coming - // 2- false, there is NO encrypted index coming - client.writeInt(next_index); - // Get the comparison // I am not sure why I need this loop, but you will only need 1 comparison. int comparison_type; @@ -337,7 +326,6 @@ else if (comparison_type == 1) { // true - get leaf value // false - get encrypted AES index for next round classification_complete = client.readBoolean(); - System.out.println("BOOLEAN READ"); if (classification_complete) { o = client.readObject(); if (o instanceof String) { @@ -346,7 +334,10 @@ else if (comparison_type == 1) { } } else { - next_index = client.readInt(); + o = from_level_site.readObject(); + if (o instanceof features) { + this.feature = (features) o; + } } } diff --git a/src/main/java/weka/finito/level_site_evaluation_thread.java b/src/main/java/weka/finito/level_site_evaluation_thread.java index 1efa494..38e894e 100644 --- a/src/main/java/weka/finito/level_site_evaluation_thread.java +++ b/src/main/java/weka/finito/level_site_evaluation_thread.java @@ -1,5 +1,6 @@ package weka.finito; +import java.io.ObjectOutputStream; import java.lang.System; import security.socialistmillionaire.alice_joye; @@ -13,27 +14,19 @@ import static weka.finito.utils.shared.*; public class level_site_evaluation_thread implements Runnable { - private final Socket client_socket; private final level_order_site level_site_data; private final features encrypted_features; - private final Socket next_level_site; + private final ObjectOutputStream oos; // This thread is ONLY to handle evaluations public level_site_evaluation_thread(Socket client_socket, level_order_site level_site_data, - features encrypted_features) { + features encrypted_features, ObjectOutputStream oos) { // Have encrypted copy of thresholds if not done already for all nodes in level-site - this(client_socket, level_site_data, encrypted_features, null); - } - - // This thread is ONLY to handle evaluations - public level_site_evaluation_thread(Socket client_socket, level_order_site level_site_data, - features encrypted_features, Socket next_level_site) { - // Have encrypted copy of thresholds if not done already for all nodes in level-site - this.level_site_data = level_site_data; this.client_socket = client_socket; + this.level_site_data = level_site_data; this.encrypted_features = encrypted_features; - this.next_level_site = next_level_site; + this.oos = oos; } // This will run the communication with client and next level site @@ -45,10 +38,9 @@ public final void run() { niu.setDGKPublicKey(this.level_site_data.dgk_public_key); niu.setPaillierPublicKey(this.level_site_data.paillier_public_key); - level_site_data.set_current_index(niu.readInt()); - // Null, keep going down the tree, // Not null, you got the correct leaf node of your DT! + // encrypted_features will have index updated within traverse_level NodeInfo reply = traverse_level(level_site_data, encrypted_features, niu); niu.writeInt(-1); @@ -59,7 +51,9 @@ public final void run() { } else { niu.writeBoolean(false); - niu.writeInt(level_site_data.get_next_index()); + oos.writeObject(encrypted_features); + oos.flush(); + // Update Index and send it down to the next level-site } long stop_time = System.nanoTime(); double run_time = (double) (stop_time - start_time); diff --git a/src/main/java/weka/finito/level_site_server.java b/src/main/java/weka/finito/level_site_server.java index 45a03f8..aec6a3b 100644 --- a/src/main/java/weka/finito/level_site_server.java +++ b/src/main/java/weka/finito/level_site_server.java @@ -6,6 +6,7 @@ import javax.net.ssl.SSLServerSocket; import javax.net.ssl.SSLServerSocketFactory; import javax.net.ssl.SSLSocket; +import javax.net.ssl.SSLSocketFactory; import java.io.IOException; import java.io.ObjectOutputStream; @@ -20,8 +21,9 @@ public class level_site_server implements Runnable { protected boolean isStopped = false; protected Thread runningThread= null; protected level_order_site level_site_parameters = null; - protected SSLServerSocketFactory factory = (SSLServerSocketFactory) SSLServerSocketFactory.getDefault(); + private final SSLSocketFactory socket_factory = (SSLSocketFactory) SSLSocketFactory.getDefault(); + private SSLSocket next_level_site; public static void main(String[] args) { setup_tls(); @@ -86,12 +88,19 @@ public void run() { // System.out.println("Level-Site received training data on Port: " + client_socket.getLocalPort()); oos.writeBoolean(true); // Create a persistent connection to next level-site and oos to send the next stuff down + /* + if(level_site_parameters.get_next_level_site() != null) { + next_level_site = (SSLSocket) socket_factory.createSocket( + level_site_parameters.get_next_level_site(), level_site_parameters.get_next_level_site_port()); + } + */ closeConnection(oos, ois, client_socket); } else if (o instanceof features) { // Start evaluating with the client level_site_evaluation_thread current_level_site_class = - new level_site_evaluation_thread(client_socket, this.level_site_parameters, (features) o); + new level_site_evaluation_thread(client_socket, this.level_site_parameters, + (features) o, oos); new Thread(current_level_site_class).start(); } else { diff --git a/src/main/java/weka/finito/server.java b/src/main/java/weka/finito/server.java index 16720c5..80f50e4 100644 --- a/src/main/java/weka/finito/server.java +++ b/src/main/java/weka/finito/server.java @@ -160,13 +160,14 @@ private void evaluate_with_client_directly(SSLSocket client_site) if (client_input instanceof features) { input = (features) client_input; } + assert input != null; long start_time = System.nanoTime(); int previous_index = 0; // Traverse DT until you hit a leaf, the client has to track the index... for (level_order_site level_site_data : all_level_sites) { - level_site_data.set_current_index(previous_index); + input.set_current_index(previous_index); // Handle at a level... NodeInfo leaf = traverse_level(level_site_data, input, Niu); @@ -184,7 +185,7 @@ private void evaluate_with_client_directly(SSLSocket client_site) } else { // Update the index for next merry-go-round - previous_index = level_site_data.get_next_index(); + previous_index = input.get_next_index(); } } } diff --git a/src/main/java/weka/finito/structs/features.java b/src/main/java/weka/finito/structs/features.java index 7fd1dcf..59d798a 100644 --- a/src/main/java/weka/finito/structs/features.java +++ b/src/main/java/weka/finito/structs/features.java @@ -14,12 +14,14 @@ public final class features implements Serializable { private static final long serialVersionUID = 6000706455545108960L; private String client_ip; private int next_index; + private int current_index; private final HashMap thresholds; public features(String path, int precision, PaillierPublicKey paillier_public_key, DGKPublicKey dgk_public_key) throws HomomorphicException, IOException { this.client_ip = ""; this.next_index = 0; + this.current_index =0; this.thresholds = read_values(path, precision, paillier_public_key, dgk_public_key); } @@ -43,6 +45,14 @@ public void set_next_index(int next_index) { this.next_index = next_index; } + public int get_current_index() { + return this.current_index; + } + + public void set_current_index(int current_index) { + this.current_index = current_index; + } + public static HashMap read_values(String path, int precision, PaillierPublicKey paillier_public_key, diff --git a/src/main/java/weka/finito/structs/level_order_site.java b/src/main/java/weka/finito/structs/level_order_site.java index 6abb14f..a620cf5 100644 --- a/src/main/java/weka/finito/structs/level_order_site.java +++ b/src/main/java/weka/finito/structs/level_order_site.java @@ -14,18 +14,12 @@ public final class level_order_site implements Serializable { private static final long serialVersionUID = 575566807906351024L; - private int index = 0; - private int next_index; private final int level; - public final PaillierPublicKey paillier_public_key; public final DGKPublicKey dgk_public_key; - private final List node_level_data = new ArrayList<>(); - // Set to value to a client, to let level-site d-1 know to talk to a client next with answer - private String next_level_site = "client"; - + private String next_level_site = null; private int next_level_site_port = -1; public level_order_site(int level, PaillierPublicKey paillier_public_key, DGKPublicKey dgk_public_key) { @@ -38,22 +32,6 @@ public List get_node_data() { return this.node_level_data; } - public void set_current_index(int index) { - this.index = index; - } - - public void set_next_index(int next_index) { - this.next_index = next_index; - } - - public int get_next_index() { - return this.next_index; - } - - public int get_current_index() { - return this.index; - } - public void set_next_level_site(String next_level_site) { this.next_level_site = next_level_site; } diff --git a/src/main/java/weka/finito/utils/shared.java b/src/main/java/weka/finito/utils/shared.java index f4a1f13..da83ff5 100644 --- a/src/main/java/weka/finito/utils/shared.java +++ b/src/main/java/weka/finito/utils/shared.java @@ -88,16 +88,16 @@ public static NodeInfo traverse_level(level_order_site level_site_data, ls = node_level_data.get(node_level_index); System.out.println("j=" + node_level_index); if (ls.isLeaf()) { - if (n == 2 * level_site_data.get_current_index() - || n == 2 * level_site_data.get_current_index() + 1) { + if (n == 2 * encrypted_features.get_current_index() + || n == 2 * encrypted_features.get_current_index() + 1) { terminalLeafFound = true; to_return = ls; } n += 2; } else { - if ((n == 2 * level_site_data.get_current_index() - || n == 2 * level_site_data.get_current_index() + 1)) { + if ((n == 2 * encrypted_features.get_current_index() + || n == 2 * encrypted_features.get_current_index() + 1)) { if (ls.comparisonType == 6) { inequalityHolds = compare(ls, 1, @@ -111,8 +111,8 @@ public static NodeInfo traverse_level(level_order_site level_site_data, if (inequalityHolds) { equalsFound = true; - level_site_data.set_next_index(next_index); - System.out.println("New index: " + level_site_data.get_next_index()); + encrypted_features.set_next_index(next_index); + System.out.println("New index: " + encrypted_features.get_next_index()); } } n++;