#!/usr/bin/env python3
"""
重新生成第 13-14 页
模型：Google Gemini 3.1 Flash Image Preview (Nano Banana 2)
特点：带角色参考图，确保角色一致性
"""

import os
import requests
import json
import base64
import time
from pathlib import Path
from datetime import datetime

# API 配置 - OpenRouter
OPENROUTER_API_KEY = os.environ.get('OPENROUTER_API_KEY')
API_URL = "https://openrouter.ai/api/v1/chat/completions"

# 模型：Gemini 3.1 Flash Image Preview
MODEL = "google/gemini-3.1-flash-image-preview"
IMAGE_SIZE = "512x512"

# 输出目录
OUTPUT_DIR = Path("/root/.openclaw/workspace/trilingual-picturebook/output")
OUTPUT_DIR.mkdir(exist_ok=True)

# 角色参考图
CHARACTER_SHEET_DIR = Path("/root/.openclaw/workspace/trilingual-picturebook")
YANGYANG_REF = CHARACTER_SHEET_DIR / "_charsheet_yangyang.png"
ZHUAZHUA_REF = CHARACTER_SHEET_DIR / "_charsheet_zhuazhua.png"


def load_image_as_base64(image_path: Path) -> str:
    """加载图片为 base64 data URL"""
    with open(image_path, "rb") as f:
        image_data = f.read()
    base64_data = base64.b64encode(image_data).decode('utf-8')
    return f"data:image/png;base64,{base64_data}"


def generate_image(prompt_text: str, filename: str):
    """调用 Gemini API 生成图片（带角色参考）"""
    headers = {
        "Authorization": f"Bearer {OPENROUTER_API_KEY}",
        "Content-Type": "application/json",
        "HTTP-Referer": "https://openclaw.ai",
        "X-Title": "Picturebook Gemini"
    }
    
    # 构建内容：参考图 + 提示词
    content_parts = []
    
    # 添加角色参考图
    if YANGYANG_REF.exists():
        print(f"   📎 参考图：yangyang ({YANGYANG_REF.stat().st_size // 1024} KB)")
        content_parts.append({
            "type": "image_url",
            "image_url": load_image_as_base64(YANGYANG_REF)
        })
    
    if ZHUAZHUA_REF.exists():
        print(f"   📎 参考图：zhuazhua ({ZHUAZHUA_REF.stat().st_size // 1024} KB)")
        content_parts.append({
            "type": "image_url",
            "image_url": load_image_as_base64(ZHUAZHUA_REF)
        })
    
    # 添加提示词
    content_parts.append({
        "type": "text",
        "text": prompt_text
    })
    
    payload = {
        "model": MODEL,
        "messages": [{"role": "user", "content": content_parts}],
        "max_tokens": 1000
    }
    
    # 重试逻辑
    max_retries = 6
    base_delay = 20
    
    for attempt in range(max_retries):
        try:
            print(f"   🔄 尝试 {attempt + 1}/{max_retries}...")
            response = requests.post(API_URL, headers=headers, json=payload, timeout=180)
            
            if response.status_code == 429:
                wait_time = min(base_delay * (2 ** attempt), 300)
                print(f"   ⏳ 限流，等待 {wait_time} 秒...")
                time.sleep(wait_time)
                continue
            
            response.raise_for_status()
            result = response.json()
            
            # 提取图片
            if "choices" in result and len(result["choices"]) > 0:
                message = result["choices"][0].get("message", {})
                if "images" in message and len(message["images"]) > 0:
                    image_url = message["images"][0].get("image_url", "")
                    if isinstance(image_url, dict):
                        image_url = image_url.get("url", "")
                    
                    print(f"   📥 下载图片...")
                    if image_url.startswith("data:image"):
                        image_data = base64.b64decode(image_url.split(",")[1])
                    else:
                        image_resp = requests.get(image_url, timeout=60)
                        image_resp.raise_for_status()
                        image_data = image_resp.content
                    
                    # 保存
                    output_path = OUTPUT_DIR / f"{filename}.png"
                    with open(output_path, "wb") as f:
                        f.write(image_data)
                    
                    file_size_kb = len(image_data) / 1024
                    print(f"   ✅ 已保存：{filename} ({file_size_kb:.1f} KB)")
                    return {"success": True, "path": str(output_path), "size_kb": file_size_kb}
            
            print(f"   ❌ 未返回图片")
            return {"success": False, "error": "No image"}
            
        except Exception as e:
            print(f"   ❌ 错误：{e}")
            if attempt < max_retries - 1:
                wait_time = min(base_delay * (2 ** attempt), 300)
                time.sleep(wait_time)
            else:
                return {"success": False, "error": str(e)}
    
    return {"success": False, "error": "Max retries"}


