Skip to content

position_dlc_orient.py

DLCOrientationParams

Bases: SpyglassMixin, Manual

Parameters for determining and smoothing the orientation of a set of BodyParts

Source code in src/spyglass/position/v1/position_dlc_orient.py
@schema
class DLCOrientationParams(SpyglassMixin, dj.Manual):
    """
    Parameters for determining and smoothing the orientation of a set of BodyParts
    """

    definition = """
    dlc_orientation_params_name: varchar(80) # name for this set of parameters
    ---
    params: longblob
    """

    @classmethod
    def insert_params(cls, params_name: str, params: dict, **kwargs):
        cls.insert1(
            {"dlc_orientation_params_name": params_name, "params": params},
            **kwargs,
        )

    @classmethod
    def insert_default(cls, **kwargs):
        params = {
            "orient_method": "red_green_orientation",
            "bodypart1": "greenLED",
            "bodypart2": "redLED_C",
            "orientation_smoothing_std_dev": 0.001,
        }
        cls.insert1(
            {"dlc_orientation_params_name": "default", "params": params},
            **kwargs,
        )

    @classmethod
    def get_default(cls):
        query = cls & {"dlc_orientation_params_name": "default"}
        if not len(query) > 0:
            cls().insert_default(skip_duplicates=True)
            default = (
                cls & {"dlc_orientation_params_name": "default"}
            ).fetch1()
        else:
            default = query.fetch1()
        return default

DLCOrientation

Bases: SpyglassMixin, Computed

Determines and smooths orientation of a set of bodyparts given a specified method

Source code in src/spyglass/position/v1/position_dlc_orient.py
@schema
class DLCOrientation(SpyglassMixin, dj.Computed):
    """
    Determines and smooths orientation of a set of bodyparts given a specified method
    """

    definition = """
    -> DLCOrientationSelection
    ---
    -> AnalysisNwbfile
    dlc_orientation_object_id : varchar(80)
    """

    def _get_pos_df(self, key):
        cohort_entries = DLCSmoothInterpCohort.BodyPart & key
        pos_df = pd.concat(
            {
                bodypart: (
                    DLCSmoothInterpCohort.BodyPart
                    & {**key, **{"bodypart": bodypart}}
                ).fetch1_dataframe()
                for bodypart in cohort_entries.fetch("bodypart")
            },
            axis=1,
        )
        return pos_df

    def make(self, key):
        # Get labels to smooth from Parameters table
        AnalysisNwbfile()._creation_times["pre_create_time"] = time()
        pos_df = self._get_pos_df(key)

        params = (DLCOrientationParams() & key).fetch1("params")
        orientation_smoothing_std_dev = params.pop(
            "orientation_smoothing_std_dev", None
        )
        sampling_rate = 1 / np.median(np.diff(pos_df.index.to_numpy()))
        orient_func = _key_to_func_dict[params["orient_method"]]
        orientation = orient_func(pos_df, **params)

        if not params["orient_method"] == "none":
            # Smooth orientation
            is_nan = np.isnan(orientation)
            unwrap_orientation = orientation.copy()
            # Only unwrap non nan values, while keeping nans in dataset for interpolation
            unwrap_orientation[~is_nan] = np.unwrap(orientation[~is_nan])
            unwrap_df = pd.DataFrame(
                unwrap_orientation, columns=["orientation"], index=pos_df.index
            )
            nan_spans = get_span_start_stop(np.where(is_nan)[0])
            orient_df = interp_orientation(
                unwrap_df,
                nan_spans,
            )
            orientation = gaussian_smooth(
                orient_df["orientation"].to_numpy(),
                orientation_smoothing_std_dev,
                sampling_rate,
                axis=0,
                truncate=8,
            )
            # convert back to between -pi and pi
            orientation = np.angle(np.exp(1j * orientation))

        final_df = pd.DataFrame(
            orientation, columns=["orientation"], index=pos_df.index
        )
        key["analysis_file_name"] = AnalysisNwbfile().create(  # logged
            key["nwb_file_name"]
        )
        # if spatial series exists, get metadata from there
        if query := (RawPosition & key):
            spatial_series = query.fetch_nwb()[0]["raw_position"]
        else:
            spatial_series = None

        orientation = pynwb.behavior.CompassDirection()
        orientation.create_spatial_series(
            name="orientation",
            timestamps=final_df.index.to_numpy(),
            conversion=1.0,
            data=final_df["orientation"].to_numpy(),
            reference_frame=getattr(spatial_series, "reference_frame", ""),
            comments=getattr(spatial_series, "comments", "no comments"),
            description="orientation",
        )
        nwb_analysis_file = AnalysisNwbfile()
        key["dlc_orientation_object_id"] = nwb_analysis_file.add_nwb_object(
            key["analysis_file_name"], orientation
        )

        nwb_analysis_file.add(
            nwb_file_name=key["nwb_file_name"],
            analysis_file_name=key["analysis_file_name"],
        )

        self.insert1(key)
        AnalysisNwbfile().log(key, table=self.full_table_name)

    def fetch1_dataframe(self):
        nwb_data = self.fetch_nwb()[0]
        index = pd.Index(
            np.asarray(
                nwb_data["dlc_orientation"].get_spatial_series().timestamps
            ),
            name="time",
        )
        COLUMNS = ["orientation"]
        return pd.DataFrame(
            np.asarray(nwb_data["dlc_orientation"].get_spatial_series().data)[
                :, np.newaxis
            ],
            columns=COLUMNS,
            index=index,
        )