2023-04-13 17:34:37 +00:00
|
|
|
"""peewee+
|
|
|
|
|
|
|
|
Various extensions, helpers, and utilities for `Peewee`_
|
2021-11-25 03:52:11 +00:00
|
|
|
|
|
|
|
:constant SQLITE_DEFAULT_VARIABLE_LIMIT: The default number of variables that a single SQL query
|
|
|
|
can contain when interfacing with SQLite. The actual
|
|
|
|
number is set at compile time and is not easily
|
|
|
|
discoverable from within Python. This default value will
|
|
|
|
be correct for the vast majority of applications.
|
|
|
|
|
|
|
|
:constant SQLITE_DEFAULT_PRAGMAS: The default pragmas that should be used when instantiating an
|
|
|
|
SQLite database connection. The value for this constant is taken
|
|
|
|
directly from the `Peewee documentation`_
|
|
|
|
|
2023-04-13 17:34:37 +00:00
|
|
|
.. _`Peewee`: https://docs.peewee-orm.com/en/latest/
|
|
|
|
|
|
|
|
.. _`Peewee documentation`: https://docs.peewee-orm.com/en/latest/peewee/database.html#recommended-settings
|
2021-11-25 03:52:11 +00:00
|
|
|
"""
|
2022-01-20 05:59:02 +00:00
|
|
|
import contextlib
|
2023-05-04 18:25:05 +00:00
|
|
|
import datetime
|
2021-11-25 03:01:30 +00:00
|
|
|
import enum
|
2022-01-20 05:59:02 +00:00
|
|
|
import functools
|
2021-11-25 02:33:39 +00:00
|
|
|
import json
|
2021-11-24 22:13:41 +00:00
|
|
|
from pathlib import Path
|
2021-11-25 02:33:39 +00:00
|
|
|
from typing import Any
|
|
|
|
from typing import Dict
|
2021-11-24 22:13:41 +00:00
|
|
|
from typing import Optional
|
2021-11-25 03:40:04 +00:00
|
|
|
from typing import Sequence
|
2021-11-25 03:01:30 +00:00
|
|
|
from typing import Type
|
2021-11-25 03:40:04 +00:00
|
|
|
from typing import TypeVar
|
2021-11-24 22:13:41 +00:00
|
|
|
|
|
|
|
import peewee
|
|
|
|
|
2021-11-25 03:40:04 +00:00
|
|
|
|
2021-11-24 21:30:47 +00:00
|
|
|
__title__ = "peewee-plus"
|
2023-05-04 18:35:40 +00:00
|
|
|
__version__ = "1.3.0"
|
2021-11-24 21:30:47 +00:00
|
|
|
__license__ = "MIT"
|
|
|
|
__summary__ = "Various extensions, helpers, and utilities for Peewee"
|
|
|
|
__url__ = "https://github.com/enpaul/peewee-plus/"
|
|
|
|
__authors__ = ["Ethan Paul <24588726+enpaul@users.noreply.github.com>"]
|
2021-11-24 22:13:41 +00:00
|
|
|
|
|
|
|
|
2021-11-25 03:40:04 +00:00
|
|
|
__all__ = [
|
2021-11-25 03:43:37 +00:00
|
|
|
"__title__",
|
|
|
|
"__version__",
|
|
|
|
"__license__",
|
|
|
|
"__summary__",
|
|
|
|
"__url__",
|
|
|
|
"__authors__",
|
2021-11-25 03:41:58 +00:00
|
|
|
"calc_batch_size",
|
|
|
|
"EnumField",
|
2022-01-20 05:59:02 +00:00
|
|
|
"flat_transaction",
|
2021-11-25 03:41:58 +00:00
|
|
|
"JSONField",
|
2021-11-25 03:40:04 +00:00
|
|
|
"PathField",
|
|
|
|
"PrecisionFloatField",
|
2021-11-25 03:41:58 +00:00
|
|
|
"SQLITE_DEFAULT_PRAGMAS",
|
2021-11-25 03:40:04 +00:00
|
|
|
"SQLITE_DEFAULT_VARIABLE_LIMIT",
|
2023-05-04 18:25:05 +00:00
|
|
|
"TimedeltaField",
|
2021-11-25 03:40:04 +00:00
|
|
|
]
|
|
|
|
|
|
|
|
|
2021-11-25 03:41:58 +00:00
|
|
|
SQLITE_DEFAULT_PRAGMAS: Dict[str, Any] = {
|
|
|
|
"journal_mode": "wal",
|
|
|
|
"cache_size": -1 * 64000,
|
|
|
|
"foreign_keys": 1,
|
|
|
|
"ignore_check_constraints": 0,
|
|
|
|
"synchronous": 0,
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2023-04-13 18:16:11 +00:00
|
|
|
SQLITE_DEFAULT_VARIABLE_LIMIT: int
|
|
|
|
|
|
|
|
# With SQLite 3.32 (2020-05-22) the devs bumped the default variable limit to
|
|
|
|
# 32766. This logic attemps to import the sqlite3 bindings and determine whether
|
|
|
|
# the version of the installed SQLite version is greater or equal to 3.32. If
|
|
|
|
# the sqlite3 bindings cannot be imported (either because they aren't installed)
|
|
|
|
# or because the platform is using SQLite 1 or 2 then it falls back to the
|
|
|
|
# 999 value.
|
|
|
|
try:
|
|
|
|
import sqlite3
|
|
|
|
except ImportError:
|
|
|
|
SQLITE_DEFAULT_VARIABLE_LIMIT = 999
|
|
|
|
else:
|
|
|
|
if sqlite3.sqlite_version_info[0] >= 3 or (
|
|
|
|
sqlite3.sqlite_version_info[0] == 3 and sqlite3.sqlite_version_info[1] >= 32
|
|
|
|
):
|
|
|
|
SQLITE_DEFAULT_VARIABLE_LIMIT = 32766
|
|
|
|
else:
|
|
|
|
SQLITE_DEFAULT_VARIABLE_LIMIT = 999
|
2021-11-25 03:40:04 +00:00
|
|
|
|
|
|
|
|
|
|
|
T = TypeVar("T", bound=peewee.Model)
|
|
|
|
|
|
|
|
|
|
|
|
def calc_batch_size(
|
|
|
|
models: Sequence[T], sqlite_variable_limit: int = SQLITE_DEFAULT_VARIABLE_LIMIT
|
|
|
|
) -> int:
|
|
|
|
"""Determine the batch size that should be used when performing queries
|
|
|
|
|
|
|
|
This is intended to work around the query variable limit in SQLite. Critically this is a
|
|
|
|
limit to the number of _variables_, not _records_ that can be referenced in a single query.
|
|
|
|
|
|
|
|
The "correct" way to calculate this is to iterate over the model list and tally the number of
|
|
|
|
changed fields, then add one for the table name, and each time you reach the
|
|
|
|
``SQLITE_VARIABLE_LIMIT`` (which is a known constant) cut a new batch until all the models are
|
|
|
|
processed. This is very complicated because peewee doesn't provide a simple way to reliably
|
|
|
|
identify changed fields.
|
|
|
|
|
|
|
|
The naive way to calculate this (i.e. the way this function does it) is to determine the
|
|
|
|
maximum number of variables that _could be_ used to modify a record and use that as the
|
|
|
|
constant batch limiter. The theoretical maximum number of variables associated with a single
|
|
|
|
record is equal to the number of fields on that record, plus 1 (for the table name). This
|
|
|
|
gives the batch size (i.e. number of records that can be modified in a single query) as:
|
|
|
|
|
|
|
|
::
|
|
|
|
|
|
|
|
999 / (len(fields) + 1)
|
|
|
|
|
|
|
|
Where ``fields`` is an array of the fields that could be written on the record.
|
|
|
|
|
|
|
|
Example usage:
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
models = [MyModel(...), MyModel(...), MyModel(...), MyModel(...)]
|
|
|
|
|
|
|
|
with database.atomic():
|
|
|
|
MyModel.bulk_create(models, batch_size=calc_batch_size(models))
|
|
|
|
|
|
|
|
.. note:: This function (pretty safely) requires that all the records in ``models`` are all
|
|
|
|
instances of the same model.
|
|
|
|
|
|
|
|
.. note:: This function just returns ``len(models)`` if the backend is anything other than
|
|
|
|
:class:`peewee.SqliteDatabase`. This is because the limitation this function works
|
|
|
|
around is only applicable to SQLite, so on other platforms the batch size can just
|
|
|
|
be as large as possible. This also helps to support writing code that transparently
|
|
|
|
supports multiple backends.
|
|
|
|
|
|
|
|
:param models: Sequence of models to be created or updated that need to be batched
|
|
|
|
:param sqlite_variable_limit: Number of variables that can be present in a single SQL query;
|
|
|
|
this is defined at compile time in the SQLite bindings for the
|
|
|
|
current platform and should not need to be changed unless using
|
|
|
|
SQLite bindings that were compiled with custom parameters.
|
|
|
|
:returns: Number of models that can be processed in a single batch
|
|
|
|
"""
|
|
|
|
# We need to inspect the models in the logic below, so if there are no models then just
|
|
|
|
# return zero since the batch size doesn't matter anyway
|
|
|
|
if not models:
|
|
|
|
return 0
|
|
|
|
if isinstance(
|
|
|
|
models[0]._meta.database, # pylint: disable=protected-access
|
|
|
|
peewee.SqliteDatabase,
|
|
|
|
):
|
|
|
|
return int(
|
|
|
|
sqlite_variable_limit
|
|
|
|
/ (len(models[0]._meta.fields) + 1) # pylint: disable=protected-access
|
|
|
|
)
|
|
|
|
return len(models)
|
2021-11-24 22:13:41 +00:00
|
|
|
|
|
|
|
|
2022-01-20 05:59:02 +00:00
|
|
|
def flat_transaction(interface: peewee.Database):
|
|
|
|
"""Database transaction wrapper that avoids nested transactions
|
|
|
|
|
|
|
|
A decorator that can be used to decorate functions or methods so that the entire callable
|
|
|
|
is executed in a single transaction context. If a transaction is already open then it will
|
|
|
|
be reused rather than opening a nested transaction.
|
|
|
|
|
|
|
|
Example usage:
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
db = peewee.SqliteDatabase("test.db")
|
|
|
|
|
|
|
|
|
|
|
|
@flat_transaction(db)
|
|
|
|
def subquery():
|
|
|
|
...
|
|
|
|
|
|
|
|
|
|
|
|
@flat_transaction(db)
|
|
|
|
def my_query():
|
|
|
|
...
|
|
|
|
subquery()
|
|
|
|
|
|
|
|
|
|
|
|
# This call opens only a single transaction
|
|
|
|
my_query()
|
|
|
|
|
|
|
|
:param interface: Peewee database interface that should be used to open the transaction
|
|
|
|
"""
|
|
|
|
|
|
|
|
def outer(func):
|
|
|
|
@functools.wraps(func)
|
|
|
|
def inner(*args, **kwargs):
|
|
|
|
with interface.atomic() if not interface.in_transaction() else contextlib.nullcontext():
|
|
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
|
|
return inner
|
|
|
|
|
|
|
|
return outer
|
|
|
|
|
|
|
|
|
2022-09-07 18:35:30 +00:00
|
|
|
# TODO: The disable=abstract-method pragmas below are to get around new linting warnings
|
|
|
|
# in pylint>2.12, but they haven't been addressed properly. They should be revisited
|
|
|
|
# and fixed properly in the future.
|
|
|
|
class PathField(peewee.CharField): # pylint: disable=abstract-method
|
2021-11-24 22:13:41 +00:00
|
|
|
"""Field class for storing file paths
|
|
|
|
|
|
|
|
This field can be used to simply store pathlib paths in the database without needing to
|
|
|
|
cast to ``str`` on write and ``Path`` on read.
|
|
|
|
|
|
|
|
It can also serve to save paths relative to a root path defined at runtime. This can be
|
|
|
|
useful when an application stores files under a directory defined in the app configuration,
|
|
|
|
such as in an environment variable or a config file.
|
|
|
|
|
|
|
|
For example, if a model is defined like below to load a path from the ``MYAPP_DATA_DIR``
|
|
|
|
environment variable:
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
class MyModel(peewee.Model):
|
|
|
|
some_path = peewee_plus.PathField(relative_to=Path(os.environ["MYAPP_DATA_DIR"]))
|
|
|
|
|
|
|
|
|
|
|
|
p1 = MyModel(some_path=Path(os.environ["MYAPP_DATA_DIR"]) / "foo.json").save()
|
|
|
|
p2 = MyModel(some_path=Path("bar.json")).save()
|
|
|
|
|
|
|
|
Then the data directory can be changed without updating the database, and the code can
|
|
|
|
still rely on the database always returning absolute paths:
|
|
|
|
|
|
|
|
::
|
|
|
|
|
|
|
|
>>> os.environ["MYAPP_DATA_DIR"] = "/etc/myapp"
|
|
|
|
>>> [item.some_path for item in MyModel.select()]
|
|
|
|
[PosixPath('/etc/myapp/foo.json'), PosixPath('/etc/myapp/bar.json')]
|
|
|
|
>>>
|
|
|
|
>>> os.environ["MYAPP_DATA_DIR"] = "/opt/myapp/data"
|
|
|
|
>>> [item.some_path for item in MyModel.select()]
|
|
|
|
[PosixPath('/opt/myapp/data/foo.json'), PosixPath('/opt/myapp/data/bar.json')]
|
|
|
|
>>>
|
|
|
|
|
|
|
|
:param relative_to: Optional root path that paths should be stored relative to. If specified
|
|
|
|
then values being set will be converted to relative paths under this path,
|
|
|
|
and values being read will always be absolute paths under this path.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, *args, relative_to: Optional[Path] = None, **kwargs):
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self.relative_to = relative_to
|
|
|
|
|
|
|
|
def db_value(self, value: Path) -> str:
|
|
|
|
if value.is_absolute() and self.relative_to:
|
|
|
|
value = value.relative_to(self.relative_to)
|
|
|
|
return super().db_value(value)
|
|
|
|
|
|
|
|
def python_value(self, value: str) -> Path:
|
|
|
|
return (
|
|
|
|
self.relative_to / Path(super().python_value(value))
|
|
|
|
if self.relative_to
|
|
|
|
else Path(super().python_value(value))
|
|
|
|
)
|
2021-11-24 23:33:43 +00:00
|
|
|
|
|
|
|
|
2022-09-07 18:35:30 +00:00
|
|
|
class PrecisionFloatField(peewee.FloatField): # pylint: disable=abstract-method
|
2021-11-24 23:33:43 +00:00
|
|
|
"""Field class for storing floats with custom precision parameters
|
|
|
|
|
|
|
|
This field adds support for specifying the ``M`` and ``D`` precision parameters of a
|
|
|
|
``FLOAT`` field as specified in the `MySQL documentation`_.
|
|
|
|
accepts. See the `MySQL docs`_ for more information.
|
|
|
|
|
|
|
|
.. warning:: This field implements syntax that is specific to MySQL. When used with a
|
|
|
|
different database backend, such as SQLite or Postgres, it behaves identically
|
|
|
|
to :class:`peewee.FloatField`
|
|
|
|
|
|
|
|
.. note:: This field's implementation was adapted from here_
|
|
|
|
|
|
|
|
.. _`MySQL documentation`: https://dev.mysql.com/doc/refman/8.0/en/floating-point-types.html
|
|
|
|
.. _here: https://stackoverflow.com/a/67476045/5361209
|
|
|
|
|
|
|
|
:param max_digits: Maximum number of digits, combined from left and right of the decimal place,
|
2023-04-13 17:34:37 +00:00
|
|
|
to store for the value; corresponds to the ``M`` MySQL precision parameter.
|
|
|
|
:param decimal_places: Maximum number of digits that will be stored after the decimal place;
|
|
|
|
corresponds to the ``D`` MySQL precision parameter.
|
2021-11-24 23:33:43 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, *args, max_digits: int = 10, decimal_places: int = 4, **kwargs):
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self.max_digits = max_digits
|
|
|
|
self.decimal_places = decimal_places
|
|
|
|
|
|
|
|
def get_modifiers(self):
|
|
|
|
return [self.max_digits, self.decimal_places]
|
2021-11-25 02:33:39 +00:00
|
|
|
|
|
|
|
|
2022-09-07 18:35:30 +00:00
|
|
|
class JSONField(peewee.TextField): # pylint: disable=abstract-method
|
2021-11-25 02:33:39 +00:00
|
|
|
"""Field class for storing JSON-serializable data
|
|
|
|
|
2023-04-13 17:34:37 +00:00
|
|
|
This field can be used to store a dictionary of data directly in the database
|
2021-11-25 02:33:39 +00:00
|
|
|
without needing to call :func:`json.dumps` and :func:`json.loads` directly.
|
|
|
|
|
|
|
|
::
|
|
|
|
|
|
|
|
>>> class MyModel(peewee.Model):
|
|
|
|
... some_data = JSONField()
|
|
|
|
...
|
|
|
|
>>> m = MyModel(some_data={"foo": 1, "bar": 2})
|
|
|
|
>>> m.save()
|
|
|
|
>>> m.some_data
|
|
|
|
{'foo': 1, 'bar': 2}
|
|
|
|
>>>
|
|
|
|
|
|
|
|
.. warning:: If a non-JSON serializable object is set to the field then a
|
|
|
|
:err:`peewee.IntegrityError` will be raised
|
|
|
|
|
|
|
|
.. warning:: This is a very bad way to store data in a RDBMS and effectively makes the data
|
|
|
|
contained in the field unqueriable.
|
|
|
|
|
|
|
|
:param dump_params: Additional keyword arguments to unpack into :func:`json.dump`
|
|
|
|
:param load_params: Additional keyword arguments to unpack into :func:`json.load`
|
2023-05-04 19:08:47 +00:00
|
|
|
:raises ValueError: When attempting to set a non-JSON serializable object to the field
|
|
|
|
:raises peewee.IntegrityError: When the underlying database value is not JSON serializable
|
2021-11-25 02:33:39 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
*args,
|
|
|
|
dump_params: Optional[Dict[str, Any]] = None,
|
|
|
|
load_params: Optional[Dict[str, Any]] = None,
|
|
|
|
**kwargs,
|
|
|
|
):
|
|
|
|
super().__init__(*args, **kwargs)
|
2021-11-25 03:53:38 +00:00
|
|
|
self.dump_params = dump_params or {}
|
|
|
|
self.load_params = load_params or {}
|
2021-11-25 02:33:39 +00:00
|
|
|
|
|
|
|
def db_value(self, value: Any) -> str:
|
|
|
|
try:
|
|
|
|
return super().db_value(json.dumps(value, **self.dump_params))
|
|
|
|
except TypeError as err:
|
2023-05-04 19:08:47 +00:00
|
|
|
raise ValueError(
|
2021-11-25 02:33:39 +00:00
|
|
|
f"Failed to JSON encode object of type '{type(value)}'"
|
|
|
|
) from err
|
|
|
|
|
|
|
|
def python_value(self, value: str) -> Any:
|
|
|
|
try:
|
|
|
|
return json.loads(super().python_value(value), **self.load_params)
|
|
|
|
except json.JSONDecodeError as err:
|
|
|
|
raise peewee.IntegrityError(
|
|
|
|
f"Failed to decode JSON value from database column '{self.column}'"
|
|
|
|
) from err
|
2021-11-25 03:01:30 +00:00
|
|
|
|
|
|
|
|
2022-09-07 18:35:30 +00:00
|
|
|
class EnumField(peewee.CharField): # pylint: disable=abstract-method
|
2021-11-25 03:01:30 +00:00
|
|
|
"""Field class for storing Enums
|
|
|
|
|
|
|
|
This field can be used for storing members of an :class:`enum.Enum` in the database,
|
|
|
|
effectively storing a database reference to a value defined in the application.
|
|
|
|
|
|
|
|
.. warning:: This field ties database data to application structure: if the Enum passed
|
|
|
|
to this field is modified then the application may encounter errors when
|
|
|
|
trying to interface with the database schema.
|
|
|
|
|
|
|
|
::
|
|
|
|
|
|
|
|
>>> class MyOptions(enum.Enum):
|
|
|
|
... FOO = "have you ever heard the tragedy"
|
|
|
|
... BAR = "of darth plageius"
|
|
|
|
... BAZ = "the wise?"
|
|
|
|
...
|
|
|
|
>>>
|
|
|
|
>>> class MyModel(peewee.Model):
|
|
|
|
... option = EnumField(MyOptions)
|
|
|
|
...
|
|
|
|
>>> m = MyModel(option=MyOptions.FOO)
|
|
|
|
>>> m.save()
|
|
|
|
>>> m.option
|
|
|
|
<MyOptions.FOO: "have you ever heard the tragedy">
|
|
|
|
>>>
|
|
|
|
|
|
|
|
:param enumeration: The Enum to accept members of and to use for decoding database values
|
|
|
|
:raises TypeError: If the value to be written to the field is not a member of the
|
|
|
|
specified Enum
|
|
|
|
:raises peewee.IntegrityError: If the value read back from the database cannot be decoded to
|
|
|
|
a member of the specified Enum
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, enumeration: Type[enum.Enum], *args, **kwargs):
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self.enumeration = enumeration
|
|
|
|
|
|
|
|
def db_value(self, value: enum.Enum) -> str:
|
|
|
|
if not isinstance(value, self.enumeration):
|
|
|
|
raise TypeError(f"Enum {self.enumeration.__name__} has no value '{value}'")
|
|
|
|
return super().db_value(value.name)
|
|
|
|
|
|
|
|
def python_value(self, value: str) -> enum.Enum:
|
|
|
|
try:
|
2022-07-20 18:42:44 +00:00
|
|
|
return (
|
|
|
|
None
|
|
|
|
if value is None and self.null
|
|
|
|
else self.enumeration[super().python_value(value)]
|
|
|
|
)
|
2021-11-25 03:01:30 +00:00
|
|
|
except KeyError:
|
|
|
|
raise peewee.IntegrityError(
|
|
|
|
f"Enum {self.enumeration.__name__} has no value with name '{value}'"
|
|
|
|
) from None
|
2023-05-04 18:25:05 +00:00
|
|
|
|
|
|
|
|
|
|
|
class TimedeltaField(peewee.BigIntegerField):
|
|
|
|
"""Field class for storing python-native Timedelta objects
|
|
|
|
|
|
|
|
This is really just a helper wrapper around an integer field that performs the second conversions
|
|
|
|
automatically. It is a helpful helper though, so it's included.
|
|
|
|
|
|
|
|
.. note:: To avoid issues with float precision, this field stores the database value as an integer.
|
|
|
|
However, this necessitates the usage of the BigInt type to avoid overflowing the value.
|
|
|
|
Essentially, the value this field ends up storing is the number of microseconds in the
|
|
|
|
timedelta.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def db_value(self, value: datetime.timedelta) -> int:
|
|
|
|
return super().db_value(int(value.total_seconds() * 1000000))
|
|
|
|
|
|
|
|
def python_value(self, value: int) -> datetime.timedelta:
|
|
|
|
return datetime.timedelta(seconds=super().python_value(value) / 1000000)
|