Skip to content

recording.py

SortGroup

Bases: SpyglassMixin, Manual

Source code in src/spyglass/spikesorting/v1/recording.py
@schema
class SortGroup(SpyglassMixin, dj.Manual):
    definition = """
    # Set of electrodes to spike sort together
    -> Session
    sort_group_id: int
    ---
    sort_reference_electrode_id = -1: int  # the electrode to use for referencing
                                           # -1: no reference, -2: common median
    """

    class SortGroupElectrode(SpyglassMixin, dj.Part):
        definition = """
        -> SortGroup
        -> Electrode
        """

    @classmethod
    def set_group_by_shank(
        cls,
        nwb_file_name: str,
        references: dict = None,
        omit_ref_electrode_group=False,
        omit_unitrode=True,
    ):
        """Divides electrodes into groups based on their shank position.

        * Electrodes from probes with 1 shank (e.g. tetrodes) are placed in a
          single group
        * Electrodes from probes with multiple shanks (e.g. polymer probes) are
          placed in one group per shank
        * Bad channels are omitted

        Parameters
        ----------
        nwb_file_name : str
            the name of the NWB file whose electrodes should be put into
            sorting groups
        references : dict, optional
            If passed, used to set references. Otherwise, references set using
            original reference electrodes from config. Keys: electrode groups.
            Values: reference electrode.
        omit_ref_electrode_group : bool
            Optional. If True, no sort group is defined for electrode group of
            reference.
        omit_unitrode : bool
            Optional. If True, no sort groups are defined for unitrodes.
        """
        # delete any current groups
        (SortGroup & {"nwb_file_name": nwb_file_name}).delete()
        # get the electrodes from this NWB file
        electrodes = (
            Electrode()
            & {"nwb_file_name": nwb_file_name}
            & {"bad_channel": "False"}
        ).fetch()
        e_groups = list(np.unique(electrodes["electrode_group_name"]))
        e_groups.sort(key=int)  # sort electrode groups numerically
        sort_group = 0
        sg_key = dict()
        sge_key = dict()
        sg_key["nwb_file_name"] = sge_key["nwb_file_name"] = nwb_file_name
        for e_group in e_groups:
            # for each electrode group, get a list of the unique shank numbers
            shank_list = np.unique(
                electrodes["probe_shank"][
                    electrodes["electrode_group_name"] == e_group
                ]
            )
            sge_key["electrode_group_name"] = e_group
            # get the indices of all electrodes in this group / shank and set their sorting group
            for shank in shank_list:
                sg_key["sort_group_id"] = sge_key["sort_group_id"] = sort_group
                # specify reference electrode. Use 'references' if passed, otherwise use reference from config
                if not references:
                    shank_elect_ref = electrodes[
                        "original_reference_electrode"
                    ][
                        np.logical_and(
                            electrodes["electrode_group_name"] == e_group,
                            electrodes["probe_shank"] == shank,
                        )
                    ]
                    if np.max(shank_elect_ref) == np.min(shank_elect_ref):
                        sg_key["sort_reference_electrode_id"] = shank_elect_ref[
                            0
                        ]
                    else:
                        ValueError(
                            f"Error in electrode group {e_group}: reference "
                            + "electrodes are not all the same"
                        )
                else:
                    if e_group not in references.keys():
                        raise Exception(
                            f"electrode group {e_group} not a key in "
                            + "references, so cannot set reference"
                        )
                    else:
                        sg_key["sort_reference_electrode_id"] = references[
                            e_group
                        ]
                # Insert sort group and sort group electrodes
                reference_electrode_group = electrodes[
                    electrodes["electrode_id"]
                    == sg_key["sort_reference_electrode_id"]
                ][
                    "electrode_group_name"
                ]  # reference for this electrode group
                if (
                    len(reference_electrode_group) == 1
                ):  # unpack single reference
                    reference_electrode_group = reference_electrode_group[0]
                elif (int(sg_key["sort_reference_electrode_id"]) > 0) and (
                    len(reference_electrode_group) != 1
                ):
                    raise Exception(
                        "Should have found exactly one electrode group for "
                        + "reference electrode, but found "
                        + f"{len(reference_electrode_group)}."
                    )
                if omit_ref_electrode_group and (
                    str(e_group) == str(reference_electrode_group)
                ):
                    logger.warn(
                        f"Omitting electrode group {e_group} from sort groups "
                        + "because contains reference."
                    )
                    continue
                shank_elect = electrodes["electrode_id"][
                    np.logical_and(
                        electrodes["electrode_group_name"] == e_group,
                        electrodes["probe_shank"] == shank,
                    )
                ]
                if (
                    omit_unitrode and len(shank_elect) == 1
                ):  # omit unitrodes if indicated
                    logger.warn(
                        f"Omitting electrode group {e_group}, shank {shank} "
                        + "from sort groups because unitrode."
                    )
                    continue
                cls.insert1(sg_key, skip_duplicates=True)
                for elect in shank_elect:
                    sge_key["electrode_id"] = elect
                    cls.SortGroupElectrode().insert1(
                        sge_key, skip_duplicates=True
                    )
                sort_group += 1

