#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Functionality to perform sketched measurements."""
from collections import defaultdict
import torch
import torch_dct as dct
from .linops import ByBlockLinOp
from .utils import (
COMPLEX_DTYPES,
BadSeedError,
BadShapeError,
gaussian_noise,
phase_noise,
rademacher_noise,
randperm,
)
# ##############################################################################
# # IID NOISE LINOPS
# ##############################################################################
[docs]
class RademacherNoiseLinOp(ByBlockLinOp):
"""Random linear operator with i.i.d. Rademacher entries.
.. warning::
Since this linop uses random generators and seeds to fetch the blocks,
it is important that two different instances do not overlap in seeds,
to prevent correlated noise. Use sufficiently far away seeds and
``register=True`` to test this behaviour.
:param shape: Shape of the linop as ``(h, w)``.
:param seed: Random seed used in :meth:`get_block` to sample random blocks.
:param by_row: See :class:`skerch.linops.ByBlockLinOp`.
:param register: If true, when the linop is created, its seed range
(going from ``seed`` to ``seed + max(h, w)``) is added to a class-wide
register, which raises a :class:`skerch.utils.BadSeedError` if there
are any other instances of this class with overlapping ranges. If
false, this behaviour is disabled.
"""
REGISTER = defaultdict(list)
[docs]
@classmethod
def check_register(cls):
"""Checks if two different-seeded linops have overlapping seeds."""
for reg_type, reg in cls.REGISTER.items():
sorted_reg = sorted(reg, key=lambda x: x[0])
for (_, end1), (beg2, _) in zip(sorted_reg[:-1], sorted_reg[1:]):
if end1 >= beg2:
clsname = cls.__name__
msg = (
f"Overlapping seeds when creating {clsname}! "
f"({reg_type}, {sorted_reg}). This is not necessarily "
"an issue, but may lead to different-seeded random "
"linops generating the same rows or columns. To "
"prevent this, ensure that the random seeds of "
"different noise linops are separated by more than the "
"number of rows/columns. To disable this behaviour, "
"initialize with register=False."
)
raise BadSeedError(msg)
def __init__(
self,
shape,
seed,
by_row=False,
batch=None,
blocksize=1,
register=True,
):
"""Initializer. See class docstring."""
super().__init__(shape, by_row, batch, blocksize)
self.seed = seed
#
if register:
seed_end = seed + (self.shape[0] if self.by_row else self.shape[1])
self.__class__.REGISTER["default"].append((seed, seed_end))
self.check_register()
[docs]
def get_block(self, block_idx, input_dtype, input_device):
"""Samples a vector with Rademacher i.i.d. noise.
See base class definition for details.
"""
idxs = self.get_vector_idxs(block_idx)
h, w = self.shape
bsize = len(idxs)
#
out_shape = (bsize, w) if self.by_row else (h, bsize)
result = ( # device always CPU to ensure determinism across devices
rademacher_noise(
out_shape, seed=self.seed + idxs.start, device="cpu"
)
.to(input_dtype)
.to(input_device)
)
return result
def __repr__(self):
"""Returns a string in the form <classname(shape), attr=value, ...>."""
clsname = self.__class__.__name__
byrow_s = ", by row" if self.by_row else ", by col"
batch_s = "" if self.batch is None else f", batch={self.batch}"
block_s = f", blocksize={self.blocksize}"
seed_s = f", seed={self.seed}"
#
feats = f"{byrow_s}{batch_s}{block_s}{seed_s}"
s = f"<{clsname}({self.shape[0]}x{self.shape[1]}){feats}>"
return s
[docs]
class GaussianNoiseLinOp(RademacherNoiseLinOp):
"""Random linear operator with i.i.d. Gaussian entries.
Like :class:`RademacherNoiseLinOp`, but with Gaussian noise.
"""
REGISTER = defaultdict(list)
def __init__(
self,
shape,
seed,
by_row=False,
batch=None,
blocksize=1,
register=True,
mean=0.0,
std=1.0,
):
"""Initializer. See class docstring."""
super().__init__(shape, seed, by_row, batch, blocksize, register)
self.mean = mean
self.std = std
[docs]
def get_block(self, block_idx, input_dtype, input_device):
"""Samples a vector with Gaussian i.i.d. noise.
See base class for details.
"""
idxs = self.get_vector_idxs(block_idx)
h, w = self.shape
bsize = len(idxs)
#
out_shape = (bsize, w) if self.by_row else (h, bsize)
result = gaussian_noise( # device always CPU to ensure determinism
out_shape,
self.mean,
self.std,
seed=self.seed + idxs.start,
dtype=input_dtype,
device="cpu",
).to(input_device)
return result
def __repr__(self):
"""Returns a string in the form <classname(shape), attr=value, ...>."""
clsname = self.__class__.__name__
byrow_s = ", by row" if self.by_row else ", by col"
batch_s = "" if self.batch is None else f", batch={self.batch}"
block_s = f", blocksize={self.blocksize}"
seed_s = f", seed={self.seed}"
stats_s = f", mean={self.mean}, std={self.std}"
#
feats = f"{byrow_s}{batch_s}{block_s}{seed_s}{stats_s}"
s = f"<{clsname}({self.shape[0]}x{self.shape[1]}){feats}>"
return s
[docs]
class PhaseNoiseLinOp(RademacherNoiseLinOp):
"""Random linear operator with i.i.d. complex entries in the unit circle.
Like :class:`RademacherNoiseLinOp`, but with phase noise. Must be of
complex datatype.
:param conj: For the same seed, the linear operators with true and false
``conj`` values are complex conjugates of each other.
"""
REGISTER = defaultdict(list)
def __init__(
self,
shape,
seed,
by_row=False,
batch=None,
blocksize=1,
register=True,
conj=False,
):
"""Initializer. See class docstring."""
super().__init__(shape, seed, by_row, batch, blocksize, register)
self.conj = conj
[docs]
def get_block(self, block_idx, input_dtype, input_device):
"""Samples a vector with i.i.d. phase noise.
See base class definition for details.
"""
idxs = self.get_vector_idxs(block_idx)
h, w = self.shape
bsize = len(idxs)
#
out_shape = (w, bsize) if self.by_row else (h, bsize)
result = phase_noise( # device always CPU to ensure determinism
out_shape, self.seed + idxs.start, input_dtype, device="cpu"
).to(input_device)
#
if self.conj:
result = result.conj()
return result
def __repr__(self):
"""Returns a string in the form <classname(shape), attr=value, ...>."""
clsname = self.__class__.__name__
byrow_s = ", by row" if self.by_row else ", by col"
batch_s = "" if self.batch is None else f", batch={self.batch}"
block_s = f", blocksize={self.blocksize}"
seed_s = f", seed={self.seed}"
conj_s = f", conj={self.conj}"
#
feats = f"{byrow_s}{batch_s}{block_s}{seed_s}{conj_s}"
s = f"<{clsname}({self.shape[0]}x{self.shape[1]}){feats}>"
return s
# ##############################################################################
# # SSRFT
# ##############################################################################
[docs]
class SSRFT:
r"""Scrambled Subsampled Randomized Fourier Transform (SSRFT).
This static class implements the forward and adjoint SSRFT, as described
in `[TYUC2019, 3.2] <https://arxiv.org/abs/1902.08651>`_:
.. math::
\text{SSRFT} = R\,\mathcal{F}\,\Pi\,\mathcal{F}\,\Pi'
Where :math:`R` is a random index-picker, \mathcal{F} is either a
DCT or a FFT (if ``x`` is complex), and :math:`\Pi, \Pi'` are
random permutations which also multiply entries by Rademacher or
phase noise (if ``x`` is complex).
"""
[docs]
@staticmethod
def ssrft(x, out_dims, seed=0b1110101001010101011, norm="ortho"):
r"""Forward SSRFT (see class docstring for definition).
:param x: Matrix to be projected, such that ``y = SSRFT @ x``
:param out_dims: Number of rows in ``y`` with ``rows(y) <= rows(x)``
:param seed: Random seed for the SSRFT.
:param norm: Norm for the FFT and DCT. Currently only ``ortho`` is
supported to ensure orthogonality.
"""
if norm != "ortho":
raise NotImplementedError("Unsupported norm! use ortho")
#
n = x.shape[-1]
if out_dims > n or out_dims <= 0:
raise ValueError(
"out_dims can't be larger than input dimension or <=0!"
)
# make sure all sources of randomness are CPU, to ensure cross-device
# consistency of the operator
seeds = [seed + i for i in range(5)]
if x.dtype in COMPLEX_DTYPES:
# first scramble: permute, phase noise, and FFT
x = x[..., randperm(n, seed=seeds[0], device="cpu")]
x = x * phase_noise(
x.shape[-1],
seed=seeds[1],
dtype=x.dtype,
device="cpu",
conj=False,
).to(x.device)
x = torch.fft.fft(x, norm=norm)
# second scramble: permute, phase noise, and FFT
x = x[..., randperm(n, seed=seeds[2], device="cpu")]
x = x * phase_noise(
x.shape[-1],
seed=seeds[3],
dtype=x.dtype,
device="cpu",
conj=False,
).to(x.device)
x = torch.fft.fft(x, norm=norm)
else:
# first scramble: permute, rademacher, and DCT
x = x[..., randperm(n, seed=seeds[0], device="cpu")]
x = x * rademacher_noise(x.shape[-1], seeds[1], device="cpu").to(
x.device
)
x = dct.dct(x, norm=norm)
# second scramble: permute, rademacher and DCT
x = x[..., randperm(n, seed=seeds[2], device="cpu")]
x = x * rademacher_noise(x.shape[-1], seeds[3], device="cpu").to(
x.device
)
x = dct.dct(x, norm=norm)
# extract random indices and return
x = x[..., randperm(n, seed=seeds[4], device="cpu")[:out_dims]]
return x
[docs]
@staticmethod
def issrft(x, out_dims, seed=0b1110101001010101011, norm="ortho"):
r"""Adjoint SSRFT (see class docstring for definition).
Inversion of the SSRFT, such that for a square SSRFT,
``x == issrft(ssrft(x))`` holds.
Note that this means that, for complex ``x``, the adjoint operation
involves complex conjugation as well. See class docstring and
:meth:`ssrft` for more details.
:param out_dims: In this case, instead of random index-picker, which
reduces dimension, we have an index embedding, which increases
dimension by placing the ``x`` entries in the corresponding indices
(and leaving the rest to zeros). For this reason,
``out_dims >= len(x)`` is required.
"""
if norm != "ortho":
raise NotImplementedError("Unsupported norm! use ortho")
n = x.shape[-1]
if out_dims < n:
raise ValueError("out_dims can't be smaller than input dimension!")
# make sure all sources of randomness are CPU, to ensure cross-device
# consistency of the operator
seeds = [seed + i for i in range(5)]
# create output and embed random indices
out = torch.zeros(
x.shape[:-1] + (out_dims,), dtype=x.dtype, device=x.device
)
out[..., randperm(out_dims, seed=seeds[4], device="cpu")[:n]] = x
#
if x.dtype in COMPLEX_DTYPES:
# invert second scramble: iFFT, rademacher, and inverse permutation
out = torch.fft.ifft(out, norm=norm)
out = out * phase_noise(
out.shape[-1],
seed=seeds[3],
dtype=x.dtype,
device="cpu",
conj=True,
).to(x.device)
out = out[
...,
randperm(out_dims, seed=seeds[2], device="cpu", inverse=True),
]
# invert first scramble: iFFT, rademacher, and inverse permutation
out = torch.fft.ifft(out, norm=norm)
out = out * phase_noise(
out.shape[-1],
seed=seeds[1],
dtype=x.dtype,
device="cpu",
conj=True,
).to(x.device)
out = out[
...,
randperm(out_dims, seed=seeds[0], device="cpu", inverse=True),
]
else:
# invert second scramble: iDCT, rademacher, and inverse permutation
out = dct.idct(out, norm=norm)
out = out * rademacher_noise(
out.shape[-1], seeds[3], device="cpu"
).to(x.device)
out = out[
...,
randperm(out_dims, seed=seeds[2], device="cpu", inverse=True),
]
# invert first scramble: iDCT, rademacher, and inverse permutation
out = dct.idct(out, norm=norm)
out = out * rademacher_noise(
out.shape[-1], seeds[1], device="cpu"
).to(x.device)
out = out[
...,
randperm(out_dims, seed=seeds[0], device="cpu", inverse=True),
]
#
return out
[docs]
class SsrftNoiseLinOp(ByBlockLinOp):
"""Linop for the Scrambled Subsampled Randomized Fourier Transform (SSRFT).
This class encapsulates the forward and adjoint SSRFT transforms into a
single linear operator with fixed shape and orthonormal columns, which is
deterministic for the same dtype, shape and seed (also across different
torch devices).
See :class:`SSRFT` for more details.
.. note::
This linop can either be square or tall, but never fat (i.e. width must
be less or equal than height). Since the SSRFT cannot increase the
dimensionality of its input, the forward matmul of this linop is actually
the inverse SSRFT, and the adjoint matmul is the forward SSRFT.
This slight change in format that doesn't really affect the semantics of
the SSRFT, and it makes it more compatible with other noise linops, which
are typically also tall instead of fat. It is also more common to think
about orthogonal columns than rows. To make it fat,
:class:`skerch.linops.TransposedLinOp` can still be used.
.. note::
Unlike classes extending :class:`skerch.linops.ByBlockLinOp`, in this
case it is not efficient to apply this operator by row/column. Instead,
this implementation applies the SSRFT directly to the input, by vector,
but it also implements ``get_vector`` via one-hot vecmul to
facilitate parallel measurements and fit the standard interface for
``skerch`` measurement linops.
"""
REGISTER = defaultdict(list)
[docs]
@classmethod
def check_register(cls):
"""Checks if two different-seeded linops have overlapping seeds."""
for reg_type, reg in cls.REGISTER.items():
sorted_reg = sorted(reg, key=lambda x: x[0])
for (_, end1), (beg2, _) in zip(sorted_reg[:-1], sorted_reg[1:]):
if end1 >= beg2:
clsname = cls.__name__
msg = (
f"Overlapping seeds when creating {clsname}! "
f"({reg_type}, {sorted_reg}). This is not necessarily "
"an issue, but may lead to different-seeded random "
"linops generating the same rows or columns. To "
"prevent this, ensure that the random seeds of "
"different noise linops are separated by more than the "
"number of rows/columns. To disable this behaviour, "
"initialize with register=False."
)
raise BadSeedError(msg)
def __init__(
self, shape, seed, batch=None, blocksize=1, norm="ortho", register=True
):
"""Initializer. See class docstring."""
by_row = False
super().__init__(shape, by_row, batch, blocksize)
h, w = shape
if w > h:
raise BadShapeError(
"Width > height not supported in SSRFT! use transposition or "
"adjoint matmul."
)
self.seed = seed
self.norm = norm
#
if register:
seed_end = seed + (self.shape[0] if self.by_row else self.shape[1])
self.__class__.REGISTER["default"].append((seed, seed_end))
self.check_register()
[docs]
def get_block(self, block_idx, input_dtype, input_device):
"""Samples a SSRFT block.
See base class definition for details.
"""
idxs = self.get_vector_idxs(block_idx)
h, w = self.shape
bsize = len(idxs)
#
onehot_mat = torch.zeros(
(bsize, w), dtype=input_dtype, device=input_device
)
onehot_mat[range(bsize), idxs] = 1
#
result = SSRFT.issrft(
onehot_mat,
self.shape[0],
seed=self.seed,
norm=self.norm,
).transpose(0, 1)
return result
def __repr__(self):
"""Returns a string in the form <classname(shape), attr=value, ...>."""
clsname = self.__class__.__name__
byrow_s = ", by row" if self.by_row else ", by col"
batch_s = "" if self.batch is None else f", batch={self.batch}"
block_s = f", blocksize={self.blocksize}"
seed_s = f", seed={self.seed}"
norm_s = f", norm={self.norm}"
#
feats = f"{byrow_s}{batch_s}{block_s}{seed_s}{norm_s}"
s = f"<{clsname}({self.shape[0]}x{self.shape[1]}){feats}>"
return s