import { InferenceSession, Tensor } from "onnxruntime-web";
import onnx_model_v1 from "../../assets/models/TicTacToeAgentV1.onnx";
import onnx_model_v2 from "../../assets/models/TicTacToeAgentV1.onnx";

export class TicTacToeVersions {
  static abstract = "abstract";
  static V0 = "MinMax";
  static V1 = "1.0";
  static V2 = "2.0";
}

export class AbstractTicTacToe {
  version = TicTacToeVersions.abstract;
  static X = -1;
  static O = 1;
  static SPACE = 0;
  session = null;
  role = -1; // default role is X
  board = [0, 0, 0, 0, 0, 0, 0, 0, 0];

  constructor(role=undefined) {
    if (role === undefined) {
      this.role = AbstractTicTacToe.X;
    } else if (role === AbstractTicTacToe.X || role === AbstractTicTacToe.O) {
      this.role = role;
    } else {
      throw new Error(`
        TicTacToe v-${this.version}) constructed with invalid role: ${role}. 
        Options are -1 (X) or 1 (O).
      `);
    }
  }
  
  getRole() {
    return this.role;
  }

  getPlayerRole() {
    return this.role === AbstractTicTacToe.X 
      ? AbstractTicTacToe.O 
      : AbstractTicTacToe.X;
  }

  async createSession() {
    this.session = null;
    this.board = [0, 0, 0, 0, 0, 0, 0, 0, 0];
  };

  isValidMove(action, role) {
    if (action < 0 || action >= this.board.length) {
      console.log(`Action ${action} is invalid. Out of index.`);
      return false;
    }

    if (this.board[action] !== 0) {
      console.log(`Action ${action} is invalid. Board: ${this.board}`);
      return false;
    }

    let boardSum = this.board.reduce((partialSum, x) => partialSum + x, 0);
    boardSum += role;

    const balanced = Math.abs(boardSum) <= 1;
    if (!balanced) {
      console.log(`Action ${action} is invalid. Board unbalanced. Sum: ${Math.abs(boardSum)}. Role: ${role}`);
    }
    return balanced;
  };

  getValidMoves() {
    const indicies = [];
    const mask = new Array(this.board.length).map((index) => {
      return 0;
    });
    for (let i = 0; i < this.board.length; i++) {
      if (this.board[i] === 0) {
        indicies.push(i);
        mask[i] = 1;
      }
    }
    return [mask, indicies];
  };

  argmax(values) {
    return values.indexOf(Math.max(...values));
  };

  async takeAction(role) {
    const action = await this.getAction();
    if (this.isValidMove(action, role)) {
      this.board[action] = role;
    } else {
      this.printBoard();
      throw new Error(`Action '${action}' is invalid. Role: ${role}. Version: ${this.version}.`);
    }
  };

  // The policy used to determine which action to take.
  async getAction() {
    throw new Error(
      `(TicTacToe v-${this.version}) The getAction method is undefined.`
    );
  };

  // returns [winner, done]
  // winner can be:
  // 0:  tied
  // 1:  O
  // -1: X
  // 2: game is ongoing
  isGameOver() {
    // check rows and columns
    for (let i = 0; i < 3; i++) {
      if (
        this.board[i * 3 + 0] === this.board[i * 3 + 1] &&
        this.board[i * 3 + 1] === this.board[i * 3 + 2] &&
        this.board[i * 3 + 2] !== 0
      ) {
        return [this.board[i * 3 + 0], true];
      } else if (
        this.board[0 * 3 + i] === this.board[1 * 3 + i] &&
        this.board[1 * 3 + i] === this.board[2 * 3 + i] &&
        this.board[2 * 3 + i] !== 0
      ) {
        return [this.board[0 * 3 + i], true];
      }
    }

    // check diagonals
    if (
      this.board[0] === this.board[4] &&
      this.board[4] === this.board[8] &&
      this.board[8] !== 0
    ) {
      return [this.board[0], true];
    } else if (
      this.board[2] === this.board[4] &&
      this.board[4] === this.board[6] &&
      this.board[6] !== 0
    ) {
      return [this.board[2], true];
    }

    // check if there's a tie
    const boardSum = this.board
      .map((value) => {
        return value === 0 ? 0 : 1;
      })
      .reduce((partialSum, x) => partialSum + x, 0);

    // return 0 if game is tied, return 2 if game is ongoing
    if (boardSum === 9) {
      return [0, true];
    }
    return [2, false];
  };

  getBoard() {
    return [...this.board];
  };

  setBoard(board) {
    this.board = [...board];
  };

  resetBoard() {
    this.board = [0, 0, 0, 0, 0, 0, 0, 0, 0];
  };

  roleIsX() {
    return this.role === AbstractTicTacToe.X;
  };

