Skip to content

dj_merge_tables.py

is_merge_table(table)

Return True if table fields exactly match Merge table.

Source code in src/spyglass/utils/dj_merge_tables.py
def is_merge_table(table):
    """Return True if table fields exactly match Merge table."""

    def trim_def(definition):
        return re_sub(
            r"\n\s*\n", "\n", re_sub(r"#.*\n", "\n", definition.strip())
        )

    if isinstance(table, str):
        table = dj.FreeTable(dj.conn(), table)
    if not isinstance(table, dj.Table):
        return False
    if get_master(table.full_table_name):
        return False  # Part tables are not merge tables
    if not table.is_declared:
        if tbl_def := getattr(table, "definition", None):
            return trim_def(MERGE_DEFINITION) == trim_def(tbl_def)
        logger.warning(
            f"Cannot determine merge table status for {table.table_name}"
        )
        return True
    return table.primary_key == [
        RESERVED_PRIMARY_KEY
    ] and table.heading.secondary_attributes == [RESERVED_SECONDARY_KEY]

Merge

Bases: Manual

Adds funcs to support standard Merge table operations.

Many methods have the @classmethod decorator to permit MergeTable.method() symtax. This makes access to instance attributes (e.g., (MergeTable & "example='restriction'").restriction) harder, but these attributes have limited utility when the user wants to, for example, restrict the merged view rather than the master table itself.

