Skip to content

position_dlc_position.py

DLCSmoothInterpParams

Bases: SpyglassMixin, Manual

Parameters for extracting the smoothed head position.

Attributes:

Name Type Description
interpolate bool, default True

whether to interpolate over NaN spans

smooth bool, default True

whether to smooth the dataset

smoothing_params dict

smoothing_duration : float, default 0.05 number of frames to smooth over: sampling_rate*smoothing_duration = num_frames

interp_params dict

max_cm_to_interp : int, default 20 maximum distance between high likelihood points on either side of a NaN span to interpolate over

likelihood_thresh float, default 0.95

likelihood below which to NaN and interpolate over

Source code in src/spyglass/position/v1/position_dlc_position.py
@schema
class DLCSmoothInterpParams(SpyglassMixin, dj.Manual):
    """
    Parameters for extracting the smoothed head position.

    Attributes
    ----------
    interpolate : bool, default True
        whether to interpolate over NaN spans
    smooth : bool, default True
        whether to smooth the dataset
    smoothing_params : dict
        smoothing_duration : float, default 0.05
            number of frames to smooth over:
            sampling_rate*smoothing_duration = num_frames
    interp_params : dict
        max_cm_to_interp : int, default 20
            maximum distance between high likelihood points on either side of a
            NaN span to interpolate over
    likelihood_thresh : float, default 0.95
        likelihood below which to NaN and interpolate over
    """

    definition = """
    dlc_si_params_name : varchar(80) # name for this set of parameters
    ---
    params: longblob # dictionary of parameters
    """

    @classmethod
    def insert_params(cls, params_name: str, params: dict, **kwargs):
        cls.insert1(
            {"dlc_si_params_name": params_name, "params": params},
            **kwargs,
        )

    @classmethod
    def insert_default(cls, **kwargs):
        default_params = {
            "smooth": True,
            "smoothing_params": {
                "smoothing_duration": 0.05,
                "smooth_method": "moving_avg",
            },
            "interpolate": True,
            "likelihood_thresh": 0.95,
            "interp_params": {"max_cm_to_interp": 15},
            "max_cm_between_pts": 20,
            # This is for use when finding "good spans" and is how many indices
            # to bridge in between good spans see inds_to_span in get_good_spans
            "num_inds_to_span": 20,
        }
        cls.insert1(
            {"dlc_si_params_name": "default", "params": default_params},
            **kwargs,
        )

    @classmethod
    def insert_nan_params(cls, **kwargs):
        nan_params = {
            "smooth": False,
            "interpolate": False,
            "likelihood_thresh": 0.95,
            "max_cm_between_pts": 20,
            "num_inds_to_span": 20,
        }
        cls.insert1(
            {"dlc_si_params_name": "just_nan", "params": nan_params}, **kwargs
        )

    @classmethod
    def get_default(cls):
        query = cls & {"dlc_si_params_name": "default"}
        if not len(query) > 0:
            cls().insert_default(skip_duplicates=True)
            default = (cls & {"dlc_si_params_name": "default"}).fetch1()
        else:
            default = query.fetch1()
        return default

    @classmethod
    def get_nan_params(cls):
        query = cls & {"dlc_si_params_name": "just_nan"}
        if not len(query) > 0:
            cls().insert_nan_params(skip_duplicates=True)
            nan_params = (cls & {"dlc_si_params_name": "just_nan"}).fetch1()
        else:
            nan_params = query.fetch1()
        return nan_params

    @staticmethod
    def get_available_methods():
        return _key_to_smooth_func_dict.keys()

    def insert1(self, key, **kwargs):
        params = key.get("params")
        if not isinstance(params, dict):
            raise KeyError("'params' must be a dict in key")

        validate_option(
            option=params.get("max_cm_between_pts"), name="max_cm_between_pts"
        )
        validate_smooth_params(params)

        validate_option(
            params.get("likelihood_thresh"),
            name="likelihood_thresh",
            types=float,
            val_range=(0, 1),
        )

        super().insert1(key, **kwargs)

DLCSmoothInterp

Bases: SpyglassMixin, Computed

Interpolates across low likelihood periods and smooths the position Can take a few minutes.