set_group_by_shank(nwb_file_name, references=None, omit_ref_electrode_group=False, omit_unitrode=True) classmethod

Divides electrodes into groups based on their shank position.

  • Electrodes from probes with 1 shank (e.g. tetrodes) are placed in a single group
  • Electrodes from probes with multiple shanks (e.g. polymer probes) are placed in one group per shank
  • Bad channels are omitted

Parameters:

Name Type Description Default
nwb_file_name str

the name of the NWB file whose electrodes should be put into sorting groups

required
references dict

If passed, used to set references. Otherwise, references set using original reference electrodes from config. Keys: electrode groups. Values: reference electrode.

None
omit_ref_electrode_group bool

Optional. If True, no sort group is defined for electrode group of reference.

False
omit_unitrode bool

Optional. If True, no sort groups are defined for unitrodes.

True
Source code in src/spyglass/spikesorting/v1/recording.py
@classmethod
def set_group_by_shank(
    cls,
    nwb_file_name: str,
    references: dict = None,
    omit_ref_electrode_group=False,
    omit_unitrode=True,
):
    """Divides electrodes into groups based on their shank position.

    * Electrodes from probes with 1 shank (e.g. tetrodes) are placed in a
      single group
    * Electrodes from probes with multiple shanks (e.g. polymer probes) are
      placed in one group per shank
    * Bad channels are omitted

    Parameters
    ----------
    nwb_file_name : str
        the name of the NWB file whose electrodes should be put into
        sorting groups
    references : dict, optional
        If passed, used to set references. Otherwise, references set using
        original reference electrodes from config. Keys: electrode groups.
        Values: reference electrode.
    omit_ref_electrode_group : bool
        Optional. If True, no sort group is defined for electrode group of
        reference.
    omit_unitrode : bool
        Optional. If True, no sort groups are defined for unitrodes.
    """
    # delete any current groups
    (SortGroup & {"nwb_file_name": nwb_file_name}).delete()
    # get the electrodes from this NWB file
    electrodes = (
        Electrode()
        & {"nwb_file_name": nwb_file_name}
        & {"bad_channel": "False"}
    ).fetch()
    e_groups = list(np.unique(electrodes["electrode_group_name"]))
    e_groups.sort(key=int)  # sort electrode groups numerically
    sort_group = 0
    sg_key = dict()
    sge_key = dict()
    sg_key["nwb_file_name"] = sge_key["nwb_file_name"] = nwb_file_name
    for e_group in e_groups:
        # for each electrode group, get a list of the unique shank numbers
        shank_list = np.unique(
            electrodes["probe_shank"][
                electrodes["electrode_group_name"] == e_group
            ]
        )
        sge_key["electrode_group_name"] = e_group
        # get the indices of all electrodes in this group / shank and set their sorting group
        for shank in shank_list:
            sg_key["sort_group_id"] = sge_key["sort_group_id"] = sort_group
            # specify reference electrode. Use 'references' if passed, otherwise use reference from config
            if not references:
                shank_elect_ref = electrodes[
                    "original_reference_electrode"
                ][
                    np.logical_and(
                        electrodes["electrode_group_name"] == e_group,
                        electrodes["probe_shank"] == shank,
                    )
                ]
                if np.max(shank_elect_ref) == np.min(shank_elect_ref):
                    sg_key["sort_reference_electrode_id"] = shank_elect_ref[
                        0
                    ]
                else:
                    ValueError(
                        f"Error in electrode group {e_group}: reference "
                        + "electrodes are not all the same"
                    )
            else:
                if e_group not in references.keys():
                    raise Exception(
                        f"electrode group {e_group} not a key in "
                        + "references, so cannot set reference"
                    )
                else:
                    sg_key["sort_reference_electrode_id"] = references[
                        e_group
                    ]
            # Insert sort group and sort group electrodes
            reference_electrode_group = electrodes[
                electrodes["electrode_id"]
                == sg_key["sort_reference_electrode_id"]
            ][
                "electrode_group_name"
            ]  # reference for this electrode group
            if (
                len(reference_electrode_group) == 1
            ):  # unpack single reference
                reference_electrode_group = reference_electrode_group[0]
            elif (int(sg_key["sort_reference_electrode_id"]) > 0) and (
                len(reference_electrode_group) != 1
            ):
                raise Exception(
                    "Should have found exactly one electrode group for "
                    + "reference electrode, but found "
                    + f"{len(reference_electrode_group)}."
                )
            if omit_ref_electrode_group and (
                str(e_group) == str(reference_electrode_group)
            ):
                logger.warn(
                    f"Omitting electrode group {e_group} from sort groups "
                    + "because contains reference."
                )
                continue
            shank_elect = electrodes["electrode_id"][
                np.logical_and(
                    electrodes["electrode_group_name"] == e_group,
                    electrodes["probe_shank"] == shank,
                )
            ]
            if (
                omit_unitrode and len(shank_elect) == 1
            ):  # omit unitrodes if indicated
                logger.warn(
                    f"Omitting electrode group {e_group}, shank {shank} "
                    + "from sort groups because unitrode."
                )
                continue
            cls.insert1(sg_key, skip_duplicates=True)
            for elect in shank_elect:
                sge_key["electrode_id"] = elect
                cls.SortGroupElectrode().insert1(
                    sge_key, skip_duplicates=True
                )
            sort_group += 1

