#!/usr/bin/python

import pygame
from sys import exit


# pygame stuff

pygame.init()

screen = pygame.display.set_mode((300, 300))
pygame.display.set_caption("Tic-tac-toe")

screen.fill((255, 255, 255), (98, 5, 5, 290))
screen.fill((255, 255, 255), (198, 5, 5, 290))
screen.fill((255, 255, 255), (5, 98, 290, 5))
screen.fill((255, 255, 255), (5, 198, 290, 5))

pygame.display.flip()

def mark(token, pos):
    font = pygame.font.Font(None, 100)

    token = font.render(token, True, (255, 255, 255))
    pos = token.get_rect(center=(pos[1] * 100 + 50, pos[0] * 100 + 50))

    screen.blit(token, pos)
    pygame.display.flip()

def finish(winner):
    font = pygame.font.Font(None, 80)
    
    text = font.render("%s wins!" % winner if winner in "XO" else "Tie!",\
                       True, (255, 0, 0), (100, 100, 100))
    pos = text.get_rect(center=(150, 150))
    
    screen.blit(text, pos)
    pygame.display.flip()
    
    while pygame.event.wait().type != pygame.QUIT:
        pass
    exit(0)


# Rules of production, state representation

class Rule():
    def __init__(self, i, j, token):
        self.i, self.j, self.token = i, j, token

    def can_apply(self, state):
        return state[self.i][self.j] == ' '

    def apply(self, state):
        s = [row[:] for row in state]
        s[self.i][self.j] = self.token
        return s

rules_p1 = [Rule(i, j, 'X') for i in range(3) for j in range(3)]
rules_p2 = [Rule(i, j, 'O') for i in range(3) for j in range(3)]

def game_over(state):
    rcd = state + map(list, zip(*state)) + [[state[i][i] for i in range(3)],
                                            [state[2 - i][i] for i in range(3)]]
    return 'O' if ['O', 'O', 'O'] in rcd else \
           'X' if ['X', 'X', 'X'] in rcd else \
           ' ' if not any(map(lambda l: ' ' in l, state)) else \
           False

state = [[' ', ' ', ' '], [' ', ' ', ' '], [' ', ' ', ' ']]


# Minimax

max_depth = 6

def h(state):
    rcd = state + map(list, zip(*state)) + [[state[i][i] for i in range(3)],
                                            [state[2 - i][i] for i in range(3)]]

    if ['O', 'O', 'O'] in rcd:
        return 2

    if ['X', 'X', 'X'] in rcd:
        return -2

    return 0

def minimax(state):
    def maxmove(state, depth):
        if depth == max_depth or game_over(state):
            return None, h(state)

        best_rule, best_state = None, None
        for rule in rules_p2:
            if rule.can_apply(state):
                _, s = minmove(rule.apply(state), depth + 1)
                if best_state is None or best_state < s:
                    best_rule, best_state = rule, s
        return best_rule, best_state

    def minmove(state, depth):
        if depth == max_depth or game_over(state):
            return None, h(state)

        best_rule, best_state = None, None
        for rule in rules_p1:
            if rule.can_apply(state):
                _, s = maxmove(rule.apply(state), depth + 1)
                if best_state is None or best_state > s:
                    best_rule, best_state = rule, s
        return best_rule, best_state
        
    return maxmove(state, 0)


# Game loop

while True:
    xturn = True

    while xturn:
        event = pygame.event.wait()
        pygame.event.clear()

        if event.type == pygame.QUIT:
            exit(0)
        elif event.type == pygame.MOUSEBUTTONDOWN:
            xpos, ypos = pygame.mouse.get_pos()
            xpos /= 100
            ypos /= 100

            if state[ypos][xpos] == ' ':
                mark('X', (ypos, xpos))
                state[ypos][xpos] = 'X'
                
                print "X marked %d, %d" % (ypos, xpos)

                w = game_over(state)
                if w: finish(w)

                xturn = False

    rule, _ = minimax(state)
    mark('O', (rule.i, rule.j))
    state[rule.i][rule.j] = 'O'
    print "O marked %d, %d" % (rule.i, rule.j)

    w = game_over(state)
    if w: finish(w)