Source code in src/spyglass/utils/dj_merge_tables.py
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
class Merge(dj.Manual):
    """Adds funcs to support standard Merge table operations.

    Many methods have the @classmethod decorator to permit MergeTable.method()
    symtax. This makes access to instance attributes (e.g., (MergeTable &
    "example='restriction'").restriction) harder, but these attributes have
    limited utility when the user wants to, for example, restrict the merged
    view rather than the master table itself.
    """

    def __init__(self):
        super().__init__()
        self._reserved_pk = RESERVED_PRIMARY_KEY
        self._reserved_sk = RESERVED_SECONDARY_KEY
        if not self.is_declared:
            if not is_merge_table(self):  # Check definition
                logger.warn(
                    "Merge table with non-default definition\n"
                    + f"Expected:\n{MERGE_DEFINITION.strip()}\n"
                    + f"Actual  :\n{self.definition.strip()}"
                )
            for part in self.parts(as_objects=True):
                if part.primary_key != self.primary_key:
                    logger.warn(  # PK is only 'merge_id' in parts, no others
                        f"Unexpected primary key in {part.table_name}"
                        + f"\n\tExpected: {self.primary_key}"
                        + f"\n\tActual  : {part.primary_key}"
                    )
        self._source_class_dict = {}

    @staticmethod
    def _part_name(part=None):
        """Return the CamelCase name of a part table"""
        if not isinstance(part, str):
            part = part.table_name
        return to_camel_case(part.split("__")[-1].strip("`"))

    def get_source_from_key(self, key: dict) -> str:
        """Return the source of a given key"""
        return self._normalize_source(key)

    def parts(self, camel_case=False, *args, **kwargs) -> list:
        """Return a list of part tables, add option for CamelCase names.

        See DataJoint `parts` for additional arguments. If camel_case is True,
        forces return of strings rather than objects.
        """
        self._ensure_dependencies_loaded()

        if camel_case and kwargs.get("as_objects"):
            logger.warning(
                "Overriding as_objects=True to return CamelCase part names."
            )
            kwargs["as_objects"] = False

        parts = super().parts(*args, **kwargs)

        if camel_case:
            parts = [self._part_name(part) for part in parts]

        return parts

    @classmethod
    def _merge_restrict_parts(
        cls,
        restriction: str = True,
        as_objects: bool = True,
        return_empties: bool = True,
        add_invalid_restrict: bool = True,
    ) -> list:
        """Returns a list of parts with restrictions applied.

        Parameters
        ---------
        restriction: str, optional
            Restriction to apply to the parts. Default True, no restrictions.
        as_objects: bool, optional
            Default True. Return part tables as objects
        return_empties: bool, optional
            Default True. Return empty part tables
        add_invalid_restrict: bool, optional
            Default True. Include part for which the restriction is invalid.

        Returns
        ------
        list
            list of datajoint tables, parts of Merge Table
        """

        cls._ensure_dependencies_loaded()

        # Normalize restriction to sql string
        restr_str = make_condition(cls(), restriction, set())

        parts_all = cls.parts(as_objects=True)
        # If the restriction makes ref to a source, we only want that part
        if (
            not return_empties
            and isinstance(restr_str, str)
            and f"`{cls()._reserved_sk}`" in restr_str
        ):
            parts_all = [
                part
                for part in parts_all
                if from_camel_case(
                    restr_str.split(f'`{cls()._reserved_sk}`="')[-1].split('"')[
                        0
                    ]
                )  # Only look at source part table
                in part.full_table_name
            ]
        if isinstance(restriction, dict):  # restr by source already done above
            _ = restriction.pop(cls()._reserved_sk, None)  # won't work for str
            # If a dict restriction has all invalid keys, it is treated as True
            if not add_invalid_restrict:
                parts_all = [  # so exclude tables w/ nonmatching attrs
                    p
                    for p in parts_all
                    if all([k in p.heading.names for k in restriction.keys()])
                ]

        parts = []
        for part in parts_all:
            try:
                parts.append(part.restrict(restriction))
            except DataJointError:  # If restriction not valid on given part
                if add_invalid_restrict:
                    parts.append(part)

        if not return_empties:
            parts = [p for p in parts if len(p)]
        if not as_objects:
            parts = [p.full_table_name for p in parts]

        return parts

    @classmethod
    def _merge_restrict_parents(
        cls,
        restriction: str = True,
        parent_name: str = None,
        as_objects: bool = True,
        return_empties: bool = True,
        add_invalid_restrict: bool = True,
    ) -> list:
        """Returns a list of part parents with restrictions applied.

        Rather than part tables, we look at parents of those parts, the source
        of the data.

        Parameters
        ---------
        restriction: str, optional
            Restriction to apply to the returned parent. Default True, no
            restrictions.
        parent_name: str, optional
            CamelCase name of the parent.
        as_objects: bool, optional
            Default True. Return part tables as objects
        return_empties: bool, optional
            Default True. Return empty part tables
        add_invalid_restrict: bool, optional
            Default True. Include part for which the restriction is invalid.

        Returns
        ------
        list
            list of datajoint tables, parents of parts of Merge Table
        """
        # .restrict(restriction) does not work on returned part FreeTable
        # & part.fetch below restricts parent to entries in merge table
        part_parents = [
            parent
            & part.fetch(*part.heading.secondary_attributes, as_dict=True)
            for part in cls()._merge_restrict_parts(
                restriction=restriction,
                return_empties=return_empties,
                add_invalid_restrict=add_invalid_restrict,
            )
            for parent in part.parents(as_objects=True)  # ID respective parents
            if cls().table_name not in parent.full_table_name  # Not merge table
        ]
        if parent_name:
            part_parents = [
                p
                for p in part_parents
                if from_camel_case(parent_name) in p.full_table_name
            ]
        if not as_objects:
            part_parents = [p.full_table_name for p in part_parents]

        return part_parents

    @classmethod
    def _merge_repr(
        cls, restriction: str = True, include_empties=False
    ) -> dj.expression.Union:
        """Merged view, including null entries for columns unique to one part.

        Parameters
        ---------
        restriction: str, optional
            Restriction to apply to the merged view
        include_empties: bool, optional
            Default False. Add columns for empty parts.

        Returns
        ------
        datajoint.expression.Union
        """

        parts = [
            cls() * p  # join with master to include sec key (i.e., 'source')
            for p in cls._merge_restrict_parts(
                restriction=restriction,
                add_invalid_restrict=False,
                return_empties=include_empties,
            )
        ]
        if not parts:
            logger.warning("No parts found. Try adjusting restriction.")
            return

        attr_dict = {  # NULL for non-numeric, 0 for numeric
            attr.name: "0" if attr.numeric else "NULL"
            for attr in iter_chain.from_iterable(
                part.heading.attributes.values() for part in parts
            )
        }

        def _proj_part(part):
            """Project part, adding NULL/0 for missing attributes"""
            return dj.U(*attr_dict.keys()) * part.proj(
                ...,  # include all attributes from part
                **{
                    k: v
                    for k, v in attr_dict.items()
                    if k not in part.heading.names
                },
            )

        query = _proj_part(parts[0])  # start with first part
        for part in parts[1:]:  # add remaining parts
            query += _proj_part(part)

        return query

    @classmethod
    def _merge_insert(cls, rows: list, part_name: str = None, **kwargs) -> None:
        """Insert rows into merge, ensuring data exists in part parent(s).

        Parameters
        ---------
        rows: List[dict]
            An iterable where an element is a dictionary.
        part: str, optional
            CamelCase name of the part table

        Raises
        ------
        TypeError
            If rows is not a list of dicts
        ValueError
            If data doesn't exist in part parents, integrity error
        """
        cls._ensure_dependencies_loaded()

        type_err_msg = "Input `rows` must be a list of dictionaries"
        try:
            for r in iter(rows):
                if not isinstance(r, dict):
                    raise TypeError(type_err_msg)
        except TypeError:
            raise TypeError(type_err_msg)

        parts = cls._merge_restrict_parts(as_objects=True)
        if part_name:
            parts = [
                p
                for p in parts
                if from_camel_case(part_name) in p.full_table_name
            ]

        master_entries = []
        parts_entries = {p: [] for p in parts}
        for row in rows:
            keys = []  # empty to-be-inserted keys
            for part in parts:  # check each part
                part_name = cls._part_name(part)
                part_parent = part.parents(as_objects=True)[-1]
                if part_parent & row:  # if row is in part parent
                    keys = (part_parent & row).fetch("KEY")  # get pk
                    if len(keys) > 1:
                        raise ValueError(
                            "Ambiguous entry. Data has mult rows in "
                            + f"{part_name}:\n\tData:{row}\n\t{keys}"
                        )
                    key = keys[0]
                    if part & key:
                        print(f"Key already in part {part_name}: {key}")
                        continue
                    master_sk = {cls()._reserved_sk: part_name}
                    uuid = dj.hash.key_hash(key | master_sk)
                    master_pk = {cls()._reserved_pk: uuid}

                    master_entries.append({**master_pk, **master_sk})
                    parts_entries[part].append({**master_pk, **key})

            if not keys:
                raise ValueError(
                    "Non-existing entry in any of the parent tables - Entry: "
                    + f"{row}"
                )

        with cls._safe_context():
            super().insert(cls(), master_entries, **kwargs)
            for part, part_entries in parts_entries.items():
                part.insert(part_entries, **kwargs)

    @classmethod
    def _ensure_dependencies_loaded(cls) -> None:
        """Ensure connection dependencies loaded.

        Otherwise parts returns none
        """
        if not dj.conn.connection.dependencies._loaded:
            dj.conn.connection.dependencies.load()

    def insert(self, rows: list, **kwargs):
        """Merges table specific insert, ensuring data exists in part parents.

        Parameters
        ---------
        rows: List[dict]
            An iterable where an element is a dictionary.

        Raises
        ------
        TypeError
            If rows is not a list of dicts
        ValueError
            If data doesn't exist in part parents, integrity error
        """
        self._merge_insert(rows, **kwargs)

    @classmethod
    def merge_view(cls, restriction: str = True):
        """Prints merged view, including null entries for unique columns.

        Note: To handle this Union as a table-like object, use `merge_resrict`

        Parameters
        ---------
        restriction: str, optional
            Restriction to apply to the merged view
        """

        # If we overwrite `preview`, we then encounter issues with operators
        # getting passed a `Union`, which doesn't have a method we can
        # intercept to manage master/parts

        return pprint(cls._merge_repr(restriction=restriction))

    @classmethod
    def merge_html(cls, restriction: str = True):
        """Displays HTML in notebooks."""

        return HTML(repr_html(cls._merge_repr(restriction=restriction)))

    @classmethod
    def merge_restrict(cls, restriction: str = True) -> dj.U:
        """Given a restriction, return a merged view with restriction applied.

        Example
        -------
            >>> MergeTable.merge_restrict("field = 1")

        Parameters
        ----------
        restriction: str
            Restriction one would apply if `merge_view` was a real table.

        Returns
        -------
        datajoint.Union
            Merged view with restriction applied.
        """
        return cls._merge_repr(restriction=restriction)

    @classmethod
    def merge_delete(cls, restriction: str = True, **kwargs):
        """Given a restriction string, delete corresponding entries.

        Parameters
        ----------
        restriction: str
            Optional restriction to apply before deletion from master/part
            tables. If not provided, delete all entries.
        kwargs: dict
            Additional keyword arguments for DataJoint delete.

        Example
        -------
            >>> MergeTable.merge_delete("field = 1")
        """
        uuids = [
            {k: v}
            for entry in cls.merge_restrict(restriction).fetch("KEY")
            for k, v in entry.items()
            if k == cls()._reserved_pk
        ]
        (cls() & uuids).delete(**kwargs)

    @classmethod
    def merge_delete_parent(
        cls, restriction: str = True, dry_run=True, **kwargs
    ) -> list:
        """Delete entries from merge master, part, and respective part parents

        Note: Clears merge entries from their respective parents.

        Parameters
        ----------
        restriction: str
            Optional restriction to apply before deletion from parents. If not
            provided, delete all entries present in Merge Table.
        dry_run: bool
            Default True. If true, return list of tables with entries that would
            be deleted. Otherwise, table entries.
        kwargs: dict
            Additional keyword arguments for DataJoint delete.
        """
        part_parents = cls._merge_restrict_parents(
            restriction=restriction, as_objects=True, return_empties=False
        )

        if dry_run:
            return part_parents

        merge_ids = cls.merge_restrict(restriction).fetch(
            RESERVED_PRIMARY_KEY, as_dict=True
        )

        # CB: Removed transaction protection here bc 'no' confirmation resp
        # still resulted in deletes. If re-add, consider transaction=False
        super().delete((cls & merge_ids), **kwargs)

        if cls & merge_ids:  # If 'no' on del prompt from above, skip below
            return  # User can still abort del below, but yes/no is unlikly

        for part_parent in part_parents:
            super().delete(part_parent, **kwargs)

    def fetch_nwb(
        self,
        restriction: str = None,
        multi_source=False,
        disable_warning=False,
        *attrs,
        **kwargs,
    ):
        """Return the (Analysis)Nwbfile file linked in the source.

        Relies on SpyglassMixin._nwb_table_tuple to determine the table to
        fetch from and the appropriate path attribute to return.

        Parameters
        ----------
        restriction: str, optional
            Restriction to apply to parents before running fetch. Default True.
        multi_source: bool
            Return from multiple parents. Default False.

        Notes
        -----
        Nwb files not strictly returned in same order as self
        """
        if isinstance(self, dict):
            raise ValueError("Try replacing Merge.method with Merge().method")
        restriction = restriction or self.restriction or True
        sources = set((self & restriction).fetch(self._reserved_sk))
        nwb_list = []
        for source in sources:
            source_restr = (
                self & {self._reserved_sk: source} & restriction
            ).fetch("KEY")
            nwb_list.extend(
                self.merge_restrict_class(
                    source_restr, permit_multiple_rows=True
                ).fetch_nwb()
            )
        return nwb_list

    @classmethod
    def merge_get_part(
        cls,
        restriction: str = True,
        join_master: bool = False,
        restrict_part=True,
        multi_source=False,
        return_empties=False,
    ) -> dj.Table:
        """Retrieve part table from a restricted Merge table.

        Note: unlike other Merge Table methods, returns the native table, not
        a FreeTable

        Parameters
        ----------
        restriction: str
            Optional restriction to apply before determining part to return.
            Default True.
        join_master: bool
            Join part with Merge master to show source field. Default False.
        restrict_part: bool
            Apply restriction to part. Default True. If False, return the
            native part table.
        multi_source: bool
            Return multiple parts. Default False.
        return_empties: bool
            Default False. Return empty part tables.

        Returns
        ------
        Union[dj.Table, List[dj.Table]]
            Native part table(s) of Merge. If `multi_source`, returns list.

        Example
        -------
            >>> (MergeTable & restriction).get_part_table()
            >>> MergeTable().merge_get_part(restriction, join_master=True)

        Raises
        ------
        ValueError
            If multiple sources are found, but not expected lists and suggests
            restricting
        """
        sources = [
            cls._part_name(part)  # friendly part name
            for part in cls._merge_restrict_parts(
                restriction=restriction,
                as_objects=False,
                return_empties=return_empties,
                add_invalid_restrict=False,
            )
        ]

        if not multi_source and len(sources) != 1:
            raise ValueError(
                f"Found {len(sources)} potential parts: {sources}\n\t"
                + "Try adding a restriction before invoking `get_part`.\n\t"
                + "Or permitting multiple sources with `multi_source=True`."
            )
        if len(sources) == 0:
            return None

        parts = [
            (
                getattr(cls, source)().restrict(restriction)
                if restrict_part  # Re-apply restriction or don't
                else getattr(cls, source)()
            )
            for source in sources
        ]
        if join_master:
            parts = [cls * part for part in parts]

        return parts if multi_source else parts[0]

    @classmethod
    def merge_get_parent(
        cls,
        restriction: str = True,
        join_master: bool = False,
        multi_source: bool = False,
        return_empties: bool = False,
        add_invalid_restrict: bool = True,
    ) -> dj.FreeTable:
        """Returns a list of part parents with restrictions applied.

        Rather than part tables, we look at parents of those parts, the source
        of the data, and only the rows that have keys inserted in the merge
        table.

        Parameters
        ----------
        restriction: str
            Optional restriction to apply before determining parent to return.
            Default True.
        join_master: bool
            Default False. Join part with Merge master to show uuid and source
        multi_source: bool
            Return multiple parents. Default False.
        return_empties: bool
            Default False. Return empty parent tables.
        add_invalid_restrict: bool
            Default True. Include parent for which the restriction is invalid.

        Returns
        ------
        dj.FreeTable
            Parent of parts of Merge Table as FreeTable.
        """

        part_parents = cls._merge_restrict_parents(
            restriction=restriction,
            as_objects=True,
            return_empties=return_empties,
            add_invalid_restrict=add_invalid_restrict,
        )

        if not multi_source and len(part_parents) != 1:
            raise ValueError(
                f"Found  {len(part_parents)} potential parents: {part_parents}"
                + "\n\tTry adding a string restriction when invoking "
                + "`get_parent`. Or permitting multiple sources with "
                + "`multi_source=True`."
            )

        if join_master:
            part_parents = [cls * part for part in part_parents]

        return part_parents if multi_source else part_parents[0]

    @property
    def source_class_dict(self) -> dict:
        # NOTE: fails if table is aliased in dj.Part but not merge script
        # i.e., must import aliased table as part name
        if not self._source_class_dict:
            module = getmodule(self)
            self._source_class_dict = {
                part_name: getattr(module, part_name)
                for part_name in self.parts(camel_case=True)
            }
        return self._source_class_dict

    def _normalize_source(
        self, source: Union[str, dj.Table, dj.condition.AndList, dict]
    ) -> str:
        fetched_source = None
        if isinstance(source, (Merge, dj.condition.AndList)):
            try:
                fetched_source = (self & source).fetch(self._reserved_sk)
            except DataJointError:
                raise ValueError(f"Unable to find source for {source}")
            source = fetched_source[0]
            if len(fetched_source) > 1:
                logger.warn(f"Multiple sources. Selecting first: {source}.")
        if isinstance(source, dj.Table):
            source = self._part_name(source)
        if isinstance(source, dict):
            source = self._part_name(self.merge_get_parent(source))

        return source

    def merge_get_parent_class(self, source: str) -> dj.Table:
        """Return the class of the parent table for a given CamelCase source.

        Parameters
        ----------
        source: Union[str, dict, dj.Table]
            Accepts a CamelCase name of the source, or key as a dict, or a part
            table.

        Returns
        -------
        dj.Table
            Class instance of the parent table, including class methods.
        """

        ret = self.source_class_dict.get(self._normalize_source(source))
        if not ret:
            logger.error(
                f"No source class found for {source}: \n\t"
                + f"{self.parts(camel_case=True)}"
            )
        return ret

    def merge_restrict_class(
        self, key: dict, permit_multiple_rows: bool = False
    ) -> dj.Table:
        """Returns native parent class, restricted with key."""
        parent = self.merge_get_parent(key)
        parent_key = parent.fetch("KEY", as_dict=True)

        if not permit_multiple_rows and len(parent_key) > 1:
            raise ValueError(
                f"Ambiguous entry. Data has mult rows in parent:\n\tData:{key}"
                + f"\n\t{parent_key}"
            )

        parent_class = self.merge_get_parent_class(parent)
        return parent_class & parent_key

    @classmethod
    def merge_fetch(self, restriction: str = True, *attrs, **kwargs) -> list:
        """Perform a fetch across all parts. If >1 result, return as a list.

        Parameters
        ----------
        restriction: str
            Optional restriction to apply before determining parent to return.
            Default True.
        attrs, kwargs
            arguments passed to DataJoint `fetch` call

        Returns
        -------
        Union[ List[np.array], List[dict], List[pd.DataFrame] ]
            Table contents, with type determined by kwargs
        """
        results = []
        parts = self()._merge_restrict_parts(
            restriction=restriction,
            as_objects=True,
            return_empties=False,
            add_invalid_restrict=False,
        )

        for part in parts:
            try:
                results.extend(part.fetch(*attrs, **kwargs))
            except DataJointError as e:
                logger.warn(
                    f"{e.args[0]} Skipping "
                    + to_camel_case(part.table_name.split("__")[-1])
                )

        # Note: this could collapse results like merge_view, but user may call
        # for recarray, pd.DataFrame, or dict, and fetched contents differ if
        # attrs or "KEY" called. Intercept format, merge, and then transform?

        if not results:
            logger.info(
                "No merge_fetch results.\n\t"
                + "If not restricting, try: `M.merge_fetch(True,'attr')\n\t"
                + "If restricting by source, use dict: "
                + "`M.merge_fetch({'source':'X'}"
            )
        return results[0] if len(results) == 1 else results

    def merge_populate(self, source: str, keys=None):
        """Populate the merge table with entries from the source table."""
        logger.warning("CBroz: Not fully tested. Use with caution.")
        parent_class = self.merge_get_parent_class(source)
        if not keys:
            keys = parent_class.key_source
        parent_class.populate(keys)
        successes = (parent_class & keys).fetch("KEY", as_dict=True)
        self.insert(successes)

    def delete(self, force_permission=False, *args, **kwargs):
        """Alias for cautious_delete, overwrites datajoint.table.Table.delete"""
        if not (
            parts := self.merge_get_part(
                restriction=self.restriction,
                multi_source=True,
                return_empties=False,
            )
        ):
            return

        for part in parts:
            part.delete(force_permission=force_permission, *args, **kwargs)

    def super_delete(self, warn=True, *args, **kwargs):
        """Alias for datajoint.table.Table.delete.

        Added to support MRO of SpyglassMixin
        """
        if warn:
            logger.warning("!! Bypassing cautious_delete !!")
            self._log_delete(start=time(), super_delete=True)
        super().delete(*args, **kwargs)

