docs/source/working-notes/troyraen/SuperNNova/snn_train_a_model.md

Train a SuperNNova Model

Links

setup

folowing the two quickstarts in readthedocs:

gcloud compute instances start troy
gcloud compute ssh troy

sudo mkdir /home/snn
sudo chown -R troyraen /home/snn
cd /home/snn

# clone the repo to get the env file and test data, and to have a reference
# (will pip install so i can use it as a module)
git clone https://github.com/supernnova/supernnova.git
cd supernnova/env
conda create --name snnenv --file conda_env_cpu_linux64.txt
conda activate snnenv
pip install supernnova

ipython

train and validate a model using the test set

gcloud compute instances start troy
gcloud compute ssh troy
cd /home/snn/supernnova
ipython

following the pip quickstart guide in readthedocs:

import supernnova.conf as conf
from supernnova.data import make_dataset
from supernnova.training import train_rnn
from supernnova.validation import validate_rnn


#--- build the database
args =  conf.get_args()  # get config args
args.data = True  # making new dataset
args.dump_dir = "tests/dump"  # where the dataset will be saved
args.raw_dir = "tests/raw"  # where raw photometry files are saved
args.fits_dir = "tests/fits"  # conf: where salt2fits are saved
settings = conf.get_settings(args)
make_dataset.make_dataset(settings)
# [troy] this puts the following files in
    # tests/dump/figures: multiviolin_test.png
    # tests/dump/processed: SNID.pickle, database.h5, hostspe_SNID.pickle

# train RNN
args =  conf.get_args()  # [troy] necessary to clear the settings
args.train_rnn = True
args.dump_dir = "tests/dump"  # where the dataset is saved
args.nb_epoch = 2  # training epochs
settings = conf.get_settings(args)  # set settings
train_rnn.train(settings)  # train rnn
# [troy] creates the following dir with some .json, .png, and .pt files in it
    # dump/models/vanilla_S_0_CLF_2_R_none_photometry_DF_1.0_N_global_lstm_32x2_0.05_128_True_mean

# validate RNN
args =  conf.get_args()
args.validate_rnn = False  # validate rnn
args.dump_dir = "tests/dump"  # where the dataset is saved
settings = conf.get_settings(args)
validate_rnn.get_predictions(settings)  # classify test set
# [troy] adds the following file to dump/models/vanilla_S_0_CLF_2_R_none_photometry_DF_1.0_N_global_lstm_32x2_0.05_128_True_mean:
    # PRED_vanilla_S_0_CLF_2_R_none_photometry_DF_1.0_N_global_lstm_32x2_0.05_128_True_mean.pickle