Training
mlcg provides some tools to train its models in the scripts folder and some example input files such as examples/train_schnet.yaml. The training is defined using the pytorch-lightning package and especially its cli utilities.
Scripts
Scripts that are using LightningCLI have many convinient builtin functionalities such as a detailed helper
python scripts/mlcg-train.py --help
Utils for using Pytorch Lightning
- class mlcg.pl.PLModel(model, loss)[source]
PL interface to train with models defined in mlcg.nn.
- Parameters:
model (
Module
) – instance of a model class from mlcg.nn.loss (
Loss
) – instance of mlcg.nn.Loss.
- class mlcg.pl.DataModule(dataset, log_dir, val_ratio=0.1, test_ratio=0.1, splits=None, batch_size=512, inference_batch_size=64, num_workers=1, loading_stride=1, save_local_copy=False, pin_memory=True)[source]
PL interface to train with datasets defined in mlcg.datasets.
- Parameters:
dataset (
InMemoryDataset
) – a dataset from mlcg.datasets (or following the API of torch_geometric.data.InMemoryDataset)log_dir (
str
) – where to store the data that might be produced during training.val_ratio (
float
) – fraction of the dataset used for validationtest_ratio (
float
) – fraction of the dataset used for testingsplits (
Optional
[str
]) – filename of a file containing the indices for training, validation, and testing. It should be compatible with np.load and contain the fields ‘idx_train’, ‘idx_val’, and ‘idx_test’. If None then the dataset is split randomly using the val_ratio and test_ratio.batch_size (
int
) – number of structure to include in each training batches.inference_batch_size (
int
) – number of structure to include in each validation/training batches.num_workers; – number of cpu used for loading the dataset (see here for more details).
loading_stride (
int
) – stride used to subselect the dataset. Useful parameter for debugging purposes.save_local_copy (
bool
) – saves the input dataset in log_dir
- class mlcg.pl.LightningCLI(model_class=None, datamodule_class=None, save_config_callback=<class 'pytorch_lightning.cli.SaveConfigCallback'>, save_config_kwargs=None, trainer_class=<class 'pytorch_lightning.trainer.trainer.Trainer'>, trainer_defaults=None, seed_everything_default=True, parser_kwargs=None, subclass_mode_model=False, subclass_mode_data=False, args=None, run=True, auto_configure_optimizers=True, **kwargs)[source]
Command line interface for training a model with pytorch lightning.
It adds a few functionalities to pytorch_lightning.utilities.cli.LightningCLI.
register torch optimizers and lr_scheduler so that they can be specified
in the configuration file. Note that only single (optimizer,lr_scheduler) can be specified like that and more complex patterns should be implemented in the pytorch_lightning model definition (child of pytorch_lightning. LightningModule). see doc for more details.
link manually some arguments related to the definition of the work directory. If default_root_dir argument of pytorch_lightning.Trainer is set and the save_dir / log_dir / dirpath argument of loggers / data / callbacks is set to default_root_dir then they will be set to the value of default_root_dir / default_root_dir/data / default_root_dir/ckpt.