diff --git a/rpc-client-test/src/main/java/com/antgroup/tugraph/TuGraphDbRpcClientTest.java b/rpc-client-test/src/main/java/com/antgroup/tugraph/TuGraphDbRpcClientTest.java index f1b0a70..45f4f67 100644 --- a/rpc-client-test/src/main/java/com/antgroup/tugraph/TuGraphDbRpcClientTest.java +++ b/rpc-client-test/src/main/java/com/antgroup/tugraph/TuGraphDbRpcClientTest.java @@ -5,6 +5,7 @@ import org.slf4j.LoggerFactory; import java.io.IOException; +import java.util.ArrayList; public class TuGraphDbRpcClientTest { static Logger log = LoggerFactory.getLogger(TuGraphDbRpcClientTest.class); @@ -35,6 +36,16 @@ public static void loadProcedure(TuGraphDbRpcClient client) { result = client.loadProcedure("./scan_graph.so", "CPP", "scan_graph", "SO", "test scan_graph", true, "v1", "default"); log.info("loadProcedure : " + result); + + String[] multi_files = { + "../../test/test_procedures/multi_files.cpp", + "../../test/test_procedures/multi_files.h", + "../../test/test_procedures/multi_files_core.cpp" + }; + result = client.loadProcedure(multi_files, "CPP", "multi_file", "CPP", "test sortstr", true, + "v1", "default"); + log.info("loadProcedure : " + result); + assert (result); } catch (IOException e) { log.info("catch IOException : " + e.getMessage()); } catch (Exception e) { @@ -48,7 +59,7 @@ public static void listProcedures(TuGraphDbRpcClient client) throws Exception { String result = client.listProcedures("CPP", "any", "default"); log.info("testListProcedure: " + result); JSONArray jsonArray = JSONArray.parseArray(result); - assert (jsonArray.size() == 2); + assert (jsonArray.size() == 3); } catch (TuGraphDbRpcException e) { log.info("catch TuGraphDbRpcException : " + e.getMessage()); } diff --git a/rpc-client/src/main/java/com/antgroup/tugraph/TuGraphDbRpcClient.java b/rpc-client/src/main/java/com/antgroup/tugraph/TuGraphDbRpcClient.java index 295165b..9814878 100644 --- a/rpc-client/src/main/java/com/antgroup/tugraph/TuGraphDbRpcClient.java +++ b/rpc-client/src/main/java/com/antgroup/tugraph/TuGraphDbRpcClient.java @@ -185,6 +185,22 @@ public boolean loadProcedure(String sourceFile, String procedureType, String pro } } + public boolean loadProcedure(String[] sourceFiles, String procedureType, String procedureName, String codeType, + String procedureDescription, boolean readOnly, String version, String graph) throws Exception { + if (clientType == ClientType.SINGLE_CONNECTION) { + return baseClient.loadProcedure(sourceFiles, procedureType, procedureName, codeType, procedureDescription, readOnly, version, graph); + } else { + return doubleCheckQuery(()-> { + boolean succeed = leaderClient.loadProcedure(sourceFiles, procedureType, procedureName, codeType, procedureDescription, readOnly, version, graph); + //update procedure info + if (succeed) { + refreshUserDefinedProcedure(); + } + return succeed; + }); + } + } + public String listProcedures(String procedureType, String version, String graph) throws Exception { if (clientType == ClientType.SINGLE_CONNECTION) { return baseClient.listProcedures(procedureType, version, graph); @@ -705,6 +721,15 @@ public boolean loadProcedure(String sourceFile, String codeType, String procedureDescription, boolean readOnly, String version, String graph) throws IOException { + String[] sourceFiles = {sourceFile}; + return loadProcedure(sourceFiles, procedureType, procedureName, codeType, procedureDescription, readOnly, version, graph); + } + + public boolean loadProcedure(String[] sourceFiles, + String procedureType, String procedureName, + String codeType, + String procedureDescription, boolean readOnly, + String version, String graph) throws IOException { Lgraph.PluginRequest.PluginType type = "CPP".equals(procedureType) ? Lgraph.PluginRequest.PluginType.CPP : Lgraph.PluginRequest.PluginType.PYTHON; Lgraph.LoadPluginRequest.CodeType cType = @@ -712,10 +737,19 @@ public boolean loadProcedure(String sourceFile, "PY".equals(codeType) ? Lgraph.LoadPluginRequest.CodeType.PY : "CPP".equals(codeType) ? Lgraph.LoadPluginRequest.CodeType.CPP : Lgraph.LoadPluginRequest.CodeType.ZIP; - ByteString content = ByteString.copyFrom(Objects.requireNonNull(binaryFileReader(sourceFile))); + List contents = new ArrayList(); + List filenames = new ArrayList(); + for (String sourceFile : sourceFiles) { + ByteString content = ByteString.copyFrom(Objects.requireNonNull(binaryFileReader(sourceFile))); + contents.add(content); + String[] files = sourceFile.split("/"); + String fn = files[files.length - 1]; + filenames.add(fn); + } Lgraph.LoadPluginRequest lpRequest = Lgraph.LoadPluginRequest.newBuilder().setName(procedureName).setDesc(procedureDescription) - .setReadOnly(readOnly).setCode(content).setCodeType(cType).build(); + .setReadOnly(readOnly).setCodeType(cType) + .addAllCode(contents).addAllFileName(filenames).build(); Lgraph.PluginRequest req = Lgraph.PluginRequest.newBuilder().setType(type).setLoadPluginRequest(lpRequest).setGraph(graph).setVersion(version).build(); Lgraph.LGraphRequest request = diff --git a/rpc-client/src/main/proto/lgraph.proto b/rpc-client/src/main/proto/lgraph.proto index 321ee64..9d9795a 100644 --- a/rpc-client/src/main/proto/lgraph.proto +++ b/rpc-client/src/main/proto/lgraph.proto @@ -574,9 +574,10 @@ message LoadPluginRequest { required string name = 1; required bool read_only = 2; - required bytes code = 3; + repeated bytes code = 3; optional string desc = 4; optional CodeType code_type = 5; + repeated string file_name = 6; }; message LoadPluginResponse {};