Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-08-05 08:10:11

0001 #!/usr/bin/env python3
0002 from pathlib import Path
0003 from typing import Optional, Dict, List
0004 import re
0005 import enum
0006 import sys
0007 
0008 import uproot
0009 import typer
0010 import hist
0011 import pydantic
0012 import yaml
0013 import pandas
0014 import matplotlib.pyplot
0015 import awkward
0016 
0017 
0018 class Model(pydantic.BaseModel):
0019     class Config:
0020         extra = "forbid"
0021 
0022 
0023 class HistConfig(Model):
0024     nbins: int = 100
0025     min: Optional[float] = None
0026     max: Optional[float] = None
0027     label: Optional[str] = None
0028 
0029 
0030 class Extra(HistConfig):
0031     expression: str
0032     name: str
0033 
0034 
0035 class Config(Model):
0036     histograms: Dict[str, HistConfig] = pydantic.Field(default_factory=dict)
0037     extra_histograms: List[Extra] = pydantic.Field(default_factory=list)
0038     exclude: List[str] = pydantic.Field(default_factory=list)
0039 
0040 
0041 class Mode(str, enum.Enum):
0042     recreate = "recreate"
0043     update = "update"
0044 
0045 
0046 def main(
0047     infile: Path = typer.Argument(
0048         ..., exists=True, dir_okay=False, help="The input ROOT file"
0049     ),
0050     treename: str = typer.Argument(..., help="The tree to look up branched from"),
0051     outpath: Path = typer.Argument(
0052         "outfile", dir_okay=False, help="The output ROOT file"
0053     ),
0054     config_file: Optional[Path] = typer.Option(
0055         None,
0056         "--config",
0057         "-c",
0058         exists=True,
0059         dir_okay=False,
0060         help="A config file following the input spec. By default, all branches will be plotted.",
0061     ),
0062     mode: Mode = typer.Option(Mode.recreate, help="Mode to open ROOT file in"),
0063     plots: Optional[Path] = typer.Option(
0064         None,
0065         "--plots",
0066         "-p",
0067         file_okay=False,
0068         help="If set, output plots individually to this directory",
0069     ),
0070     plot_format: str = typer.Option(
0071         "pdf", "--plot-format", "-f", help="Format to write plots in if --plots is set"
0072     ),
0073     silent: bool = typer.Option(
0074         False, "--silent", "-s", help="Do not print any output"
0075     ),
0076     dump_yml: bool = typer.Option(False, help="Print axis ranges as yml"),
0077 ):
0078     """
0079     Script to plot all branches in a TTree from a ROOT file, with optional configurable binning and ranges.
0080     Also allows setting extra expressions to be plotted as well.
0081     """
0082 
0083     rf = uproot.open(infile)
0084     tree = rf[treename]
0085 
0086     outfile = getattr(uproot, mode.value)(outpath)
0087 
0088     if config_file is None:
0089         config = Config()
0090     else:
0091         with config_file.open() as fh:
0092             config = Config.parse_obj(yaml.safe_load(fh))
0093 
0094     histograms = {}
0095 
0096     if not silent:
0097         print(config.extra_histograms, file=sys.stderr)
0098 
0099     for df in tree.iterate(library="ak", how=dict):
0100         for col in df.keys():
0101             if any([re.match(ex, col) for ex in config.exclude]):
0102                 continue
0103             h = histograms.get(col)
0104             values = awkward.flatten(df[col], axis=None)
0105 
0106             if len(values) == 0:
0107                 print(f"WARNING: Branch '{col}' is empty. Skipped.")
0108                 continue
0109 
0110             if h is None:
0111                 # try to find config
0112                 found = None
0113                 for ex, data in config.histograms.items():
0114                     if re.match(ex, col):
0115                         found = data.copy()
0116                         print(
0117                             "Found HistConfig",
0118                             ex,
0119                             "for",
0120                             col,
0121                             ":",
0122                             found,
0123                             file=sys.stderr,
0124                         )
0125 
0126                 if found is None:
0127                     found = HistConfig()
0128 
0129                 if found.min is None:
0130                     found.min = awkward.min(values)
0131 
0132                 if found.max is None:
0133                     found.max = awkward.max(values)
0134 
0135                 if found.min == found.max:
0136                     found.min -= 1
0137                     found.max += 1
0138 
0139                 h = hist.Hist(
0140                     hist.axis.Regular(
0141                         found.nbins, found.min, found.max, name=found.label or col
0142                     )
0143                 )
0144 
0145                 histograms[col] = h
0146             h.fill(values)
0147 
0148             for extra in config.extra_histograms:
0149                 h = histograms.get(extra.name)
0150                 #  calc = pandas.eval(extra.expression, target=df)
0151                 calc = eval(extra.expression)
0152                 values = awkward.flatten(calc, axis=None)
0153                 if h is None:
0154                     if extra.min is None:
0155                         extra.min = awkward.min(values)
0156                     if extra.max is None:
0157                         extra.max = awkward.max(values)
0158 
0159                 if extra.min == extra.max:
0160                     extra.min -= 1
0161                     extra.max += 1
0162 
0163                 h = hist.Hist(
0164                     hist.axis.Regular(
0165                         extra.nbins,
0166                         extra.min,
0167                         extra.max,
0168                         name=extra.label or extra.name,
0169                     )
0170                 )
0171 
0172                 histograms[extra.name] = h
0173                 h.fill(values)
0174 
0175     if plots is not None:
0176         plots.mkdir(parents=True, exist_ok=True)
0177 
0178     for k, h in histograms.items():
0179         if not silent:
0180             if dump_yml:
0181                 ax = h.axes[0]
0182                 s = """
0183 {k}:
0184   nbins: {b}
0185   min: {min}
0186   max: {max}
0187                 """.format(
0188                     k=k, b=len(ax.edges) - 1, min=ax.edges[0], max=ax.edges[-1]
0189                 )
0190                 print(s)
0191             else:
0192                 print(k, h.axes[0])
0193         outfile[k] = h
0194 
0195         if plots is not None:
0196             fig, ax = matplotlib.pyplot.subplots()
0197 
0198             h.plot(ax=ax, flow=None)
0199 
0200             fig.tight_layout()
0201             fig.savefig(str(plots / f"{k}.{plot_format}"))
0202             matplotlib.pyplot.close()
0203 
0204 
0205 if __name__ == "__main__":
0206     typer.run(main)