# -*- coding: utf-8 -*-
# 导入所需模块
import os
import sys
import time

import requests
import sounddevice as sd
import numpy as np
import wave
import io
import base64
from PIL import Image
from openai import OpenAI
from devices import dfrobot_epaper  # 墨水屏驱动库
import RPi.GPIO as GPIO
from apscheduler.schedulers.background import BackgroundScheduler
import datetime
import cv2

# ================== 配置区域 ==================
# SiliconFlow API 配置
SILICONFLOW_API_KEY = "sk-kxwsrzianqfxsebnihblrgyyytrrtgvvdjvdiujcuvwymrfp"  # API密钥
SILICONFLOW_BASE_URL = "https://api.siliconflow.cn/v1"  # API基础URL
AUDIO_MODEL = "Qwen/Qwen3-Omni-30B-A3B-Instruct"  # 语音识别模型
IMAGE_MODEL = "Kwai-Kolors/Kolors"  # 文生图模型
IMAGE_SIZE = "1280x720"  # 生成的图片尺寸
GENERATED_IMAGE_PATH = "./generated_raw.png"  # 原始生成图片保存路径
CONVERTED_IMAGE_PATH = "./display.bmp"  # 转换后的墨水屏图片路径

# 墨水屏硬件参数
RASPBERRY_SPI_BUS = 0
RASPBERRY_SPI_DEV = 0
RASPBERRY_PIN_CS = 27  # 片选引脚
RASPBERRY_PIN_CD = 17  # 命令/数据引脚
RASPBERRY_PIN_BUSY = 4  # 忙检测引脚
EPAPER_WIDTH = 250      # 墨水屏宽度（像素）
EPAPER_HEIGHT = 128     # 墨水屏高度（像素）

# 录音参数
SAMPLE_RATE = 16000     # 采样率
BUTTON_A_PIN = 21       # A键（文生图）GPIO引脚
BUTTON_B_PIN = 20       # B键（图生图）GPIO引脚

# GPIO初始化（上拉输入模式）
GPIO.setmode(GPIO.BCM)
GPIO.setup(BUTTON_A_PIN, GPIO.IN, pull_up_down=GPIO.PUD_UP)
GPIO.setup(BUTTON_B_PIN, GPIO.IN, pull_up_down=GPIO.PUD_UP)
print(f"Button A initialized on GPIO {BUTTON_A_PIN}")
print(f"Button B initialized on GPIO {BUTTON_B_PIN}")

# 全局初始化墨水屏对象
epaper = dfrobot_epaper.DFRobot_Epaper_SPI(
    RASPBERRY_SPI_BUS, RASPBERRY_SPI_DEV,
    RASPBERRY_PIN_CS, RASPBERRY_PIN_CD, RASPBERRY_PIN_BUSY
)
epaper.begin()
print("E-paper initialized globally")

def record_while_button_pressed(button_pin, device=1, samplerate=SAMPLE_RATE):
    """
    按住按键期间持续录音，松开后返回音频base64数据
    :param button_pin: 按键对应的GPIO引脚
    :param device: 音频输入设备编号（默认为1，USB麦克风）
    :param samplerate: 采样率
    :return: 音频base64数据URI，若无录音则返回None
    """
    print(f"Press and hold button (GPIO{button_pin}) to record... Release to stop.")

    # 等待按键按下（检测低电平）
    while GPIO.input(button_pin) == GPIO.HIGH:
        time.sleep(0.05)
    print("Recording started...")

    audio_chunks = []
    stream = sd.InputStream(
        samplerate=samplerate,
        device=device,
        channels=1,
        dtype='float32'
    )
    stream.start()

    # 按键按住期间持续读取音频块
    while GPIO.input(button_pin) == GPIO.LOW:
        chunk, overflowed = stream.read(int(samplerate * 0.1))
        if overflowed:
            print("Warning: audio buffer overflow")
        audio_chunks.append(chunk.copy())
        time.sleep(0.01)

    stream.stop()
    stream.close()
    print("Recording stopped.")

    if not audio_chunks:
        return None

    # 拼接音频数据并转换为16-bit PCM
    recording = np.concatenate(audio_chunks, axis=0)
    recording_int16 = (recording * 32767).clip(-32768, 32767).astype(np.int16)
    wav_buffer = io.BytesIO()
    with wave.open(wav_buffer, 'wb') as wf:
        wf.setnchannels(1)
        wf.setsampwidth(2)
        wf.setframerate(samplerate)
        wf.writeframes(recording_int16.tobytes())
    wav_buffer.seek(0)
    b64_str = base64.b64encode(wav_buffer.read()).decode('utf-8')
    return f"data:audio/wav;base64,{b64_str}"

