Skip to content

Base mixin

Utility base class for dataclasses.

BaseMixin

Bases: ABC

Mixin base class for dataclasses.

Source code in src/snailz/_base_mixin.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
class BaseMixin(ABC):
    """Mixin base class for dataclasses."""

    def persistable(self) -> dict:
        """
        Create persistable dictionary from object by ignoring all keys
        listed in class-level `pivot_keys` member.
        """

        return {key: self.__dict__[key] for key in self.persistable_keys()}

    def not_null_keys(self) -> set:
        """Generate set of keys for non-null values in object."""

        nullable_keys = getattr(self, "nullable_keys", set())
        return {key for key in self.persistable_keys() if key not in nullable_keys}

    def persistable_keys(self) -> list[str]:
        """
        Generate list of keys to persist for object by ignoring all
        keys listed in class-level `pivot_keys` member.
        """

        pivot_keys = getattr(self, "pivot_keys", set())
        return [key for key in self.__dict__.keys() if key not in pivot_keys]

    @classmethod
    def save_csv(cls, outdir: Path | str, objects: list):
        """
        Save objects of derived class as CSV. Derived classes should
        override this and up-call to save scalar properties, then save
        properties that need to be pivoted to long form.

        Args:
            outdir: Output directory.
            objects: Objects to save.
        """

        assert all(isinstance(obj, cls) for obj in objects)
        with open(Path(outdir, f"{cls.table_name()}.csv"), "w", newline="") as stream:
            writer = cls._csv_dict_writer(stream, objects[0].persistable_keys())
            for obj in objects:
                writer.writerow(obj.persistable())

    @classmethod
    def save_db(cls, db: Database, objects: list):
        """
        Save objects of derived class to database. Derived classes should
        override this and up-call to save scalar properties, then save
        properties that need to be pivoted to long form.

        Args:
            db: Database connector.
            objects: Objects to save.
        """

        assert all(isinstance(obj, cls) for obj in objects)
        table = db[cls.table_name()]
        primary_key = getattr(cls, "primary_key", None)
        foreign_keys = getattr(cls, "foreign_keys", [])
        table.insert_all(  # type: ignore[possibly-missing-attribute]
            (obj.persistable() for obj in objects),
            pk=primary_key,
            foreign_keys=foreign_keys,
        )
        table.transform(  # type: ignore[possibly-missing-attribute]
            not_null=objects[0].not_null_keys()
        )

    @classmethod
    @abstractmethod
    def table_name(cls) -> str:
        """Database table name."""
        pass

    @classmethod
    def _csv_dict_writer(cls, stream: TextIO, fieldnames: list[str]) -> DictWriter:
        """
        Construct a CSV dict writer with default properties.

        Args:
            stream: Writeable stream to wrap.
            fieldnames: List of fields to be persisted.

        Returns:
            CSV dict writer.
        """

        writer = DictWriter(stream, fieldnames=fieldnames, lineterminator="\n")
        writer.writeheader()
        return writer

persistable()

Create persistable dictionary from object by ignoring all keys listed in class-level pivot_keys member.

Source code in src/snailz/_base_mixin.py
13
14
15
16
17
18
19
def persistable(self) -> dict:
    """
    Create persistable dictionary from object by ignoring all keys
    listed in class-level `pivot_keys` member.
    """

    return {key: self.__dict__[key] for key in self.persistable_keys()}

not_null_keys()

Generate set of keys for non-null values in object.

Source code in src/snailz/_base_mixin.py
21
22
23
24
25
def not_null_keys(self) -> set:
    """Generate set of keys for non-null values in object."""

    nullable_keys = getattr(self, "nullable_keys", set())
    return {key for key in self.persistable_keys() if key not in nullable_keys}

persistable_keys()

Generate list of keys to persist for object by ignoring all keys listed in class-level pivot_keys member.

Source code in src/snailz/_base_mixin.py
27
28
29
30
31
32
33
34
def persistable_keys(self) -> list[str]:
    """
    Generate list of keys to persist for object by ignoring all
    keys listed in class-level `pivot_keys` member.
    """

    pivot_keys = getattr(self, "pivot_keys", set())
    return [key for key in self.__dict__.keys() if key not in pivot_keys]

save_csv(outdir, objects) classmethod

Save objects of derived class as CSV. Derived classes should override this and up-call to save scalar properties, then save properties that need to be pivoted to long form.

Parameters:

Name Type Description Default
outdir Path | str

Output directory.

required
objects list

Objects to save.

required
Source code in src/snailz/_base_mixin.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
@classmethod
def save_csv(cls, outdir: Path | str, objects: list):
    """
    Save objects of derived class as CSV. Derived classes should
    override this and up-call to save scalar properties, then save
    properties that need to be pivoted to long form.

    Args:
        outdir: Output directory.
        objects: Objects to save.
    """

    assert all(isinstance(obj, cls) for obj in objects)
    with open(Path(outdir, f"{cls.table_name()}.csv"), "w", newline="") as stream:
        writer = cls._csv_dict_writer(stream, objects[0].persistable_keys())
        for obj in objects:
            writer.writerow(obj.persistable())

save_db(db, objects) classmethod

Save objects of derived class to database. Derived classes should override this and up-call to save scalar properties, then save properties that need to be pivoted to long form.

Parameters:

Name Type Description Default
db Database

Database connector.

required
objects list

Objects to save.

required
Source code in src/snailz/_base_mixin.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
@classmethod
def save_db(cls, db: Database, objects: list):
    """
    Save objects of derived class to database. Derived classes should
    override this and up-call to save scalar properties, then save
    properties that need to be pivoted to long form.

    Args:
        db: Database connector.
        objects: Objects to save.
    """

    assert all(isinstance(obj, cls) for obj in objects)
    table = db[cls.table_name()]
    primary_key = getattr(cls, "primary_key", None)
    foreign_keys = getattr(cls, "foreign_keys", [])
    table.insert_all(  # type: ignore[possibly-missing-attribute]
        (obj.persistable() for obj in objects),
        pk=primary_key,
        foreign_keys=foreign_keys,
    )
    table.transform(  # type: ignore[possibly-missing-attribute]
        not_null=objects[0].not_null_keys()
    )

table_name() abstractmethod classmethod

Database table name.

Source code in src/snailz/_base_mixin.py
79
80
81
82
83
@classmethod
@abstractmethod
def table_name(cls) -> str:
    """Database table name."""
    pass