import time
import logging
import sys
import signal
import traceback
import MetaTrader5 as mt5
from datetime import datetime
import pytz
import threading
import numpy as np
from collections import defaultdict
import math
from typing import Dict, Tuple, Any

from utils import load_config
from logger import setup_logger
from mt5_interface import MT5Interface
from data_handler import DataHandler
from strategy.session_checker import SessionChecker
from strategy.market_structure import MarketStructure
from strategy.entry_signals import EntrySignals
from strategy.liquidity_provider import LiquidityProvider
from strategy.macro_timing import MacroTiming
from risk_manager import RiskManager
from execution_handler import ExecutionHandler
from notification_manager import NotificationManager
from trade_logger import TradeLogger
from enhanced_logger import EnhancedLogger
from trading_loop import run_trading_loop
from strategy.ict_signal_engine import generate_trade_signal
from constants import CANDLE_COUNTS, ENTRY_TYPE_PRIORITY

# Global flag to handle graceful shutdown
shutdown_flag = threading.Event()
logger = None
STRICT_MODE = True  # Set to True to refuse to trade on any config/risk/structure error

risk_lock = threading.Lock()  # Add a global risk lock

def execute_trade_with_risk_management(trade_signal, risk_manager, execution_handler, mt5_interface, logger):
    """
    Centralized trade execution with risk management enforcement.
    Returns True if trade executed, False otherwise.
    """
    with risk_lock:
        logger.debug(f"[DEBUG] Executing trade with signal: {trade_signal}")
        trade_signal = sanitize_trade_signal(trade_signal, logger=logger)
        if not trade_signal:
            logger.error(f"[SKIP] Trade signal skipped due to invalid fields after sanitization: {trade_signal}")
            return False

        # Position size
        position_size = risk_manager.calculate_position_size(
            trade_signal['entry'],
            trade_signal['stop_loss'],
            trade_signal['symbol']
        )
        trade_signal['position_size'] = position_size

        # Symbol info
        symbol_info = mt5_interface.get_symbol_info(trade_signal['symbol'])
        if symbol_info is None:
            logger.error(f"[RISK] Could not retrieve symbol info for {trade_signal['symbol']}. Skipping trade.")
            return False

        # Actual risk
        actual_risk = position_size * abs(trade_signal['entry'] - trade_signal['stop_loss']) / symbol_info['point'] * symbol_info['trade_tick_value']

        # Max open risk
        current_open_risk = risk_manager.total_open_risk(mt5_interface)
        account_info = mt5_interface.get_account_info()
        if account_info is None:
            logger.error("[RISK] Could not retrieve account info. Skipping trade.")
            return False
        max_allowed = account_info['balance'] * risk_manager.max_open_risk_percent / 100.0
        logger.debug(f"[DEBUG] Position size: {position_size}, Actual risk: {actual_risk}, Current open risk: {current_open_risk}, Max allowed: {max_allowed}")
        if current_open_risk + actual_risk > max_allowed:
            logger.error(f"[SKIP] [GLOBAL RISK BLOCK] Trade blocked: If all open trades hit SL, total risk ({current_open_risk + actual_risk:.2f}) would exceed max allowed ({max_allowed:.2f}, {risk_manager.max_open_risk_percent}% of balance). Skipping new trade. Signal: {trade_signal}")
            return False

        # Per-trade risk
        if risk_manager.would_exceed_max_open_risk(actual_risk, mt5_interface):
            logger.error(f"[GLOBAL RISK] Trade blocked: would exceed max_total_risk_percent. Actual risk for this trade: {actual_risk}")
            return False

        # Validate signal
        if not risk_manager.validate_signal(trade_signal, mt5_interface):
            logger.error(f"[SKIP] [RISK] Trade blocked: risk management validation failed. Signal: {trade_signal}")
            return False

        # Execute trade
        return execution_handler.execute_trade(trade_signal)

def validate_config(config):
    required = ['risk_management', 'trading']
    for key in required:
        if key not in config:
            raise ValueError(f"Config missing required section: {key}")
    rm = config['risk_management']
    for k in ['risk_per_trade', 'max_total_risk_percent', 'max_lot_size', 'min_lot_size']:
        if k not in rm:
            raise ValueError(f"Config missing risk_management.{k}")
    if not isinstance(rm['risk_per_trade'], (float, int)) or rm['risk_per_trade'] <= 0 or rm['risk_per_trade'] >= 1:
        raise ValueError("risk_per_trade must be a decimal between 0 and 1 (e.g., 0.01 for 1% risk per trade)")
    if not isinstance(rm['max_total_risk_percent'], (float, int)) or rm['max_total_risk_percent'] <= 0:
        raise ValueError("max_total_risk_percent must be positive number")
    return True

