Skip to content

curation.py

CurationV1

Bases: SpyglassMixin, Manual

Source code in src/spyglass/spikesorting/v1/curation.py
@schema
class CurationV1(SpyglassMixin, dj.Manual):
    definition = """
    # Curation of a SpikeSorting. Use `insert_curation` to insert rows.
    -> SpikeSorting
    curation_id=0: int
    ---
    parent_curation_id=-1: int
    -> AnalysisNwbfile
    object_id: varchar(72)
    merges_applied: bool
    description: varchar(100)
    """

    @classmethod
    def insert_curation(
        cls,
        sorting_id: str,
        parent_curation_id: int = -1,
        labels: Union[None, Dict[str, List[str]]] = None,
        merge_groups: Union[None, List[List[str]]] = None,
        apply_merge: bool = False,
        metrics: Union[None, Dict[str, Dict[str, float]]] = None,
        description: str = "",
    ):
        """Insert a row into CurationV1.

        Parameters
        ----------
        sorting_id : str
            The key for the original SpikeSorting
        parent_curation_id : int, optional
            The curation id of the parent curation
        labels : dict or None, optional
            curation labels (e.g. good, noise, mua)
        merge_groups : dict or None, optional
            groups of unit IDs to be merged
        metrics : dict or None, optional
            Computed quality metrics, one for each neuron
        description : str, optional
            description of this curation or where it originates; e.g. FigURL

        Note
        ----
        Example curation.json (output of figurl):
        {
         "labelsByUnit":
            {"1":["noise","reject"],"10":["noise","reject"]},
         "mergeGroups":
            [[11,12],[46,48],[51,54],[50,53]]
        }

        Returns
        -------
        curation_key : dict
        """
        AnalysisNwbfile()._creation_times["pre_create_time"] = time()

        sort_query = cls & {"sorting_id": sorting_id}
        parent_curation_id = max(parent_curation_id, -1)
        if parent_curation_id == -1:
            parent_curation_id = -1
            # check to see if this sorting with a parent of -1
            # has already been inserted and if so, warn the user
            query = sort_query & {"parent_curation_id": -1}
            if query:
                Warning("Sorting has already been inserted.")
                return query.fetch("KEY")

        # generate curation ID
        existing_curation_ids = sort_query.fetch("curation_id")
        curation_id = max(existing_curation_ids, default=-1) + 1

        # write the curation labels, merge groups,
        # and metrics as columns in the units table of NWB
        analysis_file_name, object_id = _write_sorting_to_nwb_with_curation(
            sorting_id=sorting_id,
            labels=labels,
            merge_groups=merge_groups,
            metrics=metrics,
            apply_merge=apply_merge,
        )

        # INSERT
        AnalysisNwbfile().add(
            (SpikeSortingSelection & {"sorting_id": sorting_id}).fetch1(
                "nwb_file_name"
            ),
            analysis_file_name,
        )

        key = {
            "sorting_id": sorting_id,
            "curation_id": curation_id,
            "parent_curation_id": parent_curation_id,
            "analysis_file_name": analysis_file_name,
            "object_id": object_id,
            "merges_applied": apply_merge,
            "description": description,
        }
        cls.insert1(
            key,
            skip_duplicates=True,
        )
        AnalysisNwbfile().log(analysis_file_name, table=cls.full_table_name)

        return key

    @classmethod
    def insert_metric_curation(cls, key: Dict, apply_merge=False):
        """Insert a row into CurationV1.

        Parameters
        ----------
        key : Dict
            primary key of MetricCuration

        Returns
        -------
        curation_key : Dict
        """
        from spyglass.spikesorting.v1.metric_curation import (
            MetricCuration,
            MetricCurationSelection,
        )

        sorting_id, parent_curation_id = (MetricCurationSelection & key).fetch1(
            "sorting_id", "curation_id"
        )

        curation_key = cls.insert_curation(
            sorting_id=sorting_id,
            parent_curation_id=parent_curation_id,
            labels=MetricCuration.get_labels(key) or None,
            merge_groups=MetricCuration.get_merge_groups(key) or None,
            apply_merge=apply_merge,
            description=(f"metric_curation_id: {key['metric_curation_id']}"),
        )

        return curation_key

    @classmethod
    def get_recording(cls, key: dict) -> si.BaseRecording:
        """Get recording related to this curation as spikeinterface BaseRecording

        Parameters
        ----------
        key : dict
            primary key of CurationV1 table
        """

        analysis_file_name = (
            SpikeSortingRecording * SpikeSortingSelection & key
        ).fetch1("analysis_file_name")
        analysis_file_abs_path = AnalysisNwbfile.get_abs_path(
            analysis_file_name
        )
        recording = se.read_nwb_recording(
            analysis_file_abs_path, load_time_vector=True
        )
        recording.annotate(is_filtered=True)

        return recording

    @classmethod
    def get_sorting(cls, key: dict) -> si.BaseSorting:
        """Get sorting in the analysis NWB file as spikeinterface BaseSorting

        Parameters
        ----------
        key : dict
            primary key of CurationV1 table

        Returns
        -------
        sorting : si.BaseSorting

        """
        recording = cls.get_recording(key)
        sampling_frequency = recording.get_sampling_frequency()
        analysis_file_name = (CurationV1 & key).fetch1("analysis_file_name")
        analysis_file_abs_path = AnalysisNwbfile.get_abs_path(
            analysis_file_name
        )
        with pynwb.NWBHDF5IO(
            analysis_file_abs_path, "r", load_namespaces=True
        ) as io:
            nwbf = io.read()
            units = nwbf.units.to_dataframe()
        units_dict_list = [
            {
                unit_id: np.searchsorted(recording.get_times(), spike_times)
                for unit_id, spike_times in zip(
                    units.index, units["spike_times"]
                )
            }
        ]

        sorting = si.NumpySorting.from_unit_dict(
            units_dict_list, sampling_frequency=sampling_frequency
        )

        return sorting

    @classmethod
    def get_merged_sorting(cls, key: dict) -> si.BaseSorting:
        """Get sorting with merges applied.

        Parameters
        ----------
        key : dict
            CurationV1 key

        Returns
        -------
        sorting : si.BaseSorting

        """
        recording = cls.get_recording(key)

        curation_key = (cls & key).fetch1()

        sorting_analysis_file_abs_path = AnalysisNwbfile.get_abs_path(
            curation_key["analysis_file_name"]
        )
        si_sorting = se.read_nwb_sorting(
            sorting_analysis_file_abs_path,
            sampling_frequency=recording.get_sampling_frequency(),
        )

        with pynwb.NWBHDF5IO(
            sorting_analysis_file_abs_path, "r", load_namespaces=True
        ) as io:
            nwbfile = io.read()
            nwb_sorting = nwbfile.objects[curation_key["object_id"]]
            merge_groups = nwb_sorting["merge_groups"][:]

        if merge_groups:
            units_to_merge = _merge_dict_to_list(merge_groups)
            return sc.MergeUnitsSorting(
                parent_sorting=si_sorting, units_to_merge=units_to_merge
            )
        else:
            return si_sorting

    @classmethod
    def get_sort_group_info(cls, key: dict) -> dj.Table:
        """Returns the sort group information for the curation
        (e.g. brain region, electrode placement, etc.)

        Parameters
        ----------
        key : dict
            restriction on CuratedSpikeSorting table

        Returns
        -------
        sort_group_info : Table
            Table with information about the sort groups
        """
        table = (
            (cls & key) * SpikeSortingSelection()
        ) * SpikeSortingRecordingSelection().proj(
            "recording_id", "sort_group_id"
        )
        electrode_restrict_list = []
        for entry in table:
            # pull just one electrode from each sort group for info
            electrode_restrict_list.extend(
                ((SortGroup.SortGroupElectrode() & entry) * Electrode).fetch(
                    limit=1
                )
            )

        sort_group_info = (
            (Electrode & electrode_restrict_list)
            * table
            * SortGroup.SortGroupElectrode()
        ) * BrainRegion()
        return (cls & key).proj() * sort_group_info

