Training

mlcg provides some tools to train its models in the scripts folder and some example input files such as examples/train_schnet.yaml. The training is defined using the pytorch-lightning package and especially its cli utilities.

Scripts

Scripts that are using LightningCLI have many convinient builtin functionalities such as a detailed helper

python scripts/mlcg-train.py --help

Utils for using Pytorch Lightning

class mlcg.pl.PLModel(model, loss)[source]

PL interface to train with models defined in mlcg.nn.

Parameters:
  • model (Module) – instance of a model class from mlcg.nn.

  • loss (Loss) – instance of mlcg.nn.Loss.

class mlcg.pl.DataModule(dataset, log_dir, val_ratio=0.1, test_ratio=0.1, splits=None, batch_size=512, inference_batch_size=64, num_workers=1, loading_stride=1, save_local_copy=False, pin_memory=True)[source]

PL interface to train with datasets defined in mlcg.datasets.

Parameters:
  • dataset (InMemoryDataset) – a dataset from mlcg.datasets (or following the API of torch_geometric.data.InMemoryDataset)

  • log_dir (str) – where to store the data that might be produced during training.

  • val_ratio (float) – fraction of the dataset used for validation

  • test_ratio (float) – fraction of the dataset used for testing

  • splits (Optional[str]) – filename of a file containing the indices for training, validation, and testing. It should be compatible with np.load and contain the fields ‘idx_train’, ‘idx_val’, and ‘idx_test’. If None then the dataset is split randomly using the val_ratio and test_ratio.

  • batch_size (int) – number of structure to include in each training batches.

  • inference_batch_size (int) – number of structure to include in each validation/training batches.

  • num_workers; – number of cpu used for loading the dataset (see here for more details).

  • loading_stride (int) – stride used to subselect the dataset. Useful parameter for debugging purposes.

  • save_local_copy (bool) – saves the input dataset in log_dir

class mlcg.pl.LightningCLI(model_class=None, datamodule_class=None, save_config_callback=<class 'pytorch_lightning.cli.SaveConfigCallback'>, save_config_kwargs=None, trainer_class=<class 'pytorch_lightning.trainer.trainer.Trainer'>, trainer_defaults=None, seed_everything_default=True, parser_kwargs=None, subclass_mode_model=False, subclass_mode_data=False, args=None, run=True, auto_configure_optimizers=True, **kwargs)[source]

Command line interface for training a model with pytorch lightning.

It adds a few functionalities to pytorch_lightning.utilities.cli.LightningCLI.

  • register torch optimizers and lr_scheduler so that they can be specified

in the configuration file. Note that only single (optimizer,lr_scheduler) can be specified like that and more complex patterns should be implemented in the pytorch_lightning model definition (child of pytorch_lightning. LightningModule). see doc for more details.

  • link manually some arguments related to the definition of the work directory. If default_root_dir argument of pytorch_lightning.Trainer is set and the save_dir / log_dir / dirpath argument of loggers / data / callbacks is set to default_root_dir then they will be set to the value of default_root_dir / default_root_dir/data / default_root_dir/ckpt.