"""
RedisProxy
==========
Simple Redis proxy for basic pub/sub messaging and key-value operations.
Supports Unix socket connections.
"""
from __future__ import annotations
from typing import List, Optional, Callable, Awaitable, Dict, Any
import asyncio
import concurrent.futures
import logging
import redis
from .base import BaseProxy
[docs]
class RedisProxy(BaseProxy):
"""
Simple Redis proxy for pub/sub messaging and key-value operations.
Supports Unix socket connections.
"""
def __init__(
self,
host: str = "localhost",
port: int = 6379,
db: int = 0,
password: Optional[str] = None,
debug: bool = False,
unix_socket_path: Optional[str] = None,
):
self.host = host
self.port = port
self.db = db
self.password = password
self.debug = debug
self.unix_socket_path = unix_socket_path
self._client = None
self._pubsub_client = None
self._pubsub = None
self._pubsub_pattern = None
self._loop = None
# Executor for blocking Redis I/O (listen loops + Redis commands)
self._exe = concurrent.futures.ThreadPoolExecutor(max_workers=4, thread_name_prefix="RedisProxy")
self.log = logging.getLogger("RedisProxy")
# Communication and health check attributes
self.app_id = f"redis-proxy-{id(self)}" # Unique app identifier
self._is_listening = False # Whether actively listening for messages
self._message_handlers = {} # Active message handlers
self._subscription_tasks = {} # Active subscription tasks
# Store active subscriptions
self._subscriptions = {} # normal channel subscriptions
self._subscription_task = None
# Pattern subscription state
self._pattern_callbacks = {} # channel: callback for pattern subpub
self._pattern_subscription_task = None
self._current_pattern = None
# CPU-heavy offloading flags per handler
self._handler_cpu_heavy: Dict[str, bool] = {} # channel -> cpu_heavy flag
self._pattern_cpu_heavy: Dict[str, bool] = {} # channel -> cpu_heavy flag
# Shutdown flag to control infinite loops
self._shutdown_flag = False
[docs]
async def start(self):
"""Initialize the connection to Redis."""
self._loop = asyncio.get_running_loop()
try:
# Create Redis client - prioritize Unix socket
if self.unix_socket_path:
self.log.info("Initializing Redis connection via Unix socket: %s", self.unix_socket_path)
self._client = await self._loop.run_in_executor(
self._exe,
lambda: redis.Redis(
unix_socket_path=self.unix_socket_path,
db=self.db,
password=self.password,
decode_responses=True
)
)
# Create separate client for pub/sub operations
self._pubsub_client = await self._loop.run_in_executor(
self._exe,
lambda: redis.Redis(
unix_socket_path=self.unix_socket_path,
db=self.db,
password=self.password,
decode_responses=True
)
)
else:
self.log.info("Initializing Redis connection to %s:%s db=%s", self.host, self.port, self.db)
self._client = await self._loop.run_in_executor(
self._exe,
lambda: redis.Redis(
host=self.host,
port=self.port,
db=self.db,
password=self.password,
decode_responses=True
)
)
# Create separate client for pub/sub operations
self._pubsub_client = await self._loop.run_in_executor(
self._exe,
lambda: redis.Redis(
host=self.host,
port=self.port,
db=self.db,
password=self.password,
decode_responses=True
)
)
except Exception as e:
self.log.error(f"Failed to create Redis clients: {e}")
return
# Test connection
try:
ping_result = await self._loop.run_in_executor(self._exe, self._client.ping)
if ping_result:
self.log.info("Redis connection established successfully")
# Initialize pub/sub
self._pubsub = self._pubsub_client.pubsub()
# Initialize pattern pub/sub (separate client for patterns)
self._pubsub_pattern = self._pubsub_client.pubsub()
self._subscribe_pattern("/petal-*") # Default pattern
# Set listening state to True after successful connection and pub/sub setup
self._is_listening = True
self.log.info(f"RedisProxy {self.app_id} is now listening for messages")
else:
self.log.warning("Redis ping returned unexpected result")
except Exception as e:
self.log.error(f"Failed to connect to Redis: {e}")
[docs]
async def stop(self):
"""Close the Redis connection and clean up resources."""
self.log.info("Stopping RedisProxy...")
# Set shutdown flag to stop infinite loops
self._shutdown_flag = True
# Stop listening state
self._is_listening = False
# Stop subscription tasks
if self._subscription_task and not self._subscription_task.done():
self._subscription_task.cancel()
try:
await self._subscription_task
except asyncio.CancelledError:
pass
if self._pattern_subscription_task and not self._pattern_subscription_task.done():
self._pattern_subscription_task.cancel()
try:
await self._pattern_subscription_task
except asyncio.CancelledError:
pass
# Clear handlers and tasks
self._message_handlers.clear()
self._subscription_tasks.clear()
# Close pub/sub
if self._pubsub:
try:
await self._loop.run_in_executor(self._exe, self._pubsub.close)
except Exception as e:
self.log.error(f"Error closing Redis pub/sub: {e}")
if self._pubsub_pattern:
try:
await self._loop.run_in_executor(self._exe, self._pubsub_pattern.close)
except Exception as e:
self.log.error(f"Error closing Redis pattern pub/sub: {e}")
# Close Redis connections
if self._client:
try:
await self._loop.run_in_executor(self._exe, self._client.close)
except Exception as e:
self.log.error(f"Error closing Redis connection: {e}")
if self._pubsub_client:
try:
await self._loop.run_in_executor(self._exe, self._pubsub_client.close)
except Exception as e:
self.log.error(f"Error closing Redis pub/sub connection: {e}")
# Shutdown the executor with a timeout
if self._exe:
self._exe.shutdown(wait=False) # Don't wait for infinite loops
self.log.info("RedisProxy stopped")
# ------ Key-Value Operations ------ #
[docs]
async def get(self, key: str) -> Optional[str]:
"""Get a value from Redis."""
if not self._client:
self.log.error("Redis client not initialized")
return None
try:
result = await self._loop.run_in_executor(
self._exe,
lambda: self._client.get(key)
)
# 📥 Log key reads
self.log.debug(f"📥 Redis GET: {key} = {result}")
return result
except Exception as e:
self.log.error(f"Error getting key {key}: {e}")
return None
[docs]
async def set(self, key: str, value: str, ex: Optional[int] = None) -> bool:
"""Set a value in Redis."""
if not self._client:
self.log.error("Redis client not initialized")
return False
try:
result = await self._loop.run_in_executor(
self._exe,
lambda: bool(self._client.set(key, value, ex=ex))
)
# 📤 Log key writes
self.log.debug(f"📤 Redis SET: {key} = {value} (ex={ex}) -> {result}")
return result
except Exception as e:
self.log.error(f"Error setting key {key}: {e}")
return False
[docs]
async def delete(self, key: str) -> int:
"""Delete a key from Redis."""
if not self._client:
self.log.error("Redis client not initialized")
return 0
try:
return await self._loop.run_in_executor(
self._exe,
lambda: self._client.delete(key)
)
except Exception as e:
self.log.error(f"Error deleting key {key}: {e}")
return 0
[docs]
async def exists(self, key: str) -> bool:
"""Check if a key exists in Redis."""
if not self._client:
self.log.error("Redis client not initialized")
return False
try:
result = await self._loop.run_in_executor(
self._exe,
lambda: self._client.exists(key)
)
return bool(result)
except Exception as e:
self.log.error(f"Error checking existence of key {key}: {e}")
return False
[docs]
async def scan_keys(self, pattern: str, count: int = 100) -> List[str]:
"""
Scan Redis keys matching a pattern.
Args:
pattern: Key pattern to match (e.g., ``job:*``)
count: Number of keys to return per scan iteration
Returns:
List of matching keys
"""
if not self._client:
self.log.error("Redis client not initialized")
return []
try:
def _scan():
keys = []
cursor = 0
while True:
cursor, partial_keys = self._client.scan(cursor=cursor, match=pattern, count=count)
keys.extend(partial_keys)
if cursor == 0:
break
return keys
return await self._loop.run_in_executor(self._exe, _scan)
except Exception as e:
self.log.error(f"Error scanning keys with pattern {pattern}: {e}")
return []
[docs]
async def list_online_applications(self) -> List[str]:
"""List online applications by checking Redis keys for app registrations."""
if not self._client:
self.log.error("Redis client not initialized")
return []
try:
# Look for keys that match application registration pattern
# This is a simple implementation - can be customized based on your app registration pattern
result = await self._loop.run_in_executor(
self._exe,
lambda: self._client.keys("app:*:online")
)
# Extract app names from keys like "app:myapp:online"
apps = [key.split(':')[1] for key in result if ':' in key and len(key.split(':')) >= 3]
return apps
except Exception as e:
self.log.error(f"Error listing online applications: {e}")
return []
# ------ Pub/Sub Operations ------ #
[docs]
async def publish(self, channel: str, message: str) -> int:
"""Publish a message to a channel (async, non-blocking)."""
if not self._client:
self.log.error("Redis client not initialized")
return 0
try:
result = await self._loop.run_in_executor(
self._exe,
lambda: self._client.publish(channel, message)
)
return result
except Exception as e:
self.log.error(f"Error publishing to channel {channel}: {e}")
return 0
[docs]
def subscribe(
self,
channel: str,
callback: Callable[[str, str], Awaitable[None]],
cpu_heavy: bool = False,
):
"""Subscribe to a channel with an async callback function.
Args:
channel: The Redis channel to subscribe to
callback: Async callback function that receives (channel, data)
cpu_heavy: If True, the callback's CPU-bound work will be offloaded
to a thread-pool executor managed by the proxy.
Raises:
TypeError: If callback is not an async function
"""
if not asyncio.iscoroutinefunction(callback):
raise TypeError(
f"Callback must be an async function (coroutine). "
f"Got {type(callback).__name__} instead. "
f"Define your callback with 'async def' instead of 'def'."
)
if not self._pubsub:
self.log.error("Redis pub/sub not initialized")
return
# Store the callback and its config in tracking dictionaries
self._subscriptions[channel] = callback
self._message_handlers[channel] = callback
self._handler_cpu_heavy[channel] = cpu_heavy
# Subscribe to the channel
try:
# await self._loop.run_in_executor(
# self._exe,
# lambda: self._pubsub.subscribe(channel)
# )
self._pubsub.subscribe(channel)
# Start listening if not already started
if not self._subscription_task and self._loop:
self._subscription_task = self._loop.create_task(self._listen_for_messages())
# Track this task
self._subscription_tasks[channel] = self._subscription_task
self.log.info(f"Subscribed to channel: {channel}")
except Exception as e:
self.log.error(f"Error subscribing to channel {channel}: {e}")
[docs]
def unsubscribe(self, channel: str):
"""Unsubscribe from a channel."""
if not self._pubsub:
self.log.error("Redis pub/sub not initialized")
return
try:
self._pubsub.unsubscribe(channel)
# Remove callback from both tracking dictionaries
if channel in self._subscriptions:
del self._subscriptions[channel]
if channel in self._message_handlers:
del self._message_handlers[channel]
if channel in self._subscription_tasks:
del self._subscription_tasks[channel]
if channel in self._handler_cpu_heavy:
del self._handler_cpu_heavy[channel]
self.log.info(f"Unsubscribed from channel: {channel}")
except Exception as e:
self.log.error(f"Error unsubscribing from channel {channel}: {e}")
def _subscribe_pattern(self, pattern: str = "/petal-*"):
"""Subscribe to channels matching a pattern. Only one pattern at a time."""
if not self._pubsub_pattern:
self.log.error("Redis pattern pub/sub not initialized")
return
# Only support one pattern at a time
if self._current_pattern and self._current_pattern != pattern:
self.log.warning(f"Switching from pattern '{self._current_pattern}' to '{pattern}'")
self._unsubscribe_pattern(self._current_pattern)
self._current_pattern = pattern
try:
self._pubsub_pattern.psubscribe(pattern)
# Start listening if not already started
if not self._pattern_subscription_task and self._loop:
self._pattern_subscription_task = self._loop.create_task(self._listen_for_pattern_messages())
self.log.info(f"Subscribed to pattern: {pattern}")
except Exception as e:
self.log.error(f"Error subscribing to pattern {pattern}: {e}")
[docs]
def register_pattern_channel_callback(
self,
channel: str,
callback: Callable[[str, str], Awaitable[None]],
cpu_heavy: bool = False,
):
"""Register a callback for a specific channel for pattern subscriptions.
The callback must be an async function (coroutine).
Sync callbacks are not supported and will raise TypeError.
Args:
channel: The channel to register the callback for
callback: Async callback function that receives (channel, data)
cpu_heavy: If True, the callback's CPU-bound work will be offloaded
to a thread-pool executor managed by the proxy.
"""
if not asyncio.iscoroutinefunction(callback):
raise TypeError(
f"Callback for channel '{channel}' must be an async function (coroutine). "
f"Got {type(callback).__name__}. Use 'async def' to define your callback."
)
self._pattern_callbacks[channel] = callback
self._pattern_cpu_heavy[channel] = cpu_heavy
self.log.info(f"Registered pattern callback for channel: {channel}")
[docs]
def unregister_pattern_channel_callback(self, channel: str):
"""Unregister a callback for a specific channel for pattern subscriptions."""
if channel in self._pattern_callbacks:
del self._pattern_callbacks[channel]
self.log.info(f"Unregistered pattern callback for channel: {channel}")
else:
self.log.warning(f"No pattern callback registered for channel: {channel}")
def _unsubscribe_pattern(self, pattern: str = None):
"""Unsubscribe from the current pattern."""
if not self._pubsub_pattern:
self.log.error("Redis pattern pub/sub not initialized")
return
# Use current pattern if none specified
if pattern is None:
pattern = self._current_pattern
if not pattern:
self.log.warning("No pattern to unsubscribe from")
return
try:
self._pubsub_pattern.punsubscribe(pattern)
if pattern == self._current_pattern:
self._current_pattern = None
self._pattern_callbacks.clear()
self.log.info(f"Unsubscribed from pattern: {pattern}")
except Exception as e:
self.log.error(f"Error unsubscribing from pattern {pattern}: {e}")
async def _listen_for_pattern_messages(self):
"""Listen for messages from pattern subscriptions and dispatch directly."""
while not self._shutdown_flag:
try:
message = await self._loop.run_in_executor(
self._exe,
lambda: self._pubsub_pattern.get_message(timeout=1.0)
)
if message and message['type'] == 'pmessage':
channel = message['channel']
data = message['data']
self.log.debug(f"Received pattern message at channel: {channel}")
# Dispatch directly to registered callback
callback = self._pattern_callbacks.get(channel)
if callback:
is_cpu_heavy = self._pattern_cpu_heavy.get(channel, False)
self._invoke_callback_safely(callback, channel, data, cpu_heavy=is_cpu_heavy)
except Exception as e:
if "timeout" not in str(e).lower() and not self._shutdown_flag:
self.log.error(f"Error listening for pattern messages: {e}")
async def _listen_for_messages(self):
"""Listen for PubSub messages and auto-reconnect on errors."""
while not self._shutdown_flag:
try:
message = await self._loop.run_in_executor(
self._exe,
lambda: self._pubsub.get_message(timeout=1.0)
)
# Only care about real published messages
if message and message.get('type') == 'message':
# Decode bytes to str if needed
channel = message['channel']
if isinstance(channel, bytes):
channel = channel.decode('utf-8', 'ignore')
data = message['data']
if isinstance(data, bytes):
data = data.decode('utf-8', 'ignore')
self.log.debug(f"Received at channel: {channel}")
# Dispatch directly to registered callback
callback = self._subscriptions.get(channel)
if callback:
is_cpu_heavy = self._handler_cpu_heavy.get(channel, False)
self._invoke_callback_safely(callback, channel, data, cpu_heavy=is_cpu_heavy)
except redis.exceptions.ConnectionError as conn_err:
if not self._shutdown_flag:
self.log.error(f"Connection lost: {conn_err!r}; reconnecting in 1s")
await asyncio.sleep(1)
await self._reconnect_pubsub()
except (IOError, OSError) as io_err:
if not self._shutdown_flag:
self.log.error(f"I/O error on PubSub socket: {io_err!r}; reconnecting in 1s")
await asyncio.sleep(1)
await self._reconnect_pubsub()
except IndexError as idx_err:
# Defensive: swallow stray indexing bugs and keep going
self.log.error(f"Unexpected indexing error: {idx_err!r}; continuing")
except Exception as e:
# Anything else that isn’t just a timeout
msg = str(e).lower()
if "timeout" not in msg:
self.log.error(f"Error listening for messages: {e!r}")
def _invoke_callback_safely(
self, callback: Callable, channel: str, data: str, *, cpu_heavy: bool = False,
):
"""Safely invoke an async callback by scheduling it on the main event loop.
All callbacks are required to be async functions.
When *cpu_heavy* is ``True`` the coroutine body is executed inside
``loop.run_in_executor`` to avoid blocking the event loop.
"""
try:
if self._loop and not self._loop.is_closed():
if cpu_heavy:
async def _offloaded():
await self._loop.run_in_executor(
self._exe, lambda: asyncio.run(callback(channel, data))
)
asyncio.run_coroutine_threadsafe(_offloaded(), self._loop)
else:
asyncio.run_coroutine_threadsafe(callback(channel, data), self._loop)
else:
self.log.warning(f"Cannot invoke async callback for {channel}: event loop not available")
except Exception as e:
self.log.error(f"Error in callback for channel {channel}: {e}")
def _reconnect_pubsub(self):
"""Tear down the old PubSub and re-subscribe to all channels."""
self.log.warning("_reconnect_pubsub method is deprecated - Redis connections should be properly managed")