mirror of
https://github.com/enpaul/peewee-plus.git
synced 2024-11-21 22:16:54 +00:00
Add enum field for storing enum references in the database
This commit is contained in:
parent
19b507416d
commit
ba3f143dc9
@ -1,8 +1,10 @@
|
|||||||
|
import enum
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
import peewee
|
import peewee
|
||||||
|
|
||||||
@ -14,7 +16,7 @@ __url__ = "https://github.com/enpaul/peewee-plus/"
|
|||||||
__authors__ = ["Ethan Paul <24588726+enpaul@users.noreply.github.com>"]
|
__authors__ = ["Ethan Paul <24588726+enpaul@users.noreply.github.com>"]
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["PathField", "PrecisionFloatField", "JSONField"]
|
__all__ = ["PathField", "PrecisionFloatField", "JSONField", "EnumField"]
|
||||||
|
|
||||||
|
|
||||||
class PathField(peewee.CharField):
|
class PathField(peewee.CharField):
|
||||||
@ -162,3 +164,55 @@ class JSONField(peewee.TextField):
|
|||||||
raise peewee.IntegrityError(
|
raise peewee.IntegrityError(
|
||||||
f"Failed to decode JSON value from database column '{self.column}'"
|
f"Failed to decode JSON value from database column '{self.column}'"
|
||||||
) from err
|
) from err
|
||||||
|
|
||||||
|
|
||||||
|
class EnumField(peewee.CharField):
|
||||||
|
"""Field class for storing Enums
|
||||||
|
|
||||||
|
This field can be used for storing members of an :class:`enum.Enum` in the database,
|
||||||
|
effectively storing a database reference to a value defined in the application.
|
||||||
|
|
||||||
|
.. warning:: This field ties database data to application structure: if the Enum passed
|
||||||
|
to this field is modified then the application may encounter errors when
|
||||||
|
trying to interface with the database schema.
|
||||||
|
|
||||||
|
::
|
||||||
|
|
||||||
|
>>> class MyOptions(enum.Enum):
|
||||||
|
... FOO = "have you ever heard the tragedy"
|
||||||
|
... BAR = "of darth plageius"
|
||||||
|
... BAZ = "the wise?"
|
||||||
|
...
|
||||||
|
>>>
|
||||||
|
>>> class MyModel(peewee.Model):
|
||||||
|
... option = EnumField(MyOptions)
|
||||||
|
...
|
||||||
|
>>> m = MyModel(option=MyOptions.FOO)
|
||||||
|
>>> m.save()
|
||||||
|
>>> m.option
|
||||||
|
<MyOptions.FOO: "have you ever heard the tragedy">
|
||||||
|
>>>
|
||||||
|
|
||||||
|
:param enumeration: The Enum to accept members of and to use for decoding database values
|
||||||
|
:raises TypeError: If the value to be written to the field is not a member of the
|
||||||
|
specified Enum
|
||||||
|
:raises peewee.IntegrityError: If the value read back from the database cannot be decoded to
|
||||||
|
a member of the specified Enum
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, enumeration: Type[enum.Enum], *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.enumeration = enumeration
|
||||||
|
|
||||||
|
def db_value(self, value: enum.Enum) -> str:
|
||||||
|
if not isinstance(value, self.enumeration):
|
||||||
|
raise TypeError(f"Enum {self.enumeration.__name__} has no value '{value}'")
|
||||||
|
return super().db_value(value.name)
|
||||||
|
|
||||||
|
def python_value(self, value: str) -> enum.Enum:
|
||||||
|
try:
|
||||||
|
return self.enumeration[super().python_value(value)]
|
||||||
|
except KeyError:
|
||||||
|
raise peewee.IntegrityError(
|
||||||
|
f"Enum {self.enumeration.__name__} has no value with name '{value}'"
|
||||||
|
) from None
|
||||||
|
53
tests/test_enumfield.py
Normal file
53
tests/test_enumfield.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
# pylint: disable=redefined-outer-name
|
||||||
|
# pylint: disable=missing-class-docstring
|
||||||
|
# pylint: disable=too-few-public-methods
|
||||||
|
# pylint: disable=unused-import
|
||||||
|
import enum
|
||||||
|
|
||||||
|
import peewee
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import peewee_plus
|
||||||
|
from .fixtures import fakedb
|
||||||
|
|
||||||
|
|
||||||
|
def test_enum(fakedb):
|
||||||
|
"""Test basic functionality of the enum field"""
|
||||||
|
|
||||||
|
class TestEnum(enum.Enum):
|
||||||
|
FOO = "fizz"
|
||||||
|
BAR = "buzz"
|
||||||
|
|
||||||
|
class TestModel(peewee.Model):
|
||||||
|
class Meta:
|
||||||
|
database = fakedb
|
||||||
|
|
||||||
|
data = peewee_plus.EnumField(TestEnum)
|
||||||
|
|
||||||
|
fakedb.create_tables([TestModel])
|
||||||
|
|
||||||
|
model = TestModel(data=TestEnum.FOO)
|
||||||
|
model.save()
|
||||||
|
|
||||||
|
model = TestModel.get()
|
||||||
|
assert model.data == TestEnum.FOO
|
||||||
|
|
||||||
|
class ModifiedEnum(enum.Enum):
|
||||||
|
BAR = "buzz"
|
||||||
|
|
||||||
|
class ModifiedModel(peewee.Model):
|
||||||
|
class Meta:
|
||||||
|
table_name = TestModel._meta.table_name # pylint: disable=protected-access
|
||||||
|
database = fakedb
|
||||||
|
|
||||||
|
data = peewee_plus.EnumField(ModifiedEnum)
|
||||||
|
|
||||||
|
with pytest.raises(peewee.IntegrityError):
|
||||||
|
ModifiedModel.get()
|
||||||
|
|
||||||
|
class BadEnum(enum.Enum):
|
||||||
|
NOTHING = "nowhere"
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
bad = TestModel(data=BadEnum.NOTHING)
|
||||||
|
bad.save()
|
Loading…
Reference in New Issue
Block a user