  roleAsString(role) {
    switch (role) {
      case AbstractTicTacToe.X:
        return "X";
      case AbstractTicTacToe.O:
        return "O";
      case AbstractTicTacToe.SPACE:
        return " ";
      default:
        throw new Error(`(roleAsString v-${this.version}) Invalid role ${role}`);
    }
  };

  flipRoles() {
    this.role *= -1;
  };

  async resetGame() {
    this.resetBoard();
    this.flipRoles();
    if (this.roleIsX()) {
      const randomAction = Math.floor(Math.random() * this.board.length);
      this.board[randomAction] = this.role;
    }
  };

  printBoard(board = undefined) {
    if (board === null || board === undefined) {
      board = this.board;
    }

    console.log("_______");
    for (let i = 0; i < 3; i++) {
      const offset = i * 3;
      const col1 = this.roleAsString(board[0 + offset]);
      const col2 = this.roleAsString(board[1 + offset]);
      const col3 = this.roleAsString(board[2 + offset]);
      const row = `${col1}|${col2}|${col3}`;
      console.log(row);
    }
    console.log("‾‾‾‾‾‾‾");
  };
}

export class TicTacToeV0 extends AbstractTicTacToe {
  static version = TicTacToeVersions.V0;
  minMaxWinner = null;

  getAction = async () => {
    const scores = [];
    const [_, moves] = this.getValidMoves();
    if (moves.length > 7) {
      // search space too large,  pick randomly
      return moves[this.getRandomInt(moves.length)];
    }
    const nextPlayer = (this.role *= -1);
    for (const move in moves) {
      this.board[move] = this.role;
      scores.push(this.minMax(this.role, nextPlayer));
      this.board[move] = 0;
    }
    return this.argmax(scores);
  };

  getRandomInt = (max) => {
    return Math.floor(Math.random() * max);
  };

  minMax = (role, player) => {
    const [status, done] = this.isGameOver();
    if (done) {
      if (this.minMaxWinner !== null) {
        // return 1 if role wins else return -1
        // -1 *  1 = -1 [role "X" loses to "O"]
        // -1 * -1 =  1 [role "X" beats op "O"]
        //  1 *  1 =  1 [role "O" beats op "X"]
        //  1 * -1 = -1 [role "O" loses to "X"]
        return role * this.minMaxWinner;
      }
      // return 0 in case of a tie
      return 0;
    }

    const scores = [];
    const [indicies, moves] = this.getValidMoves();
    const nextPlayer = (player *= -1);
    for (const move in moves) {
      this.board[move] = player;
      scores.push(this.minMax(role, nextPlayer));
      this.board[move] = 0;
    }

    if (player === role) {
      return Math.max(scores);
    }
    return Math.min(scores);
  };

  isGameOver() {
    const [status, done] = super.isGameOver();
    if (status !== 2) {
      this.minMaxWinner = status;
    }
    return [status, done];
  };

  resetGame = () => {
    this.minMaxWinner = null;
    super.resetGame();
  };
}

export class TicTacToeV1 extends AbstractTicTacToe {
  version = TicTacToeVersions.V1;

  async createSession() {
    await super.createSession();
    this.session = await InferenceSession.create(onnx_model_v1, {
      backendHint: "webgl",
      executionProviders: ["webgl"],
    });
    await this.takeAction(); // v1 is always X
  };

  async takeAction() {
    await super.takeAction(TicTacToeV1.X);
  };

  async getAction() {
    return await this.performInference();
  };

  // Performs inference and greedily returns an action.
  async performInference() {
    if (!this.session) {
      throw new Error("An Inference Session has not been created yet.");
    }

    const inputArray = new Float32Array(this.board);
    const inputTensor = new Tensor("float32", inputArray, [1, 9]);

    // when converting my pytorch model from pytorch to onnx, I
    // specified an input name "input" (although it can be whatever
    // string)... Thats why the object below has an "input" key
    const outputMap = await this.session.run({ input: inputTensor });

    // Again the key "output" exists because I specified it as an
    // output_name when I created the onnx file
    const qValues = outputMap["output"].data;

    // Mask the outputs and return index (action) corresponding
    // to highest q-value
    const maskedValues = this.board.map((value, index) => {
      if (value !== 0) {
        return -1000;
      } else {
        return qValues[index];
      }
    });
    const argmax = this.argmax(maskedValues);
    return argmax;
  };

  flipRoles() {
    // TicTacToeV1 will always play as X.
    this.role = TicTacToeV1.X;
  };
}

export class TicTacToeV2 extends AbstractTicTacToe {
  version = TicTacToeVersions.V2;

  async createSession () {
    super.createSession();
    this.session = await InferenceSession.create(onnx_model_v2, {
      backendHint: "webgl",
      executionProviders: ["webgl"],
    });
    if (this.roleIsX()) {
      await this.takeAction();
    }
  };

  async takeAction() {
    await super.takeAction(this.role);
  };

  async getAction() {
    return await this.performInference();
  };

  async performInference() {};
}
