"""
Computer Use POC v2 — Google Ads Transparency Center research for Perla Helsa (UA market).
Applies history truncation optimization — keep only last 4 turns → cuts input tokens ~4×.
"""
from __future__ import annotations

import asyncio
import json
import os
import time
from pathlib import Path

from google import genai
from google.genai import types
from playwright.async_api import async_playwright

KEY = Path("/srv/passepartout/google/gemini-computer-use-research.key").read_text().strip()
OUTDIR = Path("/tmp/cu-poc/perla")
OUTDIR.mkdir(exist_ok=True, parents=True)
LOG = OUTDIR / "run.log"
MODEL = "gemini-2.5-computer-use-preview-10-2025"
MAX_STEPS = 20
VIEWPORT_W, VIEWPORT_H = 1366, 768
HISTORY_KEEP_TURNS = 4  # keep only last N user+assistant pairs

TOTAL_INPUT = 0
TOTAL_OUTPUT = 0
TOTAL_CALLS = 0


def log(msg: str) -> None:
    line = f"[{time.strftime('%H:%M:%S')}] {msg}"
    print(line, flush=True)
    with LOG.open("a") as f: f.write(line + "\n")


def denorm(v: int, scale: int) -> int: return int(v * scale / 1000)


async def execute_action(page, name: str, args: dict) -> str:
    try:
        if name == "click_at":
            x = denorm(args["x"], VIEWPORT_W); y = denorm(args["y"], VIEWPORT_H)
            await page.mouse.click(x, y); return f"clicked ({x},{y})"
        if name == "hover_at":
            x = denorm(args["x"], VIEWPORT_W); y = denorm(args["y"], VIEWPORT_H)
            await page.mouse.move(x, y); return f"hovered ({x},{y})"
        if name == "type_text_at":
            x = denorm(args["x"], VIEWPORT_W); y = denorm(args["y"], VIEWPORT_H)
            text = args.get("text", "")
            await page.mouse.click(x, y)
            if args.get("clear_before_typing", False):
                await page.keyboard.press("Control+A"); await page.keyboard.press("Delete")
            await page.keyboard.type(text, delay=30)
            if args.get("press_enter", False): await page.keyboard.press("Enter")
            return f"typed {text!r} at ({x},{y})"
        if name == "scroll_document":
            dy = 400 if args.get("direction","down") == "down" else -400
            await page.mouse.wheel(0, dy); return f"scrolled {args.get('direction','down')}"
        if name == "scroll_at":
            x = denorm(args["x"], VIEWPORT_W); y = denorm(args["y"], VIEWPORT_H)
            direction = args.get("direction", "down"); mag = int(args.get("magnitude", 400))
            await page.mouse.move(x, y); await page.mouse.wheel(0, mag if direction == "down" else -mag)
            return f"scrolled {direction} at ({x},{y})"
        if name == "navigate":
            await page.goto(args.get("url",""), wait_until="domcontentloaded", timeout=30000)
            return f"navigated to {args.get('url','')}"
        if name == "open_web_browser": return "browser already open"
        if name == "go_back": await page.go_back(); return "went back"
        if name == "go_forward": await page.go_forward(); return "went forward"
        if name == "wait_5_seconds": await asyncio.sleep(5); return "waited 5s"
        if name == "search":
            q = args.get("query", ""); await page.goto(f"https://www.google.com/search?q={q}"); return f"searched {q!r}"
        if name == "key_combination":
            await page.keyboard.press(args.get("keys","")); return f"pressed {args.get('keys','')}"
        return f"UNKNOWN action {name} args={args}"
    except Exception as e:
        return f"ERROR {name}: {type(e).__name__}: {e}"


def truncate_history(history: list) -> list:
    """Keep first message (task + initial screenshot) + last N turns.
    A 'turn' = one user response message + one assistant response message."""
    if len(history) <= 2 * HISTORY_KEEP_TURNS + 1:
        return history
    return [history[0]] + history[-(2 * HISTORY_KEEP_TURNS):]


