From a7fc3c5639c2d9556639e09e54372fcd554f922f Mon Sep 17 00:00:00 2001 From: Ben Galewsky Date: Thu, 19 May 2022 10:46:34 -0500 Subject: [PATCH] Add script to run ttbar analysis with MLFlow instrumentation --- .gitignore | 6 + analyses/cms-open-data-ttbar/MLproject | 26 ++ analyses/cms-open-data-ttbar/README.md | 58 +++ analyses/cms-open-data-ttbar/analysis.py | 342 ++++++++++++++++++ analyses/cms-open-data-ttbar/conda.yaml | 19 + .../search_hyperparameter.py | 130 +++++++ .../cms-open-data-ttbar/utils/__init__.py | 5 +- 7 files changed, 582 insertions(+), 4 deletions(-) create mode 100644 analyses/cms-open-data-ttbar/MLproject create mode 100644 analyses/cms-open-data-ttbar/README.md create mode 100644 analyses/cms-open-data-ttbar/analysis.py create mode 100644 analyses/cms-open-data-ttbar/conda.yaml create mode 100644 analyses/cms-open-data-ttbar/search_hyperparameter.py diff --git a/.gitignore b/.gitignore index 6bf38b56..415f3867 100644 --- a/.gitignore +++ b/.gitignore @@ -5,8 +5,11 @@ __pycache__ venv/ .env/ .ipynb_checkpoints +.idea/ servicex.yml +.servicex + analyses/**/*.root analyses/**/*.pdf @@ -24,3 +27,6 @@ workshops/agctools2022/statistical-inference/input # CMS ttbar analyses/cms-open-data-ttbar/workspace.json + +# MLFlow +mlruns/ diff --git a/analyses/cms-open-data-ttbar/MLproject b/analyses/cms-open-data-ttbar/MLproject new file mode 100644 index 00000000..102f1066 --- /dev/null +++ b/analyses/cms-open-data-ttbar/MLproject @@ -0,0 +1,26 @@ +name: analysis-grand-challenge + +conda_env: conda.yaml + +entry_points: + ttbar: + parameters: + num-input-files: {type: int, default: 10} + num-bins: {type: int, default: 25} + bin-low: {type: int, default: 50} + bin-high: {type: int, default: 550} + pt-threshold: {type: int, default: 25} + + command: "python analysis.py --num-input-files {num-input-files} --num-bins {num-bins} --bin-low {bin-low} --bin-high {bin-high} --pt-threshold {pt-threshold}" + + # Use Hyperopt to optimize hyperparams of the ttbar entry_point. + hyperopt: + parameters: + max_runs: {type: int, default: 12} + metric: {type: string, default: "ttbar_norm_bestfit"} + algo: {type: string, default: "tpe.suggest"} + command: "python -O search_hyperparameter.py + --max-runs {max_runs} + --metric {metric} + --algo {algo}" + diff --git a/analyses/cms-open-data-ttbar/README.md b/analyses/cms-open-data-ttbar/README.md new file mode 100644 index 00000000..50aec186 --- /dev/null +++ b/analyses/cms-open-data-ttbar/README.md @@ -0,0 +1,58 @@ +CMS Open Data $t\\bar{t}$: from data delivery to statistical inference + +We are using [2015 CMS Open Data](https://cms.cern/news/first-cms-open-data-lhc-run-2-released) +in this demonstration to showcase an analysis pipeline. It features data +delivery and processing, histogram construction and visualization, as well as +statistical inference. + +This notebook was developed in the context of the +[IRIS-HEP AGC tools 2022 workshop](https://indico.cern.ch/e/agc-tools-2). This +work was supported by the U.S. National Science Foundation (NSF) Cooperative +Agreement OAC-1836650 (IRIS-HEP). + +This is a technical demonstration. We are including the relevant workflow +aspects that physicists need in their work, but we are not focusing on making +every piece of the demonstration physically meaningful. This concerns in +particular systematic uncertainties: we capture the workflow, but the actual +implementations are more complex in practice. If you are interested in the +physics side of analyzing top pair production, check out the latest results from +ATLAS and CMS! If you would like to see more technical demonstrations, also +check out an ATLAS Open Data example demonstrated previously. + +## Tracking Analysis Runs with MLFlow +A version of this analysis has been instrumented with +[MLFlow](https://mlflow.org) to record runs of this analysis along with the +input parameters, the fit results, and generated plots. To use the tracking +service you will need: +* Conda +* Access to an MLFlow tracking service instance +* Environment variables set to allow the script to communicate with the tracking service and the back-end object store: + * `MLFLOW_TRACKING_URI` + * `MLFLOW_S3_ENDPOINT_URL` + * `AWS_ACCESS_KEY_ID` + * `AWS_SECRET_ACCESS_KEY` + +If you would like to install a local instance of the MLFlow tracking service on +you Kubernetes cluster, this +[helm chart](https://artifacthub.io/packages/helm/ncsa/mlflow) is a good start. + +For reproducibility, MLFlow insists on running the analysis in a conda +environment. This is defined in `conda.yaml`. + +The MLFlow project is defined in `MLprojec` - this file specifies two different +_entrypoints_ + +`ttbar` is the entrypoint for running a single analysis. It offers a number +of command line parameters to control the analysis. It can be run as +```shell +mlflow run -P num-bins=25 -P pt-threshold=25 . +``` +### Hyperparameter Searches +MLFlow is often used in optimizing models by running with different +hyperparamters until a minimal loss function is realized. We've borrowed this +approach for optimizing an analysis. You can orchestrate a number of analysis +runs with different input settings by using the `hyperopt` entrypoint. + +```shell + mlflow run -e hyperopt -P max_runs=20 . +``` diff --git a/analyses/cms-open-data-ttbar/analysis.py b/analyses/cms-open-data-ttbar/analysis.py new file mode 100644 index 00000000..70f15a06 --- /dev/null +++ b/analyses/cms-open-data-ttbar/analysis.py @@ -0,0 +1,342 @@ +# Copyright (c) 2019, IRIS-HEP +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import asyncio +import time +import logging +import click +import mlflow + +import vector; vector.register_awkward() + +import awkward as ak +import cabinetry +from coffea import processor +from coffea.processor import servicex +from coffea.nanoevents import transforms +from coffea.nanoevents.methods import base, vector +from coffea.nanoevents.schemas.base import BaseSchema, zip_forms +from func_adl import ObjectStream +import hist +import json +import matplotlib.pyplot as plt +import numpy as np +import uproot + +import utils # contains code for bookkeeping and cosmetics, as well as some boilerplate + +logging.getLogger("cabinetry").setLevel(logging.INFO) + + +processor_base = servicex.Analysis + +# functions creating systematic variations +def flat_variation(ones): + # 0.1% weight variations + return (1.0 + np.array([0.001, -0.001], dtype=np.float32)) * ones[:, None] + + +def btag_weight_variation(i_jet, jet_pt): + # weight variation depending on i-th jet pT (10% as default value, multiplied by i-th jet pT / 50 GeV) + return 1 + np.array([0.1, -0.1]) * (ak.singletons(jet_pt[:, i_jet]) / 50).to_numpy() + + +def jet_pt_resolution(pt): + # normal distribution with 5% variations, shape matches jets + counts = ak.num(pt) + pt_flat = ak.flatten(pt) + resolution_variation = np.random.normal(np.ones_like(pt_flat), 0.05) + return ak.unflatten(resolution_variation, counts) + + +class TtbarAnalysis(processor_base): + def __init__(self, num_bins, bin_low, bin_high, pt_threshold): + name = "observable" + label = "observable [GeV]" + self.pt_threshold = pt_threshold + self.hist = ( + hist.Hist.new.Reg(num_bins, bin_low, bin_high, name=name, label=label) + .StrCat(["4j1b", "4j2b"], name="region", label="Region") + .StrCat([], name="process", label="Process", growth=True) + .StrCat([], name="variation", label="Systematic variation", growth=True) + .Weight() + ) + + def process(self, events): + histogram = self.hist.copy() + + process = events.metadata["process"] # "ttbar" etc. + variation = events.metadata["variation"] # "nominal" etc. + + # normalization for MC + x_sec = events.metadata["xsec"] + nevts_total = events.metadata["nevts"] + lumi = 3378 # /pb + if process != "data": + xsec_weight = x_sec * lumi / nevts_total + else: + xsec_weight = 1 + + #### systematics + # example of a simple flat weight variation, using the coffea nanoevents systematics feature + if process == "wjets": + events.add_systematic("scale_var", "UpDownSystematic", "weight", flat_variation) + + # jet energy scale / resolution systematics + # need to adjust schema to instead use coffea add_systematic feature, especially for ServiceX + # cannot attach pT variations to events.jet, so attach to events directly + # and subsequently scale pT by these scale factors + events["pt_nominal"] = 1.0 + events["pt_scale_up"] = 1.03 + events["pt_res_up"] = jet_pt_resolution(events.jet.pt) + + pt_variations = ["pt_nominal", "pt_scale_up", "pt_res_up"] if variation == "nominal" else ["pt_nominal"] + for pt_var in pt_variations: + + ### event selection + # very very loosely based on https://arxiv.org/abs/2006.13076 + + # pT > 25 GeV for leptons & jets + selected_electrons = events.electron[events.electron.pt > self.pt_threshold] + selected_muons = events.muon[events.muon.pt > self.pt_threshold] + jet_filter = events.jet.pt * events[pt_var] > self.pt_threshold + selected_jets = events.jet[jet_filter] + + # single lepton requirement + event_filters = (ak.count(selected_electrons.pt, axis=1) & ak.count(selected_muons.pt, axis=1) == 1) + # at least four jets + pt_var_modifier = events[pt_var] if "res" not in pt_var else events[pt_var][jet_filter] + event_filters = event_filters & (ak.count(selected_jets.pt * pt_var_modifier, axis=1) >= 4) + # at least one b-tagged jet ("tag" means score above threshold) + B_TAG_THRESHOLD = 0.5 + event_filters = event_filters & (ak.sum(selected_jets.btag >= B_TAG_THRESHOLD, axis=1) >= 1) + + # apply event filters + selected_events = events[event_filters] + selected_electrons = selected_electrons[event_filters] + selected_muons = selected_muons[event_filters] + selected_jets = selected_jets[event_filters] + + for region in ["4j1b", "4j2b"]: + # further filtering: 4j1b CR with single b-tag, 4j2b SR with two or more tags + if region == "4j1b": + region_filter = ak.sum(selected_jets.btag >= B_TAG_THRESHOLD, axis=1) == 1 + selected_jets_region = selected_jets[region_filter] + # use HT (scalar sum of jet pT) as observable + pt_var_modifier = events[event_filters][region_filter][pt_var] if "res" not in pt_var else events[pt_var][jet_filter][event_filters][region_filter] + observable = ak.sum(selected_jets_region.pt * pt_var_modifier, axis=-1) + + elif region == "4j2b": + region_filter = ak.sum(selected_jets.btag > B_TAG_THRESHOLD, axis=1) >= 2 + selected_jets_region = selected_jets[region_filter] + + # wrap into a four-vector object to allow addition + selected_jets_region = ak.zip( + { + "pt": selected_jets_region.pt, "eta": selected_jets_region.eta, "phi": selected_jets_region.phi, + "mass": selected_jets_region.mass, "btag": selected_jets_region.btag, + }, + with_name="Momentum4D", + ) + + # reconstruct hadronic top as bjj system with largest pT + # the jet energy scale / resolution effect is not propagated to this observable at the moment + trijet = ak.combinations(selected_jets_region, 3, fields=["j1", "j2", "j3"]) # trijet candidates + trijet["p4"] = trijet.j1 + trijet.j2 + trijet.j3 # calculate four-momentum of tri-jet system + trijet["max_btag"] = np.maximum(trijet.j1.btag, np.maximum(trijet.j2.btag, trijet.j3.btag)) + trijet = trijet[trijet.max_btag > B_TAG_THRESHOLD] # require at least one-btag in trijet candidates + # pick trijet candidate with largest pT and calculate mass of system + trijet_mass = trijet["p4"][ak.argmax(trijet.p4.pt, axis=1, keepdims=True)].mass + observable = ak.flatten(trijet_mass) + + ### histogram filling + if pt_var == "pt_nominal": + # nominal pT, but including 2-point systematics + histogram.fill( + observable=observable, region=region, process=process, variation=variation, weight=xsec_weight + ) + + if variation == "nominal": + # also fill weight-based variations for all nominal samples + for weight_name in events.systematics.fields: + for direction in ["up", "down"]: + # extract the weight variations and apply all event & region filters + weight_variation = events.systematics[weight_name][direction][f"weight_{weight_name}"][event_filters][region_filter] + # fill histograms + histogram.fill( + observable=observable, region=region, process=process, variation=f"{weight_name}_{direction}", weight=xsec_weight*weight_variation + ) + + # calculate additional systematics: b-tagging variations + for i_var, weight_name in enumerate([f"btag_var_{i}" for i in range(4)]): + for i_dir, direction in enumerate(["up", "down"]): + # create systematic variations that depend on object properties (here: jet pT) + if len(observable): + weight_variation = btag_weight_variation(i_var, selected_jets_region.pt)[:, 1-i_dir] + else: + weight_variation = 1 # no events selected + histogram.fill( + observable=observable, region=region, process=process, variation=f"{weight_name}_{direction}", weight=xsec_weight*weight_variation + ) + + elif variation == "nominal": + # pT variations for nominal samples + histogram.fill( + observable=observable, region=region, process=process, variation=pt_var, weight=xsec_weight + ) + + output = {"nevents": {events.metadata["dataset"]: len(events)}, "hist": histogram} + + return output + + def postprocess(self, accumulator): + return accumulator + + +def get_query(source: ObjectStream) -> ObjectStream: + """Query for event / column selection: no filter, select relevant lepton and jet columns + """ + return source.Select(lambda e: { + "electron_pt": e.electron_pt, + "muon_pt": e.muon_pt, + "jet_pt": e.jet_pt, + "jet_eta": e.jet_eta, + "jet_phi": e.jet_phi, + "jet_mass": e.jet_mass, + "jet_btag": e.jet_btag, + } + ) +@click.command( + help="CMS Open Data t-tbar: from data delivery to statistical inference." +) +@click.option("--num-input-files", type=click.INT, default=10, help="input files per process, set to e.g. 10 (smaller number = faster).") +@click.option("--num-bins", type=click.INT, default=25, help="Number of bins.") +@click.option("--bin-low", type=click.INT, default=50, help="Bottom bin.") +@click.option("--bin-high", type=click.INT, default=550, help="Top bin.") +@click.option("--pt-threshold", type=click.INT, default=25, help="pt for leptons & jets (in GeV).") +def analysis(num_input_files, num_bins, bin_low, bin_high, pt_threshold): + with mlflow.start_run(): + fileset = utils.construct_fileset(num_input_files, use_xcache=False) + + print(f"processes in fileset: {list(fileset.keys())}") + print( + f"\nexample of information in fileset:\n{{\n 'files': [{fileset['ttbar__nominal']['files'][0]}, ...],") + print(f" 'metadata': {fileset['ttbar__nominal']['metadata']}\n}}") + + + async def produce_all_the_histograms(fileset, analysis_processor): + return await utils.produce_all_histograms(fileset, get_query, analysis_processor, use_dask=False) + + analysis_processor = TtbarAnalysis(num_bins, bin_low, bin_high, pt_threshold) + all_histograms = asyncio.run(produce_all_the_histograms(fileset, analysis_processor)) + + print(all_histograms) + + utils.set_style() + + all_histograms[120j::hist.rebin(2), "4j1b", :, "nominal"].stack("process")[::-1].plot( + stack=True, histtype="fill", linewidth=1, edgecolor="grey") + plt.legend(frameon=False) + plt.title(">= 4 jets, 1 b-tag") + plt.xlabel("HT [GeV]") + plt.savefig('1-btag.png') + plt.clf() + mlflow.log_artifact("1-btag.png") + + all_histograms[:, "4j2b", :, "nominal"].stack("process")[::-1].plot(stack=True, + histtype="fill", + linewidth=1, + edgecolor="grey") + plt.legend(frameon=False) + plt.title(">= 4 jets, >= 2 b-tags") + plt.xlabel("$m_{bjj}$ [Gev]"); + plt.savefig('2-btag.png') + plt.clf() + + mlflow.log_artifact("2-btag.png") + + # b-tagging variations + all_histograms[120j::hist.rebin(2), "4j1b", "ttbar", "nominal"].plot( + label="nominal", linewidth=2) + all_histograms[120j::hist.rebin(2), "4j1b", "ttbar", "btag_var_0_down"].plot( + label="NP 1", linewidth=2) + all_histograms[120j::hist.rebin(2), "4j1b", "ttbar", "btag_var_1_down"].plot( + label="NP 2", linewidth=2) + all_histograms[120j::hist.rebin(2), "4j1b", "ttbar", "btag_var_2_down"].plot( + label="NP 3", linewidth=2) + all_histograms[120j::hist.rebin(2), "4j1b", "ttbar", "btag_var_3_down"].plot( + label="NP 4", linewidth=2) + plt.legend(frameon=False) + plt.xlabel("HT [GeV]") + plt.title("b-tagging variations"); + plt.savefig('b-tag-variations.png') + plt.clf() + + mlflow.log_artifact("b-tag-variations.png") + + # jet energy scale variations + all_histograms[:, "4j2b", "ttbar", "nominal"].plot(label="nominal", linewidth=2) + all_histograms[:, "4j2b", "ttbar", "pt_scale_up"].plot(label="scale up", linewidth=2) + all_histograms[:, "4j2b", "ttbar", "pt_res_up"].plot(label="resolution up", linewidth=2) + plt.legend(frameon=False) + plt.xlabel("$m_{bjj}$ [Gev]") + plt.title("Jet energy variations") + plt.savefig('jet-energy-variations.png') + plt.clf() + + mlflow.log_artifact("jet-energy-variations.png") + + utils.save_histograms(all_histograms, fileset, "histograms.root") + mlflow.log_artifact("histograms.root") + + config = cabinetry.configuration.load("cabinetry_config.yml") + cabinetry.templates.collect(config) + cabinetry.templates.postprocess( + config) # optional post-processing (e.g. smoothing) + ws = cabinetry.workspace.build(config) + cabinetry.workspace.save(ws, "workspace.json") + + model, data = cabinetry.model_utils.model_and_data(ws) + fit_results = cabinetry.fit.fit(model, data) + + fig = cabinetry.visualize.pulls( + fit_results, exclude="ttbar_norm", close_figure=True, save_figure=False + ) + fig.savefig("ttbar_norm_fit.png") + plt.clf() + mlflow.log_artifact("ttbar_norm_fit.png") + + poi_index = model.config.poi_index + mlflow.log_metric("ttbar_norm_bestfit", fit_results.bestfit[poi_index]) + mlflow.log_metric("ttbar_norm_uncertainty", fit_results.uncertainty[poi_index]) + + print( + f"\nfit result for ttbar_norm: {fit_results.bestfit[poi_index]:.3f} +/- {fit_results.uncertainty[poi_index]:.3f}") + + +if __name__ == "__main__": + analysis() \ No newline at end of file diff --git a/analyses/cms-open-data-ttbar/conda.yaml b/analyses/cms-open-data-ttbar/conda.yaml new file mode 100644 index 00000000..c86684bf --- /dev/null +++ b/analyses/cms-open-data-ttbar/conda.yaml @@ -0,0 +1,19 @@ +name: analysis-grand-challenge +channels: + - conda-forge +dependencies: + - python=3.8 + - pip + - pip: + - boto3 + - aiostream + - tenacity + - vector + - cabinetry + - backoff<2.0.0 + - servicex_clients + - coffea + - dask[distributed] + - mlflow>=1.0 + - pandas + - hyperopt diff --git a/analyses/cms-open-data-ttbar/search_hyperparameter.py b/analyses/cms-open-data-ttbar/search_hyperparameter.py new file mode 100644 index 00000000..5eef674c --- /dev/null +++ b/analyses/cms-open-data-ttbar/search_hyperparameter.py @@ -0,0 +1,130 @@ +""" +Orchestrate hyperparameter search for tt-bar analysis. +Uses hyperopt library to control the search over two of the analysis parameters: + * number of histogram buckets + * Threshold of muon and electron pt for event selection + +The analysis will be run repeatedly in an attempt to optimize for ttbar_norm_bestfit, but +through a command line option can be configured to optimize for ttbar_norm_uncertainty +""" + +import click +import numpy as np + +from hyperopt import fmin, hp, tpe, rand + +import mlflow.projects +from mlflow.tracking.client import MlflowClient + +_inf = np.finfo(np.float64).max + + +@click.command( + help="Perform hyperparameter search with Hyperopt library. Optimize ttbar target." +) +@click.option("--num-input-files", type=click.INT, default=10, help="input files per process, set to e.g. 10 (smaller number = faster).") +@click.option("--bin-low", type=click.INT, default=50, help="Bottom bin.") +@click.option("--bin-high", type=click.INT, default=550, help="Top bin.") +@click.option("--max-runs", type=click.INT, default=10, help="Maximum number of runs to evaluate.") +@click.option("--metric", type=click.STRING, default="ttbar_norm_bestfit", help="Metric to optimize on.") +@click.option("--algo", type=click.STRING, default="tpe.suggest", help="Optimizer algorithm.") +def train(max_runs, metric, algo, num_input_files, bin_low, bin_high): + """ + Run hyperparameter optimization. + """ + tracking_client = mlflow.tracking.MlflowClient() + + def new_eval(experiment_id, null_loss, return_all=False): + """ + Create a new eval function + :return: new eval function. + """ + + def eval(params): + """ + Run tt-bar analysis given an ordered set of hyperparameters. + + :param params: Parameters to run the tt-bar analysis script we optimize over: + number of histogram bis, pt threshold for cuts + :return: The metric value evaluated on the validation data. + """ + import mlflow.tracking + + num_bins, pt_threshold = params + with mlflow.start_run(nested=True) as child_run: + p = mlflow.projects.run( + uri=".", + entry_point="ttbar", + run_id=child_run.info.run_id, + parameters={ + "num-input-files": str(num_input_files), + "num-bins": str(num_bins), + "bin-low": str(bin_low), + "bin-high": str(bin_high), + "pt-threshold": str(pt_threshold) + }, + experiment_id=experiment_id, + env_manager="local", # We are already in the environment + synchronous=False, # Allow the run to fail if a model is not properly created + ) + succeeded = p.wait() + mlflow.log_params({"num-bins": num_bins, "pt-threshold": pt_threshold}) + + if succeeded: + analysis_run = tracking_client.get_run(p.run_id) + metrics = analysis_run.data.metrics + # cap the loss at the loss of the null model + loss = min(null_loss, metrics[metric]) + else: + # run failed => return null loss + tracking_client.set_terminated(p.run_id, "FAILED") + loss = null_loss + + mlflow.log_metrics( + { + "metric": loss + } + ) + + return loss + + return eval + + # Parameter search space + space = [ + hp.choice('num-bins', np.arange(15, 50, dtype=int)), + hp.choice("pt-threshold", np.arange(20, 30, dtype=int)) + ] + + with mlflow.start_run() as run: + experiment_id = run.info.experiment_id + # Evaluate null model first. + loss = new_eval(experiment_id, _inf, True)(params=[10, 25]) + best = fmin( + fn=new_eval(experiment_id, loss), + space=space, + algo=tpe.suggest if algo == "tpe.suggest" else rand.suggest, + max_evals=max_runs, + ) + mlflow.set_tag("best params", str(best)) + # find the best run, log its metrics as the final metrics of this run. + client = MlflowClient() + runs = client.search_runs( + [experiment_id], "tags.mlflow.parentRunId = '{run_id}' ".format(run_id=run.info.run_id) + ) + best_loss = _inf + best_run = None + for r in runs: + if r.data.metrics.get(metric, _inf) < best_loss: + best_run = r + best_loss = r.data.metrics[metric] + mlflow.set_tag("best_run", best_run.info.run_id) + mlflow.log_metrics( + { + metric: best_loss, + } + ) + + +if __name__ == "__main__": + train() diff --git a/analyses/cms-open-data-ttbar/utils/__init__.py b/analyses/cms-open-data-ttbar/utils/__init__.py index 1843adab..aca8ff6f 100644 --- a/analyses/cms-open-data-ttbar/utils/__init__.py +++ b/analyses/cms-open-data-ttbar/utils/__init__.py @@ -128,7 +128,7 @@ def make_datasource(fileset:dict, name: str, query: ObjectStream, ignore_cache: ) -async def produce_all_histograms(fileset, query, procesor_class, use_dask=False, ignore_cache=False): +async def produce_all_histograms(fileset, query, analysis_processor, use_dask=False, ignore_cache=False): """Runs the histogram production, processing input files with ServiceX and producing histograms with coffea. """ @@ -148,9 +148,6 @@ async def produce_all_histograms(fileset, query, procesor_class, use_dask=False, for ds_name in fileset.keys() ] - # create the analysis processor - analysis_processor = procesor_class() - async def run_updates_stream(accumulator_stream, name): """Run to get the last item in the stream""" coffea_info = None