insert_curation(sorting_id, parent_curation_id=-1, labels=None, merge_groups=None, apply_merge=False, metrics=None, description='') classmethod

Insert a row into CurationV1.

Parameters:

Name Type Description Default
sorting_id str

The key for the original SpikeSorting

required
parent_curation_id int

The curation id of the parent curation

-1
labels dict or None

curation labels (e.g. good, noise, mua)

None
merge_groups dict or None

groups of unit IDs to be merged

None
metrics dict or None

Computed quality metrics, one for each neuron

None
description str

description of this curation or where it originates; e.g. FigURL

''
Note

Example curation.json (output of figurl): { "labelsByUnit": {"1":["noise","reject"],"10":["noise","reject"]}, "mergeGroups": [[11,12],[46,48],[51,54],[50,53]] }

Returns:

Name Type Description
curation_key dict
Source code in src/spyglass/spikesorting/v1/curation.py
@classmethod
def insert_curation(
    cls,
    sorting_id: str,
    parent_curation_id: int = -1,
    labels: Union[None, Dict[str, List[str]]] = None,
    merge_groups: Union[None, List[List[str]]] = None,
    apply_merge: bool = False,
    metrics: Union[None, Dict[str, Dict[str, float]]] = None,
    description: str = "",
):
    """Insert a row into CurationV1.

    Parameters
    ----------
    sorting_id : str
        The key for the original SpikeSorting
    parent_curation_id : int, optional
        The curation id of the parent curation
    labels : dict or None, optional
        curation labels (e.g. good, noise, mua)
    merge_groups : dict or None, optional
        groups of unit IDs to be merged
    metrics : dict or None, optional
        Computed quality metrics, one for each neuron
    description : str, optional
        description of this curation or where it originates; e.g. FigURL

    Note
    ----
    Example curation.json (output of figurl):
    {
     "labelsByUnit":
        {"1":["noise","reject"],"10":["noise","reject"]},
     "mergeGroups":
        [[11,12],[46,48],[51,54],[50,53]]
    }

    Returns
    -------
    curation_key : dict
    """
    AnalysisNwbfile()._creation_times["pre_create_time"] = time()

    sort_query = cls & {"sorting_id": sorting_id}
    parent_curation_id = max(parent_curation_id, -1)
    if parent_curation_id == -1:
        parent_curation_id = -1
        # check to see if this sorting with a parent of -1
        # has already been inserted and if so, warn the user
        query = sort_query & {"parent_curation_id": -1}
        if query:
            Warning("Sorting has already been inserted.")
            return query.fetch("KEY")

    # generate curation ID
    existing_curation_ids = sort_query.fetch("curation_id")
    curation_id = max(existing_curation_ids, default=-1) + 1

    # write the curation labels, merge groups,
    # and metrics as columns in the units table of NWB
    analysis_file_name, object_id = _write_sorting_to_nwb_with_curation(
        sorting_id=sorting_id,
        labels=labels,
        merge_groups=merge_groups,
        metrics=metrics,
        apply_merge=apply_merge,
    )

    # INSERT
    AnalysisNwbfile().add(
        (SpikeSortingSelection & {"sorting_id": sorting_id}).fetch1(
            "nwb_file_name"
        ),
        analysis_file_name,
    )

    key = {
        "sorting_id": sorting_id,
        "curation_id": curation_id,
        "parent_curation_id": parent_curation_id,
        "analysis_file_name": analysis_file_name,
        "object_id": object_id,
        "merges_applied": apply_merge,
        "description": description,
    }
    cls.insert1(
        key,
        skip_duplicates=True,
    )
    AnalysisNwbfile().log(analysis_file_name, table=cls.full_table_name)

    return key

