from enhanced_logger import EnhancedLogger
import MetaTrader5 as mt5
import logging
import time
from datetime import datetime, timedelta
import pytz
import traceback
from strategy.liquidity_provider import LiquidityProvider
import math
from utils import load_config
from risk_manager import RiskManager
import decimal
import threading

logger = logging.getLogger("TradingBot.ExecutionHandler")

config = load_config()
risk_manager = RiskManager(config)

STRICT_MODE = True  # Set to True to refuse to execute any trade that fails risk validation

open_position_tolerance = 0.5  # For USTECH100M, adjust as needed
# Set cooldown_seconds to 900 (15 minutes)
cooldown_seconds = 900  # 15 minutes cooldown per symbol+direction
last_trade_time_per_symbol_dir = {}
last_signal_key_per_symbol_dir = {}
lock = threading.Lock()
risk_lock = threading.Lock()  # Add a global risk lock

def tolerant_signal_key(signal):
    def safe_round(val, ndigits=1):
        if val is None:
            return 0
        try:
            return round(float(val), ndigits)
        except Exception:
            return 0
    return (
        safe_round(signal.get('entry'), 1),
        safe_round(signal.get('stop_loss'), 1),
        signal.get('direction'),
        signal.get('type'),
        signal.get('symbol'),
        signal.get('timeframe')
    )

# Remove has_open_position function and all calls to it

def log_symbol_and_risk(signal):
    symbol_info = mt5.symbol_info(signal['symbol'])
    account_info = mt5.account_info()
    if symbol_info and account_info:
        logger.info(f"[EXEC] Account balance: {account_info.balance}, Free margin: {getattr(account_info, 'margin_free', 'N/A')}, Leverage: {getattr(account_info, 'leverage', 'N/A')}")
        logger.info(f"[EXEC] Symbol: {signal['symbol']}, Point: {symbol_info.point}, Pip value: {symbol_info.trade_tick_value}, Contract size: {symbol_info.trade_contract_size}")
        stop_distance = abs(signal['entry'] - signal['stop_loss'])
        stop_distance_pips = stop_distance / symbol_info.point
        logger.info(f"[EXEC] entry={signal['entry']}, stop={signal['stop_loss']}, stop_pips={stop_distance_pips:.2f}, lot_size={signal['position_size']}, intended_risk={account_info.balance * risk_manager.risk_percent:.2f}, actual_risk={signal['position_size'] * stop_distance_pips * symbol_info.trade_tick_value:.2f}, pip_value={symbol_info.trade_tick_value}, contract_size={symbol_info.trade_contract_size}")

