Skip to content

waveform_features.py

WaveformFeaturesParams

Bases: SpyglassMixin, Lookup

Defines the types of spike waveform features computed for a given spike time.

Source code in src/spyglass/decoding/v1/waveform_features.py
@schema
class WaveformFeaturesParams(SpyglassMixin, dj.Lookup):
    """Defines the types of spike waveform features computed for a given spike
    time."""

    definition = """
    features_param_name : varchar(80) # a name for this set of parameters
    ---
    params : longblob # the parameters for the waveform features
    """
    _default_waveform_feature_params = {
        "amplitude": {
            "peak_sign": "neg",
            "estimate_peak_time": False,
        }
    }
    _default_waveform_extraction_params = {
        "ms_before": 0.5,
        "ms_after": 0.5,
        "max_spikes_per_unit": None,
        "n_jobs": 5,
        "chunk_duration": "1000s",
    }
    contents = [
        [
            "amplitude",
            {
                "waveform_features_params": _default_waveform_feature_params,
                "waveform_extraction_params": _default_waveform_extraction_params,
            },
        ],
        [
            "amplitude, spike_location",
            {
                "waveform_features_params": {
                    "amplitude": _default_waveform_feature_params["amplitude"],
                    "spike_location": {},
                },
                "waveform_extraction_params": _default_waveform_extraction_params,
            },
        ],
    ]

    @classmethod
    def insert_default(cls):
        cls.insert(cls.contents, skip_duplicates=True)

    @staticmethod
    def check_supported_waveform_features(waveform_features: list[str]) -> bool:
        """Checks whether the requested waveform features types are supported

        Parameters
        ----------
        waveform_features : list
        """
        supported_features = set(WAVEFORM_FEATURE_FUNCTIONS)
        return set(waveform_features).issubset(supported_features)

    @property
    def supported_waveform_features(self) -> list[str]:
        """Returns the list of supported waveform features"""
        return list(WAVEFORM_FEATURE_FUNCTIONS)

check_supported_waveform_features(waveform_features) staticmethod

Checks whether the requested waveform features types are supported

Parameters:

Name Type Description Default
waveform_features list
required
Source code in src/spyglass/decoding/v1/waveform_features.py
@staticmethod
def check_supported_waveform_features(waveform_features: list[str]) -> bool:
    """Checks whether the requested waveform features types are supported

    Parameters
    ----------
    waveform_features : list
    """
    supported_features = set(WAVEFORM_FEATURE_FUNCTIONS)
    return set(waveform_features).issubset(supported_features)

supported_waveform_features: list[str] property

Returns the list of supported waveform features

UnitWaveformFeatures

Bases: SpyglassMixin, Computed

For each spike time, compute a spike waveform feature associated with that spike. Used for clusterless decoding.