def signal_key(signal):
    """Generate a deduplication key for a signal based on rounded prices and key fields."""
    def safe_round(val):
        if val is None:
            return 0
        try:
            return round(float(val), 2)
        except Exception:
            return 0
    return (
        safe_round(signal.get('entry')),
        safe_round(signal.get('stop_loss')),
        safe_round(signal.get('take_profit')),
        signal.get('direction'),
        signal.get('type'),
        signal.get('symbol'),
        signal.get('timeframe')
    )

def signal_handler(sig, frame):
    print("\nCtrl+C detected. Initiating graceful shutdown...")
    if logger:
        logger.info("Shutdown signal received. Exiting gracefully.")
    shutdown_flag.set()

def tick_sweep_monitor(mt5_interface, symbols, interval=20, threshold=0.003):
    import time
    from datetime import datetime
    logger = logging.getLogger("TickSweepMonitor")
    while True:
        for symbol in symbols:
            try:
                # Get last 16 H1 candles for the symbol
                h1_df = mt5_interface.get_rates(symbol, "H1", count=16)
                if h1_df is None or len(h1_df) < 16:
                    logger.warning(f"Not enough H1 data for sweep detection on {symbol}")
                    continue
                recent_high = h1_df['high'][:-1].max()
                recent_low = h1_df['low'][:-1].min()
                # Get latest tick
                tick = mt5_interface.get_tick(symbol)
                if not tick:
                    logger.warning(f"No tick data for {symbol}")
                    continue
                bid = tick.get('bid')
                ask = tick.get('ask')
                now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
                # Check for bullish sweep
                if ask and ask > recent_high * (1 + threshold):
                    logger.info(f"[TICK SWEEP] {now} {symbol} H1 Bullish sweep detected: ask={ask} > high={recent_high} * (1+{threshold})")
                # Check for bearish sweep
                if bid and bid < recent_low * (1 - threshold):
                    logger.info(f"[TICK SWEEP] {now} {symbol} H1 Bearish sweep detected: bid={bid} < low={recent_low} * (1-{threshold})")
            except Exception as e:
                logger.error(f"Error in tick sweep monitor for {symbol}: {e}")
        time.sleep(interval)

def get_sweep_threshold(symbol):
    # Use 0.15% for USTECH100M, 0.3% for others
    return 0.0015 if symbol == "USTECH100M" else 0.003

def get_ustech100m_params():
    return {
        'sweep_threshold': 0.0015,  # 0.15%
        'min_volume_ratio': 1.3,
        'priority_timeframes': ['M5', 'M15']
    }

def get_session_size():
    from datetime import datetime
    hour = datetime.now().hour
    if 2 <= hour < 5 or 8 <= hour < 11:  # London/NY
        return 1.0  # 100% size
    elif 19 <= hour or hour < 1:         # Asian
        return 0.5  # 50% size
    else:                                # Off-hours
        return 0.25 # 25% size

def is_bullish_engulfing(prev, curr):
    return curr['close'] > curr['open'] and curr['close'] > prev['high']

def is_bearish_engulfing(prev, curr):
    return curr['close'] < curr['open'] and curr['close'] < prev['low']

def is_bullish_reversal(prev, curr):
    return curr['close'] > curr['open'] and curr['close'] > prev['high']

def is_bearish_reversal(prev, curr):
    return curr['close'] < curr['open'] and curr['close'] < prev['low']

def price_is_near_liquidity(price, pools, threshold=7):
    return any(abs(price - lvl) <= threshold for lvl in pools)

def has_long_lower_wick(candle, min_ratio=2.0):
    body = abs(candle['close'] - candle['open'])
    wick = (candle['open'] - candle['low']) if candle['close'] > candle['open'] else (candle['close'] - candle['low'])
    return wick > body * min_ratio

def has_long_upper_wick(candle, min_ratio=2.0):
    body = abs(candle['close'] - candle['open'])
    wick = (candle['high'] - candle['close']) if candle['close'] > candle['open'] else (candle['high'] - candle['open'])
    return wick > body * min_ratio