async def main() -> None:
    global TOTAL_INPUT, TOTAL_OUTPUT, TOTAL_CALLS
    client = genai.Client(api_key=KEY)
    pw = await async_playwright().start()
    browser = await pw.chromium.launch(
        headless=False, channel="chromium",
        args=[
            f"--window-size={VIEWPORT_W},{VIEWPORT_H}",
            "--no-sandbox", "--disable-dev-shm-usage",
            "--disable-blink-features=AutomationControlled",
            "--lang=uk-UA",
        ],
    )
    context = await browser.new_context(
        viewport={"width": VIEWPORT_W, "height": VIEWPORT_H},
        locale="uk-UA",
        user_agent="Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
    )
    await context.add_init_script("Object.defineProperty(navigator, 'webdriver', {get: () => undefined});")

    page = await context.new_page()
    await page.goto("https://adstransparency.google.com/?region=UA", wait_until="domcontentloaded", timeout=30000)
    await asyncio.sleep(3)

    task = (
        "Задача: на Google Ads Transparency Center знайти рекламодавця 'Perla Helsa' (українською 'Перла Хелса') "
        "і зібрати інформацію про активні рекламні кампанії в УКРАЇНІ. "
        "Регіон має бути встановлений Україна (UA). "
        "Якщо з'явився consent dialog — Accept. "
        "Крок: у пошук введи 'Perla Helsa' (може знадобитись ввести на латиниці). "
        "Якщо знайдено забагато результатів різних типів — обери з них ті, що стосуються саме колагену (Collagen/Колаген). "
        "Збери до 10 рекламних креативів або все що знайдеш. "
        "Фінальний підсумок почни зі слова 'ЗВІТ:' і напиши: "
        "(1) чи знайдено рекламодавця, (2) скільки активних реклам всього в UA, "
        "(3) які формати (текст Search / image банери / відео / YouTube), "
        "(4) дати першої і останньої активної кампанії, "
        "(5) приклади headline/description з 3-5 топ-оголошень, "
        "(6) чи є специфічні кампанії по Колагену — що в них (месендж, CTA, картинка/текст). "
        "Максимум 20 кроків автоматизації. Якщо вже зібрав достатньо — завершуй одразу. "
    )

    history: list = []
    tools = [types.Tool(computer_use=types.ComputerUse(
        environment=types.Environment.ENVIRONMENT_BROWSER,
    ))]
    config = types.GenerateContentConfig(tools=tools)
    report_saved = False

    for step in range(1, MAX_STEPS + 1):
        shot = await page.screenshot(type="png")
        (OUTDIR / f"step-{step:02d}.png").write_bytes(shot)
        log(f"--- step {step} ---  url={page.url[:100]}  shot={len(shot)}B  hist={len(history)}")

        if step == 1:
            history.append(types.Content(role="user", parts=[
                types.Part(text=task),
                types.Part.from_bytes(data=shot, mime_type="image/png"),
            ]))

        # TRUNCATE to keep costs down
        ctx = truncate_history(history)

        try:
            response = client.models.generate_content(model=MODEL, contents=ctx, config=config)
        except Exception as e:
            log(f"API error: {type(e).__name__}: {str(e)[:400]}")
            break

        TOTAL_CALLS += 1
        um = getattr(response, "usage_metadata", None)
        if um:
            t_in = getattr(um, "prompt_token_count", 0) or 0
            t_out = getattr(um, "candidates_token_count", 0) or 0
            TOTAL_INPUT += t_in; TOTAL_OUTPUT += t_out
            log(f"  tokens: input={t_in:,} output={t_out:,}  total: in={TOTAL_INPUT:,} out={TOTAL_OUTPUT:,}")

        if not response.candidates: log("no candidates"); break
        cand = response.candidates[0]
        parts = cand.content.parts if cand.content else []
        history.append(cand.content)

        text_parts = [p.text for p in parts if p.text]
        fn_calls = [p.function_call for p in parts if p.function_call]
        if text_parts:
            full = " | ".join(text_parts)
            log(f"MODEL: {full[:400]}")
            if "ЗВІТ:" in full or "ZVIT:" in full.upper():
                report = "\n\n".join(text_parts)
                (OUTDIR / "report.md").write_text(report)
                report_saved = True
                log(f"✅ Report saved: {OUTDIR}/report.md")

        if not fn_calls:
            log("no function_call — task done"); break

        responses_parts = []
        for fc in fn_calls:
            args = dict(fc.args) if fc.args else {}
            safety = args.pop("safety_decision", None)
            log(f"  ACT: {fc.name}({json.dumps(args, ensure_ascii=False)[:180]}){' [SAFE]' if safety else ''}")
            result = await execute_action(page, fc.name, args)
            log(f"    → {result}")
            await asyncio.sleep(1.2)
            post_shot = await page.screenshot(type="png")
            resp_dict = {"result": result, "url": page.url}
            if safety: resp_dict["safety_acknowledgement"] = "true"
            responses_parts.append(types.Part.from_function_response(name=fc.name, response=resp_dict))
            responses_parts.append(types.Part.from_bytes(data=post_shot, mime_type="image/png"))
        history.append(types.Content(role="user", parts=responses_parts))

        if report_saved: break

    cost_flash = (TOTAL_INPUT * 0.30 + TOTAL_OUTPUT * 2.50) / 1_000_000
    cost_pro = (TOTAL_INPUT * 1.25 + TOTAL_OUTPUT * 10.0) / 1_000_000
    summary = (
        f"\n=== SUMMARY ===\n"
        f"total_calls:   {TOTAL_CALLS}\n"
        f"total_input:   {TOTAL_INPUT:,} tokens\n"
        f"total_output:  {TOTAL_OUTPUT:,} tokens\n"
        f"grand_total:   {TOTAL_INPUT+TOTAL_OUTPUT:,} tokens\n"
        f"cost_Flash:    ${cost_flash:.4f}\n"
        f"cost_Pro:      ${cost_pro:.4f}\n"
        f"report_saved:  {report_saved}\n"
        f"history_keep:  last {HISTORY_KEEP_TURNS} turns (vs keep-all in v1 POC)\n"
    )
    log(summary); (OUTDIR / "summary.txt").write_text(summary)
    await browser.close(); await pw.stop()


if __name__ == "__main__":
    os.environ["DISPLAY"] = ":99"
    asyncio.run(main())
