diff --git a/example/src/main/scala/com/vesoft/nebula/examples/connector/NebulaSparkReaderExample.scala b/example/src/main/scala/com/vesoft/nebula/examples/connector/NebulaSparkReaderExample.scala index 5bca2235..b54cf954 100644 --- a/example/src/main/scala/com/vesoft/nebula/examples/connector/NebulaSparkReaderExample.scala +++ b/example/src/main/scala/com/vesoft/nebula/examples/connector/NebulaSparkReaderExample.scala @@ -7,6 +7,7 @@ package com.vesoft.nebula.examples.connector import com.facebook.thrift.protocol.TCompactProtocol import com.vesoft.nebula.connector.connector.NebulaDataFrameReader +import com.vesoft.nebula.connector.ssl.SSLSignType import com.vesoft.nebula.connector.{NebulaConnectionConfig, ReadNebulaConfig} import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession @@ -132,4 +133,35 @@ object NebulaSparkReaderExample { LOG.info("edge rdd count: {}", edgeRDD.count()) } + /** + * read Nebula vertex with SSL + */ + def readVertexWithSSL(spark: SparkSession): Unit = { + LOG.info("start to read nebula vertices with ssl") + val config = + NebulaConnectionConfig + .builder() + .withMetaAddress("127.0.0.1:9559") + .withEnableMetaSSL(true) + .withEnableStorageSSL(true) + .withSSLSignType(SSLSignType.CA) + .withCaSSLSignParam("example/src/main/resources/ssl/casigned.pem", + "example/src/main/resources/ssl/casigned.crt", + "example/src/main/resources/ssl/casigned.key") + .withConenctionRetry(2) + .build() + val nebulaReadVertexConfig: ReadNebulaConfig = ReadNebulaConfig + .builder() + .withSpace("test") + .withLabel("person") + .withNoColumn(false) + .withReturnCols(List("birthday")) + .withLimit(10) + .withPartitionNum(10) + .build() + val vertex = spark.read.nebula(config, nebulaReadVertexConfig).loadVerticesToDF() + vertex.printSchema() + vertex.show(20) + println("vertex count: " + vertex.count()) + } } diff --git a/example/src/main/scala/com/vesoft/nebula/examples/connector/NebulaSparkWriterExample.scala b/example/src/main/scala/com/vesoft/nebula/examples/connector/NebulaSparkWriterExample.scala index bf988d89..2fe34482 100644 --- a/example/src/main/scala/com/vesoft/nebula/examples/connector/NebulaSparkWriterExample.scala +++ b/example/src/main/scala/com/vesoft/nebula/examples/connector/NebulaSparkWriterExample.scala @@ -8,12 +8,12 @@ package com.vesoft.nebula.examples.connector import com.facebook.thrift.protocol.TCompactProtocol import com.vesoft.nebula.connector.{ NebulaConnectionConfig, - SSLSignType, WriteMode, WriteNebulaEdgeConfig, WriteNebulaVertexConfig } import com.vesoft.nebula.connector.connector.NebulaDataFrameWriter +import com.vesoft.nebula.connector.ssl.SSLSignType import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession import org.apache.spark.storage.StorageLevel @@ -63,6 +63,7 @@ object NebulaSparkWriterExample { .withMetaAddress("127.0.0.1:9559") .withGraphAddress("127.0.0.1:9669") .withConenctionRetry(2) + .withEnableMetaSSL(true) .withEnableGraphSSL(true) .withSSLSignType(SSLSignType.CA) .withCaSSLSignParam("example/src/main/resources/ssl/casigned.pem", @@ -77,6 +78,7 @@ object NebulaSparkWriterExample { .withMetaAddress("127.0.0.1:9559") .withGraphAddress("127.0.0.1:9669") .withConenctionRetry(2) + .withEnableMetaSSL(true) .withEnableGraphSSL(true) .withSSLSignType(SSLSignType.SELF) .withSelfSSLSignParam("example/src/main/resources/ssl/selfsigned.pem", diff --git a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/NebulaConfig.scala b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/NebulaConfig.scala index d1979f88..e2fad798 100644 --- a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/NebulaConfig.scala +++ b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/NebulaConfig.scala @@ -5,8 +5,7 @@ package com.vesoft.nebula.connector -import com.vesoft.nebula.client.graph.data.{CASignedSSLParam, SelfSignedSSLParam} -import com.vesoft.nebula.connector.NebulaConnectionConfig.ConfigBuilder +import com.vesoft.nebula.connector.ssl.{CASSLSignParams, SSLSignType, SelfSSLSignParams} import org.slf4j.{Logger, LoggerFactory} import scala.collection.mutable.ListBuffer @@ -20,8 +19,8 @@ class NebulaConnectionConfig(metaAddress: String, enableGraphSSL: Boolean, enableStorageSSL: Boolean, signType: SSLSignType.Value, - caSignParam: CASignedSSLParam, - selfSignParam: SelfSignedSSLParam) + caSignParam: CASSLSignParams, + selfSignParam: SelfSSLSignParams) extends Serializable { def getMetaAddress = metaAddress def getGraphAddress = graphAddress @@ -31,9 +30,13 @@ class NebulaConnectionConfig(metaAddress: String, def getEnableMetaSSL = enableMetaSSL def getEnableGraphSSL = enableGraphSSL def getEnableStorageSSL = enableStorageSSL - def getSignType = signType - def getCaSignParam = caSignParam - def getSelfSignParam = selfSignParam + def getSignType = signType.toString + def getCaSignParam: String = { + caSignParam.caCrtFilePath + "," + caSignParam.crtFilePath + "," + caSignParam.keyFilePath + } + def getSelfSignParam: String = { + selfSignParam.crtFilePath + "," + selfSignParam.keyFilePath + "," + selfSignParam.password + } } object NebulaConnectionConfig { @@ -46,12 +49,12 @@ object NebulaConnectionConfig { protected var connectionRetry: Int = 1 protected var executeRetry: Int = 1 - protected var enableMetaSSL: Boolean = false - protected var enableGraphSSL: Boolean = false - protected var enableStorageSSL: Boolean = false - protected var sslSignType: SSLSignType.Value = _ - protected var caSignParam: CASignedSSLParam = null - protected var selfSignParam: SelfSignedSSLParam = null + protected var enableMetaSSL: Boolean = false + protected var enableGraphSSL: Boolean = false + protected var enableStorageSSL: Boolean = false + protected var sslSignType: SSLSignType.Value = _ + protected var caSignParam: CASSLSignParams = null + protected var selfSignParam: SelfSSLSignParams = null def withMetaAddress(metaAddress: String): ConfigBuilder = { this.metaAddress = metaAddress @@ -91,8 +94,7 @@ object NebulaConnectionConfig { * set enableMetaSSL, enableMetaSSL is optional */ def withEnableMetaSSL(enableMetaSSL: Boolean): ConfigBuilder = { - LOG.warn("metaSSL is not supported yet.") - this.enableMetaSSL = false + this.enableMetaSSL = enableMetaSSL this } @@ -108,8 +110,7 @@ object NebulaConnectionConfig { * set enableStorageSSL, enableStorageSSL is optional */ def withEnableStorageSSL(enableStorageSSL: Boolean): ConfigBuilder = { - LOG.warn("storageSSL is not supported yet.") - this.enableStorageSSL = false + this.enableStorageSSL = enableStorageSSL this } @@ -127,8 +128,7 @@ object NebulaConnectionConfig { def withCaSSLSignParam(caCrtFilePath: String, crtFilePath: String, keyFilePath: String): ConfigBuilder = { - val caSignParam = new CASignedSSLParam(caCrtFilePath, crtFilePath, keyFilePath) - this.caSignParam = caSignParam + this.caSignParam = CASSLSignParams(caCrtFilePath, crtFilePath, keyFilePath) this } @@ -138,8 +138,7 @@ object NebulaConnectionConfig { def withSelfSSLSignParam(crtFilePath: String, keyFilePath: String, password: String): ConfigBuilder = { - val selfSignParam = new SelfSignedSSLParam(crtFilePath, keyFilePath, password) - this.selfSignParam = selfSignParam + this.selfSignParam = SelfSSLSignParams(crtFilePath, keyFilePath, password) this } @@ -156,22 +155,21 @@ object NebulaConnectionConfig { // check ssl param if (enableMetaSSL || enableGraphSSL || enableStorageSSL) { assert( - (enableStorageSSL && enableMetaSSL && enableGraphSSL) - || (!enableStorageSSL && !enableMetaSSL && enableGraphSSL), - "ssl priority order: storage > meta > graph " + - "please make sure graph ssl is enable when storage and meta ssl is enable." + !enableStorageSSL || enableStorageSSL && enableMetaSSL, + "ssl priority order: storage > meta = graph " + + "please make sure meta ssl is enabled when storage ssl is enabled." ) sslSignType match { case SSLSignType.CA => assert( - caSignParam != null && caSignParam.getCaCrtFilePath != null - && caSignParam.getCrtFilePath != null && caSignParam.getKeyFilePath != null, + caSignParam != null && caSignParam.caCrtFilePath != null + && caSignParam.crtFilePath != null && caSignParam.keyFilePath != null, "ssl sign type is CA, param can not be null" ) case SSLSignType.SELF => assert( - selfSignParam != null && selfSignParam.getCrtFilePath != null - && selfSignParam.getKeyFilePath != null && selfSignParam.getPassword != null, + selfSignParam != null && selfSignParam.crtFilePath != null + && selfSignParam.keyFilePath != null && selfSignParam.password != null, "ssl sign type is SELF, param can not be null" ) case _ => assert(false, "SSLSignType config is null") diff --git a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/NebulaEnum.scala b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/NebulaEnum.scala index 55af9d58..e64bbf40 100644 --- a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/NebulaEnum.scala +++ b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/NebulaEnum.scala @@ -37,10 +37,3 @@ object WriteMode extends Enumeration { val UPDATE = Value("update") val DELETE = Value("delete") } - -object SSLSignType extends Enumeration { - - type signType = Value - val CA = Value("ca") - val SELF = Value("self") -} diff --git a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/NebulaOptions.scala b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/NebulaOptions.scala index 0326e660..11051a68 100644 --- a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/NebulaOptions.scala +++ b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/NebulaOptions.scala @@ -8,8 +8,8 @@ package com.vesoft.nebula.connector import java.util.Properties import com.google.common.net.HostAndPort -import com.vesoft.nebula.client.graph.data.{CASignedSSLParam, SelfSignedSSLParam} import com.vesoft.nebula.connector.connector.Address +import com.vesoft.nebula.connector.ssl.{CASSLSignParams, SSLSignType, SelfSSLSignParams} import org.apache.commons.lang.StringUtils import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap @@ -67,19 +67,21 @@ class NebulaOptions(@transient val parameters: CaseInsensitiveMap[String])( parameters.getOrElse(ENABLE_GRAPH_SSL, DEFAULT_ENABLE_GRAPH_SSL).toString.toBoolean val enableMetaSSL: Boolean = parameters.getOrElse(ENABLE_META_SSL, DEFAULT_ENABLE_META_SSL).toString.toBoolean - var sslSignType: String = null - var caSignParam: CASignedSSLParam = null - var selfSignParam: SelfSignedSSLParam = null + val enableStorageSSL: Boolean = + parameters.getOrElse(ENABLE_STORAGE_SSL, DEFAULT_ENABLE_STORAGE_SSL).toString.toBoolean + var sslSignType: String = _ + var caSignParam: CASSLSignParams = _ + var selfSignParam: SelfSSLSignParams = _ if (enableGraphSSL || enableMetaSSL) { sslSignType = parameters.get(SSL_SIGN_TYPE).get SSLSignType.withName(sslSignType) match { case SSLSignType.CA => { val params = parameters.get(CA_SIGN_PARAM).get.split(",") - caSignParam = new CASignedSSLParam(params(0), params(1), params(2)) + caSignParam = new CASSLSignParams(params(0), params(1), params(2)) } case SSLSignType.SELF => { val params = parameters.get(SELF_SIGN_PARAM).get.split(",") - selfSignParam = new SelfSignedSSLParam(params(0), params(1), params(2)) + selfSignParam = new SelfSSLSignParams(params(0), params(1), params(2)) } } } @@ -213,17 +215,18 @@ object NebulaOptions { val LABEL: String = "label" /** connection config */ - val TIMEOUT: String = "timeout" - val CONNECTION_RETRY: String = "connectionRetry" - val EXECUTION_RETRY: String = "executionRetry" - val RATE_TIME_OUT: String = "reteTimeOut" - val USER_NAME: String = "user" - val PASSWD: String = "passwd" - val ENABLE_GRAPH_SSL: String = "enableGraphSSL" - val ENABLE_META_SSL: String = "enableMetaSSL" - val SSL_SIGN_TYPE: String = "sslSignType" - val CA_SIGN_PARAM: String = "caSignParam" - val SELF_SIGN_PARAM: String = "selfSignParam" + val TIMEOUT: String = "timeout" + val CONNECTION_RETRY: String = "connectionRetry" + val EXECUTION_RETRY: String = "executionRetry" + val RATE_TIME_OUT: String = "reteTimeOut" + val USER_NAME: String = "user" + val PASSWD: String = "passwd" + val ENABLE_GRAPH_SSL: String = "enableGraphSSL" + val ENABLE_META_SSL: String = "enableMetaSSL" + val ENABLE_STORAGE_SSL: String = "enableStorageSSL" + val SSL_SIGN_TYPE: String = "sslSignType" + val CA_SIGN_PARAM: String = "caSignParam" + val SELF_SIGN_PARAM: String = "selfSignParam" /** read config */ val RETURN_COLS: String = "returnCols" @@ -254,8 +257,9 @@ object NebulaOptions { val DEFAULT_USER_NAME: String = "root" val DEFAULT_PASSWD: String = "nebula" - val DEFAULT_ENABLE_GRAPH_SSL: Boolean = false - val DEFAULT_ENABLE_META_SSL: Boolean = false + val DEFAULT_ENABLE_GRAPH_SSL: Boolean = false + val DEFAULT_ENABLE_META_SSL: Boolean = false + val DEFAULT_ENABLE_STORAGE_SSL: Boolean = false val DEFAULT_LIMIT: Int = 1000 diff --git a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/nebula/GraphProvider.scala b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/nebula/GraphProvider.scala index ac4b6b41..70ca0093 100644 --- a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/nebula/GraphProvider.scala +++ b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/nebula/GraphProvider.scala @@ -13,9 +13,9 @@ import com.vesoft.nebula.client.graph.data.{ SelfSignedSSLParam } import com.vesoft.nebula.client.graph.net.{NebulaPool, Session} -import com.vesoft.nebula.connector.SSLSignType import com.vesoft.nebula.connector.connector.Address import com.vesoft.nebula.connector.exception.GraphConnectException +import com.vesoft.nebula.connector.ssl.{CASSLSignParams, SSLSignType, SelfSSLSignParams} import org.apache.log4j.Logger import scala.collection.JavaConverters._ @@ -25,10 +25,11 @@ import scala.collection.mutable.ListBuffer * GraphProvider for Nebula Graph Service */ class GraphProvider(addresses: List[Address], + timeout: Int, enableSSL: Boolean = false, sslSignType: String = null, - caSignParam: CASignedSSLParam = null, - selfSignParam: SelfSignedSSLParam = null) + caSignParam: CASSLSignParams = null, + selfSignParam: SelfSSLSignParams = null) extends AutoCloseable with Serializable { private[this] lazy val LOG = Logger.getLogger(this.getClass) @@ -41,13 +42,22 @@ class GraphProvider(addresses: List[Address], address.append(new HostAddress(addr._1, addr._2)) } nebulaPoolConfig.setMaxConnSize(1) + nebulaPoolConfig.setTimeout(timeout) if (enableSSL) { nebulaPoolConfig.setEnableSsl(enableSSL) SSLSignType.withName(sslSignType) match { - case SSLSignType.CA => nebulaPoolConfig.setSslParam(caSignParam) - case SSLSignType.SELF => nebulaPoolConfig.setSslParam(selfSignParam) - case _ => throw new IllegalArgumentException("ssl sign type is not supported") + case SSLSignType.CA => + nebulaPoolConfig.setSslParam( + new CASignedSSLParam(caSignParam.caCrtFilePath, + caSignParam.crtFilePath, + caSignParam.keyFilePath)) + case SSLSignType.SELF => + nebulaPoolConfig.setSslParam( + new SelfSignedSSLParam(selfSignParam.crtFilePath, + selfSignParam.keyFilePath, + selfSignParam.password)) + case _ => throw new IllegalArgumentException("ssl sign type is not supported") } } pool.init(address.asJava, nebulaPoolConfig) diff --git a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/nebula/MetaProvider.scala b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/nebula/MetaProvider.scala index 7120a17f..46162e85 100644 --- a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/nebula/MetaProvider.scala +++ b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/nebula/MetaProvider.scala @@ -5,10 +5,16 @@ package com.vesoft.nebula.connector.nebula -import com.vesoft.nebula.client.graph.data.HostAddress +import com.vesoft.nebula.client.graph.data.{ + CASignedSSLParam, + HostAddress, + SSLParam, + SelfSignedSSLParam +} import com.vesoft.nebula.client.meta.MetaClient import com.vesoft.nebula.connector.connector.Address import com.vesoft.nebula.connector.DataTypeEnum +import com.vesoft.nebula.connector.ssl.{CASSLSignParams, SSLSignType, SelfSSLSignParams} import com.vesoft.nebula.meta.{PropertyType, Schema} import scala.collection.JavaConverters._ @@ -17,11 +23,32 @@ import scala.collection.mutable class MetaProvider(addresses: List[Address], timeout: Int, connectionRetry: Int, - executionRetry: Int) + executionRetry: Int, + enableSSL: Boolean, + sslSignType: String = null, + caSignParam: CASSLSignParams, + selfSignParam: SelfSSLSignParams) extends AutoCloseable { - val metaAddress = addresses.map(address => new HostAddress(address._1, address._2)).asJava - val client = new MetaClient(metaAddress, timeout, connectionRetry, executionRetry) + val metaAddress = addresses.map(address => new HostAddress(address._1, address._2)).asJava + var client: MetaClient = null + var sslParam: SSLParam = null + if (enableSSL) { + SSLSignType.withName(sslSignType) match { + case SSLSignType.CA => + sslParam = new CASignedSSLParam(caSignParam.caCrtFilePath, + caSignParam.crtFilePath, + caSignParam.keyFilePath) + case SSLSignType.SELF => + sslParam = new SelfSignedSSLParam(selfSignParam.crtFilePath, + selfSignParam.keyFilePath, + selfSignParam.password) + case _ => throw new IllegalArgumentException("ssl sign type is not supported") + } + client = new MetaClient(metaAddress, timeout, connectionRetry, executionRetry, true, sslParam) + } else { + client = new MetaClient(metaAddress, timeout, connectionRetry, executionRetry) + } client.connect() def getPartitionNumber(space: String): Int = { diff --git a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/package.scala b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/package.scala index 4167df03..59482660 100644 --- a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/package.scala +++ b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/package.scala @@ -5,6 +5,7 @@ package com.vesoft.nebula.connector +import com.vesoft.nebula.connector.ssl.SSLSignType import com.vesoft.nebula.connector.writer.NebulaExecutor import org.apache.commons.codec.digest.MurmurHash2 import org.apache.spark.rdd.RDD @@ -105,7 +106,7 @@ package object connector { def loadVerticesToDF(): DataFrame = { assert(connectionConfig != null && readConfig != null, "nebula config is not set, please call nebula() before loadVerticesToDF") - reader + val dfReader = reader .format(classOf[NebulaDataSource].getName) .option(NebulaOptions.TYPE, DataTypeEnum.VERTEX.toString) .option(NebulaOptions.SPACE_NAME, readConfig.getSpace) @@ -118,7 +119,20 @@ package object connector { .option(NebulaOptions.TIMEOUT, connectionConfig.getTimeout) .option(NebulaOptions.CONNECTION_RETRY, connectionConfig.getConnectionRetry) .option(NebulaOptions.EXECUTION_RETRY, connectionConfig.getExecRetry) - .load() + .option(NebulaOptions.ENABLE_META_SSL, connectionConfig.getEnableMetaSSL) + .option(NebulaOptions.ENABLE_STORAGE_SSL, connectionConfig.getEnableStorageSSL) + + if (connectionConfig.getEnableStorageSSL || connectionConfig.getEnableMetaSSL) { + dfReader.option(NebulaOptions.SSL_SIGN_TYPE, connectionConfig.getSignType) + SSLSignType.withName(connectionConfig.getSignType) match { + case SSLSignType.CA => + dfReader.option(NebulaOptions.CA_SIGN_PARAM, connectionConfig.getCaSignParam) + case SSLSignType.SELF => + dfReader.option(NebulaOptions.SELF_SIGN_PARAM, connectionConfig.getSelfSignParam) + } + } + + dfReader.load() } /** @@ -129,7 +143,7 @@ package object connector { assert(connectionConfig != null && readConfig != null, "nebula config is not set, please call nebula() before loadEdgesToDF") - reader + val dfReader = reader .format(classOf[NebulaDataSource].getName) .option(NebulaOptions.TYPE, DataTypeEnum.EDGE.toString) .option(NebulaOptions.SPACE_NAME, readConfig.getSpace) @@ -142,7 +156,20 @@ package object connector { .option(NebulaOptions.TIMEOUT, connectionConfig.getTimeout) .option(NebulaOptions.CONNECTION_RETRY, connectionConfig.getConnectionRetry) .option(NebulaOptions.EXECUTION_RETRY, connectionConfig.getExecRetry) - .load() + .option(NebulaOptions.ENABLE_META_SSL, connectionConfig.getEnableMetaSSL) + .option(NebulaOptions.ENABLE_STORAGE_SSL, connectionConfig.getEnableStorageSSL) + + if (connectionConfig.getEnableStorageSSL || connectionConfig.getEnableMetaSSL) { + dfReader.option(NebulaOptions.SSL_SIGN_TYPE, connectionConfig.getSignType) + SSLSignType.withName(connectionConfig.getSignType) match { + case SSLSignType.CA => + dfReader.option(NebulaOptions.CA_SIGN_PARAM, connectionConfig.getCaSignParam) + case SSLSignType.SELF => + dfReader.option(NebulaOptions.SELF_SIGN_PARAM, connectionConfig.getSelfSignParam) + } + } + + dfReader.load() } /** @@ -226,7 +253,7 @@ package object connector { assert(connectionConfig != null && writeNebulaConfig != null, "nebula config is not set, please call nebula() before writeVertices") val writeConfig = writeNebulaConfig.asInstanceOf[WriteNebulaVertexConfig] - writer + val dfWriter = writer .format(classOf[NebulaDataSource].getName) .mode(SaveMode.Overwrite) .option(NebulaOptions.TYPE, DataTypeEnum.VERTEX.toString) @@ -244,7 +271,20 @@ package object connector { .option(NebulaOptions.TIMEOUT, connectionConfig.getTimeout) .option(NebulaOptions.CONNECTION_RETRY, connectionConfig.getConnectionRetry) .option(NebulaOptions.EXECUTION_RETRY, connectionConfig.getExecRetry) - .save() + .option(NebulaOptions.ENABLE_GRAPH_SSL, connectionConfig.getEnableGraphSSL) + .option(NebulaOptions.ENABLE_META_SSL, connectionConfig.getEnableMetaSSL) + + if (connectionConfig.getEnableGraphSSL || connectionConfig.getEnableMetaSSL) { + dfWriter.option(NebulaOptions.SSL_SIGN_TYPE, connectionConfig.getSignType) + SSLSignType.withName(connectionConfig.getSignType) match { + case SSLSignType.CA => + dfWriter.option(NebulaOptions.CA_SIGN_PARAM, connectionConfig.getCaSignParam) + case SSLSignType.SELF => + dfWriter.option(NebulaOptions.SELF_SIGN_PARAM, connectionConfig.getSelfSignParam) + } + } + + dfWriter.save() } /** @@ -255,7 +295,7 @@ package object connector { assert(connectionConfig != null && writeNebulaConfig != null, "nebula config is not set, please call nebula() before writeEdges") val writeConfig = writeNebulaConfig.asInstanceOf[WriteNebulaEdgeConfig] - writer + val dfWriter = writer .format(classOf[NebulaDataSource].getName) .mode(SaveMode.Overwrite) .option(NebulaOptions.TYPE, DataTypeEnum.EDGE.toString) @@ -278,7 +318,20 @@ package object connector { .option(NebulaOptions.TIMEOUT, connectionConfig.getTimeout) .option(NebulaOptions.CONNECTION_RETRY, connectionConfig.getConnectionRetry) .option(NebulaOptions.EXECUTION_RETRY, connectionConfig.getExecRetry) - .save() + .option(NebulaOptions.ENABLE_GRAPH_SSL, connectionConfig.getEnableGraphSSL) + .option(NebulaOptions.ENABLE_META_SSL, connectionConfig.getEnableMetaSSL) + + if (connectionConfig.getEnableGraphSSL || connectionConfig.getEnableMetaSSL) { + dfWriter.option(NebulaOptions.SSL_SIGN_TYPE, connectionConfig.getSignType) + SSLSignType.withName(connectionConfig.getSignType) match { + case SSLSignType.CA => + dfWriter.option(NebulaOptions.CA_SIGN_PARAM, connectionConfig.getCaSignParam) + case SSLSignType.SELF => + dfWriter.option(NebulaOptions.SELF_SIGN_PARAM, connectionConfig.getSelfSignParam) + } + } + + dfWriter.save() } } diff --git a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReader.scala b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReader.scala index cfb2a6bb..7191b9be 100644 --- a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReader.scala +++ b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReader.scala @@ -5,13 +5,20 @@ package com.vesoft.nebula.connector.reader -import com.vesoft.nebula.client.graph.data.{HostAddress, ValueWrapper} +import com.vesoft.nebula.client.graph.data.{ + CASignedSSLParam, + HostAddress, + SSLParam, + SelfSignedSSLParam, + ValueWrapper +} import com.vesoft.nebula.client.storage.StorageClient import com.vesoft.nebula.client.storage.data.{BaseTableRow, VertexTableRow} import com.vesoft.nebula.connector.NebulaUtils.NebulaValueGetter import com.vesoft.nebula.connector.exception.GraphConnectException import com.vesoft.nebula.connector.{NebulaOptions, NebulaUtils, PartitionUtils} import com.vesoft.nebula.connector.nebula.MetaProvider +import com.vesoft.nebula.connector.ssl.SSLSignType import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.sources.v2.reader.InputPartitionReader @@ -45,17 +52,49 @@ abstract class NebulaPartitionReader extends InputPartitionReader[InternalRow] { this() this.schema = schema - metaProvider = new MetaProvider(nebulaOptions.getMetaAddress, - nebulaOptions.timeout, - nebulaOptions.connectionRetry, - nebulaOptions.executionRetry) + metaProvider = new MetaProvider( + nebulaOptions.getMetaAddress, + nebulaOptions.timeout, + nebulaOptions.connectionRetry, + nebulaOptions.executionRetry, + nebulaOptions.enableMetaSSL, + nebulaOptions.sslSignType, + nebulaOptions.caSignParam, + nebulaOptions.selfSignParam + ) val address: ListBuffer[HostAddress] = new ListBuffer[HostAddress] for (addr <- nebulaOptions.getMetaAddress) { address.append(new HostAddress(addr._1, addr._2)) } - this.storageClient = new StorageClient(address.asJava) + var sslParam: SSLParam = null + if (nebulaOptions.enableStorageSSL) { + SSLSignType.withName(nebulaOptions.sslSignType) match { + case SSLSignType.CA => { + val caSSLSignParams = nebulaOptions.caSignParam + sslParam = new CASignedSSLParam(caSSLSignParams.caCrtFilePath, + caSSLSignParams.crtFilePath, + caSSLSignParams.keyFilePath) + } + case SSLSignType.SELF => { + val selfSSLSignParams = nebulaOptions.selfSignParam + sslParam = new SelfSignedSSLParam(selfSSLSignParams.crtFilePath, + selfSSLSignParams.keyFilePath, + selfSSLSignParams.password) + } + case _ => throw new IllegalArgumentException("ssl sign type is not supported") + } + this.storageClient = new StorageClient(address.asJava, + nebulaOptions.timeout, + nebulaOptions.connectionRetry, + nebulaOptions.executionRetry, + true, + sslParam) + } else { + this.storageClient = new StorageClient(address.asJava, nebulaOptions.timeout) + } + if (!storageClient.connect()) { throw new GraphConnectException("storage connect failed.") } diff --git a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaSourceReader.scala b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaSourceReader.scala index 21f8ed0d..f6da55e4 100644 --- a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaSourceReader.scala +++ b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaSourceReader.scala @@ -41,10 +41,16 @@ abstract class NebulaSourceReader(nebulaOptions: NebulaOptions) extends DataSour val returnCols = nebulaOptions.getReturnCols val noColumn = nebulaOptions.noColumn val fields: ListBuffer[StructField] = new ListBuffer[StructField] - val metaProvider = new MetaProvider(nebulaOptions.getMetaAddress, - nebulaOptions.timeout, - nebulaOptions.connectionRetry, - nebulaOptions.executionRetry) + val metaProvider = new MetaProvider( + nebulaOptions.getMetaAddress, + nebulaOptions.timeout, + nebulaOptions.connectionRetry, + nebulaOptions.executionRetry, + nebulaOptions.enableMetaSSL, + nebulaOptions.sslSignType, + nebulaOptions.caSignParam, + nebulaOptions.selfSignParam + ) import scala.collection.JavaConverters._ var schemaCols: Seq[ColumnDef] = Seq() diff --git a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/ssl/SSLEnum.scala b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/ssl/SSLEnum.scala new file mode 100644 index 00000000..de297729 --- /dev/null +++ b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/ssl/SSLEnum.scala @@ -0,0 +1,13 @@ +/* Copyright (c) 2021 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +package com.vesoft.nebula.connector.ssl + +object SSLSignType extends Enumeration { + + type signType = Value + val CA = Value("ca") + val SELF = Value("self") +} diff --git a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/ssl/SSLSignParams.scala b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/ssl/SSLSignParams.scala new file mode 100644 index 00000000..4c9617d8 --- /dev/null +++ b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/ssl/SSLSignParams.scala @@ -0,0 +1,10 @@ +/* Copyright (c) 2021 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +package com.vesoft.nebula.connector.ssl + +case class CASSLSignParams(caCrtFilePath: String, crtFilePath: String, keyFilePath: String) + +case class SelfSSLSignParams(crtFilePath: String, keyFilePath: String, password: String) diff --git a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/writer/NebulaWriter.scala b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/writer/NebulaWriter.scala index a27c754f..05718536 100644 --- a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/writer/NebulaWriter.scala +++ b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/writer/NebulaWriter.scala @@ -19,15 +19,24 @@ class NebulaWriter(nebulaOptions: NebulaOptions) extends Serializable { val failedExecs: ListBuffer[String] = new ListBuffer[String] - val metaProvider = new MetaProvider(nebulaOptions.getMetaAddress, - nebulaOptions.timeout, - nebulaOptions.connectionRetry, - nebulaOptions.executionRetry) - val graphProvider = new GraphProvider(nebulaOptions.getGraphAddress, - nebulaOptions.enableGraphSSL, - nebulaOptions.sslSignType, - nebulaOptions.caSignParam, - nebulaOptions.selfSignParam) + val metaProvider = new MetaProvider( + nebulaOptions.getMetaAddress, + nebulaOptions.timeout, + nebulaOptions.connectionRetry, + nebulaOptions.executionRetry, + nebulaOptions.enableMetaSSL, + nebulaOptions.sslSignType, + nebulaOptions.caSignParam, + nebulaOptions.selfSignParam + ) + val graphProvider = new GraphProvider( + nebulaOptions.getGraphAddress, + nebulaOptions.timeout, + nebulaOptions.enableGraphSSL, + nebulaOptions.sslSignType, + nebulaOptions.caSignParam, + nebulaOptions.selfSignParam + ) val isVidStringType = metaProvider.getVidType(nebulaOptions.spaceName) == VidType.STRING def prepareSpace(): Unit = { diff --git a/nebula-spark-connector/src/test/scala/com/vesoft/nebula/connector/NebulaConfigSuite.scala b/nebula-spark-connector/src/test/scala/com/vesoft/nebula/connector/NebulaConfigSuite.scala index 486d0686..347a2565 100644 --- a/nebula-spark-connector/src/test/scala/com/vesoft/nebula/connector/NebulaConfigSuite.scala +++ b/nebula-spark-connector/src/test/scala/com/vesoft/nebula/connector/NebulaConfigSuite.scala @@ -5,6 +5,7 @@ package com.vesoft.nebula.connector +import com.vesoft.nebula.connector.ssl.SSLSignType import org.scalatest.BeforeAndAfterAll import org.scalatest.funsuite.AnyFunSuite @@ -35,6 +36,19 @@ class NebulaConfigSuite extends AnyFunSuite with BeforeAndAfterAll { .build() } + test("test correct ssl config with wrong ssl priority") { + assertThrows[AssertionError]( + NebulaConnectionConfig + .builder() + .withMetaAddress("127.0.0.1:9559") + .withGraphAddress("127.0.0.1:9669") + .withEnableStorageSSL(true) + .withEnableMetaSSL(false) + .withSSLSignType(SSLSignType.CA) + .withCaSSLSignParam("caCrtFile", "crtFile", "keyFile") + .build()) + } + test("test correct ssl config with no sign type param") { assertThrows[AssertionError]( NebulaConnectionConfig @@ -43,7 +57,7 @@ class NebulaConfigSuite extends AnyFunSuite with BeforeAndAfterAll { .withGraphAddress("127.0.0.1:9669") .withEnableGraphSSL(true) .withEnableMetaSSL(true) - .withCaSSLSignParam("cacrtFile", "crtFile", "keyFile") + .withCaSSLSignParam("caCrtFile", "crtFile", "keyFile") .build()) } diff --git a/nebula-spark-connector/src/test/scala/com/vesoft/nebula/connector/nebula/GraphProviderTest.scala b/nebula-spark-connector/src/test/scala/com/vesoft/nebula/connector/nebula/GraphProviderTest.scala index d199d6ad..f1a3aca8 100644 --- a/nebula-spark-connector/src/test/scala/com/vesoft/nebula/connector/nebula/GraphProviderTest.scala +++ b/nebula-spark-connector/src/test/scala/com/vesoft/nebula/connector/nebula/GraphProviderTest.scala @@ -16,7 +16,7 @@ class GraphProviderTest extends AnyFunSuite with BeforeAndAfterAll { override def beforeAll(): Unit = { val addresses: List[Address] = List(new Address("127.0.0.1", 9669)) - graphProvider = new GraphProvider(addresses) + graphProvider = new GraphProvider(addresses, 3000) val graphMock = new NebulaGraphMock graphMock.mockIntIdGraph() graphMock.mockStringIdGraph() diff --git a/nebula-spark-connector/src/test/scala/com/vesoft/nebula/connector/nebula/MetaProviderTest.scala b/nebula-spark-connector/src/test/scala/com/vesoft/nebula/connector/nebula/MetaProviderTest.scala index 2b94b4d9..307865fb 100644 --- a/nebula-spark-connector/src/test/scala/com/vesoft/nebula/connector/nebula/MetaProviderTest.scala +++ b/nebula-spark-connector/src/test/scala/com/vesoft/nebula/connector/nebula/MetaProviderTest.scala @@ -17,7 +17,7 @@ class MetaProviderTest extends AnyFunSuite with BeforeAndAfterAll { override def beforeAll(): Unit = { val addresses: List[Address] = List(new Address("127.0.0.1", 9559)) - metaProvider = new MetaProvider(addresses, 6000, 3, 3) + metaProvider = new MetaProvider(addresses, 6000, 3, 3, false, null, null, null) val graphMock = new NebulaGraphMock graphMock.mockStringIdGraph() diff --git a/nebula-spark-connector/src/test/scala/com/vesoft/nebula/connector/writer/WriteDeleteSuite.scala b/nebula-spark-connector/src/test/scala/com/vesoft/nebula/connector/writer/WriteDeleteSuite.scala index 9c8f7f1f..3f0111ea 100644 --- a/nebula-spark-connector/src/test/scala/com/vesoft/nebula/connector/writer/WriteDeleteSuite.scala +++ b/nebula-spark-connector/src/test/scala/com/vesoft/nebula/connector/writer/WriteDeleteSuite.scala @@ -25,7 +25,7 @@ class WriteDeleteSuite extends AnyFunSuite with BeforeAndAfterAll { test("write vertex into test_write_string space with delete mode") { SparkMock.deleteVertex() val addresses: List[Address] = List(new Address("127.0.0.1", 9669)) - val graphProvider = new GraphProvider(addresses) + val graphProvider = new GraphProvider(addresses, 3000) graphProvider.switchSpace("root", "nebula", "test_write_string") val resultSet: ResultSet = @@ -37,7 +37,7 @@ class WriteDeleteSuite extends AnyFunSuite with BeforeAndAfterAll { test("write edge into test_write_string space with delete mode") { SparkMock.deleteEdge() val addresses: List[Address] = List(new Address("127.0.0.1", 9669)) - val graphProvider = new GraphProvider(addresses) + val graphProvider = new GraphProvider(addresses, 3000) graphProvider.switchSpace("root", "nebula", "test_write_string") val resultSet: ResultSet = diff --git a/nebula-spark-connector/src/test/scala/com/vesoft/nebula/connector/writer/WriteInsertSuite.scala b/nebula-spark-connector/src/test/scala/com/vesoft/nebula/connector/writer/WriteInsertSuite.scala index b5865a46..124ddc9e 100644 --- a/nebula-spark-connector/src/test/scala/com/vesoft/nebula/connector/writer/WriteInsertSuite.scala +++ b/nebula-spark-connector/src/test/scala/com/vesoft/nebula/connector/writer/WriteInsertSuite.scala @@ -24,7 +24,7 @@ class WriteInsertSuite extends AnyFunSuite with BeforeAndAfterAll { test("write vertex into test_write_string space with insert mode") { SparkMock.writeVertex() val addresses: List[Address] = List(new Address("127.0.0.1", 9669)) - val graphProvider = new GraphProvider(addresses) + val graphProvider = new GraphProvider(addresses, 3000) graphProvider.switchSpace("root", "nebula", "test_write_string") val createIndexResult: ResultSet = graphProvider.submit( @@ -49,7 +49,7 @@ class WriteInsertSuite extends AnyFunSuite with BeforeAndAfterAll { SparkMock.writeEdge() val addresses: List[Address] = List(new Address("127.0.0.1", 9669)) - val graphProvider = new GraphProvider(addresses) + val graphProvider = new GraphProvider(addresses, 3000) graphProvider.switchSpace("root", "nebula", "test_write_string") val createIndexResult: ResultSet = graphProvider.submit(