class ExecutionHandler:
    def __init__(self, config, mt5_interface, notification_manager, trade_logger):
        self.config = config
        self.mt5 = mt5_interface
        self.notifier = notification_manager
        self.trade_logger = trade_logger
        # Get symbol from config
        trading_config = config.get('trading', {})
        symbols = trading_config.get('symbols', ['USTECH100M'])
        self.symbol = symbols[0] if symbols else 'USTECH100M'
        self.magic_number = 12345
        self.deviation = 20
        self.liquidity_provider = LiquidityProvider(config)
        self.last_trade_time = None
        self.min_trade_interval = self.config.get("risk_management", {}).get("min_trade_interval_seconds", 15)
        self.recent_trade_times = []  # In-memory list of recent trade timestamps
        # Revenge trade blocker state
        self.loss_history = []  # List of (timestamp, result) tuples
        self.cooloff_until = None
        self.debug_mode = True  # Set to True for extra logging, False for production
        self.risk_manager = risk_manager
        logger.info(f"Execution Handler initialized for {self.symbol} trading.")

    # --- Revenge Trade Blocker ---
    def record_trade_result(self, result):
        """
        Call this after each trade closes. Result should be 'win' or 'loss'.
        """
        now = datetime.now()
        self.loss_history.append((now, result))
        # Keep only last 20 results for memory
        self.loss_history = self.loss_history[-20:]
        # Dynamic position sizing logic
        recent_results = [r for t, r in self.loss_history[-5:]]
        if all(r == 'loss' for r in recent_results):
            # Halve risk per trade after 5 consecutive losses
            self.risk_manager.risk_percent = max(self.risk_manager.risk_percent / 2, 0.001)
            logger.warning(f"Dynamic risk: 5 losses in a row, risk per trade halved to {self.risk_manager.risk_percent}")
        elif result == 'win':
            # Restore risk per trade to config default after a win
            default_risk = self.config.get('risk_management', {}).get('risk_per_trade', 0.01)
            if self.risk_manager.risk_percent != default_risk:
                self.risk_manager.risk_percent = default_risk
                logger.info(f"Dynamic risk: Win detected, risk per trade restored to {self.risk_manager.risk_percent}")
        # Revenge trade block logic
        recent_losses = [t for t, r in self.loss_history[-5:] if r == 'loss']
        if len(recent_losses) == 5:
            self.cooloff_until = now + timedelta(hours=1)
            logger.warning(f"Revenge trade block: 5 consecutive losses. Trading blocked until {self.cooloff_until}.")

    def is_trading_blocked(self):
        if self.cooloff_until is None:
            return False
        now = datetime.now()
        if now < self.cooloff_until:
            logger.warning(f"Trading is blocked for revenge trade cool-off until {self.cooloff_until}.")
            return True
        else:
            self.cooloff_until = None
            return False

    # Example: Scale-in logic (call this when considering adding to a position)
    def try_scale_in(self, position, current_price):
        if self.risk_manager.can_scale_in(position, current_price):
            # Proceed with scale-in order logic here
            logger.info(f"Scale-in allowed for position {position.ticket} at price {current_price}")
            # ... (send order logic) ...
        else:
            logger.info(f"Scale-in NOT allowed for position {position.ticket} at price {current_price}")

    # Example: Call this periodically in your trading loop to manage partial closes and breakeven
    def manage_layered_exits(self):
        self.risk_manager.manage_partial_close_and_breakeven(self.symbol)

    def manage_ict_trades(self):
        """
        Manage open trades for partial close at 2R and move SL to breakeven.
        Uses the global risk_manager instance for all risk management logic.
        """
        self.risk_manager.manage_partial_close_and_breakeven(self.symbol)

    def execute_trade(self, signal):
        with risk_lock:
            symbol = signal['symbol']
            direction = signal['direction']
            entry = float(signal['entry'])
            # --- Cooldown timer per symbol+direction ---
            now = time.time()
            cooldown_key = (symbol, direction)
            last_time = last_trade_time_per_symbol_dir.get(cooldown_key, 0)
            if now - last_time < cooldown_seconds:
                logger.warning(f"[COOLDOWN BLOCK] Trade blocked: Cooldown active for {symbol} {direction}. Last trade at {last_time}, now {now}. Wait {cooldown_seconds - (now - last_time):.1f}s.")
                return False
            # --- Tolerant deduplication key ---
            sig_key = tolerant_signal_key(signal)
            last_key = last_signal_key_per_symbol_dir.get(cooldown_key)
            if last_key == sig_key:
                logger.warning(f"[DEDUPLICATION BLOCK] Trade blocked: Identical signal for {symbol} {direction} already executed. Skipping duplicate.")
                return False
            # --- Global risk check (redundant, but ensures no bypass) ---
            position_size = signal.get('position_size', 0)
            stop_loss = signal.get('stop_loss', 0)
            entry_price = signal.get('entry', 0)
            symbol_info = self.mt5.get_symbol_info(symbol) if hasattr(self.mt5, 'get_symbol_info') else None
            if symbol_info:
                actual_risk = position_size * abs(entry_price - stop_loss) / symbol_info['point'] * symbol_info['trade_tick_value']
                open_positions = self.mt5.get_open_positions(symbol) if hasattr(self.mt5, 'get_open_positions') else []
                account_info = self.mt5.get_account_info() if hasattr(self.mt5, 'get_account_info') else None
                current_open_risk = self.risk_manager.total_open_risk(self.mt5)
                max_allowed = account_info['balance'] * self.risk_manager.max_open_risk_percent / 100.0 if account_info else 0
                logger.info(f"[RISK-DEBUG][EH] Open risk: {current_open_risk:.2f}, New trade risk: {actual_risk:.2f}, Max allowed: {max_allowed:.2f}, Open trades: {len(open_positions)}")
                if current_open_risk + actual_risk > max_allowed:
                    logger.error(f"[GLOBAL RISK BLOCK][EH] Trade blocked: If all open trades hit SL, total risk ({current_open_risk + actual_risk:.2f}) would exceed max allowed ({max_allowed:.2f}, {self.risk_manager.max_open_risk_percent}% of balance). Skipping new trade.")
                    return False
            # --- Proceed with normal risk validation and order sending ---
            if self.is_trading_blocked():
                logger.warning("Trade execution blocked due to revenge trade cool-off.")
                return False
            if not self.risk_manager.validate_signal(signal, self.mt5):
                logger.error(f"[RISK] Trade blocked: risk management validation failed. Signal: {signal}")
                return False
            if self.debug_mode:
                log_symbol_and_risk(signal)
            symbol_info = self.mt5.get_symbol_info(signal['symbol']) if hasattr(self.mt5, 'get_symbol_info') else None
            if not symbol_info:
                logger.error(f"[EXEC] Could not retrieve symbol info for {signal['symbol']}. Skipping trade.")
                return False
            entry = round(float(signal['entry']), symbol_info['digits'])
            stop = round(float(signal['stop_loss']), symbol_info['digits'])
            tp = round(float(signal['take_profit']), symbol_info['digits'])
            stop_distance = abs(entry - stop)
            lot_step = symbol_info.get('volume_step', 0.01)
            min_lot = symbol_info.get('volume_min', 0.01)
            max_lot = symbol_info.get('volume_max', 100.0)
            contract_size = symbol_info['trade_contract_size']
            account_info = self.mt5.get_account_info() if hasattr(self.mt5, 'get_account_info') else None
            leverage = account_info.get('leverage', 100) if account_info else 100
            free_margin = account_info.get('margin_free', 0.0) if account_info else 0.0
            price = entry
            raw_volume = float(signal['position_size'])
            capped_lot = min(raw_volume, max_lot)
            margin_lot_cap = (free_margin * leverage) / (contract_size * price) if price > 0 else 0
            margin_lot_cap = max(margin_lot_cap, 0)
            final_lot = math.floor(min(capped_lot, margin_lot_cap) / lot_step) * lot_step
            final_lot = max(final_lot, min_lot)
            if final_lot < min_lot:
                logger.warning(f"[EXEC] Not enough margin for even min lot ({min_lot}) on {signal['symbol']}. Free margin: {free_margin}, required for 1 lot: {(contract_size * price) / leverage if price > 0 else 0}")
                return False
            decimals = abs(decimal.Decimal(str(lot_step)).as_tuple().exponent)
            volume = round(final_lot, decimals)
            volume = float(volume)
            logger.info(f"Placing order with volume: {volume} (raw: {raw_volume}, capped: {capped_lot}, margin_lot_cap: {margin_lot_cap}, min: {min_lot}, max: {max_lot}, step: {lot_step}, free_margin: {free_margin}, leverage: {leverage}, contract_size: {contract_size}, price: {price})")
            request = {
                "action": mt5.TRADE_ACTION_DEAL,
                "symbol": signal['symbol'],
                "volume": volume,
                "type": mt5.ORDER_TYPE_BUY if direction in [1, "BUY", "buy", "LONG", "long"] else mt5.ORDER_TYPE_SELL,
                "price": entry,
                "sl": stop,
                "tp": tp,
                "deviation": 20,
                "type_filling": mt5.ORDER_FILLING_FOK,
                "type_time": mt5.ORDER_TIME_GTC
            }
            result = self.mt5.send_order(request)
            if result is None:
                logger.error(f"Trade failed. Result: None. Last MT5 error: {self.mt5.get_last_error()}")
                return False
            elif hasattr(result, 'retcode') and result.retcode != mt5.TRADE_RETCODE_DONE:
                logger.error(f"Trade failed. Result: {result}, retcode: {result.retcode}, comment: {getattr(result, 'comment', '')}, Last MT5 error: {self.mt5.get_last_error()}")
                return False
            else:
                logger.info(f"Trade executed successfully. Ticket: {getattr(result, 'order', None)}")
                # Update cooldown and deduplication state
                last_trade_time_per_symbol_dir[cooldown_key] = now
                last_signal_key_per_symbol_dir[cooldown_key] = sig_key
                return True

    def move_trades_to_breakeven_before_friday_close(self):
        """
        Move all profitable trades' stop loss to breakeven before market close on Friday 4pm New York time.
        Losing trades are left open.
        """
        ny_tz = pytz.timezone('America/New_York')
        now_utc = datetime.utcnow().replace(tzinfo=pytz.UTC)
        now_ny = now_utc.astimezone(ny_tz)
        # Only run if it's Friday and before 4pm NY time
        if now_ny.weekday() == 4 and now_ny.hour == 15 and now_ny.minute >= 45:
            positions = self.mt5.get_open_positions()
            for pos in positions or []:
                symbol = pos.symbol
                entry = pos.price_open
                sl = pos.sl
                volume = pos.volume
                direction = pos.type  # 0=BUY, 1=SELL
                current_price = self.mt5.get_symbol_tick(symbol).bid if direction == 0 else self.mt5.get_symbol_tick(symbol).ask
                # Only move to breakeven if trade is profitable
                if (direction == 0 and current_price > entry) or (direction == 1 and current_price < entry):
                    # Move SL to entry (breakeven)
                    request = {
                        "action": mt5.TRADE_ACTION_SLTP,
                        "symbol": symbol,
                        "position": pos.ticket,
                        "sl": entry,
                        "tp": pos.tp,
                        "type_time": mt5.ORDER_TIME_GTC
                    }
                    result = self.mt5.send_order(request)
                    if hasattr(result, 'retcode') and result.retcode == mt5.TRADE_RETCODE_DONE:
                        logger.info(f"[FRIDAY BE] Moved SL to breakeven for position {pos.ticket} on {symbol} before Friday close.")
                    else:
                        logger.warning(f"[FRIDAY BE] Failed to move SL to breakeven for position {pos.ticket} on {symbol}.")