get_source_from_key(key)

Return the source of a given key

Source code in src/spyglass/utils/dj_merge_tables.py
def get_source_from_key(self, key: dict) -> str:
    """Return the source of a given key"""
    return self._normalize_source(key)

parts(camel_case=False, *args, **kwargs)

Return a list of part tables, add option for CamelCase names.

See DataJoint parts for additional arguments. If camel_case is True, forces return of strings rather than objects.

Source code in src/spyglass/utils/dj_merge_tables.py
def parts(self, camel_case=False, *args, **kwargs) -> list:
    """Return a list of part tables, add option for CamelCase names.

    See DataJoint `parts` for additional arguments. If camel_case is True,
    forces return of strings rather than objects.
    """
    self._ensure_dependencies_loaded()

    if camel_case and kwargs.get("as_objects"):
        logger.warning(
            "Overriding as_objects=True to return CamelCase part names."
        )
        kwargs["as_objects"] = False

    parts = super().parts(*args, **kwargs)

    if camel_case:
        parts = [self._part_name(part) for part in parts]

    return parts

insert(rows, **kwargs)

Merges table specific insert, ensuring data exists in part parents.

Parameters:

Name Type Description Default
rows list

An iterable where an element is a dictionary.

required

Raises:

Type Description
TypeError

If rows is not a list of dicts

