refactor device_id system through a registry to properly notify devices

sync
Laurent 8 months ago
parent 9d2e2ec912
commit ccb71a884c
  1. 34
      sync/registry.py
  2. 52
      sync/signals.py
  3. 12
      sync/views.py

@ -2,6 +2,7 @@ from django.conf import settings
from django.apps import apps from django.apps import apps
from .models import BaseModel from .models import BaseModel
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
import threading
User = get_user_model() User = get_user_model()
@ -35,5 +36,36 @@ class SyncRegistry:
self.load_sync_apps() self.load_sync_apps()
return self._registry.get(model_name) return self._registry.get(model_name)
# Create singleton instance # Global instance
sync_registry = SyncRegistry() sync_registry = SyncRegistry()
class DeviceRegistry:
"""Thread-safe registry to track device IDs associated with model instances."""
def __init__(self):
self._registry = {}
self._lock = threading.RLock()
def count(self):
"""Return the number of items in the registry."""
with self._lock:
return len(self._registry)
def register(self, instance_id, device_id):
"""Register a device_id for a model instance ID."""
with self._lock:
self._registry[str(instance_id)] = device_id
def get_device_id(self, instance_id):
"""Get the device_id for a model instance ID."""
with self._lock:
return self._registry.get(str(instance_id))
def unregister(self, instance_id):
"""Remove an instance from the registry."""
with self._lock:
if instance_id in self._registry:
del self._registry[instance_id]
# Global instance
device_registry = DeviceRegistry()

@ -7,6 +7,7 @@ from django.contrib.auth import get_user_model
from django.utils import timezone from django.utils import timezone
from .ws_sender import websocket_sender from .ws_sender import websocket_sender
from .registry import device_registry
User = get_user_model() User = get_user_model()
@ -43,11 +44,7 @@ def synchronization_prepare(sender, instance, **kwargs):
return return
if signal == pre_save: if signal == pre_save:
device_id = None detect_foreign_key_changes(sender, instance)
if hasattr(instance, '_device_id'):
device_id = instance._device_id
detect_foreign_key_changes(sender, instance, device_id)
@receiver([post_save, post_delete]) @receiver([post_save, post_delete])
def synchronization_notifications(sender, instance, created=False, **kwargs): def synchronization_notifications(sender, instance, created=False, **kwargs):
@ -61,20 +58,12 @@ def synchronization_notifications(sender, instance, created=False, **kwargs):
if not isinstance(instance, BaseModel) and not isinstance(instance, User): if not isinstance(instance, BaseModel) and not isinstance(instance, User):
return return
device_id = None process_foreign_key_changes(sender, instance, **kwargs)
if hasattr(instance, '_device_id'):
device_id = instance._device_id
process_foreign_key_changes(sender, instance, device_id, **kwargs)
signal = kwargs.get('signal') signal = kwargs.get('signal')
save_model_log_if_possible(instance, signal, created, device_id) save_model_log_if_possible(instance, signal, created)
notify_impacted_users(instance) notify_impacted_users(instance)
# print(f'*** instance._state.db: {instance._state.db}')
# transaction.on_commit(lambda: notify_impacted_users(instance))
def notify_impacted_users(instance): def notify_impacted_users(instance):
user_ids = set() user_ids = set()
# add impacted users # add impacted users
@ -91,16 +80,15 @@ def notify_impacted_users(instance):
else: else:
print('no users to notify') print('no users to notify')
device_id = None device_id = device_registry.get_device_id(instance.id)
if hasattr(instance, '_device_id'):
device_id = instance._device_id
# print(f'notify: {user_ids}') # print(f'notify: {device_id}')
for user_id in user_ids: for user_id in user_ids:
websocket_sender.send_user_message(user_id, device_id) websocket_sender.send_user_message(user_id, device_id)
# send_user_message(user_id)
def save_model_log_if_possible(instance, signal, created, device_id): device_registry.unregister(instance.id)
def save_model_log_if_possible(instance, signal, created):
users = related_users(instance) users = related_users(instance)
# print(f'users = {len(users)}, instance = {instance}') # print(f'users = {len(users)}, instance = {instance}')
@ -125,12 +113,14 @@ def save_model_log_if_possible(instance, signal, created, device_id):
# print(f'users to notify: {user_ids}') # print(f'users to notify: {user_ids}')
instance._users_to_notify = user_ids # save this for the post_save signal instance._users_to_notify = user_ids # save this for the post_save signal
save_model_log(users, operation, model_name, instance.id, store_id, device_id) save_model_log(users, operation, model_name, instance.id, store_id)
else: else:
print(f'>>> Model Log could not be created because no linked user could be found: {instance.__class__.__name__} {instance}, {signal}') print(f'>>> Model Log could not be created because no linked user could be found: {instance.__class__.__name__} {instance}, {signal}')
def save_model_log(users, model_operation, model_name, model_id, store_id, device_id): def save_model_log(users, model_operation, model_name, model_id, store_id):
device_id = device_registry.get_device_id(model_id)
with transaction.atomic(): with transaction.atomic():
for user in users: for user in users:
@ -164,7 +154,7 @@ def save_model_log(users, model_operation, model_name, model_id, store_id, devic
# model_log.save() # model_log.save()
# model_log.users.set(users) # model_log.users.set(users)
def detect_foreign_key_changes(sender, instance, device_id): def detect_foreign_key_changes(sender, instance):
if not hasattr(instance, 'pk') or not instance.pk: if not hasattr(instance, 'pk') or not instance.pk:
return return
if not isinstance(instance, BaseModel): if not isinstance(instance, BaseModel):
@ -190,7 +180,8 @@ def detect_foreign_key_changes(sender, instance, device_id):
'new_value': new_value 'new_value': new_value
}) })
def process_foreign_key_changes(sender, instance, device_id, **kwargs): def process_foreign_key_changes(sender, instance, **kwargs):
if hasattr(instance, '_fk_changes'): if hasattr(instance, '_fk_changes'):
for change in instance._fk_changes: for change in instance._fk_changes:
for data_access in change['data_access_list']: for data_access in change['data_access_list']:
@ -198,12 +189,12 @@ def process_foreign_key_changes(sender, instance, device_id, **kwargs):
model_name = change['old_value'].__class__.__name__ model_name = change['old_value'].__class__.__name__
save_model_log(data_access.concerned_users(), 'REVOKE_ACCESS', save_model_log(data_access.concerned_users(), 'REVOKE_ACCESS',
model_name, change['old_value'].id, model_name, change['old_value'].id,
change['old_value'].get_store_id(), device_id) change['old_value'].get_store_id())
if change['new_value']: if change['new_value']:
model_name = change['new_value'].__class__.__name__ model_name = change['new_value'].__class__.__name__
save_model_log(data_access.concerned_users(), 'GRANT_ACCESS', save_model_log(data_access.concerned_users(), 'GRANT_ACCESS',
model_name, change['new_value'].id, model_name, change['new_value'].id,
change['new_value'].get_store_id(), device_id) change['new_value'].get_store_id())
@receiver(post_delete) @receiver(post_delete)
def delete_data_access_if_necessary(sender, instance, **kwargs): def delete_data_access_if_necessary(sender, instance, **kwargs):
@ -222,12 +213,11 @@ def handle_shared_with_changes(sender, instance, action, pk_set, **kwargs):
elif action == "post_remove": elif action == "post_remove":
instance.create_access_log(users, 'REVOKE_ACCESS') instance.create_access_log(users, 'REVOKE_ACCESS')
device_id = None device_id = device_registry.get_device_id(instance.id)
if hasattr(instance, '_device_id'):
device_id = instance._device_id
for user_id in pk_set: for user_id in pk_set:
websocket_sender.send_user_message(user_id, device_id) websocket_sender.send_user_message(user_id, device_id)
device_registry.unregister(instance.id)
for user in users: for user in users:
evaluate_if_user_should_sync(user) evaluate_if_user_should_sync(user)

