diff --git a/coffea/processor/executor.py b/coffea/processor/executor.py index 22ef9e106..013e0112c 100644 --- a/coffea/processor/executor.py +++ b/coffea/processor/executor.py @@ -338,8 +338,10 @@ def run_spark_job(fileset, processor_instance, executor, executor_args={}, executor_args.setdefault('config', None) executor_args.setdefault('file_type', 'parquet') - executor_args.setdefault('laurelin_version', '0.1.0') + executor_args.setdefault('laurelin_version', '0.3.0') + executor_args.setdefault('treeName', 'Events') file_type = executor_args['file_type'] + treeName = executor_args['treeName'] if executor_args['config'] is None: executor_args.pop('config') @@ -356,7 +358,7 @@ def run_spark_job(fileset, processor_instance, executor, executor_args={}, raise ValueError("Expected 'spark' to be a pyspark.sql.session.SparkSession") dfslist = _spark_make_dfs(spark, fileset, partitionsize, processor_instance.columns, - thread_workers, file_type) + thread_workers, file_type, treeName) output = processor_instance.accumulator.identity() executor(spark, dfslist, processor_instance, output, thread_workers) diff --git a/coffea/processor/spark/detail.py b/coffea/processor/spark/detail.py index fcb6f45c6..6d70ff9fb 100644 --- a/coffea/processor/spark/detail.py +++ b/coffea/processor/spark/detail.py @@ -32,7 +32,7 @@ def _spark_initialize(config=_default_config, **kwargs): cfg_actual = cfg_actual.config('spark.ui.showConsoleProgress', 'false') # always load laurelin even if we may not use it - kwargs.setdefault('laurelin_version', '0.1.0') + kwargs.setdefault('laurelin_version', '0.3.0') laurelin = kwargs['laurelin_version'] cfg_actual = cfg_actual.config('spark.jars.packages', 'edu.vanderbilt.accre:laurelin:%s' % laurelin) @@ -48,16 +48,21 @@ def _spark_initialize(config=_default_config, **kwargs): return session -def _read_df(spark, dataset, files_or_dirs, ana_cols, partitionsize, file_type, treeName='Events'): - if not isinstance(files_or_dirs, Sequence): +def _read_df(spark, dataset, files_or_dirs, ana_cols, partitionsize, file_type, treeName): + flist = files_or_dirs + tname = treeName + if isinstance(files_or_dirs, dict): + tname = files_or_dirs['treename'] + flist = files_or_dirs['files'] + if not isinstance(flist, Sequence): raise ValueError('spark dataset file list must be a Sequence (like list())') df = None if file_type == 'parquet': - df = spark.read.parquet(*files_or_dirs) + df = spark.read.parquet(flist) else: df = spark.read.format(file_type) \ - .option('tree', treeName) \ - .load(*files_or_dirs) + .option('tree', tname) \ + .load(flist) count = df.count() df_cols = set(df.columns) @@ -74,7 +79,8 @@ def _read_df(spark, dataset, files_or_dirs, ana_cols, partitionsize, file_type, return df, dataset, count -def _spark_make_dfs(spark, fileset, partitionsize, columns, thread_workers, file_type, status=True): +def _spark_make_dfs(spark, fileset, partitionsize, columns, thread_workers, file_type, + treeName, status=True): dfs = {} ana_cols = set(columns) @@ -84,7 +90,8 @@ def dfs_accumulator(total, result): with ThreadPoolExecutor(max_workers=thread_workers) as executor: futures = set(executor.submit(_read_df, spark, ds, files, - ana_cols, partitionsize, file_type) for ds, files in fileset.items()) + ana_cols, partitionsize, file_type, + treeName) for ds, files in fileset.items()) futures_handler(futures, dfs, status, 'datasets', 'loading', futures_accumulator=dfs_accumulator) diff --git a/tests/test_spark.py b/tests/test_spark.py index 701731d9c..3d514cac9 100644 --- a/tests/test_spark.py +++ b/tests/test_spark.py @@ -40,8 +40,8 @@ def test_spark_executor(): spark = _spark_initialize(config=spark_config,log_level='ERROR',spark_progress=False) - filelist = {'ZJets': ['file:'+osp.join(os.getcwd(),'tests/samples/nano_dy.root')], - 'Data' : ['file:'+osp.join(os.getcwd(),'tests/samples/nano_dimuon.root')] + filelist = {'ZJets': {'files': ['file:'+osp.join(os.getcwd(),'tests/samples/nano_dy.root')], 'treename': 'Events' }, + 'Data' : {'files': ['file:'+osp.join(os.getcwd(),'tests/samples/nano_dimuon.root')], 'treename': 'Events'} } from coffea.processor.test_items import NanoTestProcessor