ValueError

If data doesn't exist in part parents, integrity error

Source code in src/spyglass/utils/dj_merge_tables.py
def insert(self, rows: list, **kwargs):
    """Merges table specific insert, ensuring data exists in part parents.

    Parameters
    ---------
    rows: List[dict]
        An iterable where an element is a dictionary.

    Raises
    ------
    TypeError
        If rows is not a list of dicts
    ValueError
        If data doesn't exist in part parents, integrity error
    """
    self._merge_insert(rows, **kwargs)

merge_view(restriction=True) classmethod

Prints merged view, including null entries for unique columns.

Note: To handle this Union as a table-like object, use merge_resrict

Parameters:

Name Type Description Default
restriction str

Restriction to apply to the merged view

True
Source code in src/spyglass/utils/dj_merge_tables.py
@classmethod
def merge_view(cls, restriction: str = True):
    """Prints merged view, including null entries for unique columns.

    Note: To handle this Union as a table-like object, use `merge_resrict`

    Parameters
    ---------
    restriction: str, optional
        Restriction to apply to the merged view
    """

    # If we overwrite `preview`, we then encounter issues with operators
    # getting passed a `Union`, which doesn't have a method we can
    # intercept to manage master/parts

    return pprint(cls._merge_repr(restriction=restriction))