SpikeSortingRecordingSelection

Bases: SpyglassMixin, Manual

Source code in src/spyglass/spikesorting/v1/recording.py
@schema
class SpikeSortingRecordingSelection(SpyglassMixin, dj.Manual):
    definition = """
    # Raw voltage traces and parameters. Use `insert_selection` method to insert rows.
    recording_id: uuid
    ---
    -> Raw
    -> SortGroup
    -> IntervalList
    -> SpikeSortingPreprocessingParameters
    -> LabTeam
    """

    _parallel_make = True

    @classmethod
    def insert_selection(cls, key: dict):
        """Insert a row into SpikeSortingRecordingSelection with an
        automatically generated unique recording ID as the sole primary key.

        Parameters
        ----------
        key : dict
            primary key of Raw, SortGroup, IntervalList,
            SpikeSortingPreprocessingParameters, LabTeam tables

        Returns
        -------
            primary key of SpikeSortingRecordingSelection table
        """
        query = cls & key
        if query:
            logger.warn("Similar row(s) already inserted.")
            return query.fetch(as_dict=True)
        key["recording_id"] = uuid.uuid4()
        cls.insert1(key, skip_duplicates=True)
        return key

insert_selection(key) classmethod

Insert a row into SpikeSortingRecordingSelection with an automatically generated unique recording ID as the sole primary key.

Parameters:

Name Type Description Default
key dict

primary key of Raw, SortGroup, IntervalList, SpikeSortingPreprocessingParameters, LabTeam tables

required

Returns:

Type Description
primary key of SpikeSortingRecordingSelection table
Source code in src/spyglass/spikesorting/v1/recording.py
@classmethod
def insert_selection(cls, key: dict):
    """Insert a row into SpikeSortingRecordingSelection with an
    automatically generated unique recording ID as the sole primary key.

    Parameters
    ----------
    key : dict
        primary key of Raw, SortGroup, IntervalList,
        SpikeSortingPreprocessingParameters, LabTeam tables

    Returns
    -------
        primary key of SpikeSortingRecordingSelection table
    """
    query = cls & key
    if query:
        logger.warn("Similar row(s) already inserted.")
        return query.fetch(as_dict=True)
    key["recording_id"] = uuid.uuid4()
    cls.insert1(key, skip_duplicates=True)
    return key

SpikeSortingRecording

Bases: SpyglassMixin, Computed

