diff --git a/src/main/scala/org/apache/spark/sql/redis/RedisSourceRelation.scala b/src/main/scala/org/apache/spark/sql/redis/RedisSourceRelation.scala index 479768d0..90b6feb9 100644 --- a/src/main/scala/org/apache/spark/sql/redis/RedisSourceRelation.scala +++ b/src/main/scala/org/apache/spark/sql/redis/RedisSourceRelation.scala @@ -157,13 +157,13 @@ class RedisSourceRelation(override val sqlContext: SQLContext, new GenericRow(Array[Any]()) } } else { - val filteredSchema = { - val requiredColumnsSet = Set(requiredColumns: _*) - val filteredFields = schema.fields - .filter { f => - requiredColumnsSet.contains(f.name) - } - StructType(filteredFields) + // filter schema columns, it should be in the same order as given 'requiredColumns' + val requiredSchema = { + val fieldsMap = schema.fields.map(f => (f.name, f)).toMap + val requiredFields = requiredColumns.map { c => + fieldsMap(c) + } + StructType(requiredFields) } val keyType = if (persistenceModel == SqlOptionModelBinary) { @@ -173,12 +173,12 @@ class RedisSourceRelation(override val sqlContext: SQLContext, } keysRdd.mapPartitions { partition => // grouped iterator to only allocate memory for a portion of rows - partition.grouped(iteratorGroupingSize).map { batch => + partition.grouped(iteratorGroupingSize).flatMap { batch => groupKeysByNode(redisConfig.hosts, batch.iterator) .flatMap { case (node, keys) => - scanRows(node, keys, keyType, filteredSchema, requiredColumns) + scanRows(node, keys, keyType, requiredSchema, requiredColumns) } - }.flatten + } } } } diff --git a/src/test/scala/com/redislabs/provider/redis/df/HashDataframeSuite.scala b/src/test/scala/com/redislabs/provider/redis/df/HashDataframeSuite.scala index 2c61513d..ab05a460 100644 --- a/src/test/scala/com/redislabs/provider/redis/df/HashDataframeSuite.scala +++ b/src/test/scala/com/redislabs/provider/redis/df/HashDataframeSuite.scala @@ -1,18 +1,21 @@ package com.redislabs.provider.redis.df import java.sql.{Date, Timestamp} +import java.util.UUID import com.redislabs.provider.redis.toRedisContext import com.redislabs.provider.redis.util.Person.{data, _} import com.redislabs.provider.redis.util.TestUtils._ import com.redislabs.provider.redis.util.{EntityId, Person} import org.apache.spark.SparkException -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.redis.RedisSourceRelation.tableDataKeyPattern import org.apache.spark.sql.redis._ -import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.{StructField, _} import org.scalatest.Matchers +import scala.util.Random + /** * @author The Viet Nguyen */ @@ -295,6 +298,37 @@ trait HashDataframeSuite extends RedisDataframeSuite with Matchers { } } + /** + * A test case for https://github.com/RedisLabs/spark-redis/issues/132 + */ + test("RedisSourceRelation.buildScan columns ordering") { + val schema = { + StructType(Array( + StructField("id", StringType), + StructField("int", IntegerType), + StructField("float", FloatType), + StructField("double", DoubleType), + StructField("str", StringType))) + } + + val rowsNum = 8 + val rdd = spark.sparkContext.parallelize(1 to rowsNum, 2).map { _ => + def genStr = UUID.randomUUID().toString + def genInt = Random.nextInt() + def genDouble = Random.nextDouble() + def genFloat = Random.nextFloat() + Row.fromSeq(Seq(genStr, genInt, genFloat, genDouble, genStr)) + } + + val df = spark.createDataFrame(rdd, schema) + val tableName = generateTableName("cols-ordering") + df.write.format(RedisFormat).option(SqlOptionTableName, tableName).save() + val loadedDf = spark.read.format(RedisFormat).option(SqlOptionTableName, tableName).load() + loadedDf.schema shouldBe schema + loadedDf.collect().length shouldBe rowsNum + loadedDf.show() + } + def saveMap(tableName: String): Unit = { Person.dataMaps.foreach { person => saveMap(tableName, person("name"), person)