Commit 9c191325 authored by pvskand's avatar pvskand
Browse files

MNIST Example added and reducing numpy matrix size

parent 2d2c291a
Loading
Loading
Loading
Loading
+119 −139
Original line number Diff line number Diff line
%% Cell type:markdown id: tags:

# Using Deepchem Datasets
In this tutorial we will have a look at various deepchem `dataset` methods present in `deepchem.datasets`.

%% Cell type:code id: tags:

``` python
import deepchem as dc
import numpy as np
import random
```

%% Output

    /home/skand/anaconda2/lib/python2.7/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
      from ._conv import register_converters as _register_converters

%% Cell type:markdown id: tags:

# Using NumpyDatasets
This is used when you have your data in numpy arrays.

%% Cell type:code id: tags:

``` python
# data is your dataset in numpy array of size : 20x20.
data = np.random.random((20, 20))
labels = np.random.random((20,)) # labels of size 20x1
data = np.random.random((4, 4))
labels = np.random.random((4,)) # labels of size 20x1
```

%% Cell type:code id: tags:

``` python
from deepchem.data.datasets import NumpyDataset # import NumpyDataset
```

%% Cell type:code id: tags:

``` python
dataset = NumpyDataset(data, labels) # creates numpy dataset object
```

%% Cell type:markdown id: tags:

## Extracting X, y from NumpyDataset Object
Extracting the data and labels from the NumpyDataset is very easy.

%% Cell type:code id: tags:

``` python
dataset.X # Extracts the data (X) from the NumpyDataset Object
```

