import math
import numpy as np
from .backend import get_module
[docs]
def fun_meyer(x, param, engine: str = "auto"):
"""
Compute a smooth window similar to the Meyer wavelet in frequency domain.
Parameters
----------
x : array_like
Incput grid, typically generated by `numpy.linspace`.
param : array_like of shape (4,)
Four increasing values `[p0, p1, p2, p3]` that define the window:
- The window is zero for `x <= p0` and `x >= p3`.
- The window is one for `p1 <= x <= p2`.
- It transitions smoothly from 0 to 1 over `[p0, p1]`.
- It transitions smoothly from 1 to 0 over `[p2, p3]`.
Returns
-------
w : ndarray
The window values evaluated at each point in `x`.
"""
ncp = get_module(engine)
if not (len(param) == 4 and param[0] < param[1] < param[2] < param[3]):
raise Exception("param should be of size 4, and p[0] < p[1] < p[2] < p[3]")
p = ncp.array([-20,70,-84, 35, 0, 0, 0, 0])
# x = ncp.linspace(0,5)
y = ncp.ones_like(x)
y[x <= param[0] ] = 0.
y[x >= param[3] ] = 0.
xx = (x[ (x >= param[0]) & (x <= param[1]) ] -param[0] ) /(param[1]-param[0])
y[ (x >= param[0]) & (x <= param[1]) ] = ncp.polyval( p, xx)
xx = (x[ (x >= param[2]) & (x <= param[3]) ] - param[3] ) /(param[2]-param[3])
y[ (x >= param[2]) & (x <= param[3]) ] = ncp.polyval( p, xx)
return y.reshape(x.shape)
[docs]
def bands2vec(imband, engine: str = "auto"):
"""
Convert the dictionary of complex subbands into a real-valued compressed vector.
Parameters
----------
imband : dict
Mapping from subband identifier (tuple) to ndarray of complex coefficients.
The key (0,) corresponds to the low-frequency band; other keys represent detail subbands.
Returns
-------
compressed : ndarray
1D real array containing the concatenated coefficients:
- First, the real part of the low-frequency band (imband[(0,)]).
- Then, for each detail subband, the interleaved real and imaginary parts,
so a0, b0, a1, b1, ..., where a is real and b is imaginary
"""
ncp = get_module(engine)
compressed = ncp.real(imband[(0,)].flatten())
# ucurv.imSz[0] = imband[0].shape
for id, subwin in imband.items():
if id == (0,): continue
# ucurv.imSz[id] = subwin.shape
a = ncp.real(subwin.flatten())
b = ncp.imag(subwin.flatten())
c = [item for pair in zip(a, b) for item in pair]
compressed = ncp.concatenate((compressed, ncp.array(c)))
return compressed
[docs]
def vec2bands(imband, udct, engine: str = "auto"):
"""
Reconstruct the dictionary of complex subbands from a compressed real-valued vector.
Parameters
----------
vector : ndarray
1D real array produced by `bands2vec`, containing:
- First, the low-frequency band samples.
- Then, for each detail subband, interleaved real and imaginary parts.
udct : object
An object providing these attributes:
- sz : sequence of ints, the original image dimensions.
- res : int, the number of decomposition levels.
- Sampling : dict mapping subband ids (tuples) to sampling factors.
- Msubwin : dict whose keys are the subband ids (tuples).
Returns
-------
uncompressed : dict
Mapping from subband identifier (tuple) to ndarray of complex coefficients,
reconstructed to the appropriate shape.
"""
ncp = get_module(engine)
# imSz = ncp.array(udct.sz)//2**(udct.res - 1)
# # first is the low band
# uncompressed = {(0,) :ncp.reshape(imband[:ncp.prod(imSz)], imSz)}
# p = ncp.prod(imSz)
imSz = tuple(sz_i // (2 ** (udct.res - 1)) for sz_i in udct.sz)
count = math.prod(imSz)
# first is the low band
uncompressed = { (0,): ncp.reshape(imband[:count], imSz) }
p = count
for id in udct.Msubwin.keys():
#if id == 0: continue
sampling = udct.Sampling[(id[0], id[1])].tolist()
imSz = tuple(sz_i // s for sz_i, s in zip(udct.sz, sampling))
length = math.prod(imSz)
c = imband[p:p + 2 * length]
c = [complex(c[i], c[i + 1]) for i in range(0, len(c), 2)]
uncompressed[id] = ncp.reshape(ncp.array(c, dtype=complex) , imSz)
p += 2 * length
return uncompressed
[docs]
def ucurv2d_show(imband, udct, engine: str = "auto"):
"""
Note: currently broken
Assemble and visualize a 2D curvelet transform by concatenating its subbands.
Parameters
----------
imband : dict
Mapping from subband identifiers to 2D complex arrays.
- The low-frequency band is stored under key `(0,)`.
udct : object
A curvelet transform descriptor with attributes:
- `dim` (int): number of dimensions (must be 2).
- `cfg` (list of tuple): number of subbands at each resolution/direction.
- `res` (int): total number of resolution levels.
- `sz` (tuple of int): original image shape `(height, width)`.
Returns
-------
display : ndarray of complex
A 2D array of shape `(sz[0], W)` where `W` is the concatenated width of
all subbands laid out for display. The first block is the low-frequency
image, followed by rows of detail subbands arranged by resolution.
Raises
------
Exception
If `udct.dim != 2`, since this function only supports 2D transforms.
"""
ncp = get_module(engine)
if udct.dim != 2:
raise Exception(" ucurv2d_show only work with 2D transform")
cfg = udct.cfg
imlist = []
res = udct.res
sz = udct.sz
for rs in range(res):
dirim = []
for dir in [0, 1]:
bandlist = [imband[(rs, dir, i)] for i in range(cfg[rs][dir])]
dirim.append(ncp.concatenate(bandlist , axis = 1-dir))
sp = dirim[1].shape
sp0 = sp[0]//3
d1 = ncp.concatenate([dirim[1][:sp0,:], dirim[1][sp0:2*sp0,:], dirim[1][2*sp0:,:] ] , axis = 1)
dimg = ncp.concatenate([dirim[0], d1] , axis = 0)
dshape = dimg.shape
dimg2 = ncp.zeros((sz[0], ncp.max(dshape)), dtype = complex)
dimg2[:dshape[0], :dshape[1]] = dimg
imlist.append(dimg2)
dimg2 = ncp.concatenate(imlist, axis = 1)
lbshape = imband[(0,)].shape
iml = ncp.zeros((sz[0], lbshape[1]), dtype = complex)
iml[:lbshape[0], :] = imband[(0,)]
dimg3 = ncp.concatenate([iml, dimg2], axis = 1)
return dimg3