import pandas as pd
import numpy as np
import logging
import time
from functools import lru_cache
from typing import Dict, Optional, Tuple, List
from strategy.utils import is_in_killzone
from constants import CANDLE_COUNTS

# Try to import ta library for RSI calculation
try:
    import ta  # type: ignore
    TA_AVAILABLE = True
except ImportError:
    TA_AVAILABLE = False

logger = logging.getLogger("TradingBot.MarketStructure")

class MarketStructure:
    def __init__(self, config):
        self.config = config
        # Use all timeframes from config if available
        self.timeframes = config.get('trading', {}).get('timeframes', ["M1", "M5", "M15", "H1"])
        self.analysis_cache = {}
        self.last_analysis_time = 0
        self.analysis_interval = 0.1  # 100ms between full analyses
        self.ict_swing_window = 3
        self.killzones = config.get("killzones", {
            "ASIAN": {"start": "20:00", "end": "00:00", "weight": 0.5},
            "LONDON": {"start": "02:00", "end": "05:00", "weight": 1.0},
            "NEWYORK": {"start": "07:00", "end": "10:00", "weight": 1.0}
        })
        logger.info("Market Structure initialized with enhanced logic")

    # --- Simple Trend Methods ---
    def analyze(self, data_cache):
        try:
            current_time = time.time()
            if current_time - self.last_analysis_time < self.analysis_interval:
                return self.analysis_cache.get('last_analysis', {})
            start_time = time.time()
            analysis = {}
            for tf in self.timeframes:
                if tf in data_cache and not data_cache[tf].empty:
                    analysis[tf] = self._quick_trend_analysis(data_cache[tf])
            self.analysis_cache['last_analysis'] = analysis
            self.last_analysis_time = current_time
            analysis_time = time.time() - start_time
            if analysis_time > 0.05:
                logger.warning(f"Market structure analysis took {analysis_time:.3f}s")
            return analysis
        except Exception as e:
            logger.error(f"Error in market structure analysis: {str(e)}")
            return {}

    def _quick_trend_analysis(self, data):
        try:
            if len(data) < 50:
                logger.warning(f"_quick_trend_analysis: insufficient data (len={len(data)})")
                return {"bias": "NEUTRAL", "strength": 0}
            if data[['high', 'low', 'close']].isnull().any().any():
                logger.warning(f"_quick_trend_analysis: NaN detected in OHLCV. Data sample:\n{data[['high','low','close']].tail()}")
            # Use vectorized rolling means
            closes = data['close']
            sma20 = closes.rolling(window=20, min_periods=20).mean().iloc[-1]
            sma50 = closes.rolling(window=50, min_periods=50).mean().iloc[-1]
            current_price = closes.iloc[-1]
            if current_price > sma20 and sma20 > sma50:
                bias = "BULLISH"
                strength = min((current_price - sma20) / sma20 * 100, 100)
            elif current_price < sma20 and sma20 < sma50:
                bias = "BEARISH"
                strength = min((sma20 - current_price) / sma20 * 100, 100)
            else:
                bias = "NEUTRAL"
                strength = 0
            return {
                "bias": bias,
                "strength": strength,
                "sma20": sma20,
                "sma50": sma50
            }
        except Exception as e:
            logger.error(f"Error in quick trend analysis: {str(e)}")
            return {"bias": "NEUTRAL", "strength": 0}

    def get_market_bias(self, analysis):
        try:
            if not analysis:
                logger.info("get_market_bias: No analysis provided, returning NEUTRAL")
                return "NEUTRAL"
            # Loosen bias threshold in aggressive mode
            aggressive_mode = self.config.get('aggressive_mode', False)
            threshold = 10 if aggressive_mode else 20
            weights = {"M1": 0.1, "M5": 0.3, "M15": 0.3, "H1": 0.3}
            bias_score = 0
            for tf, weight in weights.items():
                if tf in analysis:
                    if analysis[tf]["bias"] == "BULLISH":
                        bias_score += weight * analysis[tf]["strength"]
                    elif analysis[tf]["bias"] == "BEARISH":
                        bias_score -= weight * analysis[tf]["strength"]
            logger.info(f"get_market_bias: bias_score={bias_score}, threshold={threshold}")
            if bias_score > threshold:
                return "BULLISH"
            elif bias_score < -threshold:
                return "BEARISH"
            logger.info("get_market_bias: bias_score within threshold, returning NEUTRAL")
            return "NEUTRAL"
        except Exception as e:
            logger.error(f"Error getting market bias: {str(e)}")
            return "NEUTRAL"

    def get_trend_simple(self, rates_cache):
        try:
            analysis = self.analyze(rates_cache)
            bias = self.get_market_bias(analysis)
            trend = bias.lower() if bias in ["BULLISH", "BEARISH"] else "neutral"
            details = {tf: analysis[tf] for tf in analysis}
            return {"trend": trend, "details": details, "mode": "simple"}
        except Exception as e:
            logger.error(f"Error in get_trend_simple: {str(e)}")
            return {"trend": "neutral", "details": {}, "mode": "simple"}

    # --- Advanced ICT/Structure Methods ---
    def detect_swings(self, df, swing_window=None):
        try:
            if swing_window is None:
                swing_window = self.ict_swing_window
            if df is None or len(df) < swing_window * 2 + 1:
                logger.info(f"ICT detect_swings: Not enough data (len={len(df) if df is not None else 'None'}), window={swing_window}")
                return {'high': [], 'low': []}
            highs = df['high'].rolling(swing_window, min_periods=swing_window).max()
            lows = df['low'].rolling(swing_window, min_periods=swing_window).min()
            swing_highs = []
            swing_lows = []
            for i in range(swing_window-1, len(df)):
                if not np.isnan(highs.iloc[i]):
                    swing_highs.append((i, highs.iloc[i]))
                if not np.isnan(lows.iloc[i]):
                    swing_lows.append((i, lows.iloc[i]))
            logger.info(f"ICT detect_swings: Found {len(swing_highs)} swing_highs, {len(swing_lows)} swing_lows (len={len(df)}, window={swing_window})")
            return {'high': swing_highs, 'low': swing_lows}
        except Exception as e:
            logger.error(f"Error in detect_swings: {str(e)}")
            return {'high': [], 'low': []}

    def detect_market_structure(self, df):
        swings = self.detect_swings(df, self.ict_swing_window)
        swing_highs, swing_lows = swings['high'], swings['low']
        import numpy as np
        if isinstance(swing_highs, (float, np.floating)) or isinstance(swing_lows, (float, np.floating)):
            logger.info("detect_market_structure: swing_highs or swing_lows is scalar, returning sideways")
            return "sideways"
        if not hasattr(swing_highs, '__len__') or not hasattr(swing_lows, '__len__'):
            logger.info("detect_market_structure: swing_highs or swing_lows not list-like, returning sideways")
            return "sideways"
        if len(swing_highs) < 2 or len(swing_lows) < 2:
            logger.info(f"detect_market_structure: not enough swings (highs={len(swing_highs)}, lows={len(swing_lows)}), returning sideways")
            return "sideways"
        # If list-like, extract last and previous values
        last_high, prev_high = swing_highs[-1][1], swing_highs[-2][1]
        last_low, prev_low = swing_lows[-1][1], swing_lows[-2][1]
        if last_high > prev_high and last_low > prev_low:
            return "bullish"
        elif last_high < prev_high and last_low < prev_low:
            return "bearish"
        else:
            return "sideways"

    def detect_bos_choch(self, candles):
        """
        Returns True if lower timeframe phase is compatible with higher timeframe phase.
        E.g., Re-Accumulation (lower) only if higher is Accumulation or Re-Accumulation.
        """
        # This method was incorrectly implemented - it should detect BOS/CHoCH patterns
        # For now, return None as placeholder
        return None

    def is_in_killzone(self, current_time):
        """
        Proxy to shared is_in_killzone utility using self.killzones.
        """
        return is_in_killzone(current_time, self.killzones)

    def get_ltf_phase(self, m5, m15):
        """
        Determine the LTF phase by majority vote between M5 and M15 structures.
        Returns 'bullish', 'bearish', or 'sideways'.
        """
        m5_structure = self.detect_market_structure(m5)
        m15_structure = self.detect_market_structure(m15)
        votes = [m5_structure, m15_structure]
        if votes.count('bullish') >= 1 and votes.count('bullish') > votes.count('bearish'):
            return 'bullish'
        elif votes.count('bearish') >= 1 and votes.count('bearish') > votes.count('bullish'):
            return 'bearish'
        else:
            return 'sideways'

    def get_trend_ict(self, rates_cache):
        try:
            # Use M5, M15, H1, D1 for multi-timeframe bias
            m5 = rates_cache.get("M5")
            m15 = rates_cache.get("M15")
            h1 = rates_cache.get("H1")
            d1 = rates_cache.get("D1")
            if m5 is None or m15 is None or h1 is None or d1 is None or m5.empty or m15.empty or h1.empty or d1.empty:
                return {"trend": "neutral", "details": {"reason": "Missing data"}, "mode": "ict"}
            # LTF structure
            m5_structure = self.detect_market_structure(m5)
            m15_structure = self.detect_market_structure(m15)
            ltf_phase = self.get_ltf_phase(m5, m15)
            htf_structure = self.detect_market_structure(h1)
            d1_structure = self.detect_market_structure(d1)
            # Bias logic: require at least 3/4 to agree
            structure_votes = [m5_structure, m15_structure, htf_structure, d1_structure]
            bullish_votes = structure_votes.count("bullish")
            bearish_votes = structure_votes.count("bearish")
            if bullish_votes >= 3:
                trend = "bullish"
            elif bearish_votes >= 3:
                trend = "bearish"
            else:
                trend = "neutral"
            details = {
                "m5_structure": m5_structure,
                "m15_structure": m15_structure,
                "ltf_phase": ltf_phase,
                "htf_structure": htf_structure,
                "d1_structure": d1_structure,
                "structure_votes": structure_votes
            }
            return {"trend": trend, "details": details, "mode": "ict"}
        except Exception as e:
            logger.error(f"Error in get_trend_ict: {str(e)}")
            return {"trend": "neutral", "details": {"reason": str(e)}, "mode": "ict"}

    # --- Unified Interface ---
    def _ict_pullback_signals(self, df, trend_direction, bos_level, ict_zones, momentum):
        signals = []
        import numpy as np
        import pandas as pd
        price = df['close'].iloc[-1] if hasattr(df['close'], 'iloc') else df['close'][-1]
        atr = self._calculate_atr(df)
        # Ensure atr is a float
        if isinstance(atr, (np.ndarray, list, tuple)):
            atr_val = float(atr[-1])
        else:
            atr_val = float(atr)
        # Uptrend: price above BoS, pullback into zone, RSI oversold
        if trend_direction == 'up' and price > bos_level:
            for zone in ict_zones:
                rsi_val = None
                if isinstance(momentum, pd.Series):
                    rsi_val = momentum.iloc[-1]
                elif isinstance(momentum, (np.ndarray, list, tuple)):
                    rsi_val = momentum[-1]
                else:
                    try:
                        rsi_val = float(momentum)
                    except Exception:
                        rsi_val = 50.0
                if zone['start'] <= price <= zone['end']:
                    if rsi_val < 35:  # Tuned RSI oversold
                        signals.append({
                            'type': 'ict_pullback_buy',
                            'entry': price,
                            'stop_loss': float(zone['start']) - atr_val * 0.2,
                            'take_profit': bos_level + (bos_level - float(zone['start'])),
                            'confidence': 85,
                            'direction': 'buy',
                        })
                        logger.info(f"ICT pullback buy: price={price}, zone={zone}, bos={bos_level}, rsi={rsi_val}")
                    break
        # Downtrend: price below BoS, pullback into zone, RSI overbought
        elif trend_direction == 'down' and price < bos_level:
            for zone in ict_zones:
                rsi_val = None
                if isinstance(momentum, pd.Series):
                    rsi_val = momentum.iloc[-1]
                elif isinstance(momentum, (np.ndarray, list, tuple)):
                    rsi_val = momentum[-1]
                else:
                    try:
                        rsi_val = float(momentum)
                    except Exception:
                        rsi_val = 50.0
                if zone['start'] <= price <= zone['end']:
                    if rsi_val > 65:  # Tuned RSI overbought
                        signals.append({
                            'type': 'ict_pullback_sell',
                            'entry': price,
                            'stop_loss': float(zone['end']) + atr_val * 0.2,
                            'take_profit': bos_level - (float(zone['end']) - bos_level),
                            'confidence': 85,
                            'direction': 'sell',
                        })
                        logger.info(f"ICT pullback sell: price={price}, zone={zone}, bos={bos_level}, rsi={rsi_val}")
                    break
        return signals

    def _confirm_m15_signal_with_ltf(self, data_cache, signal):
        """Confirm M15 signal direction with M5 or M1 structure alignment."""
        m5 = data_cache.get('M5')
        m1 = data_cache.get('M1')
        direction = signal.get('direction') or signal.get('type', '').split('_')[-1]
        # Normalize direction
        if isinstance(direction, str):
            direction = direction.lower()
        # Use MarketStructure's detect_market_structure
        ltf_dir = None
        if m5 is not None and not m5.empty:
            ltf_dir = self.detect_market_structure(m5)
        elif m1 is not None and not m1.empty:
            ltf_dir = self.detect_market_structure(m1)
        # Only confirm if ltf_dir is bullish/bearish and matches signal
        if ltf_dir in ['bullish', 'bearish']:
            if (direction == 'buy' and ltf_dir == 'bullish') or (direction == 'sell' and ltf_dir == 'bearish'):
                return True
        return False

    def get_trend(self, data_cache: Dict[str, pd.DataFrame]) -> Dict:
        import pandas as pd
        import numpy as np
        try:
            # Validate data
            required = ['D1', 'H4', 'H1', 'M15', 'M5']
            min_rows = 10
            missing = []
            for tf in required:
                df = data_cache.get(tf)
                if df is None or not isinstance(df, pd.DataFrame) or df.empty:
                    logger.error(f"get_trend: {tf} is invalid: type={type(df)}, value={repr(df)}")
                    missing.append(tf)
                elif len(df) < min_rows:
                    logger.error(f"get_trend: {tf} has too few rows: {len(df)}")
                    missing.append(tf)
            if missing:
                return {
                    'phase': 'error',
                    'confidence': 0,
                    'signals': [],
                    'key_levels': {},
                    'reason': f"Missing/invalid/too short: {missing}",
                    'ltfs_signals': [],
                    'htfs_signals': []
                }
            d1, h4, h1, m15, m5 = [data_cache[tf] for tf in required]
            current_price = m15['close'].iloc[-1]

            # 1. Detect Market Phase
            market_phase = self._detect_market_phase(d1, h4, h1, m15)
            logger.debug(f"get_trend: Detected market phase: {market_phase}")
            # 2. Generate Trading Signals
            signals = []
            if market_phase['phase'] == 'sideways':
                signals.extend(self._range_trading_signals(m15))
                logger.debug(f"get_trend: Range trading signals: {signals}")
            elif market_phase['phase'] == 'trending':
                signals.extend(self._trend_following_signals(h4, h1, m15))
                # --- ICT pullback signals ---
                trend_dir = 'up' if current_price > market_phase['key_levels'].get('support', 0) else 'down'
                from strategy.entry_signals import EntrySignals
                entry_signals = EntrySignals(self.config)
                fvg_zones = []
                for fvg in entry_signals._detect_fvg(m15):
                    if fvg['type'].lower().startswith('bullish'):
                        fvg_zones.append({'start': fvg['bottom'], 'end': fvg['top']})
                    elif fvg['type'].lower().startswith('bearish'):
                        fvg_zones.append({'start': fvg['top'], 'end': fvg['bottom']})
                ob_zones = []
                for ob in entry_signals._detect_order_block(m15):
                    if ob['type'].lower().startswith('bullish'):
                        ob_zones.append({'start': ob['bottom'], 'end': ob['top']})
                    elif ob['type'].lower().startswith('bearish'):
                        ob_zones.append({'start': ob['top'], 'end': ob['bottom']})
                ict_zones = fvg_zones + ob_zones
                bos_level = None
                if trend_dir == 'up':
                    bos_level = entry_signals._get_recent_swing_high(m15, lookback=20)
                else:
                    bos_level = entry_signals._get_recent_swing_low(m15, lookback=20)
                try:
                    # Only pass a Series to _calculate_rsi
                    close_series = m15['close']
                    if TA_AVAILABLE:
                        rsi = ta.momentum.rsi(close_series, window=14)
                    else:
                        if isinstance(close_series, pd.Series):
                            rsi = self._calculate_rsi(close_series, period=14)
                        else:
                            rsi = 50.0
                except Exception:
                    rsi = 50.0  # Default neutral RSI
                signals.extend(self._ict_pullback_signals(m15, trend_dir, bos_level, ict_zones, rsi))
                logger.debug(f"get_trend: Trend following + ICT signals: {signals}")
            elif market_phase['phase'] == 'reversal':
                signals.extend(self._reversal_signals(h4, h1, m15))
                logger.debug(f"get_trend: Reversal signals: {signals}")
            # --- Filter M15 signals by LTF confirmation ---
            filtered_signals = []
            for s in signals:
                tf = s.get('timeframe', 'M15')
                if tf == 'M15' or (s.get('entry') and s.get('type', '').endswith('_buy') or s.get('type', '').endswith('_sell')):
                    if self._confirm_m15_signal_with_ltf(data_cache, s):
                        s['source'] = 'structure'
                        filtered_signals.append(s)
                        logger.info(f"M15 signal confirmed by LTF: {s}")
                    else:
                        logger.info(f"M15 signal rejected by LTF: {s}")
                else:
                    s['source'] = 'structure'
                    filtered_signals.append(s)
            # --- Integrate LTF/HTF signals ---
            ltfs_signals = self.get_ltf_entry_signals(data_cache)
            for s in ltfs_signals:
                s['source'] = 'ltfs'
            htfs_signals = self.get_htf_entry_signals(data_cache)
            for s in htfs_signals:
                s['source'] = 'htfs'
            # Optionally, combine all for a master list
            all_signals = filtered_signals + ltfs_signals + htfs_signals
            # --- Add M5 and M15 structure to details ---
            m5_structure = self.detect_market_structure(m5)
            m15_structure = self.detect_market_structure(m15)
            ltf_phase = self.get_ltf_phase(m5, m15)
            if not all_signals:
                logger.info(f"get_trend: No trading signals generated for phase {market_phase['phase']} (confidence={market_phase['confidence']}). Key levels: {market_phase['key_levels']}")
            return {
                'phase': market_phase.get('phase', 'error'),
                'confidence': market_phase.get('confidence', 0),
                'signals': all_signals,
                'key_levels': market_phase.get('key_levels', {}),
                'ltfs_signals': ltfs_signals,
                'htfs_signals': htfs_signals,
                'm5_structure': m5_structure,
                'm15_structure': m15_structure,
                'ltf_phase': ltf_phase
            }
        except Exception as e:
            import traceback
            logger.error(f"Trend analysis error: {e}\n{traceback.format_exc()}")
            return {
                'phase': 'error',
                'confidence': 0,
                'signals': [],
                'key_levels': {},
                'ltfs_signals': [],
                'htfs_signals': []
            }

    def _detect_market_phase(self, d1, h4, h1, m15) -> Dict:
        """Determine if market is trending, sideways, or reversing"""
        import pandas as pd
        logger = logging.getLogger("MarketStructure")
        try:
            # Data validation
            min_rows = 10
            for name, df in zip(['D1', 'H4', 'H1', 'M15'], [d1, h4, h1, m15]):
                if df is None or not isinstance(df, pd.DataFrame) or df.empty:
                    logger.error(f"_detect_market_phase: {name} is invalid: type={type(df)}, value={repr(df)}")
                    return {'phase': 'error', 'confidence': 0, 'key_levels': {}, 'signals': [], 'reason': f'{name} missing or invalid'}
                elif len(df) < min_rows:
                    logger.error(f"_detect_market_phase: {name} has too few rows: {len(df)}")
                    return {'phase': 'error', 'confidence': 0, 'key_levels': {}, 'signals': [], 'reason': f'{name} too few rows'}
            atr = m15['high'].iloc[-14:-1].max() - m15['low'].iloc[-14:-1].min()
            d1_range = d1['high'].iloc[-1] - d1['low'].iloc[-1]
            logger.debug(f"_detect_market_phase: ATR={atr}, D1 range={d1_range}")
            if d1_range > 0 and atr < d1_range * 0.15:  # Low volatility condition
                logger.info(f"_detect_market_phase: Sideways detected (ATR={atr}, D1 range={d1_range})")
                return {
                    'phase': 'sideways',
                    'confidence': 90,
                    'key_levels': {
                        'support': m15['low'].iloc[-14:].min(),
                        'resistance': m15['high'].iloc[-14:].max(),
                        'midpoint': (m15['high'].iloc[-14:].max() + m15['low'].iloc[-14:].min()) / 2
                    }
                }
            # 2. Check for Reversals (CHoCH)
            reversal = self._detect_reversal(h4, h1)
            logger.debug(f"_detect_market_phase: Reversal check: {reversal}")
            if reversal['confirmed']:
                logger.info(f"_detect_market_phase: Reversal detected: {reversal}")
                return {
                    'phase': 'reversal',
                    'confidence': reversal['confidence'],
                    'key_levels': reversal['key_levels']
                }
            # 3. Default to Trending Market
            logger.info(f"_detect_market_phase: Trending detected (ATR={atr}, D1 range={d1_range})")
            return {
                'phase': 'trending',
                'confidence': 75,
                'key_levels': self._find_pivot_levels(h4)
            }
        except Exception as e:
            import traceback
            logger.error(f"Market phase detection error: {e}\n{traceback.format_exc()}")
            return {'phase': 'error', 'confidence': 0, 'key_levels': {}, 'signals': [], 'reason': str(e)}

    def _range_trading_signals(self, m15: pd.DataFrame) -> List[Dict]:
        """Generate signals for range-bound markets"""
        signals = []
        resistance = m15['high'].iloc[-14:].max()
        support = m15['low'].iloc[-14:].min()
        midpoint = (resistance + support) / 2
        current = m15['close'].iloc[-1]
        logger.debug(f"_range_trading_signals: support={support}, resistance={resistance}, midpoint={midpoint}, current={current}")
        # 1. Buy at support with confirmation
        if current <= support * 1.001:
            logger.info(f"_range_trading_signals: Limit buy signal at support {support}")
            signals.append({
                'type': 'limit_buy',
                'entry': support,
                'stop_loss': support * 0.998,
                'take_profit': midpoint,
                'confidence': 80,
                'direction': 'buy',
            })
        # 2. Sell at resistance with confirmation
        elif current >= resistance * 0.999:
            logger.info(f"_range_trading_signals: Limit sell signal at resistance {resistance}")
            signals.append({
                'type': 'limit_sell',
                'entry': resistance,
                'stop_loss': resistance * 1.002,
                'take_profit': midpoint,
                'confidence': 80,
                'direction': 'sell',
            })
        # 3. Fade extreme moves (for aggressive traders)
        rsi = self._calculate_rsi(m15['close'].iloc[-14:])
        logger.debug(f"_range_trading_signals: RSI={rsi}")
        if rsi > 70:
            logger.info(f"_range_trading_signals: Mean reversion sell signal (RSI={rsi})")
            signals.append({
                'type': 'mean_reversion_sell',
                'entry': current,
                'stop_loss': resistance * 1.005,
                'take_profit': midpoint,
                'confidence': 65,
                'direction': 'sell',
            })
        elif rsi < 30:
            logger.info(f"_range_trading_signals: Mean reversion buy signal (RSI={rsi})")
            signals.append({
                'type': 'mean_reversion_buy',
                'entry': current,
                'stop_loss': support * 0.995,
                'take_profit': midpoint,
                'confidence': 65,
                'direction': 'buy',
            })
        if not signals:
            logger.info(f"_range_trading_signals: No range signals generated (current={current}, support={support}, resistance={resistance}, RSI={rsi})")
        return signals

    def _calculate_atr(self, df, period=14):
        import numpy as np
        if df is None or len(df) < period + 1:
            return np.nan
        high_low = df['high'] - df['low']
        high_close = np.abs(df['high'] - df['close'].shift())
        low_close = np.abs(df['low'] - df['close'].shift())
        tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
        atr = tr.rolling(window=period, min_periods=period).mean()
        # Return as float (last value)
        if hasattr(atr, 'values') and not isinstance(atr, np.ndarray):
            arr = atr.values
            return float(arr[-1]) if hasattr(arr, '__len__') and len(arr) > 0 else np.nan
        elif isinstance(atr, (np.ndarray, list, tuple)):
            return float(atr[-1]) if len(atr) > 0 else np.nan
        elif hasattr(atr, 'iloc') and hasattr(atr, '__len__') and len(atr) > 0:
            return float(atr.iloc[-1])
        else:
            try:
                return float(atr)
            except Exception:
                return np.nan

    def _trend_following_signals(self, h4, h1, m15) -> List[Dict]:
        signals = []
        import numpy as np
        import pandas as pd
        h4_high = h4['high'].rolling(20).max()
        h4_low = h4['low'].rolling(20).min()
        close = m15['close'].iloc[-1] if hasattr(m15['close'], 'iloc') else m15['close'][-1]
        atr = self._calculate_atr(m15)
        if isinstance(atr, (np.ndarray, list, tuple)):
            atr_val = float(atr[-1])
        else:
            atr_val = float(atr)
        def get_last(val):
            if isinstance(val, pd.Series):
                return float(val.iloc[-1])
            elif isinstance(val, (np.ndarray, list, tuple)) and len(val) > 0:
                return float(val[-1])
            elif isinstance(val, (float, int, np.floating, np.integer)):
                return float(val)
            else:
                return float('nan')
        def get_prev(val):
            if isinstance(val, pd.Series):
                return float(val.iloc[-2])
            elif isinstance(val, (np.ndarray, list, tuple)) and len(val) > 1:
                return float(val[-2])
            elif isinstance(val, (float, int, np.floating, np.integer)):
                return float(val)
            else:
                return float('nan')
        h4_high_last = get_last(h4_high)
        h4_low_last = get_last(h4_low)
        h4_high_prev = get_prev(h4_high)
        h4_low_prev = get_prev(h4_low)
        close = float(close)
        if any(np.isnan(x) for x in [h4_high_last, h4_low_last, h4_high_prev, h4_low_prev, close]):
            logger.warning("_trend_following_signals: Skipping due to NaN in key values.")
            return signals
        h4_range = h4_high_last - h4_low_last
        upper_30 = h4_high_last - 0.3 * h4_range
        lower_30 = h4_low_last + 0.3 * h4_range
        m15_swings = self.detect_swings(m15, 340)
        # Ensure swing_high and swing_low are floats
        swing_high = m15_swings.get('high', float('nan'))
        swing_low = m15_swings.get('low', float('nan'))
        try:
            swing_high = float(swing_high)
        except Exception:
            swing_high = float('nan')
        try:
            swing_low = float(swing_low)
        except Exception:
            swing_low = float('nan')
        atr = float(atr)
        import math
        # Add debug logging for all intermediate values in _trend_following_signals
        logger.debug(f"[DEBUG] swing_high={swing_high}, swing_low={swing_low}, atr={atr}, close={close}, upper_30={upper_30}, lower_30={lower_30}, h4_range={h4_range}")
        if any([
            swing_high is None, swing_low is None,
            math.isnan(swing_high), math.isnan(swing_low),
            math.isinf(swing_high), math.isinf(swing_low),
            atr is None, math.isnan(atr), math.isinf(atr)
        ]):
            logger.warning(f"[SKIP] _trend_following_signals: NaN/None/inf detected. swing_high={swing_high}, swing_low={swing_low}, atr={atr}")
            return signals
        if close >= upper_30:
            logger.info(f"[SIGNAL] Trend continuation buy: close={close}, upper_30={upper_30}, swing_low={swing_low}, atr={atr}, h4_range={h4_range}")
            logger.debug(f"[SIGNAL-DICT] {{'type': 'trend_continuation_buy', 'entry': {close}, 'stop_loss': {swing_low - atr}, 'take_profit': {close + h4_range * 0.8}, 'confidence': 70}}")
            signals.append({
                'type': 'trend_continuation_buy',
                'entry': close,
                'stop_loss': swing_low - atr,
                'take_profit': close + h4_range * 0.8,
                'confidence': 70
            })
        elif close <= lower_30:
            logger.info(f"[SIGNAL] Trend continuation sell: close={close}, lower_30={lower_30}, swing_high={swing_high}, atr={atr}, h4_range={h4_range}")
            logger.debug(f"[SIGNAL-DICT] {{'type': 'trend_continuation_sell', 'entry': {close}, 'stop_loss': {swing_high + atr}, 'take_profit': {close - h4_range * 0.8}, 'confidence': 70}}")
            signals.append({
                'type': 'trend_continuation_sell',
                'entry': close,
                'stop_loss': swing_high + atr,
                'take_profit': close - h4_range * 0.8,
                'confidence': 70
            })
        if not signals:
            logger.info(f"_trend_following_signals: No trend signals generated (close={close}, h4_high={h4_high_last}, h4_low={h4_low_last})")
        return signals

    def _reversal_signals(self, h4, h1, m15) -> List[Dict]:
        signals = []
        reversal = self._detect_reversal(h4, h1)
        if reversal['confirmed']:
            if reversal['direction'] == 'bullish':
                logger.info(f"_reversal_signals: Reversal buy signal (trigger={reversal['trigger_price']})")
                signals.append({
                    'type': 'reversal_buy',
                    'entry': reversal['trigger_price'],
                    'stop_loss': reversal['key_levels']['swing_low'],
                    'take_profit': reversal['key_levels']['previous_high'],
                    'confidence': reversal['confidence'],
                    'direction': 'buy',
                })
            else:
                logger.info(f"_reversal_signals: Reversal sell signal (trigger={reversal['trigger_price']})")
                signals.append({
                    'type': 'reversal_sell',
                    'entry': reversal['trigger_price'],
                    'stop_loss': reversal['key_levels']['swing_high'],
                    'take_profit': reversal['key_levels']['previous_low'],
                    'confidence': reversal['confidence'],
                    'direction': 'sell',
                })
        if not signals:
            logger.info(f"_reversal_signals: No reversal signals generated (reversal={reversal})")
        return signals

    def _detect_swings(self, df, window=3):
        import numpy as np
        if df is None or len(df) < window + 2:
            logger.info(f"_detect_swings: DataFrame too short (len={len(df) if df is not None else 'None'}), window={window}")
            return {'high': [], 'low': []}
        highs = df['high'].rolling(window, min_periods=window).max()
        lows = df['low'].rolling(window, min_periods=window).min()
        swing_highs = []
        swing_lows = []
        for i in range(window-1, len(df)):
            if not np.isnan(highs.iloc[i]):
                swing_highs.append((i, highs.iloc[i]))
            if not np.isnan(lows.iloc[i]):
                swing_lows.append((i, lows.iloc[i]))
        logger.info(f"_detect_swings: Found {len(swing_highs)} swing_highs, {len(swing_lows)} swing_lows (len={len(df)}, window={window})")
        return {'high': swing_highs, 'low': swing_lows}

    def _detect_reversal(self, h4, h1) -> Dict:
        import pandas as pd
        # Ensure h4 and h1 are DataFrames
        if not (isinstance(h4, pd.DataFrame) and isinstance(h1, pd.DataFrame)):
            return {'confirmed': False}
        h4_swings = self._detect_swings(h4, 5)
        h1_swings = self._detect_swings(h1, 3)
        # Helper to compute direction from swings
        def get_direction(swings):
            if not isinstance(swings, dict):
                return 'sideways'
            if not isinstance(swings.get('high'), list) or not isinstance(swings.get('low'), list):
                return 'sideways'
            if len(swings['high']) < 2 or len(swings['low']) < 2:
                return 'sideways'
            last_high, prev_high = swings['high'][-1][1], swings['high'][-2][1]
            last_low, prev_low = swings['low'][-1][1], swings['low'][-2][1]
            if last_high > prev_high and last_low > prev_low:
                return 'down'
            elif last_high < prev_high and last_low < prev_low:
                return 'up'
            else:
                return 'sideways'
        h4_direction = get_direction(h4_swings)
        h1_direction = get_direction(h1_swings)
        # Ensure enough swings for index access
        def safe_get_last(swings, key):
            if isinstance(swings.get(key), list) and len(swings[key]) > 0:
                return swings[key][-1][1]
            return None
        def safe_get_prev(swings, key):
            if isinstance(swings.get(key), list) and len(swings[key]) > 1:
                return swings[key][-2][1]
            elif isinstance(swings.get(key), list) and len(swings[key]) > 0:
                return swings[key][-1][1]
            return None
        # Check h1['close'] is a pandas Series and has enough data
        h1_close = h1['close'] if h1 is not None and isinstance(h1, pd.DataFrame) and 'close' in h1 else None
        if not isinstance(h1_close, pd.Series) or len(h1_close) < 2:
            return {'confirmed': False}
        # Bullish reversal
        h1_high_last = safe_get_last(h1_swings, 'high')
        h1_low_last = safe_get_last(h1_swings, 'low')
        h4_low_last = safe_get_last(h4_swings, 'low')
        h4_high_last = safe_get_last(h4_swings, 'high')
        h4_high_prev = safe_get_prev(h4_swings, 'high')
        h4_low_prev = safe_get_prev(h4_swings, 'low')
        if (h4_direction == 'down' and
            h1_high_last is not None and h1_low_last is not None and h4_low_last is not None and
            h1_close.iloc[-1] is not None and h1_close.iloc[-2] is not None and
            h1_close.iloc[-1] > h1_high_last and
            h1_close.iloc[-2] < h4_low_last * 0.999):
            return {
                'confirmed': True,
                'direction': 'bullish',
                'trigger_price': h1_high_last,
                'key_levels': {
                    'swing_low': h4_low_last,
                    'previous_high': h4_high_prev
                },
                'confidence': 75
            }
        # Bearish reversal
        if (h4_direction == 'up' and
            h1_high_last is not None and h1_low_last is not None and h4_high_last is not None and
            h1_close.iloc[-1] is not None and h1_close.iloc[-2] is not None and
            h1_close.iloc[-1] < h1_low_last and
            h1_close.iloc[-2] > h4_high_last * 1.001):
            return {
                'confirmed': True,
                'direction': 'bearish',
                'trigger_price': h1_low_last,
                'key_levels': {
                    'swing_high': h4_high_last,
                    'previous_low': h4_low_prev
                },
                'confidence': 75
            }
        return {'confirmed': False}

    def _find_pivot_levels(self, df):
        try:
            support = df['low'].iloc[-14:].min() if len(df) >= 14 else df['low'].min()
            resistance = df['high'].iloc[-14:].max() if len(df) >= 14 else df['high'].max()
            return {'support': support, 'resistance': resistance}
        except Exception as e:
            logger.error(f"Pivot level error: {e}")
            return {'support': None, 'resistance': None}

    def _calculate_rsi(self, series: pd.Series, period=14) -> float:
        # Ensure input is a Series
        if not isinstance(series, pd.Series):
            try:
                series = pd.Series(series)
            except Exception:
                return 50.0
        delta = series.diff()
        gain = delta.where(delta > 0, 0).rolling(period).mean()
        loss = -delta.where(delta < 0, 0).rolling(period).mean()
        rs = gain / loss
        rsi = 100 - (100 / (1 + rs))
        val = None
        if hasattr(rsi, 'iloc') and hasattr(rsi, '__len__') and len(rsi) > 0:
            val = rsi.iloc[-1]
        elif isinstance(rsi, np.ndarray) and hasattr(rsi, '__len__') and rsi.size > 0:
            val = rsi[-1]
        elif isinstance(rsi, float):
            val = rsi
        if val is None or (isinstance(val, float) and (np.isnan(val) or not np.isfinite(val))):
            return 50.0
        return float(val)

    def is_phase_aligned(self, lower_tf_phase, higher_tf_phase):
        """
        Returns True if lower timeframe phase is compatible with higher timeframe phase.
        E.g., Re-Accumulation (lower) only if higher is Accumulation or Re-Accumulation.
        """
        bullish_phases = ["Accumulation", "Re-Accumulation"]
        bearish_phases = ["Distribution", "Re-Distribution"]
        if lower_tf_phase in bullish_phases and higher_tf_phase in bullish_phases:
            return True
        if lower_tf_phase in bearish_phases and higher_tf_phase in bearish_phases:
                return True
        return False

    def main_trading_loop(self, symbol="EURUSD", ltf="15M", htf="4H", get_current_time=None, fetch_candles=None, execute_order=None, sleep_fn=None):
        """
        Main trading loop for IFVG/BPR strategy with killzone filtering.
        Args:
            symbol: Trading symbol
            ltf: Lower timeframe string
            htf: Higher timeframe string
            get_current_time: function returning current time in 'HH:MM' NY time
            fetch_candles: function(symbol, tf) -> list of candles
            execute_order: function(symbol, direction, entry, stop_loss, take_profit)
            sleep_fn: function(seconds) for sleeping (default: time.sleep)
        """
        import time
        if sleep_fn is None:
            sleep_fn = time.sleep
        while True:
            if get_current_time is None or fetch_candles is None or execute_order is None:
                raise ValueError("get_current_time, fetch_candles, and execute_order must be provided.")
            current_time = get_current_time("NY")
            in_kz_name, in_kz_meta = self.is_in_killzone(current_time)
            if in_kz_meta:
                ltf_candles = fetch_candles(symbol, ltf)
                ht_candles = fetch_candles(symbol, htf)
                # ICT-only logic would go here (e.g., detect_high_prob_ifvg, generate_entry)
                # This is a placeholder for ICT-specific strategies.
            sleep_fn(300)  # 5 minutes

    def get_ltf_bias(self, data_cache):
        """Return LTF bias using M30, M15, and M5 (majority vote, fallback to NEUTRAL)."""
        votes = []
        for tf in ["M30", "M15", "M5"]:
            df = data_cache.get(tf)
            if df is not None and not df.empty:
                bias = self.detect_market_structure(df)
                votes.append(bias)
        if votes.count("bullish") > votes.count("bearish"):
            return "bullish"
        elif votes.count("bearish") > votes.count("bullish"):
            return "bearish"
        return "neutral"

    def get_htf_bias(self, data_cache):
        """Return HTF bias using H4 and D1 (majority vote, fallback to NEUTRAL)."""
        votes = []
        for tf in ["H4", "D1"]:
            df = data_cache.get(tf)
            if df is not None and not df.empty:
                bias = self.detect_market_structure(df)
                votes.append(bias)
        if votes.count("bullish") > votes.count("bearish"):
            return "bullish"
        elif votes.count("bearish") > votes.count("bullish"):
            return "bearish"
        return "neutral"

    def get_ltf_entry_signals(self, data_cache):
        import numpy as np
        import pandas as pd
        ltf_bias = self.get_ltf_bias(data_cache)
        logger.info(f"[LTF-DEBUG] LTF bias: {ltf_bias}")
        signals = []
        for tf in ["M5", "M1"]:
            df = data_cache.get(tf)
            if df is not None and not df.empty:
                entry_dir = self.detect_market_structure(df)
                logger.info(f"[LTF-DEBUG] {tf} entry_dir: {entry_dir}")
                if entry_dir == ltf_bias and entry_dir in ["bullish", "bearish"]:
                    entry_price = df['close'].iloc[-1] if hasattr(df['close'], 'iloc') and hasattr(df['close'], '__len__') and len(df['close']) > 0 else (df['close'][-1] if hasattr(df['close'], '__getitem__') and len(df['close']) > 0 else None)
                    try:
                        entry_price = float(entry_price)
                    except Exception:
                        entry_price = float('nan')
                    atr = self._calculate_atr(df)
                    if isinstance(atr, (np.ndarray, list, tuple)) and len(atr) > 0:
                        atr_val = float(atr[-1])
                    elif hasattr(atr, 'iloc') and hasattr(atr, '__len__') and len(atr) > 0:
                        atr_val = float(atr.iloc[-1])
                    elif isinstance(atr, (float, int)):
                        atr_val = float(atr)
                    else:
                        atr_val = float('nan')
                    # ATR-based TP - increased multiplier to ensure minimum 1.5 RR
                    if entry_dir == "bullish":
                        atr_tp = entry_price + atr_val * 3.0  # Increased from 1.5 to 3.0
                    else:
                        atr_tp = entry_price - atr_val * 3.0  # Increased from 1.5 to 3.0
                    # Swing-based TP from M15/M30
                    swing_tp = None
                    for swing_tf in ["M15", "M30"]:
                        swing_df = data_cache.get(swing_tf)
                        if swing_df is not None and not swing_df.empty:
                            if entry_dir == "bullish":
                                swing = swing_df['high'].rolling(20).max()
                                if isinstance(swing, pd.Series) and hasattr(swing, '__len__') and len(swing) > 0:
                                    swing_val = float(swing.iloc[-1])
                                elif isinstance(swing, (np.ndarray, list, tuple)) and len(swing) > 0:
                                    swing_val = float(swing[-1])
                                elif isinstance(swing, (float, int)):
                                    swing_val = float(swing)
                                else:
                                    swing_val = None
                            else:
                                swing = swing_df['low'].rolling(20).min()
                                if isinstance(swing, pd.Series) and hasattr(swing, '__len__') and len(swing) > 0:
                                    swing_val = float(swing.iloc[-1])
                                elif isinstance(swing, (np.ndarray, list, tuple)) and len(swing) > 0:
                                    swing_val = float(swing[-1])
                                elif isinstance(swing, (float, int)):
                                    swing_val = float(swing)
                                else:
                                    swing_val = None
                            if swing_tp is None and swing_val is not None:
                                swing_tp = swing_val
                    if any([entry_price is None, np.isnan(entry_price), np.isnan(atr_val)]):
                        logger.warning(f"get_ltf_entry_signals: Skipping due to invalid entry_price or atr_val: entry_price={entry_price}, atr_val={atr_val}")
                        continue
                    signals.append({
                        'type': f'ltf_{entry_dir}',
                        'entry': entry_price,
                        'stop_loss': entry_price - atr_val if entry_dir == "bullish" else entry_price + atr_val,
                        'take_profit': swing_tp if swing_tp is not None else atr_tp,
                        'confidence': 60,
                        'timeframe': tf
                    })
        return signals

    def get_htf_entry_signals(self, data_cache):
        """Generate entry signals on H1 if HTF bias matches direction."""
        htf_bias = self.get_htf_bias(data_cache)
        signals = []
        tf = "H1"
        df = data_cache.get(tf)
        if df is not None and not df.empty:
            entry_dir = self.detect_market_structure(df)
            if entry_dir == htf_bias and entry_dir in ["bullish", "bearish"]:
                signals.append({
                    "timeframe": tf,
                    "direction": entry_dir,
                    "type": f"htf_entry_{entry_dir}",
                    "confidence": 80
                })
        return signals

