Add enum field for storing enum references in the database

This commit is contained in:
Ethan Paul 2021-11-24 22:01:30 -05:00
parent 19b507416d
commit ba3f143dc9
No known key found for this signature in database
GPG Key ID: D0E2CBF1245E92BF
2 changed files with 108 additions and 1 deletions

View File

@ -1,8 +1,10 @@
import enum
import json import json
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from typing import Dict from typing import Dict
from typing import Optional from typing import Optional
from typing import Type
import peewee import peewee
@ -14,7 +16,7 @@ __url__ = "https://github.com/enpaul/peewee-plus/"
__authors__ = ["Ethan Paul <24588726+enpaul@users.noreply.github.com>"] __authors__ = ["Ethan Paul <24588726+enpaul@users.noreply.github.com>"]
__all__ = ["PathField", "PrecisionFloatField", "JSONField"] __all__ = ["PathField", "PrecisionFloatField", "JSONField", "EnumField"]
class PathField(peewee.CharField): class PathField(peewee.CharField):
@ -162,3 +164,55 @@ class JSONField(peewee.TextField):
raise peewee.IntegrityError( raise peewee.IntegrityError(
f"Failed to decode JSON value from database column '{self.column}'" f"Failed to decode JSON value from database column '{self.column}'"
) from err ) from err
class EnumField(peewee.CharField):
"""Field class for storing Enums
This field can be used for storing members of an :class:`enum.Enum` in the database,
effectively storing a database reference to a value defined in the application.
.. warning:: This field ties database data to application structure: if the Enum passed
to this field is modified then the application may encounter errors when
trying to interface with the database schema.
::
>>> class MyOptions(enum.Enum):
... FOO = "have you ever heard the tragedy"
... BAR = "of darth plageius"
... BAZ = "the wise?"
...
>>>
>>> class MyModel(peewee.Model):
... option = EnumField(MyOptions)
...
>>> m = MyModel(option=MyOptions.FOO)
>>> m.save()
>>> m.option
<MyOptions.FOO: "have you ever heard the tragedy">
>>>
:param enumeration: The Enum to accept members of and to use for decoding database values
:raises TypeError: If the value to be written to the field is not a member of the
specified Enum
:raises peewee.IntegrityError: If the value read back from the database cannot be decoded to
a member of the specified Enum
"""
def __init__(self, enumeration: Type[enum.Enum], *args, **kwargs):
super().__init__(*args, **kwargs)
self.enumeration = enumeration
def db_value(self, value: enum.Enum) -> str:
if not isinstance(value, self.enumeration):
raise TypeError(f"Enum {self.enumeration.__name__} has no value '{value}'")
return super().db_value(value.name)
def python_value(self, value: str) -> enum.Enum:
try:
return self.enumeration[super().python_value(value)]
except KeyError:
raise peewee.IntegrityError(
f"Enum {self.enumeration.__name__} has no value with name '{value}'"
) from None

53
tests/test_enumfield.py Normal file
View File

@ -0,0 +1,53 @@
# pylint: disable=redefined-outer-name
# pylint: disable=missing-class-docstring
# pylint: disable=too-few-public-methods
# pylint: disable=unused-import
import enum
import peewee
import pytest
import peewee_plus
from .fixtures import fakedb
def test_enum(fakedb):
"""Test basic functionality of the enum field"""
class TestEnum(enum.Enum):
FOO = "fizz"
BAR = "buzz"
class TestModel(peewee.Model):
class Meta:
database = fakedb
data = peewee_plus.EnumField(TestEnum)
fakedb.create_tables([TestModel])
model = TestModel(data=TestEnum.FOO)
model.save()
model = TestModel.get()
assert model.data == TestEnum.FOO
class ModifiedEnum(enum.Enum):
BAR = "buzz"
class ModifiedModel(peewee.Model):
class Meta:
table_name = TestModel._meta.table_name # pylint: disable=protected-access
database = fakedb
data = peewee_plus.EnumField(ModifiedEnum)
with pytest.raises(peewee.IntegrityError):
ModifiedModel.get()
class BadEnum(enum.Enum):
NOTHING = "nowhere"
with pytest.raises(TypeError):
bad = TestModel(data=BadEnum.NOTHING)
bad.save()