Source code in src/spyglass/decoding/v1/waveform_features.py
@schema
class UnitWaveformFeatures(SpyglassMixin, dj.Computed):
    """For each spike time, compute a spike waveform feature associated with that
    spike. Used for clusterless decoding.
    """

    definition = """
    -> UnitWaveformFeaturesSelection
    ---
    -> AnalysisNwbfile
    object_id: varchar(40) # the NWB object that stores the waveforms
    """

    _parallel_make = True

    def make(self, key):
        AnalysisNwbfile()._creation_times["pre_create_time"] = time()
        # get the list of feature parameters
        params = (WaveformFeaturesParams & key).fetch1("params")

        # check that the feature type is supported
        if not WaveformFeaturesParams.check_supported_waveform_features(
            params["waveform_features_params"]
        ):
            raise NotImplementedError(
                f"Features {set(params['waveform_features_params'])} are not supported"
            )

        merge_key = {"merge_id": key["spikesorting_merge_id"]}
        waveform_extractor = self._fetch_waveform(
            merge_key, params["waveform_extraction_params"]
        )

        source_key = SpikeSortingOutput().merge_get_parent(merge_key).fetch1()
        # v0 pipeline
        if "sorter" in source_key and "nwb_file_name" in source_key:
            sorter = source_key["sorter"]
            nwb_file_name = source_key["nwb_file_name"]
            analysis_nwb_key = "units"
        # v1 pipeline
        else:
            sorting_id = (SpikeSortingOutput.CurationV1 & merge_key).fetch1(
                "sorting_id"
            )
            sorter, nwb_file_name = (
                SpikeSortingSelection & {"sorting_id": sorting_id}
            ).fetch1("sorter", "nwb_file_name")
            analysis_nwb_key = "object_id"

        waveform_features = {}

        for feature, feature_params in params[
            "waveform_features_params"
        ].items():
            waveform_features[feature] = self._compute_waveform_features(
                waveform_extractor,
                feature,
                feature_params,
                sorter,
            )

        spike_times = SpikeSortingOutput().fetch_nwb(merge_key)[0][
            analysis_nwb_key
        ]["spike_times"]

        (
            key["analysis_file_name"],
            key["object_id"],
        ) = _write_waveform_features_to_nwb(
            nwb_file_name,
            waveform_extractor,
            spike_times,
            waveform_features,
        )

        AnalysisNwbfile().add(
            nwb_file_name,
            key["analysis_file_name"],
        )
        AnalysisNwbfile().log(key, table=self.full_table_name)

        self.insert1(key)

    @staticmethod
    def _fetch_waveform(
        merge_key: dict, waveform_extraction_params: dict
    ) -> si.WaveformExtractor:
        # get the recording from the parent table
        recording = SpikeSortingOutput().get_recording(merge_key)
        if recording.get_num_segments() > 1:
            recording = si.concatenate_recordings([recording])
        # get the sorting from the parent table
        sorting = SpikeSortingOutput().get_sorting(merge_key)

        waveforms_temp_dir = temp_dir + "/" + str(merge_key["merge_id"])
        os.makedirs(waveforms_temp_dir, exist_ok=True)

        return si.extract_waveforms(
            recording=recording,
            sorting=sorting,
            folder=waveforms_temp_dir,
            overwrite=True,
            **waveform_extraction_params,
        )

    @staticmethod
    def _compute_waveform_features(
        waveform_extractor: si.WaveformExtractor,
        feature: str,
        feature_params: dict,
        sorter: str,
    ) -> dict:
        feature_func = WAVEFORM_FEATURE_FUNCTIONS[feature]
        if sorter == "clusterless_thresholder" and feature == "amplitude":
            feature_params["estimate_peak_time"] = False

        return {
            unit_id: feature_func(waveform_extractor, unit_id, **feature_params)
            for unit_id in waveform_extractor.sorting.get_unit_ids()
        }

    def fetch_data(self) -> tuple[list[np.ndarray], list[np.ndarray]]:
        """Fetches the spike times and features for each unit.

        Returns
        -------
        spike_times : list of np.ndarray
            List of spike times for each unit
        features : list of np.ndarray
            List of features for each unit

        """
        return tuple(
            zip(
                *list(
                    chain(
                        *[self._convert_data(data) for data in self.fetch_nwb()]
                    )
                )
            )
        )

    @staticmethod
    def _convert_data(nwb_data) -> list[tuple[np.ndarray, np.ndarray]]:
        feature_df = nwb_data["object_id"]

        feature_columns = [
            column for column in feature_df.columns if column != "spike_times"
        ]

        return [
            (
                unit.spike_times,
                np.concatenate(unit[feature_columns].to_numpy(), axis=1),
            )
            for _, unit in feature_df.iterrows()
        ]

fetch_data()

Fetches the spike times and features for each unit.

Returns:

Name Type Description
spike_times list of np.ndarray

List of spike times for each unit

features list of np.ndarray

List of features for each unit

Source code in src/spyglass/decoding/v1/waveform_features.py
def fetch_data(self) -> tuple[list[np.ndarray], list[np.ndarray]]:
    """Fetches the spike times and features for each unit.

    Returns
    -------
    spike_times : list of np.ndarray
        List of spike times for each unit
    features : list of np.ndarray
        List of features for each unit

    """
    return tuple(
        zip(
            *list(
                chain(
                    *[self._convert_data(data) for data in self.fetch_nwb()]
                )
            )
        )
    )