Source code for petal_app_manager.proxies.mqtt

"""
MQTTProxy
=========

• Provides access to AWS IoT MQTT broker through TypeScript client API calls
• Handles callback server for receiving continuous message streams
• Uses async callback dispatch for message processing
• Abstracts MQTT communication details away from petals
• Provides async pub/sub operations with callback-style message handling

This proxy allows petals to interact with MQTT without worrying about
the underlying connection management and HTTP communication details.
"""

from __future__ import annotations
from typing import List, Dict, Any, Optional, Callable, Awaitable
from collections import deque, defaultdict
import asyncio
import concurrent.futures
import json
import logging
import os
from datetime import datetime
import functools
import uuid

import requests
from fastapi import FastAPI, HTTPException, APIRouter
from pydantic import BaseModel

from .base import BaseProxy
from ..organization_manager import get_organization_manager

[docs] class MessageCallback(BaseModel): """Model for incoming MQTT messages via callback""" topic: str payload: Dict[str, Any] timestamp: Optional[str] = None qos: Optional[int] = None
[docs] class MQTTProxy(BaseProxy): """ Proxy for communicating with AWS IoT MQTT through TypeScript client API calls. Uses async callback dispatch for message processing. The callback endpoint is exposed as a FastAPI router that should be registered with the main application. The callback_port should be set to the main app's port (e.g., 8000) since the callback router is now part of the main app. Configuration note: Set PETAL_CALLBACK_PORT to the main FastAPI app port (default: 8000) The callback URL will be: http://{callback_host}:{callback_port}/mqtt-callback/callback """ def __init__( self, ts_client_host: str = "localhost", ts_client_port: int = 3004, callback_host: str = "localhost", callback_port: int = 3005, enable_callbacks: bool = True, debug: bool = False, request_timeout: int = 30, max_seen_message_ids: int = 1000, command_edge_topic: str = "command/edge", response_topic: str = "response", test_topic: str = "command", command_web_topic: str = "command/web", health_check_interval: float = 10.0 ): self.ts_client_host = ts_client_host self.ts_client_port = ts_client_port self.callback_host = callback_host self.callback_port = callback_port self.enable_callbacks = enable_callbacks self.debug = debug self.request_timeout = request_timeout # For HTTP callback router (registered with main FastAPI app) self.callback_router: Optional[APIRouter] = None # Router to be registered with main app # Base URL for TypeScript client self.ts_base_url = f"http://{self.ts_client_host}:{self.ts_client_port}" # Callback URL now points to the main app's router endpoint (not a separate server) # The callback_port should be the main FastAPI app port self.callback_url = f"http://{self.callback_host}:{self.callback_port}/mqtt-callback/callback" if self.enable_callbacks else None # Duplicate message filtering self._seen_message_ids = deque(maxlen=max_seen_message_ids) # Subscription management self.command_edge_topic = command_edge_topic self.response_topic = response_topic self.test_topic = test_topic self.command_web_topic = command_web_topic self._handlers: Dict[str, List[Dict[str, str | Callable[[str, Dict[str, Any]], None]]]] = defaultdict(list) self.subscribed_topics = set() # Connection state self.is_connected = False self._shutdown_flag = False self._device_topics = {} # Health monitoring self._health_monitor_task = None self._health_check_interval = health_check_interval self._loop = None self._exe = concurrent.futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix="MQTTProxy") self.log = logging.getLogger("MQTTProxy") # Setup callback router in __init__ so it's available for registration with main app # before start() is called if self.enable_callbacks: self._setup_callback_router()
[docs] async def start(self): """Initialize the MQTT proxy and start callback processing.""" # Get robot instance ID for basic setup self.robot_instance_id = self._get_machine_id() self.device_id = f"Instance-{self.robot_instance_id}" if self.robot_instance_id else None self._loop = asyncio.get_running_loop() self.log.info("Initializing MQTTProxy connection") # Validate basic configuration (organization_id will be fetched on-demand) if not self.device_id: self.log.error("Robot Instance ID must be available from OrganizationManager") self.log.warning("MQTTProxy will remain inactive until Robot Instance ID is available") return try: # Note: callback_router is already set up in __init__ for early registration # with the main app before start() is called # Check TypeScript client health is_healthy = await self._check_ts_client_health() if is_healthy: self.log.info("TypeScript MQTT client is healthy") # Mark as connected since TypeScript client is healthy self.is_connected = True # Try to subscribe to default device topics if organization ID is available organization_id = self._get_organization_id() if organization_id: self.log.info("Organization ID available, subscribing to device topics...") try: from .. import Config await asyncio.wait_for(self._subscribe_to_device_topics(), timeout=Config.MQTT_SUBSCRIBE_TIMEOUT) self.log.info("Successfully subscribed to device topics") except asyncio.TimeoutError: self.log.warning("Timeout subscribing to device topics during startup") except Exception as e: self.log.error(f"Failed to subscribe to device topics: {e}") else: self.log.info("Organization ID not available, skipping device topic subscription") else: self.log.warning("TypeScript MQTT client is not accessible - will monitor and retry") self.is_connected = False # Always start health monitoring task to detect connectivity restoration self._health_monitor_task = asyncio.create_task(self._health_monitor_loop()) self.log.info("MQTTProxy started with health monitoring") except Exception as e: self.log.error(f"Failed to initialize MQTTProxy: {e}") self.log.warning("MQTTProxy connection failed - will monitor and retry") self.is_connected = False # Still start health monitor to detect when things recover self._health_monitor_task = asyncio.create_task(self._health_monitor_loop())
[docs] async def stop(self): """Clean up resources when shutting down.""" self.log.info("Stopping MQTTProxy...") # Cancel health monitor task if self._health_monitor_task and not self._health_monitor_task.done(): self._health_monitor_task.cancel() try: await self._health_monitor_task except asyncio.CancelledError: pass # clear all registered handlers self._handlers.clear() for topic in self.subscribed_topics: await self._unsubscribe_from_topic(topic) self.subscribed_topics.clear() # Set shutdown flag self._shutdown_flag = True self.is_connected = False # Shutdown executor if self._exe: self._exe.shutdown(wait=False) self.log.info("MQTTProxy stopped")
def _get_machine_id(self) -> Optional[str]: """ Get the machine ID from the OrganizationManager. Returns: The machine ID if available, None otherwise """ try: org_manager = get_organization_manager() machine_id = org_manager.machine_id if not machine_id: self.log.error("Machine ID not available from OrganizationManager") return None return machine_id except Exception as e: self.log.error(f"Error getting machine ID from OrganizationManager: {e}") return None def _get_organization_id(self) -> Optional[str]: """ Get the organization ID from the OrganizationManager on-demand. Returns: The organization ID if available, None otherwise """ try: org_manager = get_organization_manager() org_id = org_manager.organization_id if not org_id: self.log.debug("Organization ID not yet available from OrganizationManager") return None return org_id except Exception as e: self.log.debug(f"Error getting organization ID from OrganizationManager: {e}") return None def _get_organization_id_with_wait(self, timeout: float = 5.0) -> Optional[str]: """ Get organization ID with optional wait for availability. Args: timeout: Maximum time to wait for organization ID Returns: Organization ID if available within timeout, None otherwise """ import time start_time = time.time() while time.time() - start_time < timeout: org_id = self._get_organization_id() if org_id: return org_id time.sleep(0.5) self.log.warning(f"Organization ID not available after {timeout}s timeout") return None def _get_base_topic(self) -> Optional[str]: """ Get the base topic for this device. Returns: Base topic string if organization_id and device_id are available, None otherwise """ organization_id = self._get_organization_id() if not organization_id or not self.device_id: self.log.warning("Cannot construct base topic: missing org or device ID") return None return f"org/{organization_id}/device/{self.device_id}" @property def organization_id(self) -> Optional[str]: """ Organization ID property for backward compatibility. Fetches organization_id on-demand from OrganizationManager. Returns: Organization ID if available, None otherwise """ return self._get_organization_id() # ------ Message Processing ------ # def _process_incoming_message(self, topic: str, payload: Dict[str, Any]): """ Process an incoming MQTT message - dedup and dispatch to handlers. Called directly from the HTTP callback handler. """ try: # Duplicate check by messageId msg_id = payload.get("messageId") if msg_id: if msg_id in self._seen_message_ids: self.log.debug(f"Duplicate message detected, skipping: {msg_id}") return self._seen_message_ids.append(msg_id) self.log.debug(f"Processing MQTT message on topic: {topic}") # Dispatch to registered handlers if topic in self.subscribed_topics: for handler in self._handlers[topic]: callback = handler.get("callback") is_cpu_heavy = handler.get("cpu_heavy", False) if callback: self._invoke_callback_safely(callback, topic, payload, cpu_heavy=is_cpu_heavy) else: subscription_id = handler.get("subscription_id") if subscription_id: self.log.debug(f"No callback for topic: {topic} with subscription ID {subscription_id}") else: self.log.debug(f"No callback for topic: {topic} with no subscription ID") else: self.log.debug(f"Received message for unsubscribed topic: {topic}") except Exception as e: self.log.error(f"Error processing incoming message: {e}") def _invoke_callback_safely( self, callback: Callable, topic: str, payload: Dict[str, Any], *, cpu_heavy: bool = False, ): """Safely invoke an async callback by scheduling it on the event loop. Since MQTT callbacks arrive via FastAPI HTTP handlers (already on the event loop), we use create_task() instead of run_coroutine_threadsafe(). When *cpu_heavy* is ``True`` the coroutine body is executed inside ``loop.run_in_executor`` to avoid blocking the event loop. """ try: if self._loop and not self._loop.is_closed(): if cpu_heavy: async def _offloaded(): await self._loop.run_in_executor( self._exe, lambda: asyncio.run(callback(topic, payload)) ) self._loop.create_task(_offloaded()) else: self._loop.create_task(callback(topic, payload)) else: self.log.warning(f"Cannot invoke async callback for {topic}: event loop not available") except Exception as e: self.log.error(f"Error in callback for topic {topic}: {e}") # ------ TypeScript Client Communication ------ # async def _check_ts_client_health(self) -> bool: """Check if TypeScript MQTT client is healthy.""" try: response = await self._loop.run_in_executor( self._exe, lambda: requests.get(f"{self.ts_base_url}/health", timeout=self.request_timeout) ) return response.status_code == 200 except Exception as e: self.log.debug(f"TypeScript client unreachable: {type(e).__name__}") return False async def _health_monitor_loop(self): """Background task to monitor TypeScript client health and restore device topic subscriptions.""" self.log.info("Health monitor started") last_health_status = self.is_connected while not self._shutdown_flag: try: await asyncio.sleep(self._health_check_interval) # Check TypeScript client health is_healthy = await self._check_ts_client_health() # If health status changed if is_healthy != last_health_status: if is_healthy: self.log.info("TypeScript client health restored") # Check if we have organization ID organization_id = self._get_organization_id() if organization_id: # Get expected device topics expected_topics = [ f"{self._get_base_topic()}/{self.command_edge_topic}", f"{self._get_base_topic()}/{self.response_topic}", f"{self._get_base_topic()}/{self.test_topic}" ] # Check which topics are missing missing_topics = [t for t in expected_topics if t not in self.subscribed_topics] if missing_topics: self.log.info(f"Re-subscribing to {len(missing_topics)} missing device topics...") for topic in missing_topics: # Extract relative topic from full path base_topic = self._get_base_topic() relative_topic = topic[len(base_topic)+1:] if topic.startswith(base_topic) else topic await self._subscribe_to_topic(relative_topic) self.log.info("Device topic subscriptions restored") else: self.log.info("All device topics already subscribed") self.is_connected = True else: self.log.debug("Health restored but organization ID not available yet") else: self.log.warning("TypeScript client health check failed - marking as disconnected") self.is_connected = False last_health_status = is_healthy except asyncio.CancelledError: self.log.info("Health monitor cancelled") break except Exception as e: self.log.error(f"Error in health monitor loop: {e}") # Continue monitoring despite errors self.log.info("Health monitor stopped") async def _make_ts_request(self, method: str, endpoint: str, data: Optional[Dict] = None) -> Dict[str, Any]: """Make HTTP request to TypeScript client.""" try: url = f"{self.ts_base_url}{endpoint}" response = await self._loop.run_in_executor( self._exe, functools.partial( requests.request, method=method, url=url, json=data, timeout=self.request_timeout, headers={"Content-Type": "application/json"}, ), ) if response.status_code == 200: return response.json() else: error_msg = f"TypeScript client request failed: {response.status_code} - {response.text}" self.log.error(error_msg) return {"error": error_msg} except Exception as e: error_msg = f"Error communicating with TypeScript client: {str(e)}" self.log.error(error_msg) return {"error": error_msg} # ------ Callback Router (registered with main FastAPI app) ------ # def _setup_callback_router(self): """Setup FastAPI router for receiving MQTT callback messages. This router should be registered with the main FastAPI app using: app.include_router(mqtt_proxy.callback_router, prefix="/mqtt-callback") """ if not self.enable_callbacks: return self.callback_router = APIRouter(tags=["MQTT Callback"]) # Store reference to self for use in route handlers proxy = self @self.callback_router.post('/callback') async def message_callback(message: MessageCallback): """Handle incoming MQTT messages - dedup and dispatch directly.""" try: proxy._process_incoming_message(message.topic, message.payload) return {"status": "success"} except Exception as e: proxy.log.error(f"Error processing callback message: {e}") return {"status": "error", "message": str(e)} @self.callback_router.get('/health') async def callback_health(): """Health check for MQTT callback endpoint.""" return { "status": "healthy", "timestamp": datetime.now().isoformat(), "subscriptions": len(proxy.subscribed_topics), } @self.callback_router.get('/stats') async def callback_stats(): """Statistics for MQTT callback and message processing.""" return { "subscriptions": len(proxy.subscribed_topics), "handlers": sum(len(handlers) for handlers in proxy._handlers.values()), } self.log.info("MQTT callback router configured (register with main app using include_router)") async def _subscribe_to_topic(self, topic: str) -> bool: """Subscribe to an MQTT topic via TypeScript client. Args: topic: Topic to subscribe to (relative to base topic) Returns: Subscription ID if successful, None otherwise """ try: # make sure topic just does not have a leading slash if topic.startswith("/"): topic = topic[1:] # Get base topic base_topic = self._get_base_topic() if not base_topic: self.log.error(f"Cannot subscribe to {topic}: base topic not available") return False # Determine full topic to subscribe to topic_subscribe = f"{base_topic}/{topic}" request_data = { "topic": topic_subscribe, "callbackUrl": self.callback_url if self.enable_callbacks else None } result = await self._make_ts_request("POST", "/subscribe", request_data) if "error" in result: self.log.error(f"Failed to subscribe to {topic_subscribe}: {result['error']}") return False self.subscribed_topics.add(topic_subscribe) self.log.info(f"Subscribed to topic: {topic_subscribe}") return True except Exception as e: self.log.error(f"Error subscribing to {topic_subscribe}: {e}") return False async def _unsubscribe_from_topic(self, topic: str) -> bool: """ Unsubscribe from an MQTT topic. Args: topic: Topic to unsubscribe from (relative to base topic) Returns: True if unsubscribed successfully, False otherwise """ try: # make sure topic just does not have a leading slash if topic.startswith("/"): topic = topic[1:] topic_unsubscribe = f"{self._get_base_topic()}/{topic}" # Unsubscribe using the subscription ID if topic_unsubscribe in self.subscribed_topics: self.subscribed_topics.remove(topic_unsubscribe) self.log.info(f"Unsubscribed from topic: {topic_unsubscribe}") return True except Exception as e: self.log.error(f"Error unsubscribing from {topic_unsubscribe}: {e}") return False async def _subscribe_to_device_topics(self): """Subscribe to common device topics automatically.""" # Get base topic (requires org ID and device ID) base_topic = self._get_base_topic() if not base_topic: self.log.warning("Cannot subscribe to device topics: missing org or device ID") return # Default topics to subscribe to topics = [ self.command_edge_topic, self.response_topic, self.test_topic ] for topic in topics: success = await self._subscribe_to_topic(topic) self.register_handler(self._default_message_handler) if success: self.log.info(f"Auto-subscribed to device topic: {topic}") else: self.log.error(f"Failed to auto-subscribe to device topic: {topic}") async def _default_message_handler(self, topic: str, payload: Dict[str, Any]): """Default message handler for device topics.""" self.log.info(f"Received message on {topic}: {payload}") # Handle command messages # await self._process_command(topic, payload) async def _process_command(self, topic: str, payload: Dict[str, Any]): """Enhanced command processing.""" command_type = payload.get('command') message_id = payload.get('messageId', 'unknown') # Log command for audit self.log.info(f"Processing command: {payload}") # Send response back await self.send_command_response(message_id, { 'status': 'success', 'timestamp': datetime.now().isoformat() }) async def _publish_message(self, topic: str, payload: Dict[str, Any], qos: int = 1) -> bool: """ Publish a message to an MQTT topic via TypeScript client. Args: topic: The MQTT topic to publish to payload: Message payload as a dictionary qos: Quality of Service level (0, 1, or 2) Returns: True if published successfully, False otherwise """ if not self.is_connected: self.log.error("MQTT proxy is not connected") return False # Topic is already the full topic path topic_publish = topic payload["deviceId"] = self.device_id try: request_data = { "topic": topic_publish, "payload": payload, "qos": qos, "callbackUrl": self.callback_url } result = await self._make_ts_request("POST", "/publish", request_data) if "error" in result: self.log.error(f"Failed to publish message to {topic_publish}: {result['error']}") return False self.log.debug(f"Published message to topic: {topic_publish}") return True except Exception as e: self.log.error(f"Error publishing message to {topic_publish}: {e}") return False return False # ------ Public API methods ------ #
[docs] async def publish_message(self, payload: Dict[str, Any], qos: int = 1) -> bool: """ Publish a message to an MQTT topic via TypeScript client to 'command/web' topic. Args: payload: Message payload as a dictionary qos: Quality of Service level (0, 1, or 2) Returns: True if published successfully, False otherwise """ return await self._publish_message( topic=f"{self._get_base_topic()}/{self.command_web_topic}", payload=payload, qos=qos )
[docs] async def send_command_response(self, message_id: str, response_data: Dict[str, Any]) -> bool: """ Send a command response to the response topic. Args: message_id: Original message ID to correlate response response_data: Response payload data Returns: True if published successfully, False otherwise """ response_topic = f"{self._get_base_topic()}/{self.response_topic}" response_payload = { 'messageId': message_id, 'timestamp': datetime.now().isoformat(), **response_data } return await self._publish_message(response_topic, response_payload)
[docs] def register_handler( self, handler: Callable[[str, Dict[str, Any]], Awaitable[None]], cpu_heavy: bool = False, ) -> str: """ Register an async handler to the 'command/edge' topic. Args: handler: Async callback function to handle messages cpu_heavy: If True, the handler's CPU-bound work will be offloaded to a thread-pool executor managed by the proxy. Returns: Subscription ID string, or None if registration failed Raises: TypeError: If handler is not an async function """ if not asyncio.iscoroutinefunction(handler): raise TypeError( f"Handler must be an async function (coroutine). " f"Got {type(handler).__name__} instead. " f"Define your handler with 'async def' instead of 'def'." ) topic_subscribe = f"{self._get_base_topic()}/{self.command_edge_topic}" # ensure topic is subscribed if topic_subscribe not in self.subscribed_topics: self.log.error(f"Cannot register handler: not subscribed to topic {topic_subscribe}") return None # Store callback subscription_id = str(uuid.uuid4()) self._handlers[topic_subscribe].append({ "callback": handler, "subscription_id": subscription_id, "cpu_heavy": cpu_heavy, }) self.log.debug(f"Registered handler for topic: {self.command_edge_topic} with subscription ID: {subscription_id}") return subscription_id
[docs] def unregister_handler(self, subscription_id: str) -> bool: """ Unregister a handler from the 'command/edge' topic. """ topic_unsubscribe = f"{self._get_base_topic()}/{self.command_edge_topic}" if topic_unsubscribe in self._handlers: handlers = self._handlers[topic_unsubscribe] for handler in handlers: if handler.get("subscription_id") == subscription_id: handlers.remove(handler) self.log.debug(f"Unregistered handler for topic: {self.command_edge_topic} with subscription ID: {subscription_id}") return True else: self.log.debug(f"No matching handler found for subscription ID: {subscription_id} on topic: {self.command_edge_topic}") else: self.log.warning(f"No handlers registered for topic: {self.command_edge_topic}") return False
# ------ Health Check Methods ------ #
[docs] async def health_check(self) -> Dict[str, Any]: """Check MQTT proxy health status.""" health_status = { "status": "healthy" if self.is_connected else "unhealthy", "connection": { "ts_client": await self._check_ts_client_health(), "callback_router": self.enable_callbacks and self.callback_router is not None, "connected": self.is_connected }, "configuration": { "ts_client_host": self.ts_client_host, "ts_client_port": self.ts_client_port, "callback_host": self.callback_host, "callback_port": self.callback_port, "enable_callbacks": self.enable_callbacks, }, "subscriptions": { "topics": list(self.subscribed_topics), "handlers": {topic: len(handlers) for topic, handlers in self._handlers.items()} }, "device_info": { "organization_id": self._get_organization_id(), "robot_instance_id": self.robot_instance_id } } return health_status