Analysing MNIST image dataset using Scarf#

%load_ext autotime

import scarf
import matplotlib.pyplot as plt
import pandas as pd
scarf.__version__
'0.28.9'
time: 1.65 s (started: 2024-01-24 15:41:01 +00:00)

1) Fetch MNIST dataset and convert to Zarr format#

The MNIST datasets consists of 60K grayscale images (28x28 pixel) of handwritten digits (0 through 9). These can be unraveled such that each digit is described by a 784 dimensional vector. This dataset is available to download through Scarf. We saved this data in the same format as the output of cellranger pipeline with the matrix saved in MTX format.

scarf.fetch_dataset('lecun_60K_mnist_images', save_path='scarf_datasets')
time: 29.3 s (started: 2024-01-24 15:41:03 +00:00)
reader = scarf.CrDirReader('scarf_datasets/lecun_60K_mnist_images')
writer = scarf.CrToZarr(
    reader,
    zarr_loc='scarf_datasets/lecun_60K_mnist_images/data.zarr',
    chunk_size=(2000, 1000),
)
writer.dump(batch_size=1000)
WARNING: feature_types extraction failed from features.tsv.gz in column 2
time: 3.57 s (started: 2024-01-24 15:41:32 +00:00)

2) Load the Zarr file into the DataStore object#

ds = scarf.DataStore(
    'scarf_datasets/lecun_60K_mnist_images/data.zarr',
    min_cells_per_feature=1000,
    min_features_per_cell=10,
    nthreads=4,
)
ds
WARNING: No matches found for pattern MT-|mt. Will not add/update percentage feature
WARNING: No matches found for pattern RPS|RPL|MRPS|MRPL. Will not add/update percentage feature
DataStore has 60000 (60000) cells with 1 assays: RNA
   Cell metadata:
            'I', 'ids', 'names', 'RNA_nCounts', 'RNA_nFeatures', 
          
   RNA assay has 467 (784) features and following metadata:
            'I', 'ids', 'names', 'dropOuts', 'nCells', 
          
time: 799 ms (started: 2024-01-24 15:41:36 +00:00)

The labels for each image are embedded in their names. We will extract them add them as a separate column in cells attribute table

ds.cells.insert('digit_label',
                [int(x.rsplit('_', 1)[-1])-1 for x
                 in ds.cells.fetch_all('names')], overwrite=True)
time: 51.2 ms (started: 2024-01-24 15:41:37 +00:00)

3) Creating neighbourhood graph#

We will not perform any cell filtering here. We will also skip feature selection step and will use all the valid features. Since, we imported the data as an RNAassay, PCA will be performed on the data. Before we begin the graph creation step, we will turn off the default normlization for an RNAassay.

# Set normalization method to a dummy function that returns unnormalized data
ds.RNA.normMethod = scarf.assay.norm_dummy

ds.make_graph(feat_key='I', k=31, dims=25, n_centroids=100, show_elbow_plot=True)
INFO: ANN recall: 100.00%
../_images/558d719dce6aff1ea0051effaab9bb96d7492c380b0c3df2313ed1aab2fd76ae.png
time: 46 s (started: 2024-01-24 15:41:37 +00:00)

The elbow plot above suggests that taking first 10 PC dimensions might have been optimal this dataset. We wanted to capture even the very fine difference (at risk of larger noise) between the digits and hence do not rerun this with dims=10


4) UMAP embedding and clustering#

We will now generate a 2D embedding of the neighbourhood graph of the MNIST images. This will allow to to ascertain visually, how accurate Scarf was in estimating the underlying manifold of this dataset. There are two critical differences between calling UMAP from with Scarf and directly from UMAP library:

  • When calling from Scarf, only the neighbourhood graph is provided to the core UMAP algorithm rather a normalized/scaled/reduced data matrix.

  • Scarf performs a Kmeans clustering while identifying neighbours, this allows Scarf to generate an informed initial embedding coordinates for the data based on the KMeans centroids. UMAP package, on the other hand, calculates spectral layout for the graph.

ds.run_umap(n_epochs=300, spread=1, min_dist=0.05, parallel=True)
	completed  0  /  300 epochs
	completed  30  /  300 epochs
	completed  60  /  300 epochs
	completed  90  /  300 epochs
	completed  120  /  300 epochs
	completed  150  /  300 epochs
	completed  180  /  300 epochs
	completed  210  /  300 epochs
	completed  240  /  300 epochs
	completed  270  /  300 epochs
