Skip to content

clusterless.py

Pipeline for decoding the animal's mental position and some category of interest from unclustered spikes and spike waveform features. See [1] for details.

References

[1] Denovellis, E. L. et al. Hippocampal replay of experience at real-world speeds. eLife 10, e64505 (2021).

ClusterlessDecodingV1

Bases: SpyglassMixin, Computed

Source code in src/spyglass/decoding/v1/clusterless.py
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
@schema
class ClusterlessDecodingV1(SpyglassMixin, dj.Computed):
    definition = """
    -> ClusterlessDecodingSelection
    ---
    results_path: filepath@analysis # path to the results file
    classifier_path: filepath@analysis # path to the classifier file
    """

    def make(self, key):
        orig_key = copy.deepcopy(key)

        # Get model parameters
        model_params = (
            DecodingParameters
            & {"decoding_param_name": key["decoding_param_name"]}
        ).fetch1()
        decoding_params, decoding_kwargs = (
            model_params["decoding_params"],
            model_params["decoding_kwargs"],
        )
        decoding_kwargs = decoding_kwargs or {}

        # Get position data
        (
            position_info,
            position_variable_names,
        ) = self.fetch_position_info(key)

        # Get the waveform features for the selected units
        # Don't need to filter by interval since the non_local_detector code will do that
        (
            spike_times,
            spike_waveform_features,
        ) = self.fetch_spike_data(key, filter_by_interval=False)

        # Get the encoding and decoding intervals
        encoding_interval = (
            IntervalList
            & {
                "nwb_file_name": key["nwb_file_name"],
                "interval_list_name": key["encoding_interval"],
            }
        ).fetch1("valid_times")
        is_training = np.zeros(len(position_info), dtype=bool)
        for interval_start, interval_end in encoding_interval:
            is_training[
                np.logical_and(
                    position_info.index >= interval_start,
                    position_info.index <= interval_end,
                )
            ] = True
        if "is_training" not in decoding_kwargs:
            decoding_kwargs["is_training"] = is_training

        decoding_interval = (
            IntervalList
            & {
                "nwb_file_name": key["nwb_file_name"],
                "interval_list_name": key["decoding_interval"],
            }
        ).fetch1("valid_times")

        # Decode
        classifier = ClusterlessDetector(**decoding_params)

        if key["estimate_decoding_params"]:
            # if estimating parameters, then we need to treat times outside decoding interval as missing
            # this means that times outside the decoding interval will not use the spiking data
            # a better approach would be to treat the intervals as multiple sequences
            # (see https://en.wikipedia.org/wiki/Baum%E2%80%93Welch_algorithm#Multiple_sequences)
            is_missing = np.ones(len(position_info), dtype=bool)
            for interval_start, interval_end in decoding_interval:
                is_missing[
                    np.logical_and(
                        position_info.index >= interval_start,
                        position_info.index <= interval_end,
                    )
                ] = False
            if "is_missing" not in decoding_kwargs:
                decoding_kwargs["is_missing"] = is_missing
            results = classifier.estimate_parameters(
                position_time=position_info.index.to_numpy(),
                position=position_info[position_variable_names].to_numpy(),
                spike_times=spike_times,
                spike_waveform_features=spike_waveform_features,
                time=position_info.index.to_numpy(),
                **decoding_kwargs,
            )
        else:
            VALID_FIT_KWARGS = [
                "is_training",
                "encoding_group_labels",
                "environment_labels",
                "discrete_transition_covariate_data",
            ]

            fit_kwargs = {
                key: value
                for key, value in decoding_kwargs.items()
                if key in VALID_FIT_KWARGS
            }
            classifier.fit(
                position_time=position_info.index.to_numpy(),
                position=position_info[position_variable_names].to_numpy(),
                spike_times=spike_times,
                spike_waveform_features=spike_waveform_features,
                **fit_kwargs,
            )
            VALID_PREDICT_KWARGS = [
                "is_missing",
                "discrete_transition_covariate_data",
                "return_causal_posterior",
            ]
            predict_kwargs = {
                key: value
                for key, value in decoding_kwargs.items()
                if key in VALID_PREDICT_KWARGS
            }

            # We treat each decoding interval as a separate sequence
            results = []
            for interval_start, interval_end in decoding_interval:
                interval_time = position_info.loc[
                    interval_start:interval_end
                ].index.to_numpy()

                if interval_time.size == 0:
                    logger.warning(
                        f"Interval {interval_start}:{interval_end} is empty"
                    )
                    continue
                results.append(
                    classifier.predict(
                        position_time=interval_time,
                        position=position_info.loc[interval_start:interval_end][
                            position_variable_names
                        ].to_numpy(),
                        spike_times=spike_times,
                        spike_waveform_features=spike_waveform_features,
                        time=interval_time,
                        **predict_kwargs,
                    )
                )
            results = xr.concat(results, dim="intervals")

        # Save discrete transition and initial conditions
        results["initial_conditions"] = xr.DataArray(
            classifier.initial_conditions_,
            name="initial_conditions",
        )
        results["discrete_state_transitions"] = xr.DataArray(
            classifier.discrete_state_transitions_,
            dims=("states", "states"),
            name="discrete_state_transitions",
        )
        if (
            vars(classifier).get("discrete_transition_coefficients_")
            is not None
        ):
            results["discrete_transition_coefficients"] = (
                classifier.discrete_transition_coefficients_
            )

        # Insert results
        # in future use https://github.com/rly/ndx-xarray and analysis nwb file?

        nwb_file_name = key["nwb_file_name"].replace("_.nwb", "")

        # Generate a unique path for the results file
        path_exists = True
        while path_exists:
            results_path = (
                Path(config["SPYGLASS_ANALYSIS_DIR"])
                / nwb_file_name
                / f"{nwb_file_name}_{str(uuid.uuid4())}.nc"
            )
            path_exists = results_path.exists()
        classifier.save_results(
            results,
            results_path,
        )
        key["results_path"] = results_path

        classifier_path = results_path.with_suffix(".pkl")
        classifier.save_model(classifier_path)
        key["classifier_path"] = classifier_path

        self.insert1(key)

        from spyglass.decoding.decoding_merge import DecodingOutput

        DecodingOutput.insert1(orig_key, skip_duplicates=True)

    def fetch_results(self) -> xr.Dataset:
        """Retrieve the decoding results

        Returns
        -------
        xr.Dataset
            The decoding results (posteriors, etc.)
        """
        return ClusterlessDetector.load_results(self.fetch1("results_path"))

    def fetch_model(self):
        return ClusterlessDetector.load_model(self.fetch1("classifier_path"))

    @staticmethod
    def fetch_environments(key):
        """Fetch the environments for the decoding model

        Parameters
        ----------
        key : dict
            The decoding selection key

        Returns
        -------
        List[TrackGraph]
            list of track graphs in the trained model
        """
        model_params = (
            DecodingParameters
            & {"decoding_param_name": key["decoding_param_name"]}
        ).fetch1()
        decoding_params, decoding_kwargs = (
            model_params["decoding_params"],
            model_params["decoding_kwargs"],
        )

        if decoding_kwargs is None:
            decoding_kwargs = {}

        (
            position_info,
            position_variable_names,
        ) = ClusterlessDecodingV1.fetch_position_info(key)
        classifier = ClusterlessDetector(**decoding_params)

        classifier.initialize_environments(
            position=position_info[position_variable_names].to_numpy(),
            environment_labels=decoding_kwargs.get("environment_labels", None),
        )

        return classifier.environments

    @staticmethod
    def _get_interval_range(key):
        """Get the maximum range of model times in the encoding and decoding intervals

        Parameters
        ----------
        key : dict
            The decoding selection key

        Returns
        -------
        Tuple[float, float]
            The minimum and maximum times for the model
        """
        encoding_interval = (
            IntervalList
            & {
                "nwb_file_name": key["nwb_file_name"],
                "interval_list_name": key["encoding_interval"],
            }
        ).fetch1("valid_times")

        decoding_interval = (
            IntervalList
            & {
                "nwb_file_name": key["nwb_file_name"],
                "interval_list_name": key["decoding_interval"],
            }
        ).fetch1("valid_times")

        return (
            min(
                np.asarray(encoding_interval).min(),
                np.asarray(decoding_interval).min(),
            ),
            max(
                np.asarray(encoding_interval).max(),
                np.asarray(decoding_interval).max(),
            ),
        )

    @staticmethod
    def fetch_position_info(key):
        """Fetch the position information for the decoding model

        Parameters
        ----------
        key : dict
            The decoding selection key

        Returns
        -------
        Tuple[pd.DataFrame, List[str]]
            The position information and the names of the position variables
        """
        position_group_key = {
            "position_group_name": key["position_group_name"],
            "nwb_file_name": key["nwb_file_name"],
        }

        min_time, max_time = ClusterlessDecodingV1._get_interval_range(key)
        position_info, position_variable_names = (
            PositionGroup & position_group_key
        ).fetch_position_info(min_time=min_time, max_time=max_time)

        return position_info, position_variable_names

    @staticmethod
    def fetch_linear_position_info(key):
        """Fetch the position information and project it onto the track graph

        Parameters
        ----------
        key : dict
            The decoding selection key

        Returns
        -------
        pd.DataFrame
            The linearized position information
        """
        environment = ClusterlessDecodingV1.fetch_environments(key)[0]

        position_df = ClusterlessDecodingV1.fetch_position_info(key)[0]
        position_variable_names = (PositionGroup & key).fetch1(
            "position_variables"
        )
        position = np.asarray(position_df[position_variable_names])

        linear_position_df = get_linearized_position(
            position=position,
            track_graph=environment.track_graph,
            edge_order=environment.edge_order,
            edge_spacing=environment.edge_spacing,
        )

        min_time, max_time = ClusterlessDecodingV1._get_interval_range(key)

        return (
            pd.concat(
                [linear_position_df.set_index(position_df.index), position_df],
                axis=1,
            )
            .loc[min_time:max_time]
            .dropna(subset=position_variable_names)
        )

    @staticmethod
    def fetch_spike_data(key, filter_by_interval=True):
        """Fetch the spike times for the decoding model

        Parameters
        ----------
        key : dict
            The decoding selection key
        filter_by_interval : bool, optional
            Whether to filter for spike times in the model interval, by default True

        Returns
        -------
        list[np.ndarray]
            List of spike times for each unit in the model's spike group
        """
        waveform_keys = (
            (
                UnitWaveformFeaturesGroup.UnitFeatures
                & {
                    "nwb_file_name": key["nwb_file_name"],
                    "waveform_features_group_name": key[
                        "waveform_features_group_name"
                    ],
                }
            )
        ).fetch("KEY")
        spike_times, spike_waveform_features = (
            UnitWaveformFeatures & waveform_keys
        ).fetch_data()

        if not filter_by_interval:
            return spike_times, spike_waveform_features

        min_time, max_time = ClusterlessDecodingV1._get_interval_range(key)

        new_spike_times = []
        new_waveform_features = []
        for elec_spike_times, elec_waveform_features in zip(
            spike_times, spike_waveform_features
        ):
            is_in_interval = np.logical_and(
                elec_spike_times >= min_time, elec_spike_times <= max_time
            )
            new_spike_times.append(elec_spike_times[is_in_interval])
            new_waveform_features.append(elec_waveform_features[is_in_interval])

        return new_spike_times, new_waveform_features

    @classmethod
    def get_spike_indicator(cls, key, time):
        """get spike indicator matrix for the group

        Parameters
        ----------
        key : dict
            key to identify the group
        time : np.ndarray
            time vector for which to calculate the spike indicator matrix

        Returns
        -------
        np.ndarray
            spike indicator matrix with shape (len(time), n_units)
        """
        time = np.asarray(time)
        min_time, max_time = time[[0, -1]]
        spike_times = cls.fetch_spike_data(key)[0]
        spike_indicator = np.zeros((len(time), len(spike_times)))

        for ind, times in enumerate(spike_times):
            times = times[np.logical_and(times >= min_time, times <= max_time)]
            spike_indicator[:, ind] = np.bincount(
                np.digitize(times, time[1:-1]),
                minlength=time.shape[0],
            )

        return spike_indicator

    @classmethod
    def get_firing_rate(cls, key, time, multiunit=False) -> np.ndarray:
        """get time-dependent firing rate for units in the group

        Parameters
        ----------
        key : dict
            key to identify the group
        time : np.ndarray
            time vector for which to calculate the firing rate
        multiunit : bool, optional
            if True, return the multiunit firing rate for units in the group,
            by default False

        Returns
        -------
        np.ndarray
        """
        spike_indicator = cls.get_spike_indicator(key, time)
        if spike_indicator.ndim == 1:
            spike_indicator = spike_indicator[:, np.newaxis]

        sampling_frequency = 1 / np.median(np.diff(time))

        if multiunit:
            spike_indicator = spike_indicator.sum(axis=1, keepdims=True)
        return np.stack(
            [
                get_multiunit_population_firing_rate(
                    indicator[:, np.newaxis], sampling_frequency
                )
                for indicator in spike_indicator.T
            ],
            axis=1,
        )

    def get_ahead_behind_distance(self):
        """get the ahead-behind distance for the decoding model

        Returns
        -------
        distance_metrics : np.ndarray
            Information about the distance of the animal to the mental position.
        """
        # TODO: allow specification of specific time interval
        # TODO: allow specification of track graph
        # TODO: Handle decode intervals, store in table

        classifier = self.fetch_model()
        results = self.fetch_results().squeeze()
        posterior = results.acausal_posterior.unstack("state_bins").sum("state")

        if getattr(classifier.environments[0], "track_graph") is not None:
            linear_position_info = self.fetch_linear_position_info(
                self.fetch1("KEY")
            )

            orientation_name = (
                "orientation"
                if "orientation" in linear_position_info.columns
                else "head_orientation"
            )

            traj_data = analysis.get_trajectory_data(
                posterior=posterior,
                track_graph=classifier.environments[0].track_graph,
                decoder=classifier,
                actual_projected_position=linear_position_info[
                    ["projected_x_position", "projected_y_position"]
                ],
                track_segment_id=linear_position_info["track_segment_id"],
                actual_orientation=linear_position_info[orientation_name],
            )

            return analysis.get_ahead_behind_distance(
                classifier.environments[0].track_graph, *traj_data
            )
        else:
            position_info = self.fetch_position_info(self.fetch1("KEY"))
            map_position = analysis.maximum_a_posteriori_estimate(posterior)

            orientation_name = (
                "orientation"
                if "orientation" in position_info.columns
                else "head_orientation"
            )
            position_variable_names = (
                PositionGroup & self.fetch1("KEY")
            ).fetch1("position_variables")

            return analysis.get_ahead_behind_distance2D(
                position_info[position_variable_names].to_numpy(),
                position_info[orientation_name].to_numpy(),
                map_position,
                classifier.environments[0].track_graphDD,
            )

