Source code for ucurv.ucurv

import math
import numpy as np
from .util import fun_meyer
from .meyerwavelet import meyerfwdmd, meyerinvmd
from .backend import get_module
[docs] def ucurvhello(): print("Hello from ucurv!")
[docs] def tan_theta_grid(S1, S2, engine: str = "auto"): """ Create a grid approximate the tan theta function as described in (28) of the paper """ ncp = get_module(engine) # first grid x1, x2 = ncp.meshgrid(S2, S1) t1 = ncp.zeros_like(x1) ind = ncp.logical_and(x2 != 0, ncp.abs(x1) <= ncp.abs(x2)) t1[ind] = -x1[ind] / x2[ind] t2 = ncp.zeros_like(x1) ind = ncp.logical_and(x1 != 0, ncp.abs(x2) < ncp.abs(x1)) t2[ind] = x2[ind] / x1[ind] t3 = t2.copy() t3[t2 < 0] = t2[t2 < 0] + 2 t3[t2 > 0] = t2[t2 > 0] - 2 M2 = t1 + t3 M2[x2 >= 0] = -2 return M2
[docs] def fftflip(F, dirlist = None, engine: str = "auto"): """ Flip and circularly shift an N-D FFT array so that frequency sign is reversed. This routine performs an axis-wise reversal and roll on the input array to map X(\\omega) to X(-\\omega) in its FFT representation. For each transformed axis, elements are flipped (so that the zero-frequency component moves to the end), then rolled by one to restore the zero-frequency component to the first position. Note the parameter dirlist denotes the axes to be flipped. This is useful when flipping the angle functions along certain dimensions only. Parameters ---------- F : ndarray Incput array in the frequency domain (FFT output). Can be of any dimensionality. dirlist : int or sequence of ints, optional Axis or list of axes over which to apply the fftflip. If None (default), all axes of `F` will be processed. Returns ------- Fc : ndarray A new array of the same shape as `F`, with each specified axis flipped and rolled so that the frequency axis is negated. Notes ----- - Reversing an axis in FFT output corresponds to replacing \\omega with -\\omega. - After `ncp.flip`, the zero-frequency component moves to the end of the axis. `ncp.roll` by +1 brings it back to index 0. - For multi-dimensional FFTs, reversing multiple axes implements sign inversion in each frequency dimension. """ ncp = get_module(engine) Fc = F.copy() dim = Fc.ndim if dirlist is None: dirlist = list(range(dim)) shiftvec = ncp.zeros(dim) if type(dirlist) is list: for dir in dirlist: shiftvec[dir] = 1 Fc = ncp.flip(Fc, dir) Fc = ncp.roll(Fc, 1, axis=dir) if type(dirlist) is int: dir = dirlist shiftvec[dir] = 1 Fc = ncp.flip(Fc, dir) Fc = ncp.roll(Fc, 1, axis=dir) return Fc
[docs] def angle_fun(Mgrid, n, alpha, bandpass = None, engine: str = "auto"): """ Return the angle meyer window as in Figure 8 of the paper Compute directional Meyer window functions for angular decomposition. Parameters ---------- Mgrid : ndarray An N-D coordinate grid (e.g., meshgrid) of angles normalized to [−1,1]. n : int Total number of angular directions (must be positive and even). alpha : float Angular transition parameter controlling the width of each subband. bandpass : ndarray, optional An array of the same shape as `Mgrid` to modulate (mask) each window. Returns ------- Mang : list of ndarray A list of length `n` containing the forward (`0 <= id < n/2`) and flipped inverse (`n/2 <= id < n`) Meyer windows for each angular sector. """ ncp = get_module(engine) angd = 2/n ang = angd*ncp.array([-alpha, alpha, 1-alpha, 1+alpha]) Mang = [[] for i in range(n) ] tmp = [] sp = ncp.array(Mgrid.shape) for id in range(math.ceil(n/2)): ang2 = -1 + id*angd + ang x = fun_meyer(Mgrid, ang2, engine = engine) if bandpass is not None: x = x*bandpass x = ncp.roll(x, 3*sp//4, (0,1) ) Mang[id] = x.copy() dir = 1 Mang[n-1-id] = fftflip(x, dir, engine = engine) return Mang
[docs] def downsamp(band, samp, shift = None, engine: str = "auto"): """ Downsample a N-D array by length N of power-2 integers """ ncp = get_module(engine) if shift is None: shift = ncp.zeros(len(band.shape), dtype = int) if len(samp) == 2: return band[shift[0]::samp[0], shift[1]::samp[1]] if len(samp) == 3: return band[shift[0]::samp[0], shift[1]::samp[1], shift[2]::samp[2]] if len(samp) == 4: return band[shift[0]::samp[0], shift[1]::samp[1], shift[2]::samp[2], shift[3]::samp[3]] if len(samp) == 5: return band[shift[0]::samp[0], shift[1]::samp[1], shift[2]::samp[2], shift[3]::samp[3], shift[4]::samp[4]]
[docs] def upsamp(band, samp, shift = None, engine: str = "auto"): """ Upsample a N-D array by length N of power-2 integers """ ncp = get_module(engine) if shift is None: shift = ncp.zeros(len(band.shape), dtype = int) sp = ncp.array(band.shape)*samp shape = tuple(int(x) for x in sp) # works whether sp is list, np.ndarray, or cp.ndarray bandup = ncp.zeros(shape, dtype=complex) if len(samp) == 2: bandup[shift[0]::samp[0], shift[1]::samp[1]] = band if len(samp) == 3: bandup[shift[0]::samp[0], shift[1]::samp[1], shift[2]::samp[2]] = band if len(samp) == 4: bandup[shift[0]::samp[0], shift[1]::samp[1], shift[2]::samp[2], shift[3]::samp[3]] = band if len(samp) == 5: bandup[shift[0]::samp[0], shift[1]::samp[1], shift[2]::samp[2], shift[3]::samp[3], shift[4]::samp[4]] = band return bandup
[docs] def get_folding_indices(arr, fold_ratios): """ Returns aligned vectors of indices for folding an N-dimensional array. Args: arr (np.ndarray): Input array. fold_ratios (sequence): Folding ratio for each dimension. Returns: tuple: (orig_indices, folded_indices) - orig_indices: 1D array of flat indices in the big array. - folded_indices: 1D array of corresponding flat indices in the folded array. """ # --- 1. Validation --- if len(fold_ratios) != arr.ndim: raise ValueError("Fold ratios length must match array dimensions.") # Calculate new shape and check divisibility folded_shape = [] for d, r in zip(arr.shape, fold_ratios): if d % r != 0: raise ValueError(f"Dimension size {d} not divisible by ratio {r}") folded_shape.append(d // r) folded_shape = tuple(folded_shape) # --- 2. Get Non-Zero Data from Original --- # Flat indices of all non-zero elements in the big array orig_flat = np.flatnonzero(arr) values = arr.ravel()[orig_flat] # --- 3. Transform Coordinates (Folding Logic) --- # Unravel flat indices -> N-D coordinates orig_coords = np.unravel_index(orig_flat, arr.shape) # Apply modulo arithmetic for folding: coord % new_dim_size folded_coords = tuple(c % fs for c, fs in zip(orig_coords, folded_shape)) # Ravel N-D coordinates -> flat indices for the folded array folded_flat = np.ravel_multi_index(folded_coords, folded_shape) # --- 4. Verify Final Non-Zero Status --- # We must sum the values to ensure they don't cancel out to zero (e.g. 5 + -5) folded_data_flat = np.zeros(np.prod(folded_shape), dtype=arr.dtype) np.add.at(folded_data_flat, folded_flat, values) # Create a mask: Keep only indices where the destination folded bucket is non-zero is_active = folded_data_flat[folded_flat] != 0 # Filter the vectors final_orig_indices = orig_flat[is_active] final_folded_indices = folded_flat[is_active] return final_orig_indices, final_folded_indices
[docs] def update_win_overlaps_numpy(win, orig_indices, folded_indices): """ Updates 'win' using pure Numpy: 1. Identifies overlapping groups based on 'folded_indices'. 2. Finds the MAX value in each group. 3. The MAX index gets the sum of all OTHER values in the group. 4. All OTHER indices are set to 0. """ # 1. Flatten arrays to 1D to simplify indexing flat_orig = orig_indices.flatten() flat_folded = folded_indices.flatten() # Extract the specific values involved in the folding vals = win.flat[flat_orig] # Determine the size of the folded space (number of bins) num_bins = flat_folded.max() + 1 # 2. Initialize arrays for aggregation group_sums = np.zeros(num_bins, dtype=win.dtype) group_maxs = np.full(num_bins, -np.inf, dtype=win.dtype) # 3. Perform unbuffered scatter operations # This correctly handles duplicates in 'flat_folded' np.add.at(group_sums, flat_folded, vals) # Calculate Sum per group np.maximum.at(group_maxs, flat_folded, vals) # Calculate Max per group # 4. Map the group results back to the original element positions mapped_sums = group_sums[flat_folded] mapped_maxs = group_maxs[flat_folded] # 5. Determine Winners vs Losers # A "Winner" is the index holding the maximum value for its group is_winner = (vals == mapped_maxs) # 6. Calculate New Values # Initialize a results array with 0s (for the losers) new_vals = np.zeros_like(vals) # The Winner gets: (Group Sum - Winner's Value) # This equals the sum of all other overlapping indices. new_vals[is_winner] = mapped_sums[is_winner] # 7. Write results back into the original 'win' array win.flat[flat_orig] = new_vals return win
#### class to hold all curvelet windows and other based on transform configuration
[docs] class Udct: """ A class to hold the configuration and parameters for the ucurv transform. This class initializes the sampling vectors, computes the parameters, and prepares the curvelet windows based on the specified parameters. Parameters ---------- sz : tuple of int The size of the input image or data array. It must be the even multiples of the resolution levels. 2D example: (256, 256) for a 2D image. cfg : list of list of int Configuration for the curvelet transform, where each inner list specifies the number of angles for each resolution level. complex : bool, optional If True, the transform will ouput complex curvelet coefficients. Default is False. sparse : bool, optional If True, the transform will store sparse representations of the windows. This will sigmificantly reduced memory required to remember the curvelet windows. Default is False. high : str, optional Specifies the type of the transform to use on the highest resolution. This will reduce the redundancy of the transform. Options are 'curvelet' or 'wavelet'. Default is 'curvelet'. """ def __init__(self, sz, cfg, complex = False, sparse = False, high = 'curvelet', engine: str = "auto"): ncp = get_module("numpy") #just build all windows on CPU, much faster, then lift to GPU if requested engine is cupy self.name = "ucurv" self.engine = engine # if high != 'curvelet': self.sz = tuple(ncp.array(sz)//2) else: self.sz = tuple(sz) self.cfg = tuple(cfg) self.complex = complex self.sparse = sparse self.high = high self.dim = len(sz) self.res = len(cfg) dim = len(sz) res = len(cfg) self.Sampling = {} # calculate output len clen = ncp.prod(ncp.array(self.sz))//((2**self.dim)**(self.res-1)) self.len = clen for i in range(self.res): clen = clen*((2**self.dim)**i) self.len = self.len + clen*self.dim*3**(self.dim-1)//2**(self.dim-1) # create the subsampling vectors self.Sampling[(0)] = 2**(res-1)*ncp.ones(dim, dtype = int) for rs in range(res): for ipyr in range(dim): dmat = [] for idir in range(dim): if idir == ipyr: dmat.append(2**(res-rs)) else: dmat.append(2*(cfg[rs][idir]//3)*2**(res-rs-1)) self.Sampling[(rs,ipyr)] = ncp.array(dmat, dtype = int) # the grid for angle functions Sgrid = [ [] for i in range(dim) ] for ind in range(dim): Sgrid[ind] = ncp.linspace(-1.5 * ncp.pi, 0.5 * ncp.pi - ncp.pi / (self.sz[ind] / 2), self.sz[ind]) # create the 1D smooth functions to be used in angle function creation r = ncp.pi*ncp.array([1/3, 2/3, 2/3, 4/3]) alpha = 0.1 f1d = {} for ind in range(dim): for rs in range(res): f1d[ (rs, ind) ] = fun_meyer(ncp.abs(Sgrid[ind]), [-2, -1, r[0]/2**(res-1-rs), r[1]/2**(res-1-rs)]) f1d[ (res, ind )] = fun_meyer(ncp.abs(Sgrid[ind]), [-2, -1, r[2], r[3] ]) # the grid for the lowpass function FL SLgrid = [ [] for i in range(dim) ] for ind in range(dim): SLgrid[ind] = ncp.linspace(-ncp.pi, ncp.pi - ncp.pi / (self.sz[ind] / 2), self.sz[ind]) FL = ncp.ones([1]) for ind in range(dim): fl1d = fun_meyer(ncp.abs(SLgrid[ind]), [-2, -1, r[0]/2**(res-1), r[1]/2**(res-1)]) FL = ncp.kron(FL, fl1d.flatten() ) # print(FL.shape) FL = FL.reshape(self.sz) # Mang2 will contain all the 2D angle functions needed to create dim-dimension # angle pyramid. As such it is a 4D dictionary 2D angle funtions. The dimension are # Resolution - Dimension (number of hyper pyramid) - Dimension-1 (number of angle # function in each pyramid ) - Number of angle function in that particular resolution-direction Mang2 = {} for rs in range(res): # For each resolution we loop through each pyramid for ind in range(dim): # For each pyramid we try to collect all the 2D angle function so that we can build the dim # dim-dimension angle functions for idir in range(dim): if idir == ind : # skip the dimension that is the same as the pyramid continue Mg0 = tan_theta_grid(Sgrid[ind], Sgrid[idir]) # create the bandpass function BP1 = ncp.outer(f1d[(rs,ind)], f1d[(rs,idir)] ) BP2 = ncp.outer(f1d[(rs+1,ind)], f1d[(rs+1,idir)] ) bandpass = (BP2 - BP1)**(1./(dim-1.)) # create the 2D angle function, in the vertical 2D pyramid # the index are resolution, pyramid dimension index, other dimension index Mang2[(rs, ind, idir)] = angle_fun( Mg0, cfg[rs][ idir] , alpha, bandpass) self.Mang2 = Mang2 ################################# Msubwin = {} # for each resolution for rs in range(res): # angle index list for each pyramid id_angle_lists = [] for i in range(self.dim): dlist = [j for j in range(self.dim) if j != i] new_list = [[i] for i in range(cfg[rs][dlist[0]])] for i in range(1, self.dim -1): new_list = [z + [j] for z in new_list for j in range(cfg[rs][dlist[i]])] id_angle_lists.append(new_list) for ipyr, id_angle_list in enumerate(id_angle_lists): # for each resolution-pyramid, id_angle_list is the angle combinaion within that pyramid # for instance, (5,5) would be the last angle of a (6,6) 3D pyramid # and dlist is the list of the dimension of that pyramid, # for instance (0,2) would be the list of pyramid of dimension 1 in 3D case # print("idangle", id_angle_list) dlist = [j for j in range(self.dim) if j != ipyr] # for each angle combination in that pyramid, create the subband window for alist in id_angle_list: subband = ncp.ones(self.sz) # traverse each dimension in that angle combination for idim, aid in enumerate(alist): # broacast will replicate the 2D angle function along the all other dimension # and then move the 2D angle function to the correct dimension in the pyramid shape_scratch = [sz[ipyr], sz[dlist[idim]]] + [1] * (dim - 2) F_reshaped = Mang2[(rs, ipyr, dlist[idim])][aid].reshape(shape_scratch) angkron = ncp.moveaxis(F_reshaped, [0, 1], [ipyr, dlist[idim]]) # dimensional broadcasting multiplication subband = subband*angkron Msubwin[tuple([rs, ipyr] + alist)] = subband.copy() # print(subband) ################################# sumall = ncp.zeros(self.sz) for id, subwin in Msubwin.items(): orig_indices, folded_indices = get_folding_indices(subwin, self.Sampling[(id[0], id[1])]) subwin = update_win_overlaps_numpy(subwin, orig_indices, folded_indices) sumall = sumall + subwin sumall = sumall + fftflip(sumall) sumall = sumall + FL print(sumall) self.Msubwin = {} for id, subwin in Msubwin.items(): win = ncp.fft.fftshift(ncp.sqrt(2*ncp.prod(self.Sampling[(id[0], id[1])]) *subwin / sumall)) if sparse: orig_indices, folded_indices = get_folding_indices(win, self.Sampling[(id[0], id[1])]) self.Msubwin[id] = (orig_indices, win.ravel()[orig_indices], folded_indices) else: self.Msubwin[id] = win win = ncp.sqrt(ncp.prod(self.Sampling[(0)]))*ncp.fft.fftshift(ncp.sqrt(FL/sumall)) if sparse: orig_indices, folded_indices = get_folding_indices(win, self.Sampling[(0)]) self.FL = ( orig_indices, win.ravel()[orig_indices], folded_indices) else: self.FL = win if engine == "cupy": #lift everything to GPU ncp = get_module("cupy") for key, mat in self.Sampling.items(): # mat is a 1D integer array self.Sampling[key] = ncp.asarray(mat) # Convert each sub-window for key, w in self.Msubwin.items(): if self.sparse: idx, vals = w # idx is a tuple of index-arrays, vals is an array of values idx_gpu = tuple(ncp.asarray(i) for i in idx) vals_gpu = ncp.asarray(vals) self.Msubwin[key] = (idx_gpu, vals_gpu) else: self.Msubwin[key] = ncp.asarray(w) if self.sparse: idx, vals = self.FL idx_gpu = tuple(ncp.asarray(i) for i in idx) vals_gpu = ncp.asarray(vals) self.FL = (idx_gpu, vals_gpu) else: self.FL = ncp.asarray(self.FL)
[docs] def ucurvfwd(img, udct): """ Forward Uniform Discrete Curvelet Transform (UDCT). This function computes the forward UDCT coefficients of a given MD input signal `img` using the precomputed parameters and windows. Will either return output on the CPU/GPU depending on which engine is used(numpy/cupy). Parameters ---------- img : ndarray Input real- or complex-valued array of dimension `udct.dim` and shape `udct.sz`. For `high='curvelet'`, the image size must match exactly `udct.sz`. For `high='wavelet'`, the highest resolution is decomposed with Meyer wavelets instead of curvelets. udct : Udct An instance of the Udct class containing precomputed window functions, sampling factors, and configuration parameters for the transform. Returns ------- imband : dict Dictionary of UDCT coefficients (subbands). Keys are tuples identifying each subband: - (0,) : lowpass scaling coefficients. - (rs, ipyr, a1, ..., ak) : directional subbands at resolution `rs`, pyramid index `ipyr`, and angular indices (a1, ..., ak), where k = udct.dim -1 (why ?). - For complex UDCT, each subband has a conjugate-symmetric counterpart indexed with `ipyr + udct.dim`. Each subband is stored as a downsampled array according to the decimation factors in `udct.Sampling`. Notes ----- - Inverse transform function `ucurvinv`. - If `udct.complex=True`, symmetric/antisymmetric curvelets are stored separately as conjugate subbands, requiring sqrt(0.5) normalization. - If `udct.sparse=True`, subband windows are stored sparsely and reconstructed on-the-fly. """ engine = udct.engine ncp = get_module(engine) if engine == "cupy": #move onto GPU if using cupy img = ncp.asarray(img) if udct.high == 'curvelet': assert img.shape == udct.sz Msubwin = udct.Msubwin # FL = udct.FL Sampling = udct.Sampling if udct.sparse: FL = ncp.zeros(udct.sz) FL.flat[udct.FL[0]] = udct.FL[1] else: FL = udct.FL imband = {} if udct.high == 'wavelet': band = meyerfwdmd(img, engine = engine) for i, band in enumerate(band): if i == 0: imf = ncp.fft.fftn(band) else: imband[(udct.res, i)] = band else: imf = ncp.fft.fftn(ncp.array(img)) if udct.complex: bandfilt = ncp.fft.ifftn(imf*FL) print(bandfilt) imband[(0,)] = downsamp(bandfilt, Sampling[(0)], engine = engine) for id, subwin in Msubwin.items(): if udct.sparse: sbwin = ncp.zeros(udct.sz) sbwin[subwin[0]] = subwin[1] subwin = sbwin bandfilt = ncp.sqrt(0.5)*ncp.fft.ifftn(imf *subwin) imband[id] = downsamp(bandfilt, Sampling[(id[0], id[1])], engine = engine) id2 = list(id) id2[1] = id2[1] + udct.dim bandfilt = ncp.sqrt(0.5)*ncp.fft.ifftn(imf *fftflip(subwin)) imband[tuple(id2)] = downsamp(bandfilt, Sampling[(id[0], id[1])], engine = engine) else: bandfilt = ncp.real(ncp.fft.ifftn(imf*FL)) imband[(0,)] = downsamp(bandfilt, Sampling[(0)], engine = engine) # ncp.real(ncp.fft.ifftn(imf*FL)) for id, subwin in Msubwin.items(): if udct.sparse: sbwin = ncp.zeros(udct.sz) sbwin.flat[subwin[0]] = subwin[1] subwin = sbwin bandfilt = ncp.fft.ifftn(imf *subwin) # samp = Sampling[(id[0], id[1])] # imband[id] = bandfilt[::samp[0], ::samp[1]] imband[id] = downsamp(bandfilt, Sampling[(id[0], id[1])], engine = engine) # print(bandfilt.shape, Sampling[(id[0], id[1])], imband[id].shape) return imband
##############
[docs] def ucurvinv(imband, udct): """ Inverse Uniform Discrete Curvelet Transform (UDCT). This function reconstructs a signal from its UDCT coefficients, providing the perfect inverse of `ucurvfwd`. Parameters ---------- imband : dict Dictionary of UDCT coefficients produced by `ucurvfwd`. Keys must follow the same convention: - (0,) : lowpass scaling coefficients. - (rs, ipyr, a1, ..., ak) : directional subbands. - For complex UDCT, additional conjugate subbands. udct : Udct An instance of the Udct class containing precomputed window functions, sampling factors, and configuration parameters for the transform. Returns ------- recon : ndarray The reconstructed signal, same shape as `udct.sz` Notes ----- - Uses upsampling and FFT-domain synthesis with precomputed windows. """ engine = udct.engine ncp = get_module(engine) Msubwin = udct.Msubwin Sampling = udct.Sampling # imlow = imband[0] imlow = upsamp(imband[(0,)], Sampling[(0)], engine = engine) if udct.sparse: FL = ncp.zeros(udct.sz) FL.flat[udct.FL[0]] = udct.FL[1] else: FL = udct.FL if udct.complex: recon = ncp.fft.ifftn( ncp.fft.fftn(imlow) * FL) else: recon = ncp.real(ncp.fft.ifftn( ncp.fft.fftn(imlow) * FL) ) for id, subwin in Msubwin.items(): if udct.high != 'curvelet' and id[0] == udct.res : continue if udct.sparse: sbwin = ncp.zeros(udct.sz) sbwin.flat[subwin[0]] = subwin[1] subwin = sbwin if udct.complex: bandup = upsamp(imband[id], Sampling[(id[0], id[1])], engine = engine) recon = recon + ncp.sqrt(0.5)*ncp.fft.ifftn( ncp.fft.fftn(bandup) * subwin ) id2 = list(id) id2[1] = id2[1] + udct.dim bandup = upsamp(imband[tuple(id2)], Sampling[(id[0], id[1])], engine = engine) recon = recon + ncp.sqrt(0.5)*ncp.fft.ifftn( ncp.fft.fftn(bandup) * fftflip(subwin, engine = engine) ) else: bandup = upsamp(imband[id], Sampling[(id[0], id[1])], engine = engine) recon = recon + ncp.real(ncp.fft.ifftn( ncp.fft.fftn(bandup) * subwin )) if udct.high == 'wavelet': band = [recon] for id, suband in imband.items(): if id[0] == udct.res: band.append(suband) recon = meyerinvmd(band, engine = engine) return recon