def is_basic_bullish_reversal(prev, curr):
    return (
        curr['close'] > prev['close'] and
        curr['low'] >= prev['low'] and
        curr['close'] > curr['open']
    )

def is_basic_bearish_reversal(prev, curr):
    return (
        curr['close'] < prev['close'] and
        curr['high'] <= prev['high'] and
        curr['close'] < curr['open']
    )

def advanced_tick_sweep_monitor(mt5_interface, symbols, execution_handler, interval=10):
    import time
    import numpy as np
    from datetime import datetime
    logger = logging.getLogger("AdvTickSweepMonitor")
    timeframes_ltf = ["M1", "M5"]
    timeframes_htf = ["H1"]
    timeframes_filter = ["H4", "D1"]
    sweep_memory = defaultdict(dict)  # {symbol: {tf: {sweep_info}}}
    while True:
        for symbol in symbols:
            try:
                now = datetime.now(pytz.UTC)
                # Instrument-specific params
                if symbol == "USTECH100M":
                    params = get_ustech100m_params()
                    timeframes_ltf = params['priority_timeframes']
                    threshold = params['sweep_threshold']
                    min_vol_ratio = params['min_volume_ratio']
                    liquidity_proximity = 10  # widened
                else:
                    timeframes_ltf = ["M5", "M15"]
                    threshold = 0.003
                    min_vol_ratio = 1.5
                    liquidity_proximity = 10
                # Get liquidity pools (use H1 or H4 for reference)
                h1_df = mt5_interface.get_rates(symbol, "H1", count=50)
                liquidity_pools = []
                if h1_df is not None and len(h1_df) >= 30:
                    last_close = h1_df['close'].iloc[-1]
                    liquidity_pools = [round(last_close/5)*5, round(last_close/10)*10]
                    liquidity_pools += [float(h1_df['high'].max()), float(h1_df['low'].min())]
                for tf in timeframes_ltf:
                    df = mt5_interface.get_rates(symbol, tf, count=15)
                    if df is None or len(df) < 11:
                        logger.warning(f"Not enough {tf} data for {symbol}")
                        continue
                    swing_lows = df['low'][-11:-1].astype(float).tolist()
                    swing_highs = df['high'][-11:-1].astype(float).tolist()
                    current = df.iloc[-1]
                    prev = df.iloc[-2]
                    # --- Sweep detection (wick + proximity) ---
                    bull_sweep_hit = (
                        current['low'] < min(swing_lows) and
                        price_is_near_liquidity(current['low'], liquidity_pools, liquidity_proximity) and
                        has_long_lower_wick(current)
                    )
                    bear_sweep_hit = (
                        current['high'] > max(swing_highs) and
                        price_is_near_liquidity(current['high'], liquidity_pools, liquidity_proximity) and
                        has_long_upper_wick(current)
                    )
                    # --- Debug logging ---
                    logger.info(f"[DEBUG] {symbol} {tf} Checking sweep: low={current['low']}, min_swing={min(swing_lows)}, high={current['high']}, max_swing={max(swing_highs)}")
                    logger.info(f"[DEBUG] Wick check: LLW={has_long_lower_wick(current)}, LUW={has_long_upper_wick(current)} | Close>Open: {current['close'] > current['open']}")
                    # --- Sweep memory logic ---
                    mem = sweep_memory[symbol].get(tf, None)
                    # If sweep detected, store in memory
                    if bull_sweep_hit and (not mem or mem.get('type') != 'bullish'):
                        sweep_memory[symbol][tf] = {
                            "type": "bullish",
                            "sweep_price": current['low'],
                            "time": current.name,
                            "candle_idx": len(df)-1,
                            "confirmed": False
                        }
                        logger.info(f"[SWEEP FLAGGED] {symbol} {tf} BULLISH sweep flagged at {current['low']} ({current.name})")
                    elif bear_sweep_hit and (not mem or mem.get('type') != 'bearish'):
                        sweep_memory[symbol][tf] = {
                            "type": "bearish",
                            "sweep_price": current['high'],
                            "time": current.name,
                            "candle_idx": len(df)-1,
                            "confirmed": False
                        }
                        logger.info(f"[SWEEP FLAGGED] {symbol} {tf} BEARISH sweep flagged at {current['high']} ({current.name})")
                    # If sweep in memory, check for confirmation in next 1-3 candles
                    mem = sweep_memory[symbol].get(tf, None)
                    if mem and not mem.get('confirmed'):
                        sweep_idx = mem['candle_idx']
                        for i in range(1, 4):
                            idx = sweep_idx + i
                            if idx < len(df):
                                prev_c = df.iloc[idx-1]
                                curr_c = df.iloc[idx]
                                if mem['type'] == 'bullish' and is_basic_bullish_reversal(prev_c, curr_c):
                                    sweep_memory[symbol][tf]['confirmed'] = True
                                    entry = curr_c['close']
                                    sweep_type = 'bullish'
                                    logger.info(f"[SWEEP CONFIRMED] {symbol} {tf} BULLISH reversal at {curr_c.name} close={entry}")
                                    break
                                elif mem['type'] == 'bearish' and is_basic_bearish_reversal(prev_c, curr_c):
                                    sweep_memory[symbol][tf]['confirmed'] = True
                                    entry = curr_c['close']
                                    sweep_type = 'bearish'
                                    logger.info(f"[SWEEP CONFIRMED] {symbol} {tf} BEARISH reversal at {curr_c.name} close={entry}")
                                    break
                    # If confirmed, proceed with HTF confirmation and trade logic
                    mem = sweep_memory[symbol].get(tf, None)
                    if mem and mem.get('confirmed') and not mem.get('executed'):
                        sweep_type = mem['type']
                        entry = mem['sweep_price']
                        # --- H1 Confirmation ---
                        confirmed = False
                        for htf in timeframes_htf:
                            htf_df = mt5_interface.get_rates(symbol, htf, count=30)
                            if htf_df is None or len(htf_df) < 21:
                                continue
                            fvg = False
                            h1_avg_vol = htf_df['tick_volume'][-21:-1].mean()
                            h1_last_vol = htf_df['tick_volume'].iloc[-1]
                            h1_vol_spike = h1_last_vol > min_vol_ratio * h1_avg_vol
                            for i in range(2, len(htf_df)):
                                c1, c2, c3 = htf_df.iloc[i-2], htf_df.iloc[i-1], htf_df.iloc[i]
                                if sweep_type == 'bullish' and (c1['high'] < c3['low']) and (c2['low'] < c1['high']):
                                    fvg = True
                                if sweep_type == 'bearish' and (c1['low'] > c3['high']) and (c2['high'] > c1['low']):
                                    fvg = True
                            if fvg or h1_vol_spike:
                                confirmed = True
                                logger.info(f"[ADV SWEEP] {symbol} {htf} CONFIRM: FVG={fvg}, H1_VOL_SPIKE={h1_vol_spike}")
                                break
                        # --- H4/D1 Filter ---
                        filter_block = False
                        for ftf in timeframes_filter:
                            ftf_df = mt5_interface.get_rates(symbol, ftf, count=30)
                            if ftf_df is None or len(ftf_df) < 21:
                                continue
                            filter_bias = 'bullish' if ftf_df['close'].iloc[-1] > ftf_df['close'][-21:-1].mean() else 'bearish'
                            if filter_bias != sweep_type:
                                filter_block = True
                                logger.info(f"[ADV SWEEP] {symbol} {ftf} FILTER BLOCK: filter_bias={filter_bias}, sweep_type={sweep_type}")
                                break
                        # --- Session/Size Logic (Consistent) ---
                        size = get_session_size()
                        session = 'London/NY' if size == 1.0 else ('Asian' if size == 0.5 else 'Other')
                        # --- Final Decision ---
                        if confirmed and not filter_block:
                            logger.info(f"[ADV SWEEP] {symbol} {tf} {sweep_type.upper()} wick sweep: CONFIRMED, session={session}, size={size*100:.0f}% - PLACING TRADE")
                            direction = 1 if sweep_type == 'bullish' else -1
                            sl = entry - 1.5 * abs(entry) if direction == 1 else entry + 1.5 * abs(entry)
                            tp = entry + 2.5 * abs(entry) if direction == 1 else entry - 2.5 * abs(entry)
                            trade_signal = {
                                'symbol': symbol,
                                'type': 'BUY' if direction == 1 else 'SELL',
                                'direction': 'BUY' if direction == 1 else 'SELL',
                                'entry': entry,
                                'stop_loss': sl,
                                'take_profit': tp,
                                'timeframe': tf,
                                'confidence': 1.0,
                                'pattern': f"ADV_WICK_SWEEP_{tf}_{sweep_type.upper()}",
                                'size_multiplier': size,
                                'reason': f"Wick sweep detected and confirmed: {sweep_type} on {tf}, session={session}, size={size*100:.0f}%"
                            }
                            handle_trade_attempt(trade_signal, symbol, tf, risk_manager, execution_handler, mt5_interface, logger)
                        elif not confirmed:
                            logger.info(f"[ADV SWEEP] {symbol} {tf} {sweep_type.upper()} wick sweep: NOT CONFIRMED by H1 FVG or volume spike")
                        elif filter_block:
                            logger.info(f"[ADV SWEEP] {symbol} {tf} {sweep_type.upper()} wick sweep: BLOCKED by H4/D1 filter")
            except Exception as e:
                logger.error(f"Error in advanced tick sweep monitor for {symbol}: {e}")
        time.sleep(interval)