%% Output

    array([[0.80411606, 0.34805478, 0.29928692, 0.26197872, 0.12218549,
            0.86869019, 0.0786187 , 0.64233347, 0.88440001, 0.54317082,
            0.12478745, 0.90971536, 0.79366028, 0.50423217, 0.07925668,
            0.64696748, 0.47188415, 0.99989203, 0.50182202, 0.58837986],
           [0.86891001, 0.4644628 , 0.90405208, 0.68878421, 0.24124402,
            0.53684253, 0.82148536, 0.21670004, 0.42497917, 0.83397996,
            0.43351402, 0.18756943, 0.36236951, 0.8826174 , 0.35109282,
            0.80009588, 0.78959647, 0.71436892, 0.07160891, 0.20659755],
           [0.50355677, 0.8560735 , 0.54420795, 0.96417837, 0.15491707,
            0.39011556, 0.11091615, 0.29148588, 0.1082059 , 0.11037224,
            0.76457818, 0.12473026, 0.28719931, 0.77576233, 0.71916411,
            0.66349005, 0.80499345, 0.62522088, 0.58887945, 0.66035806],
           [0.90646279, 0.6767805 , 0.47480557, 0.55327305, 0.92461253,
            0.06578666, 0.84239207, 0.15471436, 0.19349495, 0.39985696,
            0.0672663 , 0.43032112, 0.62293635, 0.90636177, 0.81686113,
            0.77144073, 0.88990408, 0.45551618, 0.20440387, 0.92580861],
           [0.04246429, 0.15277234, 0.44372911, 0.39989127, 0.45418207,
            0.11924539, 0.26439139, 0.69937347, 0.68031282, 0.59999017,
            0.84781096, 0.68543689, 0.48735431, 0.92363558, 0.69062848,
            0.84696972, 0.40403156, 0.75651259, 0.11825185, 0.62056095],
           [0.02358988, 0.00133248, 0.6639556 , 0.52043943, 0.60658972,
            0.65613178, 0.98758804, 0.88605931, 0.51491023, 0.00118192,
            0.78968332, 0.5033076 , 0.97900918, 0.67447431, 0.17397801,
            0.97425783, 0.71958989, 0.1102501 , 0.50606908, 0.78071158],
           [0.14934025, 0.19698863, 0.20066168, 0.06392504, 0.47309196,
            0.45418046, 0.32916996, 0.59452419, 0.4004481 , 0.3858378 ,
            0.05122769, 0.77450311, 0.80430419, 0.79766145, 0.04974053,
            0.99586396, 0.96329655, 0.19374425, 0.11179202, 0.19751258],
           [0.07052315, 0.47085825, 0.45183248, 0.61908212, 0.67561585,
            0.6863941 , 0.59820858, 0.97731326, 0.81392409, 0.94993571,
            0.7342618 , 0.5483082 , 0.76660455, 0.66503022, 0.78243311,
            0.51157202, 0.07505952, 0.85076943, 0.49522672, 0.49577379],
           [0.85549714, 0.61808114, 0.59830757, 0.56450062, 0.75905149,
            0.0963772 , 0.99163255, 0.36026936, 0.31126521, 0.75072157,
            0.86034812, 0.7369359 , 0.26218616, 0.03756869, 0.87674768,
            0.65134493, 0.83597441, 0.78508579, 0.03853348, 0.64537507],
           [0.87899253, 0.14145894, 0.638978  , 0.36661425, 0.19572014,
            0.67121666, 0.10690435, 0.78010338, 0.13210433, 0.3017432 ,
            0.69810343, 0.61881519, 0.67937728, 0.16195187, 0.33775608,
            0.23742211, 0.96625364, 0.66034387, 0.90965327, 0.89836415],
           [0.97502153, 0.82604086, 0.03039591, 0.8428495 , 0.93879771,
            0.9190982 , 0.88971402, 0.74871133, 0.94395543, 0.55579274,
            0.09380366, 0.16652342, 0.14918454, 0.70508486, 0.14766841,
            0.21206386, 0.92144307, 0.32044325, 0.04210327, 0.80816974],
           [0.05088437, 0.58663271, 0.86503111, 0.80431866, 0.1805416 ,
            0.39864182, 0.81798031, 0.94253819, 0.26825589, 0.51505505,
            0.41671257, 0.35384435, 0.07978396, 0.73671439, 0.88091407,
            0.3964508 , 0.63061182, 0.36086583, 0.08129983, 0.15263097],
           [0.12818544, 0.82634222, 0.30669694, 0.30999648, 0.18087149,
            0.63592918, 0.18244647, 0.74882755, 0.88553377, 0.85254722,
            0.65346311, 0.34916331, 0.84608189, 0.49873819, 0.69195101,
            0.17105631, 0.03453807, 0.54620154, 0.30503973, 0.36111045],
           [0.83923765, 0.48666503, 0.0297677 , 0.68601121, 0.97642406,
            0.94272115, 0.16336827, 0.52540839, 0.79193371, 0.98380886,
            0.12452907, 0.31050565, 0.68014483, 0.04691352, 0.29048599,
            0.71970896, 0.10792035, 0.06385573, 0.34693514, 0.8441569 ],
           [0.94282286, 0.56815512, 0.80229362, 0.93554513, 0.3149896 ,
            0.3110253 , 0.78917664, 0.24749218, 0.27681459, 0.93823735,
            0.56356183, 0.83358054, 0.5612108 , 0.49994808, 0.18618352,
            0.03914285, 0.44438789, 0.45910386, 0.37170373, 0.45591724],
           [0.39595868, 0.70257392, 0.90529185, 0.43131285, 0.25725599,
            0.16692838, 0.25366657, 0.41188515, 0.57006752, 0.77704178,
            0.11667957, 0.36433752, 0.70439071, 0.18719224, 0.58173016,
            0.40283262, 0.46053426, 0.68556492, 0.63334248, 0.17605277],
           [0.08465241, 0.31251512, 0.12272823, 0.1646526 , 0.85516812,
            0.82338084, 0.20111564, 0.2381411 , 0.49887095, 0.84785784,
            0.54274134, 0.22009211, 0.53411374, 0.1462928 , 0.73148624,
            0.06411698, 0.98673703, 0.94257941, 0.41232528, 0.59251276],
           [0.51911698, 0.00497327, 0.71563274, 0.71833516, 0.22202402,
            0.23153157, 0.8860696 , 0.21407144, 0.43139451, 0.54771503,
            0.0023732 , 0.91865715, 0.55315849, 0.09274778, 0.15400447,
            0.80817989, 0.78532842, 0.81955372, 0.74900367, 0.13604503],
           [0.23071207, 0.60006782, 0.62797911, 0.03484121, 0.5889279 ,
            0.35748095, 0.77058942, 0.17543603, 0.10067639, 0.68753007,
            0.84734377, 0.86766037, 0.64797608, 0.6000665 , 0.54686132,
            0.63090724, 0.8488857 , 0.9842188 , 0.66247635, 0.7731022 ],
           [0.55235403, 0.32934084, 0.37622766, 0.25042464, 0.04919092,
            0.62417957, 0.82564919, 0.57190411, 0.92979431, 0.52062496,
            0.4442672 , 0.26553488, 0.04886524, 0.2756087 , 0.42314851,
            0.87437479, 0.61582552, 0.61416708, 0.72970055, 0.35970778]])
    array([[0.85221987, 0.47412003, 0.71233837, 0.59094892],
           [0.39387594, 0.99322661, 0.75225026, 0.00995347],
           [0.24524296, 0.96471994, 0.41466874, 0.99579889],
           [0.01912096, 0.99213349, 0.61235698, 0.06214374]])

%% Cell type:code id: tags:

``` python
dataset.y # Extracts the labels (y) from the NumpyDataset Object
```

%% Output

    array([[0.57028775],
           [0.97620049],
           [0.56774589],
           [0.94031077],
           [0.62225   ],
           [0.60227171],
           [0.85258265],
           [0.89053693],
           [0.34354417],
           [0.78970471],
           [0.96924254],
           [0.2682545 ],
           [0.78759852],
           [0.46789658],
           [0.84192935],
           [0.0600129 ],
           [0.24827414],
           [0.71618577],
           [0.73840968],
           [0.92852733]])
    array([[0.75443686],
           [0.78473712],
           [0.6223576 ],
           [0.53884944]])

