diff --git a/src/main/java/weka/finito/client.java b/src/main/java/weka/finito/client.java index e6eb729..b689e11 100644 --- a/src/main/java/weka/finito/client.java +++ b/src/main/java/weka/finito/client.java @@ -3,11 +3,8 @@ import java.io.*; import java.math.BigInteger; import java.net.Socket; -import java.nio.charset.StandardCharsets; import java.security.KeyPair; -import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; -import java.util.Base64; import java.util.HashMap; import java.util.Hashtable; @@ -25,7 +22,7 @@ import security.socialistmillionaire.bob; import weka.finito.structs.BigIntegers; -import static weka.finito.utils.Shared.hash; +import static weka.finito.utils.shared.hash; public final class client implements Runnable { private final String classes_file = "classes.txt"; diff --git a/src/main/java/weka/finito/level_site_thread.java b/src/main/java/weka/finito/level_site_thread.java index fd2f745..4eb79f2 100644 --- a/src/main/java/weka/finito/level_site_thread.java +++ b/src/main/java/weka/finito/level_site_thread.java @@ -2,7 +2,6 @@ import java.lang.System; -import security.misc.HomomorphicException; import security.socialistmillionaire.alice; import weka.finito.structs.BigIntegers; import weka.finito.structs.NodeInfo; @@ -11,10 +10,12 @@ import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; -import java.math.BigInteger; import java.net.Socket; import java.util.Hashtable; import java.util.List; +import java.util.Map; + +import static weka.finito.utils.shared.compare; public class level_site_thread implements Runnable { @@ -24,9 +25,7 @@ public class level_site_thread implements Runnable { private level_order_site level_site_data = null; - private alice Niu = null; - - private Hashtable encrypted_features; + private final Hashtable encrypted_features = new Hashtable<>(); private final AES crypto; public level_site_thread(Socket client_socket, level_order_site level_site_data, AES crypto) { @@ -46,7 +45,11 @@ public level_site_thread(Socket client_socket, level_order_site level_site_data, this.toClient.writeBoolean(true); closeClientConnection(); } else if (x instanceof Hashtable) { - encrypted_features = (Hashtable) x; + for (Map.Entry entry: ((Hashtable) x).entrySet()){ + if (entry.getKey() instanceof String && entry.getValue() instanceof BigIntegers) { + encrypted_features.put((String) entry.getKey(), (BigIntegers) entry.getValue()); + } + } // Have encrypted copy of thresholds if not done already for all nodes in level-site this.level_site_data = level_site_data; } else { @@ -70,45 +73,6 @@ private void closeClientConnection() throws IOException { } } - // a - from CLIENT, should already be encrypted... - private boolean compare(NodeInfo ld, int comparisonType) - throws HomomorphicException, ClassNotFoundException, IOException { - - long start_time = System.nanoTime(); - - BigIntegers encrypted_values = this.encrypted_features.get(ld.variable_name); - BigInteger encrypted_client_value = null; - BigInteger encrypted_thresh = null; - - // Encrypt the thresh-hold correctly - if ((comparisonType == 1) || (comparisonType == 2) || (comparisonType == 4)) { - encrypted_thresh = ld.getPaillier(); - encrypted_client_value = encrypted_values.getIntegerValuePaillier(); - toClient.writeInt(0); - Niu.setDGKMode(false); - } - else if ((comparisonType == 3) || (comparisonType == 5)) { - encrypted_thresh = ld.getDGK(); - encrypted_client_value = encrypted_values.getIntegerValueDGK(); - toClient.writeInt(1); - Niu.setDGKMode(true); - } - toClient.flush(); - assert encrypted_client_value != null; - long stop_time = System.nanoTime(); - - double run_time = (double) (stop_time - start_time); - run_time = run_time / 1000000; - System.out.printf("Comparison took %f ms\n", run_time); - if (((comparisonType == 1) && (ld.threshold == 0)) - || (comparisonType == 4) || (comparisonType == 5)) { - return Niu.Protocol4(encrypted_thresh, encrypted_client_value); - } - else { - return Niu.Protocol4(encrypted_client_value, encrypted_thresh); - } - } - // This will run the communication with client and next level site public final void run() { Object o; @@ -118,7 +82,7 @@ public final void run() { long start_time = System.nanoTime(); try { - Niu = new alice(client_socket); + alice niu = new alice(client_socket); if (this.level_site_data == null) { toClient.writeInt(-2); closeClientConnection(); @@ -126,8 +90,8 @@ public final void run() { } List node_level_data = this.level_site_data.get_node_data(); - Niu.setDGKPublicKey(this.level_site_data.dgk_public_key); - Niu.setPaillierPublicKey(this.level_site_data.paillier_public_key); + niu.setDGKPublicKey(this.level_site_data.dgk_public_key); + niu.setPaillierPublicKey(this.level_site_data.paillier_public_key); get_previous_index = fromClient.readBoolean(); if (get_previous_index) { @@ -161,32 +125,39 @@ public final void run() { ls = node_level_data.get(node_level_index); System.out.println("j=" + node_level_index); if (ls.isLeaf()) { - if (n == 2 * this.level_site_data.get_current_index() || n == 2 * this.level_site_data.get_current_index() + 1) { + if (n == 2 * this.level_site_data.get_current_index() + || n == 2 * this.level_site_data.get_current_index() + 1) { terminalLeafFound = true; System.out.println("Terminal leaf:" + ls.getVariableName()); } n += 2; } else { - if ((n==2 * this.level_site_data.get_current_index() || n == 2 * this.level_site_data.get_current_index() + 1)) { + if ((n==2 * this.level_site_data.get_current_index() + || n == 2 * this.level_site_data.get_current_index() + 1)) { if (ls.comparisonType == 6) { - boolean firstInequalityHolds = compare(ls, 3); + boolean firstInequalityHolds = compare(ls, 3, + encrypted_features, toClient, niu); if (firstInequalityHolds) { inequalityHolds = true; - } else { - boolean secondInequalityHolds = compare(ls, 5); + } + else { + boolean secondInequalityHolds = compare(ls, 5, + encrypted_features, toClient, niu); if (secondInequalityHolds) { inequalityHolds = true; } } - } else { - inequalityHolds = compare(ls, ls.comparisonType); + } + else { + inequalityHolds = compare(ls, ls.comparisonType, + encrypted_features, toClient, niu); } if (inequalityHolds) { equalsFound = true; this.level_site_data.set_next_index(next_index); - System.out.println("New index:" + this.level_site_data.get_current_index()); + System.out.println("New index: " + this.level_site_data.get_next_index()); } } n++; diff --git a/src/main/java/weka/finito/server.java b/src/main/java/weka/finito/server.java index 8316897..fd41374 100644 --- a/src/main/java/weka/finito/server.java +++ b/src/main/java/weka/finito/server.java @@ -8,7 +8,6 @@ import java.io.FileReader; import java.io.IOException; -import java.math.BigInteger; import java.net.ServerSocket; import java.net.Socket; import java.util.*; @@ -31,7 +30,7 @@ import weka.finito.structs.level_order_site; import weka.finito.structs.NodeInfo; -import static weka.finito.utils.Shared.hash; +import static weka.finito.utils.shared.*; public final class server implements Runnable { @@ -137,46 +136,6 @@ public server(String training_data, int precision, int server_port) { this.use_level_sites = false; } - private boolean compare(NodeInfo ld, int comparisonType, - Hashtable encrypted_features, - ObjectOutputStream toClient, alice Niu) - throws ClassNotFoundException, HomomorphicException, IOException { - - long start_time = System.nanoTime(); - - BigIntegers encrypted_values = encrypted_features.get(ld.variable_name); - BigInteger encrypted_client_value = null; - BigInteger encrypted_thresh = null; - - // Encrypt the thresh-hold correctly - if ((comparisonType == 1) || (comparisonType == 2) || (comparisonType == 4)) { - encrypted_thresh = ld.getPaillier(); - encrypted_client_value = encrypted_values.getIntegerValuePaillier(); - toClient.writeInt(0); - Niu.setDGKMode(false); - } - else if ((comparisonType == 3) || (comparisonType == 5)) { - encrypted_thresh = ld.getDGK(); - encrypted_client_value = encrypted_values.getIntegerValueDGK(); - toClient.writeInt(1); - Niu.setDGKMode(true); - } - toClient.flush(); - assert encrypted_client_value != null; - long stop_time = System.nanoTime(); - - double run_time = (double) (stop_time - start_time); - run_time = run_time / 1000000; - System.out.printf("Comparison took %f ms\n", run_time); - if (((comparisonType == 1) && (ld.threshold == 0)) - || (comparisonType == 4) || (comparisonType == 5)) { - return Niu.Protocol4(encrypted_thresh, encrypted_client_value); - } - else { - return Niu.Protocol4(encrypted_client_value, encrypted_thresh); - } - } - private void evaluate(int server_port) throws IOException, HomomorphicException { ServerSocket serverSocket = new ServerSocket(server_port); System.out.println("Server will be waiting for direct evaluation from client"); @@ -200,32 +159,36 @@ private void evaluate(int server_port) throws IOException, HomomorphicException Niu.setDGKPublicKey(dgk_public); List node_level_data; - NodeInfo ls = null; + NodeInfo ls; long start_time = System.nanoTime(); int previous_index = 0; // Traverse DT until you hit a leaf, the client has to track the index... for (level_order_site level_site_data : all_level_sites) { node_level_data = level_site_data.get_node_data(); - level_site_data.set_current_index(previous_index); // Handle at a level... int node_level_index = 0; int n = 0; int next_index = 0; - boolean equalsFound = false; boolean inequalityHolds = false; - boolean terminalLeafFound = false; - while (!equalsFound) { + while (true) { ls = node_level_data.get(node_level_index); System.out.println("j=" + node_level_index); if (ls.isLeaf()) { if (n == 2 * level_site_data.get_current_index() || n == 2 * level_site_data.get_current_index() + 1) { - terminalLeafFound = true; System.out.println("Terminal leaf:" + ls.getVariableName()); + // Tell the client the value + to_client_site.writeInt(-1); + to_client_site.writeObject(ls.getVariableName()); + to_client_site.flush(); + long stop_time = System.nanoTime(); + double run_time = (double) (stop_time - start_time); + run_time = run_time / 1000000; + System.out.printf("Total Server-Site run-time took %f ms\n", run_time); break; } n += 2; @@ -250,9 +213,11 @@ private void evaluate(int server_port) throws IOException, HomomorphicException } if (inequalityHolds) { - equalsFound = true; level_site_data.set_next_index(next_index); - System.out.println("New index:" + level_site_data.get_current_index()); + System.out.println("Next index:" + next_index); + // Remember this for next loop... + previous_index = next_index; + break; } } n++; @@ -260,22 +225,6 @@ private void evaluate(int server_port) throws IOException, HomomorphicException } node_level_index++; } - - if (terminalLeafFound) { - // Tell the client the value - to_client_site.writeInt(-1); - to_client_site.writeObject(ls.getVariableName()); - to_client_site.flush(); - long stop_time = System.nanoTime(); - double run_time = (double) (stop_time - start_time); - run_time = run_time / 1000000; - System.out.printf("Total Server-Site run-time took %f ms\n", run_time); - break; - } - else { - // Remember this for next loop... - previous_index = next_index; - } } } catch (ClassNotFoundException e) { throw new RuntimeException(e); diff --git a/src/main/java/weka/finito/utils/Shared.java b/src/main/java/weka/finito/utils/Shared.java deleted file mode 100644 index 0422f7d..0000000 --- a/src/main/java/weka/finito/utils/Shared.java +++ /dev/null @@ -1,14 +0,0 @@ -package weka.finito.utils; - -import java.nio.charset.StandardCharsets; -import java.security.MessageDigest; -import java.security.NoSuchAlgorithmException; -import java.util.Base64; - -public class Shared { - public static String hash(String text) throws NoSuchAlgorithmException { - MessageDigest digest = MessageDigest.getInstance("SHA-256"); - byte[] hash = digest.digest(text.getBytes(StandardCharsets.UTF_8)); - return Base64.getEncoder().encodeToString(hash); - } -} diff --git a/src/main/java/weka/finito/utils/shared.java b/src/main/java/weka/finito/utils/shared.java new file mode 100644 index 0000000..d8debc5 --- /dev/null +++ b/src/main/java/weka/finito/utils/shared.java @@ -0,0 +1,65 @@ +package weka.finito.utils; + +import security.misc.HomomorphicException; +import security.socialistmillionaire.alice; +import weka.finito.structs.BigIntegers; +import weka.finito.structs.NodeInfo; + +import java.io.IOException; +import java.io.ObjectOutputStream; +import java.math.BigInteger; +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.util.Base64; +import java.util.Hashtable; + +public class shared { + // Used by server-site to hash leaves and client-site to find the leaf + public static String hash(String text) throws NoSuchAlgorithmException { + MessageDigest digest = MessageDigest.getInstance("SHA-256"); + byte[] hash = digest.digest(text.getBytes(StandardCharsets.UTF_8)); + return Base64.getEncoder().encodeToString(hash); + } + + // Used by level-site and server-site to compare with a client + public static boolean compare(NodeInfo ld, int comparisonType, + Hashtable encrypted_features, + ObjectOutputStream toClient, alice Niu) + throws ClassNotFoundException, HomomorphicException, IOException { + + long start_time = System.nanoTime(); + + BigIntegers encrypted_values = encrypted_features.get(ld.variable_name); + BigInteger encrypted_client_value = null; + BigInteger encrypted_thresh = null; + + // Encrypt the thresh-hold correctly + if ((comparisonType == 1) || (comparisonType == 2) || (comparisonType == 4)) { + encrypted_thresh = ld.getPaillier(); + encrypted_client_value = encrypted_values.getIntegerValuePaillier(); + toClient.writeInt(0); + Niu.setDGKMode(false); + } + else if ((comparisonType == 3) || (comparisonType == 5)) { + encrypted_thresh = ld.getDGK(); + encrypted_client_value = encrypted_values.getIntegerValueDGK(); + toClient.writeInt(1); + Niu.setDGKMode(true); + } + toClient.flush(); + assert encrypted_client_value != null; + long stop_time = System.nanoTime(); + + double run_time = (double) (stop_time - start_time); + run_time = run_time / 1000000; + System.out.printf("Comparison took %f ms\n", run_time); + if (((comparisonType == 1) && (ld.threshold == 0)) + || (comparisonType == 4) || (comparisonType == 5)) { + return Niu.Protocol4(encrypted_thresh, encrypted_client_value); + } + else { + return Niu.Protocol4(encrypted_client_value, encrypted_thresh); + } + } +}