import os import time import tempfile import requests import json import logging from flask import Flask, render_template, request, jsonify, Response, send_file, send_from_directory import whisper import torch import ollama from whisper_config import MODEL_SIZE, GPU_OPTIMIZATIONS, TRANSCRIBE_OPTIONS from pywebpush import webpush, WebPushException import base64 from cryptography.hazmat.primitives.asymmetric import ec from cryptography.hazmat.primitives import serialization from cryptography.hazmat.backends import default_backend import gc # For garbage collection # Initialize logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = Flask(__name__) app.config['UPLOAD_FOLDER'] = tempfile.mkdtemp() app.config['TTS_SERVER'] = os.environ.get('TTS_SERVER_URL', 'http://localhost:5050/v1/audio/speech') app.config['TTS_API_KEY'] = os.environ.get('TTS_API_KEY', '56461d8b44607f2cfcb8030dee313a8e') # Generate VAPID keys for push notifications if not os.path.exists('vapid_private.pem'): # Generate new VAPID keys private_key = ec.generate_private_key(ec.SECP256R1(), default_backend()) public_key = private_key.public_key() # Save private key with open('vapid_private.pem', 'wb') as f: f.write(private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption() )) # Save public key with open('vapid_public.pem', 'wb') as f: f.write(public_key.public_bytes( encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo )) # Load VAPID keys with open('vapid_private.pem', 'rb') as f: vapid_private_key = f.read() with open('vapid_public.pem', 'rb') as f: vapid_public_pem = f.read() vapid_public_key = serialization.load_pem_public_key( vapid_public_pem, backend=default_backend() ) # Convert public key to base64 for client public_numbers = vapid_public_key.public_numbers() x = public_numbers.x.to_bytes(32, byteorder='big') y = public_numbers.y.to_bytes(32, byteorder='big') vapid_public_key_base64 = base64.urlsafe_b64encode(b'\x04' + x + y).decode('utf-8').rstrip('=') # Store subscriptions in memory (in production, use a database) push_subscriptions = [] @app.route('/') def root_files(filename): # Check if requested file is one of the common icon filenames common_icons = [ 'favicon.ico', 'apple-touch-icon.png', 'apple-touch-icon-precomposed.png', 'apple-touch-icon-120x120.png', 'apple-touch-icon-120x120-precomposed.png' ] if filename in common_icons: # Map to appropriate icon in static/icons icon_mapping = { 'favicon.ico': 'favicon.ico', 'apple-touch-icon.png': 'apple-icon-180x180.png', 'apple-touch-icon-precomposed.png': 'apple-icon-180x180.png', 'apple-touch-icon-120x120.png': 'apple-icon-120x120.png', 'apple-touch-icon-120x120-precomposed.png': 'apple-icon-120x120.png' } return send_from_directory('static/icons', icon_mapping.get(filename, 'apple-icon-180x180.png')) # If not an icon, return 404 return "File not found", 404 @app.route('/favicon.ico') def favicon(): return send_from_directory('static/icons', 'favicon.ico') @app.route('/apple-touch-icon.png') def apple_touch_icon(): return send_from_directory('static/icons', 'apple-icon-180x180.png') @app.route('/apple-touch-icon-precomposed.png') def apple_touch_icon_precomposed(): return send_from_directory('static/icons', 'apple-icon-180x180.png') @app.route('/apple-touch-icon-120x120.png') def apple_touch_icon_120(): return send_from_directory('static/icons', 'apple-icon-120x120.png') @app.route('/apple-touch-icon-120x120-precomposed.png') def apple_touch_icon_120_precomposed(): return send_from_directory('static/icons', 'apple-icon-120x120.png') # Add this route to your Flask app @app.route('/service-worker.js') def service_worker(): return app.send_static_file('service-worker.js') # Make sure static files are served properly app.static_folder = 'static' @app.route('/static/icons/') def serve_icon(filename): return send_from_directory('static/icons', filename) @app.route('/api/push-public-key', methods=['GET']) def push_public_key(): return jsonify({'publicKey': vapid_public_key_base64}) @app.route('/api/push-subscribe', methods=['POST']) def push_subscribe(): try: subscription = request.json # Store subscription (in production, use a database) if subscription not in push_subscriptions: push_subscriptions.append(subscription) logger.info(f"New push subscription registered. Total subscriptions: {len(push_subscriptions)}") return jsonify({'success': True}) except Exception as e: logger.error(f"Failed to register push subscription: {str(e)}") return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/push-unsubscribe', methods=['POST']) def push_unsubscribe(): try: subscription = request.json # Remove subscription if subscription in push_subscriptions: push_subscriptions.remove(subscription) logger.info(f"Push subscription removed. Total subscriptions: {len(push_subscriptions)}") return jsonify({'success': True}) except Exception as e: logger.error(f"Failed to unsubscribe: {str(e)}") return jsonify({'success': False, 'error': str(e)}), 500 def send_push_notification(title, body, icon='/static/icons/icon-192x192.png', badge='/static/icons/icon-192x192.png', tag=None, data=None): """Send push notification to all subscribed clients""" claims = { "sub": "mailto:admin@talk2me.app", "exp": int(time.time()) + 86400 # 24 hours } notification_sent = 0 for subscription in push_subscriptions[:]: # Create a copy to iterate try: webpush( subscription_info=subscription, data=json.dumps({ 'title': title, 'body': body, 'icon': icon, 'badge': badge, 'tag': tag or 'talk2me-notification', 'data': data or {} }), vapid_private_key=vapid_private_key, vapid_claims=claims ) notification_sent += 1 except WebPushException as e: logger.error(f"Failed to send push notification: {str(e)}") # Remove invalid subscription if e.response and e.response.status_code == 410: push_subscriptions.remove(subscription) logger.info(f"Sent {notification_sent} push notifications") return notification_sent # Add a route to check TTS server status @app.route('/check_tts_server', methods=['GET']) def check_tts_server(): try: # Get current TTS server configuration tts_server_url = app.config['TTS_SERVER'] tts_api_key = app.config['TTS_API_KEY'] # Try a simple request to the TTS server with a minimal payload headers = { "Content-Type": "application/json", "Authorization": f"Bearer {tts_api_key}" } # For status check, we'll just check if the server responds to a HEAD request # or a minimal POST with a very short text to minimize bandwidth usage try: response = requests.head( tts_server_url.split('/v1/audio/speech')[0] + '/v1/models', headers=headers, timeout=5 ) status_code = response.status_code except: # If HEAD request fails, try minimal POST response = requests.post( tts_server_url, headers=headers, json={ "input": "Test", "voice": "echo", "response_format": "mp3", "speed": 1.0 }, timeout=5 ) status_code = response.status_code if status_code in [200, 401, 403]: # Even auth errors mean server is running logger.info(f"TTS server is reachable at {tts_server_url}") return jsonify({ 'status': 'online' if status_code == 200 else 'auth_error', 'message': 'TTS server is online' if status_code == 200 else 'Authentication error. Check API key.', 'url': tts_server_url, 'code': status_code }) else: logger.warning(f"TTS server returned status code {status_code}") return jsonify({ 'status': 'error', 'message': f'TTS server returned status code {status_code}', 'url': tts_server_url, 'code': status_code }) except requests.exceptions.RequestException as e: logger.error(f"Cannot connect to TTS server: {str(e)}") return jsonify({ 'status': 'error', 'message': f'Cannot connect to TTS server: {str(e)}', 'url': app.config['TTS_SERVER'] }) @app.route('/update_tts_config', methods=['POST']) def update_tts_config(): try: data = request.json tts_server_url = data.get('server_url') tts_api_key = data.get('api_key') if tts_server_url: app.config['TTS_SERVER'] = tts_server_url logger.info(f"Updated TTS server URL to {tts_server_url}") if tts_api_key: app.config['TTS_API_KEY'] = tts_api_key logger.info("Updated TTS API key") return jsonify({ 'success': True, 'message': 'TTS configuration updated', 'url': app.config['TTS_SERVER'] }) except Exception as e: logger.error(f"Failed to update TTS config: {str(e)}") return jsonify({ 'success': False, 'error': f'Failed to update TTS config: {str(e)}' }), 500 # Initialize Whisper model with GPU optimization logger.info("Initializing Whisper model with GPU optimization...") # Detect available acceleration if torch.cuda.is_available(): device = torch.device("cuda") # Check if it's AMD or NVIDIA try: gpu_name = torch.cuda.get_device_name(0) if 'AMD' in gpu_name or 'Radeon' in gpu_name: logger.info(f"AMD GPU detected via ROCm: {gpu_name}") logger.info("Using ROCm acceleration (limited optimizations)") else: logger.info(f"NVIDIA GPU detected: {gpu_name}") logger.info("Using CUDA acceleration with full optimizations") except: logger.info("GPU detected - using CUDA/ROCm acceleration") elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): device = torch.device("mps") logger.info("Apple Silicon detected - using Metal Performance Shaders") else: device = torch.device("cpu") logger.info("No GPU acceleration available - using CPU") logger.info(f"Using device: {device}") # Load model with optimizations whisper_model = whisper.load_model(MODEL_SIZE, device=device) # Enable GPU optimizations based on device type if device.type == 'cuda': # NVIDIA GPU optimizations try: # Enable TensorFloat-32 for faster computation on Ampere GPUs torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True # Enable cudnn autotuner for optimized convolution algorithms torch.backends.cudnn.benchmark = True # Set model to evaluation mode and enable half precision for faster inference whisper_model.eval() whisper_model = whisper_model.half() # FP16 for faster GPU inference # Pre-allocate GPU memory to avoid fragmentation torch.cuda.empty_cache() # Warm up the model with a dummy input to cache CUDA kernels logger.info("Warming up GPU with dummy inference...") with torch.no_grad(): # Create a dummy audio tensor (30 seconds at 16kHz) dummy_audio = torch.randn(1, 16000 * 30).to(device).half() _ = whisper_model.encode(whisper.pad_or_trim(dummy_audio)) logger.info(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB") logger.info("Whisper model loaded and optimized for NVIDIA GPU") except Exception as e: logger.warning(f"Some NVIDIA optimizations failed: {e}") elif device.type == 'mps': # Apple Silicon optimizations whisper_model.eval() # MPS doesn't support half precision well yet logger.info("Whisper model loaded and optimized for Apple Silicon") else: # CPU mode whisper_model.eval() logger.info("Whisper model loaded (CPU mode)") # Supported languages SUPPORTED_LANGUAGES = { "ar": "Arabic", "hy": "Armenian", "az": "Azerbaijani", "en": "English", "fr": "French", "ka": "Georgian", "kk": "Kazakh", "zh": "Mandarin", "fa": "Farsi", "pt": "Portuguese", "ru": "Russian", "es": "Spanish", "tr": "Turkish", "uz": "Uzbek" } # Map language names to language codes LANGUAGE_TO_CODE = {v: k for k, v in SUPPORTED_LANGUAGES.items()} # Map language names to OpenAI TTS voice options LANGUAGE_TO_VOICE = { "Arabic": "ar-EG-ShakirNeural", # Using OpenAI general voices "Armenian": "echo", # as OpenAI doesn't have specific voices "Azerbaijani": "az-AZ-BanuNeural", # for all these languages "English": "en-GB-RyanNeural", # We'll use the available voices "French": "fr-FR-DeniseNeural", # and rely on the translation being "Georgian": "ka-GE-GiorgiNeural", # in the correct language text "Kazakh": "kk-KZ-DauletNeural", "Mandarin": "zh-CN-YunjianNeural", "Farsi": "fa-IR-FaridNeural", "Portuguese": "pt-BR-ThalitaNeural", "Russian": "ru-RU-SvetlanaNeural", "Spanish": "es-CR-MariaNeural", "Turkish": "tr-TR-EmelNeural", "Uzbek": "uz-UZ-SardorNeural" } @app.route('/') def index(): return render_template('index.html', languages=sorted(SUPPORTED_LANGUAGES.values())) @app.route('/transcribe', methods=['POST']) def transcribe(): if 'audio' not in request.files: return jsonify({'error': 'No audio file provided'}), 400 audio_file = request.files['audio'] source_lang = request.form.get('source_lang', '') # Save the audio file temporarily temp_path = os.path.join(app.config['UPLOAD_FOLDER'], 'input_audio.wav') audio_file.save(temp_path) try: # Check if we should auto-detect language auto_detect = source_lang == 'auto' or source_lang == '' # Use Whisper for transcription with GPU optimizations transcribe_options = { "task": "transcribe", "temperature": 0, # Disable temperature sampling for faster inference "best_of": 1, # Disable beam search for faster inference "beam_size": 1, # Disable beam search "fp16": device.type == 'cuda', # Use FP16 on GPU "condition_on_previous_text": False, # Faster inference "compression_ratio_threshold": 2.4, "logprob_threshold": -1.0, "no_speech_threshold": 0.6 } # Only set language if not auto-detecting if not auto_detect: transcribe_options["language"] = LANGUAGE_TO_CODE.get(source_lang, None) # Clear GPU cache before transcription if device.type == 'cuda': torch.cuda.empty_cache() # Transcribe with optimized settings with torch.no_grad(): # Disable gradient computation result = whisper_model.transcribe( temp_path, **transcribe_options ) transcribed_text = result["text"] # Get detected language if auto-detection was used detected_language = None if auto_detect and 'language' in result: # Convert language code back to full name detected_code = result['language'] for lang_name, lang_code in LANGUAGE_TO_CODE.items(): if lang_code == detected_code: detected_language = lang_name break # Log detected language logger.info(f"Auto-detected language: {detected_language} ({detected_code})") # Send notification if push is enabled if len(push_subscriptions) > 0: send_push_notification( title="Transcription Complete", body=f"Successfully transcribed: {transcribed_text[:50]}...", tag="transcription-complete" ) response = { 'success': True, 'text': transcribed_text } # Include detected language if auto-detection was used if detected_language: response['detected_language'] = detected_language return jsonify(response) except Exception as e: logger.error(f"Transcription error: {str(e)}") return jsonify({'error': f'Transcription failed: {str(e)}'}), 500 finally: # Clean up the temporary file if os.path.exists(temp_path): os.remove(temp_path) # Force garbage collection to free memory if device.type == 'cuda': torch.cuda.empty_cache() gc.collect() @app.route('/translate', methods=['POST']) def translate(): try: data = request.json text = data.get('text', '') source_lang = data.get('source_lang', '') target_lang = data.get('target_lang', '') if not text or not source_lang or not target_lang: return jsonify({'error': 'Missing required parameters'}), 400 # Create a prompt for Gemma 3 translation prompt = f""" Translate the following text from {source_lang} to {target_lang}: "{text}" Provide only the translation without any additional text. """ # Use Ollama to interact with Gemma 3 response = ollama.chat( model="gemma3:27b", messages=[ { "role": "user", "content": prompt } ] ) translated_text = response['message']['content'].strip() # Send notification if push is enabled if len(push_subscriptions) > 0: send_push_notification( title="Translation Complete", body=f"Translated from {source_lang} to {target_lang}", tag="translation-complete", data={'translation': translated_text[:100]} ) return jsonify({ 'success': True, 'translation': translated_text }) except Exception as e: logger.error(f"Translation error: {str(e)}") return jsonify({'error': f'Translation failed: {str(e)}'}), 500 @app.route('/speak', methods=['POST']) def speak(): try: data = request.json text = data.get('text', '') language = data.get('language', '') if not text or not language: return jsonify({'error': 'Missing required parameters'}), 400 voice = LANGUAGE_TO_VOICE.get(language, 'echo') # Default to echo if language not found # Get TTS server URL and API key from config tts_server_url = app.config['TTS_SERVER'] tts_api_key = app.config['TTS_API_KEY'] try: # Request TTS from the OpenAI Edge TTS server logger.info(f"Sending TTS request to {tts_server_url}") headers = { "Content-Type": "application/json", "Authorization": f"Bearer {tts_api_key}" } # Log request details for debugging logger.info(f"Text for TTS: {text}") logger.info(f"Selected voice: {voice}") # Proper OpenAI TTS payload payload = { "input": text, "voice": voice, "response_format": "mp3", "speed": 1.0 } logger.debug(f"Full TTS request payload: {payload}") # Dump the payload to ensure proper JSON formatting payload_json = json.dumps(payload) logger.debug(f"Serialized payload: {payload_json}") tts_response = requests.post( tts_server_url, headers=headers, json=payload, # Use json parameter to ensure proper serialization timeout=15 # Longer timeout for audio generation ) logger.info(f"TTS response status: {tts_response.status_code}") if tts_response.status_code != 200: error_msg = f'TTS request failed with status {tts_response.status_code}' logger.error(error_msg) # Try to get error details from response if possible try: error_details = tts_response.json() logger.error(f"Error details: {error_details}") error_msg = f"{error_msg}: {error_details.get('error', {}).get('message', 'Unknown error')}" except Exception as e: logger.error(f"Could not parse error response: {str(e)}") # Log the raw response content logger.error(f"Raw response: {tts_response.text[:200]}") return jsonify({'error': error_msg}), 500 # The response contains the audio data directly temp_audio_path = os.path.join(app.config['UPLOAD_FOLDER'], f'output_{int(time.time())}.mp3') with open(temp_audio_path, 'wb') as f: f.write(tts_response.content) return jsonify({ 'success': True, 'audio_url': f'/get_audio/{os.path.basename(temp_audio_path)}' }) except requests.exceptions.RequestException as e: error_msg = f'Failed to connect to TTS server: {str(e)}' logger.error(error_msg) return jsonify({'error': error_msg}), 500 except Exception as e: logger.error(f"TTS error: {str(e)}") return jsonify({'error': f'TTS failed: {str(e)}'}), 500 @app.route('/get_audio/') def get_audio(filename): try: file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename) return send_file(file_path, mimetype='audio/mpeg') except Exception as e: logger.error(f"Audio retrieval error: {str(e)}") return jsonify({'error': f'Audio retrieval failed: {str(e)}'}), 500 # Health check endpoints for monitoring @app.route('/health', methods=['GET']) def health_check(): """Basic health check endpoint""" return jsonify({ 'status': 'healthy', 'timestamp': time.time(), 'service': 'voice-translator' }) @app.route('/health/detailed', methods=['GET']) def detailed_health_check(): """Detailed health check with component status""" health_status = { 'status': 'healthy', 'timestamp': time.time(), 'components': { 'whisper': {'status': 'unknown'}, 'ollama': {'status': 'unknown'}, 'tts': {'status': 'unknown'}, 'gpu': {'status': 'unknown'} }, 'metrics': {} } # Check Whisper model try: if whisper_model is not None: health_status['components']['whisper']['status'] = 'healthy' health_status['components']['whisper']['model_size'] = MODEL_SIZE else: health_status['components']['whisper']['status'] = 'unhealthy' health_status['status'] = 'degraded' except Exception as e: health_status['components']['whisper']['status'] = 'unhealthy' health_status['components']['whisper']['error'] = str(e) health_status['status'] = 'unhealthy' # Check GPU availability try: if torch.cuda.is_available(): health_status['components']['gpu']['status'] = 'healthy' health_status['components']['gpu']['device'] = torch.cuda.get_device_name(0) health_status['components']['gpu']['memory_allocated'] = f"{torch.cuda.memory_allocated(0) / 1024**2:.2f} MB" health_status['components']['gpu']['memory_reserved'] = f"{torch.cuda.memory_reserved(0) / 1024**2:.2f} MB" elif torch.backends.mps.is_available(): health_status['components']['gpu']['status'] = 'healthy' health_status['components']['gpu']['device'] = 'Apple Silicon GPU' else: health_status['components']['gpu']['status'] = 'not_available' health_status['components']['gpu']['device'] = 'CPU' except Exception as e: health_status['components']['gpu']['status'] = 'error' health_status['components']['gpu']['error'] = str(e) # Check Ollama connection try: ollama_models = ollama.list() health_status['components']['ollama']['status'] = 'healthy' health_status['components']['ollama']['available_models'] = len(ollama_models.get('models', [])) except Exception as e: health_status['components']['ollama']['status'] = 'unhealthy' health_status['components']['ollama']['error'] = str(e) health_status['status'] = 'degraded' # Check TTS server try: tts_response = requests.get(app.config['TTS_SERVER'].replace('/v1/audio/speech', '/health'), timeout=5) if tts_response.status_code == 200: health_status['components']['tts']['status'] = 'healthy' health_status['components']['tts']['server_url'] = app.config['TTS_SERVER'] else: health_status['components']['tts']['status'] = 'unhealthy' health_status['components']['tts']['http_status'] = tts_response.status_code health_status['status'] = 'degraded' except Exception as e: health_status['components']['tts']['status'] = 'unhealthy' health_status['components']['tts']['error'] = str(e) health_status['status'] = 'degraded' # Add system metrics health_status['metrics']['uptime'] = time.time() - app.start_time if hasattr(app, 'start_time') else 0 health_status['metrics']['request_count'] = getattr(app, 'request_count', 0) # Set appropriate HTTP status code http_status = 200 if health_status['status'] == 'healthy' else 503 if health_status['status'] == 'unhealthy' else 200 return jsonify(health_status), http_status @app.route('/health/ready', methods=['GET']) def readiness_check(): """Readiness probe - checks if service is ready to accept traffic""" try: # Check if all critical components are loaded if whisper_model is None: return jsonify({'status': 'not_ready', 'reason': 'Whisper model not loaded'}), 503 # Check Ollama connection ollama.list() return jsonify({'status': 'ready', 'timestamp': time.time()}) except Exception as e: return jsonify({'status': 'not_ready', 'reason': str(e)}), 503 @app.route('/health/live', methods=['GET']) def liveness_check(): """Liveness probe - basic check to see if process is alive""" return jsonify({'status': 'alive', 'timestamp': time.time()}) # Initialize app start time for metrics app.start_time = time.time() app.request_count = 0 # Middleware to count requests @app.before_request def before_request(): app.request_count = getattr(app, 'request_count', 0) + 1 if __name__ == '__main__': app.run(host='0.0.0.0', port=5005, debug=True)