fetch_results()

Retrieve the decoding results

Returns:

Type Description
Dataset

The decoding results (posteriors, etc.)

Source code in src/spyglass/decoding/v1/clusterless.py
def fetch_results(self) -> xr.Dataset:
    """Retrieve the decoding results

    Returns
    -------
    xr.Dataset
        The decoding results (posteriors, etc.)
    """
    return ClusterlessDetector.load_results(self.fetch1("results_path"))

fetch_environments(key) staticmethod

Fetch the environments for the decoding model

Parameters:

Name Type Description Default
key dict

The decoding selection key

required

Returns:

Type Description
List[TrackGraph]

list of track graphs in the trained model

Source code in src/spyglass/decoding/v1/clusterless.py
@staticmethod
def fetch_environments(key):
    """Fetch the environments for the decoding model

    Parameters
    ----------
    key : dict
        The decoding selection key

    Returns
    -------
    List[TrackGraph]
        list of track graphs in the trained model
    """
    model_params = (
        DecodingParameters
        & {"decoding_param_name": key["decoding_param_name"]}
    ).fetch1()
    decoding_params, decoding_kwargs = (
        model_params["decoding_params"],
        model_params["decoding_kwargs"],
    )

    if decoding_kwargs is None:
        decoding_kwargs = {}

    (
        position_info,
        position_variable_names,
    ) = ClusterlessDecodingV1.fetch_position_info(key)
    classifier = ClusterlessDetector(**decoding_params)

    classifier.initialize_environments(
        position=position_info[position_variable_names].to_numpy(),
        environment_labels=decoding_kwargs.get("environment_labels", None),
    )

    return classifier.environments

