# tests.py
import os
import tempfile
import numpy as np
from pathlib import Path
from visualizer import DroneVisualizer
from config import config
from functions import DroneDetectionSystemFunctions
import soundfile as sf
import matplotlib.pyplot as plt

# -------------------- ENVIRONMENT DETECTION --------------------
def is_notebook():
    """Detect if running in Jupyter / Colab"""
    try:
        shell = get_ipython().__class__.__name__
        return shell == 'ZMQInteractiveShell'
    except NameError:
        return False

# -------------------- AUDIO FILE SELECTION --------------------
def select_audio_file():
    """
    Cross-platform audio file selection.
    - Notebook → uses widgets.FileUpload
    - Local VS Code / terminal → prompts for file path
    Returns: (temp_file_path, original_filename)
    """
    if is_notebook():
        import ipywidgets as widgets
        from IPython.display import display, clear_output

        uploader = widgets.FileUpload(
            accept='.wav,.mp3', multiple=False,
            description="Select Audio File"
        )
        display(uploader)

        print("➡️ Please upload a file using the widget above...")
        # Wait until user uploads a file
        while not uploader.value:
            pass

        uploaded = list(uploader.value.values())[0]
        filename = list(uploader.value.keys())[0]
        content = uploaded['content']

        # Save to temporary file
        ext = Path(filename).suffix.lower()
        tmp_file = Path(tempfile.gettempdir()) / f"tmp_audio{ext}"
        with open(tmp_file, 'wb') as f:
            f.write(content)
        return str(tmp_file), filename

    else:
        # VS Code / terminal prompt
        file_path = input("Enter path to audio file (.wav/.mp3): ").strip()
        if not os.path.isfile(file_path):
            raise FileNotFoundError(f"File not found: {file_path}")
        return file_path, os.path.basename(file_path)

# -------------------- SINGLE FILE TEST --------------------
def run_single_file_test(file_path=None, threshold=0.70, analyze_long=True, show_visualization=True, true_position=None):
    """
    Run drone detection on a single audio file.
    Automatically handles notebook or terminal environments.
    If file_path is provided, uses it directly.
    """
    # Notebook environment: ignore file_path → use widget
    if file_path is None or is_notebook():
        file_path, filename = select_audio_file()
    else:
        # Local VS Code / terminal
        file_path = file_path.strip('"')  # remove quotes from command line
        if not os.path.isfile(file_path):
            raise FileNotFoundError(f"File not found: {file_path}")
        filename = os.path.basename(file_path)

    print(f"📄 File: {filename}")
    print(f"🎚️ Threshold: {threshold:.2f}")
    print(f"🔍 Long audio analysis: {'ENABLED' if analyze_long else 'DISABLED'}")

    # Initialize
    dds = DroneDetectionSystemFunctions()
    dds.load_best_model()
    visualizer = DroneVisualizer(config) if show_visualization else None

    # Run detection
    if analyze_long:
        result = dds.detect_and_localize_if_drone_enhanced(file_path, threshold=threshold)
    else:
        audio, sr = sf.read(file_path)
        is_3ch = audio.ndim == 2 and audio.shape[1] >= 3
        if is_3ch:
            result = dds.detect_and_localize_if_drone_enhanced(file_path, threshold=threshold)
        else:
            result = dds.detect_and_localize_if_drone_enhanced(file_path, file_path, file_path, threshold=threshold)

    # Visualization
    if show_visualization and visualizer:
        fig, ax = visualizer.create_environment_map(true_position=true_position)
        if result.get("detected") and "position" in result:
            pos = result["position"]
            error = np.linalg.norm(np.array(pos) - np.array(true_position)) if true_position else None
            visualizer.add_estimated_position(pos, confidence=result["probability"], error=error)
        visualizer.add_localization_info(probability=result["probability"], filename=filename)
        visualizer.show()

    # Print summary
    print(f"\n✅ Detection result: {'DRONE DETECTED' if result['detected'] else 'NO DRONE'}")
    if result.get("position"):
        print(f"📍 Estimated Position: {result['position']}")
        if true_position:
            error = np.linalg.norm(np.array(result['position']) - np.array(true_position))
            print(f"📏 Localization error: {error:.3f} m")

    # Cleanup temp file for notebook
    if is_notebook() and Path(file_path).parent == Path(tempfile.gettempdir()):
        os.unlink(file_path)

