Skip to content

dj_decoder_conversion.py

Converts decoder classes into dictionaries and dictionaries into classes so that datajoint can store them in tables.

restore_classes(params)

Converts a dictionary of parameters into a dictionary of classes since datajoint cannot handle classes

Parameters:

Name Type Description Default
params dict

The parameters to convert

required

Returns:

Name Type Description
converted_params dict

The converted parameters

Source code in src/spyglass/decoding/v1/dj_decoder_conversion.py
def restore_classes(params: dict) -> dict:
    """Converts a dictionary of parameters into a dictionary of classes
    since datajoint cannot handle classes

    Parameters
    ----------
    params : dict
        The parameters to convert

    Returns
    -------
    converted_params : dict
        The converted parameters
    """

    params = copy.deepcopy(params)

    continuous_state_transition_types = _map_class_name_to_class(cst)
    discrete_state_transition_types = _map_class_name_to_class(dst)
    continuous_initial_conditions_types = _map_class_name_to_class(ic)

    params["environments"] = [
        _convert_env_dict(env_params) for env_params in params["environments"]
    ]

    params["continuous_transition_types"] = [
        [
            _convert_dict_to_class(st, continuous_state_transition_types)
            for st in sts
        ]
        for sts in params["continuous_transition_types"]
    ]
    params["discrete_transition_type"] = _convert_dict_to_class(
        params["discrete_transition_type"],
        discrete_state_transition_types,
    )
    params["continuous_initial_conditions_types"] = [
        _convert_dict_to_class(cont_ic, continuous_initial_conditions_types)
        for cont_ic in params["continuous_initial_conditions_types"]
    ]

    if params["observation_models"] is not None:
        params["observation_models"] = [
            ObservationModel(**obs) for obs in params["observation_models"]
        ]

    return params

convert_classes_to_dict(params)

Converts the classifier parameters into a dictionary so that datajoint can store it.

Source code in src/spyglass/decoding/v1/dj_decoder_conversion.py
def convert_classes_to_dict(params: dict) -> dict:
    """Converts the classifier parameters into a dictionary so that datajoint can store it."""
    params = copy.deepcopy(params)
    try:
        params["environments"] = [
            _convert_environment_to_dict(env) for env in params["environments"]
        ]
    except TypeError:
        params["environments"] = [
            _convert_environment_to_dict(params["environments"])
        ]
    params["continuous_transition_types"] = _convert_transitions_to_dict(
        params["continuous_transition_types"]
    )
    params["discrete_transition_type"] = _to_dict(
        params["discrete_transition_type"]
    )
    params["continuous_initial_conditions_types"] = [
        _to_dict(cont_ic)
        for cont_ic in params["continuous_initial_conditions_types"]
    ]

    if params["observation_models"] is not None:
        params["observation_models"] = [
            vars(obs) for obs in params["observation_models"]
        ]

    try:
        params["clusterless_algorithm_params"] = _convert_algorithm_params(
            params["clusterless_algorithm_params"]
        )
    except KeyError:
        pass

    return params