fetch_position_info(key) staticmethod

Fetch the position information for the decoding model

Parameters:

Name Type Description Default
key dict

The decoding selection key

required

Returns:

Type Description
Tuple[DataFrame, List[str]]

The position information and the names of the position variables

Source code in src/spyglass/decoding/v1/clusterless.py
@staticmethod
def fetch_position_info(key):
    """Fetch the position information for the decoding model

    Parameters
    ----------
    key : dict
        The decoding selection key

    Returns
    -------
    Tuple[pd.DataFrame, List[str]]
        The position information and the names of the position variables
    """
    position_group_key = {
        "position_group_name": key["position_group_name"],
        "nwb_file_name": key["nwb_file_name"],
    }

    min_time, max_time = ClusterlessDecodingV1._get_interval_range(key)
    position_info, position_variable_names = (
        PositionGroup & position_group_key
    ).fetch_position_info(min_time=min_time, max_time=max_time)

    return position_info, position_variable_names

fetch_linear_position_info(key) staticmethod

Fetch the position information and project it onto the track graph

Parameters:

Name Type Description Default
key dict

The decoding selection key

required

Returns:

Type Description
DataFrame

The linearized position information

Source code in src/spyglass/decoding/v1/clusterless.py
@staticmethod
def fetch_linear_position_info(key):
    """Fetch the position information and project it onto the track graph

    Parameters
    ----------
    key : dict
        The decoding selection key

    Returns
    -------
    pd.DataFrame
        The linearized position information
    """
    environment = ClusterlessDecodingV1.fetch_environments(key)[0]

    position_df = ClusterlessDecodingV1.fetch_position_info(key)[0]
    position_variable_names = (PositionGroup & key).fetch1(
        "position_variables"
    )
    position = np.asarray(position_df[position_variable_names])

    linear_position_df = get_linearized_position(
        position=position,
        track_graph=environment.track_graph,
        edge_order=environment.edge_order,
        edge_spacing=environment.edge_spacing,
    )

    min_time, max_time = ClusterlessDecodingV1._get_interval_range(key)

    return (
        pd.concat(
            [linear_position_df.set_index(position_df.index), position_df],
            axis=1,
        )
        .loc[min_time:max_time]
        .dropna(subset=position_variable_names)
    )

