"""The general KIM class."""
# Author: Peishi Jiang <shixijps@gmail.com>
from .data import Data
from .mapping_model.loss_func import loss_mse
from .mapping_model import train_ensemble
from .mapping_model import MLP
from .utils import compute_metrics
import json
import random
import pickle
import itertools
from copy import deepcopy
from pathlib import PosixPath, Path
import jax
import jax.numpy as jnp
import numpy as np
import equinox as eqx
# from jaxlib.xla_extension import Device
from jax import Device
from typing import Optional
from jaxtyping import Array
# TODO: Need a great way to pass the computational device
# Attributes:
# ----------
# data (Data) : argument copy
# map_configs (dict) : argument copy
# map_option (str) : argument copy
# mask_option (str) : argument copy
# trained (bool) : whether KIM has been trained
# loaded_from_other_sources (bool) : whether KIM is loaded from other sources.
# Ns (int) : the number of ensemble members (from data.Ns)
# Nx (int) : the number of input features (from data.Nx)
# Ny (int) : the number of output features (from data.Ny)
# mask (Array) : the masked array with shape (Nx, Ny)
# _n_maps (int) : the number of maps
# _maps (int) : the trained maps
[docs]
class KIM(object):
"""The class for knowledge-informed mapping training, prediction, saving and loading.
Attributes:
----------
data : Data
the copy of the __init__ argument
map_configs : dict
the copy of the __init__ argument
map_option : str
the copy of the __init__ argument
mask_option : str
the copy of the __init__ argument
trained : bool
whether KIM has been trained
loaded_from_other_sources : bool
whether KIM is loaded from other sources.
Ns : int
the number of ensemble members (from data.Ns)
Nx : int
the number of input features (from data.Nx)
Ny : int
the number of output features (from data.Ny)
mask : Array
the masked array with shape (Nx, Ny)
_n_maps : int
the number of maps
_maps : int
the trained maps
"""
[docs]
def __init__(
self, data: Data, map_configs: dict,
mask_option: str='cond_sensitivity',
map_option: str='many2one',
other_mask: Optional[Array]=None,
name: str='kim'
):
"""Initialization function.
Args:
data (Data): the Data object containing the ensemble data and sensitivity analysis result.
map_configs (dict): the mapping configuration, including all the arguments of Map class except x and y.
mask_option (str): the masking option including "sensitivity" (using data.sensitivity_mask), and "cond_sensitivity" (using data.cond_sensitivity_mask).
map_option (str): the map option including "many2one": knowledge-informed mapping using sensitivity analysis result as filter, and "many2many": normal mapping without being knowledge-informed
other_mask (List): the additional mask to be assigned to self.mask with size Nx. Default to None.
name (str): the name of the KIM object
"""
self.name = name
# Check whether sensitivity has been performed in Data
if not data.sensitivity_done and map_option == "many2one":
raise Exception(
"Sensitivity analysis has not been performed. So, KIM can not be executed in many to one mode."
)
self.data = data
self.trained = False
self.loaded_from_other_sources = False
self.Ns, self.Nx, self.Ny = data.Ns, data.Nx, data.Ny
self.mask_option = mask_option
self.other_mask = other_mask
if mask_option == "sensitivity":
self.mask = data.sensitivity_mask
elif mask_option == "cond_sensitivity":
self.mask = data.cond_sensitivity_mask
else:
raise Exception("Unknown mask_option: %s" % mask_option)
# Check whether additional masks are needed
if self.other_mask is not None:
assert len(self.other_mask) == self.Nx
for i in range(self.Ny):
self.mask[~self.other_mask,i] = False
# Initialize variables/attributes for mappings
if map_option == "many2one":
n_maps = self.Ny
elif map_option == "many2many":
n_maps = 1
else:
raise Exception("Unknown mapping option: %s" % map_option)
self.map_configs = map_configs
self.map_option = map_option
self._n_maps = n_maps
@property
def maps(self):
if self.trained:
return self._maps
else:
print("KIM has not been trained yet.")
@property
def n_maps(self):
return self._n_maps
[docs]
def train(self, verbose: int=0):
# Initialize
if self.map_option == "many2many":
maps = self._init_map_many2many()
elif self.map_option == "many2one":
maps = self._init_map_many2one()
# Train
for one_map in maps:
one_map.train(verbose=verbose)
self._maps = maps
self.trained = True
def _init_map_many2many(self):
x, y = self.data.xdata_scaled, self.data.ydata_scaled
x, y = jnp.array(x), jnp.array(y) # convert to jnp array
map_configs = deepcopy(self.map_configs)
one_map = Map(x, y, **map_configs)
# one_map.train(verbose=0)
return [one_map]
def _init_map_many2one(self):
if not self.data.sensitivity_done:
raise Exception(
"The sensitivity analysis is not done. \
We can't train the knowledge-informed mapping."
)
xall, yall = self.data.xdata_scaled, self.data.ydata_scaled
mask_all = self.mask
maps = []
for i in range(self.n_maps):
# Get the masked inputs and the outpus
mask = mask_all[:,i]
if mask.sum() == 0:
print(f"There is no sensitive input to the {i} output.")
one_map = None
else:
x, y = xall[:, mask], yall[:, [i]]
x, y = jnp.array(x), jnp.array(y) # convert to jnp array
# Initialize and train the mapping
# print(self.map_configs, x.shape)
map_configs = deepcopy(self.map_configs)
one_map = Map(x, y, **map_configs)
# one_map.train(verbose=0)
maps.append(one_map)
return maps
[docs]
def evaluate_maps_on_givendata(self):
"""Perform predictions on the given dataset
"""
# TODO
# Make the prediction
y_ens, y_mean, y_mean_w, y_std_w, weights = self.predict(x=None)
y_true = self.data.ydata
# Separate them into trainining, validation, and test set
if 'num_train_sample' in self.map_configs['dl_hp_fixed'] and \
'num_val_sample' in self.map_configs['dl_hp_fixed']:
Ns_train = self.map_configs['dl_hp_fixed']['num_train_sample']
Ns_val = self.map_configs['dl_hp_fixed']['num_val_sample']
sep1, sep2 = Ns_train, Ns_train+Ns_val
y_ens_train, y_ens_val, y_ens_test = y_ens[:,:sep1,...], y_ens[:,sep1:sep2,...], y_ens[:,sep2:,...]
y_true_train, y_true_val, y_true_test = y_true[:sep1,...], y_true[sep1:sep2,...], y_true[sep2:,...]
y_mw_train, y_mw_val, y_mw_test = y_mean_w[:sep1,...], y_mean_w[sep1:sep2,...], y_mean_w[sep2:,...]
y_stdw_train, y_stdw_val, y_stdw_test = y_std_w[:sep1,...], y_std_w[sep1:sep2,...], y_std_w[sep2:,...]
elif 'num_train_sample' in self.map_configs['dl_hp_fixed'] and \
'num_val_sample' not in self.map_configs['dl_hp_fixed']:
Ns_train = self.map_configs['dl_hp_fixed']['num_train_sample']
sep1 = Ns_train
y_ens_train, y_ens_val, y_ens_test = y_ens[:,:sep1,...], None, y_ens[:,sep1:,...]
y_true_train, y_true_val, y_true_test = y_true[:sep1,...], None, y_true[sep1:,...]
y_mw_train, y_mw_val, y_mw_test = y_mean_w[:sep1,...], None, y_mean_w[sep1:,...]
y_stdw_train, y_stdw_val, y_stdw_test = y_std_w[:sep1,...], None, y_std_w[sep1:,...]
else:
y_ens_train, y_ens_val, y_ens_test = y_ens, None, None
y_true_train, y_true_val, y_true_test = y_true, None, None
y_mw_train, y_mw_val, y_mw_test = y_mean_w, None
y_stdw_train, y_stdw_val, y_stdw_test = y_std_w, None
# Calculate the performance metrics
Nens, Ny = y_ens.shape[0], self.Ny
if 'num_train_sample' in self.map_configs['dl_hp_fixed'] and \
'num_val_sample' in self.map_configs['dl_hp_fixed']:
rmse_train, mkge_train = np.zeros([Nens,Ny]), np.zeros([Nens,Ny])
rmse_val, mkge_val = np.zeros([Nens,Ny]), np.zeros([Nens,Ny])
rmse_test, mkge_test = np.zeros([Nens,Ny]), np.zeros([Nens,Ny])
for i in range(Nens):
for j in range(Ny):
metrics = compute_metrics(y_ens_train[i,...,j], y_true_train[...,j])
rmse_train[i,j] = metrics['rmse']
mkge_train[i,j] = metrics['mkge']
metrics = compute_metrics(y_ens_val[i,...,j], y_true_val[...,j])
rmse_val[i,j] = metrics['rmse']
mkge_val[i,j] = metrics['mkge']
metrics = compute_metrics(y_ens_test[i,...,j], y_true_test[...,j])
rmse_test[i,j] = metrics['rmse']
mkge_test[i,j] = metrics['mkge']
elif 'num_train_sample' in self.map_configs['dl_hp_fixed'] and \
'num_val_sample' not in self.map_configs['dl_hp_fixed']:
rmse_train, mkge_train = np.zeros([Nens,Ny]), np.zeros([Nens,Ny])
rmse_val, mkge_val = None, None
rmse_test, mkge_test = np.zeros([Nens,Ny]), np.zeros([Nens,Ny])
for i in range(Nens):
for j in range(Ny):
metrics = compute_metrics(y_ens_train[i,...,j], y_true_train[...,j])
rmse_train[i,j] = metrics['rmse']
mkge_train[i,j] = metrics['mkge']
metrics = compute_metrics(y_ens_test[i,...,j], y_true_test[...,j])
rmse_test[i,j] = metrics['rmse']
mkge_test[i,j] = metrics['mkge']
else:
rmse_train, mkge_train = np.zeros([Nens,Ny]), np.zeros([Nens,Ny])
rmse_val, mkge_val = None, None
rmse_test, mkge_test = None, None
for i in range(Nens):
for j in range(Ny):
metrics = compute_metrics(y_ens_train[i,...,j], y_true_train[...,j])
rmse_train[i,j] = metrics['rmse']
mkge_train[i,j] = metrics['mkge']
ens_predict = {'train': y_ens_train, 'val': y_ens_val, 'test': y_ens_test}
wm_predict = {'train': y_mw_train, 'val': y_mw_val, 'test': y_mw_test}
wstd_predict = {'train': y_stdw_train, 'val': y_stdw_val, 'test': y_stdw_test}
true = {'train': y_true_train, 'val': y_true_val, 'test': y_true_test}
rmse = {'train': rmse_train, 'val': rmse_val, 'test': rmse_test}
mkge = {'train': mkge_train, 'val': mkge_val, 'test': mkge_test}
# Calculate bias and uncertainty
wbias = {
'train': np.mean(np.abs(y_true_train-y_mw_train), axis=0),
'val': np.mean(np.abs(y_true_val-y_mw_val), axis=0),
'test': np.mean(np.abs(y_true_test-y_mw_test), axis=0)
}
wrelauncert = {
'train': np.mean(y_stdw_train/y_true_train, axis=0),
'val': np.mean(y_stdw_val/y_true_val, axis=0),
'test': np.mean(y_stdw_test/y_true_test, axis=0)
}
return {
'ens predict': ens_predict,
'weights': weights,
'weighted mean predict': wm_predict,
'weighted std predict': wstd_predict,
'weighted bias': wbias,
'weighted relative uncertainty': wrelauncert,
'true': true,
'rmse': rmse,
'mkge': mkge
}
[docs]
def predict(self, x: Optional[Array]=None):
"""Prediction using the trained KIM.
Args:
x (Array): predictors with shape (Ns,...,Nx)
"""
if x is not None:
assert x.shape[-1] == self.Nx # The same dimension
assert len(x.shape) >= 2 # At least 2 dimensions with the leading batch dimension
else:
x = self.data.xdata
xraw = x
xscaler, yscaler = self.data.xscaler, self.data.yscaler
x = xscaler.transform(xraw)
Ns = x.shape[0]
n_ens = self.map_configs['n_model']
Ny = self.Ny
# n_maps = self.n_maps
# xshape = list(x.shape)
if self.map_option == "many2many":
one_map = self._maps[0]
y_ens, y_mean, y_mean_w, weights = one_map.predict(x)
weights = np.stack([weights]*Ny, axis=-1)
elif self.map_option == "many2one":
y_ens, y_mean, y_mean_w, weights = [], [], [], []
for i,one_map in enumerate(self._maps):
one_mask = self.mask[:,i]
if one_mask.sum() == 0:
assert one_map is None
y_e = np.empty([n_ens, Ns, 1]) + np.nan
w = np.empty([n_ens, Ns, 1]) + np.nan
y_m = np.empty([Ns, 1]) + np.nan
y_mw = np.empty([Ns, 1]) + np.nan
else:
y_e, y_m, y_mw, w = one_map.predict(x[:, one_mask])
w = np.expand_dims(w, axis=-1)
y_ens.append(np.array(y_e))
y_mean.append(np.array(y_m))
weights.append(np.array(w))
y_mean_w.append(np.array(y_mw))
y_ens = np.concat(y_ens, axis=-1)
y_mean = np.concat(y_mean, axis=-1)
y_mean_w = np.concat(y_mean_w, axis=-1)
weights = np.concat(weights, axis=-1)
# Scale back
y_ens = np.array([yscaler.inverse_transform(y) for y in y_ens])
y_mean = yscaler.inverse_transform(y_mean)
y_mean_w= yscaler.inverse_transform(y_mean_w)
# Calculate the weighted standard deviation
def calculate_wstd(yens, ymw, w):
return np.sqrt(np.average((yens-ymw)**2, weights=w, axis=0))
# y_std_w = np.sqrt(np.average((y_ens-y_mean_w)**2, weights=weights, axis=0))
y_std_w = []
for i in range(weights.shape[-1]):
yens, ymw, w = y_ens[...,i], y_mean_w[...,i], weights[...,i]
y_std_w.append(calculate_wstd(yens, ymw, w))
y_std_w = np.stack(y_std_w, axis=-1)
# calculate_wstd = np.vectorize(calculate_wstd, signature='(m,n),(n),(m)->(n)')
# y_std_w = calculate_wstd(y_ens, y_mean_w, weights)
return y_ens, y_mean, y_mean_w, y_std_w, weights
[docs]
def save(self, rootpath: PosixPath=Path('./')):
"""Save the KIM, including:
- the data object
- all the mappings
- the remaining configurations
Args:
rootpath (PosixPath): the root path where data will be saved
"""
if not self.trained:
raise Exception("KIM has not been trained yet.")
if not rootpath.exists():
rootpath.mkdir(parents=True)
# Save the data object
f_data = rootpath / "data"
self.data.save(f_data)
# Save all the mappings
f_map_set = [rootpath / f'map{i}' for i in range(self._n_maps)]
for i,one_map in enumerate(self._maps):
one_map.save(f_map_set[i])
# Save the remaining configurations
f_configs = rootpath / "configs.pkl"
configs = {
"name": self.name,
"map_configs": self.map_configs,
"map_option": self.map_option,
"n_maps": self._n_maps,
"other_mask": self.other_mask
}
with open(f_configs, "wb") as f:
pickle.dump(configs, f)
[docs]
def load(self, rootpath: PosixPath=Path("./")):
"""load the trained KIM from specified location.
Args:
rootpath (PosixPath): the root path where KIM will be loaded
"""
# Load the overall configurations
f_configs = rootpath / "configs.pkl"
with open(f_configs, "rb") as f:
configs = pickle.load(f)
self.name = configs["name"]
self.map_configs = configs["map_configs"]
self.map_options = configs["map_option"]
self._n_maps = configs["n_maps"]
self.other_mask = configs["other_mask"]
# Check whether additional masks are needed
if self.other_mask is not None:
assert len(self.other_mask) == self.Nx
for i in range(self.Ny):
self.mask[~self.other_mask,i] = False
# Load the data object
f_data = rootpath / "data"
self.data.load(f_data, overwrite=True)
self.Ns, self.Nx, self.Ny = self.data.Ns, self.data.Nx, self.data.Ny
if self.mask_option == "sensitivity":
self.mask = self.data.sensitivity_mask
elif self.mask_option == "cond_sensitivity":
self.mask = self.data.cond_sensitivity_mask
else:
raise Exception("Unknown mask_option: %s" % self.mask_option)
# Load the trained mappings
if self.map_option == "many2many":
f_mapping = rootpath / 'map0'
one_map = self._init_map_many2many()[0]
one_map.load(f_mapping)
maps = [one_map]
elif self.map_option == "many2one":
f_mapping_set = [rootpath / f'map{i}' for i in range(self._n_maps)]
maps = self._init_map_many2one()
for i,one_map in enumerate(maps):
one_map.load(f_mapping_set[i])
self._maps = maps
self.trained = True
self.loaded_from_other_sources = True
[docs]
class Map(object):
"""The class for one mapping training, prediction, saving and loading. Ensemble training is supported through either serial or parallel way, using joblib.
Attributes
----------
x : array_like
the copy of the __init__ argument
y : array_like
the copy of the __init__ argument
n_model : int
the copy of the __init__ argument
training_parallel : bool
the copy of the __init__ argument
model_type : type
the copy of the __init__ argument
ensemble_type : str
the copy of the __init__ argument
model_hp_choices : dict
the copy of the __init__ argument
model_hp_fixed : dict
the copy of the __init__ argument
optax_hp_choices : dict
the copy of the __init__ argument
optax_hp_fixed : dict
the copy of the __init__ argument
dl_hp_choices : dict
the copy of the __init__ argument
dl_hp_fixed : dict
the copy of the __init__ argument
training_parallel : bool
the copy of the __init__ argument
ens_seed : Optional[int], optional)
the copy of the __init__ argument
parallel_config : Optional[dict], optional)
the copy of the __init__ argument
device : Optional[Device], optional
the copy of the __init__ argument
trained : bool
whether the mapping has been trained
loaded_from_other_sources : bool
whether the mapping is loaded from other sources.
Ns : int
number of samples
Nx : int
number of input features
Ny : int
number of output features
model_configs : list
model hyperparameters for all ensemble models
optax_configs : list
optimizer hyperparameters for all ensemble models
dl_configs : list
dataloader hyperparameters for all ensemble models
model_ens : list
list of trained model ensemble
loss_train_ens : list
list of the training losses over steps
loss_val_ens : list
list of the val losses over steps
"""
[docs]
def __init__(
self, x: Array, y: Array, model_type: type=MLP,
n_model: int=0, ensemble_type: str='single',
model_hp_choices: dict={}, model_hp_fixed: dict={},
optax_hp_choices: dict={}, optax_hp_fixed: dict={},
dl_hp_choices: dict={}, dl_hp_fixed: dict={},
training_parallel: bool=True,
ens_seed: Optional[int]=None,
parallel_config: Optional[dict]=None,
device: Optional[Device]=None
):
"""Initialization function.
Args:
x (array-like): the predictors with shape (Ns, Nx)
y (array-like): the predictands with shape (Ns, Ny)
model_type (type): the equinox model class
n_model (int): the number of ensemble models
ensemble_type (str): the ensemble type, either 'single', 'serial' or 'parallel'.
model_hp_choices (dict): the tunable model hyperparameters, in dictionary format {key: [value1, value2,...]}. The model hyperparameters must follow the arguments of the specified model_type
model_hp_fixed (dict): the fixed model hyperparameters, in dictionary format {key: value}. The model hyperparameters must follow the arguments of the specified model_type
optax_hp_choices (dict): the tunable optimizer hyperparameters, in dictionary format {key: [value1, value2,...]}. The optimizer hyperparameters must follow the arguments of the specified optax optimizer. Key hyperparameters: 'optimizer_type' (str), 'nsteps' (int), and 'loss_func' (callable)
optax_hp_fixed (dict): the fixed optimizer hyperparameters, in dictionary format {key: value}. The optimizer hyperparameters must follow the arguments of the specified model_type. Key hyperparameters: 'optimizer_type' (str), 'nsteps' (int), and 'loss_func' (callable)
dl_hp_choices (dict): the tunable dataloader hyperparameters, in dictionary format {key: [value1, value2,...]}. The optimizer hyperparameters must follow the arguments of make_big_data_loader. Key hyperparameters: 'batch_size' (int) and 'num_train_sample' (int)
dl_hp_fixed (dict): the fixed dataloader hyperparameters, in dictionary format {key: value}. The optimizer hyperparameters must follow the arguments of make_big_data_loader. Key hyperparameters: 'batch_size' (int) and 'num_train_sample' (int
training_parallel (bool): whether to perform parallel training
ens_seed (Optional[int], optional): the random seed for generating ensemble configurations.
parallel_config (Optional[dict], optional): the parallel training configurations following the arguments of joblib.Parallel
device (Optional[Device], optional): the computing device to be set
"""
# TODO: Need a great way to pass the computational device
# somehow coupled to the parallel training
# for now, the parallel training uses joblib through multiple CPUs
self.x, self.y = x, y
self.training_parallel = training_parallel
self.parallel_config = parallel_config
self.device = device
self.trained = False
self.loaded_from_other_sources = False
# Set up the random seed for ensemble generation
random.seed(ens_seed)
# Get the data dimensions
assert self.x.shape[0] == self.y.shape[0]
self.Ns, self.Nx, self.Ny = x.shape[0], x.shape[-1], y.shape[-1]
# Get model configs
self.model_type = model_type
self.ensemble_type = ensemble_type
self.model_hp_choices = model_hp_choices
self.model_hp_fixed = model_hp_fixed
self.optax_hp_choices = optax_hp_choices
self.optax_hp_fixed = optax_hp_fixed
self.dl_hp_choices = dl_hp_choices
self.dl_hp_fixed = dl_hp_fixed
self.n_model_init = n_model
self._get_model_configs()
@property
def n_model(self):
return len(self._model_configs)
@property
def model_configs(self):
return self._model_configs
@property
def optax_configs(self):
return self._optax_configs
@property
def dl_configs(self):
return self._dl_configs
@property
def model_ens(self):
if self.trained:
return self._model_ens
else:
print("Mapping has not been trained yet.")
@property
def loss_train_ens(self):
if self.trained:
return self._loss_train_ens
else:
print("Mapping has not been trained yet.")
@property
def loss_val_ens(self):
if self.trained:
return self._loss_val_ens
else:
print("Mapping has not been trained yet.")
def _get_model_configs(self):
# Check key configs
# TODO: A naming convention should be implemented in KIM.
# e.g., the input and output parameters used in the DNN models.
# Numbers of model inputs and outputs should be fixed
if "in_size" in self.model_hp_choices:
raise Exception("Input size should not be tunabled!")
if "out_size" in self.model_hp_choices:
raise Exception("Output size should not be tunabled!")
if 'in_size' in self.model_hp_fixed and self.model_hp_fixed['in_size'] != self.Nx:
raise Exception("Input size of the model is not: ", self.model_hp_fixed['in_size'])
if 'out_size' in self.model_hp_fixed and self.model_hp_fixed['out_size'] != self.Ny:
raise Exception("Output size of the model is not: ", self.model_hp_fixed['out_size'])
self.model_hp_fixed['in_size'] = self.Nx
self.model_hp_fixed['out_size'] = self.Ny
if 'optimizer_type' not in self.optax_hp_fixed and \
'optimizer_type' not in self.optax_hp_choices:
self.optax_hp_fixed['optimizer_type'] = 'Adam'
if 'nsteps' not in self.optax_hp_fixed and \
'nsteps' not in self.optax_hp_choices:
self.optax_hp_fixed['nsteps'] = 100
if 'loss_func' not in self.optax_hp_fixed and \
'loss_func' not in self.optax_hp_choices:
self.optax_hp_fixed['loss_func'] = loss_mse
if 'batch_size' not in self.dl_hp_fixed and \
'batch_size' not in self.dl_hp_choices:
self.dl_hp_fixed['batch_size'] = 32
if 'num_train_sample' not in self.dl_hp_fixed and \
'num_train_sample' not in self.dl_hp_choices:
self.dl_hp_fixed['num_train_sample'] = self.Ns
if 'num_val_sample' not in self.dl_hp_fixed and \
'num_val_sample' not in self.dl_hp_choices:
self.dl_hp_fixed['num_val_sample'] = self.Ns - self.dl_hp_fixed['num_train_sample']
# if 'device' not in self.dl_hp_fixed:
# self.dl_hp_fixed['device'] = self.de
# Generate ensemble configurations
n_model, model_configs, optax_configs, dl_configs = generate_ensemble_configs(
self.model_hp_choices, self.model_hp_fixed,
self.optax_hp_choices, self.optax_hp_fixed,
self.dl_hp_choices, self.dl_hp_fixed,
self.n_model_init, self.ensemble_type,
)
# self.n_model = n_model
self._model_configs = model_configs
self._optax_configs = optax_configs
self._dl_configs = dl_configs
# _, self.model_configs = generate_ensemble_configs(
# self.model_hp_choices, self.model_hp_fixed, self.n_model, self.ensemble_type
# )
# _, self.optax_configs = generate_ensemble_configs(
# self.optax_hp_choices, self.optax_hp_fixed, self.n_model, self.ensemble_type
# )
# self.n_model, self.dl_configs = generate_ensemble_configs(
# self.dl_hp_choices, self.dl_hp_fixed, self.n_model, self.ensemble_type
# )
[docs]
def train(self, verbose: int=0):
"""Mapping training.
Args:
verbose (int): the verbosity level (0: normal, 1: debug)
"""
if self.trained:
raise Exception("The mapping has already been trained!")
model_ens, loss_train_ens, loss_val_ens = train_ensemble(
self.x, self.y, self.model_type,
self.model_configs, self.optax_configs, self.dl_configs,
self.training_parallel, self.parallel_config, verbose
)
self._model_ens = model_ens
self._loss_train_ens = loss_train_ens
self._loss_val_ens = loss_val_ens
self.trained = True
[docs]
def predict(self, x: Array):
"""Prediction using the trained mapping.
Args:
x (Array): predictors with shape (Ns,...,Nx)
"""
assert x.shape[-1] == self.Nx # The same dimension
assert len(x.shape) >= 2 # At least 2 dimensions with the leading batch dimension
# Perform predictions on all models
y_ens = []
for i in range(self.n_model):
y = jax.vmap(self.model_ens[i])(x)
y_ens.append(y)
y_ens = jnp.array(y_ens)
# Calculate mean
y_mean = jnp.array(y_ens).mean(axis=0)
# print(y_mean)
# Calculate weighted mean based on loss
# loss_ens = self.loss_val_ens if len(self.loss_val_ens)>0 else self.loss_train_ens
loss_ens = self.loss_val_ens if self.loss_val_ens[0] is not None else self.loss_train_ens
loss = jnp.array([l_all[-1] for l_all in loss_ens])
weights = 1./loss / jnp.sum(1./loss)
weighted_product = jax.vmap(lambda w, y: w*y, in_axes=(0,0))
y_ens_w = weighted_product(weights, y_ens)
y_mean_w = y_ens_w.sum(axis=0)
return y_ens, y_mean, y_mean_w, weights
[docs]
def save(self, rootpath: PosixPath=Path("./")):
"""Save the trained mapping to specified location, including:
- trained models
- model/optax/dl configurations
- loss values for both training and validation sets
Args:
rootpath (PosixPath): the root path where mappings will be saved
"""
if not self.trained:
raise Exception("Mapping has not been trained yet.")
if not rootpath.exists():
rootpath.mkdir(parents=True)
# Dump overall configurations
f_overall_configs = rootpath / "configs.pkl"
overall_configs = {
"n_model": self.n_model,
"ensemble_type": self.ensemble_type,
"training_parallel": self.training_parallel,
"parallel_config": self.parallel_config,
"device": self.device,
"Ns": self.Ns, "Nx": self.Nx, "Ny": self.Ny,
"model_type": self.model_type,
}
with open(f_overall_configs, "wb") as f:
pickle.dump(overall_configs, f)
# Dump each model, its configuration, and its loss values
for i, model in enumerate(self.model_ens):
model_dir = rootpath / str(i)
if not model_dir.exists():
model_dir.mkdir(parents=True)
f_model = model_dir / "model.eqx"
f_configs = model_dir / "configs.pkl"
f_loss = model_dir / "loss.pkl"
# Save the trained model
model_configs = self.model_configs[i]
save_model(f_model, model_configs, self.model_ens[i])
# Save the configuration
configs = {
"model_configs": self.model_configs[i],
"optax_configs": self.optax_configs[i],
"dl_configs": self.dl_configs[i],
}
with open(f_configs, "wb") as f:
pickle.dump(configs, f)
# Save its loss values
loss = {
"train": self.loss_train_ens[i],
"val": self.loss_val_ens[i]
}
with open(f_loss, "wb") as f:
pickle.dump(loss, f)
[docs]
def load(self, rootpath: PosixPath=Path("./")):
"""load the trained mapping from specified location.
Args:
rootpath (PosixPath): the root path where mappings will be loaded
"""
if self.trained:
raise Exception("Mapping has already been trained.")
# Load the overall configuration
f_overall_configs = rootpath / "configs.pkl"
with open(f_overall_configs, "rb") as f:
overall_configs = pickle.load(f)
n_model = overall_configs["n_model"]
self.ensemble_type = overall_configs["ensemble_type"]
self.training_parallel = overall_configs["training_parallel"]
self.parallel_config = overall_configs["parallel_config"]
self.device = overall_configs["device"]
Ns, Nx, Ny = overall_configs["Ns"], overall_configs["Nx"], overall_configs["Ny"]
self.model_type = overall_configs["model_type"]
assert Nx == self.Nx
assert Ny == self.Ny
# Load each model, its configuration, and its loss values
model_ens = []
model_configs, optax_configs, dl_configs = [], [], []
loss_train_ens, loss_val_ens = [], []
for i in range(n_model):
f_model = rootpath / str(i) / "model.eqx"
f_configs = rootpath / str(i) / "configs.pkl"
f_loss = rootpath / str(i) / "loss.pkl"
# Save the trained model
m = load_model(f_model, self.model_type)
model_ens.append(m)
# Save the configuration
with open(f_configs, "rb") as f:
configs = pickle.load(f)
model_configs.append(configs["model_configs"])
optax_configs.append(configs["optax_configs"])
dl_configs.append(configs["dl_configs"])
# Save its loss values
with open(f_loss, "rb") as f:
loss = pickle.load(f)
loss_train_ens.append(loss["train"])
loss_val_ens.append(loss["val"])
self._model_ens = model_ens
self._model_configs = model_configs
self._optax_configs = optax_configs
self._dl_configs = dl_configs
self._loss_train_ens = loss_train_ens
self._loss_val_ens = loss_val_ens
self.loaded_from_other_sources = True
self.trained = True
def generate_ensemble_configs(
model_hp_choices: dict, model_hp_fixed: dict,
optax_hp_choices: dict, optax_hp_fixed: dict,
dl_hp_choices: dict, dl_hp_fixed: dict,
n_model: int, ens_type: str,
):
hp_all = [(model_hp_choices, model_hp_fixed),
(optax_hp_choices, optax_hp_fixed),
(dl_hp_choices, dl_hp_fixed)]
# Check there is no overlapped keys between hp_choices and hp_fixed
for hp_choices, hp_fixed in hp_all:
hp_keys1 = list(hp_choices.keys())
hp_keys2 = list(hp_fixed.keys())
assert all(i not in hp_keys1 for i in hp_keys2)
for key, value in hp_choices.items():
assert isinstance(value, list)
for key, value in hp_fixed.items():
assert isinstance(value, float) | isinstance(value, int) | \
isinstance(value, str) | callable(value)
# Generate the ensemble configs
model_configs, optax_configs, dl_configs = [], [], []
if ens_type == 'single':
n_model = 1
model_configs = [model_hp_fixed]
optax_configs = [optax_hp_fixed]
dl_configs = [dl_hp_fixed]
elif ens_type == 'ens_random':
# Get the configurations for each ensemble member
for i in range(n_model):
config_three = []
for hp_choices, hp_fixed in hp_all:
config = {}
# Fixed configurations
for key, value in hp_fixed.items():
config[key] = value
# Tuned configurations
for key, choices in hp_choices.items():
value = random.sample(choices, 1)[0]
config[key] = value
config_three.append(config)
model_configs.append(config_three[0])
optax_configs.append(config_three[1])
dl_configs.append(config_three[2])
elif ens_type == 'ens_grid':
# Get all the combinations of tuned configurations
hp_choices_three = {
**model_hp_choices, **optax_hp_choices, **dl_hp_choices
}
keys_c, options_c = zip(*hp_choices_three.items())
combinations = list(itertools.product(*options_c))
n_model = len(combinations)
# Get the configurations for each ensemble member
for i in range(n_model):
config_three = []
tuned_config = dict(zip(keys_c, combinations[i]))
for hp_choices, hp_fixed in hp_all:
config = {}
# Fixed configurations
for key, value in hp_fixed.items():
config[key] = value
# Tuned configurations
for key, choices in hp_choices.items():
config[key] = tuned_config[key]
config_three.append(config)
model_configs.append(config_three[0])
optax_configs.append(config_three[1])
dl_configs.append(config_three[2])
else:
raise Exception('Unknown ensemble type %s' % ens_type)
return n_model, model_configs, optax_configs, dl_configs
def save_model(filename, hyperparams, model):
with open(filename, "wb") as f:
hyperparam_str = json.dumps(hyperparams)
# hyperparam_str = json.dumps({})
f.write((hyperparam_str + "\n").encode())
eqx.tree_serialise_leaves(f, model)
def load_model(filename, model_type):
with open(filename, "rb") as f:
hyperparams = json.loads(f.readline().decode())
# print("hyperparameters: ")
# print(hyperparams)
# print("Model type: ")
# print(model_type)
model = model_type(**hyperparams)
return eqx.tree_deserialise_leaves(f, model)
# def generate_ensemble_configs(
# hp_choices: dict, hp_fixed: dict, n_model: int, ens_type: str
# ):
# # Check there is no overlapped keys between hp_choices and hp_fixed
# hp_keys1 = list(hp_choices.keys())
# hp_keys2 = list(hp_fixed.keys())
# assert all(i not in hp_keys1 for i in hp_keys2)
# for key, value in hp_choices.items():
# assert isinstance(value, list)
# for key, value in hp_fixed.items():
# assert isinstance(value, float) | isinstance(value, int) | \
# isinstance(value, str) | callable(value)
# # Generate the ensemble configs
# configs = []
# if ens_type == 'single':
# n_model = 1
# configs = [hp_fixed]
# elif ens_type == 'ens_random':
# for i in range(n_model):
# config = {}
# # Fixed configurations
# for key, value in hp_fixed.items():
# config[key] = value
# # Tuned configurations
# for key, choices in hp_choices.items():
# value = random.sample(choices, 1)[0]
# config[key] = value
# configs.append(config)
# elif ens_type == 'ens_grid':
# # Get all the combinations of tuned configurations
# keys_c, options_c = zip(*hp_choices.items())
# combinations = list(itertools.product(*options_c))
# n_model = len(combinations)
# for i in range(n_model):
# # Tuned configurations
# config = dict(zip(keys_c, combinations[i]))
# # Fixed configurations
# for key, value in hp_fixed.items():
# config[key] = value
# configs.append(config)
# else:
# raise Exception('Unknown ensemble type %s' % ens_type)
# return n_model, configs