.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/deep_learning.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_deep_learning.py: Deep Learning ================= Recall that ``skerch`` operations admit any linear operator that implements left- and right matrix multiplication via the ``@`` operator, and the ``.shape = (height, width)`` attribute. The `CurvLinOps library `_ provides curvature linear operators that satisfy this requirement, for a variety of very useful objects such as the Hessian, the Jacobian and the Generalized Gauss-Newton (GGN). It is also implemented with PyTorch as a backend, but the ``curvlinops`` operators actually implement SciPy's `LinearOperator `_ interface, so not only they are compatible with ``skerch``: they can also be used with most of the LinAlg routines available in SciPy. In this example, we show how to obtain an accurate and full Hessian eigendecomposition from a deep learning setup, using ``skerch``'s sketched EIGH combined with ``curvlinops``. To verify the high quality of the resulting sketched approximation, we apply the *a-posteriori* test method discussed in :ref:`Sketched Low-Rank Decompositions` (see also :mod:`skerch.a_posteriori`). This small-scale example, which runs in under a minute on CPU, is already borderline intractable using traditional linear algebra routines. Thanks to the ``pytorch`` backend, we can also use GPU acceleration with minimal changes reaching substantially larger scales. And we can also reach even larger scales if we make use of out-of-core, distributed computations (see :ref:`Out-of-core Operations via HDF5` for guidelines and `this paper `_ for an application example). .. GENERATED FROM PYTHON SOURCE LINES 35-45 .. code-block:: Python import matplotlib.pyplot as plt import torch from curvlinops import HessianLinearOperator from skerch.a_posteriori import apost_error from skerch.algorithms import seigh from skerch.linops import CompositeLinOp, DiagonalLinOp from skerch.utils import gaussian_noise, rademacher_noise .. GENERATED FROM PYTHON SOURCE LINES 46-54 ############################################################################## Setup ----- Curvature matrices are a function of a dataset, model and loss function. For this example, we create a synthetic dataset and a model with >30000 parameters, resulting in a Hessian of >1 billion entries. .. GENERATED FROM PYTHON SOURCE LINES 55-89 .. code-block:: Python SEED = 12345780 DTYPE = torch.float32 # medium-scale config to run on autodoc CPU server. # change to "cuda" and bigger data/MLP to test larger scales locally DEVICE = "cpu" DATASET_SHAPE = (2, 50, 784) # num_batches, batch_size, xdim MLP_DIMS = (784, 50, 10) SKETCH_DIMS, TEST_MEAS = 2000, 30 # synthetic dataset, model and loss function X = gaussian_noise(DATASET_SHAPE, seed=SEED, dtype=DTYPE, device=DEVICE) Y = rademacher_noise( DATASET_SHAPE[:-1] + (MLP_DIMS[-1],), seed=SEED + 1, device=DEVICE ).to(DTYPE) dataloader = list(zip(X, Y)) model = torch.nn.Sequential( *sum( [ [torch.nn.Linear(i, o), torch.nn.ReLU()] for i, o in zip(MLP_DIMS[:-1], MLP_DIMS[1:]) ], [], )[:-1] ).to(DEVICE) loss_function = torch.nn.MSELoss(reduction="mean").to(DEVICE) params = [p for p in model.parameters() if p.requires_grad] num_params = sum(p.numel() for p in params) print(model) print("Number of trainable parameters:", num_params) print("Number of Hessian entries:", num_params**2) .. rst-class:: sphx-glr-script-out .. code-block:: none Sequential( (0): Linear(in_features=784, out_features=50, bias=True) (1): ReLU() (2): Linear(in_features=50, out_features=10, bias=True) ) Number of trainable parameters: 39760 Number of Hessian entries: 1580857600 .. GENERATED FROM PYTHON SOURCE LINES 90-101 ############################################################################## Sketched Hessian Eigendecomposition ----------------------------------- Now we can create the Hessian LinOp and perform the ``skerch`` eigendecomposition! Some considerations: * If ``SKETCH_DIMS`` is too small, recovery quality may suffer * Reducing ``meas_blocksize`` helps against out-of-memory errors, but is slower (less parallel measurements at once) .. GENERATED FROM PYTHON SOURCE LINES 102-116 .. code-block:: Python H = HessianLinearOperator(model, loss_function, params, dataloader) print(H) ews, evs = seigh( H, DEVICE, DTYPE, SKETCH_DIMS, seed=SEED + 2, noise_type="rademacher", recovery_type="hmt", ) sH = CompositeLinOp((("Q", evs), ("Lbd", DiagonalLinOp(ews)), ("Qt", evs.T))) .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 117-124 ############################################################################## Error estimation ---------------- Looks good, but how good is our recovery? We now to estimate the error via a-posteriori test measurements, confirming it is very low: .. GENERATED FROM PYTHON SOURCE LINES 125-138 .. code-block:: Python (frob_sq, f_sq, err_sq), _ = apost_error( H, sH, DEVICE, DTYPE, num_meas=TEST_MEAS, seed=SEED + num_params ) fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(8, 3)) ax1.plot(ews.cpu()) ax2.imshow(((evs[-200:] * ews) @ evs[-200:].T).abs().log().cpu(), aspect="auto") fig.suptitle("Recovered Hessian eigenvalues and fragment of $log |H|$") rel_err = (err_sq / frob_sq) ** 0.5 print("Estimated Hessian norm:", frob_sq.item() ** 0.5) print("Estimated approximation error:", err_sq.item() ** 0.5) print("RELATIVE ERROR:", (err_sq / frob_sq).item() ** 0.5) .. image-sg:: /examples/images/sphx_glr_deep_learning_001.png :alt: Recovered Hessian eigenvalues and fragment of $log |H|$ :srcset: /examples/images/sphx_glr_deep_learning_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Estimated Hessian norm: 16.62060580585183 Estimated approximation error: 6.613247817472402e-05 RELATIVE ERROR: 3.978945100209745e-06 .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 20.490 seconds) .. _sphx_glr_download_examples_deep_learning.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: deep_learning.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: deep_learning.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: deep_learning.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_