Skip to content

Commit

Permalink
Merge pull request #166 from lgray/spark_tree_names
Browse files Browse the repository at this point in the history
Allow spark jobs to deal with multiple tree names
  • Loading branch information
lgray authored Sep 19, 2019
2 parents 36b3fd0 + e96aef2 commit f029dd7
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 12 deletions.
6 changes: 4 additions & 2 deletions coffea/processor/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,10 @@ def run_spark_job(fileset, processor_instance, executor, executor_args={},

executor_args.setdefault('config', None)
executor_args.setdefault('file_type', 'parquet')
executor_args.setdefault('laurelin_version', '0.1.0')
executor_args.setdefault('laurelin_version', '0.3.0')
executor_args.setdefault('treeName', 'Events')
file_type = executor_args['file_type']
treeName = executor_args['treeName']

if executor_args['config'] is None:
executor_args.pop('config')
Expand All @@ -356,7 +358,7 @@ def run_spark_job(fileset, processor_instance, executor, executor_args={},
raise ValueError("Expected 'spark' to be a pyspark.sql.session.SparkSession")

dfslist = _spark_make_dfs(spark, fileset, partitionsize, processor_instance.columns,
thread_workers, file_type)
thread_workers, file_type, treeName)

output = processor_instance.accumulator.identity()
executor(spark, dfslist, processor_instance, output, thread_workers)
Expand Down
23 changes: 15 additions & 8 deletions coffea/processor/spark/detail.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _spark_initialize(config=_default_config, **kwargs):
cfg_actual = cfg_actual.config('spark.ui.showConsoleProgress', 'false')

# always load laurelin even if we may not use it
kwargs.setdefault('laurelin_version', '0.1.0')
kwargs.setdefault('laurelin_version', '0.3.0')
laurelin = kwargs['laurelin_version']
cfg_actual = cfg_actual.config('spark.jars.packages',
'edu.vanderbilt.accre:laurelin:%s' % laurelin)
Expand All @@ -48,16 +48,21 @@ def _spark_initialize(config=_default_config, **kwargs):
return session


def _read_df(spark, dataset, files_or_dirs, ana_cols, partitionsize, file_type, treeName='Events'):
if not isinstance(files_or_dirs, Sequence):
def _read_df(spark, dataset, files_or_dirs, ana_cols, partitionsize, file_type, treeName):
flist = files_or_dirs
tname = treeName
if isinstance(files_or_dirs, dict):
tname = files_or_dirs['treename']
flist = files_or_dirs['files']
if not isinstance(flist, Sequence):
raise ValueError('spark dataset file list must be a Sequence (like list())')
df = None
if file_type == 'parquet':
df = spark.read.parquet(*files_or_dirs)
df = spark.read.parquet(flist)
else:
df = spark.read.format(file_type) \
.option('tree', treeName) \
.load(*files_or_dirs)
.option('tree', tname) \
.load(flist)
count = df.count()

df_cols = set(df.columns)
Expand All @@ -74,7 +79,8 @@ def _read_df(spark, dataset, files_or_dirs, ana_cols, partitionsize, file_type,
return df, dataset, count


def _spark_make_dfs(spark, fileset, partitionsize, columns, thread_workers, file_type, status=True):
def _spark_make_dfs(spark, fileset, partitionsize, columns, thread_workers, file_type,
treeName, status=True):
dfs = {}
ana_cols = set(columns)

Expand All @@ -84,7 +90,8 @@ def dfs_accumulator(total, result):

with ThreadPoolExecutor(max_workers=thread_workers) as executor:
futures = set(executor.submit(_read_df, spark, ds, files,
ana_cols, partitionsize, file_type) for ds, files in fileset.items())
ana_cols, partitionsize, file_type,
treeName) for ds, files in fileset.items())

futures_handler(futures, dfs, status, 'datasets', 'loading', futures_accumulator=dfs_accumulator)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def test_spark_executor():

spark = _spark_initialize(config=spark_config,log_level='ERROR',spark_progress=False)

filelist = {'ZJets': ['file:'+osp.join(os.getcwd(),'tests/samples/nano_dy.root')],
'Data' : ['file:'+osp.join(os.getcwd(),'tests/samples/nano_dimuon.root')]
filelist = {'ZJets': {'files': ['file:'+osp.join(os.getcwd(),'tests/samples/nano_dy.root')], 'treename': 'Events' },
'Data' : {'files': ['file:'+osp.join(os.getcwd(),'tests/samples/nano_dimuon.root')], 'treename': 'Events'}
}

from coffea.processor.test_items import NanoTestProcessor
Expand Down

0 comments on commit f029dd7

Please sign in to comment.