You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
285 lines
10 KiB
285 lines
10 KiB
from django.conf import settings
|
|
from django.apps import apps
|
|
from django.contrib.auth import get_user_model
|
|
|
|
from .models import BaseModel
|
|
|
|
import threading
|
|
import logging
|
|
|
|
from typing import List, Optional, Dict
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
User = get_user_model()
|
|
|
|
class ModelRegistry:
|
|
def __init__(self):
|
|
self._registry = {}
|
|
|
|
def load_sync_apps(self):
|
|
sync_apps = getattr(settings, 'SYNC_APPS', {})
|
|
for app_label, config in sync_apps.items():
|
|
app_models = apps.get_app_config(app_label).get_models()
|
|
for model in app_models:
|
|
if hasattr(model, '_meta') and not model._meta.abstract:
|
|
if issubclass(model, BaseModel) or model == User:
|
|
model_name = model.__name__
|
|
if self.should_sync_model(model_name, config):
|
|
self.register(model)
|
|
|
|
def should_sync_model(self, model_name, config):
|
|
if 'exclude' in config and model_name in config['exclude']:
|
|
return False
|
|
if 'models' in config and config['models']:
|
|
return model_name in config['models']
|
|
return True
|
|
|
|
def register(self, model):
|
|
self._registry[model.__name__] = model
|
|
|
|
def get_model(self, model_name):
|
|
if not self._registry:
|
|
self.load_sync_apps()
|
|
return self._registry.get(model_name)
|
|
|
|
# Global instance
|
|
model_registry = ModelRegistry()
|
|
|
|
class DeviceRegistry:
|
|
"""Thread-safe registry to track device IDs associated with model instances."""
|
|
|
|
def __init__(self):
|
|
self._registry = {}
|
|
self._lock = threading.RLock()
|
|
|
|
def count(self):
|
|
"""Return the number of items in the registry."""
|
|
with self._lock:
|
|
return len(self._registry)
|
|
|
|
def register(self, instance_id, device_id):
|
|
"""Register a device_id for a model instance ID."""
|
|
with self._lock:
|
|
self._registry[str(instance_id)] = device_id
|
|
|
|
def get_device_id(self, instance_id):
|
|
"""Get the device_id for a model instance ID."""
|
|
with self._lock:
|
|
return self._registry.get(str(instance_id))
|
|
|
|
def unregister(self, instance_id):
|
|
"""Remove an instance from the registry."""
|
|
with self._lock:
|
|
if instance_id in self._registry:
|
|
del self._registry[instance_id]
|
|
|
|
# Global instance
|
|
device_registry = DeviceRegistry()
|
|
|
|
class RelatedUsersRegistry:
|
|
"""Thread-safe registry to track device IDs associated with model instances."""
|
|
|
|
def __init__(self):
|
|
self._registry = {}
|
|
self._lock = threading.RLock()
|
|
|
|
def count(self):
|
|
"""Return the number of items in the registry."""
|
|
with self._lock:
|
|
return len(self._registry)
|
|
|
|
def register(self, instance_id, users):
|
|
"""Register a device_id for a model instance ID."""
|
|
with self._lock:
|
|
instance_id_str = str(instance_id)
|
|
if instance_id_str in self._registry:
|
|
existing_users = self._registry[instance_id_str]
|
|
self._registry[instance_id_str] = existing_users.union(users)
|
|
else:
|
|
self._registry[instance_id_str] = users
|
|
|
|
def get_users(self, instance_id):
|
|
"""Get the device_id for a model instance ID."""
|
|
with self._lock:
|
|
return self._registry.get(str(instance_id))
|
|
|
|
def unregister(self, instance_id):
|
|
"""Remove an instance from the registry."""
|
|
with self._lock:
|
|
if instance_id in self._registry:
|
|
del self._registry[instance_id]
|
|
|
|
# Global instance
|
|
related_users_registry = RelatedUsersRegistry()
|
|
|
|
|
|
class SyncModelChildrenManager:
|
|
"""
|
|
Manager class for handling model children sharing configuration.
|
|
Reads the SYNC_MODEL_CHILDREN_SHARING setting once and builds a bidirectional
|
|
relationship graph for efficient lookup.
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Initialize the manager by reading the Django setting and building the relationship graph."""
|
|
self._model_relationships = getattr(
|
|
settings,
|
|
'SYNC_MODEL_CHILDREN_SHARING',
|
|
{}
|
|
)
|
|
self._relationship_graph = self._build_relationship_graph()
|
|
logger.info(f'self._relationship_graph = {self._relationship_graph}')
|
|
|
|
def _build_relationship_graph(self) -> Dict[str, List[List[str]]]:
|
|
"""
|
|
Build a bidirectional relationship graph.
|
|
|
|
Returns:
|
|
Dict[str, List[List[str]]]: Dictionary where keys are model names and values
|
|
are lists of relationship paths (arrays of relationship names).
|
|
"""
|
|
graph = {}
|
|
|
|
# Add direct relationships (original models to their children)
|
|
for model_name, relationships in self._model_relationships.items():
|
|
if model_name not in graph:
|
|
graph[model_name] = []
|
|
|
|
# Add direct relationships as single-item arrays
|
|
for relationship in relationships:
|
|
graph[model_name].append([relationship])
|
|
|
|
# Build reverse relationships (children back to original models)
|
|
for original_model_name, relationships in self._model_relationships.items():
|
|
try:
|
|
original_model = model_registry.get_model(original_model_name)
|
|
if original_model is None:
|
|
continue
|
|
|
|
for relationship_name in relationships:
|
|
# Get the related model through _meta
|
|
try:
|
|
field = None
|
|
# Try to find the field in the model's _meta
|
|
for f in original_model._meta.get_fields():
|
|
if hasattr(f, 'related_name') and f.related_name == relationship_name:
|
|
field = f
|
|
break
|
|
elif hasattr(f, 'name') and f.name == relationship_name:
|
|
field = f
|
|
break
|
|
|
|
if field is None:
|
|
continue
|
|
|
|
# Get the related model
|
|
if hasattr(field, 'related_model'):
|
|
related_model = field.related_model
|
|
elif hasattr(field, 'model'):
|
|
related_model = field.model
|
|
else:
|
|
continue
|
|
|
|
related_model_name = related_model.__name__
|
|
|
|
# Find the reverse relationship name
|
|
reverse_relationship_name = self._find_reverse_relationship(
|
|
related_model, original_model, relationship_name
|
|
)
|
|
|
|
if reverse_relationship_name:
|
|
# Add the reverse path
|
|
if related_model_name not in graph:
|
|
graph[related_model_name] = []
|
|
|
|
# The path back is just the reverse relationship name
|
|
graph[related_model_name].append([reverse_relationship_name])
|
|
|
|
except Exception as e:
|
|
# Skip problematic relationships
|
|
continue
|
|
|
|
except Exception as e:
|
|
# Skip problematic models
|
|
continue
|
|
|
|
return graph
|
|
|
|
def _find_reverse_relationship(self, from_model, to_model, original_relationship_name):
|
|
"""
|
|
Find the reverse relationship name from from_model to to_model.
|
|
|
|
Args:
|
|
from_model: The model to search relationships from
|
|
to_model: The target model to find relationship to
|
|
original_relationship_name: The original relationship name for context
|
|
|
|
Returns:
|
|
str or None: The reverse relationship name if found
|
|
"""
|
|
try:
|
|
for field in from_model._meta.get_fields():
|
|
# Check ForeignKey, OneToOneField fields
|
|
if hasattr(field, 'related_model') and field.related_model == to_model:
|
|
# Check if this field has a related_name that matches our original relationship
|
|
if hasattr(field, 'related_name') and field.related_name == original_relationship_name:
|
|
# This is the reverse of our original relationship
|
|
return field.name
|
|
elif not hasattr(field, 'related_name') or field.related_name is None:
|
|
# Default reverse relationship name
|
|
default_name = f"{to_model._meta.model_name}"
|
|
if default_name == original_relationship_name.rstrip('s'): # Simple heuristic
|
|
return field.name
|
|
|
|
# Check reverse relationships
|
|
if hasattr(field, 'field') and hasattr(field.field, 'model'):
|
|
if field.field.model == to_model:
|
|
if field.get_accessor_name() == original_relationship_name:
|
|
return field.field.name
|
|
|
|
except Exception:
|
|
pass
|
|
|
|
return None
|
|
|
|
def get_relationships(self, model_name: str) -> List[str]:
|
|
"""
|
|
Get the list of direct relationships for a given model name.
|
|
|
|
Args:
|
|
model_name (str): The name of the model to look up
|
|
|
|
Returns:
|
|
List[str]: List of relationship names for the model.
|
|
Returns empty list if model is not found.
|
|
"""
|
|
return self._model_relationships.get(model_name, [])
|
|
|
|
def get_relationship_paths(self, model_name: str) -> List[List[str]]:
|
|
"""
|
|
Get all relationship paths for a given model name.
|
|
This includes both direct relationships and reverse paths.
|
|
|
|
Args:
|
|
model_name (str): The name of the model to look up
|
|
|
|
Returns:
|
|
List[List[str]]: List of relationship paths (each path is a list of relationship names).
|
|
Returns empty list if model is not found.
|
|
"""
|
|
return self._relationship_graph.get(model_name, [])
|
|
|
|
def get_relationship_graph(self) -> Dict[str, List[List[str]]]:
|
|
"""
|
|
Get the complete relationship graph.
|
|
|
|
Returns:
|
|
Dict[str, List[List[str]]]: The complete relationship graph
|
|
"""
|
|
return self._relationship_graph.copy()
|
|
|
|
|
|
# Create a singleton instance to use throughout the application
|
|
sync_model_manager = SyncModelChildrenManager()
|
|
|