def main():
    print("=" * 60)
    print("📖 Gemini 3.1 重新生成第 13-14 页")
    print("=" * 60)
    print(f"📁 输出：{OUTPUT_DIR}")
    print(f"📐 尺寸：{IMAGE_SIZE}")
    print(f"🎨 模型：{MODEL}")
    print(f"📎 参考图：yangyang + zhuazhua")
    
    # 第 13 页提示词 - 强调角色一致性（简体）
    prompt_13 = """
【重要：请仔细参考提供的角色图片，保持角色一致】

场景：龙凤胎兄妹手牵手走下木质楼梯

哥哥（参考 yangyang 图片）:
- 2 岁中国男孩，黑色短发，圆脸大眼睛
- 穿浅蓝色 T 恤和短裤
- 表情：关爱妹妹的哥哥样，走在前面引导

妹妹（参考 zhuazhua 图片）:
- 2 岁中国女孩，黑色头发扎两个小辫子
- 穿粉色连衣裙
- 表情：信任哥哥，跟在后面

氛围：温馨的晨光从窗户照进来，舒适的家居环境
风格：水彩儿童绘本插画，柔和的粉彩色调，方形构图
要求：无文字，无水印，高质量

重要：角色必须与参考图片一致！哥哥蓝色衣服，妹妹粉色衣服 + 小辫子
"""
    
    # 第 14 页提示词
    prompt_14 = """
【重要：请仔细参考提供的角色图片，保持角色一致】

场景：小狗旺财在楼梯下迎接龙凤胎

小狗旺财：
- 可爱的金毛幼犬
- 坐在楼梯底部，抬头向上看
- 表情：兴奋开心，摇尾巴

背景中的龙凤胎（参考图片）:
- 哥哥（yangyang）：蓝色衣服，在楼梯上
- 妹妹（zhuazhua）：粉色衣服 + 小辫子，在楼梯上
- 两人正走下来准备和小狗互动

氛围：温馨的晨光，舒适的家居环境
风格：水彩儿童绘本插画，柔和的粉彩色调，方形构图
要求：无文字，无水印，高质量
"""
    
    results = []
    
    # 生成第 13 页
    print("\n--- 第 13 页 ---")
    result_13 = generate_image(prompt_13, "gemini-page-13-v4")
    results.append(("gemini-page-13-v4", result_13))
    
    time.sleep(5)
    
    # 生成第 14 页
    print("\n--- 第 14 页 ---")
    result_14 = generate_image(prompt_14, "gemini-page-14-v4")
    results.append(("gemini-page-14-v4", result_14))
    
    # 保存日志
    log_path = OUTPUT_DIR / f"gemini-regen-log-{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
    with open(log_path, "w", encoding="utf-8") as f:
        json.dump(results, f, ensure_ascii=False, indent=2)
    
    print(f"\n📋 日志：{log_path}")
    print("\n" + "=" * 60)
    success = sum(1 for _, r in results if r.get("success"))
    print(f"✅ 成功：{success}/{len(results)}")
    print("=" * 60)


if __name__ == "__main__":
    main()