def speech_to_text(audio_base64_uri):
    """
    调用语音识别API将音频转换为文本
    :param audio_base64_uri: 音频base64数据URI
    :return: 识别出的文本，失败返回None
    """
    client = OpenAI(api_key=SILICONFLOW_API_KEY, base_url=SILICONFLOW_BASE_URL)
    try:
        response = client.chat.completions.create(
            model=AUDIO_MODEL,
            messages=[{
                "role": "user",
                "content": [
                    {"type": "audio_url", "audio_url": {"url": audio_base64_uri}},
                    {"type": "text", "text": "Please transcribe the audio content directly into text without any explanation."}
                ]
            }]
        )
        return response.choices[0].message.content.strip()
    except Exception as e:
        print(f"Speech recognition failed: {e}")
        return None

def generate_image(prompt, api_key, model, size, save_path):
    """
    调用文生图API生成图片并保存到本地
    :param prompt: 提示词
    :param api_key: API密钥
    :param model: 模型名称
    :param size: 图片尺寸
    :param save_path: 保存路径
    :return: 成功返回True，失败返回False
    """
    print("[1/3] Calling API to generate image...")
    client = OpenAI(api_key=api_key, base_url=SILICONFLOW_BASE_URL)
    try:
        response = client.images.generate(
            model=model,
            prompt=prompt,
            size=size,
            n=1
        )
        image_url = response.data[0].url
        print("    Image URL obtained successfully, downloading...")
        img_data = requests.get(image_url).content
        with open(save_path, "wb") as f:
            f.write(img_data)
        print(f"    Image saved to {save_path}")
        return True
    except Exception as e:
        print(f"    Image generation failed: {e}")
        return False

def convert_to_epaper_bitmap(input_path, output_path, width, height, threshold=128, fit_mode='contain'):
    """
    将普通图片转换为墨水屏可显示的1位BMP图片
    :param input_path: 输入图片路径
    :param output_path: 输出BMP路径
    :param width: 目标宽度
    :param height: 目标高度
    :param threshold: 二值化阈值
    :param fit_mode: 适应模式，'contain'为保持比例居中填充白色，'cover'为裁剪填满
    """
    print("[2/3] Converting image to e-paper format...")
    try:
        img = Image.open(input_path).convert("L")
        target_ratio = width / height
        img_ratio = img.width / img.height
        if fit_mode == 'contain':
            # 保持比例，完全显示，剩余部分填充白色
            if img_ratio > target_ratio:
                new_width = width
                new_height = int(width / img_ratio)
            else:
                new_height = height
                new_width = int(height * img_ratio)
            img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
            canvas = Image.new('L', (width, height), 255)
            x = (width - new_width) // 2
            y = (height - new_height) // 2
            canvas.paste(img, (x, y))
            img = canvas
        elif fit_mode == 'cover':
            # 保持比例，填满屏幕，多余部分裁剪
            if img_ratio > target_ratio:
                new_height = height
                new_width = int(height * img_ratio)
            else:
                new_width = width
                new_height = int(width / img_ratio)
            img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
            left = (new_width - width) // 2
            top = (new_height - height) // 2
            img = img.crop((left, top, left + width, top + height))
        # 二值化
        img = img.point(lambda x: 0 if x < threshold else 255, '1')
        img.save(output_path, format="BMP")
        print(f"    Conversion completed, saved to {output_path}")
    except Exception as e:
        print(f"    Image conversion failed: {e}")
        sys.exit(1)

def display_on_epaper(bmp_path, x=0, y=0, flush_mode="PART"):
    """
    在墨水屏上显示指定的BMP图片
    :param bmp_path: BMP图片路径
    :param x: 起始x坐标
    :param y: 起始y坐标
    :param flush_mode: 刷新模式，PART为局部刷新，ALL为全局刷新
    """
    print("[3/3] Displaying on e-paper...")
    try:
        print("    Clearing screen...")
        epaper.clearScreen()
        time.sleep(0.5)
        print(f"    Displaying image {bmp_path} at coordinates ({x}, {y})")
        epaper.bitmapFile(x, y, bmp_path)
        if flush_mode.upper() == "PART":
            epaper.flush(epaper.PART)
            print("    Partial refresh completed")
        else:
            epaper.flush(epaper.ALL)
            print("    Full refresh completed")
        print("    Image displayed on e-paper!")
    except Exception as e:
        print(f"    E-paper operation failed: {e}")
        sys.exit(1)

def execute_painting(prompt):
    """
    执行绘画流程：生成图片 -> 转换格式 -> 显示到墨水屏
    :param prompt: 提示词
    :return: 成功返回True，失败返回False
    """
    print(f"Prompt: {prompt}")
    if not generate_image(prompt, SILICONFLOW_API_KEY, IMAGE_MODEL, IMAGE_SIZE, GENERATED_IMAGE_PATH):
        return False
    convert_to_epaper_bitmap(GENERATED_IMAGE_PATH, CONVERTED_IMAGE_PATH, EPAPER_WIDTH, EPAPER_HEIGHT, fit_mode='contain')
    display_on_epaper(CONVERTED_IMAGE_PATH)
    return True

