|
In the spirit of exploring CV, I looked into Leave-One-Out Cross Validation (LOOXV). I used a slightly modified LOOXV where I would leave one row of pixels out rather than a single pixel. I was able to gather some good results using npix=512 and some very good results npix >=1024. The downside with LOOXV is that it takes significantly longer than other CV methods, but so far it has been providing decent results. Another upside to LOOXV is that, since it is just K-Fold CV taken to its extreme, implementation requires no new code. Below are the results of my LOOXV method and the results of K-Fold CV with k=512. As you can see, both methods produced very comparable results. It would be interesting to see if there is some value that maximizes our results while minimizing computing time (i.e. some k value such that 1<=k<=npix).
LOOXV
K-Fold k=512
My LOOXV code:
from sklearn.model_selection import LeaveOneOut
from copy import deepcopy
def LOOXV(model, config, dataset, MODEL_PATH, writer=None):
test_scores = [] vis = dataset.vis_gridded[0] weight = dataset.weight_gridded[0] loo = LeaveOneOut() loo.get_n_splits(vis) for train_index, test_index in loo.split(vis): dset = deepcopy(dataset) vis_train, vis_test = vis[train_index], vis[test_index] weight_train, weight_test = weight[train_index], weight[test_index] dset.vis_gridded = vis_train dset.weight_gridded = weight_train # reset model model.load_state_dict(torch.load(MODEL_PATH))
# create a new optimizer for this k_fold optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])
# train for a while train(model.to("cuda"), dset.to("cuda"), optimizer, config, writer=writer) # evaluate the test metric, some cells have 0 for visibility or weight and causes this to flag a ZeroDivisionError in losses.nll_gridded() try: dset.vis_gridded = vis_test dset.weight_gridded = weight_test test_scores.append(test(model.to("cuda"), dset.to("cuda"))) except ZeroDivisionError: continue # aggregate all test scores and sum to evaluate cross val metric test_score = np.sum(np.array(test_scores))
# adds cross validation score if writer is not None: writer.add_scalar("Cross Validation", test_score)
return test_score
|