Source code in src/spyglass/spikesorting/v1/recording.py
@schema
class SpikeSortingRecording(SpyglassMixin, dj.Computed):
    definition = """
    # Processed recording.
    -> SpikeSortingRecordingSelection
    ---
    -> AnalysisNwbfile
    object_id: varchar(40) # Object ID for the processed recording in NWB file
    """

    def make(self, key):
        AnalysisNwbfile()._creation_times["pre_create_time"] = time()
        # DO:
        # - get valid times for sort interval
        # - proprocess recording
        # - write recording to NWB file
        sort_interval_valid_times = self._get_sort_interval_valid_times(key)
        recording, timestamps = self._get_preprocessed_recording(key)
        recording_nwb_file_name, recording_object_id = _write_recording_to_nwb(
            recording,
            timestamps,
            (SpikeSortingRecordingSelection & key).fetch1("nwb_file_name"),
        )
        key["analysis_file_name"] = recording_nwb_file_name
        key["object_id"] = recording_object_id

        # INSERT:
        # - valid times into IntervalList
        # - analysis NWB file holding processed recording into AnalysisNwbfile
        # - entry into SpikeSortingRecording
        IntervalList.insert1(
            {
                "nwb_file_name": (SpikeSortingRecordingSelection & key).fetch1(
                    "nwb_file_name"
                ),
                "interval_list_name": key["recording_id"],
                "valid_times": sort_interval_valid_times,
                "pipeline": "spikesorting_recording_v1",
            }
        )
        AnalysisNwbfile().add(
            (SpikeSortingRecordingSelection & key).fetch1("nwb_file_name"),
            key["analysis_file_name"],
        )
        AnalysisNwbfile().log(
            recording_nwb_file_name, table=self.full_table_name
        )
        self.insert1(key)

    @classmethod
    def get_recording(cls, key: dict) -> si.BaseRecording:
        """Get recording related to this curation as spikeinterface BaseRecording

        Parameters
        ----------
        key : dict
            primary key of SpikeSorting table
        """

        analysis_file_name = (cls & key).fetch1("analysis_file_name")
        analysis_file_abs_path = AnalysisNwbfile.get_abs_path(
            analysis_file_name
        )
        recording = se.read_nwb_recording(
            analysis_file_abs_path, load_time_vector=True
        )

        return recording

    @staticmethod
    def _get_recording_timestamps(recording):
        if recording.get_num_segments() > 1:
            frames_per_segment = [0]
            for i in range(recording.get_num_segments()):
                frames_per_segment.append(
                    recording.get_num_frames(segment_index=i)
                )

            cumsum_frames = np.cumsum(frames_per_segment)
            total_frames = np.sum(frames_per_segment)

            timestamps = np.zeros((total_frames,))
            for i in range(recording.get_num_segments()):
                timestamps[cumsum_frames[i] : cumsum_frames[i + 1]] = (
                    recording.get_times(segment_index=i)
                )
        else:
            timestamps = recording.get_times()
        return timestamps

    def _get_sort_interval_valid_times(self, key: dict):
        """Identifies the intersection between sort interval specified by the user
        and the valid times (times for which neural data exist, excluding e.g. dropped packets).

        Parameters
        ----------
        key: dict
            primary key of SpikeSortingRecordingSelection table

        Returns
        -------
        sort_interval_valid_times: ndarray of tuples
            (start, end) times for valid intervals in the sort interval

        """
        # FETCH: - sort interval - valid times - preprocessing parameters
        nwb_file_name, sort_interval_name, params = (
            SpikeSortingPreprocessingParameters * SpikeSortingRecordingSelection
            & key
        ).fetch1("nwb_file_name", "interval_list_name", "preproc_params")

        sort_interval = (
            IntervalList
            & {
                "nwb_file_name": nwb_file_name,
                "interval_list_name": sort_interval_name,
            }
        ).fetch1("valid_times")

        valid_interval_times = (
            IntervalList
            & {
                "nwb_file_name": nwb_file_name,
                "interval_list_name": "raw data valid times",
            }
        ).fetch1("valid_times")

        # DO: - take intersection between sort interval and valid times
        return interval_list_intersect(
            sort_interval,
            valid_interval_times,
            min_length=params["min_segment_length"],
        )

    def _get_preprocessed_recording(self, key: dict):
        """Filters and references a recording.

        - Loads the NWB file created during insertion as a spikeinterface Recording
        - Slices recording in time (interval) and space (channels);
          recording chunks from disjoint intervals are concatenated
        - Applies referencing and bandpass filtering

        Parameters
        ----------
        key: dict
            primary key of SpikeSortingRecordingSelection table

        Returns
        -------
        recording: si.Recording
        """
        # FETCH:
        # - full path to NWB file
        # - channels to be included in the sort
        # - the reference channel
        # - probe type
        # - filter parameters
        nwb_file_name = (SpikeSortingRecordingSelection & key).fetch1(
            "nwb_file_name"
        )
        sort_group_id = (SpikeSortingRecordingSelection & key).fetch1(
            "sort_group_id"
        )
        nwb_file_abs_path = Nwbfile().get_abs_path(nwb_file_name)
        channel_ids = (
            SortGroup.SortGroupElectrode
            & {
                "nwb_file_name": nwb_file_name,
                "sort_group_id": sort_group_id,
            }
        ).fetch("electrode_id")
        ref_channel_id = (
            SortGroup
            & {
                "nwb_file_name": nwb_file_name,
                "sort_group_id": sort_group_id,
            }
        ).fetch1("sort_reference_electrode_id")
        recording_channel_ids = np.setdiff1d(channel_ids, ref_channel_id)
        all_channel_ids = np.unique(np.append(channel_ids, ref_channel_id))

        probe_type_by_channel = []
        electrode_group_by_channel = []
        for channel_id in channel_ids:
            probe_type_by_channel.append(
                (
                    Electrode * Probe
                    & {
                        "nwb_file_name": nwb_file_name,
                        "electrode_id": channel_id,
                    }
                ).fetch1("probe_type")
            )
            electrode_group_by_channel.append(
                (
                    Electrode
                    & {
                        "nwb_file_name": nwb_file_name,
                        "electrode_id": channel_id,
                    }
                ).fetch1("electrode_group_name")
            )
        probe_type = np.unique(probe_type_by_channel)
        filter_params = (
            SpikeSortingPreprocessingParameters * SpikeSortingRecordingSelection
            & key
        ).fetch1("preproc_params")

        # DO:
        # - load NWB file as a spikeinterface Recording
        # - slice the recording object in time and channels
        # - apply referencing depending on the option chosen by the user
        # - apply bandpass filter
        # - set probe to recording
        recording = se.read_nwb_recording(
            nwb_file_abs_path, load_time_vector=True
        )
        all_timestamps = recording.get_times()

        # TODO: make sure the following works for recordings that don't have explicit timestamps
        valid_sort_times = self._get_sort_interval_valid_times(key)
        valid_sort_times_indices = _consolidate_intervals(
            valid_sort_times, all_timestamps
        )

        # slice in time; concatenate disjoint sort intervals
        if len(valid_sort_times_indices) > 1:
            recordings_list = []
            timestamps = []
            for interval_indices in valid_sort_times_indices:
                recording_single = recording.frame_slice(
                    start_frame=interval_indices[0],
                    end_frame=interval_indices[1],
                )
                recordings_list.append(recording_single)
                timestamps.extend(
                    all_timestamps[interval_indices[0] : interval_indices[1]]
                )
            recording = si.concatenate_recordings(recordings_list)
        else:
            recording = recording.frame_slice(
                start_frame=valid_sort_times_indices[0][0],
                end_frame=valid_sort_times_indices[0][1],
            )
            timestamps = all_timestamps[
                valid_sort_times_indices[0][0] : valid_sort_times_indices[0][1]
            ]

        # slice in channels; include ref channel in first slice, then exclude it in second slice
        if ref_channel_id >= 0:
            recording = recording.channel_slice(channel_ids=all_channel_ids)
            recording = si.preprocessing.common_reference(
                recording,
                reference="single",
                ref_channel_ids=ref_channel_id,
                dtype=np.float64,
            )
            recording = recording.channel_slice(
                channel_ids=recording_channel_ids
            )
        elif ref_channel_id == -2:
            recording = recording.channel_slice(
                channel_ids=recording_channel_ids
            )
            recording = si.preprocessing.common_reference(
                recording,
                reference="global",
                operator="median",
                dtype=np.float64,
            )
        elif ref_channel_id == -1:
            recording = recording.channel_slice(
                channel_ids=recording_channel_ids
            )
        else:
            raise ValueError(
                "Invalid reference channel ID. Use -1 to skip referencing. Use "
                + "-2 to reference via global median. Use positive integer to "
                + "reference to a specific channel."
            )

        recording = si.preprocessing.bandpass_filter(
            recording,
            freq_min=filter_params["frequency_min"],
            freq_max=filter_params["frequency_max"],
            dtype=np.float64,
        )

        # if the sort group is a tetrode, change the channel location
        # (necessary because the channel location for tetrodes are not set properly)
        if (
            len(probe_type) == 1
            and probe_type[0] == "tetrode_12.5"
            and len(recording_channel_ids) == 4
            and len(np.unique(electrode_group_by_channel)) == 1
        ):
            tetrode = pi.Probe(ndim=2)
            position = [[0, 0], [0, 12.5], [12.5, 0], [12.5, 12.5]]
            tetrode.set_contacts(
                position, shapes="circle", shape_params={"radius": 6.25}
            )
            tetrode.set_contact_ids(channel_ids)
            tetrode.set_device_channel_indices(np.arange(4))
            recording = recording.set_probe(tetrode, in_place=True)

        return recording, np.asarray(timestamps)