merge_html(restriction=True) classmethod

Displays HTML in notebooks.

Source code in src/spyglass/utils/dj_merge_tables.py
@classmethod
def merge_html(cls, restriction: str = True):
    """Displays HTML in notebooks."""

    return HTML(repr_html(cls._merge_repr(restriction=restriction)))

merge_restrict(restriction=True) classmethod

Given a restriction, return a merged view with restriction applied.

Example
>>> MergeTable.merge_restrict("field = 1")

Parameters:

Name Type Description Default
restriction str

Restriction one would apply if merge_view was a real table.

True

Returns:

Type Description
Union

Merged view with restriction applied.

Source code in src/spyglass/utils/dj_merge_tables.py
@classmethod
def merge_restrict(cls, restriction: str = True) -> dj.U:
    """Given a restriction, return a merged view with restriction applied.

    Example
    -------
        >>> MergeTable.merge_restrict("field = 1")

    Parameters
    ----------
    restriction: str
        Restriction one would apply if `merge_view` was a real table.

    Returns
    -------
    datajoint.Union
        Merged view with restriction applied.
    """
    return cls._merge_repr(restriction=restriction)

merge_delete(restriction=True, **kwargs) classmethod

Given a restriction string, delete corresponding entries.

Parameters:

Name Type Description Default
restriction str

