# Redis connection and caching management import redis import json import pickle import logging from typing import Optional, Any, Dict, List, Union from datetime import timedelta from functools import wraps import hashlib import time logger = logging.getLogger(__name__) class RedisManager: """Manage Redis connections and operations""" def __init__(self, app=None, config=None): self.redis_client = None self.config = config or {} self.key_prefix = self.config.get('key_prefix', 'talk2me:') if app: self.init_app(app) def init_app(self, app): """Initialize Redis with Flask app""" # Get Redis configuration redis_url = app.config.get('REDIS_URL', 'redis://localhost:6379/0') # Parse connection options decode_responses = app.config.get('REDIS_DECODE_RESPONSES', False) max_connections = app.config.get('REDIS_MAX_CONNECTIONS', 50) socket_timeout = app.config.get('REDIS_SOCKET_TIMEOUT', 5) # Create connection pool pool = redis.ConnectionPool.from_url( redis_url, max_connections=max_connections, socket_timeout=socket_timeout, decode_responses=decode_responses ) self.redis_client = redis.Redis(connection_pool=pool) # Test connection try: self.redis_client.ping() logger.info(f"Redis connected successfully to {redis_url}") except redis.ConnectionError as e: logger.error(f"Failed to connect to Redis: {e}") raise # Store reference in app app.redis = self def _make_key(self, key: str) -> str: """Create a prefixed key""" return f"{self.key_prefix}{key}" # Basic operations def get(self, key: str, default=None) -> Any: """Get value from Redis""" try: value = self.redis_client.get(self._make_key(key)) if value is None: return default # Try to deserialize JSON first try: return json.loads(value) except (json.JSONDecodeError, TypeError): # Try pickle for complex objects try: return pickle.loads(value) except: # Return as string return value.decode('utf-8') if isinstance(value, bytes) else value except Exception as e: logger.error(f"Redis get error for key {key}: {e}") return default def set(self, key: str, value: Any, expire: Optional[int] = None) -> bool: """Set value in Redis with optional expiration""" try: # Serialize value if isinstance(value, (str, int, float)): serialized = str(value) elif isinstance(value, (dict, list)): serialized = json.dumps(value) else: serialized = pickle.dumps(value) return self.redis_client.set( self._make_key(key), serialized, ex=expire ) except Exception as e: logger.error(f"Redis set error for key {key}: {e}") return False def delete(self, *keys) -> int: """Delete one or more keys""" try: prefixed_keys = [self._make_key(k) for k in keys] return self.redis_client.delete(*prefixed_keys) except Exception as e: logger.error(f"Redis delete error: {e}") return 0 def exists(self, key: str) -> bool: """Check if key exists""" try: return bool(self.redis_client.exists(self._make_key(key))) except Exception as e: logger.error(f"Redis exists error for key {key}: {e}") return False def expire(self, key: str, seconds: int) -> bool: """Set expiration on a key""" try: return bool(self.redis_client.expire(self._make_key(key), seconds)) except Exception as e: logger.error(f"Redis expire error for key {key}: {e}") return False # Hash operations for session/rate limiting def hget(self, name: str, key: str, default=None) -> Any: """Get value from hash""" try: value = self.redis_client.hget(self._make_key(name), key) if value is None: return default try: return json.loads(value) except (json.JSONDecodeError, TypeError): return value.decode('utf-8') if isinstance(value, bytes) else value except Exception as e: logger.error(f"Redis hget error for {name}:{key}: {e}") return default def hset(self, name: str, key: str, value: Any) -> bool: """Set value in hash""" try: if isinstance(value, (dict, list)): value = json.dumps(value) return bool(self.redis_client.hset(self._make_key(name), key, value)) except Exception as e: logger.error(f"Redis hset error for {name}:{key}: {e}") return False def hgetall(self, name: str) -> Dict[str, Any]: """Get all values from hash""" try: data = self.redis_client.hgetall(self._make_key(name)) result = {} for k, v in data.items(): key = k.decode('utf-8') if isinstance(k, bytes) else k try: result[key] = json.loads(v) except: result[key] = v.decode('utf-8') if isinstance(v, bytes) else v return result except Exception as e: logger.error(f"Redis hgetall error for {name}: {e}") return {} def hdel(self, name: str, *keys) -> int: """Delete fields from hash""" try: return self.redis_client.hdel(self._make_key(name), *keys) except Exception as e: logger.error(f"Redis hdel error for {name}: {e}") return 0 # List operations for queues def lpush(self, key: str, *values) -> int: """Push values to the left of list""" try: serialized = [] for v in values: if isinstance(v, (dict, list)): serialized.append(json.dumps(v)) else: serialized.append(v) return self.redis_client.lpush(self._make_key(key), *serialized) except Exception as e: logger.error(f"Redis lpush error for {key}: {e}") return 0 def rpop(self, key: str, default=None) -> Any: """Pop value from the right of list""" try: value = self.redis_client.rpop(self._make_key(key)) if value is None: return default try: return json.loads(value) except: return value.decode('utf-8') if isinstance(value, bytes) else value except Exception as e: logger.error(f"Redis rpop error for {key}: {e}") return default def llen(self, key: str) -> int: """Get length of list""" try: return self.redis_client.llen(self._make_key(key)) except Exception as e: logger.error(f"Redis llen error for {key}: {e}") return 0 # Set operations for unique tracking def sadd(self, key: str, *values) -> int: """Add members to set""" try: return self.redis_client.sadd(self._make_key(key), *values) except Exception as e: logger.error(f"Redis sadd error for {key}: {e}") return 0 def srem(self, key: str, *values) -> int: """Remove members from set""" try: return self.redis_client.srem(self._make_key(key), *values) except Exception as e: logger.error(f"Redis srem error for {key}: {e}") return 0 def sismember(self, key: str, value: Any) -> bool: """Check if value is member of set""" try: return bool(self.redis_client.sismember(self._make_key(key), value)) except Exception as e: logger.error(f"Redis sismember error for {key}: {e}") return False def scard(self, key: str) -> int: """Get number of members in set""" try: return self.redis_client.scard(self._make_key(key)) except Exception as e: logger.error(f"Redis scard error for {key}: {e}") return 0 def smembers(self, key: str) -> set: """Get all members of set""" try: members = self.redis_client.smembers(self._make_key(key)) return {m.decode('utf-8') if isinstance(m, bytes) else m for m in members} except Exception as e: logger.error(f"Redis smembers error for {key}: {e}") return set() # Atomic counters def incr(self, key: str, amount: int = 1) -> int: """Increment counter""" try: return self.redis_client.incr(self._make_key(key), amount) except Exception as e: logger.error(f"Redis incr error for {key}: {e}") return 0 def decr(self, key: str, amount: int = 1) -> int: """Decrement counter""" try: return self.redis_client.decr(self._make_key(key), amount) except Exception as e: logger.error(f"Redis decr error for {key}: {e}") return 0 # Transaction support def pipeline(self): """Create a pipeline for atomic operations""" return self.redis_client.pipeline() # Pub/Sub support def publish(self, channel: str, message: Any) -> int: """Publish message to channel""" try: if isinstance(message, (dict, list)): message = json.dumps(message) return self.redis_client.publish(self._make_key(channel), message) except Exception as e: logger.error(f"Redis publish error for {channel}: {e}") return 0 def subscribe(self, *channels): """Subscribe to channels""" pubsub = self.redis_client.pubsub() prefixed_channels = [self._make_key(c) for c in channels] pubsub.subscribe(*prefixed_channels) return pubsub # Cache helpers def cache_translation(self, source_text: str, source_lang: str, target_lang: str, translation: str, expire_hours: int = 24) -> bool: """Cache a translation""" key = self._translation_key(source_text, source_lang, target_lang) data = { 'translation': translation, 'timestamp': time.time(), 'hits': 0 } return self.set(key, data, expire=expire_hours * 3600) def get_cached_translation(self, source_text: str, source_lang: str, target_lang: str) -> Optional[str]: """Get cached translation and increment hit counter""" key = self._translation_key(source_text, source_lang, target_lang) data = self.get(key) if data and isinstance(data, dict): # Increment hit counter data['hits'] = data.get('hits', 0) + 1 self.set(key, data) return data.get('translation') return None def _translation_key(self, text: str, source_lang: str, target_lang: str) -> str: """Generate cache key for translation""" # Create a hash of the text to handle long texts text_hash = hashlib.md5(text.encode()).hexdigest() return f"translation:{source_lang}:{target_lang}:{text_hash}" # Session management def save_session(self, session_id: str, data: Dict[str, Any], expire_seconds: int = 3600) -> bool: """Save session data""" key = f"session:{session_id}" return self.set(key, data, expire=expire_seconds) def get_session(self, session_id: str) -> Optional[Dict[str, Any]]: """Get session data""" key = f"session:{session_id}" return self.get(key) def delete_session(self, session_id: str) -> bool: """Delete session data""" key = f"session:{session_id}" return bool(self.delete(key)) def extend_session(self, session_id: str, expire_seconds: int = 3600) -> bool: """Extend session expiration""" key = f"session:{session_id}" return self.expire(key, expire_seconds) # Rate limiting def check_rate_limit(self, identifier: str, limit: int, window_seconds: int) -> tuple[bool, int]: """Check rate limit using sliding window""" key = f"rate_limit:{identifier}:{window_seconds}" now = time.time() window_start = now - window_seconds pipe = self.pipeline() # Remove old entries pipe.zremrangebyscore(self._make_key(key), 0, window_start) # Count current entries pipe.zcard(self._make_key(key)) # Add current request pipe.zadd(self._make_key(key), {str(now): now}) # Set expiration pipe.expire(self._make_key(key), window_seconds + 1) results = pipe.execute() current_count = results[1] if current_count >= limit: return False, limit - current_count return True, limit - current_count - 1 # Cleanup def cleanup_expired_keys(self, pattern: str = "*") -> int: """Clean up expired keys matching pattern""" try: cursor = 0 deleted = 0 while True: cursor, keys = self.redis_client.scan( cursor, match=self._make_key(pattern), count=100 ) for key in keys: ttl = self.redis_client.ttl(key) if ttl == -2: # Key doesn't exist continue elif ttl == -1: # Key exists but no TTL # Set a default TTL of 24 hours for keys without expiration self.redis_client.expire(key, 86400) if cursor == 0: break return deleted except Exception as e: logger.error(f"Redis cleanup error: {e}") return 0 # Cache decorator def redis_cache(expire_seconds: int = 300, key_prefix: str = ""): """Decorator to cache function results in Redis""" def decorator(func): @wraps(func) def wrapper(*args, **kwargs): # Get Redis instance from app context from flask import current_app redis_manager = getattr(current_app, 'redis', None) if not redis_manager: # No Redis, execute function normally return func(*args, **kwargs) # Generate cache key cache_key = f"{key_prefix}:{func.__name__}:" cache_key += hashlib.md5( f"{args}:{kwargs}".encode() ).hexdigest() # Try to get from cache cached = redis_manager.get(cache_key) if cached is not None: logger.debug(f"Cache hit for {func.__name__}") return cached # Execute function and cache result result = func(*args, **kwargs) redis_manager.set(cache_key, result, expire=expire_seconds) return result return wrapper return decorator