@ -21,7 +21,7 @@ from .utils import get_serializer, build_serializer_class, get_data, get_seriali
from .models import ModelLog, BaseModel, SideStoreModel, DataAccess from .models import ModelLog, BaseModel, SideStoreModel, DataAccess
from .registry import sync_registry from .registry import sync_registry, device_registry
class HierarchyApiView(APIView): class HierarchyApiView(APIView):
@ -104,6 +104,10 @@ class SynchronizationApi(HierarchyApiView):
model_operation = op.get('operation') model_operation = op.get('operation')
model_name = op.get('model_name') model_name = op.get('model_name')
data = op.get('data') data = op.get('data')
data_id = data.get('id')
device_registry.register(data_id, device_id)
print(f'*** 1count = {device_registry.count()}')
try: try:
print(f'{model_operation} : {model_name}, id = {data['id']}') print(f'{model_operation} : {model_name}, id = {data['id']}')
@ -119,7 +123,6 @@ class SynchronizationApi(HierarchyApiView):
serializer = serializer_class(data=data, context={'request': request}) serializer = serializer_class(data=data, context={'request': request})
if serializer.is_valid(): if serializer.is_valid():
instance = serializer.save() instance = serializer.save()
instance._device_id = device_id
result = serializer.data result = serializer.data
response_status = status.HTTP_201_CREATED response_status = status.HTTP_201_CREATED
else: else:
@ -127,9 +130,7 @@ class SynchronizationApi(HierarchyApiView):
message = json.dumps(serializer.errors) message = json.dumps(serializer.errors)
response_status = status.HTTP_400_BAD_REQUEST response_status = status.HTTP_400_BAD_REQUEST
elif model_operation == 'PUT': elif model_operation == 'PUT':
data_id = data.get('id')
instance = get_data(model_name, data_id) instance = get_data(model_name, data_id)
instance._device_id = device_id
serializer = serializer_class(instance, data=data, context={'request': request}) serializer = serializer_class(instance, data=data, context={'request': request})
if serializer.is_valid(): if serializer.is_valid():
if instance.last_update <= serializer.validated_data.get('last_update'): if instance.last_update <= serializer.validated_data.get('last_update'):
@ -143,11 +144,8 @@ class SynchronizationApi(HierarchyApiView):
print(f'Data invalid ! {serializer.errors}') print(f'Data invalid ! {serializer.errors}')
response_status = status.HTTP_400_BAD_REQUEST response_status = status.HTTP_400_BAD_REQUEST
elif model_operation == 'DELETE': elif model_operation == 'DELETE':
data_id = data.get('id')
try: try:
instance = get_data(model_name, data_id) instance = get_data(model_name, data_id)
instance._device_id = device_id
try: try:
instance.delete() instance.delete()
response_status = status.HTTP_204_NO_CONTENT response_status = status.HTTP_204_NO_CONTENT

Loading…
Cancel
Save