Source code in src/spyglass/position/v1/position_dlc_position.py
@schema
class DLCSmoothInterp(SpyglassMixin, dj.Computed):
    """
    Interpolates across low likelihood periods and smooths the position
    Can take a few minutes.
    """

    definition = """
    -> DLCSmoothInterpSelection
    ---
    -> AnalysisNwbfile
    dlc_smooth_interp_position_object_id : varchar(80)
    dlc_smooth_interp_info_object_id : varchar(80)
    """
    log_path = None

    def make(self, key):
        self.log_path = (
            Path(infer_output_dir(key=key, makedir=False)) / "log.log"
        )
        self._logged_make(key)
        logger.info("inserted entry into DLCSmoothInterp")

    @file_log(logger, console=False)
    def _logged_make(self, key):

        METERS_PER_CM = 0.01

        logger.info("-----------------------")
        idx = pd.IndexSlice
        # Get labels to smooth from Parameters table
        params = (DLCSmoothInterpParams() & key).fetch1("params")
        # Get DLC output dataframe
        logger.info("fetching Pose Estimation Dataframe")

        bp_key = key.copy()
        if test_mode:  # during testing, analysis_file not in BodyPart table
            bp_key.pop("analysis_file_name", None)

        dlc_df = (DLCPoseEstimation.BodyPart() & bp_key).fetch1_dataframe()
        dt = np.median(np.diff(dlc_df.index.to_numpy()))
        logger.info("Identifying indices to NaN")
        df_w_nans, bad_inds = nan_inds(
            dlc_df.copy(),
            max_dist_between=params["max_cm_between_pts"],
            likelihood_thresh=params.pop("likelihood_thresh"),
            inds_to_span=params["num_inds_to_span"],
        )

        nan_spans = get_span_start_stop(np.where(bad_inds)[0])

        if interp_params := params.get("interpolate"):
            logger.info("interpolating across low likelihood times")
            interp_df = interp_pos(df_w_nans.copy(), nan_spans, **interp_params)
        else:
            interp_df = df_w_nans.copy()
            logger.info("skipping interpolation")

        if params.get("smooth"):
            smooth_params = params.get("smoothing_params")
            smooth_method = smooth_params.get("smooth_method")
            smooth_func = _key_to_smooth_func_dict[smooth_method]

            dt = np.median(np.diff(dlc_df.index.to_numpy()))
            logger.info(f"Smoothing using method: {smooth_method}")
            smooth_df = smooth_func(
                interp_df,
                smoothing_duration=smooth_params.get("smoothing_duration"),
                sampling_rate=1 / dt,
                **params["smoothing_params"],
            )
        else:
            smooth_df = interp_df.copy()
            logger.info("skipping smoothing")

        final_df = smooth_df.drop(["likelihood"], axis=1)
        final_df = final_df.rename_axis("time").reset_index()
        position_nwb_data = (
            (DLCPoseEstimation.BodyPart() & bp_key)
            .fetch_nwb()[0]["dlc_pose_estimation_position"]
            .get_spatial_series()
        )
        key["analysis_file_name"] = AnalysisNwbfile().create(
            key["nwb_file_name"]
        )

        # Add dataframe to AnalysisNwbfile
        nwb_analysis_file = AnalysisNwbfile()
        position = pynwb.behavior.Position()
        video_frame_ind = pynwb.behavior.BehavioralTimeSeries()
        logger.info("Creating NWB objects")
        position.create_spatial_series(
            name="position",
            timestamps=final_df.time.to_numpy(),
            conversion=METERS_PER_CM,
            data=final_df.loc[:, idx[("x", "y")]].to_numpy(),
            reference_frame=position_nwb_data.reference_frame,
            comments=position_nwb_data.comments,
            description="x_position, y_position",
        )
        video_frame_ind.create_timeseries(
            name="video_frame_ind",
            timestamps=final_df.time.to_numpy(),
            data=final_df.loc[:, idx["video_frame_ind"]].to_numpy(),
            unit="index",
            comments="no comments",
            description="video_frame_ind",
        )
        key["dlc_smooth_interp_position_object_id"] = (
            nwb_analysis_file.add_nwb_object(
                analysis_file_name=key["analysis_file_name"],
                nwb_object=position,
            )
        )
        key["dlc_smooth_interp_info_object_id"] = (
            nwb_analysis_file.add_nwb_object(
                analysis_file_name=key["analysis_file_name"],
                nwb_object=video_frame_ind,
            )
        )
        nwb_analysis_file.add(
            nwb_file_name=key["nwb_file_name"],
            analysis_file_name=key["analysis_file_name"],
        )
        self.insert1(key)
        AnalysisNwbfile().log(key, table=self.full_table_name)

    def fetch1_dataframe(self):
        nwb_data = self.fetch_nwb()[0]
        index = pd.Index(
            np.asarray(
                nwb_data["dlc_smooth_interp_position"]
                .get_spatial_series()
                .timestamps
            ),
            name="time",
        )
        COLUMNS = [
            "video_frame_ind",
            "x",
            "y",
        ]
        return pd.DataFrame(
            np.concatenate(
                (
                    np.asarray(
                        nwb_data["dlc_smooth_interp_info"]
                        .time_series["video_frame_ind"]
                        .data,
                        dtype=int,
                    )[:, np.newaxis],
                    np.asarray(
                        nwb_data["dlc_smooth_interp_position"]
                        .get_spatial_series()
                        .data
                    ),
                ),
                axis=1,
            ),
            columns=COLUMNS,
            index=index,
        )