class MarketStructureICT:
    # --- IFVG/Entry Strategy Core Configuration ---
    FVG_WICK_RATIO = 0.7  # Minimum wick size ratio (70% of candle range)
    DISCOUNT_ZONE_PCT = 0.5  # Lower 50% of a bullish leg = Discount Zone
    PREMIUM_ZONE_PCT = 0.5  # Upper 50% of a bullish leg = Premium Zone
    KILLZONES = {
        "ASIAN": {"start": "20:00", "end": "00:00"},  # NY Time
        "LONDON": {"start": "02:00", "end": "05:00"},
        "NEWYORK": {"start": "07:00", "end": "10:00"}
    }
    ENTRY_TYPES = ["BODY_CLOSURE", "IFVG_RETRACE", "BPR"]  # Balanced Price Range
    min_rr = 2.0  # Minimum Risk-Reward

    def __init__(self, lookback=5, swing_window=3, **kwargs):
        self.lookback = lookback
        self.swing_window = swing_window
        # Allow override of core config via kwargs
        self.FVG_WICK_RATIO = kwargs.get('FVG_WICK_RATIO', self.FVG_WICK_RATIO)
        self.DISCOUNT_ZONE_PCT = kwargs.get('DISCOUNT_ZONE_PCT', self.DISCOUNT_ZONE_PCT)
        self.PREMIUM_ZONE_PCT = kwargs.get('PREMIUM_ZONE_PCT', self.PREMIUM_ZONE_PCT)
        self.KILLZONES = kwargs.get('KILLZONES', self.KILLZONES)
        self.ENTRY_TYPES = kwargs.get('ENTRY_TYPES', self.ENTRY_TYPES)
        self.min_rr = kwargs.get('min_risk_reward_ratio', 2.0)

    def detect_swings(self, candles, swing_window=None):
        import logging
        logger = logging.getLogger("TradingBot.MarketStructureICT")
        if swing_window is None:
            swing_window = self.swing_window
        if candles is None or len(candles) < swing_window * 2 + 1:
            logger.info(f"ICT detect_swings: Not enough candles (len={len(candles) if candles is not None else 'None'}), window={swing_window}")
            return {'high': [], 'low': []}
        highs = [candle.high for candle in candles]
        lows = [candle.low for candle in candles]
        n = len(candles)
        swing_highs = []
        swing_lows = []
        for i in range(swing_window, n - swing_window):
            if all(highs[i] > highs[i - j] for j in range(1, swing_window + 1)) and \
               all(highs[i] > highs[i + j] for j in range(1, swing_window + 1)):
                swing_highs.append((i, highs[i]))
            if all(lows[i] < lows[i - j] for j in range(1, swing_window + 1)) and \
               all(lows[i] < lows[i + j] for j in range(1, swing_window + 1)):
                swing_lows.append((i, lows[i]))
        logger.info(f"ICT detect_swings: Found {len(swing_highs)} swing_highs, {len(swing_lows)} swing_lows (len={n}, window={swing_window})")
        return {'high': swing_highs, 'low': swing_lows}

    def detect_fvg(self, candles):
        """Detect Fair Value Gaps (FVG) in candle data."""
        fvgs = []
        if len(candles) < 3:
            return fvgs
        
        for i in range(1, len(candles) - 1):
            c1, c2, c3 = candles[i-1], candles[i], candles[i+1]
            
            # Bullish FVG: gap between c1 high and c3 low
            if c1.high < c3.low:
                fvgs.append({
                    'type': 'BULLISH_FVG',
                    'high': c1.high,
                    'low': c3.low,
                    'index': i
                })
            
            # Bearish FVG: gap between c1 low and c3 high
            elif c1.low > c3.high:
                fvgs.append({
                    'type': 'BEARISH_FVG',
                    'high': c1.low,
                    'low': c3.high,
                    'index': i
                })
        
        return fvgs

    def validate_ifvg(self, fvg, candles):
        """Validate if FVG is an IFVG (Imbalance Fair Value Gap)."""
        if not fvg or len(candles) < 5:
            return None
        
        # Check if price has closed beyond the FVG (body closure violation)
        fvg_high = fvg['high']
        fvg_low = fvg['low']
        
        # Look for body closure beyond FVG
        for i in range(fvg['index'] + 1, min(fvg['index'] + 10, len(candles))):
            candle = candles[i]
            
            if fvg['type'] == 'BULLISH_FVG':
                if candle.close > fvg_high:  # Body closure above FVG
                    return {
                        'type': 'BULLISH_IFVG',
                        'high': fvg_high,
                        'low': fvg_low,
                        'index': fvg['index']
                    }
            elif fvg['type'] == 'BEARISH_FVG':
                if candle.close < fvg_low:  # Body closure below FVG
                    return {
                        'type': 'BEARISH_IFVG',
                        'high': fvg_high,
                        'low': fvg_low,
                        'index': fvg['index']
                    }
        
        return None

    def find_balanced_price_range(self, candles):
        """Find Balanced Price Range (BPR) - overlapping FVGs."""
        fvgs = self.detect_fvg(candles)
        if len(fvgs) < 2:
            return None
        
        # Look for overlapping FVGs
        for i in range(len(fvgs) - 1):
            fvg1 = fvgs[i]
            fvg2 = fvgs[i + 1]
            
            # Check if FVGs overlap
            if fvg1['low'] <= fvg2['high'] and fvg2['low'] <= fvg1['high']:
                overlap_high = min(fvg1['high'], fvg2['high'])
                overlap_low = max(fvg1['low'], fvg2['low'])
                
                if fvg1['type'] == 'BULLISH_FVG' and fvg2['type'] == 'BULLISH_FVG':
                    return {
                        'type': 'BULLISH_BPR',
                        'high': overlap_high,
                        'low': overlap_low,
                        'mean': (overlap_high + overlap_low) / 2,
                        'direction': 'LONG'
                    }
                elif fvg1['type'] == 'BEARISH_FVG' and fvg2['type'] == 'BEARISH_FVG':
                    return {
                        'type': 'BEARISH_BPR',
                        'high': overlap_high,
                        'low': overlap_low,
                        'mean': (overlap_high + overlap_low) / 2,
                        'direction': 'SHORT'
                    }
        
        return None

    def _calculate_atr(self, df):
        """Calculate Average True Range."""
        high_low = df['high'] - df['low']
        high_close = (df['high'] - df['close'].shift()).abs()
        low_close = (df['low'] - df['close'].shift()).abs()
        tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
        atr = tr.rolling(window=14).mean()
        val = atr.iloc[-1] if len(atr) > 0 else None
        if val is None or pd.isna(val) or not np.isfinite(val) or val <= 0:
            return max(float(df['high'].iloc[-1] - df['low'].iloc[-1]), 0.1)
        return float(val)

    def detect_market_structure(self, candles):
        """Identify Bullish/Bearish/Sideways structure based on recent swings."""
        swings = self.detect_swings(candles, self.swing_window)
        swing_highs, swing_lows = swings['high'], swings['low']
        if len(swing_highs) < 2 or len(swing_lows) < 2:
            return "Sideways"
        # Compare last two swings
        last_high, prev_high = swing_highs[-1][1], swing_highs[-2][1]
        last_low, prev_low = swing_lows[-1][1], swing_lows[-2][1]
        if last_high > prev_high and last_low > prev_low:
            return "Bullish"
        elif last_high < prev_high and last_low < prev_low:
            return "Bearish"
        else:
            return "Sideways"

    def detect_bos_choch(self, candles):
        """Detect Break of Structure (BOS) and Change of Character (CHoCH)."""
        swings = self.detect_swings(candles, self.swing_window)
        swing_highs, swing_lows = swings['high'], swings['low']
        if len(swing_highs) < 2 or len(swing_lows) < 2:
            return None
        # BOS: price breaks previous swing high/low in trend direction
        # CHoCH: first break in the opposite direction
        # Example: last structure was bullish, now price breaks last swing low = CHoCH
        structure = self.detect_market_structure(candles)
        last_close = candles[-1].close
        if structure == "Bullish":
            # BOS if price breaks last swing high
            if last_close > swing_highs[-1][1]:
                return "BOS_Bullish"
            # CHoCH if price breaks last swing low
            if last_close < swing_lows[-1][1]:
                return "CHoCH_Bearish"
        elif structure == "Bearish":
            if last_close < swing_lows[-1][1]:
                return "BOS_Bearish"
            if last_close > swing_highs[-1][1]:
                return "CHoCH_Bullish"
        return None

    def is_bos(self, candles, structure, swing_window=None):
        """
        Confirm BOS: Price closes beyond last swing high (Bullish) or swing low (Bearish).
        Uses true swing points for confirmation.
        """
        if swing_window is None:
            swing_window = self.swing_window
        last_close = candles[-1].close
        highs = [candle.high for candle in candles]
        lows = [candle.low for candle in candles]
        n = len(candles)
        swing_highs = []
        swing_lows = []
        for i in range(swing_window, n - swing_window):
            if all(highs[i] > highs[i - j] for j in range(1, swing_window + 1)) and \
               all(highs[i] > highs[i + j] for j in range(1, swing_window + 1)):
                swing_highs.append(i)
            if all(lows[i] < lows[i - j] for j in range(1, swing_window + 1)) and \
               all(lows[i] < lows[i + j] for j in range(1, swing_window + 1)):
                swing_lows.append(i)
        if structure == "Bullish" and swing_highs:
            last_swing_high = highs[swing_highs[-1]]
            return last_close > last_swing_high
        elif structure == "Bearish" and swing_lows:
            last_swing_low = lows[swing_lows[-1]]
            return last_close < last_swing_low
        return False

    def is_choch(self, candles, structure, swing_window=None):
        """
        Confirm CHoCH: Price reverses through last HL (Bullish) or LH (Bearish).
        Uses true swing points for confirmation.
        """
        if swing_window is None:
            swing_window = self.swing_window
        last_close = candles[-1].close
        highs = [candle.high for candle in candles]
        lows = [candle.low for candle in candles]
        n = len(candles)
        swing_highs = []
        swing_lows = []
        for i in range(swing_window, n - swing_window):
            if all(highs[i] > highs[i - j] for j in range(1, swing_window + 1)) and \
               all(highs[i] > highs[i + j] for j in range(1, swing_window + 1)):
                swing_highs.append(i)
            if all(lows[i] < lows[i - j] for j in range(1, swing_window + 1)) and \
               all(lows[i] < lows[i + j] for j in range(1, swing_window + 1)):
                swing_lows.append(i)
        if structure == "Bullish" and len(swing_lows) >= 2:
            last_hl = lows[swing_lows[-2]]  # Previous higher low
            return last_close < last_hl
        elif structure == "Bearish" and len(swing_highs) >= 2:
            last_lh = highs[swing_highs[-2]]  # Previous lower high
            return last_close > last_lh
        return False

    def is_idm(self, candles, impulse_window=2, retracement_threshold=0.5, direction="bullish"):
        """
        Identify IDM: Last pullback before BOS/CHoCH (liquidity grab).
        - impulse_window: Number of candles to define the impulse.
        - retracement_threshold: Minimum retracement (as a fraction, e.g., 0.5 for 50%).
        - direction: 'bullish' or 'bearish' impulse.
        """
        if len(candles) < impulse_window + 1:
            return False
        impulse_high = max([candle.high for candle in candles[-impulse_window-1:-1]])
        impulse_low = min([candle.low for candle in candles[-impulse_window-1:-1]])
        last_close = candles[-1].close
        if direction == "bullish":
            retracement = (impulse_high - last_close) / (impulse_high - impulse_low) if (impulse_high - impulse_low) != 0 else 0
            return retracement >= retracement_threshold
        elif direction == "bearish":
            retracement = (last_close - impulse_low) / (impulse_high - impulse_low) if (impulse_high - impulse_low) != 0 else 0
            return retracement >= retracement_threshold
        return False

    def detect_order_blocks(self, candles):
        """
        Identify Bullish/Bearish Order Blocks:
        - Bullish OB: Strong up candle after downtrend
        - Bearish OB: Strong down candle after uptrend
        Returns a list of dicts with type, high, low, mean.
        """
        blocks = []
        for i in range(2, len(candles)):
            prev_candle = candles[i-1]
            current_candle = candles[i]
            # Bullish OB (Down -> Up)
            if (prev_candle.close < prev_candle.open and 
                current_candle.close > current_candle.open and 
                current_candle.close > prev_candle.high):
                blocks.append({
                    "type": "BULLISH_OB",
                    "high": current_candle.high,
                    "low": current_candle.low,
                    "mean": (current_candle.high + current_candle.low) / 2
                })
            # Bearish OB (Up -> Down)
            elif (prev_candle.close > prev_candle.open and 
                  current_candle.close < current_candle.open and 
                  current_candle.close < prev_candle.low):
                blocks.append({
                    "type": "BEARISH_OB",
                    "high": current_candle.high,
                    "low": current_candle.low,
                    "mean": (current_candle.high + current_candle.low) / 2
                })
        return blocks

    def detect_imbalance(self, candles):
        """
        Identify price gaps (imbalance/FVG) between 3 consecutive candles.
        Returns 'gap_up', 'gap_down', or None.
        """
        if len(candles) < 3:
            return None
        c1, c2, c3 = candles[-3], candles[-2], candles[-1]
        if c1.high < c3.low:
            return "gap_up"
        elif c1.low > c3.high:
            return "gap_down"
        return None

    def detect_market_phase(self, candles, window=5, volume_data=None, atr_data=None, volume_mult=1.5, atr_mult=1.2):
        """
        Identify Accumulation, Distribution, Re-Distribution (bearish), or Re-Accumulation (bullish) phases.
        Uses configurable lookback window.
        - Accumulation: Sideways, no new highs/lows
        - Distribution: Lower highs + lower lows
        - Re-Distribution: Breakout after Distribution (new high and new low, bearish context)
        - Re-Accumulation: Breakout after Accumulation (new high and new low, bullish context)
        - Volume/ATR confirmation: Only confirm phase transition if volume or ATR spikes
        Args:
            candles: List of candle objects
            window: Lookback window
            volume_data: Optional list of volume values
            atr_data: Optional list of ATR values
            volume_mult: Multiplier for volume spike
            atr_mult: Multiplier for ATR spike
        """
        highs = [candle.high for candle in candles]
        lows = [candle.low for candle in candles]
        if len(candles) < window + 1:
            return "Accumulation"  # Not enough data
        # Calculate volume/ATR confirmation
        volume_confirm = True
        atr_confirm = True
        if volume_data is not None and len(volume_data) >= window + 1:
            avg_vol = np.mean(volume_data[-window-1:-1])
            volume_confirm = volume_data[-1] > avg_vol * volume_mult
        if atr_data is not None and len(atr_data) >= window + 1:
            avg_atr = np.mean(atr_data[-window-1:-1])
            atr_confirm = atr_data[-1] > avg_atr * atr_mult
        # Distribution: Lower highs + lower lows
        if max(highs[-window:-1]) > highs[-1] and min(lows[-window:-1]) > lows[-1] and volume_confirm and atr_confirm:
            return "Distribution"
        # Re-Distribution: Breakout after Distribution (new high and new low, bearish context)
        elif highs[-1] > max(highs[-window:-1]) and lows[-1] < min(lows[-window:-1]) and volume_confirm and atr_confirm:
            return "Re-Distribution"
        # Re-Accumulation: Breakout after Accumulation (new high and new low, bullish context)
        elif highs[-1] > max(highs[-window:-1]) and lows[-1] > min(lows[-window:-1]) and volume_confirm and atr_confirm:
            return "Re-Accumulation"
        else:
            return "Accumulation"

    def is_htf_liquidity_taken(self, symbol, candles, fetch_candles_func, higher_tf="H1", direction="sell", window=10):
        """
        Check if HTF liquidity (swing highs/lows) is swept.
        - direction: 'sell' (sweep high) or 'buy' (sweep low)
        - candles: lower timeframe candles (for current price).
        - fetch_candles_func: function to fetch HTF candles.
        """
        htf_candles = fetch_candles_func(symbol, higher_tf)
        if len(htf_candles) < window:
            return False
        if direction == "sell":
            latest_swing = max([candle.high for candle in htf_candles[-window:]])
            return candles[-1].high >= latest_swing
        elif direction == "buy":
            latest_swing = min([candle.low for candle in htf_candles[-window:]])
            return candles[-1].low <= latest_swing
        return False

    def generate_mmxm_signal(self, symbol, candles, fetch_candles_func, htf="H1", direction="sell"):
        """
        Generate SHORT (sell) or LONG (buy) signal in Re-Distribution/Accumulation phase with HTF liquidity sweep.
        direction: 'sell' for MMSM, 'buy' for MMBM
        """
        phase = self.detect_market_phase(candles)
        if direction == "sell":
            valid_phase = phase == "Re-Distribution"
            liquidity_swept = self.is_htf_liquidity_taken(symbol, candles, fetch_candles_func, higher_tf=htf, direction="sell")
            if valid_phase and liquidity_swept:
                logger.info(f"MMXM: SHORT signal - Phase: {phase}, HTF liquidity swept.")
                return {"signal": "SHORT", "phase": phase}
        elif direction == "buy":
            valid_phase = phase == "Re-Accumulation" if hasattr(self, 'detect_market_phase_buy') else phase == "Accumulation"
            liquidity_swept = self.is_htf_liquidity_taken(symbol, candles, fetch_candles_func, higher_tf=htf, direction="buy")
            if valid_phase and liquidity_swept:
                logger.info(f"MMXM: LONG signal - Phase: {phase}, HTF liquidity swept.")
                return {"signal": "LONG", "phase": phase}
        logger.info(f"MMXM: HOLD - Phase: {phase}")
        return {"signal": "HOLD", "phase": phase}

    def check_invalidation(self, candles, invalidation_buffer=0.0010, phase_window=(10, 5), direction="sell"):
        """
        Invalidate trade if price moves above (sell) or below (buy) 2nd phase extreme + buffer.
        - phase_window: tuple (start, end) for the 2nd phase window (e.g., (10, 5) means candles[-10:-5])
        - direction: 'sell' or 'buy'
        """
        start, end = phase_window
        if len(candles) < start:
            return False  # Not enough data
        if direction == "sell":
            distribution_high = max([candle.high for candle in candles[-start:-end]])
            is_invalid = candles[-1].close > (distribution_high + invalidation_buffer)
            if is_invalid:
                logger.info(f"MMXM: SHORT trade invalidated. Close {candles[-1].close} > {distribution_high + invalidation_buffer}")
            return is_invalid
        elif direction == "buy":
            accumulation_low = min([candle.low for candle in candles[-start:-end]])
            is_invalid = candles[-1].close < (accumulation_low - invalidation_buffer)
            if is_invalid:
                logger.info(f"MMXM: LONG trade invalidated. Close {candles[-1].close} < {accumulation_low - invalidation_buffer}")
            return is_invalid
        return False

    def manage_trade(self, entry_signal, candles, higher_tf_candles):
        """
        Dynamic stop loss and profit targeting:
        1. Stop Loss:
           - Initial: Beyond IFVG zone
           - Breakeven: After 50% RR achieved (Rule of 50)
        2. Take Profit:
           - Nearest liquidity level (HTF swing high/low)
           - Minimum 2:1 RR
        Args:
            entry_signal: dict with entry, stop, direction
            candles: list of LTF candles
            higher_tf_candles: list of HTF candles
        Returns: dict with stop_loss and take_profit
        """
        stop_loss = entry_signal["stop"]
        entry = entry_signal["entry"]
        direction = entry_signal["direction"]
        # Move to breakeven after 50% RR
        rr_denom = abs(entry - stop_loss)
        if rr_denom == 0:
            rr_denom = 1e-8  # Prevent division by zero
        current_profit_pct = (candles[-1].close - entry) / rr_denom if direction == "LONG" else (entry - candles[-1].close) / rr_denom
        if current_profit_pct >= 0.5:
            stop_loss = entry  # Breakeven
        # Target: Nearest HTF liquidity or 2:1 RR
        if len(higher_tf_candles) < 20:
            ht_swing_high = max([c.high for c in higher_tf_candles])
            ht_swing_low = min([c.low for c in higher_tf_candles])
        else:
            ht_swing_high = max([c.high for c in higher_tf_candles[-20:]])
            ht_swing_low = min([c.low for c in higher_tf_candles[-20:]])
        if direction == "LONG":
            rr_target = entry + 2 * (entry - stop_loss)
            target = min(ht_swing_high, rr_target)
        else:
            rr_target = entry - 2 * (stop_loss - entry)
            target = max(ht_swing_low, rr_target)
        return {"stop_loss": stop_loss, "take_profit": target}

    def detect_liquidity_pools(self, candles, round_number_step=0.005):
        """
        Detects liquidity pools: equal highs/lows, session highs/lows, and round numbers.
        Returns a dict with lists of levels for each type.
        """
        highs = [candle.high for candle in candles]
        lows = [candle.low for candle in candles]
        # Equal highs/lows (within a small tolerance)
        tolerance = 0.0005
        equal_highs = []
        equal_lows = []
        for i in range(1, len(highs)):
            if abs(highs[i] - highs[i-1]) < tolerance:
                equal_highs.append(highs[i])
            if abs(lows[i] - lows[i-1]) < tolerance:
                equal_lows.append(lows[i])
        # Session highs/lows (assuming session = last 24 candles)
        session_high = max(highs[-24:]) if len(highs) >= 24 else max(highs)
        session_low = min(lows[-24:]) if len(lows) >= 24 else min(lows)
        # Round numbers
        all_prices = highs + lows
        round_numbers = list(set([round(p / round_number_step) * round_number_step for p in all_prices]))
        return {
            "equal_highs": equal_highs,
            "equal_lows": equal_lows,
            "session_high": session_high,
            "session_low": session_low,
            "round_numbers": round_numbers
        }

    def validate_block(self, block, candles):
        """
        Check if price respects the block's key levels:
        - Bullish OB: Price stays above mean threshold
        - Bearish OB: Price stays below mean threshold
        """
        last_close = candles[-1].close
        if block["type"] == "BULLISH_OB":
            return last_close > block["mean"]  # Price above 50% level
        else:
            return last_close < block["mean"]  # Price below 50% level

    def detect_liquidity_sweep(self, candles, block):
        """
        Check for stop runs beyond block highs/lows:
        - Bullish OB: Sweep of lows (SSL) before reversal
        - Bearish OB: Sweep of highs (BSL) before reversal
        """
        recent_low = min([c.low for c in candles[-3:]])
        recent_high = max([c.high for c in candles[-3:]])
        if block["type"] == "BULLISH_OB":
            return recent_low < block["low"]  # Sweep of lows
        else:
            return recent_high > block["high"]  # Sweep of highs

    def detect_rejection(self, candles):
        """
        Find candles with long wicks (rejection blocks):
        - Upper wick > 67% of candle range (bearish)
        - Lower wick > 67% of candle range (bullish)
        Returns 'BEARISH_REJECTION', 'BULLISH_REJECTION', or None
        """
        current = candles[-1]
        candle_range = current.high - current.low
        if candle_range == 0:
            return None
        if (current.high - current.close) / candle_range > 0.67:
            return "BEARISH_REJECTION"
        elif (current.close - current.low) / candle_range > 0.67:
            return "BULLISH_REJECTION"
        return None

    def confirm_entry_trigger(self, candles, direction=None, method="choch_or_fvg_or_ob", micro_window=3):
        """
        Confirm entry trigger on lower timeframe after phase/HTF setup.
        method: 'choch_or_fvg_or_ob' (default: any of CHoCH, FVG fill, or micro-OB)
        direction: 'buy' or 'sell' (optional, for OB confluence)
        Returns True if entry is confirmed.
        """
        # CHoCH
        choch = self.detect_bos_choch(candles)
        if choch is not None:
            return True
        # FVG fill (imbalance filled)
        imbalance = self.detect_imbalance(candles)
        if imbalance is not None:
            return True
        # Micro order block (last N candles)
        ob_blocks = self.detect_order_blocks(candles[-micro_window:])
        rejection = self.detect_rejection(candles)
        if direction == "buy":
            for ob in ob_blocks:
                if (
                    ob['type'] == 'BULLISH_OB' and
                    self.validate_block(ob, candles) and
                    self.detect_liquidity_sweep(candles, ob) and
                    rejection == 'BULLISH_REJECTION'
                ):
                    return True
        elif direction == "sell":
            for ob in ob_blocks:
                if (
                    ob['type'] == 'BEARISH_OB' and
                    self.validate_block(ob, candles) and
                    self.detect_liquidity_sweep(candles, ob) and
                    rejection == 'BEARISH_REJECTION'
                ):
                    return True
        # If direction not specified, fallback to any OB with validation, sweep, and matching rejection
        for ob in ob_blocks:
            if (
                self.validate_block(ob, candles) and
                self.detect_liquidity_sweep(candles, ob) and
                ((ob['type'] == 'BULLISH_OB' and rejection == 'BULLISH_REJECTION') or
                 (ob['type'] == 'BEARISH_OB' and rejection == 'BEARISH_REJECTION'))
            ):
                return True
        return False

    def manage_position_scaling(self, entry_price, current_price, liquidity_pools, scale_in_buffer=0.001, scale_out_buffer=0.001):
        """
        Manage partial fills and scaling in/out based on price action at liquidity pools.
        Returns 'scale_in', 'scale_out', or 'hold'.
        """
        # Scale in if price is near a liquidity pool (within buffer)
        for level in liquidity_pools.get("equal_highs", []) + liquidity_pools.get("equal_lows", []):
            if abs(current_price - level) < scale_in_buffer:
                return "scale_in"
        # Scale out if price is near session high/low or round number
        for level in [liquidity_pools.get("session_high"), liquidity_pools.get("session_low")] + liquidity_pools.get("round_numbers", []):
            if abs(current_price - level) < scale_out_buffer:
                return "scale_out"
        return "hold"

    def dynamic_stop_loss(self, entry_price, current_price, initial_sl, partial_tp=None, structure_level=None, move_to_breakeven_buffer=0.0005):
        """
        Dynamically manage stop loss:
        - Move SL to break-even (entry) + buffer after partial TP or favorable move
        - Optionally trail to structure level if provided
        Args:
            entry_price: Trade entry price
            current_price: Current market price
            initial_sl: Original stop loss
            partial_tp: Price at which partial TP is taken (optional)
            structure_level: Price level to trail SL to (optional)
            move_to_breakeven_buffer: Buffer above/below entry for BE SL
        Returns: new stop loss price
        """
        # Move SL to break-even after partial TP or if price moves favorably
        if partial_tp is not None and ((entry_price < partial_tp < current_price) or (entry_price > partial_tp > current_price)):
            # Long or short, move SL to BE + buffer
            if entry_price < current_price:
                return entry_price + move_to_breakeven_buffer  # Long
            else:
                return entry_price - move_to_breakeven_buffer  # Short
        # Trail to structure if provided
        if structure_level is not None:
            return structure_level
        # Otherwise, keep initial SL
        return initial_sl

    def get_session_risk_multiplier(self, session):
        """
        Returns a risk multiplier based on trading session.
        Args:
            session: str, one of 'Asia', 'London', 'NewYork', etc.
        Returns: float risk multiplier
        """
        session_risk = {
            'Asia': 0.5,      # Lower risk
            'London': 1.0,    # Standard risk
            'NewYork': 1.0,   # Standard risk
            'Other': 0.7      # Default for unknown
        }
        return session_risk.get(session, session_risk['Other'])

    def log_trade_outcome(self, trade_id, phase, model, setup, outcome, pnl, logger=logger):
        """
        Log trade outcome for edge analytics.
        Args:
            trade_id: Unique identifier for the trade
            phase: Market phase at entry
            model: Model used (e.g., MMSM, MMBM)
            setup: Setup details (e.g., liquidity sweep, OB, etc.)
            outcome: 'win', 'loss', 'breakeven', etc.
            pnl: Profit/loss for the trade
            logger: Logger instance
        """
        logger.info(f"TRADE_LOG | ID: {trade_id} | Phase: {phase} | Model: {model} | Setup: {setup} | Outcome: {outcome} | PnL: {pnl}")

    def update_model_performance(self, model, outcome, model_stats=None, decay=0.95):
        """
        Update and track model performance for adaptive model selection.
        Args:
            model: Model name (e.g., MMSM, MMBM)
            outcome: 'win', 'loss', 'breakeven', etc.
            model_stats: dict to track stats (should be persistent in real use)
            decay: Decay factor for recent performance
        Returns: updated model_stats dict
        """
        if model_stats is None:
            model_stats = {}
        if model not in model_stats:
            model_stats[model] = {'score': 0, 'count': 0}
        # Score: +1 for win, -1 for loss, 0 for breakeven
        score_map = {'win': 1, 'loss': -1, 'breakeven': 0}
        model_stats[model]['score'] = model_stats[model]['score'] * decay + score_map.get(outcome, 0)
        model_stats[model]['count'] += 1
        # If score < -2 after 10+ trades, flag as underperforming
        model_stats[model]['flag'] = model_stats[model]['score'] < -2 and model_stats[model]['count'] >= 10
        return model_stats

    def generate_signal(self, candles, htf_blocks):
        """
        Generate signals based on:
        1. Validated HTF order block
        2. LTF liquidity sweep + rejection
        Args:
            candles: List of lower timeframe candles
            htf_blocks: List of validated higher timeframe order blocks
        Returns: dict with direction, entry, stop, target or None
        """
        for block in htf_blocks:
            if self.validate_block(block, candles):
                ltf_sweep = self.detect_liquidity_sweep(candles, block)
                ltf_rejection = self.detect_rejection(candles)
                # Bullish Entry
                if (block["type"] == "BULLISH_OB" and 
                    ltf_sweep and 
                    ltf_rejection == "BULLISH_REJECTION"):
                    entry = max(block["mean"], candles[-1].close)
                    return {
                        "direction": "LONG",
                        "entry": entry,
                        "stop": block["low"],
                        "target": block["high"]  # Initial target
                    }
                # Bearish Entry
                elif (block["type"] == "BEARISH_OB" and 
                      ltf_sweep and 
                      ltf_rejection == "BEARISH_REJECTION"):
                    entry = min(block["mean"], candles[-1].close)
                    return {
                        "direction": "SHORT",
                        "entry": entry,
                        "stop": block["high"],
                        "target": block["low"]
                    }
        return None

    def detect_high_prob_ifvg(self, candles, higher_tf_candles):
        """
        Identify IFVGs with:
        1. Singular FVG (no consecutive FVGs)
        2. Liquidity sweep (HTF swing highs/lows)
        3. Killzone alignment (to be implemented if time info available)
        4. Valid body closure violation
        Returns a list of valid IFVG dicts.
        """
        fvgs = self.detect_fvg(candles)  # Assumes detect_fvg is implemented
        valid_ifvgs = []
        for fvg in fvgs:
            # Step 1: Check for singular FVG (no overlapping FVGs in last 5 candles)
            if len(self.detect_fvg(candles[-5:])) > 1:
                continue
            # Step 2: Validate IFVG (body closure violation)
            ifvg = self.validate_ifvg(fvg, candles)
            if not ifvg:
                continue
            # Step 3: Confirm liquidity sweep (HTF swing high/low)
            if len(higher_tf_candles) < 10:
                continue
            ht_swing_high = max([c.high for c in higher_tf_candles[-10:]])
            ht_swing_low = min([c.low for c in higher_tf_candles[-10:]])
            if (ifvg["type"] == "BULLISH_IFVG" and 
                any(c.low <= ht_swing_low for c in candles[-3:])):
                valid_ifvgs.append(ifvg)
            elif (ifvg["type"] == "BEARISH_IFVG" and 
                  any(c.high >= ht_swing_high for c in candles[-3:])):
                valid_ifvgs.append(ifvg)
        return valid_ifvgs

    def generate_entry(self, ifvg, candles, entry_type="BODY_CLOSURE"):
        """
        Generate entry based on preferred method:
        1. Body Closure: Enter on candle close beyond IFVG
        2. IFVG Retrace: Wait for retrace to IFVG start/50%
        3. BPR: Balanced Price Range (two overlapping FVGs)
        Returns a dict with direction, entry, stop, target or None
        """
        MIN_RR = getattr(self, 'min_rr', 2.0)  # Use unified config or fallback
        if entry_type == "BODY_CLOSURE":
            direction = "LONG" if ifvg["type"] == "BULLISH_IFVG" else "SHORT"
            entry = candles[-1].close
            if direction == "LONG":
                stop = ifvg["low"]
                risk = abs(entry - stop)
                target = entry + risk * MIN_RR
            else:
                stop = ifvg["high"]
                risk = abs(entry - stop)
                target = entry - risk * MIN_RR
            return {
                "direction": direction,
                "entry": entry,
                "stop": stop,
                "target": target
            }
        elif entry_type == "IFVG_RETRACE":
            retrace_level = ifvg["low"] + (ifvg["high"] - ifvg["low"]) * 0.5  # 50% retrace
            if ifvg["type"] == "BULLISH_IFVG" and candles[-1].low <= retrace_level:
                entry = retrace_level
                stop = ifvg["low"]
                risk = abs(entry - stop)
                target = entry + risk * MIN_RR
                return {
                    "direction": "LONG",
                    "entry": entry,
                    "stop": stop,
                    "target": target
                }
            elif ifvg["type"] == "BEARISH_IFVG" and candles[-1].high >= retrace_level:
                entry = retrace_level
                stop = ifvg["high"]
                risk = abs(entry - stop)
                target = entry - risk * MIN_RR
                return {
                    "direction": "SHORT",
                    "entry": entry,
                    "stop": stop,
                    "target": target
                }
        elif entry_type == "BPR":
            bpr_zone = self.find_balanced_price_range(candles) if hasattr(self, 'find_balanced_price_range') else None
            if bpr_zone and bpr_zone["direction"] == ("LONG" if ifvg["type"] == "BULLISH_IFVG" else "SHORT"):
                entry = bpr_zone["mean"]
                if bpr_zone["direction"] == "LONG":
                    stop = bpr_zone["low"]
                    risk = abs(entry - stop)
                    target = entry + risk * MIN_RR
                else:
                    stop = bpr_zone["high"]
                    risk = abs(entry - stop)
                    target = entry - risk * MIN_RR
                return {
                    "direction": bpr_zone["direction"],
                    "entry": entry,
                    "stop": stop,
                    "target": target
                }
        return None

    def get_ltf_bias(self, data_cache):
        """Return LTF bias using M30, M15, and M5 (majority vote, fallback to NEUTRAL)."""
        votes = []
        for tf in ["M30", "M15", "M5"]:
            df = data_cache.get(tf)
            if df is not None and not df.empty:
                bias = self.detect_market_structure(df)
                votes.append(bias)
        if votes.count("bullish") > votes.count("bearish"):
            return "bullish"
        elif votes.count("bearish") > votes.count("bullish"):
            return "bearish"
        return "neutral"

    def get_htf_bias(self, data_cache):
        """Return HTF bias using H4 and D1 (majority vote, fallback to NEUTRAL)."""
        votes = []
        for tf in ["H4", "D1"]:
            df = data_cache.get(tf)
            if df is not None and not df.empty:
                bias = self.detect_market_structure(df)
                votes.append(bias)
        if votes.count("bullish") > votes.count("bearish"):
            return "bullish"
        elif votes.count("bearish") > votes.count("bullish"):
            return "bearish"
        return "neutral"

    def get_ltf_entry_signals(self, data_cache):
        import numpy as np
        import pandas as pd
        ltf_bias = self.get_ltf_bias(data_cache)
        logger.info(f"[LTF-DEBUG] LTF bias: {ltf_bias}")
        signals = []
        for tf in ["M5", "M1"]:
            df = data_cache.get(tf)
            if df is not None and not df.empty:
                entry_dir = self.detect_market_structure(df)
                logger.info(f"[LTF-DEBUG] {tf} entry_dir: {entry_dir}")
                if entry_dir == ltf_bias and entry_dir in ["bullish", "bearish"]:
                    entry_price = df['close'].iloc[-1] if hasattr(df['close'], 'iloc') and hasattr(df['close'], '__len__') and len(df['close']) > 0 else (df['close'][-1] if hasattr(df['close'], '__getitem__') and len(df['close']) > 0 else None)
                    try:
                        entry_price = float(entry_price)
                    except Exception:
                        entry_price = float('nan')
                    atr = self._calculate_atr(df)
                    if isinstance(atr, (np.ndarray, list, tuple)) and len(atr) > 0:
                        atr_val = float(atr[-1])
                    elif hasattr(atr, 'iloc') and hasattr(atr, '__len__') and len(atr) > 0:
                        atr_val = float(atr.iloc[-1])
                    elif isinstance(atr, (float, int)):
                        atr_val = float(atr)
                    else:
                        atr_val = float('nan')
                    # ATR-based TP - increased multiplier to ensure minimum 1.5 RR
                    if entry_dir == "bullish":
                        atr_tp = entry_price + atr_val * 3.0  # Increased from 1.5 to 3.0
                    else:
                        atr_tp = entry_price - atr_val * 3.0  # Increased from 1.5 to 3.0
                    # Swing-based TP from M15/M30
                    swing_tp = None
                    for swing_tf in ["M15", "M30"]:
                        swing_df = data_cache.get(swing_tf)
                        if swing_df is not None and not swing_df.empty:
                            if entry_dir == "bullish":
                                swing = swing_df['high'].rolling(20).max()
                                if isinstance(swing, pd.Series) and hasattr(swing, '__len__') and len(swing) > 0:
                                    swing_val = float(swing.iloc[-1])
                                elif isinstance(swing, (np.ndarray, list, tuple)) and len(swing) > 0:
                                    swing_val = float(swing[-1])
                                elif isinstance(swing, (float, int)):
                                    swing_val = float(swing)
                                else:
                                    swing_val = None
                            else:
                                swing = swing_df['low'].rolling(20).min()
                                if isinstance(swing, pd.Series) and hasattr(swing, '__len__') and len(swing) > 0:
                                    swing_val = float(swing.iloc[-1])
                                elif isinstance(swing, (np.ndarray, list, tuple)) and len(swing) > 0:
                                    swing_val = float(swing[-1])
                                elif isinstance(swing, (float, int)):
                                    swing_val = float(swing)
                                else:
                                    swing_val = None
                            if swing_tp is None and swing_val is not None:
                                swing_tp = swing_val
                    if any([entry_price is None, np.isnan(entry_price), np.isnan(atr_val)]):
                        logger.warning(f"get_ltf_entry_signals: Skipping due to invalid entry_price or atr_val: entry_price={entry_price}, atr_val={atr_val}")
                        continue
                    signals.append({
                        'type': f'ltf_{entry_dir}',
                        'entry': entry_price,
                        'stop_loss': entry_price - atr_val if entry_dir == "bullish" else entry_price + atr_val,
                        'take_profit': swing_tp if swing_tp is not None else atr_tp,
                        'confidence': 60,
                        'timeframe': tf
                    })
        return signals

    def get_htf_entry_signals(self, data_cache):
        """Generate entry signals on H1 if HTF bias matches direction."""
        htf_bias = self.get_htf_bias(data_cache)
        signals = []
        tf = "H1"
        df = data_cache.get(tf)
        if df is not None and not df.empty:
            entry_dir = self.detect_market_structure(df)
            if entry_dir == htf_bias and entry_dir in ["bullish", "bearish"]:
                signals.append({
                    "timeframe": tf,
                    "direction": entry_dir,
                    "type": f"htf_entry_{entry_dir}",
                    "confidence": 80
                })
        return signals

