I was investigating why some users were losing their data, and discovered that data was being written to the wrong MongoDB collection, in the process overwriting existing data. I narrowed it down to a piece of code that was using Mongoengine's switch_collection
. I noticed when looking at the logs that this was happening when multiple requests were being made at the same time. It turn out that switch_collection
is not thread safe.
The switch_collection
context manager works simply by modifying the collection name on the global class object, affecting all threads that are using it, so between one request switching collections and starting to write to it another request can switch the connection to a different collection, to which both threads will now write. Bad!
I decided that the best course of action would be to stop using switch_collection
altogether and just create separate models for each collection. Here's how I did it:
The previous setup
To create a simplified scenario, suppose we are using Mongoengine to store bank account balances. We have a model that looks like this:
class BankAccount(Document):
account_number = StringField()
balance = FloatField()
Let's say we store accounts in separate collections based on the bank branch, so we have collections called branch1
and branch2
, and a function for incrementing the balance that looks like this:
def increment_balance(branch_name, account_number, amount):
'''Add an amount to an account.'''
with switch_collection(BankAccount, branch_name) as NamedAccount:
NamedAccount.objects(account_number=account_number).update_one(
inc__balance=amount,
)
The switchless way
Instead of having one model class for all collections, we can use inheritance and dynamically create a separate model for each collection.
class BaseAccount(Document):
'''Abstract base class for accounts.'''
account_number = StringField()
balance = FloatField()
meta = {
'abstract': True,
}
class BankAccount(BaseAccount):
'''
Generic BankAccount model.
Kept for legacy code that still uses `switch_collection`.
'''
# Dynamically create a separate class for each bank branch
# and store them in a dict.
accounts_by_branch = {}
for branch_name in ['branch1', 'branch2']:
accounts_by_branch[branch_name] = type(
branch_name,
(BaseAccount,),
{
'meta': {
'collection': branch_name,
}
}
)
With these classes created, instead of using switch_collection
, we can write the increment method like this:
def increment_balance(branch_name, account_number, amount):
'''Add an amount to an account.'''
NamedAccount = accounts_by_branch[branch_name]
NamedAccount.objects(account_number=account_number).update_one(
inc__balance=amount,
)
Not only is our code now thread-safe, it is also cleaner and easier to read!