fetch_spike_data(key, filter_by_interval=True) staticmethod

Fetch the spike times for the decoding model

Parameters:

Name Type Description Default
key dict

The decoding selection key

required
filter_by_interval bool

Whether to filter for spike times in the model interval, by default True

True

Returns:

Type Description
list[ndarray]

List of spike times for each unit in the model's spike group

Source code in src/spyglass/decoding/v1/clusterless.py
@staticmethod
def fetch_spike_data(key, filter_by_interval=True):
    """Fetch the spike times for the decoding model

    Parameters
    ----------
    key : dict
        The decoding selection key
    filter_by_interval : bool, optional
        Whether to filter for spike times in the model interval, by default True

    Returns
    -------
    list[np.ndarray]
        List of spike times for each unit in the model's spike group
    """
    waveform_keys = (
        (
            UnitWaveformFeaturesGroup.UnitFeatures
            & {
                "nwb_file_name": key["nwb_file_name"],
                "waveform_features_group_name": key[
                    "waveform_features_group_name"
                ],
            }
        )
    ).fetch("KEY")
    spike_times, spike_waveform_features = (
        UnitWaveformFeatures & waveform_keys
    ).fetch_data()

    if not filter_by_interval:
        return spike_times, spike_waveform_features

    min_time, max_time = ClusterlessDecodingV1._get_interval_range(key)

    new_spike_times = []
    new_waveform_features = []
    for elec_spike_times, elec_waveform_features in zip(
        spike_times, spike_waveform_features
    ):
        is_in_interval = np.logical_and(
            elec_spike_times >= min_time, elec_spike_times <= max_time
        )
        new_spike_times.append(elec_spike_times[is_in_interval])
        new_waveform_features.append(elec_waveform_features[is_in_interval])

    return new_spike_times, new_waveform_features

