import os
from functools import partial
from typing import Union, Iterable, List
import unihiker


__all__ = ['Menu', 'MenuItem']


def _deep_copy(iter_: Iterable):
    return [_ if not isinstance(_, Iterable) else _deep_copy(_) for _ in iter_]


class MenuItem:
    def __init__(self, text: str, func: callable, bg_color: str = 'lightgray'):
        self.text = text
        self.func = func
        self.bg_color = bg_color

    def __call__(self, *args, **kwargs):
        self.func(*args, **kwargs)

    def __repr__(self):
        return f'<MenuItem {self.text}>'

    def __str__(self):
        return self.text


class Menu:
    def __init__(self, items: Union[List[MenuItem], None] = None, title: str = 'Menu', title_color: str = 'black',
                 exit_button: bool = True):
        self._items: List[MenuItem]
        self._gui = unihiker.GUI()
        self._now_index: int = 0
        self._page_num: int = 8
        self._total_page: int = 0
        self._fill_round_rects: List[Union[unihiker.GUI.CanvasRoundRect, None]] = []
        self._fill_round_rects_funcs: List[callable] = []
        self._fill_round_rects_funcs_list: List[List[callable]] = []
        self._texts: List[Union[unihiker.GUI.CanvasText, None]] = []
        self._texts_funcs: List[callable] = []
        self._texts_funcs_list: List[List[callable]] = []
        self._exit_button: bool = exit_button

        self._main_text = self._gui.draw_text(x=120, y=0, w=240, text=title, font_size=14, color=title_color,
                                              origin='top')
        if items is not None:
            self._items: List[MenuItem] = items
        else:
            self._items: List[MenuItem] = []

    def add_item(self, item: MenuItem):
        self._items.append(item)

    def remove_item(self, item: MenuItem):
        self._items.remove(item)

    def _init(self):
        for i in range(0, len(self._items), self._page_num):
            for j, item in enumerate(self._items[i:i + self._page_num]):
                self._fill_round_rects_funcs.append(
                    partial(self._gui.fill_round_rect, x=-400, y=-400, w=240, h=30, onclick=item.func,
                            color=item.bg_color))
                self._texts_funcs.append(
                    partial(self._gui.draw_text, x=-400, y=-400, w=240, onclick=item.func, text=item.text,
                            font_size=12))

            self._fill_round_rects_funcs_list.append(_deep_copy(self._fill_round_rects_funcs[i:i + self._page_num]))
            self._texts_funcs_list.append(_deep_copy(self._texts_funcs[i:i + self._page_num]))

        self._total_page = len(self._fill_round_rects_funcs_list)

    def _up_item(self):
        if self._now_index < 1:
            self._now_index = self._total_page - 1
        else:
            self._now_index -= 1

    def _down_item(self):
        if self._now_index > self._total_page - 2:
            self._now_index = 0
        else:
            self._now_index += 1

    def show_item(self, index: int):
        for _ in range(len(self._fill_round_rects_funcs_list[index])):
            if _ < self._page_num:
                # self._fill_round_rects.insert(index*self._page_num+_, self._fill_round_rects_funcs_list[index][_]())
                # self._texts.insert(index*self._page_num+_, self._texts_funcs_list[index][_]())
                self._fill_round_rects[index * self._page_num + _].config(x=0, y=28 + _ * 35)
                self._texts[index * self._page_num + _].config(x=2, y=28 + _ * 35)

    def hide_item(self, index: int):
        for _ in range(len(self._fill_round_rects_funcs_list[index])):
            if _ < self._page_num:
                self._fill_round_rects[index * self._page_num + _].config(x=-400, y=-400)
                self._texts[index * self._page_num + _].config(x=-400, y=-400)

    def begin(self):
        self._init()

        self._gui.on_a_click(self._up_item)
        self._gui.on_b_click(self._down_item)
        if self._exit_button:
            self._gui.add_button(text="退出", x=0, y=0, w=50, h=30, onclick=exit)
        [self._fill_round_rects.append(f1()) for f1 in self._fill_round_rects_funcs]
        [self._texts.append(f2()) for f2 in self._texts_funcs]

        while True:
            for i in range(self._total_page):
                if self._now_index == i:
                    self.show_item(i)
                else:
                    self.hide_item(i)


if __name__ == '__main__':
    menu = Menu([
        MenuItem('a', lambda: print('a')),
        MenuItem('b', lambda: print('b')),
        MenuItem('c', lambda: print('c')),
        MenuItem('d', lambda: print('d')),
        MenuItem('e', lambda: print('e')),
        MenuItem('f', lambda: print('f')),
        MenuItem('g', lambda: print('g')),
        MenuItem('h', lambda: print('h')),
        MenuItem('i', lambda: print('i')),
        MenuItem('j', lambda: print('j')),
        MenuItem('k', lambda: print('k')),
        MenuItem('l', lambda: print('l')),
        MenuItem('m', lambda: print('m')),
        MenuItem('n', lambda: print('n')),
        MenuItem('o', lambda: print('o')),
        MenuItem('p', lambda: print('p')),
        MenuItem('q', lambda: print('q')),
        MenuItem('r', lambda: print('r')),
        MenuItem('s', lambda: print('s')),
        MenuItem('t', lambda: print('t')),
        MenuItem('u', lambda: print('u')),
        MenuItem('v', lambda: print('v')),
        MenuItem('w', lambda: print('w')),
        MenuItem('x', lambda: print('x')),
        MenuItem('y', lambda: print('y')),
        MenuItem('z', lambda: print('z')),
        MenuItem('1', lambda: print('1')),
        MenuItem('2', lambda: print('2')),
        MenuItem('3', lambda: print('3')),
        MenuItem('4', lambda: print('4')),
        MenuItem('5', lambda: print('5')),
        MenuItem('6', lambda: print('6')),
        MenuItem('7', lambda: print('7')),
        MenuItem('8', lambda: print('8')),
        MenuItem('9', lambda: print('9')),
        MenuItem('0', lambda: print('0'))
    ])
    menu.begin()