class StrategyManager:
    """
    Unified manager to run multiple strategies and select the best signal.
    Advanced selection logic: scores signals by RR, recent performance, and confluence count.
    Prioritizes by score, then by order in the strategies list.
    """
    def __init__(self, strategies, min_rr=2.0, model_stats=None):
        """
        Args:
            strategies: List of strategy instances (ordered by priority, highest first)
            min_rr: Minimum risk-reward ratio
            model_stats: Optional dict tracking recent performance for each strategy
        """
        self.strategies = strategies
        self.min_rr = min_rr
        self.model_stats = model_stats if model_stats is not None else {}

    def score_signal(self, signal, strategy_name):
        """
        Compute a composite score for a signal based on RR, recent win rate, and confluence count.
        Returns a float score (higher is better).
        """
        # RR ratio
        rr = abs(signal["target"] - signal["entry"]) / max(abs(signal["entry"] - signal["stop"]), 1e-8)
        if rr < self.min_rr:
            return -1e6  # Disqualify signals below min RR
        # Recent win rate (if available)
        win_rate = 0.5  # Default neutral
        if strategy_name in self.model_stats and self.model_stats[strategy_name]["count"] > 0:
            wins = self.model_stats[strategy_name].get("wins", 0)
            count = self.model_stats[strategy_name]["count"]
            win_rate = wins / count
        # Confluence count (if provided)
        confluences = signal.get("confluences", 0)
        # Composite score: prioritize RR, then win rate, then confluences
        score = rr * 2 + win_rate + 0.1 * confluences
        return score

    def get_best_signal(self, *args, **kwargs):
        """
        Run all strategies and select the best signal based on advanced scoring.
        Returns: (signal, strategy_name) or (None, None) if no valid signal.
        """
        best_signal = None
        best_score = -1e9
        best_strategy = None
        for priority, strat in enumerate(self.strategies):
            strategy_name = strat.__class__.__name__
            if hasattr(strat, 'generate_signal'):
                signal = strat.generate_signal(*args, **kwargs)
                if signal:
                    score = self.score_signal(signal, strategy_name)
                    # Break ties by priority (lower index = higher priority)
                    if score > best_score or (score == best_score and best_signal is None):
                        best_signal = signal
                        best_score = score
                        best_strategy = strategy_name
        return best_signal, best_strategy

    def execute_best_signal(self, execute_order, *args, **kwargs):
        """
        Run all strategies, select and execute the best signal.
        Args:
            execute_order: function(symbol, direction, entry, stop_loss, take_profit)
            *args, **kwargs: passed to generate_signal
        Returns: (signal, strategy_name) or (None, None) if no valid signal.
        """
        signal, strategy_name = self.get_best_signal(*args, **kwargs)
        if signal:
            execute_order(
                symbol=signal.get('symbol', 'EURUSD'),
                direction=signal['direction'],
                entry=signal['entry'],
                stop_loss=signal['stop'],
                take_profit=signal['target']
            )
        return signal, strategy_name

