Note
Go to the end to download the full example code.
Extending With Custom Functionality
In this example we see how to extend skerch with custom functionality.
Specifically:
Adding a new recovery method for low-rank sketched algorithms
Adding a new noise source
This showcases the versatility of skerch: Not only it works on linops
that satisfy very simple interfaces but it can also be extended
and modified with low coding overhead.
from collections import defaultdict
import torch
from skerch.algorithms import SketchedAlgorithmDispatcher, ssvd
from skerch.measurements import GaussianNoiseLinOp
from skerch.synthmat import RandomLordMatrix
Creation of test data
We start by sampling an (approximately) low-rank matrix using
skerch.synthmat.RandomLordMatrix, and then running the built-in
skerch.algorithms.ssvd() via Nystrom recovery with Rademacher noise,
yielding good accuracy:
SEED = 124816315799
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.complex128
SHAPE, RANK, DECAY = (100, 200), 10, 0.1
SKETCH_MEAS, TEST_MEAS = 50, 30
mat = RandomLordMatrix.exp(
SHAPE, RANK, DECAY, symmetric=False, device=DEVICE, dtype=DTYPE, psd=False
)[0]
sU, sS, sVh = ssvd(
mat,
DEVICE,
DTYPE,
SKETCH_MEAS,
seed=SEED,
noise_type="rademacher",
recovery_type="nystrom",
lstsq_rcond=1e-10,
)
smat = (sU * sS) @ sVh
print(
"Relative error (Rademacher+Nystrom):",
(torch.dist(mat, smat) / mat.norm()).item(),
)
Relative error (Rademacher+Nystrom): 0.010568960814429136
Testing a new recovery method
Let’s now test a new theory: Since random matrices are so cool, maybe a
random sample is also a good recovery? With skerch, all we need to do is:
Define our new recovery method
Extend the dispatcher to provide the recovery as needed
Feed the requested string and dispatcher to the existing SVD algorithm
def bogo_recovery(sketch_right, sketch_left, *args, **kwargs):
"""Just guess the output. How bad could it be?"""
U = torch.linalg.qr(torch.randn_like(sketch_right))[0]
Vh = torch.linalg.qr(torch.randn_like(sketch_left.H))[0].H
S = torch.randn_like(U[0]).abs().sort(descending=True)[0]
if kwargs["as_svd"]:
return U, S, Vh
else:
return U * S, Vh
class BogoDispatcher(SketchedAlgorithmDispatcher):
"""A custom dispatcher that provides ``bogo_recovery``."""
@staticmethod
def recovery(recovery_type, hermitian=False):
"""Returns recovery funtion with given specs."""
if recovery_type == "bogo":
return bogo_recovery, None
else:
raise ValueError(f"Unknown recovery! {recovery_type}")
bU, bS, bVh = ssvd(
mat,
DEVICE,
DTYPE,
SKETCH_MEAS,
seed=SEED,
noise_type="rademacher",
recovery_type="bogo", # changed!
lstsq_rcond=1e-10,
dispatcher=BogoDispatcher, # changed!
)
bmat = (bU * bS) @ bVh
print(
"Relative error (Rademacher+Bogo):",
(torch.dist(mat, bmat) / mat.norm()).item(),
)
Relative error (Rademacher+Bogo): 2.381213423599033
Oops! It seems that bogo_recovery is not a good method, and we should
stick to the big guns. Good to know, and all in a couple dozen lines
of code!
Note
Currently, recovery methods and dispatcher must fulfill particular
interfaces (see skerch.recovery for examples). To try methods
that deviate from those, the best practice is probably copypasting
the ssvd function and adjusting the parts that break compatibility.
Testing a new measurement distribution
OK but hear me out: since random matrices are so cool, maybe some other
arbitrary form of random measurement also provides a good recovery? Or
maybe we suspect that a particular type of noise is best suited for a
particular setting of linear operators and algorithms? With skerch,
this can be easily tested:
Define our new measurement linop by extending
skerch.linops.ByBlockLinOpExtend the dispatcher to provide the measurement linop as needed
Feed the requested string and dispatcher to the existing SVD algorithm
class GaussemacherNoiseLinOp(GaussianNoiseLinOp):
"""Gaussian noise with a hard lower bound on the magnitude."""
REGISTER = defaultdict(list)
THRESHOLD = 0.5
def __init__(
self, shape, seed, by_row=False, batch=None, blocksize=1, register=True
):
super().__init__(
shape, seed, by_row, batch, blocksize, register, 0.0, 1.0
)
def get_block(self, block_idx, input_dtype, input_device):
result = super().get_block(block_idx, input_dtype, input_device)
mag = result.abs()
scale = torch.where(
mag < self.THRESHOLD,
self.THRESHOLD / (mag + 1e-7),
torch.ones_like(mag),
)
return result * scale
class GaussemacherDispatcher(SketchedAlgorithmDispatcher):
@staticmethod
def mop(noise_type, hw, seed, dtype, blocksize=1, register=False):
""" """
if "gaussemacher" in noise_type:
mop = GaussemacherNoiseLinOp(
hw, seed, blocksize=blocksize, register=register
)
else:
raise ValueError(f"Unknown noise type! {noise_type}")
return mop
gU, gS, gVh = ssvd(
mat,
DEVICE,
DTYPE,
SKETCH_MEAS,
seed=SEED,
noise_type="gaussemacher", # changed!
recovery_type="nystrom",
lstsq_rcond=1e-10,
dispatcher=GaussemacherDispatcher, # changed!
)
gmat = (gU * gS) @ gVh
print(
"Relative error (Gaussemacher(0.5)+Nystrom):",
(torch.dist(mat, gmat) / mat.norm()).item(),
)
Relative error (Gaussemacher(0.5)+Nystrom): 0.0014695705235111575
So this actually works! Maybe random matrices aren’t that bad after all…
Note
While skerch only requires the bare-minimum interface of
.shape = (height, width) and @ matmul for its inputs, the
interface for its measurement linops is marginally more complex:
in order to support batched measurements, new measurement linops are
also expected to implement a get_blocks iterator, as shown in this
example (see also Linear Operators and Matrix-Freedom).
In Summary:
We have seen how to extend
skerchwith new low-rank recovery methods with just a few lines of codeSimilarly, we can also add new noise sources with little effort
Still, some interfaces must be satisfied to run built-in code. Whenever your interfaces collide (e.g. you require a new type of input), best advice is to copypaste and modify the algorithm, which thanks to the modularity of
skerchis also fairly low-effort
Total running time of the script: (0 minutes 0.048 seconds)