get_recording(key) classmethod

Get recording related to this curation as spikeinterface BaseRecording

Parameters:

Name Type Description Default
key dict

primary key of SpikeSorting table

required
Source code in src/spyglass/spikesorting/v1/recording.py
@classmethod
def get_recording(cls, key: dict) -> si.BaseRecording:
    """Get recording related to this curation as spikeinterface BaseRecording

    Parameters
    ----------
    key : dict
        primary key of SpikeSorting table
    """

    analysis_file_name = (cls & key).fetch1("analysis_file_name")
    analysis_file_abs_path = AnalysisNwbfile.get_abs_path(
        analysis_file_name
    )
    recording = se.read_nwb_recording(
        analysis_file_abs_path, load_time_vector=True
    )

    return recording

SpikeInterfaceRecordingDataChunkIterator

Bases: GenericDataChunkIterator

DataChunkIterator specifically for use on RecordingExtractor objects.

Source code in src/spyglass/spikesorting/v1/recording.py
class SpikeInterfaceRecordingDataChunkIterator(GenericDataChunkIterator):
    """DataChunkIterator specifically for use on RecordingExtractor objects."""

    def __init__(
        self,
        recording: si.BaseRecording,
        segment_index: int = 0,
        return_scaled: bool = False,
        buffer_gb: Optional[float] = None,
        buffer_shape: Optional[tuple] = None,
        chunk_mb: Optional[float] = None,
        chunk_shape: Optional[tuple] = None,
        display_progress: bool = False,
        progress_bar_options: Optional[dict] = None,
    ):
        """
        Initialize an Iterable object which returns DataChunks with data and their selections on each iteration.

        Parameters
        ----------
        recording : si.BaseRecording
            The SpikeInterfaceRecording object which handles the data access.
        segment_index : int, optional
            The recording segment to iterate on.
            Defaults to 0.
        return_scaled : bool, optional
            Whether to return the trace data in scaled units (uV, if True) or in the raw data type (if False).
            Defaults to False.
        buffer_gb : float, optional
            The upper bound on size in gigabytes (GB) of each selection from the iteration.
            The buffer_shape will be set implicitly by this argument.
            Cannot be set if `buffer_shape` is also specified.
            The default is 1GB.
        buffer_shape : tuple, optional
            Manual specification of buffer shape to return on each iteration.
            Must be a multiple of chunk_shape along each axis.
            Cannot be set if `buffer_gb` is also specified.
            The default is None.
        chunk_mb : float, optional
            The upper bound on size in megabytes (MB) of the internal chunk for the HDF5 dataset.
            The chunk_shape will be set implicitly by this argument.
            Cannot be set if `chunk_shape` is also specified.
            The default is 1MB, as recommended by the HDF5 group. For more details, see
            https://support.hdfgroup.org/HDF5/doc/TechNotes/TechNote-HDF5-ImprovingIOPerformanceCompressedDatasets.pdf
        chunk_shape : tuple, optional
            Manual specification of the internal chunk shape for the HDF5 dataset.
            Cannot be set if `chunk_mb` is also specified.
            The default is None.
        display_progress : bool, optional
            Display a progress bar with iteration rate and estimated completion time.
        progress_bar_options : dict, optional
            Dictionary of keyword arguments to be passed directly to tqdm.
            See https://github.com/tqdm/tqdm#parameters for options.
        """
        self.recording = recording
        self.segment_index = segment_index
        self.return_scaled = return_scaled
        self.channel_ids = recording.get_channel_ids()
        super().__init__(
            buffer_gb=buffer_gb,
            buffer_shape=buffer_shape,
            chunk_mb=chunk_mb,
            chunk_shape=chunk_shape,
            display_progress=display_progress,
            progress_bar_options=progress_bar_options,
        )

    def _get_data(self, selection: Tuple[slice]) -> Iterable:
        return self.recording.get_traces(
            segment_index=self.segment_index,
            channel_ids=self.channel_ids[selection[1]],
            start_frame=selection[0].start,
            end_frame=selection[0].stop,
            return_scaled=self.return_scaled,
        )

    def _get_dtype(self):
        return self.recording.get_dtype()

    def _get_maxshape(self):
        return (
            self.recording.get_num_samples(segment_index=self.segment_index),
            self.recording.get_num_channels(),
        )