%% Cell type:markdown id: tags:

## Weights of a dataset - w
So apart from `X` and `y` which are the data and the labels, you can also assign weights `w` to each data instance. The dimension of `w` is same as that of `y`(which is Nx1 where N is the number of data instances).

**NOTE:** By default `w` is a vector initialized with equal weights (all being 1).

%% Cell type:code id: tags:

``` python
dataset.w # printing the weights that are assigned by default. Notice that they are a vector of 1's
```

%% Output

    array([[1.],
           [1.],
           [1.],
           [1.],
           [1.],
           [1.],
           [1.],
           [1.],
           [1.],
           [1.],
           [1.],
           [1.],
           [1.],
           [1.],
           [1.],
           [1.],
           [1.],
           [1.],
           [1.],
           [1.]])

%% Cell type:code id: tags:

``` python
w = np.random.random((20,)) # initializing weights with random vector of size 20x1
w = np.random.random((4,)) # initializing weights with random vector of size 20x1
dataset_with_weights = NumpyDataset(data, labels, w) # creates numpy dataset object
```

%% Cell type:code id: tags:

``` python
dataset_with_weights.w
```

%% Output

    array([[0.76612614],
           [0.17274575],
           [0.22208527],
           [0.37591921],
           [0.69610302],
           [0.97578691],
           [0.9604248 ],
           [0.51567996],
           [0.49012772],
           [0.16001986],
           [0.57317235],
           [0.95770078],
           [0.76981188],
           [0.45057093],
           [0.1064111 ],
           [0.85462948],
           [0.49530765],
           [0.57847932],
           [0.99006037],
           [0.55636807]])
    array([[0.8369533 ],
           [0.52828242],
           [0.43185016],
           [0.99442685]])

%% Cell type:markdown id: tags:

## Iterating over NumpyDataset
In order to iterate over NumpyDataset, we use `itersamples` method. We iterate over 4 quantities, namely `X`, `y`, `w` and `ids`. The first three quantities are the same as discussed above and `ids` is the id of the data instance. By default the id is given in order starting from `1`

%% Cell type:code id: tags:

``` python
for x, y, w, id in dataset.itersamples():
    print(x, y, w, id)
```

%% Output

    (array([0.85221987, 0.47412003, 0.71233837, 0.59094892]), array([0.75443686]), array([1.]), 0)
    (array([0.39387594, 0.99322661, 0.75225026, 0.00995347]), array([0.78473712]), array([1.]), 1)
    (array([0.24524296, 0.96471994, 0.41466874, 0.99579889]), array([0.6223576]), array([1.]), 2)
    (array([0.01912096, 0.99213349, 0.61235698, 0.06214374]), array([0.53884944]), array([1.]), 3)

%% Cell type:markdown id: tags:

You can also extract the ids by `dataset.ids`. This would return a numpy array consisting of the ids of the data instances.

%% Cell type:code id: tags:

``` python
dataset.ids
```

%% Output

    array([0, 1, 2, 3], dtype=object)

%% Cell type:markdown id: tags:

## MNIST Example
Just to get a better understanding, lets take read MNIST data and use `NumpyDataset` to store the data.

%% Cell type:code id: tags:

``` python
from tensorflow.examples.tutorials.mnist import input_data
```

%% Cell type:code id: tags:

``` python
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
```

%% Output

    Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
    Extracting MNIST_data/train-images-idx3-ubyte.gz
    Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
    Extracting MNIST_data/train-labels-idx1-ubyte.gz
    Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
    Extracting MNIST_data/t10k-images-idx3-ubyte.gz
    Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
    Extracting MNIST_data/t10k-labels-idx1-ubyte.gz

%% Cell type:code id: tags:

``` python
# Load the numpy data of MNIST into NumpyDataset
train = NumpyDataset(mnist.train.images, mnist.train.labels)
valid = NumpyDataset(mnist.validation.images, mnist.validation.labels)
```

%% Cell type:code id: tags:

``` python
import matplotlib.pyplot as plt
```

%% Output

    /home/skand/anaconda2/lib/python2.7/site-packages/matplotlib/font_manager.py:281: UserWarning: Matplotlib is building the font cache using fc-list. This may take a moment.
      'Matplotlib is building the font cache using fc-list. '

%% Cell type:code id: tags:

``` python
# Visualize one sample
sample = np.reshape(train.X[5], (28, 28))
plt.imshow(sample)
plt.show()
```

%% Output


%% Cell type:code id: tags:

``` python
train.ids
```

%% Output

    array([0, 1, 2, ..., 54997, 54998, 54999], dtype=object)

%% Cell type:code id: tags:

``` python
```