def sanitize_trade_signal(trade_signal, logger=None):
    """
    Ensures all price-related fields in the trade signal are valid (not NaN, None, or inf).
    Converts numpy numeric types to native Python types for all fields.
    Returns sanitized signal or None if critical fields are invalid.
    """
    import numpy as np
    import math

    def is_invalid(val):
        return (
            val is None or
            (isinstance(val, float) and (math.isnan(val) or math.isinf(val))) or
            (isinstance(val, np.floating) and (np.isnan(val) or np.isinf(val)))
        )

    # Convert all numpy numeric types to native Python types for all fields
    for k, v in list(trade_signal.items()):
        if isinstance(v, np.generic):
            trade_signal[k] = v.item()

    # Critical fields
    critical_fields = ['entry', 'stop_loss']
    for field in critical_fields:
        val = trade_signal.get(field)
        if is_invalid(val):
            if logger:
                logger.error(f"Trade signal rejected: {field} is invalid ({val}) in {trade_signal}")
            return None
    # Optional fields: set to None if invalid
    optional_fields = ['take_profit', 'tp1', 'tp2', 'tp3', 'trailing_start', 'trailing_atr']
    for field in optional_fields:
        val = trade_signal.get(field)
        if is_invalid(val):
            trade_signal[field] = None
    return trade_signal

