diff --git a/peewee_plus.py b/peewee_plus.py index e8696a9..735e823 100644 --- a/peewee_plus.py +++ b/peewee_plus.py @@ -4,10 +4,13 @@ from pathlib import Path from typing import Any from typing import Dict from typing import Optional +from typing import Sequence from typing import Type +from typing import TypeVar import peewee + __title__ = "peewee-plus" __version__ = "0.1.0" __license__ = "MIT" @@ -16,7 +19,86 @@ __url__ = "https://github.com/enpaul/peewee-plus/" __authors__ = ["Ethan Paul <24588726+enpaul@users.noreply.github.com>"] -__all__ = ["PathField", "PrecisionFloatField", "JSONField", "EnumField"] +__all__ = [ + "PathField", + "PrecisionFloatField", + "JSONField", + "EnumField", + "SQLITE_DEFAULT_VARIABLE_LIMIT", + "calc_batch_size", +] + + +SQLITE_DEFAULT_VARIABLE_LIMIT: int = 999 + + +T = TypeVar("T", bound=peewee.Model) + + +def calc_batch_size( + models: Sequence[T], sqlite_variable_limit: int = SQLITE_DEFAULT_VARIABLE_LIMIT +) -> int: + """Determine the batch size that should be used when performing queries + + This is intended to work around the query variable limit in SQLite. Critically this is a + limit to the number of _variables_, not _records_ that can be referenced in a single query. + + The "correct" way to calculate this is to iterate over the model list and tally the number of + changed fields, then add one for the table name, and each time you reach the + ``SQLITE_VARIABLE_LIMIT`` (which is a known constant) cut a new batch until all the models are + processed. This is very complicated because peewee doesn't provide a simple way to reliably + identify changed fields. + + The naive way to calculate this (i.e. the way this function does it) is to determine the + maximum number of variables that _could be_ used to modify a record and use that as the + constant batch limiter. The theoretical maximum number of variables associated with a single + record is equal to the number of fields on that record, plus 1 (for the table name). This + gives the batch size (i.e. number of records that can be modified in a single query) as: + + :: + + 999 / (len(fields) + 1) + + Where ``fields`` is an array of the fields that could be written on the record. + + Example usage: + + .. code-block:: python + + models = [MyModel(...), MyModel(...), MyModel(...), MyModel(...)] + + with database.atomic(): + MyModel.bulk_create(models, batch_size=calc_batch_size(models)) + + .. note:: This function (pretty safely) requires that all the records in ``models`` are all + instances of the same model. + + .. note:: This function just returns ``len(models)`` if the backend is anything other than + :class:`peewee.SqliteDatabase`. This is because the limitation this function works + around is only applicable to SQLite, so on other platforms the batch size can just + be as large as possible. This also helps to support writing code that transparently + supports multiple backends. + + :param models: Sequence of models to be created or updated that need to be batched + :param sqlite_variable_limit: Number of variables that can be present in a single SQL query; + this is defined at compile time in the SQLite bindings for the + current platform and should not need to be changed unless using + SQLite bindings that were compiled with custom parameters. + :returns: Number of models that can be processed in a single batch + """ + # We need to inspect the models in the logic below, so if there are no models then just + # return zero since the batch size doesn't matter anyway + if not models: + return 0 + if isinstance( + models[0]._meta.database, # pylint: disable=protected-access + peewee.SqliteDatabase, + ): + return int( + sqlite_variable_limit + / (len(models[0]._meta.fields) + 1) # pylint: disable=protected-access + ) + return len(models) class PathField(peewee.CharField): diff --git a/tests/test_calc_batch_size.py b/tests/test_calc_batch_size.py new file mode 100644 index 0000000..11ec09b --- /dev/null +++ b/tests/test_calc_batch_size.py @@ -0,0 +1,45 @@ +# pylint: disable=redefined-outer-name +# pylint: disable=missing-class-docstring +# pylint: disable=too-few-public-methods +# pylint: disable=unused-import +import peewee + +import peewee_plus +from .fixtures import fakedb + + +def test_sqlite(fakedb): + """Test the calculation of batch sizes on SQLite""" + + class TestModel(peewee.Model): + class Meta: + database = fakedb + + data = peewee.IntegerField() + + models = [TestModel(item) for item in range(500)] + assert ( + peewee_plus.calc_batch_size(models) <= peewee_plus.SQLITE_DEFAULT_VARIABLE_LIMIT + ) + assert peewee_plus.calc_batch_size(models) < len(models) + + assert peewee_plus.calc_batch_size([]) == 0 + + +def test_non_sqlite(): + """Test the calculation of batch sizes on non-SQLite""" + + class TestModel(peewee.Model): + class Meta: + database = peewee.DatabaseProxy() + + data = peewee.IntegerField() + + # Three is just chosen as an arbitrary multiplier to ensure the value is larger than the + # sqlite variable limit + assert peewee_plus.calc_batch_size( + [ + TestModel(item) + for item in range(peewee_plus.SQLITE_DEFAULT_VARIABLE_LIMIT * 3) + ] + ) == (peewee_plus.SQLITE_DEFAULT_VARIABLE_LIMIT * 3)