get_spike_indicator(key, time) classmethod

get spike indicator matrix for the group

Parameters:

Name Type Description Default
key dict

key to identify the group

required
time ndarray

time vector for which to calculate the spike indicator matrix

required

Returns:

Type Description
ndarray

spike indicator matrix with shape (len(time), n_units)

Source code in src/spyglass/decoding/v1/clusterless.py
@classmethod
def get_spike_indicator(cls, key, time):
    """get spike indicator matrix for the group

    Parameters
    ----------
    key : dict
        key to identify the group
    time : np.ndarray
        time vector for which to calculate the spike indicator matrix

    Returns
    -------
    np.ndarray
        spike indicator matrix with shape (len(time), n_units)
    """
    time = np.asarray(time)
    min_time, max_time = time[[0, -1]]
    spike_times = cls.fetch_spike_data(key)[0]
    spike_indicator = np.zeros((len(time), len(spike_times)))

    for ind, times in enumerate(spike_times):
        times = times[np.logical_and(times >= min_time, times <= max_time)]
        spike_indicator[:, ind] = np.bincount(
            np.digitize(times, time[1:-1]),
            minlength=time.shape[0],
        )

    return spike_indicator

get_firing_rate(key, time, multiunit=False) classmethod

get time-dependent firing rate for units in the group

