Add enum field for storing enum references in the database

This commit is contained in:
Ethan Paul 2021-11-24 22:01:30 -05:00
parent 19b507416d
commit ba3f143dc9
No known key found for this signature in database
GPG Key ID: D0E2CBF1245E92BF
2 changed files with 108 additions and 1 deletions

View File

@ -1,8 +1,10 @@
import enum
import json
from pathlib import Path
from typing import Any
from typing import Dict
from typing import Optional
from typing import Type
import peewee
@ -14,7 +16,7 @@ __url__ = "https://github.com/enpaul/peewee-plus/"
__authors__ = ["Ethan Paul <24588726+enpaul@users.noreply.github.com>"]
__all__ = ["PathField", "PrecisionFloatField", "JSONField"]
__all__ = ["PathField", "PrecisionFloatField", "JSONField", "EnumField"]
class PathField(peewee.CharField):
@ -162,3 +164,55 @@ class JSONField(peewee.TextField):
raise peewee.IntegrityError(
f"Failed to decode JSON value from database column '{self.column}'"
) from err
class EnumField(peewee.CharField):
"""Field class for storing Enums
This field can be used for storing members of an :class:`enum.Enum` in the database,
effectively storing a database reference to a value defined in the application.
.. warning:: This field ties database data to application structure: if the Enum passed
to this field is modified then the application may encounter errors when
trying to interface with the database schema.
::
>>> class MyOptions(enum.Enum):
... FOO = "have you ever heard the tragedy"
... BAR = "of darth plageius"
... BAZ = "the wise?"
...
>>>
>>> class MyModel(peewee.Model):
... option = EnumField(MyOptions)
...
>>> m = MyModel(option=MyOptions.FOO)
>>> m.save()
>>> m.option
<MyOptions.FOO: "have you ever heard the tragedy">
>>>
:param enumeration: The Enum to accept members of and to use for decoding database values
:raises TypeError: If the value to be written to the field is not a member of the
specified Enum
:raises peewee.IntegrityError: If the value read back from the database cannot be decoded to
a member of the specified Enum
"""
def __init__(self, enumeration: Type[enum.Enum], *args, **kwargs):
super().__init__(*args, **kwargs)
self.enumeration = enumeration
def db_value(self, value: enum.Enum) -> str:
if not isinstance(value, self.enumeration):
raise TypeError(f"Enum {self.enumeration.__name__} has no value '{value}'")
return super().db_value(value.name)
def python_value(self, value: str) -> enum.Enum:
try:
return self.enumeration[super().python_value(value)]
except KeyError:
raise peewee.IntegrityError(
f"Enum {self.enumeration.__name__} has no value with name '{value}'"
) from None

53
tests/test_enumfield.py Normal file
View File

@ -0,0 +1,53 @@
# pylint: disable=redefined-outer-name
# pylint: disable=missing-class-docstring
# pylint: disable=too-few-public-methods
# pylint: disable=unused-import
import enum
import peewee
import pytest
import peewee_plus
from .fixtures import fakedb
def test_enum(fakedb):
"""Test basic functionality of the enum field"""
class TestEnum(enum.Enum):
FOO = "fizz"
BAR = "buzz"
class TestModel(peewee.Model):
class Meta:
database = fakedb
data = peewee_plus.EnumField(TestEnum)
fakedb.create_tables([TestModel])
model = TestModel(data=TestEnum.FOO)
model.save()
model = TestModel.get()
assert model.data == TestEnum.FOO
class ModifiedEnum(enum.Enum):
BAR = "buzz"
class ModifiedModel(peewee.Model):
class Meta:
table_name = TestModel._meta.table_name # pylint: disable=protected-access
database = fakedb
data = peewee_plus.EnumField(ModifiedEnum)
with pytest.raises(peewee.IntegrityError):
ModifiedModel.get()
class BadEnum(enum.Enum):
NOTHING = "nowhere"
with pytest.raises(TypeError):
bad = TestModel(data=BadEnum.NOTHING)
bad.save()