# -------------------- AUTOMATED TESTS ON SYNTHETIC FILES --------------------
def run_automated_tests_on_generated_files():
    """
    Run automated tests on synthetic files generated by DroneDetectionSystemFunctions
    """
    print("🧪 AUTOMATED TESTING ON GENERATED FILES")
    print("=" * 60)

    # Initialize
    dds = DroneDetectionSystemFunctions()
    dds.load_best_model()
    visualizer = DroneVisualizer(config)

    # Generate synthetic test files
    dds.generate_test_files()

    test_files = [
        ("3-channel recording", "recording_3channel.wav", [0.85, 0.30]),
        ("Clean drone - 3 files", ["mic1.wav", "mic2.wav", "mic3.wav"], [0.85, 0.30]),
        ("Indoor close", ["mic1_indoor.wav", "mic2_indoor.wav", "mic3_indoor.wav"], [-0.42, 1.05]),
        ("Outdoor far", ["mic1_outdoor.wav", "mic2_outdoor.wav", "mic3_outdoor.wav"], [3.2, -1.1]),
    ]

    results = []

    for test_name, files, true_pos in test_files:
        print(f"\n🔧 Testing: {test_name}")
        print(f"📍 Expected position: {true_pos}")

        try:
            # Determine if files is list (separate channels)
            if isinstance(files, list):
                result = dds.detect_and_localize_if_drone_enhanced(files[0], files[1], files[2], threshold=0.65)
            else:
                result = dds.detect_and_localize_if_drone_enhanced(files, threshold=0.65)

            # Compute error if detected
            position_error = None
            if result["detected"] and "position" in result:
                estimated_pos = result["position"]
                position_error = np.linalg.norm(np.array(estimated_pos) - np.array(true_pos))

            results.append({
                "test_name": test_name,
                "detected": result["detected"],
                "confidence": result["probability"],
                "position_error": position_error,
                "estimated_position": result.get("position"),
                "true_position": true_pos
            })

            # Visualization
            fig, ax = visualizer.create_environment_map(true_position=true_pos)
            if result["detected"] and "position" in result:
                visualizer.add_estimated_position(result["position"], confidence=result["probability"], error=position_error)
            visualizer.add_localization_info(probability=result["probability"], filename=test_name)
            visualizer.show()

            status = "✅ DETECTED" if result["detected"] else "❌ MISSED"
            error_info = f" (error: {position_error:.3f}m)" if position_error else ""
            print(f"   Result: {status} - Confidence: {result['probability']:.3f}{error_info}")

        except Exception as e:
            print(f"   ❌ ERROR: {e}")
            import traceback
            traceback.print_exc()
            results.append({
                "test_name": test_name,
                "detected": False,
                "confidence": 0.0,
                "position_error": None,
                "error": str(e)
            })

    # Summary
    print("\n🎯 TEST SUMMARY")
    print("=" * 60)
    detected_count = sum(1 for r in results if r["detected"])
    total_tests = len(results)
    print(f"📊 Overall: {detected_count}/{total_tests} tests passed")
    for result in results:
        status = "✅ DETECTED" if result["detected"] else "❌ MISSED"
        error_info = f" (error: {result['position_error']:.3f}m)" if result["position_error"] is not None else ""
        print(f"   {result['test_name']:25} → {status} (conf: {result['confidence']:.3f}{error_info})")

    return results

# -------------------- QUICK TESTS --------------------
def quick_test():
    run_single_file_test(threshold=0.70, analyze_long=True, show_visualization=True)

def test_with_visualization():
    run_single_file_test(threshold=0.70, analyze_long=True, show_visualization=True)

def test_without_visualization():
    run_single_file_test(threshold=0.70, analyze_long=True, show_visualization=False)

def test_high_sensitivity():
    run_single_file_test(threshold=0.50, analyze_long=True, show_visualization=True)

def test_low_sensitivity():
    run_single_file_test(threshold=0.85, analyze_long=True, show_visualization=True)

# -------------------- MAIN EXECUTION --------------------
if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--auto", action="store_true", help="Run automated tests on synthetic files")
    parser.add_argument("--file", type=str, help="Path to audio file for single-file test")
    args = parser.parse_args()

    if args.auto:
        run_automated_tests_on_generated_files()
    elif args.file:
        # Run single file test in VS Code / terminal
        run_single_file_test(file_path=args.file,threshold=0.7, analyze_long=True, show_visualization=True)
    elif is_notebook():
        # Run single file test interactively in notebook
        run_single_file_test()
    else:
        print("❌ No file specified. Use --file <path> or run in notebook.")
