import os.path as osp
import warnings
import sys
import pickle as pkl
from mlcg_tk.input_generator.prior_gen import PriorBuilder
from mlcg_tk.input_generator.utils import get_output_tag
from jsonargparse import CLI
from typing import List, Optional
[docs]
def merge_statistics(
save_dir: str,
prior_tag: str,
prior_builders: List[PriorBuilder],
names: List[str],
tag: Optional[str] = None,
mol_num_batches: Optional[int] = 1,
):
"""
Merges statistics computed for separate datasets or for individual samples of the same dataset.
Parameters
----------
save_dir : str
Path to directory in which output will be saved
names : List[str]
List of either sample names or dataset names for which statistics will be merged
prior_tag : str
String identifying the specific combination of prior terms
prior_builders : List[PriorBuilder]
List of PriorBuilder objects and their corresponding parameters
tag : str
Optional label included to specify dataset for which sample statistics will be merged
"""
all_stats = []
if mol_num_batches > 1:
names = [f"{n}_batch_{b}" for b in range(mol_num_batches) for n in names]
for name in names:
stats_fn = osp.join(
save_dir,
f"{get_output_tag([tag, name, prior_tag], placement='before')}prior_builders.pck",
)
if osp.exists(stats_fn):
with open(stats_fn, "rb") as ifile:
stats = pkl.load(ifile)
all_stats.append(stats)
else:
warnings.warn(
f"Sample {name} has no saved statistics - This entry will be skipped"
)
continue
fnout = osp.join(
save_dir,
f"{get_output_tag([tag, prior_tag], placement='before')}prior_builders.pck",
)
builder_dict = {}
for prior_builder in prior_builders:
builder_dict[prior_builder.name] = prior_builder
for statistics in all_stats:
for builder in statistics:
combined_builder = builder_dict[builder.name]
for nl_name in list(builder.histograms.data.keys()):
if nl_name not in builder.nl_builder.nl_names:
continue
hists = builder.histograms[nl_name]
for k, hist in hists.items():
combined_builder.histograms.data[nl_name][k] += hist
with open(fnout, "wb") as ofile:
pkl.dump(prior_builders, ofile)
def main():
CLI([merge_statistics], as_positional=False)
if __name__ == "__main__":
main()