.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/extending.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_extending.py: Extending With Custom Functionality ======================================= In this example we see how to extend ``skerch`` with custom functionality. Specifically: * Adding a new recovery method for low-rank sketched algorithms * Adding a new noise source This showcases the versatility of ``skerch``: Not only it works on linops that satisfy very simple interfaces but it can also be extended and modified with low coding overhead. .. GENERATED FROM PYTHON SOURCE LINES 16-25 .. code-block:: Python from collections import defaultdict import torch from skerch.algorithms import SketchedAlgorithmDispatcher, ssvd from skerch.measurements import GaussianNoiseLinOp from skerch.synthmat import RandomLordMatrix .. GENERATED FROM PYTHON SOURCE LINES 26-35 ############################################################################## Creation of test data --------------------- We start by sampling an (approximately) low-rank matrix using :class:`skerch.synthmat.RandomLordMatrix`, and then running the built-in :func:`skerch.algorithms.ssvd` via Nystrom recovery with Rademacher noise, yielding good accuracy: .. GENERATED FROM PYTHON SOURCE LINES 36-64 .. code-block:: Python SEED = 124816315799 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.complex128 SHAPE, RANK, DECAY = (100, 200), 10, 0.1 SKETCH_MEAS, TEST_MEAS = 50, 30 mat = RandomLordMatrix.exp( SHAPE, RANK, DECAY, symmetric=False, device=DEVICE, dtype=DTYPE, psd=False )[0] sU, sS, sVh = ssvd( mat, DEVICE, DTYPE, SKETCH_MEAS, seed=SEED, noise_type="rademacher", recovery_type="nystrom", lstsq_rcond=1e-10, ) smat = (sU * sS) @ sVh print( "Relative error (Rademacher+Nystrom):", (torch.dist(mat, smat) / mat.norm()).item(), ) .. rst-class:: sphx-glr-script-out .. code-block:: none Relative error (Rademacher+Nystrom): 0.010568960814429136 .. GENERATED FROM PYTHON SOURCE LINES 65-76 ############################################################################## Testing a new recovery method ----------------------------- Let's now test a new theory: Since random matrices are so cool, maybe a random sample is also a good recovery? With ``skerch``, all we need to do is: 1. Define our new recovery method 2. Extend the dispatcher to provide the recovery as needed 3. Feed the requested string and dispatcher to the existing SVD algorithm .. GENERATED FROM PYTHON SOURCE LINES 77-120 .. code-block:: Python def bogo_recovery(sketch_right, sketch_left, *args, **kwargs): """Just guess the output. How bad could it be?""" U = torch.linalg.qr(torch.randn_like(sketch_right))[0] Vh = torch.linalg.qr(torch.randn_like(sketch_left.H))[0].H S = torch.randn_like(U[0]).abs().sort(descending=True)[0] if kwargs["as_svd"]: return U, S, Vh else: return U * S, Vh class BogoDispatcher(SketchedAlgorithmDispatcher): """A custom dispatcher that provides ``bogo_recovery``.""" @staticmethod def recovery(recovery_type, hermitian=False): """Returns recovery funtion with given specs.""" if recovery_type == "bogo": return bogo_recovery, None else: raise ValueError(f"Unknown recovery! {recovery_type}") bU, bS, bVh = ssvd( mat, DEVICE, DTYPE, SKETCH_MEAS, seed=SEED, noise_type="rademacher", recovery_type="bogo", # changed! lstsq_rcond=1e-10, dispatcher=BogoDispatcher, # changed! ) bmat = (bU * bS) @ bVh print( "Relative error (Rademacher+Bogo):", (torch.dist(mat, bmat) / mat.norm()).item(), ) .. rst-class:: sphx-glr-script-out .. code-block:: none Relative error (Rademacher+Bogo): 2.381213423599033 .. GENERATED FROM PYTHON SOURCE LINES 121-130 Oops! It seems that ``bogo_recovery`` is not a good method, and we should stick to the big guns. Good to know, and all in a couple dozen lines of code! .. note:: Currently, recovery methods and dispatcher must fulfill particular interfaces (see :mod:`skerch.recovery` for examples). To try methods that deviate from those, the best practice is probably copypasting the ``ssvd`` function and adjusting the parts that break compatibility. .. GENERATED FROM PYTHON SOURCE LINES 134-149 ############################################################################## Testing a new measurement distribution -------------------------------------- OK but hear me out: since random matrices are so cool, maybe some other arbitrary form of random measurement also provides a good recovery? Or maybe we suspect that a particular type of noise is best suited for a particular setting of linear operators and algorithms? With ``skerch``, this can be easily tested: 1. Define our new measurement linop by extending :class:`skerch.linops.ByBlockLinOp` 2. Extend the dispatcher to provide the measurement linop as needed 3. Feed the requested string and dispatcher to the existing SVD algorithm .. GENERATED FROM PYTHON SOURCE LINES 150-208 .. code-block:: Python class GaussemacherNoiseLinOp(GaussianNoiseLinOp): """Gaussian noise with a hard lower bound on the magnitude.""" REGISTER = defaultdict(list) THRESHOLD = 0.5 def __init__( self, shape, seed, by_row=False, batch=None, blocksize=1, register=True ): super().__init__( shape, seed, by_row, batch, blocksize, register, 0.0, 1.0 ) def get_block(self, block_idx, input_dtype, input_device): result = super().get_block(block_idx, input_dtype, input_device) mag = result.abs() scale = torch.where( mag < self.THRESHOLD, self.THRESHOLD / (mag + 1e-7), torch.ones_like(mag), ) return result * scale class GaussemacherDispatcher(SketchedAlgorithmDispatcher): @staticmethod def mop(noise_type, hw, seed, dtype, blocksize=1, register=False): """ """ if "gaussemacher" in noise_type: mop = GaussemacherNoiseLinOp( hw, seed, blocksize=blocksize, register=register ) else: raise ValueError(f"Unknown noise type! {noise_type}") return mop gU, gS, gVh = ssvd( mat, DEVICE, DTYPE, SKETCH_MEAS, seed=SEED, noise_type="gaussemacher", # changed! recovery_type="nystrom", lstsq_rcond=1e-10, dispatcher=GaussemacherDispatcher, # changed! ) gmat = (gU * gS) @ gVh print( "Relative error (Gaussemacher(0.5)+Nystrom):", (torch.dist(mat, gmat) / mat.norm()).item(), ) .. rst-class:: sphx-glr-script-out .. code-block:: none Relative error (Gaussemacher(0.5)+Nystrom): 0.0014695705235111575 .. GENERATED FROM PYTHON SOURCE LINES 209-219 So this actually works! Maybe random matrices aren't that bad after all... .. note:: While ``skerch`` only requires the bare-minimum interface of ``.shape = (height, width)`` and ``@`` matmul for its *inputs*, the interface for its *measurement linops* is marginally more complex: in order to support batched measurements, new measurement linops are also expected to implement a ``get_blocks`` iterator, as shown in this example (see also :ref:`Linear Operators and Matrix-Freedom`). .. GENERATED FROM PYTHON SOURCE LINES 223-236 ############################################################################## In Summary: ----------- * We have seen how to extend ``skerch`` with new low-rank recovery methods with just a few lines of code * Similarly, we can also add new noise sources with little effort * Still, some interfaces must be satisfied to run built-in code. Whenever your interfaces collide (e.g. you require a new type of input), best advice is to copypaste and modify the algorithm, which thanks to the modularity of ``skerch`` is also fairly low-effort .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.048 seconds) .. _sphx_glr_download_examples_extending.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: extending.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: extending.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: extending.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_