def handle_trade_attempt(trade_signal, symbol, tf, risk_manager, execution_handler, mt5_interface, logger, last_signal_key=None, last_signal_time=None):
    """Centralized trade attempt handler for all strategies and monitors."""
    from time import time
    logger.debug(f"[DEBUG] handle_trade_attempt called with: {trade_signal}")
    # Defensive: ensure 'symbol' is always set
    if 'symbol' not in trade_signal:
        trade_signal['symbol'] = symbol
    # Ensure 'direction' is set
    if 'direction' not in trade_signal:
        if 'type' in trade_signal:
            t = trade_signal['type'].lower()
            if 'buy' in t or 'bullish' in t:
                trade_signal['direction'] = 'BUY'
            elif 'sell' in t or 'bearish' in t:
                trade_signal['direction'] = 'SELL'
    if 'direction' not in trade_signal:
        logger.error(f"[ERROR] Trade signal missing 'direction': {trade_signal}")
        return False
    if execute_trade_with_risk_management(trade_signal, risk_manager, execution_handler, mt5_interface, logger):
        if last_signal_key is not None and last_signal_time is not None:
            safe_set_signal_key(last_signal_key, (symbol, tf), signal_key(trade_signal))
            safe_set_signal_time(last_signal_time, (symbol, tf), time())
        logger.info(f"[SUCCESS] Trade executed successfully for {symbol} {tf} | signal={trade_signal}")
        return True
    else:
        logger.error(f"[FAIL] Failed to execute trade for {symbol} {tf} | signal={trade_signal}")
        return False