Optional restriction to apply before deletion from master/part tables. If not provided, delete all entries.

True
kwargs

Additional keyword arguments for DataJoint delete.

{}
Example
>>> MergeTable.merge_delete("field = 1")
Source code in src/spyglass/utils/dj_merge_tables.py
@classmethod
def merge_delete(cls, restriction: str = True, **kwargs):
    """Given a restriction string, delete corresponding entries.

    Parameters
    ----------
    restriction: str
        Optional restriction to apply before deletion from master/part
        tables. If not provided, delete all entries.
    kwargs: dict
        Additional keyword arguments for DataJoint delete.

    Example
    -------
        >>> MergeTable.merge_delete("field = 1")
    """
    uuids = [
        {k: v}
        for entry in cls.merge_restrict(restriction).fetch("KEY")
        for k, v in entry.items()
        if k == cls()._reserved_pk
    ]
    (cls() & uuids).delete(**kwargs)

merge_delete_parent(restriction=True, dry_run=True, **kwargs) classmethod

Delete entries from merge master, part, and respective part parents

Note: Clears merge entries from their respective parents.

Parameters:

Name Type Description Default
restriction str

Optional restriction to apply before deletion from parents. If not provided, delete all entries present in Merge Table.

True
dry_run

Default True. If true, return list of tables with entries that would be deleted. Otherwise, table entries.

True
kwargs

Additional keyword arguments for DataJoint delete.

{}
Source code in src/spyglass/utils/dj_merge_tables.py
@classmethod
def merge_delete_parent(
    cls, restriction: str = True, dry_run=True, **kwargs
) -> list:
    """Delete entries from merge master, part, and respective part parents

    Note: Clears merge entries from their respective parents.

    Parameters
    ----------
    restriction: str
        Optional restriction to apply before deletion from parents. If not
        provided, delete all entries present in Merge Table.
    dry_run: bool
        Default True. If true, return list of tables with entries that would
        be deleted. Otherwise, table entries.
    kwargs: dict
        Additional keyword arguments for DataJoint delete.
    """
    part_parents = cls._merge_restrict_parents(
        restriction=restriction, as_objects=True, return_empties=False
    )

    if dry_run:
        return part_parents

    merge_ids = cls.merge_restrict(restriction).fetch(
        RESERVED_PRIMARY_KEY, as_dict=True
    )

    # CB: Removed transaction protection here bc 'no' confirmation resp
    # still resulted in deletes. If re-add, consider transaction=False
    super().delete((cls & merge_ids), **kwargs)

    if cls & merge_ids:  # If 'no' on del prompt from above, skip below
        return  # User can still abort del below, but yes/no is unlikly

    for part_parent in part_parents:
        super().delete(part_parent, **kwargs)

fetch_nwb(restriction=None, multi_source=False, disable_warning=False, *attrs, **kwargs)

Return the (Analysis)Nwbfile file linked in the source.

Relies on SpyglassMixin._nwb_table_tuple to determine the table to fetch from and the appropriate path attribute to return.

Parameters:

Name Type Description Default
restriction str

Restriction to apply to parents before running fetch. Default True.

None
multi_source

Return from multiple parents. Default False.

False
Notes

Nwb files not strictly returned in same order as self

Source code in src/spyglass/utils/dj_merge_tables.py
def fetch_nwb(
    self,
    restriction: str = None,
    multi_source=False,
    disable_warning=False,
    *attrs,
    **kwargs,
):
    """Return the (Analysis)Nwbfile file linked in the source.

    Relies on SpyglassMixin._nwb_table_tuple to determine the table to
    fetch from and the appropriate path attribute to return.

    Parameters
    ----------
    restriction: str, optional
        Restriction to apply to parents before running fetch. Default True.
    multi_source: bool
        Return from multiple parents. Default False.

    Notes
    -----
    Nwb files not strictly returned in same order as self
    """
    if isinstance(self, dict):
        raise ValueError("Try replacing Merge.method with Merge().method")
    restriction = restriction or self.restriction or True
    sources = set((self & restriction).fetch(self._reserved_sk))
    nwb_list = []
    for source in sources:
        source_restr = (
            self & {self._reserved_sk: source} & restriction
        ).fetch("KEY")
        nwb_list.extend(
            self.merge_restrict_class(
                source_restr, permit_multiple_rows=True
            ).fetch_nwb()
        )
    return nwb_list

merge_get_part(restriction=True, join_master=False, restrict_part=True, multi_source=False, return_empties=False) classmethod

Retrieve part table from a restricted Merge table.

Note: unlike other Merge Table methods, returns the native table, not a FreeTable

Parameters:

Name Type Description Default
restriction str

Optional restriction to apply before determining part to return. Default True.

True
join_master bool

Join part with Merge master to show source field. Default False.

False
restrict_part

Apply restriction to part. Default True. If False, return the native part table.

True
multi_source