__init__(recording, segment_index=0, return_scaled=False, buffer_gb=None, buffer_shape=None, chunk_mb=None, chunk_shape=None, display_progress=False, progress_bar_options=None)

Initialize an Iterable object which returns DataChunks with data and their selections on each iteration.

Parameters:

Name Type Description Default
recording BaseRecording

The SpikeInterfaceRecording object which handles the data access.

required
segment_index int

The recording segment to iterate on. Defaults to 0.

0
return_scaled bool

Whether to return the trace data in scaled units (uV, if True) or in the raw data type (if False). Defaults to False.

False
buffer_gb float

The upper bound on size in gigabytes (GB) of each selection from the iteration. The buffer_shape will be set implicitly by this argument. Cannot be set if buffer_shape is also specified. The default is 1GB.

None
buffer_shape tuple

Manual specification of buffer shape to return on each iteration. Must be a multiple of chunk_shape along each axis. Cannot be set if buffer_gb is also specified. The default is None.

None
chunk_mb float

The upper bound on size in megabytes (MB) of the internal chunk for the HDF5 dataset. The chunk_shape will be set implicitly by this argument. Cannot be set if chunk_shape is also specified. The default is 1MB, as recommended by the HDF5 group. For more details, see https://support.hdfgroup.org/HDF5/doc/TechNotes/TechNote-HDF5-ImprovingIOPerformanceCompressedDatasets.pdf

None
chunk_shape tuple

Manual specification of the internal chunk shape for the HDF5 dataset. Cannot be set if chunk_mb is also specified. The default is None.

None
display_progress bool

Display a progress bar with iteration rate and estimated completion time.

False
progress_bar_options dict

Dictionary of keyword arguments to be passed directly to tqdm. See https://github.com/tqdm/tqdm#parameters for options.

