Source code for mlens.utils.formatting

"""ML-ENSEMBLE

:author: Sebastian Flennerhag
:copyright: 2017
:licence: MIT

Formatting of instance lists.
"""

from __future__ import division, print_function

from .checks import assert_valid_estimator
from .exceptions import LayerSpecificationError
from collections import Counter


def _format_instances(instances):
    """Format a list of instances to a list of named estimator tuples."""
    named_instances = []
    for val in instances:
        # Check that the instance appears correctly specified
        if not isinstance(val, (list, tuple, set)):
            # val is the instance
            instance = val
        else:
            # val is a list-like object. Assume instance is the last entry
            instance = val[-1]

        # Check if it appears to be an estimator
        assert_valid_estimator(instance)

        try:
            # Format instance names

            # We keep the instance as a list to change possible duplicate names
            # exploiting that lists are mutable, before switching to tuple
            if instance == val:
                tup = [instance.__class__.__name__.lower(), instance]
            else:
                tup = ['-'.join(val[0].split()).lower(), val[-1]]

            named_instances.append(tup)

        except Exception as e:
            msg = ("Could not format instance %s. Check that passed instance "
                   "iterables follow correct syntax:\n"
                   "- if multiple preprocessing cases, pass a dictionary with "
                   "instance lists as values and case name as key.\n"
                   "- else, pass list of (named) instances.\n"
                   "See documentation for further information.\n"
                   "Error details: %r")
            raise LayerSpecificationError(msg % (instance, e))

    # Check and correct duplicate names
    duplicates = Counter([tup[0] for tup in named_instances])
    duplicates = {key: val for key, val in duplicates.items() if
                  val > 1}

    out = []  # final named_instances list

    name_count = {key: 1 for key in duplicates}
    for name, instance in named_instances:
        if name in duplicates:
            current_name_count = name_count[name]  # fix before update
            name_count[name] += 1
            name += '-%d' % current_name_count  # rename
        out.append((name, instance))

    return out


def _check_format(instance_list):
    """Quick check of an instance list to see if the format is correct."""
    # Assert list instance
    if not isinstance(instance_list, list):
        return False

    # If empty list, no preprocessing case
    if len(instance_list) == 0:
        return True

    # Check if each element in instance_list is a named instance tuple
    for element in instance_list:

        # Check that element is a tuple
        if not isinstance(element, tuple) or len(element) != 2:
            return False

        # Check that the first element is a string with no spaces,
        # the latter an estimator
        is_str = isinstance(element[0], str)
        no_spa = ' ' not in element[0]
        is_est = (hasattr(element[1], 'get_params') and
                  hasattr(element[1], 'fit'))
        if not (is_str and is_est and no_spa):
            return False

        # Check that the last element is a valid estimator
        assert_valid_estimator(element[1])

    # Check that there are no duplicate names
    names = Counter([tup[0] for tup in instance_list])
    if max([val for val in names.values()]) > 1:
        return False

    # If instances passes above criterion, it's correctly specified
    return True


def _assert_format(instances):
    """Assert that a generic instances object is correctly formatted."""
    if isinstance(instances, dict):
        # Need to check every instance list across preprocessing cases
        for instance_list in instances.values():
            if not _check_format(instance_list):
                return False
        return True

    # For list, check the given list
    return _check_format(instances)


[docs]def check_instances(instances): """Helper to ensure all instances are named. Check if ``instances`` is formatted as expected, and if not convert formatting or throw traceback error if impossible to anticipate formatting. Parameters ---------- instances : iterable instance iterable to test. Returns ------- formatted : list or dict formatted ``instances`` object. Will be formatted as a dict if preprocessing cases are detected, otherwise as a list. The dict will contain lists identical to those in the single preprocessing case. Each list is of the form ``[('name', instance]`` and no names overlap. Raises ------ LayerSpecificationError : Raises error if formatting fails, which is most likely due to wrong ordering of tuple entries, or wrong argument in the wrong position. See Also -------- :class:`mlens.ensemble.base.Layer` """ is_iterable = isinstance(instances, (list, tuple, dict)) if instances is None or is_iterable and len(instances) == 0: # If no instances specified, return empty list return [] elif not is_iterable: # Instance is the estimator, wrap in list and continue instances = [instances] if _assert_format(instances): # If format is ok, return as is return instances else: # reformat if isinstance(instances, dict): # We need to check the instance list of each case out = {} for case, case_list in instances.items(): out['-'.join(case.lower().split())] = \ _format_instances(case_list) return out else: return _format_instances(instances)