Parameters:

Name Type Description Default
key dict

key to identify the group

required
time ndarray

time vector for which to calculate the firing rate

required
multiunit bool

if True, return the multiunit firing rate for units in the group, by default False

False

Returns:

Type Description
ndarray
Source code in src/spyglass/decoding/v1/clusterless.py
@classmethod
def get_firing_rate(cls, key, time, multiunit=False) -> np.ndarray:
    """get time-dependent firing rate for units in the group

    Parameters
    ----------
    key : dict
        key to identify the group
    time : np.ndarray
        time vector for which to calculate the firing rate
    multiunit : bool, optional
        if True, return the multiunit firing rate for units in the group,
        by default False

    Returns
    -------
    np.ndarray
    """
    spike_indicator = cls.get_spike_indicator(key, time)
    if spike_indicator.ndim == 1:
        spike_indicator = spike_indicator[:, np.newaxis]

    sampling_frequency = 1 / np.median(np.diff(time))

    if multiunit:
        spike_indicator = spike_indicator.sum(axis=1, keepdims=True)
    return np.stack(
        [
            get_multiunit_population_firing_rate(
                indicator[:, np.newaxis], sampling_frequency
            )
            for indicator in spike_indicator.T
        ],
        axis=1,
    )

get_ahead_behind_distance()

get the ahead-behind distance for the decoding model

Returns:

Name Type Description
distance_metrics ndarray

Information about the distance of the animal to the mental position.

Source code in src/spyglass/decoding/v1/clusterless.py
def get_ahead_behind_distance(self):
    """get the ahead-behind distance for the decoding model

    Returns
    -------
    distance_metrics : np.ndarray
        Information about the distance of the animal to the mental position.
    """
    # TODO: allow specification of specific time interval
    # TODO: allow specification of track graph
    # TODO: Handle decode intervals, store in table

    classifier = self.fetch_model()
    results = self.fetch_results().squeeze()
    posterior = results.acausal_posterior.unstack("state_bins").sum("state")

    if getattr(classifier.environments[0], "track_graph") is not None:
        linear_position_info = self.fetch_linear_position_info(
            self.fetch1("KEY")
        )

        orientation_name = (
            "orientation"
            if "orientation" in linear_position_info.columns
            else "head_orientation"
        )

        traj_data = analysis.get_trajectory_data(
            posterior=posterior,
            track_graph=classifier.environments[0].track_graph,
            decoder=classifier,
            actual_projected_position=linear_position_info[
                ["projected_x_position", "projected_y_position"]
            ],
            track_segment_id=linear_position_info["track_segment_id"],
            actual_orientation=linear_position_info[orientation_name],
        )

        return analysis.get_ahead_behind_distance(
            classifier.environments[0].track_graph, *traj_data
        )
    else:
        position_info = self.fetch_position_info(self.fetch1("KEY"))
        map_position = analysis.maximum_a_posteriori_estimate(posterior)

        orientation_name = (
            "orientation"
            if "orientation" in position_info.columns
            else "head_orientation"
        )
        position_variable_names = (
            PositionGroup & self.fetch1("KEY")
        ).fetch1("position_variables")

        return analysis.get_ahead_behind_distance2D(
            position_info[position_variable_names].to_numpy(),
            position_info[orientation_name].to_numpy(),
            map_position,
            classifier.environments[0].track_graphDD,
        )