Return multiple parts. Default False.

False
return_empties

Default False. Return empty part tables.

False

Returns:

Type Description
Union[Table, List[Table]]

Native part table(s) of Merge. If multi_source, returns list.

Example
>>> (MergeTable & restriction).get_part_table()
>>> MergeTable().merge_get_part(restriction, join_master=True)

Raises:

Type Description
ValueError

If multiple sources are found, but not expected lists and suggests restricting

Source code in src/spyglass/utils/dj_merge_tables.py
@classmethod
def merge_get_part(
    cls,
    restriction: str = True,
    join_master: bool = False,
    restrict_part=True,
    multi_source=False,
    return_empties=False,
) -> dj.Table:
    """Retrieve part table from a restricted Merge table.

    Note: unlike other Merge Table methods, returns the native table, not
    a FreeTable

    Parameters
    ----------
    restriction: str
        Optional restriction to apply before determining part to return.
        Default True.
    join_master: bool
        Join part with Merge master to show source field. Default False.
    restrict_part: bool
        Apply restriction to part. Default True. If False, return the
        native part table.
    multi_source: bool
        Return multiple parts. Default False.
    return_empties: bool
        Default False. Return empty part tables.

    Returns
    ------
    Union[dj.Table, List[dj.Table]]
        Native part table(s) of Merge. If `multi_source`, returns list.

    Example
    -------
        >>> (MergeTable & restriction).get_part_table()
        >>> MergeTable().merge_get_part(restriction, join_master=True)

    Raises
    ------
    ValueError
        If multiple sources are found, but not expected lists and suggests
        restricting
    """
    sources = [
        cls._part_name(part)  # friendly part name
        for part in cls._merge_restrict_parts(
            restriction=restriction,
            as_objects=False,
            return_empties=return_empties,
            add_invalid_restrict=False,
        )
    ]

    if not multi_source and len(sources) != 1:
        raise ValueError(
            f"Found {len(sources)} potential parts: {sources}\n\t"
            + "Try adding a restriction before invoking `get_part`.\n\t"
            + "Or permitting multiple sources with `multi_source=True`."
        )
    if len(sources) == 0:
        return None

    parts = [
        (
            getattr(cls, source)().restrict(restriction)
            if restrict_part  # Re-apply restriction or don't
            else getattr(cls, source)()
        )
        for source in sources
    ]
    if join_master:
        parts = [cls * part for part in parts]

    return parts if multi_source else parts[0]

merge_get_parent(restriction=True, join_master=False, multi_source=False, return_empties=False, add_invalid_restrict=True) classmethod

Returns a list of part parents with restrictions applied.

Rather than part tables, we look at parents of those parts, the source of the data, and only the rows that have keys inserted in the merge table.

Parameters:

Name Type Description Default
restriction str

Optional restriction to apply before determining parent to return. Default True.

True
join_master bool

Default False. Join part with Merge master to show uuid and source

False
multi_source bool

Return multiple parents. Default False.

False
return_empties bool

Default False. Return empty parent tables.

False
add_invalid_restrict bool

Default True. Include parent for which the restriction is invalid.

True

Returns:

Type Description
FreeTable

Parent of parts of Merge Table as FreeTable.

Source code in src/spyglass/utils/dj_merge_tables.py
@classmethod
def merge_get_parent(
    cls,
    restriction: str = True,
    join_master: bool = False,
    multi_source: bool = False,
    return_empties: bool = False,
    add_invalid_restrict: bool = True,
) -> dj.FreeTable:
    """Returns a list of part parents with restrictions applied.

    Rather than part tables, we look at parents of those parts, the source
    of the data, and only the rows that have keys inserted in the merge
    table.

    Parameters
    ----------
    restriction: str
        Optional restriction to apply before determining parent to return.
        Default True.
    join_master: bool
        Default False. Join part with Merge master to show uuid and source
    multi_source: bool
        Return multiple parents. Default False.
    return_empties: bool
        Default False. Return empty parent tables.
    add_invalid_restrict: bool
        Default True. Include parent for which the restriction is invalid.

    Returns
    ------
    dj.FreeTable
        Parent of parts of Merge Table as FreeTable.
    """

    part_parents = cls._merge_restrict_parents(
        restriction=restriction,
        as_objects=True,
        return_empties=return_empties,
        add_invalid_restrict=add_invalid_restrict,
    )

    if not multi_source and len(part_parents) != 1:
        raise ValueError(
            f"Found  {len(part_parents)} potential parents: {part_parents}"
            + "\n\tTry adding a string restriction when invoking "
            + "`get_parent`. Or permitting multiple sources with "
            + "`multi_source=True`."
        )

    if join_master:
        part_parents = [cls * part for part in part_parents]

    return part_parents if multi_source else part_parents[0]

merge_get_parent_class(source)

Return the class of the parent table for a given CamelCase source.

Parameters:

Name Type Description Default
source str

Accepts a CamelCase name of the source, or key as a dict, or a part table.

required

Returns:

Type Description
Table

Class instance of the parent table, including class methods.

Source code in src/spyglass/utils/dj_merge_tables.py
def merge_get_parent_class(self, source: str) -> dj.Table:
    """Return the class of the parent table for a given CamelCase source.

    Parameters
    ----------
    source: Union[str, dict, dj.Table]
        Accepts a CamelCase name of the source, or key as a dict, or a part
        table.

    Returns
    -------
    dj.Table
        Class instance of the parent table, including class methods.
    """

    ret = self.source_class_dict.get(self._normalize_source(source))
    if not ret:
        logger.error(
            f"No source class found for {source}: \n\t"
            + f"{self.parts(camel_case=True)}"
        )
    return ret

