talk2me/app_with_db.py
Adolfo Delorenzo fa951c3141 Add comprehensive database integration, authentication, and admin dashboard
This commit introduces major enhancements to Talk2Me:

## Database Integration
- PostgreSQL support with SQLAlchemy ORM
- Redis integration for caching and real-time analytics
- Automated database initialization scripts
- Migration support infrastructure

## User Authentication System
- JWT-based API authentication
- Session-based web authentication
- API key authentication for programmatic access
- User roles and permissions (admin/user)
- Login history and session tracking
- Rate limiting per user with customizable limits

## Admin Dashboard
- Real-time analytics and monitoring
- User management interface (create, edit, delete users)
- System health monitoring
- Request/error tracking
- Language pair usage statistics
- Performance metrics visualization

## Key Features
- Dual authentication support (token + user accounts)
- Graceful fallback for missing services
- Non-blocking analytics middleware
- Comprehensive error handling
- Session management with security features

## Bug Fixes
- Fixed rate limiting bypass for admin routes
- Added missing email validation method
- Improved error handling for missing database tables
- Fixed session-based authentication for API endpoints

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-06-03 18:21:56 -06:00

746 lines
26 KiB
Python

# This is the updated app.py with Redis and PostgreSQL integration
# To use this, rename it to app.py after backing up the original
import os
import time
import tempfile
import requests
import json
import logging
from dotenv import load_dotenv
from flask import Flask, render_template, request, jsonify, Response, send_file, send_from_directory, stream_with_context, g
from flask_cors import CORS, cross_origin
import whisper
import torch
import ollama
from whisper_config import MODEL_SIZE, GPU_OPTIMIZATIONS, TRANSCRIBE_OPTIONS
from pywebpush import webpush, WebPushException
import base64
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
import gc
from functools import wraps
import traceback
from validators import Validators
import atexit
import threading
from datetime import datetime, timedelta
# Import new database and Redis components
from database import db, init_db, Translation, Transcription, UserPreferences, UsageAnalytics
from redis_manager import RedisManager, redis_cache
from redis_rate_limiter import RedisRateLimiter, rate_limit
from redis_session_manager import RedisSessionManager, init_app as init_redis_sessions
# Load environment variables
load_dotenv()
# Initialize logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Import other components
from werkzeug.middleware.proxy_fix import ProxyFix
from config import init_app as init_config
from secrets_manager import init_app as init_secrets
from request_size_limiter import RequestSizeLimiter, limit_request_size
from error_logger import ErrorLogger, log_errors, log_performance, log_exception, get_logger
from memory_manager import MemoryManager, AudioProcessingContext, with_memory_management
# Error boundary decorator
def with_error_boundary(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
log_exception(
e,
message=f"Error in {func.__name__}",
endpoint=request.endpoint,
method=request.method,
path=request.path,
ip=request.remote_addr,
function=func.__name__,
module=func.__module__
)
if any(keyword in str(e).lower() for keyword in ['inject', 'attack', 'malicious', 'unauthorized']):
app.error_logger.log_security(
'suspicious_error',
severity='warning',
error_type=type(e).__name__,
error_message=str(e),
endpoint=request.endpoint,
ip=request.remote_addr
)
error_message = str(e) if app.debug else "An internal error occurred"
return jsonify({
'success': False,
'error': error_message,
'component': func.__name__,
'request_id': getattr(g, 'request_id', None)
}), 500
return wrapper
app = Flask(__name__)
# Apply ProxyFix middleware
app.wsgi_app = ProxyFix(
app.wsgi_app,
x_for=1,
x_proto=1,
x_host=1,
x_prefix=1
)
# Initialize configuration and secrets
init_config(app)
init_secrets(app)
# Initialize database
init_db(app)
# Initialize Redis
redis_manager = RedisManager(app)
app.redis = redis_manager
# Initialize Redis-based rate limiter
redis_rate_limiter = RedisRateLimiter(redis_manager)
app.redis_rate_limiter = redis_rate_limiter
# Initialize Redis-based session management
init_redis_sessions(app)
# Configure CORS
cors_config = {
"origins": app.config.get('CORS_ORIGINS', ['*']),
"methods": ["GET", "POST", "OPTIONS"],
"allow_headers": ["Content-Type", "Authorization", "X-Requested-With", "X-Admin-Token"],
"expose_headers": ["Content-Range", "X-Content-Range"],
"supports_credentials": True,
"max_age": 3600
}
CORS(app, resources={
r"/api/*": cors_config,
r"/transcribe": cors_config,
r"/translate": cors_config,
r"/translate/stream": cors_config,
r"/speak": cors_config,
r"/get_audio/*": cors_config,
r"/check_tts_server": cors_config,
r"/update_tts_config": cors_config,
r"/health/*": cors_config,
r"/admin/*": {
**cors_config,
"origins": app.config.get('ADMIN_CORS_ORIGINS', ['http://localhost:*'])
}
})
# Configure upload folder
upload_folder = app.config.get('UPLOAD_FOLDER')
if not upload_folder:
upload_folder = os.path.join(tempfile.gettempdir(), 'talk2me_uploads')
try:
os.makedirs(upload_folder, mode=0o755, exist_ok=True)
logger.info(f"Using upload folder: {upload_folder}")
except Exception as e:
logger.error(f"Failed to create upload folder {upload_folder}: {str(e)}")
upload_folder = tempfile.mkdtemp(prefix='talk2me_')
logger.warning(f"Falling back to temporary folder: {upload_folder}")
app.config['UPLOAD_FOLDER'] = upload_folder
# Initialize request size limiter
request_size_limiter = RequestSizeLimiter(app, {
'max_content_length': app.config.get('MAX_CONTENT_LENGTH', 50 * 1024 * 1024),
'max_audio_size': app.config.get('MAX_AUDIO_SIZE', 25 * 1024 * 1024),
'max_json_size': app.config.get('MAX_JSON_SIZE', 1 * 1024 * 1024),
'max_image_size': app.config.get('MAX_IMAGE_SIZE', 10 * 1024 * 1024),
})
# Initialize error logging
error_logger = ErrorLogger(app, {
'log_level': app.config.get('LOG_LEVEL', 'INFO'),
'log_file': app.config.get('LOG_FILE', 'logs/talk2me.log'),
'error_log_file': app.config.get('ERROR_LOG_FILE', 'logs/errors.log'),
'max_bytes': app.config.get('LOG_MAX_BYTES', 50 * 1024 * 1024),
'backup_count': app.config.get('LOG_BACKUP_COUNT', 10)
})
logger = get_logger(__name__)
# Initialize memory management
memory_manager = MemoryManager(app, {
'memory_threshold_mb': app.config.get('MEMORY_THRESHOLD_MB', 4096),
'gpu_memory_threshold_mb': app.config.get('GPU_MEMORY_THRESHOLD_MB', 2048),
'cleanup_interval': app.config.get('MEMORY_CLEANUP_INTERVAL', 30)
})
# Initialize Whisper model
logger.info("Initializing Whisper model with GPU optimization...")
if torch.cuda.is_available():
device = torch.device("cuda")
try:
gpu_name = torch.cuda.get_device_name(0)
if 'AMD' in gpu_name or 'Radeon' in gpu_name:
logger.info(f"AMD GPU detected via ROCm: {gpu_name}")
else:
logger.info(f"NVIDIA GPU detected: {gpu_name}")
except:
logger.info("GPU detected - using CUDA/ROCm acceleration")
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
device = torch.device("mps")
logger.info("Apple Silicon detected - using Metal Performance Shaders")
else:
device = torch.device("cpu")
logger.info("No GPU acceleration available - using CPU")
logger.info(f"Using device: {device}")
whisper_model = whisper.load_model(MODEL_SIZE, device=device)
# Enable GPU optimizations
if device.type == 'cuda':
try:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
whisper_model.eval()
whisper_model = whisper_model.half()
torch.cuda.empty_cache()
logger.info("Warming up GPU with dummy inference...")
with torch.no_grad():
dummy_audio = torch.randn(1, 16000 * 30).to(device).half()
_ = whisper_model.encode(whisper.pad_or_trim(dummy_audio))
logger.info(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
logger.info("Whisper model loaded and optimized for GPU")
except Exception as e:
logger.warning(f"Some GPU optimizations failed: {e}")
elif device.type == 'mps':
whisper_model.eval()
logger.info("Whisper model loaded and optimized for Apple Silicon")
else:
whisper_model.eval()
logger.info("Whisper model loaded (CPU mode)")
memory_manager.set_whisper_model(whisper_model)
app.whisper_model = whisper_model
# Supported languages
SUPPORTED_LANGUAGES = {
"ar": "Arabic",
"hy": "Armenian",
"az": "Azerbaijani",
"en": "English",
"fr": "French",
"ka": "Georgian",
"kk": "Kazakh",
"zh": "Mandarin",
"fa": "Farsi",
"pt": "Portuguese",
"ru": "Russian",
"es": "Spanish",
"tr": "Turkish",
"uz": "Uzbek"
}
LANGUAGE_TO_CODE = {v: k for k, v in SUPPORTED_LANGUAGES.items()}
LANGUAGE_TO_VOICE = {
"Arabic": "ar-EG-ShakirNeural",
"Armenian": "echo",
"Azerbaijani": "az-AZ-BanuNeural",
"English": "en-GB-RyanNeural",
"French": "fr-FR-DeniseNeural",
"Georgian": "ka-GE-GiorgiNeural",
"Kazakh": "kk-KZ-DauletNeural",
"Mandarin": "zh-CN-YunjianNeural",
"Farsi": "fa-IR-FaridNeural",
"Portuguese": "pt-BR-ThalitaNeural",
"Russian": "ru-RU-SvetlanaNeural",
"Spanish": "es-CR-MariaNeural",
"Turkish": "tr-TR-EmelNeural",
"Uzbek": "uz-UZ-SardorNeural"
}
# Generate VAPID keys for push notifications
if not os.path.exists('vapid_private.pem'):
private_key = ec.generate_private_key(ec.SECP256R1(), default_backend())
public_key = private_key.public_key()
with open('vapid_private.pem', 'wb') as f:
f.write(private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption()
))
with open('vapid_public.pem', 'wb') as f:
f.write(public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
))
with open('vapid_private.pem', 'rb') as f:
vapid_private_key = f.read()
with open('vapid_public.pem', 'rb') as f:
vapid_public_pem = f.read()
vapid_public_key = serialization.load_pem_public_key(
vapid_public_pem,
backend=default_backend()
)
public_numbers = vapid_public_key.public_numbers()
x = public_numbers.x.to_bytes(32, byteorder='big')
y = public_numbers.y.to_bytes(32, byteorder='big')
vapid_public_key_base64 = base64.urlsafe_b64encode(b'\x04' + x + y).decode('utf-8').rstrip('=')
# Store push subscriptions in Redis instead of memory
# push_subscriptions = [] # Removed - now using Redis
# Temporary file cleanup
TEMP_FILE_MAX_AGE = 300
CLEANUP_INTERVAL = 60
def cleanup_temp_files():
"""Clean up old temporary files"""
try:
current_time = datetime.now()
# Clean files from upload folder
if os.path.exists(app.config['UPLOAD_FOLDER']):
for filename in os.listdir(app.config['UPLOAD_FOLDER']):
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
if os.path.isfile(filepath):
file_age = current_time - datetime.fromtimestamp(os.path.getmtime(filepath))
if file_age > timedelta(seconds=TEMP_FILE_MAX_AGE):
try:
os.remove(filepath)
logger.info(f"Cleaned up file: {filepath}")
except Exception as e:
logger.error(f"Failed to remove file {filepath}: {str(e)}")
logger.debug("Cleanup completed")
except Exception as e:
logger.error(f"Error during temp file cleanup: {str(e)}")
def run_cleanup_loop():
"""Run cleanup in a separate thread"""
while True:
time.sleep(CLEANUP_INTERVAL)
cleanup_temp_files()
# Start cleanup thread
cleanup_thread = threading.Thread(target=run_cleanup_loop, daemon=True)
cleanup_thread.start()
# Analytics collection helper
def collect_analytics(service: str, duration_ms: int = None, metadata: dict = None):
"""Collect usage analytics to database"""
try:
from sqlalchemy import func
today = datetime.utcnow().date()
hour = datetime.utcnow().hour
# Get or create analytics record
analytics = UsageAnalytics.query.filter_by(date=today, hour=hour).first()
if not analytics:
analytics = UsageAnalytics(date=today, hour=hour)
db.session.add(analytics)
# Update counters
analytics.total_requests += 1
if service == 'transcription':
analytics.transcriptions += 1
if duration_ms:
if analytics.avg_transcription_time_ms:
analytics.avg_transcription_time_ms = (
(analytics.avg_transcription_time_ms * (analytics.transcriptions - 1) + duration_ms)
/ analytics.transcriptions
)
else:
analytics.avg_transcription_time_ms = duration_ms
elif service == 'translation':
analytics.translations += 1
if duration_ms:
if analytics.avg_translation_time_ms:
analytics.avg_translation_time_ms = (
(analytics.avg_translation_time_ms * (analytics.translations - 1) + duration_ms)
/ analytics.translations
)
else:
analytics.avg_translation_time_ms = duration_ms
elif service == 'tts':
analytics.tts_requests += 1
if duration_ms:
if analytics.avg_tts_time_ms:
analytics.avg_tts_time_ms = (
(analytics.avg_tts_time_ms * (analytics.tts_requests - 1) + duration_ms)
/ analytics.tts_requests
)
else:
analytics.avg_tts_time_ms = duration_ms
db.session.commit()
except Exception as e:
logger.error(f"Failed to collect analytics: {e}")
db.session.rollback()
# Routes
@app.route('/')
def index():
return render_template('index.html', languages=sorted(SUPPORTED_LANGUAGES.values()))
@app.route('/transcribe', methods=['POST'])
@rate_limit(requests_per_minute=10, requests_per_hour=100, check_size=True)
@limit_request_size(max_audio_size=25 * 1024 * 1024)
@with_error_boundary
@log_performance('transcribe_audio')
@with_memory_management
def transcribe():
start_time = time.time()
with AudioProcessingContext(app.memory_manager, name='transcribe') as ctx:
if 'audio' not in request.files:
return jsonify({'error': 'No audio file provided'}), 400
audio_file = request.files['audio']
valid, error_msg = Validators.validate_audio_file(audio_file)
if not valid:
return jsonify({'error': error_msg}), 400
source_lang = request.form.get('source_lang', '')
allowed_languages = set(SUPPORTED_LANGUAGES.values())
source_lang = Validators.validate_language_code(source_lang, allowed_languages) or ''
temp_filename = f'input_audio_{int(time.time() * 1000)}.wav'
temp_path = os.path.join(app.config['UPLOAD_FOLDER'], temp_filename)
with open(temp_path, 'wb') as f:
audio_file.save(f)
ctx.add_temp_file(temp_path)
# Add to Redis session
if hasattr(g, 'session_manager') and hasattr(g, 'user_session'):
file_size = os.path.getsize(temp_path)
g.session_manager.add_resource(
session_id=g.user_session.session_id,
resource_type='audio_file',
resource_id=temp_filename,
path=temp_path,
size_bytes=file_size,
metadata={'filename': temp_filename, 'purpose': 'transcription'}
)
try:
auto_detect = source_lang == 'auto' or source_lang == ''
transcribe_options = {
"task": "transcribe",
"temperature": 0,
"best_of": 1,
"beam_size": 1,
"fp16": device.type == 'cuda',
"condition_on_previous_text": False,
"compression_ratio_threshold": 2.4,
"logprob_threshold": -1.0,
"no_speech_threshold": 0.6
}
if not auto_detect:
transcribe_options["language"] = LANGUAGE_TO_CODE.get(source_lang, None)
if device.type == 'cuda':
torch.cuda.empty_cache()
with torch.no_grad():
result = whisper_model.transcribe(
temp_path,
**transcribe_options
)
transcribed_text = result["text"]
detected_language = None
if auto_detect and 'language' in result:
detected_code = result['language']
for lang_name, lang_code in LANGUAGE_TO_CODE.items():
if lang_code == detected_code:
detected_language = lang_name
break
logger.info(f"Auto-detected language: {detected_language} ({detected_code})")
# Calculate duration
duration_ms = int((time.time() - start_time) * 1000)
# Save to database
try:
transcription = Transcription(
session_id=g.user_session.session_id if hasattr(g, 'user_session') else None,
user_id=g.user_session.user_id if hasattr(g, 'user_session') else None,
transcribed_text=transcribed_text,
detected_language=detected_language or source_lang,
transcription_time_ms=duration_ms,
model_used=MODEL_SIZE,
audio_file_size=os.path.getsize(temp_path),
ip_address=request.remote_addr,
user_agent=request.headers.get('User-Agent')
)
db.session.add(transcription)
db.session.commit()
except Exception as e:
logger.error(f"Failed to save transcription to database: {e}")
db.session.rollback()
# Collect analytics
collect_analytics('transcription', duration_ms)
# Send notification if push is enabled
push_count = redis_manager.scard('push_subscriptions')
if push_count > 0:
send_push_notification(
title="Transcription Complete",
body=f"Successfully transcribed: {transcribed_text[:50]}...",
tag="transcription-complete"
)
response = {
'success': True,
'text': transcribed_text
}
if detected_language:
response['detected_language'] = detected_language
return jsonify(response)
except Exception as e:
logger.error(f"Transcription error: {str(e)}")
return jsonify({'error': f'Transcription failed: {str(e)}'}), 500
finally:
try:
if 'temp_path' in locals() and os.path.exists(temp_path):
os.remove(temp_path)
except Exception as e:
logger.error(f"Failed to clean up temp file: {e}")
if device.type == 'cuda':
torch.cuda.empty_cache()
torch.cuda.synchronize()
gc.collect()
@app.route('/translate', methods=['POST'])
@rate_limit(requests_per_minute=20, requests_per_hour=300, check_size=True)
@limit_request_size(max_size=1 * 1024 * 1024)
@with_error_boundary
@log_performance('translate_text')
def translate():
start_time = time.time()
try:
if not Validators.validate_json_size(request.json, max_size_kb=100):
return jsonify({'error': 'Request too large'}), 413
data = request.json
text = data.get('text', '')
text = Validators.sanitize_text(text)
if not text:
return jsonify({'error': 'No text provided'}), 400
allowed_languages = set(SUPPORTED_LANGUAGES.values())
source_lang = Validators.validate_language_code(
data.get('source_lang', ''), allowed_languages
) or 'auto'
target_lang = Validators.validate_language_code(
data.get('target_lang', ''), allowed_languages
)
if not target_lang:
return jsonify({'error': 'Invalid target language'}), 400
# Check cache first
cached_translation = redis_manager.get_cached_translation(
text, source_lang, target_lang
)
if cached_translation:
logger.info("Translation served from cache")
return jsonify({
'success': True,
'translation': cached_translation,
'cached': True
})
# Create prompt for translation
prompt = f"""
Translate the following text from {source_lang} to {target_lang}:
"{text}"
Provide only the translation without any additional text.
"""
response = ollama.chat(
model="gemma3:27b",
messages=[
{
"role": "user",
"content": prompt
}
]
)
translated_text = response['message']['content'].strip()
# Calculate duration
duration_ms = int((time.time() - start_time) * 1000)
# Cache the translation
redis_manager.cache_translation(
text, source_lang, target_lang, translated_text
)
# Save to database
try:
translation = Translation(
session_id=g.user_session.session_id if hasattr(g, 'user_session') else None,
user_id=g.user_session.user_id if hasattr(g, 'user_session') else None,
source_text=text,
source_language=source_lang,
target_text=translated_text,
target_language=target_lang,
translation_time_ms=duration_ms,
model_used="gemma3:27b",
ip_address=request.remote_addr,
user_agent=request.headers.get('User-Agent')
)
db.session.add(translation)
db.session.commit()
except Exception as e:
logger.error(f"Failed to save translation to database: {e}")
db.session.rollback()
# Collect analytics
collect_analytics('translation', duration_ms)
# Send notification
push_count = redis_manager.scard('push_subscriptions')
if push_count > 0:
send_push_notification(
title="Translation Complete",
body=f"Translated from {source_lang} to {target_lang}",
tag="translation-complete",
data={'translation': translated_text[:100]}
)
return jsonify({
'success': True,
'translation': translated_text
})
except Exception as e:
logger.error(f"Translation error: {str(e)}")
return jsonify({'error': f'Translation failed: {str(e)}'}), 500
@app.route('/api/push-subscribe', methods=['POST'])
@rate_limit(requests_per_minute=10, requests_per_hour=50)
def push_subscribe():
try:
subscription = request.json
# Store subscription in Redis
subscription_id = f"sub_{int(time.time() * 1000)}"
redis_manager.set(f"push_subscription:{subscription_id}", subscription, expire=86400 * 30) # 30 days
redis_manager.sadd('push_subscriptions', subscription_id)
logger.info(f"New push subscription registered: {subscription_id}")
return jsonify({'success': True})
except Exception as e:
logger.error(f"Failed to register push subscription: {str(e)}")
return jsonify({'success': False, 'error': str(e)}), 500
def send_push_notification(title, body, icon='/static/icons/icon-192x192.png',
badge='/static/icons/icon-192x192.png', tag=None, data=None):
"""Send push notification to all subscribed clients"""
claims = {
"sub": "mailto:admin@talk2me.app",
"exp": int(time.time()) + 86400
}
notification_sent = 0
# Get all subscription IDs from Redis
subscription_ids = redis_manager.smembers('push_subscriptions')
for sub_id in subscription_ids:
subscription = redis_manager.get(f"push_subscription:{sub_id}")
if not subscription:
continue
try:
webpush(
subscription_info=subscription,
data=json.dumps({
'title': title,
'body': body,
'icon': icon,
'badge': badge,
'tag': tag or 'talk2me-notification',
'data': data or {}
}),
vapid_private_key=vapid_private_key,
vapid_claims=claims
)
notification_sent += 1
except WebPushException as e:
logger.error(f"Failed to send push notification: {str(e)}")
if e.response and e.response.status_code == 410:
# Remove invalid subscription
redis_manager.delete(f"push_subscription:{sub_id}")
redis_manager.srem('push_subscriptions', sub_id)
logger.info(f"Sent {notification_sent} push notifications")
return notification_sent
# Initialize app
app.start_time = time.time()
app.request_count = 0
@app.before_request
def before_request():
app.request_count = getattr(app, 'request_count', 0) + 1
# Error handlers
@app.errorhandler(404)
def not_found_error(error):
logger.warning(f"404 error: {request.url}")
return jsonify({
'success': False,
'error': 'Resource not found',
'status': 404
}), 404
@app.errorhandler(500)
def internal_error(error):
logger.error(f"500 error: {str(error)}")
logger.error(traceback.format_exc())
return jsonify({
'success': False,
'error': 'Internal server error',
'status': 500
}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5005, debug=True)