Skip to content

visualization_1D_view.py

create_1D_decode_view(posterior, linear_position=None, ref_time_sec=None)

Creates a view of an interactive heatmap of position vs. time.

Parameters:

Name Type Description Default
posterior (DataArray, shape(n_time, n_position_bins))
required
linear_position (ndarray, shape(n_time))
None
ref_time_sec float64

Reference time for the purpose of offsetting the start time

None

Returns:

Name Type Description
view DecodedLinearPositionData
Source code in src/spyglass/decoding/v0/visualization_1D_view.py
def create_1D_decode_view(
    posterior: xr.DataArray,
    linear_position: np.ndarray = None,
    ref_time_sec: Union[np.float64, None] = None,
) -> vvf.DecodedLinearPositionData:
    """Creates a view of an interactive heatmap of position vs. time.

    Parameters
    ----------
    posterior : xr.DataArray, shape (n_time, n_position_bins)
    linear_position : np.ndarray, shape (n_time, ), optional
    ref_time_sec : np.float64, optional
        Reference time for the purpose of offsetting the start time

    Returns
    -------
    view : vvf.DecodedLinearPositionData

    """
    if linear_position is not None:
        linear_position = np.asarray(linear_position).squeeze()

    trimmed_posterior = discretize_and_trim(posterior)
    observations_per_time = get_observations_per_time(
        trimmed_posterior, posterior
    )
    sampling_freq = get_sampling_freq(posterior.time)
    start_time_sec = posterior.time.values[0]
    if ref_time_sec is not None:
        start_time_sec = start_time_sec - ref_time_sec

    trimmed_bin_center_index = get_trimmed_bin_center_index(
        posterior.position.values, trimmed_posterior.position.values
    )

    return vvf.DecodedLinearPositionData(
        values=trimmed_posterior.values,
        positions=trimmed_bin_center_index,
        frame_bounds=observations_per_time,
        positions_key=posterior.position.values.astype(np.float32),
        observed_positions=linear_position,
        start_time_sec=start_time_sec,
        sampling_frequency=sampling_freq,
    )