{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Moving Digit\nAn example of linear :class:`sksfa.SFA` applied to a simple image time-series:\na one-digit version of the moving MNIST dataset. Each data point is 4096-dimensional.\n\n<img src=\"file://../images/moving_mnist.gif\" align=\"center\">\n\nIf the change in x is not significantly faster or slower than the change in y, linear SFA with only two output features\nsuccessfully extracts a smooth (and possibly flipped) representation of the position of the digit in the image.\n\nGround truth is only added for the comparison, not during training.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import numpy as np\nfrom sksfa import SFA\nimport matplotlib.pyplot as plt\nfrom matplotlib.animation import FuncAnimation\nimport os\n\nsplit_ratio = 0.7\nall_sequences = np.load(\"data/mmnist_data.npy\").squeeze()\nground_truth = np.load(\"data/mmnist_positions.npy\").squeeze()\nn_points = all_sequences.shape[0]\nsplit_idx = int(split_ratio * n_points)\nall_sequences = all_sequences[:, ::, ::]\nold_shape = all_sequences.shape[-2:]\n\ndata = all_sequences.reshape(all_sequences.shape[0], -1)\ntraining_data = data[:split_idx]\ntraining_gt = ground_truth[:split_idx]\ntest_data = data[split_idx:]\ntest_gt = ground_truth[split_idx:]\n\nsfa = SFA(2)\n\nsfa.fit(training_data)\noutput = sfa.transform(test_data)\n\ngt_delta = np.var(test_gt[1:] - test_gt[:-1], axis=0)\ngt_order = np.argsort(gt_delta)\ngt_labels = [\"x\", \"y\"]\n\nfig, ax = plt.subplots(2, 2, sharex=True)\ncutoff = 60\nax[0, 0].plot(output[:cutoff, 0])\nax[1, 0].plot(output[:cutoff, 1])\nax[0, 1].plot(test_gt[:cutoff, gt_order[0]])\nax[1, 1].plot(test_gt[:cutoff, gt_order[1]])\nax[0, 0].set_title(\"Extracted features\")\nax[0, 1].set_title(\"True position\")\n\nplt.tight_layout()\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}