From ba3f143dc92fb9b8cdbe98ec27b99a1cfe34b470 Mon Sep 17 00:00:00 2001 From: Ethan Paul <24588726+enpaul@users.noreply.github.com> Date: Wed, 24 Nov 2021 22:01:30 -0500 Subject: [PATCH] Add enum field for storing enum references in the database --- peewee_plus.py | 56 ++++++++++++++++++++++++++++++++++++++++- tests/test_enumfield.py | 53 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 108 insertions(+), 1 deletion(-) create mode 100644 tests/test_enumfield.py diff --git a/peewee_plus.py b/peewee_plus.py index 79caaba..e8696a9 100644 --- a/peewee_plus.py +++ b/peewee_plus.py @@ -1,8 +1,10 @@ +import enum import json from pathlib import Path from typing import Any from typing import Dict from typing import Optional +from typing import Type import peewee @@ -14,7 +16,7 @@ __url__ = "https://github.com/enpaul/peewee-plus/" __authors__ = ["Ethan Paul <24588726+enpaul@users.noreply.github.com>"] -__all__ = ["PathField", "PrecisionFloatField", "JSONField"] +__all__ = ["PathField", "PrecisionFloatField", "JSONField", "EnumField"] class PathField(peewee.CharField): @@ -162,3 +164,55 @@ class JSONField(peewee.TextField): raise peewee.IntegrityError( f"Failed to decode JSON value from database column '{self.column}'" ) from err + + +class EnumField(peewee.CharField): + """Field class for storing Enums + + This field can be used for storing members of an :class:`enum.Enum` in the database, + effectively storing a database reference to a value defined in the application. + + .. warning:: This field ties database data to application structure: if the Enum passed + to this field is modified then the application may encounter errors when + trying to interface with the database schema. + + :: + + >>> class MyOptions(enum.Enum): + ... FOO = "have you ever heard the tragedy" + ... BAR = "of darth plageius" + ... BAZ = "the wise?" + ... + >>> + >>> class MyModel(peewee.Model): + ... option = EnumField(MyOptions) + ... + >>> m = MyModel(option=MyOptions.FOO) + >>> m.save() + >>> m.option + + >>> + + :param enumeration: The Enum to accept members of and to use for decoding database values + :raises TypeError: If the value to be written to the field is not a member of the + specified Enum + :raises peewee.IntegrityError: If the value read back from the database cannot be decoded to + a member of the specified Enum + """ + + def __init__(self, enumeration: Type[enum.Enum], *args, **kwargs): + super().__init__(*args, **kwargs) + self.enumeration = enumeration + + def db_value(self, value: enum.Enum) -> str: + if not isinstance(value, self.enumeration): + raise TypeError(f"Enum {self.enumeration.__name__} has no value '{value}'") + return super().db_value(value.name) + + def python_value(self, value: str) -> enum.Enum: + try: + return self.enumeration[super().python_value(value)] + except KeyError: + raise peewee.IntegrityError( + f"Enum {self.enumeration.__name__} has no value with name '{value}'" + ) from None diff --git a/tests/test_enumfield.py b/tests/test_enumfield.py new file mode 100644 index 0000000..a29c5b1 --- /dev/null +++ b/tests/test_enumfield.py @@ -0,0 +1,53 @@ +# pylint: disable=redefined-outer-name +# pylint: disable=missing-class-docstring +# pylint: disable=too-few-public-methods +# pylint: disable=unused-import +import enum + +import peewee +import pytest + +import peewee_plus +from .fixtures import fakedb + + +def test_enum(fakedb): + """Test basic functionality of the enum field""" + + class TestEnum(enum.Enum): + FOO = "fizz" + BAR = "buzz" + + class TestModel(peewee.Model): + class Meta: + database = fakedb + + data = peewee_plus.EnumField(TestEnum) + + fakedb.create_tables([TestModel]) + + model = TestModel(data=TestEnum.FOO) + model.save() + + model = TestModel.get() + assert model.data == TestEnum.FOO + + class ModifiedEnum(enum.Enum): + BAR = "buzz" + + class ModifiedModel(peewee.Model): + class Meta: + table_name = TestModel._meta.table_name # pylint: disable=protected-access + database = fakedb + + data = peewee_plus.EnumField(ModifiedEnum) + + with pytest.raises(peewee.IntegrityError): + ModifiedModel.get() + + class BadEnum(enum.Enum): + NOTHING = "nowhere" + + with pytest.raises(TypeError): + bad = TestModel(data=BadEnum.NOTHING) + bad.save()