def unified_main_trading_loop(manager, symbol="EURUSD", ltf="15M", htf="4H", get_current_time=None, fetch_candles=None, execute_order=None, sleep_fn=None, logger=None):
    """
    Unified main trading loop using StrategyManager to select and execute the best signal.
    Args:
        manager: StrategyManager instance
        symbol: Trading symbol
        ltf: Lower timeframe string
        htf: Higher timeframe string
        get_current_time: function returning current time in 'HH:MM' NY time
        fetch_candles: function(symbol, tf) -> list of candles
        execute_order: function(symbol, direction, entry, stop_loss, take_profit)
        sleep_fn: function(seconds) for sleeping (default: time.sleep)
        logger: optional logger for info
    """
    import time
    if sleep_fn is None:
        sleep_fn = time.sleep
    while True:
        if get_current_time is None or fetch_candles is None or execute_order is None:
            raise ValueError("get_current_time, fetch_candles, and execute_order must be provided.")
        current_time = get_current_time("NY")
        # Assume all strategies have KILLZONES or use the first one's config
        if hasattr(manager.strategies[0], 'is_in_killzone'):
            in_kz_name, in_kz_meta = manager.strategies[0].is_in_killzone(current_time)
        else:
            in_kz_name, in_kz_meta = None, None # If no killzone logic, always trade
        if in_kz_meta:
            ltf_candles = fetch_candles(symbol, ltf)
            ht_candles = fetch_candles(symbol, htf)
            # ICT-only logic would go here (e.g., detect_high_prob_ifvg, generate_entry)
            # This is a placeholder for ICT-specific strategies.
            signal, strategy_name = manager.execute_best_signal(
                execute_order,
                ltf_candles, ht_candles
            )
            if logger:
                if signal:
                    logger.info(f"Executed {strategy_name} signal: {signal}")
                else:
                    logger.info("No valid signal this cycle.")
        if logger:
            logger.info("Sleeping for 5 minutes...")
        sleep_fn(300)  # 5 minutes

