{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Moving Digit (Experimental - Hierarchical SFA)\nAn example of :class:`sksfa.HSFA` applied to a simple image time-series:\na one-digit version of the moving MNIST dataset. Each data point is 4096-dimensional.\n\n<img src=\"file://../images/moving_mnist.gif\" align=\"center\">\n\nIf the change in x is not significantly faster or slower than the change in y, HSFA with only two output features\nsuccessfully extracts a smooth (and possibly flipped) representation of the position of the digit in the image.\n\nGround truth is only added for the comparison, not during training.\n\n(Note that a problem like this can also be solved with linear SFA. This example serves the purpose of providing an\nexample on how an HSFA network is initialized and how it can be directly applied to image data without flattening.)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import numpy as np\nfrom sksfa import HSFA\nimport matplotlib.pyplot as plt\nimport os\n\n# Loading and preparing the data\n# - HSFA requires a colorchannel, even for grayscale images\nsplit_ratio = 0.7\ndata = np.load(\"data/mmnist_data.npy\").squeeze()[..., None]\nground_truth = np.load(\"data/mmnist_positions.npy\").squeeze()\nn_points = data.shape[0]\nsplit_idx = int(split_ratio * n_points)\n\ntraining_data = data[:split_idx]\ntraining_gt = ground_truth[:split_idx]\ntest_data = data[split_idx:]\ntest_gt = ground_truth[split_idx:]\n\n# Preparing the HSFA-network:\n# - each layer needs a 6-tuple for configuration\n# - each 6-tuple contains (kernel_width, kernel_height, stride_width, stride_height, n_features, expansion_degree)\n# The final layer will always be a full connected SFA layer\nlayer_configurations = [(8, 8, 8, 8, 8, 1),\n                        (2, 2, 2, 2, 8, 2)]\n\nhsfa = HSFA(n_components=2,\n            input_shape=data.shape[1:],\n            layer_configurations=layer_configurations,\n            internal_batch_size=100,\n            noise_std=0.01)\n\nhsfa.summary()\n\nhsfa.fit(training_data)\noutput = hsfa.transform(test_data)\n\ngt_delta = np.var(test_gt[1:] - test_gt[:-1], axis=0)\ngt_order = np.argsort(gt_delta)\ngt_labels = [\"x\", \"y\"]\n\nfig, ax = plt.subplots(2, 2, sharex=True)\ncutoff = 60\nax[0, 0].plot(output[:cutoff, 0])\nax[1, 0].plot(output[:cutoff, 1])\nax[0, 1].plot(test_gt[:cutoff, gt_order[0]])\nax[1, 1].plot(test_gt[:cutoff, gt_order[1]])\nax[0, 0].set_title(\"Extracted features\")\nax[0, 1].set_title(\"True position\")\n\nplt.tight_layout()\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}