def run_bot():
    """Main bot execution function."""
    global logger
    try:
        start_time = time.time()
        
        # Setup logging first
        config = load_config("config.json")
        try:
            validate_config(config)
        except Exception as e:
            print(f"Config validation failed: {e}")
            if logger:
                logger.error(f"Config validation failed: {e}")
            if STRICT_MODE:
                return
        
        # Use a local cfg variable for all .get calls
        cfg = config if config is not None and isinstance(config, dict) else {}
        logger = setup_logger(cfg)
        logger.info(f"Config loaded in {time.time() - start_time:.2f} seconds")
        logger.info("Starting Multi-Symbol Trading Bot...")
        
        # Get symbols and timeframes list from config
        trading_cfg = cfg.get('trading', {})
        symbols = trading_cfg.get('symbols', [trading_cfg.get('symbol', 'USTECH100M')])
        timeframes = trading_cfg.get('timeframes', ["M15"])
        # Initialize per-symbol-timeframe signal tracking
        from typing import Dict, Tuple, Any
        last_signal_key: Dict[Tuple[str, str], Any] = {(symbol, tf): None for symbol in symbols for tf in timeframes}
        last_signal_time: Dict[Tuple[str, str], float] = {(symbol, tf): 0.0 for symbol in symbols for tf in timeframes}
        min_signal_interval = 60  # seconds
        last_trading_day = datetime.now().date()
        
        # Initialize components
        init_start = time.time()
        mt5_interface = MT5Interface(cfg)
        if not mt5_interface.initialize():
            logger.error("Failed to initialize MT5. Exiting.")
            return
        logger.info(f"MT5 initialized in {time.time() - init_start:.2f} seconds")

        # Initialize other components with timing
        comp_start = time.time()
        market_structure = MarketStructure(cfg)
        entry_signals = EntrySignals(cfg)
        liquidity_provider = LiquidityProvider(cfg)
        session_checker = SessionChecker(cfg)
        risk_manager = RiskManager(cfg)
        notification_manager = NotificationManager(cfg)
        macro_timing = MacroTiming(cfg)
        trade_logger = TradeLogger(cfg)
        execution_handler = ExecutionHandler(
            cfg,
            mt5_interface,
            notification_manager,
            trade_logger
        )
        logger.info(f"All components initialized in {time.time() - comp_start:.2f} seconds")

        data_handler = DataHandler(mt5_interface, cfg)
        if not data_handler.fetch_initial_data():
            logger.error("Failed to fetch initial data. Exiting.")
            return
        logger.info("Initial data fetched successfully.")

        print(f"{symbols[0]} Trading Bot initialized successfully. Starting main loop...")
        logger.info(f"{symbols[0]} Trading Bot initialized successfully. Starting main loop...")

        # Use risk_manager for trade validation if needed

        # Run the trading loop with ICT strategy
        while not shutdown_flag.is_set():
            try:
                data_cache = data_handler.get_latest_data()
                current_utc = datetime.now(pytz.UTC)
                # --- Pre-fetch all required rates for this loop ---
                rates_cache = {}
                for symbol in symbols:
                    rates_cache[symbol] = {}
                    for tf in set(["D1", "H1"] + timeframes):
                        count = CANDLE_COUNTS.get(tf, 30)  # Use CANDLE_COUNTS, default to 30 if not found
                        rates_cache[symbol][tf] = mt5_interface.get_rates(symbol, tf, count=count)
                for symbol in symbols:
                    daily_df = rates_cache[symbol]["D1"]
                    hourly_df = rates_cache[symbol]["H1"]
                    for tf in timeframes:
                        if tf == "D1" or tf == "H1":
                            m_df = daily_df if tf == "D1" else hourly_df
                        else:
                            m_df = rates_cache[symbol][tf]
                        if daily_df is None or hourly_df is None or m_df is None:
                            logger.warning(f"Missing required data for ICT analysis for {symbol} {tf}")
                            continue
                        # --- DEBUG: Log data types and sample data before get_trend ---
                        logger.debug(f"Data cache types: {{tf: type(rates_cache[symbol][tf]) for tf in rates_cache[symbol]}}")
                        if 'D1' in rates_cache[symbol]:
                            logger.debug(f"D1 data sample: {rates_cache[symbol]['D1'].head()}")
                        else:
                            logger.error(f"D1 data missing from rates_cache for symbol {symbol}")
                        # --- REAL STRATEGY LOGIC START ---
                        # 1. Check if session is valid for trading
                        if not session_checker.is_session_open(symbol, tf, current_utc):
                            logger.info(f"Session closed for {symbol} {tf}, skipping.")
                            continue

                        # 2. Confirm market structure (e.g., bullish/bearish trend)
                        # --- NEW TREND/PHASE LOGIC START ---
                        trend_info = market_structure.get_trend(rates_cache[symbol])
                        phase = trend_info.get('phase', 'sideways')
                        confidence = trend_info.get('confidence', 0)
                        key_levels = trend_info.get('key_levels', {})
                        signals = trend_info.get('signals', [])
                        logger.info(f"[STRUCTURE] {symbol} {tf} phase={phase}, confidence={confidence}, key_levels={key_levels}")
                        if cfg and cfg.get('trend_confidence_threshold', None) is not None:
                            confidence_threshold_val = cfg.get('trend_confidence_threshold', 25)
                        else:
                            confidence_threshold_val = 25
                        if confidence < confidence_threshold_val:
                            logger.warning(f"[STRUCTURE] Low confidence ({confidence}) for {symbol} {tf} phase ({phase}), skipping.")
                            continue

                        if not signals:
                            logger.warning(f"[SKIP] No trading signals for {symbol} {tf} in phase {phase}, skipping. signals={signals}")
                            continue

                        # Log all signals and their priorities
                        for s in signals:
                            logger.info(f"[SIGNAL] {s['type']} | confidence={s.get('confidence', 0)} | priority={ENTRY_TYPE_PRIORITY.get(s['type'], 99)} | entry={s.get('entry')} | sl={s.get('stop_loss')} | tp={s.get('take_profit')} | full={s}")

                        # Sort signals by (priority, -confidence)
                        def get_signal_priority(signal):
                            return ENTRY_TYPE_PRIORITY.get(signal['type'], 99), -signal.get('confidence', 0)

                        confidence_threshold = cfg.get('signal_confidence_threshold', 60) if cfg and cfg.get('signal_confidence_threshold', None) is not None else 60
                        sorted_signals = sorted(signals, key=get_signal_priority)

                        for signal in sorted_signals:
                            logger.debug(f"[DEBUG] Considering signal for execution: {signal}")
                            if signal.get('confidence', 0) >= confidence_threshold:
                                logger.info(f"[EXECUTE] Attempting to execute signal: {signal['type']} | confidence={signal.get('confidence', 0)} | priority={ENTRY_TYPE_PRIORITY.get(signal['type'], 99)} | full={signal}")
                                if 'symbol' not in signal:
                                    signal['symbol'] = symbol
                                if 'direction' not in signal:
                                    if 'buy' in signal['type'].lower():
                                        signal['direction'] = 'BUY'
                                    elif 'sell' in signal['type'].lower():
                                        signal['direction'] = 'SELL'
                                    elif 'bullish' in signal['type'].lower():
                                        signal['direction'] = 'BUY'
                                    elif 'bearish' in signal['type'].lower():
                                        signal['direction'] = 'SELL'
                                # Defensive: error if still missing
                                if 'direction' not in signal:
                                    logger.error(f"[ERROR] Signal missing 'direction': {signal}")
                                    raise ValueError(f"Signal missing 'direction': {signal}")
                                handle_trade_attempt(signal, symbol, tf, risk_manager, execution_handler, mt5_interface, logger, last_signal_key, last_signal_time)
                            else:
                                logger.info(f"[SKIP] Signal below confidence threshold: {signal}")
                        # --- NEW TREND/PHASE LOGIC END ---
                        # Manage existing trades (use M5, M15, and M30 as reference for now)
                        if tf in ["M5", "M15", "M30"]:
                            current_price = m_df['close'].iloc[-1] if not m_df.empty else None
                            if current_price:
                                # --- ENHANCED RISK MANAGEMENT ---
                                # Call trade management for partials and BE (every M5/M15/M30 cycle)
                                execution_handler.manage_ict_trades()
                                # Enforce risk management rules on all open trades (every M5/M15/M30 cycle)
                                risk_manager.enforce_risk_on_open_trades(symbol, mt5_interface=mt5_interface, notification_manager=notification_manager)
                                # Additional risk management checks
                                try:
                                    # Check for partial close opportunities at 2:1 RR
                                    risk_manager.manage_partial_close_and_breakeven(symbol, mt5_interface=mt5_interface, notification_manager=notification_manager)
                                    # Dynamic stop loss management
                                    positions = mt5_interface.get_open_positions(symbol=symbol)
                                    for position in positions:
                                        # Move to breakeven after 50% RR achieved
                                        entry = position['price_open']
                                        sl = position['sl']
                                        current_price = position['price_current']
                                        direction = position['type']
                                        if sl and entry:
                                            # Calculate current profit percentage
                                            if direction == 0:  # BUY
                                                profit_pct = (current_price - entry) / abs(sl - entry)
                                            else:  # SELL
                                                profit_pct = (entry - current_price) / abs(sl - entry)
                                            # Move to breakeven after 50% RR
                                            if profit_pct >= 0.5 and abs(sl - entry) > 0.001:
                                                logger.info(f"Moving position {position['ticket']} to breakeven (50% RR achieved)")
                                                mt5_interface.modify_position(
                                                    position['ticket'],
                                                    sl=entry,
                                                    tp=position['tp']
                                                )
                                except Exception as e:
                                    logger.error(f"Error in enhanced risk management: {str(e)}")
                        
                        # --- LTF and HTF entry logic ---
                        ltf_signals = market_structure.get_ltf_entry_signals(rates_cache[symbol])
                        htf_signals = market_structure.get_htf_entry_signals(rates_cache[symbol])
                        for signal in ltf_signals + htf_signals:
                            handle_trade_attempt(signal, symbol, tf, risk_manager, execution_handler, mt5_interface, logger)
                # Daily risk reset
                current_day = datetime.now().date()
                if hasattr(risk_manager, 'reset_daily_limits') and callable(getattr(risk_manager, 'reset_daily_limits', None)):
                    if current_day != last_trading_day:
                        risk_manager.reset_daily_limits()
                        last_trading_day = current_day
                # Move profitable trades to breakeven before Friday close
                execution_handler.move_trades_to_breakeven_before_friday_close()
                # Enforce 2:1 RR partial take profit rule for all open trades
                for sym in symbols:
                    logger.info(f"Enforcing 2:1 RR partial take profit rule for open trades on {sym}")
                    risk_manager.manage_partial_close_and_breakeven(sym, mt5_interface=mt5_interface, notification_manager=notification_manager)
                time.sleep(cfg.get('loop_interval', 10))
            except Exception as e:
                import traceback
                logger.error(f"Error in main loop: {str(e)}\n{traceback.format_exc()}")
                time.sleep(cfg.get('loop_interval', 10))

    except Exception as e:
        if logger:
            logger.error(f"Fatal error: {str(e)}", exc_info=True)
        else:
            print(f"Fatal error: {str(e)}")
    finally:
        if 'mt5_interface' in locals():
            mt5_interface.shutdown()

