From 65a45d209d015ae04f6ad14d229432a7b45105ce Mon Sep 17 00:00:00 2001 From: Laurent Date: Wed, 9 Apr 2025 15:14:17 +0200 Subject: [PATCH] Fixes issues with data access --- sync/models/data_access.py | 7 ++-- sync/registry.py | 40 +++++++++++++++++++-- sync/signals.py | 72 ++++++++++++++++++++++++-------------- sync/utils.py | 6 ++-- sync/views.py | 16 ++++----- 5 files changed, 100 insertions(+), 41 deletions(-) diff --git a/sync/models/data_access.py b/sync/models/data_access.py index 94f4c8e..53039a5 100644 --- a/sync/models/data_access.py +++ b/sync/models/data_access.py @@ -4,7 +4,7 @@ from django.utils import timezone from django.core.exceptions import ObjectDoesNotExist from django.conf import settings -from ..registry import sync_registry +from ..registry import model_registry import uuid from . import ModelLog, SideStoreModel, BaseModel @@ -18,6 +18,9 @@ class DataAccess(BaseModel): granted_at = models.DateTimeField(auto_now_add=True) # last_hierarchy_update = models.DateTimeField(default=timezone.now) + def delete_dependencies(self): + pass + def create_revoke_access_log(self): self.create_access_log(self.shared_with.all(), 'REVOKE_ACCESS') @@ -29,7 +32,7 @@ class DataAccess(BaseModel): def create_access_log(self, users, operation): """Create an access log for a list of users """ - model_class = sync_registry.get_model(self.model_name) + model_class = model_registry.get_model(self.model_name) if model_class: try: obj = model_class.objects.get(id=self.model_id) diff --git a/sync/registry.py b/sync/registry.py index 91a86c1..f27b2e2 100644 --- a/sync/registry.py +++ b/sync/registry.py @@ -6,7 +6,7 @@ import threading User = get_user_model() -class SyncRegistry: +class ModelRegistry: def __init__(self): self._registry = {} @@ -37,7 +37,7 @@ class SyncRegistry: return self._registry.get(model_name) # Global instance -sync_registry = SyncRegistry() +model_registry = ModelRegistry() class DeviceRegistry: """Thread-safe registry to track device IDs associated with model instances.""" @@ -69,3 +69,39 @@ class DeviceRegistry: # Global instance device_registry = DeviceRegistry() + +class RelatedUsersRegistry: + """Thread-safe registry to track device IDs associated with model instances.""" + + def __init__(self): + self._registry = {} + self._lock = threading.RLock() + + def count(self): + """Return the number of items in the registry.""" + with self._lock: + return len(self._registry) + + def register(self, instance_id, users): + """Register a device_id for a model instance ID.""" + with self._lock: + instance_id_str = str(instance_id) + if instance_id_str in self._registry: + existing_users = self._registry[instance_id_str] + self._registry[instance_id_str] = existing_users.union(users) + else: + self._registry[instance_id_str] = users + + def get_users(self, instance_id): + """Get the device_id for a model instance ID.""" + with self._lock: + return self._registry.get(str(instance_id)) + + def unregister(self, instance_id): + """Remove an instance from the registry.""" + with self._lock: + if instance_id in self._registry: + del self._registry[instance_id] + +# Global instance +related_users_registry = RelatedUsersRegistry() diff --git a/sync/signals.py b/sync/signals.py index 4d174f1..3632065 100644 --- a/sync/signals.py +++ b/sync/signals.py @@ -9,10 +9,12 @@ from django.contrib.auth import get_user_model from django.utils import timezone from .ws_sender import websocket_sender -from .registry import device_registry +from .registry import device_registry, related_users_registry User = get_user_model() +### Device + @receiver(post_save, sender=Device) def device_created(sender, instance, **kwargs): if not instance.user: @@ -29,12 +31,15 @@ def device_post_delete(sender, instance, **kwargs): return evaluate_if_user_should_sync(instance._user) +### Sync + @receiver(pre_save) def presave_handler(sender, instance, **kwargs): synchronization_prepare(sender, instance, **kwargs) def synchronization_prepare(sender, instance, **kwargs): + print(f'*** synchronization_prepare for instance: {instance}') signal = kwargs.get('signal') # avoid crash in manage.py createsuperuser + delete user in the admin @@ -45,6 +50,11 @@ def synchronization_prepare(sender, instance, **kwargs): if not isinstance(instance, BaseModel) and not isinstance(instance, User): return + users = related_users(instance) + print(f'* impacted users = {users}') + related_users_registry.register(instance.id, users) + # user_ids = [user.id for user in users] + if signal == pre_save: detect_foreign_key_changes(sender, instance) @@ -66,34 +76,40 @@ def synchronization_notifications(sender, instance, created=False, **kwargs): save_model_log_if_possible(instance, signal, created) notify_impacted_users(instance) -def notify_impacted_users(instance): - user_ids = set() - # add impacted users - if isinstance(instance, User): - user_ids.add(instance.id) - elif isinstance(instance, BaseModel): - owner = instance.last_updated_by - if owner: - user_ids.add(owner.id) + related_users_registry.unregister(instance.id) - if isinstance(instance, BaseModel): - if hasattr(instance, '_users_to_notify'): - user_ids.update(instance._users_to_notify) - else: - print('no users to notify') +def notify_impacted_users(instance): + print(f'*** notify_impacted_users for instance: {instance}') + # user_ids = set() + # # add impacted users + # if isinstance(instance, User): + # user_ids.add(instance.id) + # elif isinstance(instance, BaseModel): + # owner = instance.last_updated_by + # if owner: + # user_ids.add(owner.id) + + # if isinstance(instance, BaseModel): + # if hasattr(instance, '_users_to_notify'): + # user_ids.update(instance._users_to_notify) + # else: + # print('no users to notify') device_id = device_registry.get_device_id(instance.id) + users = related_users_registry.get_users(instance.id) - # print(f'notify: {device_id}') - for user_id in user_ids: - websocket_sender.send_user_message(user_id, device_id) + if users: + user_ids = [user.id for user in users] + print(f'notify device: {device_id}, users = {user_ids}') + for user_id in user_ids: + websocket_sender.send_user_message(user_id, device_id) device_registry.unregister(instance.id) def save_model_log_if_possible(instance, signal, created): - users = related_users(instance) - # print(f'users = {len(users)}, instance = {instance}') + users = related_users_registry.get_users(instance.id) + print(f'*** save_model_log >>> users = {users}, instance = {instance}') if users: if signal == post_save or signal == pre_save: if created: @@ -111,10 +127,9 @@ def save_model_log_if_possible(instance, signal, created): if operation == ModelOperation.DELETE: # delete now unnecessary logs ModelLog.objects.filter(model_id=instance.id).delete() - user_ids = [user.id for user in users] - - # print(f'users to notify: {user_ids}') - instance._users_to_notify = user_ids # save this for the post_save signal + # user_ids = [user.id for user in users] + # # print(f'users to notify: {user_ids}') + # instance._users_to_notify = user_ids # save this for the post_save signal save_model_log(users, operation, model_name, instance.id, store_id) else: @@ -200,6 +215,8 @@ def process_foreign_key_changes(sender, instance, **kwargs): model_name, change['new_value'].id, change['new_value'].get_store_id()) +### Data Access + @receiver(post_delete) def delete_data_access_if_necessary(sender, instance, **kwargs): if not isinstance(instance, BaseModel): @@ -210,6 +227,7 @@ def delete_data_access_if_necessary(sender, instance, **kwargs): @receiver(m2m_changed, sender=DataAccess.shared_with.through) def handle_shared_with_changes(sender, instance, action, pk_set, **kwargs): + print(f'm2m changed = {pk_set}') users = User.objects.filter(id__in=pk_set) if action == "post_add": @@ -233,10 +251,14 @@ def data_access_post_save(sender, instance, **kwargs): @receiver(pre_delete, sender=DataAccess) def revoke_access_after_delete(sender, instance, **kwargs): instance.create_revoke_access_log() + related_users_registry.register(instance.id, instance.shared_with.all()) + instance._user = instance.related_user @receiver(post_delete, sender=DataAccess) def data_access_post_delete(sender, instance, **kwargs): + notify_impacted_users(instance) + if not hasattr(instance, '_user') or not instance._user: return evaluate_if_user_should_sync(instance._user) @@ -285,8 +307,6 @@ def evaluate_if_user_should_sync(user): ).count() > 0: should_synchronize = True - print(f'should_synchronize = {should_synchronize}') - with transaction.atomic(): user.should_synchronize = should_synchronize # if we go from True to False we might want to delete ModelLog once the last device has synchronized diff --git a/sync/utils.py b/sync/utils.py index f8b3cd3..493c0b6 100644 --- a/sync/utils.py +++ b/sync/utils.py @@ -1,6 +1,6 @@ import importlib from django.apps import apps -from .registry import sync_registry +from .registry import model_registry from collections import defaultdict def build_serializer_class(model_name): @@ -30,14 +30,14 @@ def get_serializer(instance, model_name): return serializer(instance) def get_data(model_name, model_id): - model = sync_registry.get_model(model_name) + model = model_registry.get_model(model_name) # print(f'model_name = {model_name}') # model = apps.get_model(app_label=app_label, model_name=model_name) return model.objects.get(id=model_id) def get_serialized_data(model_name, model_id): # print(f'model_name = {model_name}') - model = sync_registry.get_model(model_name) + model = model_registry.get_model(model_name) instance = model.objects.get(id=model_id) serializer_class = build_serializer_class(model_name) serializer = serializer_class(instance) diff --git a/sync/views.py b/sync/views.py index 3d80895..1b04955 100644 --- a/sync/views.py +++ b/sync/views.py @@ -21,7 +21,7 @@ from .utils import get_serializer, build_serializer_class, get_data, get_seriali from .models import ModelLog, BaseModel, SideStoreModel, DataAccess -from .registry import sync_registry, device_registry +from .registry import model_registry, device_registry class HierarchyApiView(APIView): @@ -117,7 +117,7 @@ class SynchronizationApi(HierarchyApiView): data['last_updated_by'] = request.user.id # always refresh the user performing the operation # model = apps.get_model(app_label='tournaments', model_name=model_name) - model = sync_registry.get_model(model_name) + model = model_registry.get_model(model_name) if model_operation == 'POST': @@ -186,7 +186,7 @@ class SynchronizationApi(HierarchyApiView): last_update_str = request.query_params.get('last_update') device_id = request.query_params.get('device_id') - print(f'last_update_str = {last_update_str}') + # print(f'last_update_str = {last_update_str}') decoded_last_update = unquote(last_update_str) # Decodes %2B into + # print(f'last_update_str = {last_update_str}') @@ -198,7 +198,7 @@ class SynchronizationApi(HierarchyApiView): except ValueError: return Response({"error": f"Invalid date format for last_update: {decoded_last_update}"}, status=status.HTTP_400_BAD_REQUEST) - print(f'/data GET: {last_update}') + # print(f'/data GET: {last_update}') logs = self.query_model_logs(last_update, request.user, device_id) print(f'>>> log count = {len(logs)}') @@ -227,7 +227,7 @@ class SynchronizationApi(HierarchyApiView): deletions[log.model_name].append({'model_id': log.model_id, 'store_id': log.store_id}) elif log.operation == 'GRANT_ACCESS': - model = sync_registry.get_model(log.model_name) + model = model_registry.get_model(log.model_name) instance = model.objects.get(id=log.model_id) serializer = get_serializer(instance, log.model_name) @@ -242,7 +242,7 @@ class SynchronizationApi(HierarchyApiView): }) # Get the model instance and add its parents to hierarchy - model = sync_registry.get_model(log.model_name) + model = model_registry.get_model(log.model_name) try: instance = model.objects.get(id=log.model_id) self.add_parents_with_hierarchy_organizer(instance, revocations_parents_organizer) @@ -274,7 +274,7 @@ class SynchronizationApi(HierarchyApiView): "date": last_log_date } - print(f'sync GET response. UP = {len(updates)} / DEL = {len(deletions)} / G = {len(grants)} / R = {len(revocations)}') + # print(f'sync GET response. UP = {len(updates)} / DEL = {len(deletions)} / G = {len(grants)} / R = {len(revocations)}') # print(f'sync GET response. response = {response_data}') return Response(response_data, status=status.HTTP_200_OK) @@ -301,7 +301,7 @@ class UserDataAccessApi(HierarchyApiView): for data_access in data_access_objects: try: - model = sync_registry.get_model(data_access.model_name) + model = model_registry.get_model(data_access.model_name) instance = model.objects.get(id=data_access.model_id) # Get the base data