docs/source/working-notes/troyraen/SuperNNova/snn_prelim_exploration.md
SuperNNova preliminary exploration
(see also: snn_train_a_model.md)
Following run_onthefly.py
pip install supernnova
pip install torch
pip install h5py
pip install natsort
pip install scikit-learn
pip install seaborn
Get some alerts and classify them:
from astropy.time import Time
import numpy as np
import pandas as pd
from pgb_utils import pubsub as pgbps
from supernnova.validation.validate_onthefly import classify_lcs
COLUMN_NAMES = [
"SNID",
"MJD",
"FLUXCAL",
"FLUXCALERR",
"FLT"
]
cols = ['objectId', 'jd', 'magpsf', 'sigmapsf', 'magzpsci', 'fid']
ztf_fid_names = {1:'g', 2:'r', 3:'i'}
device='cpu'
model_file='/Users/troyraen/Documents/broker/SNN/ZTF_DMAM_V19_NoC_SNIa_vs_CC_forFink/vanilla_S_0_CLF_2_R_none_photometry_DF_1.0_N_global_lstm_32x2_0.05_128_True_mean.pt'
# rnn_state = torch.load(model_file, map_location=lambda storage, loc: storage)
subscription = 'ztf-loop'
msgs = pgbps.pull(subscription, max_messages=10)
# dflist = [pgbps.decode_ztf_alert(m, return_format='df') for m in msgs]
dflist = []
for m in msgs:
df = pgbps.decode_ztf_alert(m, return_format='df')
df['objectId'] = df.objectId
df = df[cols]
df['SNID'] = df['objectId']
df['MJD'] = Time(df['jd'], format='jd').mjd
df['FLUXCAL'] = 10 ** ((df['magzpsci'] - df['magpsf']) / 2.5)
df['FLUXCALERR'] = df['FLUXCAL'] * df['sigmapsf'] * np.log(10 / 2.5)
df['FLT'] = df['fid'].map(ztf_fid_names)
dflist.append(df)
dfs = pd.concat(dflist)
ids_preds, pred_probs = classify_lcs(dfs, model_file, device)
preds_df = reformat_to_df(pred_probs, ids=ids_preds)
def reformat_to_df(pred_probs, ids=None):
""" Reformat SNN predictions to a DataFrame
# TO DO: suppport nb_inference != 1
"""
num_inference_samples = 1
d_series = {}
for i in range(pred_probs[0].shape[1]):
d_series["SNID"] = []
d_series[f"prob_class{i}"] = []
for idx, value in enumerate(pred_probs):
d_series["SNID"] += [ids[idx]] if len(ids) > 0 else idx
value = value.reshape((num_inference_samples, -1))
value_dim = value.shape[1]
for i in range(value_dim):
d_series[f"prob_class{i}"].append(value[:, i][0])
preds_df = pd.DataFrame.from_dict(d_series)
# get predicted class
preds_df["pred_class"] = np.argmax(pred_probs, axis=-1).reshape(-1)
return preds_df