def weather_daily_task():
    """定时任务：获取北京天气并生成对应的简笔画"""
    print("\n--- Starting daily weather task ---")
    try:
        import pyowm
        owm = pyowm.OWM('bb27bb2970186aa8850fc421a830fb47')  # 使用OpenWeatherMap API密钥
        mgr = owm.weather_manager()
        weather = mgr.weather_at_place('Beijing,CN').weather
        temp = weather.temperature('celsius')['temp']
        status = weather.detailed_status
        # 构造绘画提示词
        prompt = f"{status} in Beijing, {temp} degrees Celsius, black and white simple line drawing style, pure white background, high contrast"
        print(f"Weather prompt: {prompt}")
        execute_painting(prompt)
    except Exception as e:
        print(f"Weather task failed: {e}")

def capture_image_to_base64():
    """
    调用摄像头拍照并返回JPEG格式的base64数据URI
    :return: 图片data URI，失败返回None
    """
    print("Capturing photo...")
    cap = cv2.VideoCapture(0)
    if not cap.isOpened():
        print("Error: Could not open camera.")
        return None
    ret, frame = cap.read()
    cap.release()
    if not ret:
        print("Error: Failed to capture image.")
        return None
    _, buffer = cv2.imencode('.jpg', frame)
    jpg_as_text = base64.b64encode(buffer).decode('utf-8')
    return f"data:image/jpeg;base64,{jpg_as_text}"

def generate_image_from_prompt_and_image(prompt, image_base64, save_path):
    """
    调用图生图API（基于提示词和输入图片）生成新图片
    :param prompt: 提示词
    :param image_base64: 输入图片的base64数据URI
    :param save_path: 保存路径
    :return: 成功返回True，失败返回False
    """
    print("Calling API...")
    url = f"{SILICONFLOW_BASE_URL}/images/generations"
    headers = {
        "Authorization": f"Bearer {SILICONFLOW_API_KEY}",
        "Content-Type": "application/json"
    }
    payload = {
        "model": "Qwen/Qwen-Image-Edit-2509",  # 支持图生图的模型
        "prompt": prompt,
        "image": image_base64,
        "image_size": IMAGE_SIZE,
        "batch_size": 1
    }
    try:
        response = requests.post(url, json=payload, headers=headers)
        response.raise_for_status()
        data = response.json()
        image_url = data['images'][0]['url']
        print("    Image URL obtained, downloading...")
        img_data = requests.get(image_url).content
        with open(save_path, "wb") as f:
            f.write(img_data)
        print(f"    Image saved to {save_path}")
        return True
    except Exception as e:
        print(f"Image generation failed: {e}")
        return False

def handle_button_a():
    """处理A键按下：录音 -> 语音识别 -> 文生图 -> 显示"""
    print("--- Button A pressed (Text-to-Image) ---")
    audio_uri = record_while_button_pressed(BUTTON_A_PIN)
    if audio_uri is None:
        print("No audio recorded.")
        return
    prompt = speech_to_text(audio_uri)
    if not prompt:
        return
    execute_painting(prompt)
    print("Button A round completed.\n")

def handle_button_b():
    """处理B键按下：录音 -> 语音识别 -> 拍照 -> 图生图 -> 显示"""
    print("--- Button B pressed (Image-to-Image) ---")
    # 录音获取提示词
    audio_uri = record_while_button_pressed(BUTTON_B_PIN)
    if audio_uri is None:
        print("No audio recorded.")
        return
    prompt = speech_to_text(audio_uri)
    if not prompt:
        return
    # 拍照获取输入图片
    image_b64 = capture_image_to_base64()
    if image_b64 is None:
        return
    # 图生图并显示
    if generate_image_from_prompt_and_image(prompt, image_b64, GENERATED_IMAGE_PATH):
        convert_to_epaper_bitmap(GENERATED_IMAGE_PATH, CONVERTED_IMAGE_PATH, EPAPER_WIDTH, EPAPER_HEIGHT, fit_mode='contain')
        display_on_epaper(CONVERTED_IMAGE_PATH)
    print("Button B round completed.\n")

def main():
    """主函数：初始化定时任务，循环检测按键"""
    if SILICONFLOW_API_KEY == "":
        print("Error: Please set a valid SiliconFlow API key")
        sys.exit(1)

    # 启动定时任务（每天21:55执行天气绘画）
    scheduler = BackgroundScheduler()
    scheduler.add_job(weather_daily_task, 'cron', hour=21, minute=55)
    scheduler.start()
    print("Scheduler started. Daily weather task set for 21:55.")

    print("Voice-to-Image E-paper System Started (A: text2img, B: img2img)")

    # 主循环：轮询两个按键，检测到按下后调用对应处理函数
    while True:
        if GPIO.input(BUTTON_A_PIN) == GPIO.LOW:
            time.sleep(0.05)  # 简单防抖
            if GPIO.input(BUTTON_A_PIN) == GPIO.LOW:
                handle_button_a()

        if GPIO.input(BUTTON_B_PIN) == GPIO.LOW:
            time.sleep(0.05)
            if GPIO.input(BUTTON_B_PIN) == GPIO.LOW:
                handle_button_b()

        time.sleep(0.05)  # 降低CPU占用

if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        print("\nProgram exited")
    finally:
        GPIO.cleanup([BUTTON_A_PIN, BUTTON_B_PIN])
        print("GPIO cleaned up.")