Source code for mlens.base.id_train

"""ML-ENSEMBLE

:author: Sebastian Flennerhag
:copyright: 2017
:licence: MIT

Class for identifying a training set after an estimator has been fitted. Used
for determining whether a `predict` or `transform` method should use
cross validation to create predictions, or estimators fitted on full
training data.
"""

from __future__ import division, print_function

from ..utils.exceptions import NotFittedError
from ..externals.sklearn.base import BaseEstimator

from numpy import array_equal, ix_
from numpy.random import permutation
from numbers import Integral


[docs]class IdTrain(BaseEstimator): """Container to identify training set. Samples a random subset from set passed to the `fit` method, to allow identification of the training set in a `transform` or `predict` method. Parameters ---------- size : int size to sample. A random subset of size [size, size] will be stored in the instance. """ def __init__(self, size=10): if not isinstance(size, Integral): raise ValueError("'size' must be an integer. Got %r" % size) self.size = size
[docs] def fit(self, X): """Sample a training set. Parameters ---------- X: array-like training set to sample observations from. Returns ---------- self: obj fitted instance with stored sample. """ self.train_shape = X.shape sample_idx = {} for i in range(2): dim_size = min(X.shape[i], self.size) sample_idx[i] = permutation(X.shape[i])[:dim_size] sample = X[ix_(sample_idx[0], sample_idx[1])] self.sample_idx_ = sample_idx self.sample_ = sample return self
[docs] def is_train(self, X): """Check if an array is the training set. Parameters ---------- X: array-like training set to sample observations from. Returns ---------- self: obj fitted instance with stored sample. """ if not hasattr(self, "train_shape"): raise NotFittedError("This IdTrain instance is not fitted yet.") if not self._check_shape(X): return False idx = self.sample_idx_ try: # Grab sample from `X` sample = X[ix_(idx[0], idx[1])] return array_equal(sample, self.sample_) except IndexError: # If index is out of bounds, X.shape < training_set.shape # -> X is not the training set return False
def _check_shape(self, X): """Check if X has the shape as the training set.""" return all([X.shape[i] == self.train_shape[i] for i in range(2)])