None
Source code in src/spyglass/spikesorting/v1/recording.py
def __init__(
    self,
    recording: si.BaseRecording,
    segment_index: int = 0,
    return_scaled: bool = False,
    buffer_gb: Optional[float] = None,
    buffer_shape: Optional[tuple] = None,
    chunk_mb: Optional[float] = None,
    chunk_shape: Optional[tuple] = None,
    display_progress: bool = False,
    progress_bar_options: Optional[dict] = None,
):
    """
    Initialize an Iterable object which returns DataChunks with data and their selections on each iteration.

    Parameters
    ----------
    recording : si.BaseRecording
        The SpikeInterfaceRecording object which handles the data access.
    segment_index : int, optional
        The recording segment to iterate on.
        Defaults to 0.
    return_scaled : bool, optional
        Whether to return the trace data in scaled units (uV, if True) or in the raw data type (if False).
        Defaults to False.
    buffer_gb : float, optional
        The upper bound on size in gigabytes (GB) of each selection from the iteration.
        The buffer_shape will be set implicitly by this argument.
        Cannot be set if `buffer_shape` is also specified.
        The default is 1GB.
    buffer_shape : tuple, optional
        Manual specification of buffer shape to return on each iteration.
        Must be a multiple of chunk_shape along each axis.
        Cannot be set if `buffer_gb` is also specified.
        The default is None.
    chunk_mb : float, optional
        The upper bound on size in megabytes (MB) of the internal chunk for the HDF5 dataset.
        The chunk_shape will be set implicitly by this argument.
        Cannot be set if `chunk_shape` is also specified.
        The default is 1MB, as recommended by the HDF5 group. For more details, see
        https://support.hdfgroup.org/HDF5/doc/TechNotes/TechNote-HDF5-ImprovingIOPerformanceCompressedDatasets.pdf
    chunk_shape : tuple, optional
        Manual specification of the internal chunk shape for the HDF5 dataset.
        Cannot be set if `chunk_mb` is also specified.
        The default is None.
    display_progress : bool, optional
        Display a progress bar with iteration rate and estimated completion time.
    progress_bar_options : dict, optional
        Dictionary of keyword arguments to be passed directly to tqdm.
        See https://github.com/tqdm/tqdm#parameters for options.
    """
    self.recording = recording
    self.segment_index = segment_index
    self.return_scaled = return_scaled
    self.channel_ids = recording.get_channel_ids()
    super().__init__(
        buffer_gb=buffer_gb,
        buffer_shape=buffer_shape,
        chunk_mb=chunk_mb,
        chunk_shape=chunk_shape,
        display_progress=display_progress,
        progress_bar_options=progress_bar_options,
    )

TimestampsDataChunkIterator

Bases: GenericDataChunkIterator

DataChunkIterator specifically for use on RecordingExtractor objects.

Source code in src/spyglass/spikesorting/v1/recording.py
class TimestampsDataChunkIterator(GenericDataChunkIterator):
    """DataChunkIterator specifically for use on RecordingExtractor objects."""

    def __init__(
        self,
        recording: si.BaseRecording,
        segment_index: int = 0,
        return_scaled: bool = False,
        buffer_gb: Optional[float] = None,
        buffer_shape: Optional[tuple] = None,
        chunk_mb: Optional[float] = None,
        chunk_shape: Optional[tuple] = None,
        display_progress: bool = False,
        progress_bar_options: Optional[dict] = None,
    ):
        """
        Initialize an Iterable object which returns DataChunks with data and their selections on each iteration.

        Parameters
        ----------
        recording : SpikeInterfaceRecording
            The SpikeInterfaceRecording object (RecordingExtractor or BaseRecording) which handles the data access.
        segment_index : int, optional
            The recording segment to iterate on.
            Defaults to 0.
        return_scaled : bool, optional
            Whether to return the trace data in scaled units (uV, if True) or in the raw data type (if False).
            Defaults to False.
        buffer_gb : float, optional
            The upper bound on size in gigabytes (GB) of each selection from the iteration.
            The buffer_shape will be set implicitly by this argument.
            Cannot be set if `buffer_shape` is also specified.
            The default is 1GB.
        buffer_shape : tuple, optional
            Manual specification of buffer shape to return on each iteration.
            Must be a multiple of chunk_shape along each axis.
            Cannot be set if `buffer_gb` is also specified.
            The default is None.
        chunk_mb : float, optional
            The upper bound on size in megabytes (MB) of the internal chunk for the HDF5 dataset.
            The chunk_shape will be set implicitly by this argument.
            Cannot be set if `chunk_shape` is also specified.
            The default is 1MB, as recommended by the HDF5 group. For more details, see
            https://support.hdfgroup.org/HDF5/doc/TechNotes/TechNote-HDF5-ImprovingIOPerformanceCompressedDatasets.pdf
        chunk_shape : tuple, optional
            Manual specification of the internal chunk shape for the HDF5 dataset.
            Cannot be set if `chunk_mb` is also specified.
            The default is None.
        display_progress : bool, optional
            Display a progress bar with iteration rate and estimated completion time.
        progress_bar_options : dict, optional
            Dictionary of keyword arguments to be passed directly to tqdm.
            See https://github.com/tqdm/tqdm#parameters for options.
        """
        self.recording = recording
        self.segment_index = segment_index
        self.return_scaled = return_scaled
        self.channel_ids = recording.get_channel_ids()
        super().__init__(
            buffer_gb=buffer_gb,
            buffer_shape=buffer_shape,
            chunk_mb=chunk_mb,
            chunk_shape=chunk_shape,
            display_progress=display_progress,
            progress_bar_options=progress_bar_options,
        )

    # change channel id to always be first channel
    def _get_data(self, selection: Tuple[slice]) -> Iterable:
        return self.recording.get_traces(
            segment_index=self.segment_index,
            channel_ids=[0],
            start_frame=selection[0].start,
            end_frame=selection[0].stop,
            return_scaled=self.return_scaled,
        )

    def _get_dtype(self):
        return self.recording.get_dtype()

    # remove the last dim for the timestamps since it is always just a 1D vector
    def _get_maxshape(self):
        return (
            self.recording.get_num_samples(segment_index=self.segment_index),
        )

