From 08a88cc2264431df3f98c9a3d5967e4aaee62afa Mon Sep 17 00:00:00 2001 From: AndrewQuijano Date: Sun, 15 Oct 2023 22:20:09 -0400 Subject: [PATCH 01/12] I got this to run, but the current version misclassifies though --- src/main/java/weka/finito/client.java | 113 ++++++++++++--- src/main/java/weka/finito/server.java | 189 +++++++++++++++++++++++++- src/test/java/PrivacyTest.java | 56 +++++++- 3 files changed, 332 insertions(+), 26 deletions(-) diff --git a/src/main/java/weka/finito/client.java b/src/main/java/weka/finito/client.java index 8660337..86a04d1 100644 --- a/src/main/java/weka/finito/client.java +++ b/src/main/java/weka/finito/client.java @@ -50,7 +50,6 @@ public final class client implements Runnable { private DGKPrivateKey dgk_private_key; private PaillierPrivateKey paillier_private_key; private final HashMap hashed_classification = new HashMap<>(); - private boolean talk_to_server_site; private final String server_ip; private final int server_port; @@ -100,7 +99,17 @@ public static void main(String[] args) { client test = null; if (args.length == 1) { - test = new client(key_size, args[0], level_domains, port, precision,server_ip, port); + // Test with level-sites + test = new client(key_size, args[0], level_domains, port, precision, server_ip, port); + } + else if (args.length == 2) { + // Test with just a server-site directly + if (args[1].equalsIgnoreCase("--server")) { + test = new client(key_size, args[0], precision, server_ip, port); + } + else { + test = new client(key_size, args[0], level_domains, port, precision, server_ip, port); + } } else { System.out.println("Missing Testing Data set as an argument parameter"); @@ -136,6 +145,24 @@ public client(int key_size, String features_file, String [] level_site_ips, int this.server_port = server_port; } + // Testing using only a single server, no level-sites + public client(int key_size, String features_file, + int precision, String server_ip, int server_port) { + this.key_size = key_size; + this.features_file = features_file; + this.level_site_ips = null; + this.level_site_ports = null; + this.port = -1; + this.precision = precision; + this.server_ip = server_ip; + this.server_port = server_port; + } + + // Get Classification after Evaluation + public String getClassification() { + return this.classification; + } + public void generate_keys() { // Generate Key Pairs DGKKeyPairGenerator p = new DGKKeyPairGenerator(); @@ -199,7 +226,7 @@ private boolean need_keys() { } // Used for set-up - private void communicate_with_server_site(PaillierPublicKey paillier, DGKPublicKey dgk) + private void setup_with_server_site(PaillierPublicKey paillier, DGKPublicKey dgk) throws IOException, ClassNotFoundException { System.out.println("Connecting to " + server_ip + ":" + server_port); try (Socket server_site = new Socket(server_ip, server_port)) { @@ -217,12 +244,6 @@ private void communicate_with_server_site(PaillierPublicKey paillier, DGKPublicK } } - // Get Classification after Evaluation - public String getClassification() { - return this.classification; - } - - // Evaluation private Hashtable read_features(String path, PaillierPublicKey paillier_public_key, @@ -262,7 +283,47 @@ private Hashtable read_features(String path, } } - // Function used to Evaluate + private void evaluate_with_server_site(Socket server_site) throws IOException, HomomorphicException, ClassNotFoundException { + // Communicate with each Level-Site + Object o; + bob client; + + // Create I/O streams + ObjectOutputStream to_server_site = new ObjectOutputStream(server_site.getOutputStream()); + ObjectInputStream from_server_site = new ObjectInputStream(server_site.getInputStream()); + + // Send the encrypted data to Level-Site + to_server_site.writeObject(this.feature); + to_server_site.flush(); + + // Send the Public Keys using Alice and Bob + client = new bob(server_site, paillier, dgk); + + // Work with the comparison + int comparison_type; + while(true) { + comparison_type = from_server_site.readInt(); + if (comparison_type == -1) { + this.classification_complete = true; + break; + } + else if (comparison_type == 0) { + client.setDGKMode(false); + } + else if (comparison_type == 1) { + client.setDGKMode(true); + } + client.Protocol4(); + } + + o = from_server_site.readObject(); + if (o instanceof String) { + classification = (String) o; + classification = hashed_classification.get(classification); + } + } + + // Function used to Evaluate for each level-site private void communicate_with_level_site(Socket level_site) throws IOException, ClassNotFoundException, HomomorphicException { // Communicate with each Level-Site @@ -336,9 +397,9 @@ else if (comparison_type == 1) { } } - // Function used to Evaluate + // Function used to Train (if needed) and Evaluate public void run() { - this.talk_to_server_site = this.need_keys(); + boolean talk_to_server_site = this.need_keys(); try { // Don't regenerate keys if you are just using a different VALUES file @@ -356,7 +417,7 @@ public void run() { // Client needs to know all possible classes... if (talk_to_server_site) { // Don't send keys to server-site to ask for classes since now it is assumed level-sites are up - communicate_with_server_site(paillier_public_key, dgk_public_key); + setup_with_server_site(paillier_public_key, dgk_public_key); for (String aClass : classes) { hashed_classification.put(hash(aClass), aClass); } @@ -374,8 +435,26 @@ public void run() { int connection_port; long start_time = System.nanoTime(); - try { + // If you are just evaluating directly with the server-site + if (level_site_ips == null) { + try(Socket server_site = new Socket(server_ip, server_port)) { + System.out.println("Client connected to sever-site with PPDT"); + evaluate_with_server_site(server_site); + long end_time = System.nanoTime(); + System.out.println("The Classification is: " + classification); + double run_time = (double) (end_time - start_time); + run_time = run_time/1000000; + System.out.printf("It took %f ms to classify\n", run_time); + finish_evaluation(); + } + catch (HomomorphicException | IOException | ClassNotFoundException e) { + throw new RuntimeException(e); + } + return; + } + + try { for (int i = 0; i < level_site_ips.length; i++) { if (classification_complete) { break; @@ -399,11 +478,14 @@ public void run() { double run_time = (double) (end_time - start_time); run_time = run_time/1000000; System.out.printf("It took %f ms to classify\n", run_time); + finish_evaluation(); } catch (Exception e) { throw new RuntimeException(e); } + } + private void finish_evaluation() throws IOException { // At the end, write your keys... dgk_public_key.writeKey("dgk.pub"); paillier_public_key.writeKey("paillier.pub"); @@ -417,8 +499,5 @@ public void run() { writer.write("\n"); } } - catch (IOException e) { - throw new RuntimeException(e); - } } } diff --git a/src/main/java/weka/finito/server.java b/src/main/java/weka/finito/server.java index 47631da..2e30659 100644 --- a/src/main/java/weka/finito/server.java +++ b/src/main/java/weka/finito/server.java @@ -8,17 +8,21 @@ import java.io.FileReader; import java.io.IOException; +import java.math.BigInteger; import java.net.ServerSocket; import java.net.Socket; import java.nio.charset.StandardCharsets; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; import java.util.*; +import java.util.Map.Entry; import java.lang.System; import java.util.concurrent.TimeUnit; import security.DGK.DGKPublicKey; +import security.misc.HomomorphicException; +import security.socialistmillionaire.alice; import security.paillier.PaillierPublicKey; import weka.classifiers.trees.j48.BinC45ModelSelection; import weka.classifiers.trees.j48.C45PruneableClassifierTree; @@ -26,6 +30,7 @@ import weka.core.Instances; import weka.core.SerializationHelper; +import weka.finito.structs.BigIntegers; import weka.finito.structs.level_order_site; import weka.finito.structs.NodeInfo; @@ -45,13 +50,23 @@ public final class server implements Runnable { private final List all_level_sites = new ArrayList<>(); private final int server_port; + private final boolean use_level_sites; public static void main(String[] args) { int port = 0; int precision = 0; + String training_data = null; + boolean use_level_sites = false; // Get data for training. - if (args.length != 1) { + if (args.length == 1) { + training_data = args[0]; + } + else if (args.length == 2){ + training_data = args[0]; + use_level_sites = args[1].equalsIgnoreCase("--client"); + } + else { System.out.println("Missing Training Data set as an argument parameter"); System.exit(1); } @@ -80,11 +95,20 @@ public static void main(String[] args) { // Create and run the server. System.out.println("Server Initialized and started running"); - server server = new server(args[0], level_domains, port, precision, port); + + server server; + // Pick either Level-Sites or no Level-site + if (use_level_sites) { + server = new server(training_data, level_domains, port, precision, port); + + } + else { + server = new server(training_data, precision, port); + } server.run(); - } + } - // For local host testing, (GitHub Actions CI, on PrivacyTest.java) + // For local host testing, (GitHub Actions CI, on PrivacyTest.java, for level-sites) public server(String training_data, String [] level_site_ips, int [] level_site_ports, int precision, int server_port) { this.training_data = training_data; @@ -92,15 +116,158 @@ public server(String training_data, String [] level_site_ips, int [] level_site_ this.level_site_ports = level_site_ports; this.precision = precision; this.server_port = server_port; + this.use_level_sites = true; } - // For Cloud environment, (Testing with Kubernetes) + // For Cloud environment, (Testing with Kubernetes, for level-sites) public server(String training_data, String [] level_site_domains, int port, int precision, int server_port) { this.training_data = training_data; this.level_site_ips = level_site_domains; this.port = port; this.precision = precision; this.server_port = server_port; + this.use_level_sites = true; + } + + // For testing, but having just a client and server + public server(String training_data, int precision, int server_port) { + this.training_data = training_data; + this.level_site_ips = null; + this.precision = precision; + this.server_port = server_port; + this.use_level_sites = false; + } + + private boolean compare(NodeInfo ld, int comparisonType, + Hashtable encrypted_features, + ObjectOutputStream toClient, alice Niu) + throws ClassNotFoundException, HomomorphicException, IOException { + + long start_time = System.nanoTime(); + + BigIntegers encrypted_values = encrypted_features.get(ld.variable_name); + BigInteger encrypted_client_value = null; + BigInteger encrypted_thresh = null; + + // Encrypt the thresh-hold correctly + if ((comparisonType == 1) || (comparisonType == 2) || (comparisonType == 4)) { + encrypted_thresh = ld.getPaillier(); + encrypted_client_value = encrypted_values.getIntegerValuePaillier(); + toClient.writeInt(0); + Niu.setDGKMode(false); + } + else if ((comparisonType == 3) || (comparisonType == 5)) { + encrypted_thresh = ld.getDGK(); + encrypted_client_value = encrypted_values.getIntegerValueDGK(); + toClient.writeInt(1); + Niu.setDGKMode(true); + } + toClient.flush(); + assert encrypted_client_value != null; + long stop_time = System.nanoTime(); + + double run_time = (double) (stop_time - start_time); + run_time = run_time / 1000000; + System.out.printf("Comparison took %f ms\n", run_time); + if (((comparisonType == 1) && (ld.threshold == 0)) + || (comparisonType == 4) || (comparisonType == 5)) { + return Niu.Protocol4(encrypted_thresh, encrypted_client_value); + } + else { + return Niu.Protocol4(encrypted_client_value, encrypted_thresh); + } + } + + private void evaluate(int server_port) throws IOException, HomomorphicException { + ServerSocket serverSocket = new ServerSocket(server_port); + System.out.println("Server will be waiting for direct evaluation from client"); + Object client_input; + Hashtable features = new Hashtable<>(); + try (Socket client_site = serverSocket.accept()) { + ObjectOutputStream to_client_site = new ObjectOutputStream(client_site.getOutputStream()); + ObjectInputStream from_client_site = new ObjectInputStream(client_site.getInputStream()); + + // Get encrypted features + client_input = from_client_site.readObject(); + if (client_input instanceof Hashtable) { + for (Entry entry: ((Hashtable) client_input).entrySet()){ + if (entry.getKey() instanceof String && entry.getValue() instanceof BigIntegers) { + features.put((String) entry.getKey(), (BigIntegers) entry.getValue()); + } + } + } + alice Niu = new alice(client_site); + Niu.setPaillierPublicKey(paillier_public); + Niu.setDGKPublicKey(dgk_public); + + List node_level_data; + NodeInfo ls; + long start_time = System.nanoTime(); + + // Traverse DT until you hit a leaf, the client has to track the index... + for (level_order_site level_site_data : all_level_sites) { + node_level_data = level_site_data.get_node_data(); + + // Handle at a level... + int node_level_index = 0; + int n = 0; + int next_index = 0; + boolean equalsFound = false; + boolean inequalityHolds = false; + + while (!equalsFound) { + ls = node_level_data.get(node_level_index); + System.out.println("j=" + node_level_index); + if (ls.isLeaf()) { + if (n == 2 * level_site_data.get_current_index() + || n == 2 * level_site_data.get_current_index() + 1) { + // Tell the client the value + to_client_site.writeInt(-1); + to_client_site.writeObject(ls.getVariableName()); + to_client_site.flush(); + long stop_time = System.nanoTime(); + double run_time = (double) (stop_time - start_time); + run_time = run_time / 1000000; + System.out.printf("Total Server-Site run-time took %f ms\n", run_time); + break; + } + n += 2; + } else { + if ((n == 2 * level_site_data.get_current_index() + || n == 2 * level_site_data.get_current_index() + 1)) { + if (ls.comparisonType == 6) { + boolean firstInequalityHolds = compare(ls, 3, features, + to_client_site, Niu); + if (firstInequalityHolds) { + inequalityHolds = true; + } else { + boolean secondInequalityHolds = compare(ls, 5, features, + to_client_site, Niu); + if (secondInequalityHolds) { + inequalityHolds = true; + } + } + } else { + inequalityHolds = compare(ls, ls.comparisonType, features, + to_client_site, Niu); + } + + if (inequalityHolds) { + equalsFound = true; + level_site_data.set_next_index(next_index); + System.out.println("New index:" + level_site_data.get_current_index()); + } + } + n++; + next_index++; + } + node_level_index++; + } + } + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } + serverSocket.close(); } private static String hash(String text) throws NoSuchAlgorithmException { @@ -371,7 +538,19 @@ public void run() { ObjectInputStream from_level_site; int port_to_connect; + // If we are testing without level-sites do this... + if (!use_level_sites) { + try { + evaluate(this.server_port); + } + catch (IOException | HomomorphicException e) { + throw new RuntimeException(e); + } + return; + } + // There should be at least 1 IP Address for each level site + assert this.level_site_ips != null; if(this.level_site_ips.length < all_level_sites.size()) { String error = String.format("Please create more level-sites for the " + "decision tree trained from %s", training_data); diff --git a/src/test/java/PrivacyTest.java b/src/test/java/PrivacyTest.java index e7082bc..a4289cd 100644 --- a/src/test/java/PrivacyTest.java +++ b/src/test/java/PrivacyTest.java @@ -51,7 +51,55 @@ public void read_properties() throws IOException { } @Test - public void test_all() throws Exception { + public void test_single_site() throws Exception { + String answer_path = new File(data_directory, "answers.csv").toString(); + // Parse CSV file with various tests + try (BufferedReader br = new BufferedReader(new FileReader(answer_path))) { + String line; + while ((line = br.readLine()) != null) { + String [] values = line.split(","); + String data_set = values[0]; + String features = values[1]; + String expected_classification = values[2]; + String full_feature_path = new File(data_directory, features).toString(); + String full_data_set_path = new File(data_directory, data_set).toString(); + System.out.println(full_data_set_path); + String classification = test_server_case(full_data_set_path, full_feature_path, key_size, precision, + server_ip, server_port); + System.out.println(expected_classification + " =!= " + classification); + assertEquals(expected_classification, classification); + } + } + } + + public static String test_server_case(String training_data, String features_file, + int key_size, int precision, String server_ip, int server_port) + throws InterruptedException { + + // Create the server + server cloud = new server(training_data, precision, server_port); + Thread server = new Thread(cloud); + server.start(); + + // Create client + client evaluate = new client(key_size, features_file, precision, server_ip, server_port); + Thread client = new Thread(evaluate); + client.start(); + + // Programmatically wait until classification is done. + server.join(); + client.join(); + + // Be sure to delete any keys you made... + for (String file: delete_files) { + delete_file(file); + } + return evaluate.getClassification(); + } + + // Use this to test with level-sites + @Test + public void test_all_level_sites() throws Exception { String answer_path = new File(data_directory, "answers.csv").toString(); // Parse CSV file with various tests try (BufferedReader br = new BufferedReader(new FileReader(answer_path))) { @@ -64,7 +112,7 @@ public void test_all() throws Exception { String full_feature_path = new File(data_directory, features).toString(); String full_data_set_path = new File(data_directory, data_set).toString(); System.out.println(full_data_set_path); - String classification = test_case(full_data_set_path, full_feature_path, levels, key_size, precision, + String classification = test_level_site(full_data_set_path, full_feature_path, levels, key_size, precision, level_site_ips, level_site_ports_string, server_ip, server_port); System.out.println(expected_classification + " =!= " + classification); assertEquals(expected_classification, classification); @@ -72,7 +120,7 @@ public void test_all() throws Exception { } } - public static String test_case(String training_data, String features_file, int levels, + public static String test_level_site(String training_data, String features_file, int levels, int key_size, int precision, String [] level_site_ips, String [] level_site_ports_string, String server_ip, int server_port) throws InterruptedException, NoSuchPaddingException, NoSuchAlgorithmException { @@ -121,7 +169,7 @@ public static void delete_file(String file_name){ if (myObj.delete()) { System.out.println("Deleted the file: " + myObj.getName()); } else { - System.out.println("Failed to delete the file."); + System.out.println("Failed to delete the file: " + myObj.getName()); } } } From 24411fcb27b8dca34bdd9ea6b01dfadcd50371e7 Mon Sep 17 00:00:00 2001 From: AndrewQuijano Date: Sun, 15 Oct 2023 22:39:48 -0400 Subject: [PATCH 02/12] Going to create a Java class to store all shared code. Also can confirm I did NOT break the original level-sites approach, seems I got an indexing mistake going on. --- src/main/java/weka/finito/client.java | 8 +--- .../java/weka/finito/level_site_thread.java | 1 + src/main/java/weka/finito/server.java | 43 +++++++++++-------- src/main/java/weka/finito/utils/Shared.java | 14 ++++++ 4 files changed, 42 insertions(+), 24 deletions(-) create mode 100644 src/main/java/weka/finito/utils/Shared.java diff --git a/src/main/java/weka/finito/client.java b/src/main/java/weka/finito/client.java index 86a04d1..e6eb729 100644 --- a/src/main/java/weka/finito/client.java +++ b/src/main/java/weka/finito/client.java @@ -25,6 +25,8 @@ import security.socialistmillionaire.bob; import weka.finito.structs.BigIntegers; +import static weka.finito.utils.Shared.hash; + public final class client implements Runnable { private final String classes_file = "classes.txt"; private final String features_file; @@ -179,12 +181,6 @@ public void generate_keys() { paillier_private_key = (PaillierPrivateKey) paillier.getPrivate(); } - public static String hash(String text) throws NoSuchAlgorithmException { - MessageDigest digest = MessageDigest.getInstance("SHA-256"); - byte[] hash = digest.digest(text.getBytes(StandardCharsets.UTF_8)); - return Base64.getEncoder().encodeToString(hash); - } - private boolean need_keys() { try { dgk_public_key = DGKPublicKey.readKey("dgk.pub"); diff --git a/src/main/java/weka/finito/level_site_thread.java b/src/main/java/weka/finito/level_site_thread.java index d2817f5..fd2f745 100644 --- a/src/main/java/weka/finito/level_site_thread.java +++ b/src/main/java/weka/finito/level_site_thread.java @@ -143,6 +143,7 @@ public final void run() { } // Level Data is the Node Data... + // it is set to 0 by default... if (previous_index != null) { this.level_site_data.set_current_index(Integer.parseInt(previous_index)); } diff --git a/src/main/java/weka/finito/server.java b/src/main/java/weka/finito/server.java index 2e30659..8316897 100644 --- a/src/main/java/weka/finito/server.java +++ b/src/main/java/weka/finito/server.java @@ -11,9 +11,6 @@ import java.math.BigInteger; import java.net.ServerSocket; import java.net.Socket; -import java.nio.charset.StandardCharsets; -import java.security.MessageDigest; -import java.security.NoSuchAlgorithmException; import java.util.*; import java.util.Map.Entry; @@ -34,6 +31,8 @@ import weka.finito.structs.level_order_site; import weka.finito.structs.NodeInfo; +import static weka.finito.utils.Shared.hash; + public final class server implements Runnable { @@ -201,19 +200,23 @@ private void evaluate(int server_port) throws IOException, HomomorphicException Niu.setDGKPublicKey(dgk_public); List node_level_data; - NodeInfo ls; + NodeInfo ls = null; long start_time = System.nanoTime(); + int previous_index = 0; // Traverse DT until you hit a leaf, the client has to track the index... for (level_order_site level_site_data : all_level_sites) { node_level_data = level_site_data.get_node_data(); + level_site_data.set_current_index(previous_index); + // Handle at a level... int node_level_index = 0; int n = 0; int next_index = 0; boolean equalsFound = false; boolean inequalityHolds = false; + boolean terminalLeafFound = false; while (!equalsFound) { ls = node_level_data.get(node_level_index); @@ -221,14 +224,8 @@ private void evaluate(int server_port) throws IOException, HomomorphicException if (ls.isLeaf()) { if (n == 2 * level_site_data.get_current_index() || n == 2 * level_site_data.get_current_index() + 1) { - // Tell the client the value - to_client_site.writeInt(-1); - to_client_site.writeObject(ls.getVariableName()); - to_client_site.flush(); - long stop_time = System.nanoTime(); - double run_time = (double) (stop_time - start_time); - run_time = run_time / 1000000; - System.out.printf("Total Server-Site run-time took %f ms\n", run_time); + terminalLeafFound = true; + System.out.println("Terminal leaf:" + ls.getVariableName()); break; } n += 2; @@ -263,6 +260,22 @@ private void evaluate(int server_port) throws IOException, HomomorphicException } node_level_index++; } + + if (terminalLeafFound) { + // Tell the client the value + to_client_site.writeInt(-1); + to_client_site.writeObject(ls.getVariableName()); + to_client_site.flush(); + long stop_time = System.nanoTime(); + double run_time = (double) (stop_time - start_time); + run_time = run_time / 1000000; + System.out.printf("Total Server-Site run-time took %f ms\n", run_time); + break; + } + else { + // Remember this for next loop... + previous_index = next_index; + } } } catch (ClassNotFoundException e) { throw new RuntimeException(e); @@ -270,12 +283,6 @@ private void evaluate(int server_port) throws IOException, HomomorphicException serverSocket.close(); } - private static String hash(String text) throws NoSuchAlgorithmException { - MessageDigest digest = MessageDigest.getInstance("SHA-256"); - byte[] hash = digest.digest(text.getBytes(StandardCharsets.UTF_8)); - return Base64.getEncoder().encodeToString(hash); - } - private void client_communication() throws Exception { ServerSocket serverSocket = new ServerSocket(server_port); System.out.println("Server ready to get public keys from client"); diff --git a/src/main/java/weka/finito/utils/Shared.java b/src/main/java/weka/finito/utils/Shared.java new file mode 100644 index 0000000..0422f7d --- /dev/null +++ b/src/main/java/weka/finito/utils/Shared.java @@ -0,0 +1,14 @@ +package weka.finito.utils; + +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.util.Base64; + +public class Shared { + public static String hash(String text) throws NoSuchAlgorithmException { + MessageDigest digest = MessageDigest.getInstance("SHA-256"); + byte[] hash = digest.digest(text.getBytes(StandardCharsets.UTF_8)); + return Base64.getEncoder().encodeToString(hash); + } +} From b7c929eef649e403abc322417cd5be326061bb86 Mon Sep 17 00:00:00 2001 From: AndrewQuijano Date: Sun, 15 Oct 2023 23:19:12 -0400 Subject: [PATCH 03/12] got being able to do everytihng in server-site complete. Also migrated compare() to shared. Likely will do same with overall computing index, and will need to start messing around with log4j --- src/main/java/weka/finito/client.java | 5 +- .../java/weka/finito/level_site_thread.java | 83 ++++++------------- src/main/java/weka/finito/server.java | 81 ++++-------------- src/main/java/weka/finito/utils/Shared.java | 14 ---- src/main/java/weka/finito/utils/shared.java | 65 +++++++++++++++ 5 files changed, 108 insertions(+), 140 deletions(-) delete mode 100644 src/main/java/weka/finito/utils/Shared.java create mode 100644 src/main/java/weka/finito/utils/shared.java diff --git a/src/main/java/weka/finito/client.java b/src/main/java/weka/finito/client.java index e6eb729..b689e11 100644 --- a/src/main/java/weka/finito/client.java +++ b/src/main/java/weka/finito/client.java @@ -3,11 +3,8 @@ import java.io.*; import java.math.BigInteger; import java.net.Socket; -import java.nio.charset.StandardCharsets; import java.security.KeyPair; -import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; -import java.util.Base64; import java.util.HashMap; import java.util.Hashtable; @@ -25,7 +22,7 @@ import security.socialistmillionaire.bob; import weka.finito.structs.BigIntegers; -import static weka.finito.utils.Shared.hash; +import static weka.finito.utils.shared.hash; public final class client implements Runnable { private final String classes_file = "classes.txt"; diff --git a/src/main/java/weka/finito/level_site_thread.java b/src/main/java/weka/finito/level_site_thread.java index fd2f745..4eb79f2 100644 --- a/src/main/java/weka/finito/level_site_thread.java +++ b/src/main/java/weka/finito/level_site_thread.java @@ -2,7 +2,6 @@ import java.lang.System; -import security.misc.HomomorphicException; import security.socialistmillionaire.alice; import weka.finito.structs.BigIntegers; import weka.finito.structs.NodeInfo; @@ -11,10 +10,12 @@ import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; -import java.math.BigInteger; import java.net.Socket; import java.util.Hashtable; import java.util.List; +import java.util.Map; + +import static weka.finito.utils.shared.compare; public class level_site_thread implements Runnable { @@ -24,9 +25,7 @@ public class level_site_thread implements Runnable { private level_order_site level_site_data = null; - private alice Niu = null; - - private Hashtable encrypted_features; + private final Hashtable encrypted_features = new Hashtable<>(); private final AES crypto; public level_site_thread(Socket client_socket, level_order_site level_site_data, AES crypto) { @@ -46,7 +45,11 @@ public level_site_thread(Socket client_socket, level_order_site level_site_data, this.toClient.writeBoolean(true); closeClientConnection(); } else if (x instanceof Hashtable) { - encrypted_features = (Hashtable) x; + for (Map.Entry entry: ((Hashtable) x).entrySet()){ + if (entry.getKey() instanceof String && entry.getValue() instanceof BigIntegers) { + encrypted_features.put((String) entry.getKey(), (BigIntegers) entry.getValue()); + } + } // Have encrypted copy of thresholds if not done already for all nodes in level-site this.level_site_data = level_site_data; } else { @@ -70,45 +73,6 @@ private void closeClientConnection() throws IOException { } } - // a - from CLIENT, should already be encrypted... - private boolean compare(NodeInfo ld, int comparisonType) - throws HomomorphicException, ClassNotFoundException, IOException { - - long start_time = System.nanoTime(); - - BigIntegers encrypted_values = this.encrypted_features.get(ld.variable_name); - BigInteger encrypted_client_value = null; - BigInteger encrypted_thresh = null; - - // Encrypt the thresh-hold correctly - if ((comparisonType == 1) || (comparisonType == 2) || (comparisonType == 4)) { - encrypted_thresh = ld.getPaillier(); - encrypted_client_value = encrypted_values.getIntegerValuePaillier(); - toClient.writeInt(0); - Niu.setDGKMode(false); - } - else if ((comparisonType == 3) || (comparisonType == 5)) { - encrypted_thresh = ld.getDGK(); - encrypted_client_value = encrypted_values.getIntegerValueDGK(); - toClient.writeInt(1); - Niu.setDGKMode(true); - } - toClient.flush(); - assert encrypted_client_value != null; - long stop_time = System.nanoTime(); - - double run_time = (double) (stop_time - start_time); - run_time = run_time / 1000000; - System.out.printf("Comparison took %f ms\n", run_time); - if (((comparisonType == 1) && (ld.threshold == 0)) - || (comparisonType == 4) || (comparisonType == 5)) { - return Niu.Protocol4(encrypted_thresh, encrypted_client_value); - } - else { - return Niu.Protocol4(encrypted_client_value, encrypted_thresh); - } - } - // This will run the communication with client and next level site public final void run() { Object o; @@ -118,7 +82,7 @@ public final void run() { long start_time = System.nanoTime(); try { - Niu = new alice(client_socket); + alice niu = new alice(client_socket); if (this.level_site_data == null) { toClient.writeInt(-2); closeClientConnection(); @@ -126,8 +90,8 @@ public final void run() { } List node_level_data = this.level_site_data.get_node_data(); - Niu.setDGKPublicKey(this.level_site_data.dgk_public_key); - Niu.setPaillierPublicKey(this.level_site_data.paillier_public_key); + niu.setDGKPublicKey(this.level_site_data.dgk_public_key); + niu.setPaillierPublicKey(this.level_site_data.paillier_public_key); get_previous_index = fromClient.readBoolean(); if (get_previous_index) { @@ -161,32 +125,39 @@ public final void run() { ls = node_level_data.get(node_level_index); System.out.println("j=" + node_level_index); if (ls.isLeaf()) { - if (n == 2 * this.level_site_data.get_current_index() || n == 2 * this.level_site_data.get_current_index() + 1) { + if (n == 2 * this.level_site_data.get_current_index() + || n == 2 * this.level_site_data.get_current_index() + 1) { terminalLeafFound = true; System.out.println("Terminal leaf:" + ls.getVariableName()); } n += 2; } else { - if ((n==2 * this.level_site_data.get_current_index() || n == 2 * this.level_site_data.get_current_index() + 1)) { + if ((n==2 * this.level_site_data.get_current_index() + || n == 2 * this.level_site_data.get_current_index() + 1)) { if (ls.comparisonType == 6) { - boolean firstInequalityHolds = compare(ls, 3); + boolean firstInequalityHolds = compare(ls, 3, + encrypted_features, toClient, niu); if (firstInequalityHolds) { inequalityHolds = true; - } else { - boolean secondInequalityHolds = compare(ls, 5); + } + else { + boolean secondInequalityHolds = compare(ls, 5, + encrypted_features, toClient, niu); if (secondInequalityHolds) { inequalityHolds = true; } } - } else { - inequalityHolds = compare(ls, ls.comparisonType); + } + else { + inequalityHolds = compare(ls, ls.comparisonType, + encrypted_features, toClient, niu); } if (inequalityHolds) { equalsFound = true; this.level_site_data.set_next_index(next_index); - System.out.println("New index:" + this.level_site_data.get_current_index()); + System.out.println("New index: " + this.level_site_data.get_next_index()); } } n++; diff --git a/src/main/java/weka/finito/server.java b/src/main/java/weka/finito/server.java index 8316897..fd41374 100644 --- a/src/main/java/weka/finito/server.java +++ b/src/main/java/weka/finito/server.java @@ -8,7 +8,6 @@ import java.io.FileReader; import java.io.IOException; -import java.math.BigInteger; import java.net.ServerSocket; import java.net.Socket; import java.util.*; @@ -31,7 +30,7 @@ import weka.finito.structs.level_order_site; import weka.finito.structs.NodeInfo; -import static weka.finito.utils.Shared.hash; +import static weka.finito.utils.shared.*; public final class server implements Runnable { @@ -137,46 +136,6 @@ public server(String training_data, int precision, int server_port) { this.use_level_sites = false; } - private boolean compare(NodeInfo ld, int comparisonType, - Hashtable encrypted_features, - ObjectOutputStream toClient, alice Niu) - throws ClassNotFoundException, HomomorphicException, IOException { - - long start_time = System.nanoTime(); - - BigIntegers encrypted_values = encrypted_features.get(ld.variable_name); - BigInteger encrypted_client_value = null; - BigInteger encrypted_thresh = null; - - // Encrypt the thresh-hold correctly - if ((comparisonType == 1) || (comparisonType == 2) || (comparisonType == 4)) { - encrypted_thresh = ld.getPaillier(); - encrypted_client_value = encrypted_values.getIntegerValuePaillier(); - toClient.writeInt(0); - Niu.setDGKMode(false); - } - else if ((comparisonType == 3) || (comparisonType == 5)) { - encrypted_thresh = ld.getDGK(); - encrypted_client_value = encrypted_values.getIntegerValueDGK(); - toClient.writeInt(1); - Niu.setDGKMode(true); - } - toClient.flush(); - assert encrypted_client_value != null; - long stop_time = System.nanoTime(); - - double run_time = (double) (stop_time - start_time); - run_time = run_time / 1000000; - System.out.printf("Comparison took %f ms\n", run_time); - if (((comparisonType == 1) && (ld.threshold == 0)) - || (comparisonType == 4) || (comparisonType == 5)) { - return Niu.Protocol4(encrypted_thresh, encrypted_client_value); - } - else { - return Niu.Protocol4(encrypted_client_value, encrypted_thresh); - } - } - private void evaluate(int server_port) throws IOException, HomomorphicException { ServerSocket serverSocket = new ServerSocket(server_port); System.out.println("Server will be waiting for direct evaluation from client"); @@ -200,32 +159,36 @@ private void evaluate(int server_port) throws IOException, HomomorphicException Niu.setDGKPublicKey(dgk_public); List node_level_data; - NodeInfo ls = null; + NodeInfo ls; long start_time = System.nanoTime(); int previous_index = 0; // Traverse DT until you hit a leaf, the client has to track the index... for (level_order_site level_site_data : all_level_sites) { node_level_data = level_site_data.get_node_data(); - level_site_data.set_current_index(previous_index); // Handle at a level... int node_level_index = 0; int n = 0; int next_index = 0; - boolean equalsFound = false; boolean inequalityHolds = false; - boolean terminalLeafFound = false; - while (!equalsFound) { + while (true) { ls = node_level_data.get(node_level_index); System.out.println("j=" + node_level_index); if (ls.isLeaf()) { if (n == 2 * level_site_data.get_current_index() || n == 2 * level_site_data.get_current_index() + 1) { - terminalLeafFound = true; System.out.println("Terminal leaf:" + ls.getVariableName()); + // Tell the client the value + to_client_site.writeInt(-1); + to_client_site.writeObject(ls.getVariableName()); + to_client_site.flush(); + long stop_time = System.nanoTime(); + double run_time = (double) (stop_time - start_time); + run_time = run_time / 1000000; + System.out.printf("Total Server-Site run-time took %f ms\n", run_time); break; } n += 2; @@ -250,9 +213,11 @@ private void evaluate(int server_port) throws IOException, HomomorphicException } if (inequalityHolds) { - equalsFound = true; level_site_data.set_next_index(next_index); - System.out.println("New index:" + level_site_data.get_current_index()); + System.out.println("Next index:" + next_index); + // Remember this for next loop... + previous_index = next_index; + break; } } n++; @@ -260,22 +225,6 @@ private void evaluate(int server_port) throws IOException, HomomorphicException } node_level_index++; } - - if (terminalLeafFound) { - // Tell the client the value - to_client_site.writeInt(-1); - to_client_site.writeObject(ls.getVariableName()); - to_client_site.flush(); - long stop_time = System.nanoTime(); - double run_time = (double) (stop_time - start_time); - run_time = run_time / 1000000; - System.out.printf("Total Server-Site run-time took %f ms\n", run_time); - break; - } - else { - // Remember this for next loop... - previous_index = next_index; - } } } catch (ClassNotFoundException e) { throw new RuntimeException(e); diff --git a/src/main/java/weka/finito/utils/Shared.java b/src/main/java/weka/finito/utils/Shared.java deleted file mode 100644 index 0422f7d..0000000 --- a/src/main/java/weka/finito/utils/Shared.java +++ /dev/null @@ -1,14 +0,0 @@ -package weka.finito.utils; - -import java.nio.charset.StandardCharsets; -import java.security.MessageDigest; -import java.security.NoSuchAlgorithmException; -import java.util.Base64; - -public class Shared { - public static String hash(String text) throws NoSuchAlgorithmException { - MessageDigest digest = MessageDigest.getInstance("SHA-256"); - byte[] hash = digest.digest(text.getBytes(StandardCharsets.UTF_8)); - return Base64.getEncoder().encodeToString(hash); - } -} diff --git a/src/main/java/weka/finito/utils/shared.java b/src/main/java/weka/finito/utils/shared.java new file mode 100644 index 0000000..d8debc5 --- /dev/null +++ b/src/main/java/weka/finito/utils/shared.java @@ -0,0 +1,65 @@ +package weka.finito.utils; + +import security.misc.HomomorphicException; +import security.socialistmillionaire.alice; +import weka.finito.structs.BigIntegers; +import weka.finito.structs.NodeInfo; + +import java.io.IOException; +import java.io.ObjectOutputStream; +import java.math.BigInteger; +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.util.Base64; +import java.util.Hashtable; + +public class shared { + // Used by server-site to hash leaves and client-site to find the leaf + public static String hash(String text) throws NoSuchAlgorithmException { + MessageDigest digest = MessageDigest.getInstance("SHA-256"); + byte[] hash = digest.digest(text.getBytes(StandardCharsets.UTF_8)); + return Base64.getEncoder().encodeToString(hash); + } + + // Used by level-site and server-site to compare with a client + public static boolean compare(NodeInfo ld, int comparisonType, + Hashtable encrypted_features, + ObjectOutputStream toClient, alice Niu) + throws ClassNotFoundException, HomomorphicException, IOException { + + long start_time = System.nanoTime(); + + BigIntegers encrypted_values = encrypted_features.get(ld.variable_name); + BigInteger encrypted_client_value = null; + BigInteger encrypted_thresh = null; + + // Encrypt the thresh-hold correctly + if ((comparisonType == 1) || (comparisonType == 2) || (comparisonType == 4)) { + encrypted_thresh = ld.getPaillier(); + encrypted_client_value = encrypted_values.getIntegerValuePaillier(); + toClient.writeInt(0); + Niu.setDGKMode(false); + } + else if ((comparisonType == 3) || (comparisonType == 5)) { + encrypted_thresh = ld.getDGK(); + encrypted_client_value = encrypted_values.getIntegerValueDGK(); + toClient.writeInt(1); + Niu.setDGKMode(true); + } + toClient.flush(); + assert encrypted_client_value != null; + long stop_time = System.nanoTime(); + + double run_time = (double) (stop_time - start_time); + run_time = run_time / 1000000; + System.out.printf("Comparison took %f ms\n", run_time); + if (((comparisonType == 1) && (ld.threshold == 0)) + || (comparisonType == 4) || (comparisonType == 5)) { + return Niu.Protocol4(encrypted_thresh, encrypted_client_value); + } + else { + return Niu.Protocol4(encrypted_client_value, encrypted_thresh); + } + } +} From 18d8044309608daa1fd90cafe2eb5265b7ca42c2 Mon Sep 17 00:00:00 2001 From: AndrewQuijano Date: Sun, 22 Oct 2023 00:02:04 -0400 Subject: [PATCH 04/12] Successfully started the switch to TLS. I think next goal is to have server-site loop to take as many requests as possible for evaluation. Also, I think now I can investigate keeping connections alive between level-sites to save time and handshakes, and client having to pass the index. This info would be useful for a future paper on PPDT I got in mind --- .gitattributes | 3 + .github/workflows/build-gradle-project.yml | 7 +++ .gitignore | 4 ++ create_certificate.sh | 21 +++++++ k8/server-deploy/server_deployment.yaml | 40 +++++++++++++ k8/server-deploy/server_service.yaml | 12 ++++ k8/{server => server-job}/server_service.yaml | 0 .../server_training_job.yaml | 0 src/main/java/weka/finito/client.java | 18 +++++- .../java/weka/finito/level_site_server.java | 23 +++++--- .../java/weka/finito/level_site_thread.java | 4 +- src/main/java/weka/finito/server.java | 59 +++++++++++++++---- .../weka/finito/structs/level_order_site.java | 39 +++++++++--- src/main/java/weka/finito/utils/shared.java | 18 ++++++ src/test/java/PrivacyTest.java | 4 ++ 15 files changed, 225 insertions(+), 27 deletions(-) create mode 100644 .gitattributes create mode 100644 create_certificate.sh create mode 100644 k8/server-deploy/server_deployment.yaml create mode 100644 k8/server-deploy/server_service.yaml rename k8/{server => server-job}/server_service.yaml (100%) rename k8/{server => server-job}/server_training_job.yaml (100%) diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..4cb403f --- /dev/null +++ b/.gitattributes @@ -0,0 +1,3 @@ +# On GitHub Actions with Windows, git will use CRLF by default when checking out +# the repo. If this file has CRLF line endings +configure.ac text eol=lf \ No newline at end of file diff --git a/.github/workflows/build-gradle-project.yml b/.github/workflows/build-gradle-project.yml index e933734..ffacb80 100644 --- a/.github/workflows/build-gradle-project.yml +++ b/.github/workflows/build-gradle-project.yml @@ -6,6 +6,12 @@ on: jobs: build-gradle-project: runs-on: ubuntu-latest + env: + ALIAS: andrew + KEYSTORE: andrew_keystore + PASSWORD: ${{ secrets.PASSWORD }} + CERTIFICATE: andrew_certificate + steps: - name: Checkout project sources uses: actions/checkout@v3 @@ -19,6 +25,7 @@ jobs: distribution: 'oracle' java-version: '17' cache: 'gradle' + - run: sh gradlew build - name: Upload coverage reports to Codecov diff --git a/.gitignore b/.gitignore index 208f30f..37efd50 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ +# Making certificates, dont mess around here +.env +andrew_* + # No Binaries bin diff --git a/create_certificate.sh b/create_certificate.sh new file mode 100644 index 0000000..4d4fbc8 --- /dev/null +++ b/create_certificate.sh @@ -0,0 +1,21 @@ +#!/bin/bash +# https://docs.oracle.com/cd/E54932_01/doc.705/e54936/cssg_create_ssl_cert.htm#CSVSG179 + +# Generate the self-signed certificate and place it in the KeyStore +keytool -genkey -noprompt \ + -alias "${ALIAS}" \ + -dname "CN=mqttserver.ibm.com, OU=ID, O=IBM, L=Hursley, S=Hants, C=US" \ + -keystore "${KEYSTORE}" \ + -storepass "${PASSWORD}" \ + -keypass "${PASSWORD}" \ + -keyalg RSA \ + -sigalg SHA256withRSA + +# Export the certificate needs to a certificate file, use this on both client and server to run +# Technically you just need the key store between level-sites, instead of hashing, just encrypt the leaf value +# with Paillier? + +# -Djavax.net.ssl.keyStore=${KEYSTORE} -Djavax.net.ssl.keyStorePassword=${PASSWORD} +# -Djavax.net.ssl.trustStore=${KEYSTORE} -Djavax.net.ssl.trustStorePassword=${PASSWORD} + +keytool -export -alias "${ALIAS}" -keystore "${KEYSTORE}" -rfc -file "${CERTIFICATE}" -storepass "${PASSWORD}" diff --git a/k8/server-deploy/server_deployment.yaml b/k8/server-deploy/server_deployment.yaml new file mode 100644 index 0000000..15bb73a --- /dev/null +++ b/k8/server-deploy/server_deployment.yaml @@ -0,0 +1,40 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: ppdt-server-deploy + labels: + app: ppdt-server-deploy + role: client +spec: + replicas: 1 + selector: + matchLabels: + pod: ppdt-server-deploy + template: + metadata: + labels: + pod: ppdt-server-deploy + spec: + containers: + - name: ppdt-server-deploy + image: andrewquijano92/ppdt + ports: + - containerPort: 9000 + env: + - name: TREE_ROLE + value: "SERVER" + + - name: PRECISION + value: "2" + + - name: PORT_NUM + value: "9000" + + - name: LEVEL_SITE_DOMAINS + value: "ppdt-level-site-01-service,ppdt-level-site-02-service,ppdt-level-site-03-service,ppdt-level-site-04-service,ppdt-level-site-05-service,ppdt-level-site-06-service,ppdt-level-site-07-service,ppdt-level-site-08-service,ppdt-level-site-09-service,ppdt-level-site-10-service" + + - name: CLIENT + value: "ppdt-client-service" + + - name: GRADLE_USER_HOME + value: "gradle_user_home" diff --git a/k8/server-deploy/server_service.yaml b/k8/server-deploy/server_service.yaml new file mode 100644 index 0000000..a1456b3 --- /dev/null +++ b/k8/server-deploy/server_service.yaml @@ -0,0 +1,12 @@ +kind: Service +apiVersion: v1 +metadata: + name: ppdt-server-service +spec: + selector: + pod: ppdt-server-deploy + ports: + - protocol: TCP + port: 9000 + targetPort: 9000 + type: NodePort \ No newline at end of file diff --git a/k8/server/server_service.yaml b/k8/server-job/server_service.yaml similarity index 100% rename from k8/server/server_service.yaml rename to k8/server-job/server_service.yaml diff --git a/k8/server/server_training_job.yaml b/k8/server-job/server_training_job.yaml similarity index 100% rename from k8/server/server_training_job.yaml rename to k8/server-job/server_training_job.yaml diff --git a/src/main/java/weka/finito/client.java b/src/main/java/weka/finito/client.java index b689e11..ab5a041 100644 --- a/src/main/java/weka/finito/client.java +++ b/src/main/java/weka/finito/client.java @@ -22,7 +22,10 @@ import security.socialistmillionaire.bob; import weka.finito.structs.BigIntegers; -import static weka.finito.utils.shared.hash; +import javax.net.ssl.SSLSocket; +import javax.net.ssl.SSLSocketFactory; + +import static weka.finito.utils.shared.*; public final class client implements Runnable { private final String classes_file = "classes.txt"; @@ -392,6 +395,10 @@ else if (comparison_type == 1) { // Function used to Train (if needed) and Evaluate public void run() { + + // Step : 1 + SSLSocketFactory factory = (SSLSocketFactory) SSLSocketFactory.getDefault(); + boolean talk_to_server_site = this.need_keys(); try { @@ -461,7 +468,14 @@ public void run() { connection_port = port; } - try(Socket level_site = new Socket(level_site_ips[i], connection_port)) { + //try(Socket level_site = new Socket(level_site_ips[i], connection_port)) { + try(SSLSocket level_site = (SSLSocket) factory.createSocket(level_site_ips[i], connection_port)) { + // Step : 3 + level_site.setEnabledProtocols(protocols); + level_site.setEnabledCipherSuites(cipher_suites); + + // Step : 4 {optional} + level_site.startHandshake(); System.out.println("Client connected to level " + i); communicate_with_level_site(level_site); } diff --git a/src/main/java/weka/finito/level_site_server.java b/src/main/java/weka/finito/level_site_server.java index f116025..68ce20f 100644 --- a/src/main/java/weka/finito/level_site_server.java +++ b/src/main/java/weka/finito/level_site_server.java @@ -3,23 +3,28 @@ import weka.finito.structs.level_order_site; import javax.crypto.NoSuchPaddingException; +import javax.net.ssl.SSLServerSocket; +import javax.net.ssl.SSLServerSocketFactory; +import javax.net.ssl.SSLSocket; import java.io.IOException; -import java.net.ServerSocket; -import java.net.Socket; import java.lang.System; import java.security.NoSuchAlgorithmException; +import static weka.finito.utils.shared.cipher_suites; +import static weka.finito.utils.shared.protocols; + public class level_site_server implements Runnable { protected int serverPort; - protected ServerSocket serverSocket = null; + protected SSLServerSocket serverSocket = null; protected boolean isStopped = false; protected Thread runningThread= null; protected level_order_site level_site_parameters = null; protected int precision; protected AES crypto; + protected SSLServerSocketFactory factory; public static void main(String[] args) throws NoSuchPaddingException, NoSuchAlgorithmException { int our_port = 0; @@ -63,10 +68,10 @@ public void run() { openServerSocket(); while(! isStopped()) { - Socket clientSocket; + SSLSocket clientSocket; try { System.out.println("Ready to accept connections at: " + this.serverPort); - clientSocket = this.serverSocket.accept(); + clientSocket = (SSLSocket) this.serverSocket.accept(); } catch (IOException e) { if(isStopped()) { @@ -84,7 +89,7 @@ public void run() { } else { // Received new data from server-site, overwrite the existing copy if it is new - if (this.level_site_parameters.compareTo(new_data) != 0) { + if (!this.level_site_parameters.equals(new_data)) { this.level_site_parameters = new_data; // System.out.println("New Training Data received...Overwriting now..."); } @@ -116,7 +121,11 @@ public synchronized void stop(){ private void openServerSocket() { try { - this.serverSocket = new ServerSocket(this.serverPort); + // Step : 1 + factory = (SSLServerSocketFactory) SSLServerSocketFactory.getDefault(); + serverSocket = (SSLServerSocket) factory.createServerSocket(this.serverPort); + serverSocket.setEnabledProtocols(protocols); + serverSocket.setEnabledCipherSuites(cipher_suites); } catch (IOException e) { throw new RuntimeException("Cannot open port " + this.serverPort, e); diff --git a/src/main/java/weka/finito/level_site_thread.java b/src/main/java/weka/finito/level_site_thread.java index 4eb79f2..69d227c 100644 --- a/src/main/java/weka/finito/level_site_thread.java +++ b/src/main/java/weka/finito/level_site_thread.java @@ -31,6 +31,8 @@ public class level_site_thread implements Runnable { public level_site_thread(Socket client_socket, level_order_site level_site_data, AES crypto) { this.client_socket = client_socket; this.crypto = crypto; + String clientIpAddress = client_socket.getInetAddress().getHostAddress(); + System.out.println("Level-Site got connection from client: " + clientIpAddress); Object x; try { @@ -53,7 +55,7 @@ public level_site_thread(Socket client_socket, level_order_site level_site_data, // Have encrypted copy of thresholds if not done already for all nodes in level-site this.level_site_data = level_site_data; } else { - System.out.println("Wrong Object Received: " + x.getClass().toString()); + System.out.println("Wrong Object Received: " + x.getClass()); closeClientConnection(); } } catch (Exception e) { diff --git a/src/main/java/weka/finito/server.java b/src/main/java/weka/finito/server.java index fd41374..a43dd7e 100644 --- a/src/main/java/weka/finito/server.java +++ b/src/main/java/weka/finito/server.java @@ -30,6 +30,9 @@ import weka.finito.structs.level_order_site; import weka.finito.structs.NodeInfo; +import javax.net.ssl.SSLSocket; +import javax.net.ssl.SSLSocketFactory; + import static weka.finito.utils.shared.*; @@ -50,6 +53,8 @@ public final class server implements Runnable { private final int server_port; private final boolean use_level_sites; + private String client; + public static void main(String[] args) { int port = 0; int precision = 0; @@ -91,6 +96,9 @@ else if (args.length == 2){ } String[] level_domains = level_domains_str.split(","); + // Figure out Client IP for last level-site to send back + String client = System.getenv("CLIENT"); + // Create and run the server. System.out.println("Server Initialized and started running"); @@ -98,10 +106,9 @@ else if (args.length == 2){ // Pick either Level-Sites or no Level-site if (use_level_sites) { server = new server(training_data, level_domains, port, precision, port); - } else { - server = new server(training_data, precision, port); + server = new server(training_data, precision, port, client); } server.run(); } @@ -136,6 +143,16 @@ public server(String training_data, int precision, int server_port) { this.use_level_sites = false; } + public server(String training_data, int precision, int server_port, String client) { + this.training_data = training_data; + this.level_site_ips = null; + this.precision = precision; + this.server_port = server_port; + this.use_level_sites = false; + this.client = client; + } + + private void evaluate(int server_port) throws IOException, HomomorphicException { ServerSocket serverSocket = new ServerSocket(server_port); System.out.println("Server will be waiting for direct evaluation from client"); @@ -246,7 +263,6 @@ private void client_communication() throws Exception { o = from_client_site.readObject(); this.dgk_public = (DGKPublicKey) o; - System.out.println("Server collected keys from client"); // Train level-sites @@ -480,6 +496,9 @@ else if (node_info.comparisonType == 6) { public void run() { + // Step : 1 + SSLSocketFactory factory = (SSLSocketFactory) SSLSocketFactory.getDefault(); + try { // Train the DT ppdt = train_decision_tree(this.training_data); @@ -492,7 +511,7 @@ public void run() { ObjectOutputStream to_level_site; ObjectInputStream from_level_site; - int port_to_connect; + int connection_port; // If we are testing without level-sites do this... if (!use_level_sites) { @@ -518,22 +537,42 @@ public void run() { level_order_site current_level_site = all_level_sites.get(i); if (port == -1) { - port_to_connect = this.level_site_ports[i]; + connection_port = this.level_site_ports[i]; + } + else { + connection_port = this.port; + } + + if (i + 1 != all_level_sites.size()) { + current_level_site.set_next_level_site(level_site_ips[(i + 1) % level_site_ips.length]); } else { - port_to_connect = this.port; + current_level_site.set_next_level_site(client); } + current_level_site.set_next_level_site_port(connection_port); + + // System.out.println("Current level-site: " + level_site_ips[i] + ":" + port_to_connect); + // System.out.println("Next is: " + current_level_site.get_next_level_site() + // + ":" + current_level_site.get_next_level_site_port()); + + //try (Socket level_site = new Socket(level_site_ips[i], port_to_connect)) { + try(SSLSocket level_site = (SSLSocket) factory.createSocket(level_site_ips[i], connection_port)) { + // Step : 3 + level_site.setEnabledProtocols(protocols); + level_site.setEnabledCipherSuites(cipher_suites); + + // Step : 4 {optional} + level_site.startHandshake(); - try (Socket level_site = new Socket(level_site_ips[i], port_to_connect)) { - System.out.println("training level-site " + i + " on port:" + port_to_connect); + System.out.println("training level-site " + i + " on port:" + connection_port); to_level_site = new ObjectOutputStream(level_site.getOutputStream()); from_level_site = new ObjectInputStream(level_site.getInputStream()); to_level_site.writeObject(current_level_site); if(from_level_site.readBoolean()) { - System.out.println("Training Successful on port:" + port_to_connect); + System.out.println("Training Successful on port:" + connection_port); } else { - System.out.println("Training NOT Successful on port:" + port_to_connect); + System.out.println("Training NOT Successful on port:" + connection_port); } } catch (IOException e) { diff --git a/src/main/java/weka/finito/structs/level_order_site.java b/src/main/java/weka/finito/structs/level_order_site.java index 8c5eec3..7f9ddd1 100644 --- a/src/main/java/weka/finito/structs/level_order_site.java +++ b/src/main/java/weka/finito/structs/level_order_site.java @@ -12,7 +12,7 @@ * This class contains all the necessary information for a Level-site to complete evaluation */ -public final class level_order_site implements Serializable, Comparable { +public final class level_order_site implements Serializable { private static final long serialVersionUID = 575566807906351024L; private int index = 0; @@ -24,6 +24,11 @@ public final class level_order_site implements Serializable, Comparable node_level_data = new ArrayList<>(); + // Set to value to client, to let level-site d-1 know to talk to client next with answer + private String next_level_site = "client"; + + private int next_level_site_port = -1; + public level_order_site(int level, PaillierPublicKey paillier_public_key, DGKPublicKey dgk_public_key) { this.level = level; this.paillier_public_key = paillier_public_key; @@ -49,7 +54,23 @@ public int get_next_index() { public int get_current_index() { return this.index; } - + + public void set_next_level_site(String next_level_site) { + this.next_level_site = next_level_site; + } + + public String get_next_level_site() { + return this.next_level_site; + } + + public void set_next_level_site_port(int next_level_site_port) { + this.next_level_site_port = next_level_site_port; + } + + public int get_next_level_site_port() { + return this.next_level_site_port; + } + public void append_data(NodeInfo info) { node_level_data.add(info); } @@ -71,23 +92,27 @@ public String toString() { return output.toString(); } - public int compareTo(level_order_site o) { + public boolean equals(level_order_site o) { + if (o == this) { + return true; + } + if (this.node_level_data.size() == o.node_level_data.size()) { for (int i = 0; i < node_level_data.size(); i++) { NodeInfo left = this.node_level_data.get(i); NodeInfo right = o.node_level_data.get(i); int comparison = left.compareTo(right); if (comparison != 0) { - return comparison; + return false; } } - return 0; + return true; } else if (this.node_level_data.size() > o.node_level_data.size()) { - return 1; + return false; } else { - return -1; + return false; } } } diff --git a/src/main/java/weka/finito/utils/shared.java b/src/main/java/weka/finito/utils/shared.java index d8debc5..0844cc4 100644 --- a/src/main/java/weka/finito/utils/shared.java +++ b/src/main/java/weka/finito/utils/shared.java @@ -13,8 +13,14 @@ import java.security.NoSuchAlgorithmException; import java.util.Base64; import java.util.Hashtable; +import java.util.Properties; public class shared { + + public static final String[] protocols = new String[]{"TLSv1.3"}; + public static final String[] cipher_suites = new String[]{"TLS_AES_128_GCM_SHA256"}; + + // Used by server-site to hash leaves and client-site to find the leaf public static String hash(String text) throws NoSuchAlgorithmException { MessageDigest digest = MessageDigest.getInstance("SHA-256"); @@ -22,6 +28,18 @@ public static String hash(String text) throws NoSuchAlgorithmException { return Base64.getEncoder().encodeToString(hash); } + public static void setup_tls() { + // If you get a null pointer, you forgot to populate environment variables... + String keystore = System.getenv("KEYSTORE"); + String password = System.getenv("PASSWORD"); + Properties systemProps = System.getProperties(); + systemProps.put("javax.net.ssl.keyStorePassword", password); + systemProps.put("javax.net.ssl.keyStore", keystore); + systemProps.put("javax.net.ssl.trustStore", keystore); + systemProps.put("javax.net.ssl.trustStorePassword", password); + System.setProperties(systemProps); + } + // Used by level-site and server-site to compare with a client public static boolean compare(NodeInfo ld, int comparisonType, Hashtable encrypted_features, diff --git a/src/test/java/PrivacyTest.java b/src/test/java/PrivacyTest.java index a4289cd..e068570 100644 --- a/src/test/java/PrivacyTest.java +++ b/src/test/java/PrivacyTest.java @@ -14,6 +14,7 @@ import java.util.Properties; import static org.junit.Assert.assertEquals; +import static weka.finito.utils.shared.setup_tls; public final class PrivacyTest { @@ -29,6 +30,9 @@ public final class PrivacyTest { private final static String [] delete_files = {"dgk", "dgk.pub", "paillier", "paillier.pub", "classes.txt"}; @Before public void read_properties() throws IOException { + + setup_tls(); + // Arguments: System.out.println("Running Full Local Test..."); System.out.println("Working Directory = " + System.getProperty("user.dir")); From 21873625c379d0ad0c05a407488066c4d7aecf90 Mon Sep 17 00:00:00 2001 From: AndrewQuijano Date: Sun, 22 Oct 2023 19:39:54 -0400 Subject: [PATCH 05/12] level-site-threat and server-site both use same method to classify on the level. I will add TLS for server part as well shortly/ --- build.gradle | 3 + .../java/weka/finito/level_site_thread.java | 62 ++-------------- src/main/java/weka/finito/server.java | 73 ++++--------------- src/main/java/weka/finito/utils/shared.java | 64 ++++++++++++++++ 4 files changed, 88 insertions(+), 114 deletions(-) diff --git a/build.gradle b/build.gradle index 1391203..a01a1d2 100644 --- a/build.gradle +++ b/build.gradle @@ -21,6 +21,9 @@ dependencies { testImplementation 'org.junit.jupiter:junit-jupiter-api:5.8.1' testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.8.1' + // https://mvnrepository.com/artifact/commons-io/commons-io + implementation 'commons-io:commons-io:2.14.0' + implementation fileTree(dir: 'libs', include: ['*.jar']) } diff --git a/src/main/java/weka/finito/level_site_thread.java b/src/main/java/weka/finito/level_site_thread.java index 69d227c..709b427 100644 --- a/src/main/java/weka/finito/level_site_thread.java +++ b/src/main/java/weka/finito/level_site_thread.java @@ -16,6 +16,7 @@ import java.util.Map; import static weka.finito.utils.shared.compare; +import static weka.finito.utils.shared.traverse_level; public class level_site_thread implements Runnable { @@ -91,7 +92,6 @@ public final void run() { return; } - List node_level_data = this.level_site_data.get_node_data(); niu.setDGKPublicKey(this.level_site_data.dgk_public_key); niu.setPaillierPublicKey(this.level_site_data.paillier_public_key); @@ -114,68 +114,20 @@ public final void run() { this.level_site_data.set_current_index(Integer.parseInt(previous_index)); } - boolean equalsFound = false; - boolean inequalityHolds = false; - boolean terminalLeafFound = false; - int node_level_index = 0; - int n = 0; - int next_index = 0; - NodeInfo ls = null; - String encrypted_next_index; + // Null, keep going down the tree, + // Not-null, you got the correct leaf node of your DT! + NodeInfo reply = traverse_level(level_site_data, encrypted_features, toClient, niu); - while ((!equalsFound) && (!terminalLeafFound)) { - ls = node_level_data.get(node_level_index); - System.out.println("j=" + node_level_index); - if (ls.isLeaf()) { - if (n == 2 * this.level_site_data.get_current_index() - || n == 2 * this.level_site_data.get_current_index() + 1) { - terminalLeafFound = true; - System.out.println("Terminal leaf:" + ls.getVariableName()); - } - n += 2; - } - else { - if ((n==2 * this.level_site_data.get_current_index() - || n == 2 * this.level_site_data.get_current_index() + 1)) { - if (ls.comparisonType == 6) { - boolean firstInequalityHolds = compare(ls, 3, - encrypted_features, toClient, niu); - if (firstInequalityHolds) { - inequalityHolds = true; - } - else { - boolean secondInequalityHolds = compare(ls, 5, - encrypted_features, toClient, niu); - if (secondInequalityHolds) { - inequalityHolds = true; - } - } - } - else { - inequalityHolds = compare(ls, ls.comparisonType, - encrypted_features, toClient, niu); - } - - if (inequalityHolds) { - equalsFound = true; - this.level_site_data.set_next_index(next_index); - System.out.println("New index: " + this.level_site_data.get_next_index()); - } - } - n++; - next_index++; - } - node_level_index++; - } + String encrypted_next_index; // Place -1 to break Protocol4 loop toClient.writeInt(-1); toClient.flush(); - if (terminalLeafFound) { + if (reply != null) { // Tell the client the value toClient.writeBoolean(true); - toClient.writeObject(ls.getVariableName()); + toClient.writeObject(reply.getVariableName()); } else { toClient.writeBoolean(false); diff --git a/src/main/java/weka/finito/server.java b/src/main/java/weka/finito/server.java index a43dd7e..45c696d 100644 --- a/src/main/java/weka/finito/server.java +++ b/src/main/java/weka/finito/server.java @@ -175,73 +175,28 @@ private void evaluate(int server_port) throws IOException, HomomorphicException Niu.setPaillierPublicKey(paillier_public); Niu.setDGKPublicKey(dgk_public); - List node_level_data; - NodeInfo ls; long start_time = System.nanoTime(); int previous_index = 0; // Traverse DT until you hit a leaf, the client has to track the index... for (level_order_site level_site_data : all_level_sites) { - node_level_data = level_site_data.get_node_data(); level_site_data.set_current_index(previous_index); // Handle at a level... - int node_level_index = 0; - int n = 0; - int next_index = 0; - boolean inequalityHolds = false; - - while (true) { - ls = node_level_data.get(node_level_index); - System.out.println("j=" + node_level_index); - if (ls.isLeaf()) { - if (n == 2 * level_site_data.get_current_index() - || n == 2 * level_site_data.get_current_index() + 1) { - System.out.println("Terminal leaf:" + ls.getVariableName()); - // Tell the client the value - to_client_site.writeInt(-1); - to_client_site.writeObject(ls.getVariableName()); - to_client_site.flush(); - long stop_time = System.nanoTime(); - double run_time = (double) (stop_time - start_time); - run_time = run_time / 1000000; - System.out.printf("Total Server-Site run-time took %f ms\n", run_time); - break; - } - n += 2; - } else { - if ((n == 2 * level_site_data.get_current_index() - || n == 2 * level_site_data.get_current_index() + 1)) { - if (ls.comparisonType == 6) { - boolean firstInequalityHolds = compare(ls, 3, features, - to_client_site, Niu); - if (firstInequalityHolds) { - inequalityHolds = true; - } else { - boolean secondInequalityHolds = compare(ls, 5, features, - to_client_site, Niu); - if (secondInequalityHolds) { - inequalityHolds = true; - } - } - } else { - inequalityHolds = compare(ls, ls.comparisonType, features, - to_client_site, Niu); - } - - if (inequalityHolds) { - level_site_data.set_next_index(next_index); - System.out.println("Next index:" + next_index); - // Remember this for next loop... - previous_index = next_index; - break; - } - } - n++; - next_index++; - } - node_level_index++; - } + NodeInfo leaf = traverse_level(level_site_data, features, to_client_site, Niu); + + // You found a leaf! No more traversing needed! + if (leaf != null) { + // Tell the client the value + to_client_site.writeInt(-1); + to_client_site.writeObject(leaf.getVariableName()); + to_client_site.flush(); + long stop_time = System.nanoTime(); + double run_time = (double) (stop_time - start_time); + run_time = run_time / 1000000; + System.out.printf("Total Server-Site run-time took %f ms\n", run_time); + break; + } } } catch (ClassNotFoundException e) { throw new RuntimeException(e); diff --git a/src/main/java/weka/finito/utils/shared.java b/src/main/java/weka/finito/utils/shared.java index 0844cc4..ac1a7f1 100644 --- a/src/main/java/weka/finito/utils/shared.java +++ b/src/main/java/weka/finito/utils/shared.java @@ -4,6 +4,7 @@ import security.socialistmillionaire.alice; import weka.finito.structs.BigIntegers; import weka.finito.structs.NodeInfo; +import weka.finito.structs.level_order_site; import java.io.IOException; import java.io.ObjectOutputStream; @@ -13,6 +14,7 @@ import java.security.NoSuchAlgorithmException; import java.util.Base64; import java.util.Hashtable; +import java.util.List; import java.util.Properties; public class shared { @@ -40,6 +42,68 @@ public static void setup_tls() { System.setProperties(systemProps); } + public static NodeInfo traverse_level(level_order_site level_site_data, + Hashtable encrypted_features, + ObjectOutputStream toClient, alice niu) + throws HomomorphicException, IOException, ClassNotFoundException { + + List node_level_data = level_site_data.get_node_data(); + boolean terminalLeafFound = false; + boolean equalsFound = false; + boolean inequalityHolds = false; + int node_level_index = 0; + int n = 0; + int next_index = 0; + NodeInfo ls; + NodeInfo to_return = null; + + while ((!equalsFound) && (!terminalLeafFound)) { + ls = node_level_data.get(node_level_index); + System.out.println("j=" + node_level_index); + if (ls.isLeaf()) { + if (n == 2 * level_site_data.get_current_index() + || n == 2 * level_site_data.get_current_index() + 1) { + terminalLeafFound = true; + to_return = ls; + } + n += 2; + } + else { + if ((n == 2 * level_site_data.get_current_index() + || n == 2 * level_site_data.get_current_index() + 1)) { + if (ls.comparisonType == 6) { + boolean firstInequalityHolds = compare(ls, 3, + encrypted_features, toClient, niu); + if (firstInequalityHolds) { + inequalityHolds = true; + } + else { + boolean secondInequalityHolds = compare(ls, 5, + encrypted_features, toClient, niu); + if (secondInequalityHolds) { + inequalityHolds = true; + } + } + } + else { + inequalityHolds = compare(ls, ls.comparisonType, + encrypted_features, toClient, niu); + } + + if (inequalityHolds) { + equalsFound = true; + level_site_data.set_next_index(next_index); + System.out.println("New index: " + level_site_data.get_next_index()); + } + } + n++; + next_index++; + } + node_level_index++; + } + return to_return; + } + // Used by level-site and server-site to compare with a client public static boolean compare(NodeInfo ld, int comparisonType, Hashtable encrypted_features, From fa726f2905b4d3df5471adfc1a3f8101fe1cbf21 Mon Sep 17 00:00:00 2001 From: AndrewQuijano Date: Sun, 22 Oct 2023 20:15:51 -0400 Subject: [PATCH 06/12] Client/Server will use TLS too. Also, fixed the bug so client/server can evaluate just fine used the shared code --- src/main/java/weka/finito/client.java | 16 ++- .../java/weka/finito/level_site_server.java | 5 +- src/main/java/weka/finito/server.java | 124 ++++++++++-------- 3 files changed, 81 insertions(+), 64 deletions(-) diff --git a/src/main/java/weka/finito/client.java b/src/main/java/weka/finito/client.java index ab5a041..b87eedf 100644 --- a/src/main/java/weka/finito/client.java +++ b/src/main/java/weka/finito/client.java @@ -396,7 +396,7 @@ else if (comparison_type == 1) { // Function used to Train (if needed) and Evaluate public void run() { - // Step : 1 + // Step: 1 SSLSocketFactory factory = (SSLSocketFactory) SSLSocketFactory.getDefault(); boolean talk_to_server_site = this.need_keys(); @@ -438,7 +438,14 @@ public void run() { // If you are just evaluating directly with the server-site if (level_site_ips == null) { - try(Socket server_site = new Socket(server_ip, server_port)) { + try(SSLSocket server_site = (SSLSocket) factory.createSocket(server_ip, server_port)) { + // Step: 3 + server_site.setEnabledProtocols(protocols); + server_site.setEnabledCipherSuites(cipher_suites); + + // Step: 4 {optional} + server_site.startHandshake(); + System.out.println("Client connected to sever-site with PPDT"); evaluate_with_server_site(server_site); long end_time = System.nanoTime(); @@ -468,13 +475,12 @@ public void run() { connection_port = port; } - //try(Socket level_site = new Socket(level_site_ips[i], connection_port)) { try(SSLSocket level_site = (SSLSocket) factory.createSocket(level_site_ips[i], connection_port)) { - // Step : 3 + // Step: 3 level_site.setEnabledProtocols(protocols); level_site.setEnabledCipherSuites(cipher_suites); - // Step : 4 {optional} + // Step: 4 {optional} level_site.startHandshake(); System.out.println("Client connected to level " + i); communicate_with_level_site(level_site); diff --git a/src/main/java/weka/finito/level_site_server.java b/src/main/java/weka/finito/level_site_server.java index 68ce20f..dab61ff 100644 --- a/src/main/java/weka/finito/level_site_server.java +++ b/src/main/java/weka/finito/level_site_server.java @@ -24,7 +24,7 @@ public class level_site_server implements Runnable { protected int precision; protected AES crypto; - protected SSLServerSocketFactory factory; + protected SSLServerSocketFactory factory = (SSLServerSocketFactory) SSLServerSocketFactory.getDefault(); public static void main(String[] args) throws NoSuchPaddingException, NoSuchAlgorithmException { int our_port = 0; @@ -121,8 +121,7 @@ public synchronized void stop(){ private void openServerSocket() { try { - // Step : 1 - factory = (SSLServerSocketFactory) SSLServerSocketFactory.getDefault(); + // Step: 1 serverSocket = (SSLServerSocket) factory.createServerSocket(this.serverPort); serverSocket.setEnabledProtocols(protocols); serverSocket.setEnabledCipherSuites(cipher_suites); diff --git a/src/main/java/weka/finito/server.java b/src/main/java/weka/finito/server.java index 45c696d..9b205ec 100644 --- a/src/main/java/weka/finito/server.java +++ b/src/main/java/weka/finito/server.java @@ -30,6 +30,8 @@ import weka.finito.structs.level_order_site; import weka.finito.structs.NodeInfo; +import javax.net.ssl.SSLServerSocket; +import javax.net.ssl.SSLServerSocketFactory; import javax.net.ssl.SSLSocket; import javax.net.ssl.SSLSocketFactory; @@ -38,6 +40,7 @@ public final class server implements Runnable { + private final SSLServerSocketFactory factory = (SSLServerSocketFactory) SSLServerSocketFactory.getDefault(); private static final String os = System.getProperty("os.name").toLowerCase(); private final String training_data; private final String [] level_site_ips; @@ -152,56 +155,67 @@ public server(String training_data, int precision, int server_port, String clien this.client = client; } + private void run_one_time(int port) throws IOException, HomomorphicException, ClassNotFoundException { + try (SSLServerSocket serverSocket = (SSLServerSocket) factory.createServerSocket(port);) { + serverSocket.setEnabledProtocols(protocols); + serverSocket.setEnabledCipherSuites(cipher_suites); + + System.out.println("Server will be waiting for direct evaluation from client"); + try (SSLSocket client_site = (SSLSocket) serverSocket.accept()) { + evaluate(client_site); + } + } + } + + private void evaluate(SSLSocket client_site) + throws IOException, HomomorphicException, ClassNotFoundException { + + ObjectOutputStream to_client_site = new ObjectOutputStream(client_site.getOutputStream()); + ObjectInputStream from_client_site = new ObjectInputStream(client_site.getInputStream()); - private void evaluate(int server_port) throws IOException, HomomorphicException { - ServerSocket serverSocket = new ServerSocket(server_port); - System.out.println("Server will be waiting for direct evaluation from client"); Object client_input; Hashtable features = new Hashtable<>(); - try (Socket client_site = serverSocket.accept()) { - ObjectOutputStream to_client_site = new ObjectOutputStream(client_site.getOutputStream()); - ObjectInputStream from_client_site = new ObjectInputStream(client_site.getInputStream()); - // Get encrypted features - client_input = from_client_site.readObject(); - if (client_input instanceof Hashtable) { - for (Entry entry: ((Hashtable) client_input).entrySet()){ - if (entry.getKey() instanceof String && entry.getValue() instanceof BigIntegers) { - features.put((String) entry.getKey(), (BigIntegers) entry.getValue()); - } + // Get encrypted features + client_input = from_client_site.readObject(); + if (client_input instanceof Hashtable) { + for (Entry entry: ((Hashtable) client_input).entrySet()){ + if (entry.getKey() instanceof String && entry.getValue() instanceof BigIntegers) { + features.put((String) entry.getKey(), (BigIntegers) entry.getValue()); } } - alice Niu = new alice(client_site); - Niu.setPaillierPublicKey(paillier_public); - Niu.setDGKPublicKey(dgk_public); - - long start_time = System.nanoTime(); - int previous_index = 0; - - // Traverse DT until you hit a leaf, the client has to track the index... - for (level_order_site level_site_data : all_level_sites) { - level_site_data.set_current_index(previous_index); - - // Handle at a level... - NodeInfo leaf = traverse_level(level_site_data, features, to_client_site, Niu); - - // You found a leaf! No more traversing needed! - if (leaf != null) { - // Tell the client the value - to_client_site.writeInt(-1); - to_client_site.writeObject(leaf.getVariableName()); - to_client_site.flush(); - long stop_time = System.nanoTime(); - double run_time = (double) (stop_time - start_time); - run_time = run_time / 1000000; - System.out.printf("Total Server-Site run-time took %f ms\n", run_time); - break; - } - } - } catch (ClassNotFoundException e) { - throw new RuntimeException(e); - } - serverSocket.close(); + } + alice Niu = new alice(client_site); + Niu.setPaillierPublicKey(paillier_public); + Niu.setDGKPublicKey(dgk_public); + + long start_time = System.nanoTime(); + int previous_index = 0; + + // Traverse DT until you hit a leaf, the client has to track the index... + for (level_order_site level_site_data : all_level_sites) { + level_site_data.set_current_index(previous_index); + + // Handle at a level... + NodeInfo leaf = traverse_level(level_site_data, features, to_client_site, Niu); + + // You found a leaf! No more traversing needed! + if (leaf != null) { + // Tell the client the value + to_client_site.writeInt(-1); + to_client_site.writeObject(leaf.getVariableName()); + to_client_site.flush(); + long stop_time = System.nanoTime(); + double run_time = (double) (stop_time - start_time); + run_time = run_time / 1000000; + System.out.printf("Total Server-Site run-time took %f ms\n", run_time); + break; + } + else { + // Update the index for next merry-go-round + previous_index = level_site_data.get_next_index(); + } + } } private void client_communication() throws Exception { @@ -446,12 +460,15 @@ else if (node_info.comparisonType == 6) { } // While n > 0 (nodes > 0) all_level_sites.add(Level_Order_S); ++level; - } // While Tree Not Empty + } // While a tree is not empty } + // Run should + // Train, ONLY if necessary + // Evaluate, be prepared for either level-site or no level-site case, 1 time public void run() { - // Step : 1 + // Step 1 SSLSocketFactory factory = (SSLSocketFactory) SSLSocketFactory.getDefault(); try { @@ -471,12 +488,12 @@ public void run() { // If we are testing without level-sites do this... if (!use_level_sites) { try { - evaluate(this.server_port); + run_one_time(this.server_port); } - catch (IOException | HomomorphicException e) { + catch (IOException | HomomorphicException | ClassNotFoundException e) { throw new RuntimeException(e); } - return; + return; } // There should be at least 1 IP Address for each level site @@ -506,17 +523,12 @@ public void run() { } current_level_site.set_next_level_site_port(connection_port); - // System.out.println("Current level-site: " + level_site_ips[i] + ":" + port_to_connect); - // System.out.println("Next is: " + current_level_site.get_next_level_site() - // + ":" + current_level_site.get_next_level_site_port()); - - //try (Socket level_site = new Socket(level_site_ips[i], port_to_connect)) { try(SSLSocket level_site = (SSLSocket) factory.createSocket(level_site_ips[i], connection_port)) { - // Step : 3 + // Step: 3 level_site.setEnabledProtocols(protocols); level_site.setEnabledCipherSuites(cipher_suites); - // Step : 4 {optional} + // Step: 4 {optional} level_site.startHandshake(); System.out.println("training level-site " + i + " on port:" + connection_port); From 24b4ec89f3e66b178f7a3ec24ab58195306fef63 Mon Sep 17 00:00:00 2001 From: AndrewQuijano Date: Sun, 22 Oct 2023 21:21:42 -0400 Subject: [PATCH 07/12] Ugh I am likely going to sin, because I need to have Docker have my keystore up...so I just wonder how to do this... --- .github/workflows/build-gradle-project.yml | 7 +- create_certificate.sh => create_keystore.sh | 0 src/main/java/weka/finito/client.java | 6 +- .../java/weka/finito/level_site_thread.java | 11 +-- src/main/java/weka/finito/server.java | 83 +++++++++---------- 5 files changed, 49 insertions(+), 58 deletions(-) rename create_certificate.sh => create_keystore.sh (100%) diff --git a/.github/workflows/build-gradle-project.yml b/.github/workflows/build-gradle-project.yml index ffacb80..35800a0 100644 --- a/.github/workflows/build-gradle-project.yml +++ b/.github/workflows/build-gradle-project.yml @@ -10,7 +10,6 @@ jobs: ALIAS: andrew KEYSTORE: andrew_keystore PASSWORD: ${{ secrets.PASSWORD }} - CERTIFICATE: andrew_certificate steps: - name: Checkout project sources @@ -26,7 +25,11 @@ jobs: java-version: '17' cache: 'gradle' - - run: sh gradlew build + - name: Create keystore for local use + run: sh create_keystore.sh + + - name: Run Gradle Testing + run: sh gradlew build - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v3 diff --git a/create_certificate.sh b/create_keystore.sh similarity index 100% rename from create_certificate.sh rename to create_keystore.sh diff --git a/src/main/java/weka/finito/client.java b/src/main/java/weka/finito/client.java index b87eedf..490e779 100644 --- a/src/main/java/weka/finito/client.java +++ b/src/main/java/weka/finito/client.java @@ -28,6 +28,8 @@ import static weka.finito.utils.shared.*; public final class client implements Runnable { + private final SSLSocketFactory factory = (SSLSocketFactory) SSLSocketFactory.getDefault(); + private final String classes_file = "classes.txt"; private final String features_file; private final int key_size; @@ -204,7 +206,7 @@ private boolean need_keys() { } private String [] read_classes() { - // Don't forget to remember the classes of DT as well + // Remember the classes of DT as well StringBuilder content = new StringBuilder(); String line; @@ -397,8 +399,6 @@ else if (comparison_type == 1) { public void run() { // Step: 1 - SSLSocketFactory factory = (SSLSocketFactory) SSLSocketFactory.getDefault(); - boolean talk_to_server_site = this.need_keys(); try { diff --git a/src/main/java/weka/finito/level_site_thread.java b/src/main/java/weka/finito/level_site_thread.java index 709b427..3d297c0 100644 --- a/src/main/java/weka/finito/level_site_thread.java +++ b/src/main/java/weka/finito/level_site_thread.java @@ -12,10 +12,8 @@ import java.io.ObjectOutputStream; import java.net.Socket; import java.util.Hashtable; -import java.util.List; import java.util.Map; -import static weka.finito.utils.shared.compare; import static weka.finito.utils.shared.traverse_level; public class level_site_thread implements Runnable { @@ -32,8 +30,6 @@ public class level_site_thread implements Runnable { public level_site_thread(Socket client_socket, level_order_site level_site_data, AES crypto) { this.client_socket = client_socket; this.crypto = crypto; - String clientIpAddress = client_socket.getInetAddress().getHostAddress(); - System.out.println("Level-Site got connection from client: " + clientIpAddress); Object x; try { @@ -42,7 +38,7 @@ public level_site_thread(Socket client_socket, level_order_site level_site_data, x = fromClient.readObject(); if (x instanceof level_order_site) { - // Traffic from Server. Level-Site alone will manage closing this. + // Traffic from Server. Level-Site alone will manage to close this. this.level_site_data = (level_order_site) x; // System.out.println("Level-Site received training data on Port: " + client_socket.getLocalPort()); this.toClient.writeBoolean(true); @@ -115,7 +111,7 @@ public final void run() { } // Null, keep going down the tree, - // Not-null, you got the correct leaf node of your DT! + // Not null, you got the correct leaf node of your DT! NodeInfo reply = traverse_level(level_site_data, encrypted_features, toClient, niu); String encrypted_next_index; @@ -130,8 +126,9 @@ public final void run() { toClient.writeObject(reply.getVariableName()); } else { + toClient.writeBoolean(false); - // encrypt with AES, send to client which will send to next level-site + // encrypt with AES, send to the client which will send to next level-site encrypted_next_index = crypto.encrypt(String.valueOf(this.level_site_data.get_next_index())); iv = crypto.getIV(); toClient.writeObject(encrypted_next_index); diff --git a/src/main/java/weka/finito/server.java b/src/main/java/weka/finito/server.java index 9b205ec..c6d1b70 100644 --- a/src/main/java/weka/finito/server.java +++ b/src/main/java/weka/finito/server.java @@ -41,6 +41,7 @@ public final class server implements Runnable { private final SSLServerSocketFactory factory = (SSLServerSocketFactory) SSLServerSocketFactory.getDefault(); + private final SSLSocketFactory socket_factory = (SSLSocketFactory) SSLSocketFactory.getDefault(); private static final String os = System.getProperty("os.name").toLowerCase(); private final String training_data; private final String [] level_site_ips; @@ -49,14 +50,13 @@ public final class server implements Runnable { private PaillierPublicKey paillier_public; private DGKPublicKey dgk_public; private final int precision; - private ClassifierTree ppdt; + private ClassifierTree ppdt = null; private final List leaves = new ArrayList<>(); private final List all_level_sites = new ArrayList<>(); private final int server_port; - private final boolean use_level_sites; - private String client; + private int evaluations = 1; public static void main(String[] args) { int port = 0; @@ -99,9 +99,6 @@ else if (args.length == 2){ } String[] level_domains = level_domains_str.split(","); - // Figure out Client IP for last level-site to send back - String client = System.getenv("CLIENT"); - // Create and run the server. System.out.println("Server Initialized and started running"); @@ -111,7 +108,7 @@ else if (args.length == 2){ server = new server(training_data, level_domains, port, precision, port); } else { - server = new server(training_data, precision, port, client); + server = new server(training_data, precision, port); } server.run(); } @@ -124,50 +121,44 @@ public server(String training_data, String [] level_site_ips, int [] level_site_ this.level_site_ports = level_site_ports; this.precision = precision; this.server_port = server_port; - this.use_level_sites = true; } - // For Cloud environment, (Testing with Kubernetes, for level-sites) + // For Cloud environment, (Testing with Kubernetes/EKS, for level-sites) public server(String training_data, String [] level_site_domains, int port, int precision, int server_port) { this.training_data = training_data; this.level_site_ips = level_site_domains; this.port = port; this.precision = precision; this.server_port = server_port; - this.use_level_sites = true; + // I will likely want more than 1 test with server-site! + this.evaluations = 100; } - // For testing, but having just a client and server + // For testing, but having just a client and server, just do one evaluation for the sake of testing. public server(String training_data, int precision, int server_port) { this.training_data = training_data; this.level_site_ips = null; this.precision = precision; this.server_port = server_port; - this.use_level_sites = false; } - public server(String training_data, int precision, int server_port, String client) { - this.training_data = training_data; - this.level_site_ips = null; - this.precision = precision; - this.server_port = server_port; - this.use_level_sites = false; - this.client = client; - } - - private void run_one_time(int port) throws IOException, HomomorphicException, ClassNotFoundException { - try (SSLServerSocket serverSocket = (SSLServerSocket) factory.createServerSocket(port);) { + private void run_server_site(int port) throws IOException, HomomorphicException, ClassNotFoundException { + int count = 0; + try (SSLServerSocket serverSocket = (SSLServerSocket) factory.createServerSocket(port)) { serverSocket.setEnabledProtocols(protocols); serverSocket.setEnabledCipherSuites(cipher_suites); System.out.println("Server will be waiting for direct evaluation from client"); - try (SSLSocket client_site = (SSLSocket) serverSocket.accept()) { - evaluate(client_site); + while (count < evaluations) { + try (SSLSocket client_site = (SSLSocket) serverSocket.accept()) { + evaluate_with_client_directly(client_site); + } + ++count; } } } - private void evaluate(SSLSocket client_site) + private void evaluate_with_client_directly(SSLSocket client_site) throws IOException, HomomorphicException, ClassNotFoundException { ObjectOutputStream to_client_site = new ObjectOutputStream(client_site.getOutputStream()); @@ -468,37 +459,39 @@ else if (node_info.comparisonType == 6) { // Evaluate, be prepared for either level-site or no level-site case, 1 time public void run() { - // Step 1 - SSLSocketFactory factory = (SSLSocketFactory) SSLSocketFactory.getDefault(); - try { - // Train the DT - ppdt = train_decision_tree(this.training_data); - // Get Public Keys from Client AND train level-sites - client_communication(); + // Train the DT if you have to. + if (ppdt == null) { + ppdt = train_decision_tree(this.training_data); + // Get Public Keys from Client AND train level-sites + client_communication(); + } } catch (Exception e) { throw new RuntimeException(e); } - ObjectOutputStream to_level_site; - ObjectInputStream from_level_site; - int connection_port; - // If we are testing without level-sites do this... - if (!use_level_sites) { + if (this.level_site_ips != null) { + train_level_sites(); + } else { try { - run_one_time(this.server_port); + run_server_site(this.server_port); } catch (IOException | HomomorphicException | ClassNotFoundException e) { throw new RuntimeException(e); } - return; } + } + + private void train_level_sites() { + ObjectOutputStream to_level_site; + ObjectInputStream from_level_site; + int connection_port; // There should be at least 1 IP Address for each level site - assert this.level_site_ips != null; - if(this.level_site_ips.length < all_level_sites.size()) { + assert this.level_site_ips != null; + if(this.level_site_ips.length < all_level_sites.size()) { String error = String.format("Please create more level-sites for the " + "decision tree trained from %s", training_data); throw new RuntimeException(error); @@ -518,12 +511,10 @@ public void run() { if (i + 1 != all_level_sites.size()) { current_level_site.set_next_level_site(level_site_ips[(i + 1) % level_site_ips.length]); } - else { - current_level_site.set_next_level_site(client); - } + current_level_site.set_next_level_site_port(connection_port); - try(SSLSocket level_site = (SSLSocket) factory.createSocket(level_site_ips[i], connection_port)) { + try(SSLSocket level_site = (SSLSocket) socket_factory.createSocket(level_site_ips[i], connection_port)) { // Step: 3 level_site.setEnabledProtocols(protocols); level_site.setEnabledCipherSuites(cipher_suites); From ab0ea445510129c2998156a36a9b4f8afecd2232 Mon Sep 17 00:00:00 2001 From: AndrewQuijano Date: Sun, 22 Oct 2023 21:32:33 -0400 Subject: [PATCH 08/12] Install keytool for me too please --- .github/workflows/build-gradle-project.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build-gradle-project.yml b/.github/workflows/build-gradle-project.yml index 35800a0..de8f3d6 100644 --- a/.github/workflows/build-gradle-project.yml +++ b/.github/workflows/build-gradle-project.yml @@ -15,9 +15,12 @@ jobs: - name: Checkout project sources uses: actions/checkout@v3 - - name: Install packages + - name: Install Graphviz to visualize trees run: sudo apt-get install -y graphviz + - name: Install Java to have KeyTool + run: sudo apt-get install -y java-11-amazon-corretto-jdk + - name: Setup Gradle uses: actions/setup-java@v3 with: From c65118c945a583e5d0b42cdae49a63a7d7fa5f9b Mon Sep 17 00:00:00 2001 From: AndrewQuijano Date: Sun, 22 Oct 2023 21:36:34 -0400 Subject: [PATCH 09/12] Maybe just default JRE/JDK can get me keytool --- .github/workflows/build-gradle-project.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build-gradle-project.yml b/.github/workflows/build-gradle-project.yml index de8f3d6..55d6d8b 100644 --- a/.github/workflows/build-gradle-project.yml +++ b/.github/workflows/build-gradle-project.yml @@ -19,7 +19,7 @@ jobs: run: sudo apt-get install -y graphviz - name: Install Java to have KeyTool - run: sudo apt-get install -y java-11-amazon-corretto-jdk + run: sudo apt-get install -y default-jdk, default-jre, graphviz, curl, python3-pip - name: Setup Gradle uses: actions/setup-java@v3 From 63b580eabffafb3a200474b7e7a433035a4d94c9 Mon Sep 17 00:00:00 2001 From: AndrewQuijano Date: Sun, 22 Oct 2023 21:39:38 -0400 Subject: [PATCH 10/12] Never mind, I don't get how this ran WITHOUT any issue of needing a key store? --- .github/workflows/build-gradle-project.yml | 6 ------ 1 file changed, 6 deletions(-) diff --git a/.github/workflows/build-gradle-project.yml b/.github/workflows/build-gradle-project.yml index 55d6d8b..3a95999 100644 --- a/.github/workflows/build-gradle-project.yml +++ b/.github/workflows/build-gradle-project.yml @@ -18,9 +18,6 @@ jobs: - name: Install Graphviz to visualize trees run: sudo apt-get install -y graphviz - - name: Install Java to have KeyTool - run: sudo apt-get install -y default-jdk, default-jre, graphviz, curl, python3-pip - - name: Setup Gradle uses: actions/setup-java@v3 with: @@ -28,9 +25,6 @@ jobs: java-version: '17' cache: 'gradle' - - name: Create keystore for local use - run: sh create_keystore.sh - - name: Run Gradle Testing run: sh gradlew build From bd1cdd84a9cec8784ac4d44199ffac2189c873e0 Mon Sep 17 00:00:00 2001 From: AndrewQuijano Date: Sun, 22 Oct 2023 21:44:14 -0400 Subject: [PATCH 11/12] Maybe I should put it after java setup --- .github/workflows/build-gradle-project.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/build-gradle-project.yml b/.github/workflows/build-gradle-project.yml index 3a95999..793d656 100644 --- a/.github/workflows/build-gradle-project.yml +++ b/.github/workflows/build-gradle-project.yml @@ -25,6 +25,9 @@ jobs: java-version: '17' cache: 'gradle' + - name: Create Key Store + run: sh create_keystore.sh + - name: Run Gradle Testing run: sh gradlew build From 0cde7d5554dac57683ab8233effd93b0ce1e4571 Mon Sep 17 00:00:00 2001 From: AndrewQuijano Date: Sun, 22 Oct 2023 21:46:37 -0400 Subject: [PATCH 12/12] Oh, I don't need the damn certificate, just keystore --- create_keystore.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/create_keystore.sh b/create_keystore.sh index 4d4fbc8..bb507e9 100644 --- a/create_keystore.sh +++ b/create_keystore.sh @@ -18,4 +18,4 @@ keytool -genkey -noprompt \ # -Djavax.net.ssl.keyStore=${KEYSTORE} -Djavax.net.ssl.keyStorePassword=${PASSWORD} # -Djavax.net.ssl.trustStore=${KEYSTORE} -Djavax.net.ssl.trustStorePassword=${PASSWORD} -keytool -export -alias "${ALIAS}" -keystore "${KEYSTORE}" -rfc -file "${CERTIFICATE}" -storepass "${PASSWORD}" +# keytool -export -alias "${ALIAS}" -keystore "${KEYSTORE}" -rfc -file "${CERTIFICATE}" -storepass "${PASSWORD}"