# This is the updated app.py with Redis and PostgreSQL integration # To use this, rename it to app.py after backing up the original import os import time import tempfile import requests import json import logging from dotenv import load_dotenv from flask import Flask, render_template, request, jsonify, Response, send_file, send_from_directory, stream_with_context, g from flask_cors import CORS, cross_origin import whisper import torch import ollama from whisper_config import MODEL_SIZE, GPU_OPTIMIZATIONS, TRANSCRIBE_OPTIONS from pywebpush import webpush, WebPushException import base64 from cryptography.hazmat.primitives.asymmetric import ec from cryptography.hazmat.primitives import serialization from cryptography.hazmat.backends import default_backend import gc from functools import wraps import traceback from validators import Validators import atexit import threading from datetime import datetime, timedelta # Import new database and Redis components from database import db, init_db, Translation, Transcription, UserPreferences, UsageAnalytics from redis_manager import RedisManager, redis_cache from redis_rate_limiter import RedisRateLimiter, rate_limit from redis_session_manager import RedisSessionManager, init_app as init_redis_sessions # Load environment variables load_dotenv() # Initialize logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Import other components from werkzeug.middleware.proxy_fix import ProxyFix from config import init_app as init_config from secrets_manager import init_app as init_secrets from request_size_limiter import RequestSizeLimiter, limit_request_size from error_logger import ErrorLogger, log_errors, log_performance, log_exception, get_logger from memory_manager import MemoryManager, AudioProcessingContext, with_memory_management # Error boundary decorator def with_error_boundary(func): @wraps(func) def wrapper(*args, **kwargs): try: return func(*args, **kwargs) except Exception as e: log_exception( e, message=f"Error in {func.__name__}", endpoint=request.endpoint, method=request.method, path=request.path, ip=request.remote_addr, function=func.__name__, module=func.__module__ ) if any(keyword in str(e).lower() for keyword in ['inject', 'attack', 'malicious', 'unauthorized']): app.error_logger.log_security( 'suspicious_error', severity='warning', error_type=type(e).__name__, error_message=str(e), endpoint=request.endpoint, ip=request.remote_addr ) error_message = str(e) if app.debug else "An internal error occurred" return jsonify({ 'success': False, 'error': error_message, 'component': func.__name__, 'request_id': getattr(g, 'request_id', None) }), 500 return wrapper app = Flask(__name__) # Apply ProxyFix middleware app.wsgi_app = ProxyFix( app.wsgi_app, x_for=1, x_proto=1, x_host=1, x_prefix=1 ) # Initialize configuration and secrets init_config(app) init_secrets(app) # Initialize database init_db(app) # Initialize Redis redis_manager = RedisManager(app) app.redis = redis_manager # Initialize Redis-based rate limiter redis_rate_limiter = RedisRateLimiter(redis_manager) app.redis_rate_limiter = redis_rate_limiter # Initialize Redis-based session management init_redis_sessions(app) # Configure CORS cors_config = { "origins": app.config.get('CORS_ORIGINS', ['*']), "methods": ["GET", "POST", "OPTIONS"], "allow_headers": ["Content-Type", "Authorization", "X-Requested-With", "X-Admin-Token"], "expose_headers": ["Content-Range", "X-Content-Range"], "supports_credentials": True, "max_age": 3600 } CORS(app, resources={ r"/api/*": cors_config, r"/transcribe": cors_config, r"/translate": cors_config, r"/translate/stream": cors_config, r"/speak": cors_config, r"/get_audio/*": cors_config, r"/check_tts_server": cors_config, r"/update_tts_config": cors_config, r"/health/*": cors_config, r"/admin/*": { **cors_config, "origins": app.config.get('ADMIN_CORS_ORIGINS', ['http://localhost:*']) } }) # Configure upload folder upload_folder = app.config.get('UPLOAD_FOLDER') if not upload_folder: upload_folder = os.path.join(tempfile.gettempdir(), 'talk2me_uploads') try: os.makedirs(upload_folder, mode=0o755, exist_ok=True) logger.info(f"Using upload folder: {upload_folder}") except Exception as e: logger.error(f"Failed to create upload folder {upload_folder}: {str(e)}") upload_folder = tempfile.mkdtemp(prefix='talk2me_') logger.warning(f"Falling back to temporary folder: {upload_folder}") app.config['UPLOAD_FOLDER'] = upload_folder # Initialize request size limiter request_size_limiter = RequestSizeLimiter(app, { 'max_content_length': app.config.get('MAX_CONTENT_LENGTH', 50 * 1024 * 1024), 'max_audio_size': app.config.get('MAX_AUDIO_SIZE', 25 * 1024 * 1024), 'max_json_size': app.config.get('MAX_JSON_SIZE', 1 * 1024 * 1024), 'max_image_size': app.config.get('MAX_IMAGE_SIZE', 10 * 1024 * 1024), }) # Initialize error logging error_logger = ErrorLogger(app, { 'log_level': app.config.get('LOG_LEVEL', 'INFO'), 'log_file': app.config.get('LOG_FILE', 'logs/talk2me.log'), 'error_log_file': app.config.get('ERROR_LOG_FILE', 'logs/errors.log'), 'max_bytes': app.config.get('LOG_MAX_BYTES', 50 * 1024 * 1024), 'backup_count': app.config.get('LOG_BACKUP_COUNT', 10) }) logger = get_logger(__name__) # Initialize memory management memory_manager = MemoryManager(app, { 'memory_threshold_mb': app.config.get('MEMORY_THRESHOLD_MB', 4096), 'gpu_memory_threshold_mb': app.config.get('GPU_MEMORY_THRESHOLD_MB', 2048), 'cleanup_interval': app.config.get('MEMORY_CLEANUP_INTERVAL', 30) }) # Initialize Whisper model logger.info("Initializing Whisper model with GPU optimization...") if torch.cuda.is_available(): device = torch.device("cuda") try: gpu_name = torch.cuda.get_device_name(0) if 'AMD' in gpu_name or 'Radeon' in gpu_name: logger.info(f"AMD GPU detected via ROCm: {gpu_name}") else: logger.info(f"NVIDIA GPU detected: {gpu_name}") except: logger.info("GPU detected - using CUDA/ROCm acceleration") elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): device = torch.device("mps") logger.info("Apple Silicon detected - using Metal Performance Shaders") else: device = torch.device("cpu") logger.info("No GPU acceleration available - using CPU") logger.info(f"Using device: {device}") whisper_model = whisper.load_model(MODEL_SIZE, device=device) # Enable GPU optimizations if device.type == 'cuda': try: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.benchmark = True whisper_model.eval() whisper_model = whisper_model.half() torch.cuda.empty_cache() logger.info("Warming up GPU with dummy inference...") with torch.no_grad(): dummy_audio = torch.randn(1, 16000 * 30).to(device).half() _ = whisper_model.encode(whisper.pad_or_trim(dummy_audio)) logger.info(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB") logger.info("Whisper model loaded and optimized for GPU") except Exception as e: logger.warning(f"Some GPU optimizations failed: {e}") elif device.type == 'mps': whisper_model.eval() logger.info("Whisper model loaded and optimized for Apple Silicon") else: whisper_model.eval() logger.info("Whisper model loaded (CPU mode)") memory_manager.set_whisper_model(whisper_model) app.whisper_model = whisper_model # Supported languages SUPPORTED_LANGUAGES = { "ar": "Arabic", "hy": "Armenian", "az": "Azerbaijani", "en": "English", "fr": "French", "ka": "Georgian", "kk": "Kazakh", "zh": "Mandarin", "fa": "Farsi", "pt": "Portuguese", "ru": "Russian", "es": "Spanish", "tr": "Turkish", "uz": "Uzbek" } LANGUAGE_TO_CODE = {v: k for k, v in SUPPORTED_LANGUAGES.items()} LANGUAGE_TO_VOICE = { "Arabic": "ar-EG-ShakirNeural", "Armenian": "echo", "Azerbaijani": "az-AZ-BanuNeural", "English": "en-GB-RyanNeural", "French": "fr-FR-DeniseNeural", "Georgian": "ka-GE-GiorgiNeural", "Kazakh": "kk-KZ-DauletNeural", "Mandarin": "zh-CN-YunjianNeural", "Farsi": "fa-IR-FaridNeural", "Portuguese": "pt-BR-ThalitaNeural", "Russian": "ru-RU-SvetlanaNeural", "Spanish": "es-CR-MariaNeural", "Turkish": "tr-TR-EmelNeural", "Uzbek": "uz-UZ-SardorNeural" } # Generate VAPID keys for push notifications if not os.path.exists('vapid_private.pem'): private_key = ec.generate_private_key(ec.SECP256R1(), default_backend()) public_key = private_key.public_key() with open('vapid_private.pem', 'wb') as f: f.write(private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption() )) with open('vapid_public.pem', 'wb') as f: f.write(public_key.public_bytes( encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo )) with open('vapid_private.pem', 'rb') as f: vapid_private_key = f.read() with open('vapid_public.pem', 'rb') as f: vapid_public_pem = f.read() vapid_public_key = serialization.load_pem_public_key( vapid_public_pem, backend=default_backend() ) public_numbers = vapid_public_key.public_numbers() x = public_numbers.x.to_bytes(32, byteorder='big') y = public_numbers.y.to_bytes(32, byteorder='big') vapid_public_key_base64 = base64.urlsafe_b64encode(b'\x04' + x + y).decode('utf-8').rstrip('=') # Store push subscriptions in Redis instead of memory # push_subscriptions = [] # Removed - now using Redis # Temporary file cleanup TEMP_FILE_MAX_AGE = 300 CLEANUP_INTERVAL = 60 def cleanup_temp_files(): """Clean up old temporary files""" try: current_time = datetime.now() # Clean files from upload folder if os.path.exists(app.config['UPLOAD_FOLDER']): for filename in os.listdir(app.config['UPLOAD_FOLDER']): filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) if os.path.isfile(filepath): file_age = current_time - datetime.fromtimestamp(os.path.getmtime(filepath)) if file_age > timedelta(seconds=TEMP_FILE_MAX_AGE): try: os.remove(filepath) logger.info(f"Cleaned up file: {filepath}") except Exception as e: logger.error(f"Failed to remove file {filepath}: {str(e)}") logger.debug("Cleanup completed") except Exception as e: logger.error(f"Error during temp file cleanup: {str(e)}") def run_cleanup_loop(): """Run cleanup in a separate thread""" while True: time.sleep(CLEANUP_INTERVAL) cleanup_temp_files() # Start cleanup thread cleanup_thread = threading.Thread(target=run_cleanup_loop, daemon=True) cleanup_thread.start() # Analytics collection helper def collect_analytics(service: str, duration_ms: int = None, metadata: dict = None): """Collect usage analytics to database""" try: from sqlalchemy import func today = datetime.utcnow().date() hour = datetime.utcnow().hour # Get or create analytics record analytics = UsageAnalytics.query.filter_by(date=today, hour=hour).first() if not analytics: analytics = UsageAnalytics(date=today, hour=hour) db.session.add(analytics) # Update counters analytics.total_requests += 1 if service == 'transcription': analytics.transcriptions += 1 if duration_ms: if analytics.avg_transcription_time_ms: analytics.avg_transcription_time_ms = ( (analytics.avg_transcription_time_ms * (analytics.transcriptions - 1) + duration_ms) / analytics.transcriptions ) else: analytics.avg_transcription_time_ms = duration_ms elif service == 'translation': analytics.translations += 1 if duration_ms: if analytics.avg_translation_time_ms: analytics.avg_translation_time_ms = ( (analytics.avg_translation_time_ms * (analytics.translations - 1) + duration_ms) / analytics.translations ) else: analytics.avg_translation_time_ms = duration_ms elif service == 'tts': analytics.tts_requests += 1 if duration_ms: if analytics.avg_tts_time_ms: analytics.avg_tts_time_ms = ( (analytics.avg_tts_time_ms * (analytics.tts_requests - 1) + duration_ms) / analytics.tts_requests ) else: analytics.avg_tts_time_ms = duration_ms db.session.commit() except Exception as e: logger.error(f"Failed to collect analytics: {e}") db.session.rollback() # Routes @app.route('/') def index(): return render_template('index.html', languages=sorted(SUPPORTED_LANGUAGES.values())) @app.route('/transcribe', methods=['POST']) @rate_limit(requests_per_minute=10, requests_per_hour=100, check_size=True) @limit_request_size(max_audio_size=25 * 1024 * 1024) @with_error_boundary @log_performance('transcribe_audio') @with_memory_management def transcribe(): start_time = time.time() with AudioProcessingContext(app.memory_manager, name='transcribe') as ctx: if 'audio' not in request.files: return jsonify({'error': 'No audio file provided'}), 400 audio_file = request.files['audio'] valid, error_msg = Validators.validate_audio_file(audio_file) if not valid: return jsonify({'error': error_msg}), 400 source_lang = request.form.get('source_lang', '') allowed_languages = set(SUPPORTED_LANGUAGES.values()) source_lang = Validators.validate_language_code(source_lang, allowed_languages) or '' temp_filename = f'input_audio_{int(time.time() * 1000)}.wav' temp_path = os.path.join(app.config['UPLOAD_FOLDER'], temp_filename) with open(temp_path, 'wb') as f: audio_file.save(f) ctx.add_temp_file(temp_path) # Add to Redis session if hasattr(g, 'session_manager') and hasattr(g, 'user_session'): file_size = os.path.getsize(temp_path) g.session_manager.add_resource( session_id=g.user_session.session_id, resource_type='audio_file', resource_id=temp_filename, path=temp_path, size_bytes=file_size, metadata={'filename': temp_filename, 'purpose': 'transcription'} ) try: auto_detect = source_lang == 'auto' or source_lang == '' transcribe_options = { "task": "transcribe", "temperature": 0, "best_of": 1, "beam_size": 1, "fp16": device.type == 'cuda', "condition_on_previous_text": False, "compression_ratio_threshold": 2.4, "logprob_threshold": -1.0, "no_speech_threshold": 0.6 } if not auto_detect: transcribe_options["language"] = LANGUAGE_TO_CODE.get(source_lang, None) if device.type == 'cuda': torch.cuda.empty_cache() with torch.no_grad(): result = whisper_model.transcribe( temp_path, **transcribe_options ) transcribed_text = result["text"] detected_language = None if auto_detect and 'language' in result: detected_code = result['language'] for lang_name, lang_code in LANGUAGE_TO_CODE.items(): if lang_code == detected_code: detected_language = lang_name break logger.info(f"Auto-detected language: {detected_language} ({detected_code})") # Calculate duration duration_ms = int((time.time() - start_time) * 1000) # Save to database try: transcription = Transcription( session_id=g.user_session.session_id if hasattr(g, 'user_session') else None, user_id=g.user_session.user_id if hasattr(g, 'user_session') else None, transcribed_text=transcribed_text, detected_language=detected_language or source_lang, transcription_time_ms=duration_ms, model_used=MODEL_SIZE, audio_file_size=os.path.getsize(temp_path), ip_address=request.remote_addr, user_agent=request.headers.get('User-Agent') ) db.session.add(transcription) db.session.commit() except Exception as e: logger.error(f"Failed to save transcription to database: {e}") db.session.rollback() # Collect analytics collect_analytics('transcription', duration_ms) # Send notification if push is enabled push_count = redis_manager.scard('push_subscriptions') if push_count > 0: send_push_notification( title="Transcription Complete", body=f"Successfully transcribed: {transcribed_text[:50]}...", tag="transcription-complete" ) response = { 'success': True, 'text': transcribed_text } if detected_language: response['detected_language'] = detected_language return jsonify(response) except Exception as e: logger.error(f"Transcription error: {str(e)}") return jsonify({'error': f'Transcription failed: {str(e)}'}), 500 finally: try: if 'temp_path' in locals() and os.path.exists(temp_path): os.remove(temp_path) except Exception as e: logger.error(f"Failed to clean up temp file: {e}") if device.type == 'cuda': torch.cuda.empty_cache() torch.cuda.synchronize() gc.collect() @app.route('/translate', methods=['POST']) @rate_limit(requests_per_minute=20, requests_per_hour=300, check_size=True) @limit_request_size(max_size=1 * 1024 * 1024) @with_error_boundary @log_performance('translate_text') def translate(): start_time = time.time() try: if not Validators.validate_json_size(request.json, max_size_kb=100): return jsonify({'error': 'Request too large'}), 413 data = request.json text = data.get('text', '') text = Validators.sanitize_text(text) if not text: return jsonify({'error': 'No text provided'}), 400 allowed_languages = set(SUPPORTED_LANGUAGES.values()) source_lang = Validators.validate_language_code( data.get('source_lang', ''), allowed_languages ) or 'auto' target_lang = Validators.validate_language_code( data.get('target_lang', ''), allowed_languages ) if not target_lang: return jsonify({'error': 'Invalid target language'}), 400 # Check cache first cached_translation = redis_manager.get_cached_translation( text, source_lang, target_lang ) if cached_translation: logger.info("Translation served from cache") return jsonify({ 'success': True, 'translation': cached_translation, 'cached': True }) # Create prompt for translation prompt = f""" Translate the following text from {source_lang} to {target_lang}: "{text}" Provide only the translation without any additional text. """ response = ollama.chat( model="gemma3:27b", messages=[ { "role": "user", "content": prompt } ] ) translated_text = response['message']['content'].strip() # Calculate duration duration_ms = int((time.time() - start_time) * 1000) # Cache the translation redis_manager.cache_translation( text, source_lang, target_lang, translated_text ) # Save to database try: translation = Translation( session_id=g.user_session.session_id if hasattr(g, 'user_session') else None, user_id=g.user_session.user_id if hasattr(g, 'user_session') else None, source_text=text, source_language=source_lang, target_text=translated_text, target_language=target_lang, translation_time_ms=duration_ms, model_used="gemma3:27b", ip_address=request.remote_addr, user_agent=request.headers.get('User-Agent') ) db.session.add(translation) db.session.commit() except Exception as e: logger.error(f"Failed to save translation to database: {e}") db.session.rollback() # Collect analytics collect_analytics('translation', duration_ms) # Send notification push_count = redis_manager.scard('push_subscriptions') if push_count > 0: send_push_notification( title="Translation Complete", body=f"Translated from {source_lang} to {target_lang}", tag="translation-complete", data={'translation': translated_text[:100]} ) return jsonify({ 'success': True, 'translation': translated_text }) except Exception as e: logger.error(f"Translation error: {str(e)}") return jsonify({'error': f'Translation failed: {str(e)}'}), 500 @app.route('/api/push-subscribe', methods=['POST']) @rate_limit(requests_per_minute=10, requests_per_hour=50) def push_subscribe(): try: subscription = request.json # Store subscription in Redis subscription_id = f"sub_{int(time.time() * 1000)}" redis_manager.set(f"push_subscription:{subscription_id}", subscription, expire=86400 * 30) # 30 days redis_manager.sadd('push_subscriptions', subscription_id) logger.info(f"New push subscription registered: {subscription_id}") return jsonify({'success': True}) except Exception as e: logger.error(f"Failed to register push subscription: {str(e)}") return jsonify({'success': False, 'error': str(e)}), 500 def send_push_notification(title, body, icon='/static/icons/icon-192x192.png', badge='/static/icons/icon-192x192.png', tag=None, data=None): """Send push notification to all subscribed clients""" claims = { "sub": "mailto:admin@talk2me.app", "exp": int(time.time()) + 86400 } notification_sent = 0 # Get all subscription IDs from Redis subscription_ids = redis_manager.smembers('push_subscriptions') for sub_id in subscription_ids: subscription = redis_manager.get(f"push_subscription:{sub_id}") if not subscription: continue try: webpush( subscription_info=subscription, data=json.dumps({ 'title': title, 'body': body, 'icon': icon, 'badge': badge, 'tag': tag or 'talk2me-notification', 'data': data or {} }), vapid_private_key=vapid_private_key, vapid_claims=claims ) notification_sent += 1 except WebPushException as e: logger.error(f"Failed to send push notification: {str(e)}") if e.response and e.response.status_code == 410: # Remove invalid subscription redis_manager.delete(f"push_subscription:{sub_id}") redis_manager.srem('push_subscriptions', sub_id) logger.info(f"Sent {notification_sent} push notifications") return notification_sent # Initialize app app.start_time = time.time() app.request_count = 0 @app.before_request def before_request(): app.request_count = getattr(app, 'request_count', 0) + 1 # Error handlers @app.errorhandler(404) def not_found_error(error): logger.warning(f"404 error: {request.url}") return jsonify({ 'success': False, 'error': 'Resource not found', 'status': 404 }), 404 @app.errorhandler(500) def internal_error(error): logger.error(f"500 error: {str(error)}") logger.error(traceback.format_exc()) return jsonify({ 'success': False, 'error': 'Internal server error', 'status': 500 }), 500 if __name__ == '__main__': app.run(host='0.0.0.0', port=5005, debug=True)