mirror of
https://github.com/enpaul/peewee-plus.git
synced 2024-11-14 18:46:47 +00:00
Add function for calculating sqlite batch size
Add default sqlite variable limit constant
This commit is contained in:
parent
ba3f143dc9
commit
6484b395a2
@ -4,10 +4,13 @@ from pathlib import Path
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from typing import Sequence
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
import peewee
|
import peewee
|
||||||
|
|
||||||
|
|
||||||
__title__ = "peewee-plus"
|
__title__ = "peewee-plus"
|
||||||
__version__ = "0.1.0"
|
__version__ = "0.1.0"
|
||||||
__license__ = "MIT"
|
__license__ = "MIT"
|
||||||
@ -16,7 +19,86 @@ __url__ = "https://github.com/enpaul/peewee-plus/"
|
|||||||
__authors__ = ["Ethan Paul <24588726+enpaul@users.noreply.github.com>"]
|
__authors__ = ["Ethan Paul <24588726+enpaul@users.noreply.github.com>"]
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["PathField", "PrecisionFloatField", "JSONField", "EnumField"]
|
__all__ = [
|
||||||
|
"PathField",
|
||||||
|
"PrecisionFloatField",
|
||||||
|
"JSONField",
|
||||||
|
"EnumField",
|
||||||
|
"SQLITE_DEFAULT_VARIABLE_LIMIT",
|
||||||
|
"calc_batch_size",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
SQLITE_DEFAULT_VARIABLE_LIMIT: int = 999
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=peewee.Model)
|
||||||
|
|
||||||
|
|
||||||
|
def calc_batch_size(
|
||||||
|
models: Sequence[T], sqlite_variable_limit: int = SQLITE_DEFAULT_VARIABLE_LIMIT
|
||||||
|
) -> int:
|
||||||
|
"""Determine the batch size that should be used when performing queries
|
||||||
|
|
||||||
|
This is intended to work around the query variable limit in SQLite. Critically this is a
|
||||||
|
limit to the number of _variables_, not _records_ that can be referenced in a single query.
|
||||||
|
|
||||||
|
The "correct" way to calculate this is to iterate over the model list and tally the number of
|
||||||
|
changed fields, then add one for the table name, and each time you reach the
|
||||||
|
``SQLITE_VARIABLE_LIMIT`` (which is a known constant) cut a new batch until all the models are
|
||||||
|
processed. This is very complicated because peewee doesn't provide a simple way to reliably
|
||||||
|
identify changed fields.
|
||||||
|
|
||||||
|
The naive way to calculate this (i.e. the way this function does it) is to determine the
|
||||||
|
maximum number of variables that _could be_ used to modify a record and use that as the
|
||||||
|
constant batch limiter. The theoretical maximum number of variables associated with a single
|
||||||
|
record is equal to the number of fields on that record, plus 1 (for the table name). This
|
||||||
|
gives the batch size (i.e. number of records that can be modified in a single query) as:
|
||||||
|
|
||||||
|
::
|
||||||
|
|
||||||
|
999 / (len(fields) + 1)
|
||||||
|
|
||||||
|
Where ``fields`` is an array of the fields that could be written on the record.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
models = [MyModel(...), MyModel(...), MyModel(...), MyModel(...)]
|
||||||
|
|
||||||
|
with database.atomic():
|
||||||
|
MyModel.bulk_create(models, batch_size=calc_batch_size(models))
|
||||||
|
|
||||||
|
.. note:: This function (pretty safely) requires that all the records in ``models`` are all
|
||||||
|
instances of the same model.
|
||||||
|
|
||||||
|
.. note:: This function just returns ``len(models)`` if the backend is anything other than
|
||||||
|
:class:`peewee.SqliteDatabase`. This is because the limitation this function works
|
||||||
|
around is only applicable to SQLite, so on other platforms the batch size can just
|
||||||
|
be as large as possible. This also helps to support writing code that transparently
|
||||||
|
supports multiple backends.
|
||||||
|
|
||||||
|
:param models: Sequence of models to be created or updated that need to be batched
|
||||||
|
:param sqlite_variable_limit: Number of variables that can be present in a single SQL query;
|
||||||
|
this is defined at compile time in the SQLite bindings for the
|
||||||
|
current platform and should not need to be changed unless using
|
||||||
|
SQLite bindings that were compiled with custom parameters.
|
||||||
|
:returns: Number of models that can be processed in a single batch
|
||||||
|
"""
|
||||||
|
# We need to inspect the models in the logic below, so if there are no models then just
|
||||||
|
# return zero since the batch size doesn't matter anyway
|
||||||
|
if not models:
|
||||||
|
return 0
|
||||||
|
if isinstance(
|
||||||
|
models[0]._meta.database, # pylint: disable=protected-access
|
||||||
|
peewee.SqliteDatabase,
|
||||||
|
):
|
||||||
|
return int(
|
||||||
|
sqlite_variable_limit
|
||||||
|
/ (len(models[0]._meta.fields) + 1) # pylint: disable=protected-access
|
||||||
|
)
|
||||||
|
return len(models)
|
||||||
|
|
||||||
|
|
||||||
class PathField(peewee.CharField):
|
class PathField(peewee.CharField):
|
||||||
|
45
tests/test_calc_batch_size.py
Normal file
45
tests/test_calc_batch_size.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
# pylint: disable=redefined-outer-name
|
||||||
|
# pylint: disable=missing-class-docstring
|
||||||
|
# pylint: disable=too-few-public-methods
|
||||||
|
# pylint: disable=unused-import
|
||||||
|
import peewee
|
||||||
|
|
||||||
|
import peewee_plus
|
||||||
|
from .fixtures import fakedb
|
||||||
|
|
||||||
|
|
||||||
|
def test_sqlite(fakedb):
|
||||||
|
"""Test the calculation of batch sizes on SQLite"""
|
||||||
|
|
||||||
|
class TestModel(peewee.Model):
|
||||||
|
class Meta:
|
||||||
|
database = fakedb
|
||||||
|
|
||||||
|
data = peewee.IntegerField()
|
||||||
|
|
||||||
|
models = [TestModel(item) for item in range(500)]
|
||||||
|
assert (
|
||||||
|
peewee_plus.calc_batch_size(models) <= peewee_plus.SQLITE_DEFAULT_VARIABLE_LIMIT
|
||||||
|
)
|
||||||
|
assert peewee_plus.calc_batch_size(models) < len(models)
|
||||||
|
|
||||||
|
assert peewee_plus.calc_batch_size([]) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_sqlite():
|
||||||
|
"""Test the calculation of batch sizes on non-SQLite"""
|
||||||
|
|
||||||
|
class TestModel(peewee.Model):
|
||||||
|
class Meta:
|
||||||
|
database = peewee.DatabaseProxy()
|
||||||
|
|
||||||
|
data = peewee.IntegerField()
|
||||||
|
|
||||||
|
# Three is just chosen as an arbitrary multiplier to ensure the value is larger than the
|
||||||
|
# sqlite variable limit
|
||||||
|
assert peewee_plus.calc_batch_size(
|
||||||
|
[
|
||||||
|
TestModel(item)
|
||||||
|
for item in range(peewee_plus.SQLITE_DEFAULT_VARIABLE_LIMIT * 3)
|
||||||
|
]
|
||||||
|
) == (peewee_plus.SQLITE_DEFAULT_VARIABLE_LIMIT * 3)
|
Loading…
Reference in New Issue
Block a user