# strategy/utils.py
"""
Shared utility functions for strategy modules.
Includes time/session logic, swing detection, and common validators.
"""
from typing import Dict, Tuple, Optional
import pandas as pd
import logging
import re
from functools import lru_cache

logger = logging.getLogger("TradingBot.StrategyUtils")

class StrategyValidationError(Exception):
    """Custom exception for strategy validation errors."""
    pass

class ConfigValidationError(Exception):
    """Custom exception for config validation errors."""
    pass

def validation_required(func):
    """
    Decorator to catch and log validation errors in strategy utilities.
    """
    def wrapper(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except (StrategyValidationError, ConfigValidationError, ValueError) as e:
            logger.error(f"Validation error in {func.__name__}: {e}")
            raise
    return wrapper

@validation_required
def is_in_killzone(current_time: str, killzones: Dict[str, Dict]) -> Tuple[Optional[str], Optional[Dict]]:
    """
    Check if current_time (string 'HH:MM') is within any configured killzone.

    Args:
        current_time (str): Time in 'HH:MM' format (24h, e.g. '14:30').
        killzones (dict): Dict of killzone configs, each with 'start', 'end', 'weight'.
            Example:
                {
                    "LONDON": {"start": "02:00", "end": "05:00", "weight": 1.0},
                    ...
                }

    Returns:
        tuple: (zone_name, zone_metadata) if in killzone, else (None, None)

    Raises:
        ValueError: If current_time is not in 'HH:MM' format or killzones is not a dict

    Example:
        >>> is_in_killzone('03:15', KILLZONES)
        ('LONDON', {'start': '02:00', 'end': '05:00', 'weight': 1.0})

    Units:
        Time is in 24h format, weights are floats (risk/session multipliers)
    """
    if not isinstance(killzones, dict):
        logger.error("is_in_killzone: killzones must be a dict")
        raise ConfigValidationError("killzones must be a dict")
    if not isinstance(current_time, str) or not re.match(r"^\d{2}:\d{2}$", current_time):
        logger.error(f"is_in_killzone: current_time '{current_time}' is not in 'HH:MM' format")
        raise StrategyValidationError("current_time must be a string in 'HH:MM' format")
    for name, kz in killzones.items():
        if not all(k in kz for k in ("start", "end", "weight")):
            logger.error(f"is_in_killzone: killzone '{name}' missing required keys")
            continue
        if kz["start"] <= current_time <= kz["end"]:
            return name, kz
    return None, None

@validation_required
def detect_swings(df: pd.DataFrame, swing_window: int = 3) -> Tuple[list, list]:
    """
    Detect swing highs and lows in a price DataFrame.

    Args:
        df (pd.DataFrame): Must have 'high' and 'low' columns (price units, e.g. USD or pips)
        swing_window (int): Number of bars to look back/forward for swing (default: 3)

    Returns:
        tuple: (swing_highs, swing_lows): Lists of (index, price)

    Raises:
        ValueError: If df is not a DataFrame or missing required columns

    Example:
        >>> detect_swings(df, 3)
        ([(10, 2034.5), (17, 2040.2)], [(12, 2020.1)])

    Units:
        Price units as in df (e.g. USD, pips)
    """
    # Simple cache: use id(df) and swing_window as key
    # (Note: True LRU cache for DataFrames is tricky; this is a lightweight approach)
    cache = getattr(detect_swings, '_cache', None)
    if cache is None:
        cache = {}
        setattr(detect_swings, '_cache', cache)
    cache_key = (id(df), swing_window, len(df))
    if cache_key in cache:
        return cache[cache_key]
    if not isinstance(df, pd.DataFrame):
        logger.error("detect_swings: df must be a pandas DataFrame")
        raise ValueError("df must be a pandas DataFrame")
    for col in ("high", "low"):
        if col not in df.columns:
            logger.error(f"detect_swings: DataFrame missing required column '{col}'")
            raise ValueError(f"DataFrame must contain '{col}' column")
    if df is None or df.empty or len(df) < 2 * swing_window + 1:
        logger.warning("detect_swings: Insufficient data for swing detection")
        return [], []
    highs = df['high'].values
    lows = df['low'].values
    swing_highs = []
    swing_lows = []
    n = len(df)
    for i in range(swing_window, n - swing_window):
        if all(highs[i] > highs[i - j] for j in range(1, swing_window + 1)) and \
           all(highs[i] > highs[i + j] for j in range(1, swing_window + 1)):
            swing_highs.append((i, highs[i]))
        if all(lows[i] < lows[i - j] for j in range(1, swing_window + 1)) and \
           all(lows[i] < lows[i + j] for j in range(1, swing_window + 1)):
            swing_lows.append((i, lows[i]))
    cache[cache_key] = (swing_highs, swing_lows)
    return swing_highs, swing_lows 