get_good_spans(bad_inds_mask, inds_to_span=50)

This function takes in a boolean mask of good and bad indices and determines spans of consecutive good indices. It combines two neighboring spans with a separation of less than inds_to_span and treats them as a single good span.

Parameters:

Name Type Description Default
bad_inds_mask boolean mask

A boolean mask where True is a bad index and False is a good index.

required
inds_to_span int

This indicates how many indices between two good spans should be bridged to form a single good span. For instance if span A is (1500, 2350) and span B is (2370, 3700), then span A and span B would be combined into span A (1500, 3700) since one would want to identify potential jumps in the space in between the original A and B.

50

Returns:

Name Type Description
good_spans list

List of spans of good indices, unmodified.

modified_spans list

spans that are amended to bridge up to inds_to_span consecutive bad indices

Source code in src/spyglass/position/v1/position_dlc_position.py
def get_good_spans(bad_inds_mask, inds_to_span: int = 50):
    """
    This function takes in a boolean mask of good and bad indices and
    determines spans of consecutive good indices. It combines two neighboring
    spans with a separation of less than inds_to_span and treats them as a
    single good span.

    Parameters
    ----------
    bad_inds_mask : boolean mask
        A boolean mask where True is a bad index and False is a good index.
    inds_to_span : int, default 50
        This indicates how many indices between two good spans should
        be bridged to form a single good span.
        For instance if span A is (1500, 2350) and span B is (2370, 3700),
        then span A and span B would be combined into span A (1500, 3700)
        since one would want to identify potential jumps in the space in between
        the original A and B.

    Returns
    -------
    good_spans : list
        List of spans of good indices, unmodified.
    modified_spans : list
        spans that are amended to bridge up to inds_to_span consecutive bad indices
    """
    good = get_span_start_stop(np.arange(len(bad_inds_mask))[~bad_inds_mask])

    if len(good) < 1:
        return None, good
    elif len(good) == 1:  # if all good, no need to modify
        return good, good

    modified_spans = []
    for (start1, stop1), (start2, stop2) in zip(good[:-1], good[1:]):
        check_existing = [
            entry
            for entry in modified_spans
            if start1 in range(entry[0] - inds_to_span, entry[1] + inds_to_span)
        ]
        if len(check_existing) > 0:
            modify_ind = modified_spans.index(check_existing[0])
            if (start2 - stop1) <= inds_to_span:
                modified_spans[modify_ind] = (check_existing[0][0], stop2)
            else:
                modified_spans[modify_ind] = (check_existing[0][0], stop1)
                modified_spans.append((start2, stop2))
            continue
        if (start2 - stop1) <= inds_to_span:
            modified_spans.append((start1, stop2))
        else:
            modified_spans.append((start1, stop1))
            modified_spans.append((start2, stop2))
    return good, modified_spans