merge_restrict_class(key, permit_multiple_rows=False)

Returns native parent class, restricted with key.

Source code in src/spyglass/utils/dj_merge_tables.py
def merge_restrict_class(
    self, key: dict, permit_multiple_rows: bool = False
) -> dj.Table:
    """Returns native parent class, restricted with key."""
    parent = self.merge_get_parent(key)
    parent_key = parent.fetch("KEY", as_dict=True)

    if not permit_multiple_rows and len(parent_key) > 1:
        raise ValueError(
            f"Ambiguous entry. Data has mult rows in parent:\n\tData:{key}"
            + f"\n\t{parent_key}"
        )

    parent_class = self.merge_get_parent_class(parent)
    return parent_class & parent_key

merge_fetch(restriction=True, *attrs, **kwargs) classmethod

Perform a fetch across all parts. If >1 result, return as a list.

Parameters:

Name Type Description Default
restriction str

Optional restriction to apply before determining parent to return. Default True.

True
attrs

arguments passed to DataJoint fetch call

()
kwargs

arguments passed to DataJoint fetch call

()

Returns:

Type Description
Union[List[array], List[dict], List[DataFrame]]

Table contents, with type determined by kwargs

Source code in src/spyglass/utils/dj_merge_tables.py
@classmethod
def merge_fetch(self, restriction: str = True, *attrs, **kwargs) -> list:
    """Perform a fetch across all parts. If >1 result, return as a list.

    Parameters
    ----------
    restriction: str
        Optional restriction to apply before determining parent to return.
        Default True.
    attrs, kwargs
        arguments passed to DataJoint `fetch` call

    Returns
    -------
    Union[ List[np.array], List[dict], List[pd.DataFrame] ]
        Table contents, with type determined by kwargs
    """
    results = []
    parts = self()._merge_restrict_parts(
        restriction=restriction,
        as_objects=True,
        return_empties=False,
        add_invalid_restrict=False,
    )

    for part in parts:
        try:
            results.extend(part.fetch(*attrs, **kwargs))
        except DataJointError as e:
            logger.warn(
                f"{e.args[0]} Skipping "
                + to_camel_case(part.table_name.split("__")[-1])
            )

    # Note: this could collapse results like merge_view, but user may call
    # for recarray, pd.DataFrame, or dict, and fetched contents differ if
    # attrs or "KEY" called. Intercept format, merge, and then transform?

    if not results:
        logger.info(
            "No merge_fetch results.\n\t"
            + "If not restricting, try: `M.merge_fetch(True,'attr')\n\t"
            + "If restricting by source, use dict: "
            + "`M.merge_fetch({'source':'X'}"
        )
    return results[0] if len(results) == 1 else results

merge_populate(source, keys=None)

Populate the merge table with entries from the source table.

Source code in src/spyglass/utils/dj_merge_tables.py
def merge_populate(self, source: str, keys=None):
    """Populate the merge table with entries from the source table."""
    logger.warning("CBroz: Not fully tested. Use with caution.")
    parent_class = self.merge_get_parent_class(source)
    if not keys:
        keys = parent_class.key_source
    parent_class.populate(keys)
    successes = (parent_class & keys).fetch("KEY", as_dict=True)
    self.insert(successes)

delete(force_permission=False, *args, **kwargs)

Alias for cautious_delete, overwrites datajoint.table.Table.delete

Source code in src/spyglass/utils/dj_merge_tables.py
def delete(self, force_permission=False, *args, **kwargs):
    """Alias for cautious_delete, overwrites datajoint.table.Table.delete"""
    if not (
        parts := self.merge_get_part(
            restriction=self.restriction,
            multi_source=True,
            return_empties=False,
        )
    ):
        return

    for part in parts:
        part.delete(force_permission=force_permission, *args, **kwargs)

super_delete(warn=True, *args, **kwargs)

Alias for datajoint.table.Table.delete.

Added to support MRO of SpyglassMixin

Source code in src/spyglass/utils/dj_merge_tables.py
def super_delete(self, warn=True, *args, **kwargs):
    """Alias for datajoint.table.Table.delete.

    Added to support MRO of SpyglassMixin
    """
    if warn:
        logger.warning("!! Bypassing cautious_delete !!")
        self._log_delete(start=time(), super_delete=True)
    super().delete(*args, **kwargs)

delete_downstream_merge(table, **kwargs)

Given a table/restriction, id or delete relevant downstream merge entries

Passthrough to SpyglassMixin.delete_downstream_parts

Source code in src/spyglass/utils/dj_merge_tables.py
def delete_downstream_merge(
    table: dj.Table,
    **kwargs,
) -> list:
    """Given a table/restriction, id or delete relevant downstream merge entries

    Passthrough to SpyglassMixin.delete_downstream_parts
    """
    from spyglass.common.common_usage import ActivityLog
    from spyglass.utils.dj_mixin import SpyglassMixin

    ActivityLog().deprecate_log(
        "delete_downstream_merge. Use Table.delete_downstream_merge"
    )

    if not isinstance(table, SpyglassMixin):
        raise ValueError("Input must be a Spyglass Table.")
    table = table if isinstance(table, dj.Table) else table()

    return table.delete_downstream_parts(**kwargs)