insert_metric_curation(key, apply_merge=False) classmethod

Insert a row into CurationV1.

Parameters:

Name Type Description Default
key Dict

primary key of MetricCuration

required

Returns:

Name Type Description
curation_key Dict
Source code in src/spyglass/spikesorting/v1/curation.py
@classmethod
def insert_metric_curation(cls, key: Dict, apply_merge=False):
    """Insert a row into CurationV1.

    Parameters
    ----------
    key : Dict
        primary key of MetricCuration

    Returns
    -------
    curation_key : Dict
    """
    from spyglass.spikesorting.v1.metric_curation import (
        MetricCuration,
        MetricCurationSelection,
    )

    sorting_id, parent_curation_id = (MetricCurationSelection & key).fetch1(
        "sorting_id", "curation_id"
    )

    curation_key = cls.insert_curation(
        sorting_id=sorting_id,
        parent_curation_id=parent_curation_id,
        labels=MetricCuration.get_labels(key) or None,
        merge_groups=MetricCuration.get_merge_groups(key) or None,
        apply_merge=apply_merge,
        description=(f"metric_curation_id: {key['metric_curation_id']}"),
    )

    return curation_key

get_recording(key) classmethod

Get recording related to this curation as spikeinterface BaseRecording

Parameters:

Name Type Description Default
key dict

primary key of CurationV1 table

required
Source code in src/spyglass/spikesorting/v1/curation.py
@classmethod
def get_recording(cls, key: dict) -> si.BaseRecording:
    """Get recording related to this curation as spikeinterface BaseRecording

    Parameters
    ----------
    key : dict
        primary key of CurationV1 table
    """

    analysis_file_name = (
        SpikeSortingRecording * SpikeSortingSelection & key
    ).fetch1("analysis_file_name")
    analysis_file_abs_path = AnalysisNwbfile.get_abs_path(
        analysis_file_name
    )
    recording = se.read_nwb_recording(
        analysis_file_abs_path, load_time_vector=True
    )
    recording.annotate(is_filtered=True)

    return recording

