Training your model
If your model fits well with our structures and does not have specific needs, our Trainers should already be sufficient for you.
Advantages of using our trainers
Checkpoints: Our trainers save the model state at each epoch if it is the best one so far, in a folder best_model. They also always save the model state and optimizer state in a checkpoint. This way, if anything happens and your training is stopped, you can continue training from the latest checkpoint.
Data Management: Our trainers know how to interact with your data in the HDF5 and your model. For instance, it can use the BatchSampler to sample streamlines at each batch, and the BatchLoader to interpolate the diffusion data at each coordinate. This way, your model class stays as simple as possible, purely AI-based layers, without the rigmarole and shenanigans of data management.
Logs and visu: They save many metrics as logs on your computer, which you can visualize with our scripts. It also sends data to comet.ml. See Visualizing logs for more information.
Heavy data - ready: They can manage GPU usage and selecting sampling to limit the loading of heavy data.
Training options: They prepare torch’s optimizer (ex, Adam, SGD, RAdam), define the learning rate, etc.
Overview of the process
This is an example of basic script that you could create to train your model with our trainer. It will require:
Your model
An instance of our object
MultiSubjectDataset: the Trainer knows how to get data in the hdf5, possibly in a lazy way, and store it in a MultiSubjectDataset. See The MultisubjectDataset for more information.An instance of a
BatchSampler: the Trainer knows how to sample a list of chosen streamlines for a batch. See Batch sampler for more information.An instance of a
BatchLoader: the Trainer knows how to load the data using theMultiSubjectDataset, and how to modify the streamlines based on your model’s requirements, for instance adding noise or compressing, changing the step size, and reversing or splitting the streamlines. See Batch loader for more information.
For instance, if you need a dMRI input, your final python script could look like this:
1 # Loading the data, possibly with lazy option
2 dataset = MultiSubjectDataset(hdf5_file)
3 dataset.load_data()
4
5 # Preparing your model
6 model = myModel(args)
7
8 # Preparing the BatchSampler
9 batch_sampler = DWIMLBatchIDSampler(
10 dataset=dataset, streamline_group_name=streamline_group_name)
11
12 # Preparing the BatchLoader.
13 batch_loader = DWIMLBatchLoaderOneInput(
14 dataset=dataset, model=model,
15 input_group_name=input_group_name,
16 streamline_group_name=streamline_group_name)
17
18 # Preparing your trainer
19 trainer = DWIMLTrainerOneInput(
20 model=model, experiments_path=experiments_path,
21 experiment_name=experiment_name, batch_sampler=batch_sampler,
22 batch_loader=batch_loader)
23
24 # Run the training!
25 trainer.train_and_validate()
Once all objects are ready, the Trainer’s method train_and_validate can be used to iterate on epochs until a maximum number of iteration is reached, or a maximum number of bad epochs based on some loss.
Our choices of trainers
DWIMLTrainer
This is the main class. For every batch, it loads the chosen streamlines and uses the model, as explained in section 2 below.
DWIMLTrainerOneInput
This trainer additionally loads one volume group and accessed the coordinates at each point of your streamlines, or possibly in a neighborhood at each coordinate. Of note, this is done as a separate step, and not through torch’s DataLoaders (see explanation in Batch loader), because interpolation of data is faster through GPU, if you have access, but DataLoaders always work on CPU.
This trainer is expected to be used with a child of ModelWithOneInput (see page You may inherit from many models!).
DWIMLTrainerOneInputWithGVPhase
We will soon publish how we have used a new generation-validation phase to supervise our models.
Trainers: the code explained
The Trainer’s main method is train_and_validate. It is summarized below.
1def self.train_and_validate():
2 for epoch in range(nb_epochs):
3 # 1) set the learning rate
4 ...
5
6 # 2) Train
7 self.train_one_epoch()
8
9 # 3) Validate
10 self.validate_one_epoch()
11
12 # 4) Save the model if it's the best epoch
13 if this_is_the_best_epoch:
14 ...
15
16 # 5) Save a checkpoint
17 self.save_checkpoint()
Other steps managed in this method include creating the torch DataLoader from the data_loaders. The DataLoader’s collate_fn will be the sampler’s load_batch() method.
The train_one_epoch method and validate_one_epoch are similar, but validation excludes back-propagation.
1def self.train_one_epoch():
2 for batch in batches:
3 self.run_one_batch()
4
5 # If training: back-prop includes:
6 # - clip gradients
7 # - update torch's optimizer: self.optimizer.step()
8 # - reset torch's gradients: self.optimizer.zero_grad(set_to_none=True)
9 self.back_propagation()
Finally, run_one_batch depends on your model. For instance, in DWIMLTrainerOneInput, it interpolates the input at each point and calls the model:
1def self.run_one_batch():
2 # 1) Send data to GPU if available
3 ...
4
5 # 2) Formats the streamlines if required by the model
6 # ex: SOS, EOS
7 ...
8
9 # 3) Interpolate the input (done in the BatchLoader)
10 batch_inputs = self.batch_loader.load_batch_inputs(
11 streamlines, ids_per_subj)
12
13 # 4) Data augmentation if required
14 streamlines = self.batch_loader.add_noise_streamlines_forward(
15 streamlines, self.device)
16
17 # 5) Call the model
18 model_outputs = self.model(batch_inputs, streamlines_f)
19
20 # 6) Compute the loss
21 mean_loss, n = self.model.compute_loss(model_outputs, targets,
22 average_results=True)
23
24 return mean_loss, n
If this is not right for you, you can override the DWIMLTrainer and re-code this last method.