diff --git a/peewee_plus.py b/peewee_plus.py index fd7b522..79caaba 100644 --- a/peewee_plus.py +++ b/peewee_plus.py @@ -1,4 +1,7 @@ +import json from pathlib import Path +from typing import Any +from typing import Dict from typing import Optional import peewee @@ -11,7 +14,7 @@ __url__ = "https://github.com/enpaul/peewee-plus/" __authors__ = ["Ethan Paul <24588726+enpaul@users.noreply.github.com>"] -__all__ = ["PathField", "PrecisionFloatField"] +__all__ = ["PathField", "PrecisionFloatField", "JSONField"] class PathField(peewee.CharField): @@ -102,3 +105,60 @@ class PrecisionFloatField(peewee.FloatField): def get_modifiers(self): return [self.max_digits, self.decimal_places] + + +class JSONField(peewee.TextField): + """Field class for storing JSON-serializable data + + This field can be used to store a dictionary of data directly in the database without needing + without needing to call :func:`json.dumps` and :func:`json.loads` directly. + + :: + + >>> class MyModel(peewee.Model): + ... some_data = JSONField() + ... + >>> m = MyModel(some_data={"foo": 1, "bar": 2}) + >>> m.save() + >>> m.some_data + {'foo': 1, 'bar': 2} + >>> + + .. warning:: If a non-JSON serializable object is set to the field then a + :err:`peewee.IntegrityError` will be raised + + .. warning:: This is a very bad way to store data in a RDBMS and effectively makes the data + contained in the field unqueriable. + + :param dump_params: Additional keyword arguments to unpack into :func:`json.dump` + :param load_params: Additional keyword arguments to unpack into :func:`json.load` + """ + + def __init__( + self, + *args, + dump_params: Optional[Dict[str, Any]] = None, + load_params: Optional[Dict[str, Any]] = None, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.dump_params = dump_params or dict() + self.load_params = load_params or dict() + + def db_value(self, value: Any) -> str: + """Convert the python value to the corresponding value to store in the database""" + try: + return super().db_value(json.dumps(value, **self.dump_params)) + except TypeError as err: + raise peewee.IntegrityError( + f"Failed to JSON encode object of type '{type(value)}'" + ) from err + + def python_value(self, value: str) -> Any: + """Convert the database-stored value to the corresponding python value""" + try: + return json.loads(super().python_value(value), **self.load_params) + except json.JSONDecodeError as err: + raise peewee.IntegrityError( + f"Failed to decode JSON value from database column '{self.column}'" + ) from err diff --git a/tests/test_jsonfield.py b/tests/test_jsonfield.py new file mode 100644 index 0000000..0a45511 --- /dev/null +++ b/tests/test_jsonfield.py @@ -0,0 +1,35 @@ +# pylint: disable=redefined-outer-name +# pylint: disable=missing-class-docstring +# pylint: disable=too-few-public-methods +# pylint: disable=unused-import +from pathlib import Path + +import peewee +import pytest + +import peewee_plus +from .fixtures import fakedb + + +def test_json(fakedb): + """Test basic usage of JSONField class""" + + class TestModel(peewee.Model): + class Meta: + database = fakedb + + some_data = peewee_plus.JSONField() + + fakedb.create_tables([TestModel]) + + data = {"foo": 10, "bar": ["hello", "world"], "baz": True} + + model = TestModel(some_data=data) + model.save() + + model = TestModel.get() + assert model.some_data == data + + with pytest.raises(peewee.IntegrityError): + bad = TestModel(some_data=Path(".")) + bad.save()