time: 1min 27s (started: 2024-01-24 15:42:23 +00:00)

Before we visualize the UMAP embeddings, we will also perform clustering on the data (neighbourhood graph) using the default Paris algorithm. Here we choose to perform overclustering of the data so that we can capture fine differences within the individual digit classes.

ds.run_clustering(n_clusters=20)
time: 15.2 s (started: 2024-01-24 15:43:51 +00:00)

Relabeling the cluster ids to match their corresponding digit labels

ds.smart_label(
    to_relabel='RNA_cluster',
    base_label='digit_label',
    new_col_name='cluster_label',
)
time: 49.7 ms (started: 2024-01-24 15:44:06 +00:00)
ds.plot_layout(
    layout_key='RNA_UMAP', color_by=['digit_label', 'cluster_label'],
    do_shading=True, shade_npixels=300, legend_onside=False,
    width=4, height=4, cmap='tab20'
)
/home/docs/checkouts/readthedocs.org/user_builds/scarf/envs/latest/lib/python3.10/site-packages/scarf/plots.py:594: FutureWarning: The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.
  centers = df[[x, y, vc]].groupby(vc).median().T
/home/docs/checkouts/readthedocs.org/user_builds/scarf/envs/latest/lib/python3.10/site-packages/scarf/plots.py:594: FutureWarning: The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.
  centers = df[[x, y, vc]].groupby(vc).median().T
../_images/8ff177cfed71c72fa8334219425df81198b3161e764488a2f02da19d7fc5c9fe.png
time: 2.72 s (started: 2024-01-24 15:44:06 +00:00)

The UMAP embedding shows that images from the same individual digit classes were grouped together very nicely. We obtained a cluster of digits classes 4,7 and 9 and another of 3, 5 and 8 classses, similar observation has been made before in the link below. Since, the images from classes 0 and 1 are well separated once can infer that the global structure is also well preserved.

https://umap-learn.readthedocs.io/en/latest/auto_examples/plot_mnist_example.html

We can use plot_cluster_tree method to show an explicit realtionship between the clusters of images. This method leverages the hierarchical relationship between the individual images as calculated by the Paris algorithm. Each circle/pie (sized by number of images in the cluster) represents a clusters. The colurs inside each pie indicate the proportion actual digit classes.

ds.plot_cluster_tree(
    cluster_key='cluster_label',
    fill_by_value='digit_label',
)
../_images/10e20b72e9acdc9a7d5664f827bcbf8bffc7cc180a165568c8e6ff447d460297.png
time: 3.97 s (started: 2024-01-24 15:44:09 +00:00)

Finally, lets visualize images from each of the cluster. To do so, we take all images from a given cluster and merge them (think of it can creating an overlay of all the images from that cluster).

clusts = pd.Series(ds.cells.fetch_all('cluster_label'))
digits = pd.Series(ds.cells.fetch_all('digit_label'))
time: 8.8 ms (started: 2024-01-24 15:44:13 +00:00)
fig = plt.figure(figsize=(8,2))
for n,i in enumerate(sorted(clusts.unique())):
    mean_map = ds.RNA.rawData[((clusts == i) & (digits == int(i[0]))).values].mean(axis=0)
    mean_map = mean_map.compute().reshape(28, 28)
    ax = fig.add_subplot(2, 10, n+1)
    ax.imshow(mean_map, cmap='binary')
    ax.set_axis_off()
    ax.set_title(i, fontsize=10)
plt.tight_layout()
plt.show()
../_images/cd700a8a85ac20f3e8a7b4511bc2b235dda9b41684ad080e272930fc0518dd10.png
time: 4.55 s (started: 2024-01-24 15:44:13 +00:00)

It is quite clear that Scarf’s clsutering was able to identify naunced differences between the images. For example, cluster 1a and 1c captured those images of an ‘upright’ 1s while cluster 1b and 1d captured ‘slanted’ 1s.

This vignette helped us learn two things:

  • Scarf is a flexible package that can handle analysis of diverse kinds of datasets.

  • Results from this dataset show that Scarf can perform quick, memory efficient and meaningful analysis of large-scale datasets.


That is all for this vignette.