# Example Usage (called from main.py)
if __name__ == "__main__":
    from logger import setup_logger
    from utils import load_config
    from mt5_interface import MT5Interface
    from data_handler import DataHandler

    config = load_config("config.json.template") # Use template for structure
    # config = load_config("config.json") # Use actual config in real run

    cfg = config if config is not None and isinstance(config, dict) else {}

    if cfg:
        main_logger = setup_logger(cfg)
        mt5_interface = MT5Interface(cfg)

        if mt5_interface.initialize():
            data_handler = DataHandler(mt5_interface, cfg)
            if data_handler.fetch_initial_data():
                print("Initial data fetched.")
                latest_data = data_handler.get_latest_data()

                if latest_data and "H4" in latest_data and not latest_data["H4"].empty:
                    market_structure_analyzer = MarketStructure(cfg)
                    analysis = market_structure_analyzer.analyze(latest_data)

                    if analysis:
                        print("\nMarket Structure Analysis Result:")
                        for key, value in analysis.items():
                            if isinstance(value, list) and len(value) > 5:
                                print(f"  {key}: [List with {len(value)} items]")
                            elif isinstance(value, float):
                                print(f"  {key}: {value:.4f}")
                            else:
                                print(f"  {key}: {value}")
                    else:
                        print("\nMarket structure analysis failed.")
                else:
                    print("Failed to retrieve H4 data for analysis.")
            else:
                print("Failed to fetch initial data.")

            mt5_interface.shutdown()
        else:
            print("MT5 Initialization Failed.")
    else:
        print("Failed to load configuration.")