def safe_set_signal_key(d: Dict, k, v):
    try:
        d[k] = v if isinstance(v, (str, int, float, tuple)) else str(v)
    except Exception as e:
        if logger:
            logger.warning(f"Could not set last_signal_key: {e}")

def safe_set_signal_time(d: Dict, k, v):
    try:
        d[k] = float(v)
    except Exception as e:
        if logger:
            logger.warning(f"Could not set last_signal_time: {e}")

if __name__ == "__main__":
    try:
        # Set up signal handler for graceful shutdown
        signal.signal(signal.SIGINT, signal_handler)
        signal.signal(signal.SIGTERM, signal_handler)  # Handle termination signal as well
        # Start advanced tick sweep monitor in a background thread
        loaded_config = load_config("config.json")
        if not loaded_config:
            print("Failed to load configuration. Exiting.")
            sys.exit(1)
        if loaded_config is not None and isinstance(loaded_config, dict):
            cfg = loaded_config
            trading_cfg = cfg.get('trading', {})
            symbols = trading_cfg.get('symbols', [trading_cfg.get('symbol', 'USTECH100M')])
            mt5_interface = MT5Interface(cfg)
            risk_manager = RiskManager(cfg)
            notification_manager = NotificationManager(cfg)
            trade_logger = TradeLogger(cfg)
            execution_handler = ExecutionHandler(
                cfg,
                mt5_interface,
                notification_manager,
                trade_logger
            )
            threading.Thread(target=advanced_tick_sweep_monitor, args=(mt5_interface, symbols, execution_handler), daemon=True).start()
        else:
            symbols = ['USTECH100M']
            mt5_interface = None
            risk_manager = None
            notification_manager = None
            trade_logger = None
            execution_handler = None
        run_bot()
    except KeyboardInterrupt:
        print("\nBot stopped by user.")
    except Exception as e:
        print(f"\nUnexpected error: {str(e)}")
        traceback.print_exc()
        