get_sorting(key) classmethod

Get sorting in the analysis NWB file as spikeinterface BaseSorting

Parameters:

Name Type Description Default
key dict

primary key of CurationV1 table

required

Returns:

Name Type Description
sorting BaseSorting
Source code in src/spyglass/spikesorting/v1/curation.py
@classmethod
def get_sorting(cls, key: dict) -> si.BaseSorting:
    """Get sorting in the analysis NWB file as spikeinterface BaseSorting

    Parameters
    ----------
    key : dict
        primary key of CurationV1 table

    Returns
    -------
    sorting : si.BaseSorting

    """
    recording = cls.get_recording(key)
    sampling_frequency = recording.get_sampling_frequency()
    analysis_file_name = (CurationV1 & key).fetch1("analysis_file_name")
    analysis_file_abs_path = AnalysisNwbfile.get_abs_path(
        analysis_file_name
    )
    with pynwb.NWBHDF5IO(
        analysis_file_abs_path, "r", load_namespaces=True
    ) as io:
        nwbf = io.read()
        units = nwbf.units.to_dataframe()
    units_dict_list = [
        {
            unit_id: np.searchsorted(recording.get_times(), spike_times)
            for unit_id, spike_times in zip(
                units.index, units["spike_times"]
            )
        }
    ]

    sorting = si.NumpySorting.from_unit_dict(
        units_dict_list, sampling_frequency=sampling_frequency
    )

    return sorting

get_merged_sorting(key) classmethod

Get sorting with merges applied.

Parameters:

Name Type Description Default
key dict

CurationV1 key

required

Returns:

Name Type Description
sorting BaseSorting
Source code in src/spyglass/spikesorting/v1/curation.py
@classmethod
def get_merged_sorting(cls, key: dict) -> si.BaseSorting:
    """Get sorting with merges applied.

    Parameters
    ----------
    key : dict
        CurationV1 key

    Returns
    -------
    sorting : si.BaseSorting

    """
    recording = cls.get_recording(key)

    curation_key = (cls & key).fetch1()

    sorting_analysis_file_abs_path = AnalysisNwbfile.get_abs_path(
        curation_key["analysis_file_name"]
    )
    si_sorting = se.read_nwb_sorting(
        sorting_analysis_file_abs_path,
        sampling_frequency=recording.get_sampling_frequency(),
    )

    with pynwb.NWBHDF5IO(
        sorting_analysis_file_abs_path, "r", load_namespaces=True
    ) as io:
        nwbfile = io.read()
        nwb_sorting = nwbfile.objects[curation_key["object_id"]]
        merge_groups = nwb_sorting["merge_groups"][:]

    if merge_groups:
        units_to_merge = _merge_dict_to_list(merge_groups)
        return sc.MergeUnitsSorting(
            parent_sorting=si_sorting, units_to_merge=units_to_merge
        )
    else:
        return si_sorting

get_sort_group_info(key) classmethod

Returns the sort group information for the curation (e.g. brain region, electrode placement, etc.)

Parameters:

Name Type Description Default
key dict

restriction on CuratedSpikeSorting table

required

Returns:

Name Type Description
sort_group_info Table

Table with information about the sort groups

Source code in src/spyglass/spikesorting/v1/curation.py
@classmethod
def get_sort_group_info(cls, key: dict) -> dj.Table:
    """Returns the sort group information for the curation
    (e.g. brain region, electrode placement, etc.)

    Parameters
    ----------
    key : dict
        restriction on CuratedSpikeSorting table

    Returns
    -------
    sort_group_info : Table
        Table with information about the sort groups
    """
    table = (
        (cls & key) * SpikeSortingSelection()
    ) * SpikeSortingRecordingSelection().proj(
        "recording_id", "sort_group_id"
    )
    electrode_restrict_list = []
    for entry in table:
        # pull just one electrode from each sort group for info
        electrode_restrict_list.extend(
            ((SortGroup.SortGroupElectrode() & entry) * Electrode).fetch(
                limit=1
            )
        )

    sort_group_info = (
        (Electrode & electrode_restrict_list)
        * table
        * SortGroup.SortGroupElectrode()
    ) * BrainRegion()
    return (cls & key).proj() * sort_group_info