nenuscanner/utils/mashal.py

36 lines
1.1 KiB
Python

import numpy as np
def marshal_arrays(arrays):
"""
Flatten a list of numpy arrays and store their shapes.
Parameters:
arrays (list of np.ndarray): List of numpy arrays to be marshalled.
Returns:
tuple: A tuple containing:
- flat (np.ndarray): A single concatenated numpy array of all elements.
- shapes (list of tuple): A list of shapes of the original arrays.
"""
flattened = list(map(lambda a : np.reshape(a,-1),arrays))
shapes = list(map(np.shape,arrays))
flat = np.concatenate(flattened)
return flat, shapes
def unmarshal_arrays(flat,shapes):
"""
Rebuild the original list of numpy arrays from the flattened array and shapes.
Parameters:
flat (np.ndarray): The single concatenated numpy array of all elements.
shapes (list of tuple): A list of shapes of the original arrays.
Returns:
list of np.ndarray: The list of original numpy arrays.
"""
sizes = list(map(np.prod,shapes))
splits = np.cumsum(np.asarray(sizes,dtype=int))[:-1]
flattened = np.split(flat,splits)
arrays = list(map(lambda t : np.reshape(t[0],t[1]),zip(flattened,shapes)))
return arrays