__init__(recording, segment_index=0, return_scaled=False, buffer_gb=None, buffer_shape=None, chunk_mb=None, chunk_shape=None, display_progress=False, progress_bar_options=None)

Initialize an Iterable object which returns DataChunks with data and their selections on each iteration.

Parameters:

Name Type Description Default
recording SpikeInterfaceRecording

The SpikeInterfaceRecording object (RecordingExtractor or BaseRecording) which handles the data access.

required
segment_index int

The recording segment to iterate on. Defaults to 0.

0
return_scaled bool

Whether to return the trace data in scaled units (uV, if True) or in the raw data type (if False). Defaults to False.

False
buffer_gb float

The upper bound on size in gigabytes (GB) of each selection from the iteration. The buffer_shape will be set implicitly by this argument. Cannot be set if buffer_shape is also specified. The default is 1GB.

None
buffer_shape tuple

Manual specification of buffer shape to return on each iteration. Must be a multiple of chunk_shape along each axis. Cannot be set if buffer_gb is also specified. The default is None.

None
chunk_mb float

The upper bound on size in megabytes (MB) of the internal chunk for the HDF5 dataset. The chunk_shape will be set implicitly by this argument. Cannot be set if chunk_shape is also specified. The default is 1MB, as recommended by the HDF5 group. For more details, see https://support.hdfgroup.org/HDF5/doc/TechNotes/TechNote-HDF5-ImprovingIOPerformanceCompressedDatasets.pdf

None
chunk_shape tuple

Manual specification of the internal chunk shape for the HDF5 dataset. Cannot be set if chunk_mb is also specified. The default is None.

None
display_progress bool

Display a progress bar with iteration rate and estimated completion time.

False
progress_bar_options dict

Dictionary of keyword arguments to be passed directly to tqdm. See https://github.com/tqdm/tqdm#parameters for options.

None
Source code in src/spyglass/spikesorting/v1/recording.py
def __init__(
    self,
    recording: si.BaseRecording,
    segment_index: int = 0,
    return_scaled: bool = False,
    buffer_gb: Optional[float] = None,
    buffer_shape: Optional[tuple] = None,
    chunk_mb: Optional[float] = None,
    chunk_shape: Optional[tuple] = None,
    display_progress: bool = False,
    progress_bar_options: Optional[dict] = None,
):
    """
    Initialize an Iterable object which returns DataChunks with data and their selections on each iteration.

    Parameters
    ----------
    recording : SpikeInterfaceRecording
        The SpikeInterfaceRecording object (RecordingExtractor or BaseRecording) which handles the data access.
    segment_index : int, optional
        The recording segment to iterate on.
        Defaults to 0.
    return_scaled : bool, optional
        Whether to return the trace data in scaled units (uV, if True) or in the raw data type (if False).
        Defaults to False.
    buffer_gb : float, optional
        The upper bound on size in gigabytes (GB) of each selection from the iteration.
        The buffer_shape will be set implicitly by this argument.
        Cannot be set if `buffer_shape` is also specified.
        The default is 1GB.
    buffer_shape : tuple, optional
        Manual specification of buffer shape to return on each iteration.
        Must be a multiple of chunk_shape along each axis.
        Cannot be set if `buffer_gb` is also specified.
        The default is None.
    chunk_mb : float, optional
        The upper bound on size in megabytes (MB) of the internal chunk for the HDF5 dataset.
        The chunk_shape will be set implicitly by this argument.
        Cannot be set if `chunk_shape` is also specified.
        The default is 1MB, as recommended by the HDF5 group. For more details, see
        https://support.hdfgroup.org/HDF5/doc/TechNotes/TechNote-HDF5-ImprovingIOPerformanceCompressedDatasets.pdf
    chunk_shape : tuple, optional
        Manual specification of the internal chunk shape for the HDF5 dataset.
        Cannot be set if `chunk_mb` is also specified.
        The default is None.
    display_progress : bool, optional
        Display a progress bar with iteration rate and estimated completion time.
    progress_bar_options : dict, optional
        Dictionary of keyword arguments to be passed directly to tqdm.
        See https://github.com/tqdm/tqdm#parameters for options.
    """
    self.recording = recording
    self.segment_index = segment_index
    self.return_scaled = return_scaled
    self.channel_ids = recording.get_channel_ids()
    super().__init__(
        buffer_gb=buffer_gb,
        buffer_shape=buffer_shape,
        chunk_mb=chunk_mb,
        chunk_shape=chunk_shape,
        display_progress=display_progress,
        progress_bar_options=progress_bar_options,
    )