Source code for ecosound.soundscape.hmd

"""
HMD (Hybrid Millidecade) Class for Acoustic Data Analysis

"""

import xarray as xr
import numpy as np
import pandas as pd
from glob import glob
from dask.distributed import Client, LocalCluster
import dask
from pathlib import Path


[docs] class HMD: """ Hybrid Millidecade acoustic data processor for NetCDF spectral files. Parameters ---------- n_workers : int, optional Number of Dask workers (default: 4) memory_per_worker : str, optional Memory limit per worker (default: '4GB') temp_directory : str, optional Directory for Dask temporary files (default: system temp) use_dask : bool, optional Enable Dask parallel processing (default: True) use_processes : bool, optional Use processes instead of threads (default: auto-detect, False on Windows) Examples -------- >>> hmd = HMD(n_workers=4) >>> hmd.load_nc_files('path/to/deployment_01') >>> result = hmd.extract_band_levels([[50, 300]], ['ship']) >>> hmd.summary() """
[docs] def __init__(self, n_workers=4, memory_per_worker='4GB', temp_directory=None, use_dask=True, use_processes=None): """Initialize HMD processor with optional Dask cluster""" import platform import tempfile self.n_workers = n_workers self.memory_per_worker = memory_per_worker self.temp_directory = temp_directory or tempfile.gettempdir() self.use_dask = use_dask self.client = None self.ds = None if use_processes is None: use_processes = platform.system() != 'Windows' self.use_processes = use_processes if self.use_dask: self._setup_dask() else: print("Running without Dask (single-threaded mode)")
[docs] def load_nc_files(self, path, freq_range=None, time_range=None, recursive=False, chunks=None, prefilter_by_date=False, filename_pattern=None): """ Load NetCDF files from a directory. Parameters ---------- path : str Path to directory containing .nc files freq_range : tuple, optional Frequency range in Hz (min_freq, max_freq) time_range : tuple, optional Time range as strings ('2021-06-01', '2021-06-30'). Start is inclusive, end is exclusive: [start, end) recursive : bool, optional If True, search subdirectories recursively (default: False) chunks : dict, optional Custom chunk sizes (default: {'time': 1440, 'frequency': 500}) prefilter_by_date : bool, optional Filter files by date before loading (default: False) filename_pattern : str, optional Regex pattern to filter files by name. Only files matching this pattern will be loaded. Examples: - r'_\d{8}\.nc$' : Files ending with _YYYYMMDD.nc (e.g., data_20201121.nc) - ``r'^deployment_.*\.nc$'`` : Files starting with ``'deployment_'`` - r'.*_HMD_.*\.nc$' : Files containing '_HMD_' Default: None (load all .nc files) Returns ------- self Examples -------- >>> # Load all .nc files >>> hmd.load_nc_files('data/deployment_01') >>> >>> # Load only files ending with date pattern _YYYYMMDD.nc >>> hmd.load_nc_files('data/', ... recursive=True, ... filename_pattern=r'_\d{8}\.nc$') >>> >>> # Load only files with specific prefix >>> hmd.load_nc_files('data/', ... filename_pattern=r'^HMD_.*\.nc$') """ if chunks is None: chunks = {'time': 1440, 'frequency': 500} path = Path(path) if not path.exists(): raise FileNotFoundError(f"Path does not exist: {path}") # Find all .nc files if recursive: all_files = sorted(list(path.rglob('*.nc'))) else: all_files = sorted(list(path.glob('*.nc'))) if len(all_files) == 0: search_type = "or subdirectories" if recursive else "" raise FileNotFoundError(f"No .nc files found in {path} {search_type}") print(f"Found {len(all_files)} .nc files") # Filter files by filename pattern if requested if filename_pattern is not None: import re initial_count = len(all_files) pattern = re.compile(filename_pattern) all_files = [f for f in all_files if pattern.search(f.name)] filtered_count = initial_count - len(all_files) if filtered_count > 0: print(f"Filtered by filename pattern '{filename_pattern}': " f"{len(all_files)} files (excluded {filtered_count})") if len(all_files) == 0: raise FileNotFoundError(f"No files match pattern '{filename_pattern}'") # Filter files by date if requested if prefilter_by_date and time_range is not None: all_files, skipped = self._filter_files_by_date(all_files, time_range) print(f"Pre-filtered to {len(all_files)} files (skipped {skipped} outside date range)") if len(all_files) == 0: raise FileNotFoundError(f"No files found within time range {time_range}") # Group files by parent directory files_by_dir = {} for file in all_files: parent = file.parent if parent not in files_by_dir: files_by_dir[parent] = [] files_by_dir[parent].append(file) # Show what we found if len(files_by_dir) > 1: print(f"Files organized in {len(files_by_dir)} directories:") for dir_path, files in sorted(files_by_dir.items()): try: rel_path = dir_path.relative_to(path) dir_label = str(rel_path) if str(rel_path) != '.' else "(base)" except ValueError: dir_label = dir_path.name print(f" - {dir_label}: {len(files)} files") print(f"\nLoading {len(all_files)} files...") # Create mapping of file to deployment name file_to_deployment = {} for dir_path, files in files_by_dir.items(): if len(files_by_dir) == 1: deployment_name = dir_path.name else: try: rel_path = dir_path.relative_to(path) deployment_name = str(rel_path).replace('\\', '/') if str(rel_path) != '.' else path.name except ValueError: deployment_name = dir_path.name for file in files: file_to_deployment[str(file)] = deployment_name # Preprocessing function def preprocess(ds): source_file = ds.encoding.get('source', '') deployment_name = file_to_deployment.get(source_file, 'unknown') if freq_range is not None: ds = ds.sel(frequency=slice(freq_range[0], freq_range[1])) ds = ds.assign_coords({'deployment': deployment_name}) return ds # Load all files at once self.ds = xr.open_mfdataset( [str(f) for f in all_files], engine='h5netcdf', chunks=chunks if self.use_dask else None, preprocess=preprocess, parallel=True if self.use_dask else False, combine='by_coords', decode_timedelta=True, data_vars='minimal', coords='minimal', compat='override', lock=False if self.use_processes else None, ) # Subset time range if specified (end exclusive) if time_range is not None: print("Subsetting time range...") import pandas as pd # Convert end time and subtract one microsecond to make it exclusive end_time = pd.Timestamp(time_range[1]) - pd.Timedelta(microseconds=1) self.ds = self.ds.sel(time=slice(time_range[0], end_time)) print(f"✓ Loaded dataset: {len(self.ds.time)} time points, " f"{len(self.ds.frequency)} frequency bins") if 'deployment' in self.ds.dims: print(f" Deployments: {list(self.ds.deployment.values)}") # Provide chunking advice if chunks don't align well if self.ds.chunks: freq_chunks = self.ds.chunks.get('frequency', []) if len(freq_chunks) > 1: print("\n ℹ Note: Data has multiple frequency chunks which may cause warnings") print(" For better performance when extracting band levels, consider:") print(" hmd.rechunk({'time': 1440, 'frequency': -1})") return self
[docs] def extract_band_levels(self, freq_bands, band_names=None, persist=True): """ Extract time series at specific frequencies or frequency bands. Parameters ---------- freq_bands : list of lists/tuples List of frequency specifications. Each element can be: - [freq]: Single frequency point (e.g., [100] for 100 Hz) - [freq_min, freq_max]: Frequency band (e.g., [50, 300] for 50-300 Hz) band_names : list of str, optional Names for each band. If None, auto-generated as 'band_0', 'band_1', etc. persist : bool If True, keep result in distributed memory Returns ------- xarray.Dataset Dataset with time series for each band/frequency Examples -------- >>> # Single frequencies >>> bands = [[100], [500], [1000]] >>> result = hmd.extract_band_levels(bands) >>> # Frequency bands >>> bands = [[50, 300], [500, 2000], [2000, 10000]] >>> names = ['ship', 'fish', 'mammal'] >>> result = hmd.extract_band_levels(bands, band_names=names) >>> # Mixed: single frequencies and bands >>> bands = [[100], [50, 300], [1000], [2000, 10000]] >>> names = ['100Hz', 'ship', '1000Hz', 'mammal'] >>> result = hmd.extract_band_levels(bands, band_names=names) >>> # Access results >>> ship_noise = result['ship'] >>> ts_100hz = result['100Hz'] """ self._check_loaded() # Auto-generate band names if not provided if band_names is None: band_names = [f'band_{i}' for i in range(len(freq_bands))] if len(band_names) != len(freq_bands): raise ValueError(f"Number of band_names ({len(band_names)}) must match " f"number of freq_bands ({len(freq_bands)})") results = {} for band_name, freq_spec in zip(band_names, freq_bands): if len(freq_spec) == 1: # Single frequency point freq = freq_spec[0] timeseries = self.ds.psd.sel(frequency=freq, method='nearest') # Drop the frequency coordinate to avoid conflicts when combining timeseries = timeseries.drop_vars('frequency', errors='ignore') print(f" {band_name}: {freq} Hz (single frequency)") elif len(freq_spec) == 2: # Frequency band - integrate properly freq_min, freq_max = freq_spec band = self.ds.sel(frequency=slice(freq_min, freq_max)) timeseries = self._integrate_band_db(band.psd) print(f" {band_name}: {freq_min}-{freq_max} Hz (integrated band)") else: raise ValueError(f"Each freq_spec must have 1 or 2 elements, " f"got {len(freq_spec)} for {band_name}") results[band_name] = timeseries # Combine into a dataset result_ds = xr.Dataset(results) if persist: return result_ds.persist() else: return result_ds
[docs] def compute_timeseries_stats(self, data, percentiles=[10, 50, 90], resolution='1h'): """ Compute statistics (mean and percentiles) on acoustic data in dB. Optimized for parallel computation with Dask. Robust to NaN values. """ import xarray as xr import dask.array as da import numpy as np import warnings # Determine data type and structure if isinstance(data, xr.DataArray): process_data = {'data': data} is_dataset = False elif isinstance(data, xr.Dataset): process_data = {var: data[var] for var in data.data_vars} is_dataset = True else: raise ValueError("data must be xarray.DataArray or Dataset") # Check dimensions first_var = list(process_data.values())[0] has_time = 'time' in first_var.dims has_freq = 'frequency' in first_var.dims if not (has_time and not has_freq): raise ValueError("Data must only have a 'time' dimension (time series)") print(f"Computing time series statistics at {resolution} resolution...") results = {} # Process all variables for name, arr in process_data.items(): print(f" Processing {name}...") # Check if data is already computed (not lazy) # If so, convert back to Dask array to avoid large graph warning if not hasattr(arr.data, 'dask'): # Data is already computed (numpy array) # Re-chunk it as a Dask array with optimal chunks print(f" Converting computed data back to Dask array...") if resolution in ['1h', '1H']: time_chunk = 24 * 60 # 1 day worth of minute data elif resolution in ['1D', '1d']: time_chunk = 365 # 1 year worth of daily data else: # Try to estimate good chunk size time_chunk = min(10000, len(arr.time)) arr = arr.chunk({'time': time_chunk}) else: # Data is already lazy (Dask array) # Rechunk to optimal size for resampling if resolution in ['1h', '1H']: time_chunk = 24 * 60 # 1 day worth of minute data elif resolution in ['1D', '1d']: time_chunk = 30 * 1440 # 1 month worth of minute data else: time_chunk = 10000 # Default large chunk arr = arr.chunk({'time': time_chunk}) # Create resampler resampler = arr.resample(time=resolution,label='left') # Compute all statistics in parallel stats_dict = {} # Count valid observations per bin count = resampler.count() # For dB mean, convert to linear → resample → convert back linear = 10.0 ** (arr / 10.0) linear_mean = linear.resample(time=resolution,label='left').mean(skipna=True) # Safe log10 that handles zero/negative values and NaNs def safe_log10(x): """Safely compute log10, returning NaN for invalid values""" result = np.where((x > 0) & np.isfinite(x), np.log10(x), np.nan) return result stats_dict['mean'] = 10 * xr.apply_ufunc( safe_log10, linear_mean, dask='parallelized', output_dtypes=[float] ) # Suppress RuntimeWarnings from flox during std calculation with NaNs #with warnings.catch_warnings(): # warnings.filterwarnings('ignore', 'invalid value encountered', RuntimeWarning) # stats_dict['std'] = resampler.nanstd(skipna=True) #stats_dict['min'] = resampler.nanmin(skipna=True) #stats_dict['max'] = resampler.nanmax(skipna=True) # FIX: Compute all percentiles at once, then separate them # This avoids reference issues if percentiles: # Compute all quantiles together all_quantiles = resampler.quantile([p / 100 for p in percentiles], skipna=True) # Now extract each percentile separately for i, p in enumerate(percentiles): # Select this specific percentile and drop the quantile coordinate percentile_result = all_quantiles.isel(quantile=i).drop_vars('quantile') stats_dict[f'L{p}'] = percentile_result # Add count as well for diagnostics stats_dict['count'] = count # Combine into single Dataset var_stats = xr.Dataset(stats_dict) # Persist immediately to trigger computation and cache results if self.use_dask: var_stats = var_stats.persist() print(f" ✓ {name} persisted to distributed memory") results[name] = var_stats print(f"✓ Time series statistics complete ({resolution} resolution)") # Return appropriate format if is_dataset: if len(results) == 1: return list(results.values())[0] else: # Combine multiple variables combined = xr.Dataset() for var_name, var_stats in results.items(): for stat_name in var_stats.data_vars: combined[f'{var_name}_{stat_name}'] = var_stats[stat_name] return combined else: return results['data']
[docs] def compute_frequency_stats(self, data=None, percentiles=[1, 10, 25, 50, 75, 90, 99], include_mean=False, persist=True): """ Compute statistics (percentiles and optionally mean) across the time dimension for each frequency. Optimized for parallel computation with Dask. Robust to NaN values. This method computes statistics across time for spectral data (PSD), producing a frequency-domain summary showing how sound levels are distributed at each frequency. Parameters ---------- data : xarray.DataArray or xarray.Dataset, optional Input spectral data in dB units with 'frequency' and 'time' dimensions. If None, uses self.ds.psd (default: None) percentiles : list of int/float, optional Percentiles to compute (default: [1, 10, 25, 50, 75, 90, 99]). Each percentile will be saved as 'p{value}' (e.g., p50 for median) include_mean : bool, optional If True, also compute the proper dB mean (convert to linear, average, convert back). Default is False since mean is less commonly used for PSD analysis. persist : bool, optional If True and using Dask, immediately persist results to distributed memory. Default is True for better performance. Returns ------- xarray.Dataset Statistics with 'frequency' dimension. Variables include: - 'p{N}': Nth percentile for each frequency (e.g., p50 is median) - 'mean': Proper dB mean (only if include_mean=True) - 'count': Number of valid observations at each frequency Examples -------- >>> # Compute default percentiles for PSD >>> hmd = HMD(n_workers=4) >>> hmd.load_nc_files('deployment_01/') >>> freq_stats = hmd.compute_frequency_stats() >>> >>> # Custom percentiles without persisting >>> freq_stats = hmd.compute_frequency_stats( ... percentiles=[5, 50, 95], ... include_mean=True, ... persist=False ... ) >>> >>> # Use with custom data >>> freq_stats = hmd.compute_frequency_stats(data=hmd.ds.psd) Notes ----- - For dB data, mean is computed properly: convert to linear → average → convert back - All percentiles are computed together in one operation for efficiency - Results are persisted to Dask distributed memory for faster subsequent access - Use this method in plot_psd for consistent statistics computation """ import xarray as xr import dask.array as da import numpy as np # Use PSD by default if data is None: self._check_loaded() data = self.ds.psd # Determine data type and structure if isinstance(data, xr.DataArray): process_data = {'data': data} is_dataset = False elif isinstance(data, xr.Dataset): process_data = {var: data[var] for var in data.data_vars} is_dataset = True else: raise ValueError("data must be xarray.DataArray or Dataset") # Check dimensions first_var = list(process_data.values())[0] has_time = 'time' in first_var.dims has_freq = 'frequency' in first_var.dims if not (has_time and has_freq): raise ValueError("Data must have both 'time' and 'frequency' dimensions (spectral data)") print(f"Computing frequency statistics across time dimension...") results = {} # Process all variables for name, arr in process_data.items(): print(f" Processing {name}...") # Check if data is already computed (not lazy) # If so, convert back to Dask array to avoid large graph warning if not hasattr(arr.data, 'dask'): # Data is already computed (numpy array) # Re-chunk it as a Dask array with optimal chunks print(f" Converting computed data back to Dask array...") time_chunk = min(10000, len(arr.time)) arr = arr.chunk({'time': time_chunk, 'frequency': -1}) else: # Data is already lazy (Dask array) # Ensure optimal chunking: all frequencies, chunked time current_chunks = arr.chunks if 'frequency' in arr.dims: # Rechunk to have all frequencies in one chunk arr = arr.chunk({'time': 10000, 'frequency': -1}) # Compute all statistics in parallel stats_dict = {} # Count valid observations per frequency count = arr.count(dim='time') # Compute mean if requested if include_mean: # For dB mean, convert to linear → average → convert back linear = 10.0 ** (arr / 10.0) linear_mean = linear.mean(dim='time', skipna=True) # Safe log10 that handles zero/negative values and NaNs def safe_log10(x): """Safely compute log10, returning NaN for invalid values""" result = np.where((x > 0) & np.isfinite(x), np.log10(x), np.nan) return result stats_dict['mean'] = 10 * xr.apply_ufunc( safe_log10, linear_mean, dask='parallelized', output_dtypes=[float] ) # Compute all percentiles at once for efficiency if percentiles: print(f" Computing percentiles: {percentiles}...") # Compute all quantiles together all_quantiles = arr.quantile([p / 100 for p in percentiles], dim='time', skipna=True) # Now extract each percentile separately for i, p in enumerate(percentiles): # Select this specific percentile and drop the quantile coordinate percentile_result = all_quantiles.isel(quantile=i).drop_vars('quantile') stats_dict[f'p{p}'] = percentile_result # Add count for diagnostics stats_dict['count'] = count # Combine into single Dataset var_stats = xr.Dataset(stats_dict) # Persist immediately to trigger computation and cache results if persist and self.use_dask: var_stats = var_stats.persist() print(f" ✓ {name} persisted to distributed memory") results[name] = var_stats print(f"✓ Frequency statistics complete") # Return appropriate format if is_dataset: if len(results) == 1: return list(results.values())[0] else: # Combine multiple variables combined = xr.Dataset() for var_name, var_stats in results.items(): for stat_name in var_stats.data_vars: combined[f'{var_name}_{stat_name}'] = var_stats[stat_name] return combined else: return results['data']
[docs] def plot_psd(self, style='quantile', percentiles=[1, 10, 25, 50, 75, 90, 99], freq_range=None, db_range=None, scale='log', cmap='bimary', colors=None, linewidth=1.5, alpha=0.7, legend_loc='best', title=None, figsize=(10, 6), dpi=100, save_path=None, show=True, return_data=False): """ Plot Power Spectral Density (PSD) with statistical summaries across frequency. Creates visualizations showing the distribution of sound levels across frequency using either quantile lines (with confidence bands) or density heatmaps. Parameters ---------- style : str, optional Visualization style (default: 'quantile'): - 'quantile': Line plot with percentile bands - 'density': 2D histogram heatmap showing distribution - 'both': Overlay density heatmap with quantile lines on top percentiles : list of float, optional Percentiles to display (default: [1, 10, 25, 50, 75, 90, 99]). Each percentile is plotted as a separate line labeled as L1, L10, etc. freq_range : tuple, optional Frequency range to display (min_freq, max_freq) in Hz. If None, uses full frequency range from data. db_range : tuple, optional dB scale limits (min_db, max_db) for y-axis. If None, auto-scales to data range. scale : str, optional Frequency axis scale: 'log' (default) or 'linear' cmap : str, optional Matplotlib colormap for density plots (default: 'viridis'). Good options: 'viridis', 'plasma', 'inferno', 'Blues', 'YlOrRd' colors : str, list, or colormap, optional Colors for percentile lines. Can be: - None: Auto-generated colors - String: Colormap name (e.g., 'viridis', 'plasma') - List: Explicit list of colors linewidth : float, optional Line width for percentile curves (default: 1.5) alpha : float, optional Line transparency (default: 0.7) legend_loc : str, optional Legend location (default: 'best'). Same options as plot_multiyear_overlay: - 'best', 'upper right', 'lower left', etc. - 'outside right', 'outside right top', 'outside right bottom' - 'outside left', 'outside left top', 'outside left bottom' title : str, optional Plot title (auto-generated if None) figsize : tuple, optional Figure size in inches (default: (10, 6)) dpi : int, optional Resolution in dots per inch (default: 100) save_path : str, optional Path to save figure (PNG, PDF, etc.) show : bool, optional Whether to display the figure (default: True) return_data : bool, optional If True, return computed statistics instead of plotting. Returns ------- matplotlib.figure.Figure or dict Figure object if return_data=False, otherwise dict of computed statistics Examples -------- >>> # Load data and plot PSD quantiles with default percentiles >>> hmd = HMD(n_workers=4) >>> hmd.load_nc_files('deployment_01/', time_range=('2020-01-01', '2020-02-01')) >>> hmd.plot_psd(style='quantile') # Shows L1, L10, L25, L50, L75, L90, L99 >>> >>> # Custom percentiles with legend outside and custom title >>> hmd.plot_psd(style='quantile', ... percentiles=[5, 50, 95], ... title='Acoustic Spectral Characteristics', ... legend_loc='outside right top') >>> >>> # Use colormap for percentile lines >>> hmd.plot_psd(style='quantile', ... colors='viridis', ... linewidth=2, ... alpha=0.8, ... title='PSD Analysis') >>> >>> # Density plot with custom frequency range >>> hmd.plot_psd(style='density', freq_range=(20, 2000), scale='log') >>> >>> # Overlay density and quantile plots >>> hmd.plot_psd(style='both', figsize=(12, 6)) >>> >>> # Get statistics for further analysis >>> stats = hmd.plot_psd(return_data=True) Notes ----- - PSD plots show spectral characteristics of the acoustic environment - Quantile plots are useful for understanding typical and extreme levels - Density plots reveal the full distribution of sound levels at each frequency - Log scale is recommended for frequency axis to better visualize wide ranges """ import matplotlib.pyplot as plt from matplotlib.colors import Normalize import numpy as np self._check_loaded() print(f"Creating PSD plot (style: {style})...") # Subset frequency range if specified if freq_range is not None: print(f" Subsetting frequency range: {freq_range[0]}-{freq_range[1]} Hz") psd = self.ds.psd.sel(frequency=slice(freq_range[0], freq_range[1])) else: psd = self.ds.psd # Use compute_frequency_stats for optimized statistics computation stats_ds = self.compute_frequency_stats( data=psd, percentiles=percentiles, include_mean=False, persist=True ) # Get frequency coordinates freqs = stats_ds.coords['frequency'].values # Return data if requested if return_data: # Convert Dataset to dict format for backward compatibility stats = {var: stats_ds[var] for var in stats_ds.data_vars if var.startswith('p')} return stats # Extract stats for plotting (compute if still lazy) stats = {} for var in stats_ds.data_vars: if var.startswith('p'): if hasattr(stats_ds[var], 'compute'): stats[var] = stats_ds[var].compute() else: stats[var] = stats_ds[var] # Create figure fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi) # Plot density first (as background) if style is 'both' or 'density' if style in ['density', 'both']: print(" Creating density plot...") # Determine dB bins if db_range is not None: db_min, db_max = db_range else: # Compute percentiles from the data without loading all into memory if hasattr(psd, 'compute'): # Use a sample for percentile estimation to avoid memory issues print(" Estimating dB range from data sample...") sample = psd.isel(time=slice(0, None, max(1, len(psd.time) // 10000))).compute() db_min = np.nanpercentile(sample.values, 1) db_max = np.nanpercentile(sample.values, 99) else: db_min = np.nanpercentile(psd.values, 1) db_max = np.nanpercentile(psd.values, 99) n_db_bins = 100 db_bins = np.linspace(db_min, db_max, n_db_bins + 1) db_centers = (db_bins[:-1] + db_bins[1:]) / 2 # Create density matrix (frequency x dB level) density_matrix = np.zeros((len(freqs), len(db_centers))) # Compute histogram for each frequency separately to save memory print(f" Computing density histograms for {len(freqs)} frequency bins...") for i, freq in enumerate(freqs): # Extract just this frequency across all time freq_data = psd.sel(frequency=freq, method='nearest') # Compute this frequency's data if hasattr(freq_data, 'compute'): freq_values = freq_data.compute().values else: freq_values = freq_data.values # Remove NaNs freq_data_clean = freq_values[~np.isnan(freq_values)] if len(freq_data_clean) > 0: # Compute histogram hist, _ = np.histogram(freq_data_clean, bins=db_bins) # Normalize to probability density density_matrix[i, :] = hist / hist.sum() if hist.sum() > 0 else hist # Progress indicator every 100 frequencies if (i + 1) % 100 == 0: print(f" Progress: {i + 1}/{len(freqs)} frequencies") # Plot density as pcolormesh # Create meshgrid edges freq_edges = np.concatenate([ [freqs[0] * 0.95], # Extend slightly before first (freqs[:-1] + freqs[1:]) / 2, # Midpoints [freqs[-1] * 1.05] # Extend slightly after last ]) db_edges = db_bins im = ax.pcolormesh( freq_edges, db_edges, density_matrix.T, cmap=cmap, shading='flat', rasterized=True, zorder=1 # Behind the lines ) # Add colorbar (smaller for overlay) if style == 'both': cbar = plt.colorbar(im, ax=ax, pad=0.02, fraction=0.046) else: cbar = plt.colorbar(im, ax=ax, pad=0.02) cbar.set_label('Probability Density', fontsize=10, fontweight='bold') # Plot quantile lines on top if style is 'both' or 'quantile' if style in ['quantile', 'both']: # Sort percentiles percentiles_sorted = sorted(percentiles) n_percentiles = len(percentiles_sorted) # Set up colors for lines if colors is None: # Use a colormap for better distinction cmap_obj = plt.cm.get_cmap('tab10' if n_percentiles <= 10 else 'tab20') line_colors = [cmap_obj(i % cmap_obj.N) for i in range(n_percentiles)] elif isinstance(colors, str): # String provided - assume it's a colormap name try: cmap_obj = plt.cm.get_cmap(colors) line_colors = [cmap_obj(i / max(n_percentiles - 1, 1)) for i in range(n_percentiles)] except: # If not a valid colormap, treat as single color line_colors = [colors] * n_percentiles elif hasattr(colors, 'N'): # It's a colormap object line_colors = [colors(i / max(n_percentiles - 1, 1)) for i in range(n_percentiles)] else: # Assume it's a list of colors line_colors = colors # Plot all percentile lines for idx, p in enumerate(percentiles_sorted): p_values = stats[f'p{p}'].values color = line_colors[idx % len(line_colors)] # Higher zorder to plot on top of density ax.plot(freqs, p_values, color=color, linewidth=linewidth, alpha=alpha, label=f'L{p}', # Use L notation zorder=3) # Add legend with same options as plot_multiyear_overlay if legend_loc == 'outside right': ax.legend(loc='center left', bbox_to_anchor=(1.15, 0.5), framealpha=0.9, ncol=1) elif legend_loc == 'outside right top': ax.legend(loc='upper left', bbox_to_anchor=(1.15, 1.0), framealpha=0.9, ncol=1) elif legend_loc == 'outside right bottom': ax.legend(loc='lower left', bbox_to_anchor=(1.15, 0.0), framealpha=0.9, ncol=1) elif legend_loc == 'outside left': ax.legend(loc='center right', bbox_to_anchor=(-0.02, 0.5), framealpha=0.9, ncol=1) elif legend_loc == 'outside left top': ax.legend(loc='upper right', bbox_to_anchor=(-0.02, 1.0), framealpha=0.9, ncol=1) elif legend_loc == 'outside left bottom': ax.legend(loc='lower right', bbox_to_anchor=(-0.02, 0.0), framealpha=0.9, ncol=1) else: ax.legend(loc=legend_loc, framealpha=0.9) # Set scale and labels (common for all styles) if scale == 'log': ax.set_xscale('log') ax.set_xlabel('Frequency (Hz)', fontsize=11, fontweight='bold') ax.set_ylabel('Power Spectral Density (dB re 1 µPa²/Hz)', fontsize=11, fontweight='bold') if db_range is not None: ax.set_ylim(db_range) # Add grid if style == 'quantile': # Full grid for quantile only ax.grid(True, which='major', alpha=0.3, linestyle='--') ax.grid(True, which='minor', alpha=0.15, linestyle=':') else: # X-axis grid only for density and both (y-axis grid would obscure density) ax.grid(True, which='major', alpha=0.3, linestyle='--', axis='x') ax.grid(True, which='minor', alpha=0.15, linestyle=':', axis='x') # Enable minor ticks for better grid granularity ax.minorticks_on() # Set title if title is not None: ax.set_title(title, fontsize=12, fontweight='bold') else: # Auto-generate title based on style time_start = pd.to_datetime(self.ds.time.min().values).strftime('%Y-%m-%d') time_end = pd.to_datetime(self.ds.time.max().values).strftime('%Y-%m-%d') if style == 'both': subtitle = f'Power Spectral Density Analysis\n{time_start} to {time_end}' elif style == 'quantile': subtitle = f'Power Spectral Density - Quantile Plot\n{time_start} to {time_end}' else: # density subtitle = f'Power Spectral Density - Density Plot\n{time_start} to {time_end}' ax.set_title(subtitle, fontsize=12, fontweight='bold') plt.tight_layout() # Save if requested if save_path: plt.savefig(save_path, dpi=dpi, bbox_inches='tight') print(f"✓ Saved PSD plot to {save_path}") # Show if requested if show: plt.show() else: plt.close(fig) return fig
[docs] def plot_ltsa(self, bin='1H', freq_range=None, db_range=(32, 108), scale='log', cmap='rainbow', statistic='median', plot_date_range=None, title=None, figsize=(14, 6), dpi=100, save_path=None, show=True, return_data=False): """ Plot Long-Term Spectral Average (LTSA) spectrogram. Creates a spectrogram visualization with time on x-axis, frequency on y-axis, and color representing sound intensity. Data is binned in time and the specified statistic (median, mean, etc.) is computed for each time-frequency bin. Parameters ---------- bin : str, optional Time interval for binning (default: '1H'). Examples: - '1H': 1 hour bins - '6H': 6 hour bins - '1D': 1 day bins - '1W': 1 week bins freq_range : tuple, optional Frequency range to display (min_freq, max_freq) in Hz. If None, uses full frequency range from data. db_range : tuple, optional Fixed dB scale limits (min_db, max_db) for color mapping. If None, auto-scales to data range. scale : str, optional Frequency axis scale: 'log' (default) or 'linear' cmap : str, optional Matplotlib colormap name (default: 'rainbow'). Good options: 'rainbow', 'jet', 'viridis', 'plasma', 'inferno' statistic : str, optional Statistic to compute for each bin (default: 'median'). Options: 'median', 'mean', 'min', 'max', 'std' plot_date_range : tuple, list, or str, optional Date range to display on x-axis (default: None, uses full data range). Can be: - Tuple/list of two dates: (start_date, end_date) as strings or datetime - 'fullyear': Automatically set to Jan 1 - Dec 31 of the year being plotted - None: Use full range of available data title : str, optional Plot title (auto-generated if None) figsize : tuple, optional Figure size in inches (default: (14, 6)) dpi : int, optional Resolution in dots per inch (default: 100) save_path : str, optional Path to save figure (PNG, PDF, etc.) show : bool, optional Whether to display the figure (default: True) return_data : bool, optional If True, return the binned data array instead of plotting. Useful for further analysis or custom plotting. Returns ------- matplotlib.figure.Figure or xarray.DataArray Figure object if return_data=False, otherwise the binned data Examples -------- >>> # Load data and plot LTSA with 6-hour bins >>> hmd = HMD(n_workers=4) >>> hmd.load_nc_files('deployment_01/', time_range=('2020-01-01', '2020-02-01')) >>> hmd.plot_ltsa(bin='6H', freq_range=(50, 1000), scale='log') >>> >>> # Daily bins with custom color range >>> hmd.plot_ltsa(bin='1D', db_range=(60, 120), cmap='plasma') >>> >>> # Display full year (Jan 1 - Dec 31) >>> hmd.plot_ltsa(bin='1D', plot_date_range='fullyear') >>> >>> # Display specific date range >>> hmd.plot_ltsa(bin='1H', plot_date_range=('2020-06-01', '2020-06-30')) >>> >>> # Get binned data for further analysis >>> ltsa_data = hmd.plot_ltsa(bin='1H', return_data=True) Notes ----- - LTSA is useful for visualizing long-term acoustic patterns - Binning reduces data volume while preserving key features - Log scale is recommended for frequency axis in most cases - Median is robust to outliers compared to mean """ import matplotlib.pyplot as plt import matplotlib.dates as mdates from matplotlib.colors import LogNorm, Normalize self._check_loaded() print(f"Creating LTSA plot with {bin} bins...") # Subset frequency range if specified if freq_range is not None: print(f" Subsetting frequency range: {freq_range[0]}-{freq_range[1]} Hz") data = self.ds.sel(frequency=slice(freq_range[0], freq_range[1])) else: data = self.ds # Get the PSD data psd = data.psd # Optimize chunking for Dask performance if not hasattr(psd.data, 'dask'): # Data is already computed (numpy array) - re-chunk as Dask array print(f" Converting computed data back to Dask array for optimal resampling...") # Determine optimal chunk size based on bin resolution if bin in ['1h', '1H']: time_chunk = 24 * 60 # 1 day worth of minute data elif bin in ['6h', '6H']: time_chunk = 24 * 10 # ~1 day worth of 6-hour bins elif bin in ['1D', '1d']: time_chunk = 365 # 1 year worth of daily data else: # Default: estimate good chunk size time_chunk = min(10000, len(psd.time)) psd = psd.chunk({'time': time_chunk, 'frequency': -1}) else: # Data is already lazy - ensure optimal chunking for resampling if bin in ['1h', '1H']: time_chunk = 24 * 60 # 1 day worth of minute data elif bin in ['6h', '6H']: time_chunk = 24 * 10 # ~1 day worth of 6-hour bins elif bin in ['1D', '1d']: time_chunk = 30 * 1440 # 1 month worth of minute data else: time_chunk = 10000 # Default large chunk psd = psd.chunk({'time': time_chunk, 'frequency': -1}) # Resample in time using the specified statistic print(f" Resampling time to {bin} intervals using {statistic}...") resampler = psd.resample(time=bin, label='left') if statistic == 'median': ltsa_data = resampler.median(skipna=True) elif statistic == 'mean': ltsa_data = resampler.mean(skipna=True) elif statistic == 'min': ltsa_data = resampler.min(skipna=True) elif statistic == 'max': ltsa_data = resampler.max(skipna=True) elif statistic == 'std': ltsa_data = resampler.std(skipna=True) else: raise ValueError(f"Unknown statistic: {statistic}. " f"Choose from: median, mean, min, max, std") # Persist to distributed memory if using Dask (speeds up subsequent operations) if self.use_dask and hasattr(ltsa_data, 'persist'): print(" Persisting LTSA data to distributed memory...") ltsa_data = ltsa_data.persist() # Compute if using Dask if hasattr(ltsa_data, 'compute'): print(" Computing LTSA data...") ltsa_data = ltsa_data.compute() print(f" LTSA shape: {ltsa_data.shape[0]} time bins × {ltsa_data.shape[1]} frequency bins") # Return data if requested if return_data: return ltsa_data # Create plot fig, ax = plt.subplots(figsize=figsize, dpi=dpi) # Get time and frequency coordinates times = ltsa_data.coords['time'].values freqs = ltsa_data.coords['frequency'].values values = ltsa_data.values # Keep original orientation (time x frequency) # Set up color normalization if db_range is not None: vmin, vmax = db_range else: vmin, vmax = np.nanpercentile(values, [1, 99]) # Plot the spectrogram using pcolormesh # Following pbp library approach: time on x-axis, frequency on y-axis im = ax.pcolormesh( times, # Time coordinates directly (matplotlib handles datetime64) freqs, # Frequency coordinates values.T, # Transpose: (frequency, time) for pcolormesh cmap=cmap, vmin=vmin, vmax=vmax, shading='nearest', # Nearest neighbor (like pbp) rasterized=True # Faster rendering for large datasets ) # ax0 = fig.add_subplot(spec[2]) # vmin, vmax = cmlim # sg = plt.pcolormesh( # ds.time, ds.frequency, da, shading="nearest", cmap="rainbow", vmin=vmin, vmax=vmax # ) # plt.yscale("log") # plt.ylim(list(ylim)) # plt.ylabel(freqlabl) # xl = ax0.get_xlim() # ax0.set_xticks([]) # Set frequency scale if scale == 'log': ax.set_yscale('log') # Set frequency limits if specified if freq_range is not None: ax.set_ylim(freq_range) # Set x-axis (time) limits if specified if plot_date_range is not None: if plot_date_range == 'fullyear': # Automatically set to Jan 1 - Dec 31 of the year in the data # Use the first time point to determine the year first_time = pd.to_datetime(times[0]) year = first_time.year xlim_start = pd.Timestamp(f'{year}-01-01') xlim_end = pd.Timestamp(f'{year}-12-31 23:59:59') ax.set_xlim(xlim_start, xlim_end) print(f" Setting x-axis to full year: {year}") elif isinstance(plot_date_range, (list, tuple)) and len(plot_date_range) == 2: # User-specified date range xlim_start = pd.Timestamp(plot_date_range[0]) xlim_end = pd.Timestamp(plot_date_range[1]) ax.set_xlim(xlim_start, xlim_end) print(f" Setting x-axis range: {xlim_start.strftime('%Y-%m-%d')} to {xlim_end.strftime('%Y-%m-%d')}") else: raise ValueError( "plot_date_range must be 'fullyear' or a tuple/list of two dates. " f"Got: {plot_date_range}" ) # Format x-axis (time) using concise date formatter like pbp ax.xaxis.set_major_formatter( mdates.ConciseDateFormatter(ax.xaxis.get_major_locator()) ) # Labels and title ax.set_xlabel('Time', fontsize=11, fontweight='bold') ax.set_ylabel('Frequency (Hz)', fontsize=11, fontweight='bold') if title is None: time_start = pd.to_datetime(times[0]).strftime('%Y-%m-%d') time_end = pd.to_datetime(times[-1]).strftime('%Y-%m-%d') title = f'Long-Term Spectral Average (LTSA)\n{time_start} to {time_end} | Bin: {bin} | Statistic: {statistic.capitalize()}' ax.set_title(title, fontsize=12, fontweight='bold', pad=10) # Add colorbar cbar = plt.colorbar(im, ax=ax, pad=0.02) cbar.set_label('Power Spectral Density (dB re 1 µPa²/Hz)', fontsize=10, fontweight='bold') plt.tight_layout() # Save if requested if save_path: plt.savefig(save_path, dpi=dpi, bbox_inches='tight') print(f"✓ Saved LTSA plot to {save_path}") # Show if requested if show: plt.show() else: plt.close(fig) return fig
[docs] def plot_multiyear_overlay(self, data, band_names=None, time_axis='dayofyear', smoothing_window=None, title=None, xlabel=None, ylabel='Sound Level (dB)', figsize=None, colors=None, alpha=0.7, linewidth=1.5, grid=True, legend_loc='best', save_path=None, show=True, show_median=False, median_color='black', median_linewidth=2.5, median_linestyle='-', median_alpha=1.0, show_percentile_range=False, percentiles=[10, 90], range_color='lightgray', range_alpha=0.3, **kwargs): """ Plot multi-year overlay of time series data with aligned time axes. Creates plots where each year is shown as a separate line, aligned by day of year, week of year, etc. Useful for comparing seasonal patterns across multiple years. Parameters ---------- data : xarray.DataArray or xarray.Dataset Time series data spanning multiple years. - If DataArray: plots single variable across years - If Dataset: plots selected or all data variables as subplots band_names : str or list of str, optional For Dataset input, specify which variables to plot. If None, plots all data variables. time_axis : str, optional How to align time across years (default: 'dayofyear'): - 'dayofyear': Day of year (1-366), best for daily+ resolution - 'weekofyear': Week of year (1-53), good for weekly+ data - 'month': Month (1-12), for monthly data - 'dayofweek': Day of week (0-6), for analyzing weekly patterns smoothing_window : int, optional Apply rolling mean smoothing with this window size. Units match time_axis (e.g., days for dayofyear). Default: None (no smoothing) title : str, optional Plot title (auto-generated if None) xlabel : str, optional X-axis label. If None, auto-generated based on time_axis (e.g., 'Day of Year', 'Week of Year', etc.) ylabel : str, optional Y-axis label (default: 'Sound Level (dB)') figsize : tuple, optional Figure size. Default: (14, 6) for single plot, scaled for multiple colors : str, list, or colormap, optional Colors for each year. Can be: - None: Auto-generated using 'tab10' or 'tab20' colormap - String: Colormap name (e.g., 'viridis', 'plasma', 'coolwarm') - Colormap object: matplotlib.cm colormap - List: Explicit list of colors for each year alpha : float, optional Line transparency (default: 0.7) linewidth : float, optional Line width (default: 1.5) grid : bool, optional Show grid (default: True) legend_loc : str, optional Legend location (default: 'best'). Can be: - 'best', 'upper right', 'upper left', 'lower left', 'lower right', etc. - 'outside right': Places legend outside plot on the right (centered) - 'outside right top': Places legend outside plot on the right (top-aligned) - 'outside right bottom': Places legend outside plot on the right (bottom-aligned) - 'outside left': Places legend outside plot on the left (centered) - 'outside left top': Places legend outside plot on the left (top-aligned) - 'outside left bottom': Places legend outside plot on the left (bottom-aligned) save_path : str, optional Path to save figure show : bool, optional Whether to display the figure (default: True) show_median : bool, optional If True, display median line across all years (default: False) median_color : str, optional Color for median line (default: 'black') median_linewidth : float, optional Line width for median line (default: 2.5) median_linestyle : str, optional Line style for median ('-', '--', '-.', ':', default: '-') median_alpha : float, optional Transparency for median line (default: 1.0) show_percentile_range : bool, optional If True, display shaded range between percentiles (default: False) percentiles : list of two numbers, optional Lower and upper percentiles for range (default: [10, 90]) range_color : str, optional Color for percentile range shading (default: 'lightgray') range_alpha : float, optional Transparency for percentile range (default: 0.3) **kwargs : dict Additional matplotlib plot arguments Returns ------- matplotlib.figure.Figure Examples -------- >>> # Load multi-year data >>> hmd.load_nc_files('deployment/', time_range=('2018-01-01', '2021-01-01')) >>> band_levels = hmd.extract_band_levels([[50, 300]], ['ship']) >>> stats = hmd.compute_timeseries_stats(band_levels, resolution='1D') >>> >>> # Plot mean levels across years, aligned by day of year >>> hmd.plot_multiyear_overlay(stats['ship_mean'], ... time_axis='dayofyear', ... smoothing_window=7) # 7-day smoothing >>> >>> # Compare multiple bands across years >>> hmd.plot_multiyear_overlay(stats, ... band_names=['ship_mean', 'fish_mean'], ... time_axis='weekofyear') >>> >>> # Use a colormap for years >>> hmd.plot_multiyear_overlay(stats['ship_mean'], ... colors='viridis', # Colormap name ... time_axis='dayofyear') >>> >>> # Use a diverging colormap >>> import matplotlib.pyplot as plt >>> hmd.plot_multiyear_overlay(stats['ship_mean'], ... colors='coolwarm', ... time_axis='dayofyear') >>> >>> # Use explicit colors >>> hmd.plot_multiyear_overlay(stats['ship_mean'], ... colors=['red', 'blue', 'green'], ... time_axis='dayofyear') >>> >>> # Place legend outside plot on the right >>> hmd.plot_multiyear_overlay(stats['ship_mean'], ... legend_loc='outside right', ... time_axis='dayofyear') >>> >>> # Place legend outside on the right, top-aligned >>> hmd.plot_multiyear_overlay(stats['ship_mean'], ... legend_loc='outside right top', ... time_axis='dayofyear') >>> >>> # Custom axis labels >>> hmd.plot_multiyear_overlay(stats['ship_mean'], ... xlabel='Day of Year (Jan 1 = 1)', ... ylabel='SPL (dB re 1 µPa)', ... time_axis='dayofyear') """ import matplotlib.pyplot as plt import matplotlib.dates as mdates # Determine what we're plotting if isinstance(data, xr.DataArray): plot_data = {data.name or 'Time Series': data} is_single = True elif isinstance(data, xr.Dataset): if band_names is not None: if isinstance(band_names, str): band_names = [band_names] missing = [b for b in band_names if b not in data.data_vars] if missing: raise ValueError(f"Band names not found in Dataset: {missing}") plot_data = {var: data[var] for var in band_names} else: plot_data = {var: data[var] for var in data.data_vars} is_single = len(plot_data) == 1 else: raise ValueError("data must be xarray.DataArray or xarray.Dataset") # Compute data if needed for name in plot_data: if hasattr(plot_data[name], 'compute'): print(f"Computing {name}...") plot_data[name] = plot_data[name].compute() n_vars = len(plot_data) # Set default figure size if figsize is None: if is_single: figsize = (14, 6) else: figsize = (14, min(4 * n_vars, 12)) # Create subplots if is_single: fig, axes = plt.subplots(1, 1, figsize=figsize) axes = [axes] else: fig, axes = plt.subplots(n_vars, 1, figsize=figsize, sharex=True) if n_vars == 1: axes = [axes] # Process each variable for var_idx, (var_name, ts) in enumerate(plot_data.items()): ax = axes[var_idx] # Convert to pandas for easier time manipulation df = ts.to_dataframe(name='value').reset_index() df['time'] = pd.to_datetime(df['time']) df = df.dropna(subset=['value']) # Extract year and time alignment column df['year'] = df['time'].dt.year if time_axis == 'dayofyear': df['time_aligned'] = df['time'].dt.dayofyear default_xlabel = 'Day of Year' xticks_spacing = 30 # ~monthly xlim = (1, 366) elif time_axis == 'weekofyear': df['time_aligned'] = df['time'].dt.isocalendar().week default_xlabel = 'Week of Year' xticks_spacing = 4 # ~monthly xlim = (1, 53) elif time_axis == 'month': df['time_aligned'] = df['time'].dt.month default_xlabel = 'Month' xticks_spacing = 1 xlim = (1, 12) elif time_axis == 'dayofweek': df['time_aligned'] = df['time'].dt.dayofweek default_xlabel = 'Day of Week' xticks_spacing = 1 xlim = (0, 6) else: raise ValueError(f"Unknown time_axis: {time_axis}. " f"Choose from: dayofyear, weekofyear, month, dayofweek") # Get unique years years = sorted(df['year'].unique()) n_years = len(years) # Set up colors if colors is None: # Use a colormap for better year distinction cmap = plt.cm.get_cmap('tab10' if n_years <= 10 else 'tab20') year_colors = [cmap(i % cmap.N) for i in range(n_years)] elif isinstance(colors, str): # String provided - assume it's a colormap name try: cmap = plt.cm.get_cmap(colors) year_colors = [cmap(i / max(n_years - 1, 1)) for i in range(n_years)] except: # If not a valid colormap, treat as single color year_colors = [colors] * n_years elif hasattr(colors, 'N'): # It's a colormap object year_colors = [colors(i / max(n_years - 1, 1)) for i in range(n_years)] else: # Assume it's a list of colors year_colors = colors # Plot each year for year_idx, year in enumerate(years): year_data = df[df['year'] == year].copy() # Sort by aligned time year_data = year_data.sort_values('time_aligned') # Apply smoothing if requested if smoothing_window is not None and smoothing_window > 1: year_data['value'] = year_data['value'].rolling( window=smoothing_window, center=True, min_periods=1 ).mean() # Detect gaps in time sequence and break lines at those points # This prevents matplotlib from connecting across missing data periods color = year_colors[year_idx % len(year_colors)] # Calculate expected time step based on resolution time_diffs = year_data['time_aligned'].diff() # For most resolutions, the typical diff should be 1 # But account for edge cases around year boundaries if len(time_diffs) > 1: # Use median to get typical step size typical_step = time_diffs[time_diffs > 0].median() # A gap is when the diff is more than 1.5x the typical step gap_threshold = typical_step * 1.5 else: gap_threshold = 2 # Default threshold # Find where gaps occur is_gap = time_diffs > gap_threshold # Find continuous segments between gaps segment_breaks = year_data.index[is_gap].tolist() # Create list of segment boundaries segment_bounds = [year_data.index[0]] + segment_breaks + [year_data.index[-1]] # Plot each continuous segment separately first_segment = True for i in range(len(segment_bounds) - 1): start_idx = segment_bounds[i] end_idx = segment_bounds[i + 1] # Get the segment if i < len(segment_bounds) - 2: # Not the last segment - exclude the gap point segment = year_data.loc[start_idx:end_idx].iloc[:-1] else: # Last segment - include all points segment = year_data.loc[start_idx:end_idx] # Skip empty segments if len(segment) == 0: continue # Only add label to first segment to avoid duplicate legend entries label_text = str(year) if first_segment else None ax.plot(segment['time_aligned'], segment['value'], label=label_text, color=color, alpha=alpha, linewidth=linewidth, **kwargs) first_segment = False # Compute and plot median line across all years if requested if show_median: # Group all data by time_aligned and compute median median_by_aligned = df.groupby('time_aligned')['value'].median() # Plot median line ax.plot(median_by_aligned.index, median_by_aligned.values, color=median_color, linewidth=median_linewidth, linestyle=median_linestyle, alpha=median_alpha, label='Median', zorder=10) # High zorder to plot on top # Compute and plot percentile range across all years if requested if show_percentile_range: if len(percentiles) != 2: raise ValueError("percentiles must be a list of exactly 2 values [lower, upper]") # Group all data by time_aligned and compute percentiles lower_percentile = df.groupby('time_aligned')['value'].quantile(percentiles[0] / 100) upper_percentile = df.groupby('time_aligned')['value'].quantile(percentiles[1] / 100) # Plot shaded range ax.fill_between(lower_percentile.index, lower_percentile.values, upper_percentile.values, color=range_color, alpha=range_alpha, label=f'{percentiles[0]}-{percentiles[1]}% Range', zorder=1) # Low zorder to plot behind lines # Formatting # Use custom xlabel if provided, otherwise use default based on time_axis final_xlabel = xlabel if xlabel is not None else default_xlabel ax.set_xlabel(final_xlabel, fontsize=10) ax.set_ylabel(ylabel if is_single else f'{var_name}\n({ylabel})', fontsize=10) ax.set_xlim(xlim) # Set x-ticks if time_axis == 'dayofyear': # Show month labels month_starts = [1, 32, 60, 91, 121, 152, 182, 213, 244, 274, 305, 335] month_names = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'] ax.set_xticks(month_starts) ax.set_xticklabels(month_names) elif time_axis == 'weekofyear': ax.set_xticks(np.arange(1, 54, xticks_spacing)) elif time_axis == 'month': month_names = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'] ax.set_xticks(range(1, 13)) ax.set_xticklabels(month_names) elif time_axis == 'dayofweek': day_names = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'] ax.set_xticks(range(7)) ax.set_xticklabels(day_names) if grid: ax.grid(True, alpha=0.3, linestyle='--') # Add legend if legend_loc == 'outside right': ax.legend(loc='center left', bbox_to_anchor=(1.02, 0.5), framealpha=0.9, title='Year', ncol=1) elif legend_loc == 'outside right top': ax.legend(loc='upper left', bbox_to_anchor=(1.02, 1.0), framealpha=0.9, title='Year', ncol=1) elif legend_loc == 'outside right bottom': ax.legend(loc='lower left', bbox_to_anchor=(1.02, 0.0), framealpha=0.9, title='Year', ncol=1) elif legend_loc == 'outside left': ax.legend(loc='center right', bbox_to_anchor=(-0.02, 0.5), framealpha=0.9, title='Year', ncol=1) elif legend_loc == 'outside left top': ax.legend(loc='upper right', bbox_to_anchor=(-0.02, 1.0), framealpha=0.9, title='Year', ncol=1) elif legend_loc == 'outside left bottom': ax.legend(loc='lower right', bbox_to_anchor=(-0.02, 0.0), framealpha=0.9, title='Year', ncol=1) else: ax.legend(loc=legend_loc, framealpha=0.9, title='Year', ncol=min(n_years, 6)) # Add title for single variable or subplot titles if is_single and title is None: smooth_text = f" ({smoothing_window}-{time_axis.replace('of', ' ')} smoothing)" if smoothing_window else "" title = f'Multi-Year Overlay: {var_name}{smooth_text}' if is_single: ax.set_title(title, fontsize=12, fontweight='bold') elif n_vars > 1: ax.set_title(var_name, fontsize=11, fontweight='bold') # Overall title for multiple variables if not is_single and title: fig.suptitle(title, fontsize=13, fontweight='bold', y=0.995) plt.tight_layout() # Save if requested if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f"✓ Saved to {save_path}") # Show if requested if show: plt.show() else: plt.close(fig) return fig
[docs] def plot_timeseries(self, data, band_names=None, overlay=True, title=None, ylabel='Sound Level (dB)', xlabel='Time', figsize=None, colors=None, alpha=0.7, linewidth=1.5, grid=True, legend_loc='best', sharex=True, sharey=False, save_path=None, **kwargs): """ Plot time series of pre-extracted acoustic data. Parameters ---------- data : xarray.DataArray or xarray.Dataset Pre-computed time series data. - If DataArray: plots single time series - If Dataset: plots selected or all data variables band_names : str or list of str, optional For Dataset input, specify which variables to plot. If None, plots all data variables. For DataArray input, this is ignored. overlay : bool, optional If True (default), plot all series on same axes. If False, create separate subplot for each series. title : str, optional Plot title (auto-generated if None) figsize : tuple, optional Figure size. Defaults based on overlay and number of series: - Single: (12, 4) - Multiple overlaid: (14, 6) - Multiple subplots: (12, 3*n_series) colors : str or list, optional Colors for plotting. Single color or list of colors. alpha : float, optional Line transparency (default: 0.7) linewidth : float, optional Line width (default: 1.5) grid : bool, optional Show grid (default: True) legend_loc : str, optional Legend location (default: 'best'), only used when overlay=True sharex : bool, optional Share x-axis across subplots when overlay=False (default: True) sharey : bool, optional Share y-axis across subplots when overlay=False (default: False) save_path : str, optional Path to save figure **kwargs : dict Additional matplotlib plot arguments Returns ------- matplotlib.figure.Figure Examples -------- >>> # Single time series >>> hmd.plot_timeseries(result['ship']) >>> # Multiple bands overlaid (default) >>> hmd.plot_timeseries(result, band_names=['ship', 'fish']) >>> # Multiple bands as separate subplots >>> hmd.plot_timeseries(result, band_names=['ship', 'fish'], overlay=False) >>> # All bands as subplots with shared y-axis >>> hmd.plot_timeseries(result, overlay=False, sharey=True) """ import matplotlib.pyplot as plt import matplotlib.dates as mdates # Determine what we're plotting if isinstance(data, xr.DataArray): # Single time series plot_data = {data.name or 'Time Series': data} is_single = True elif isinstance(data, xr.Dataset): # Multiple time series from Dataset if band_names is not None: # Plot specific bands if isinstance(band_names, str): band_names = [band_names] # Validate band names missing = [b for b in band_names if b not in data.data_vars] if missing: raise ValueError(f"Band names not found in Dataset: {missing}") plot_data = {var: data[var] for var in band_names} else: # Plot all data variables plot_data = {var: data[var] for var in data.data_vars} is_single = len(plot_data) == 1 else: raise ValueError("data must be xarray.DataArray or xarray.Dataset") # Ensure data is computed (not lazy) for name in plot_data: if hasattr(plot_data[name], 'compute'): print(f"Computing {name}...") plot_data[name] = plot_data[name].compute() n_series = len(plot_data) # Set default figure size based on single vs multiple and overlay mode if figsize is None: if is_single: figsize = (12, 4) elif overlay: figsize = (14, 6) else: # Subplots: scale height with number of series figsize = (12, min(3 * n_series, 12)) # Cap at 12 inches height # Set up colors if colors is None: if is_single or overlay: colors = ['steelblue', 'darkgreen', 'coral', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan', 'red'] else: # For subplots, can use same color for all colors = ['steelblue'] * n_series elif isinstance(colors, str): colors = [colors] * n_series # Create figure and axes if is_single or overlay: # Single plot or overlaid plots fig, ax = plt.subplots(figsize=figsize) axes = [ax] * n_series # Same axis for all else: # Separate subplots fig, axes = plt.subplots(n_series, 1, figsize=figsize, sharex=sharex, sharey=sharey, squeeze=False) axes = axes.flatten() # Plot each time series for idx, (name, ts) in enumerate(plot_data.items()): ax = axes[idx] color = colors[idx % len(colors)] if overlay and not is_single: # Overlaid plot needs labels for legend label = name else: # Subplots don't need legend label = None ax.plot(ts.time.values, ts.values, label=label, color=color, alpha=alpha, linewidth=linewidth, **kwargs) # For subplots, add individual titles and y-labels if not overlay and not is_single: ax.set_title(name, fontsize=11, fontweight='bold') ax.set_ylabel(ylabel, fontsize=9) # Format x-axis for each subplot ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d')) ax.xaxis.set_major_locator(mdates.AutoDateLocator()) if grid: ax.grid(True, alpha=0.3, linestyle='--') # Set overall title and labels if overlay or is_single: # Single plot setup ax = axes[0] if title is None: if is_single: name = list(plot_data.keys())[0] title = f'Sound Level Time Series: {name}' else: n_bands = len(plot_data) title = f'Acoustic Time Series ({n_bands} bands)' ax.set_title(title, fontsize=12, fontweight='bold') ax.set_xlabel(xlabel, fontsize=10) ax.set_ylabel(ylabel, fontsize=10) # Format x-axis ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d')) ax.xaxis.set_major_locator(mdates.AutoDateLocator()) # Add legend if multiple series overlaid if overlay and not is_single and n_series > 1: ax.legend(loc=legend_loc, framealpha=0.9) if grid: ax.grid(True, alpha=0.3, linestyle='--') else: # Subplots setup if title is None: title = f'Acoustic Time Series ({n_series} bands)' fig.suptitle(title, fontsize=13, fontweight='bold', y=1.01) # Only label bottom x-axis axes[-1].set_xlabel(xlabel, fontsize=10) # Rotate x-labels fig.autofmt_xdate(rotation=45, ha='right') plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f"✓ Saved to {save_path}") else: plt.show() return fig
[docs] def timeseries_dashboard(self, data, title='Time Series Dashboard', port=5006, show=True, save_html=None): """ Create advanced interactive dashboard with variable selection and multiple views. Parameters ---------- data : xarray.Dataset or list of xarray.Dataset Time series data (must be computed). Can be a single Dataset or a list of Datasets to overlay on the same plot. title : str, optional Dashboard title port : int, optional Port for web server show : bool, optional Auto-open browser save_html : str, optional Save to HTML file Returns ------- panel.template.Template Panel template dashboard Examples -------- >>> # Single dataset >>> stats = hmd.compute_timeseries_stats(band_levels, resolution='1h') >>> hmd.plot_interactive_advanced(stats.compute()) >>> # Multiple datasets overlaid >>> hmd.plot_interactive_advanced([band_levels.compute(), stats.compute()]) """ try: import holoviews as hv import panel as pn import pandas as pd import numpy as np from holoviews import opts from bokeh.models import HoverTool except ImportError: raise ImportError( "Interactive plotting requires holoviews and panel. Install with:\n" " pip install holoviews panel bokeh" ) hv.extension('bokeh') pn.extension() # Handle single dataset or list of datasets if not isinstance(data, list): data_list = [data] else: data_list = data # Compute all datasets if needed for i in range(len(data_list)): if hasattr(data_list[i], 'compute'): print(f"Computing dataset {i + 1}/{len(data_list)}...") data_list[i] = data_list[i].compute() # Validate all are Datasets for i, ds in enumerate(data_list): if not isinstance(ds, xr.Dataset): raise ValueError(f"Item {i} must be xarray.Dataset, got {type(ds)}") # Collect all variables from all datasets with prefixes all_vars = [] var_to_dataset = {} # Map variable name to (dataset_index, original_var_name) for idx, ds in enumerate(data_list): prefix = f"Dataset{idx + 1}_" if len(data_list) > 1 else "" for var in ds.data_vars: display_name = f"{prefix}{var}" all_vars.append(display_name) var_to_dataset[display_name] = (idx, var) # Create variable selector var_selector = pn.widgets.MultiChoice( name='Select Variables to Plot', options=all_vars, value=all_vars[:min(5, len(all_vars))], # Default to first 5 width=400 ) # Color palette selector color_schemes = { 'Default': ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf'], 'Viridis': ['#440154', '#31688e', '#35b779', '#fde724', '#b5de2b'], 'Warm': ['#d62728', '#ff7f0e', '#bcbd22', '#e377c2', '#8c564b'], 'Cool': ['#1f77b4', '#17becf', '#9467bd', '#2ca02c', '#7f7f7f'] } color_selector = pn.widgets.Select( name='Color Scheme', options=list(color_schemes.keys()), value='Default', width=200 ) # Date range selector (use first dataset for range) time_min = pd.Timestamp(data_list[0].time.min().values) time_max = pd.Timestamp(data_list[0].time.max().values) # Extend range if other datasets have wider ranges for ds in data_list[1:]: ds_min = pd.Timestamp(ds.time.min().values) ds_max = pd.Timestamp(ds.time.max().values) if ds_min < time_min: time_min = ds_min if ds_max > time_max: time_max = ds_max date_range = pn.widgets.DatetimeRangePicker( name='Date Range', start=time_min, end=time_max, value=(time_min, time_max), width=400 ) # Line width slider line_width_slider = pn.widgets.FloatSlider( name='Line Width', start=0.5, end=5, step=0.5, value=2, width=200 ) # Alpha slider alpha_slider = pn.widgets.FloatSlider( name='Transparency', start=0.1, end=1.0, step=0.1, value=0.8, width=200 ) @pn.depends(var_selector.param.value, color_selector.param.value, date_range.param.value, line_width_slider.param.value, alpha_slider.param.value) def create_plot(selected_vars, color_scheme, date_range_val, line_width, alpha): if not selected_vars: return pn.pane.Markdown("### Please select at least one variable to plot") colors = color_schemes[color_scheme] curves = [] for idx, display_var in enumerate(selected_vars): # Get the dataset and original variable name ds_idx, original_var = var_to_dataset[display_var] ts = data_list[ds_idx][original_var] # Filter by date range if date_range_val: start, end = date_range_val ts = ts.sel(time=slice(start, end)) # Convert to DataFrame and add series name column df = ts.to_dataframe(name='value').reset_index() df = df.dropna(subset=['value']) if len(df) == 0: continue # Add series name to dataframe for tooltip df['series'] = display_var # Create curve with all necessary vdims for tooltips curve = hv.Curve(df, kdims=['time'], vdims=['value', 'series'], label=display_var) curve = curve.opts( color=colors[idx % len(colors)], line_width=line_width, alpha=alpha ) curves.append(curve) if not curves: return pn.pane.Markdown("### No valid data in selected range") # Create custom HoverTool hover = HoverTool( tooltips=[ ('Series', '@series'), ('Time', '@time{%Y-%m-%d %H:%M:%S}'), ('Value', '@value{0.3f} dB'), ], formatters={ '@time': 'datetime', }, mode='mouse' ) # Overlay all curves - apply layout options to Overlay, not individual Curves overlay = hv.Overlay(curves).opts( width=1400, height=600, xlabel='Time', ylabel='Sound Level (dB)', title='Time Series Comparison', legend_position='top_right', show_grid=True, toolbar='above', tools=['hover', 'pan', 'wheel_zoom', 'box_zoom', 'reset', 'save'], active_tools=['pan', 'wheel_zoom'] ) return overlay # Info panel with dataset information info_text = "## Dataset Information\n\n" for i, ds in enumerate(data_list): prefix = f"Dataset {i + 1}" if len(data_list) > 1 else "Dataset" info_text += f"**{prefix}:**\n" info_text += f"- Variables: {len(ds.data_vars)}\n" info_text += f"- Time points: {len(ds.time)}\n" ds_min = pd.Timestamp(ds.time.min().values) ds_max = pd.Timestamp(ds.time.max().values) info_text += f"- Date range: {ds_min.date()} to {ds_max.date()}\n\n" info_text += f"**Total variables available:** {len(all_vars)}\n" # Build dashboard using template template = pn.template.FastListTemplate( title=title, sidebar=[ pn.pane.Markdown("## Controls"), var_selector, pn.layout.Divider(), color_selector, line_width_slider, alpha_slider, pn.layout.Divider(), date_range, pn.layout.Divider(), pn.pane.Markdown(info_text) ], main=[ pn.Column( # pn.pane.Markdown( # "**Interactive Features:**\n" # "- Hover over lines to see values\n" # "- Scroll to zoom, drag to pan\n" # "- Box zoom: drag while holding shift\n" # "- Reset: click reset button to restore view" # ), create_plot ) ], accent_base_color="#2196F3", header_background="#2196F3" ) # Save to HTML if requested if save_html: template.save(save_html, embed=True) print(f"✓ Saved interactive dashboard to {save_html}") # Show in browser if show: print(f"✓ Opening advanced dashboard in browser (port {port})...") template.show(port=port, threaded=True) return template
[docs] def summary(self): """Print summary of loaded dataset""" if self.ds is None: print("No data loaded. Use .load_nc_files() first.") return print("=" * 70) print("HYBRID MILLIDECADE DATASET SUMMARY") print("=" * 70) print(f"Time range : {self.ds.time.min().values} to " f"{self.ds.time.max().values}") print(f"Time points : {len(self.ds.time)}") print(f"Frequency range: {self.ds.frequency.min().values:.1f} - " f"{self.ds.frequency.max().values:.1f} Hz") print(f"Frequency bins : {len(self.ds.frequency)}") print(f"PSD shape : {self.ds.psd.shape}") print(f"Total size : {self.ds.nbytes / 1e9:.3f} GB") if hasattr(self.ds, 'deployment'): if 'deployment' in self.ds.dims: print(f"Deployments : {list(self.ds.deployment.values)}") else: print(f"Deployment : {self.ds.deployment.values}") print(f"Data variables : {list(self.ds.data_vars)}") if self.ds.chunks: print(f"Chunks : {dict(self.ds.chunks)}") print("=" * 70)
[docs] def rechunk(self, chunks=None): """ Rechunk the dataset for optimal performance. Rechunking can improve performance when chunks don't align well with the stored data or when you're doing operations along specific dimensions. Parameters ---------- chunks : dict, optional Chunk sizes for each dimension. If None, uses intelligent defaults based on the operation type: - For time-series extraction: {'time': -1, 'frequency': 'auto'} - For spectral analysis: {'time': 'auto', 'frequency': -1} Default: {'time': 1440, 'frequency': -1} (good for band extraction) Returns ------- self Examples -------- >>> # Rechunk for time-series extraction (extract_band_levels) >>> hmd.rechunk({'time': -1, 'frequency': 'auto'}) >>> >>> # Rechunk for spectral statistics >>> hmd.rechunk({'time': 'auto', 'frequency': -1}) >>> >>> # Custom chunking >>> hmd.rechunk({'time': 10000, 'frequency': 500}) Notes ----- - Use -1 to load entire dimension into single chunk - Use 'auto' to let Dask determine optimal size - Rechunking can be slow but improves downstream performance - Call this after load_nc_files() and before analysis operations """ self._check_loaded() if chunks is None: # Default: optimize for band extraction (common use case) chunks = {'time': 1440, 'frequency': -1} print(f"Rechunking dataset with chunks: {chunks}") print(" This may take a moment but will improve downstream performance...") # Rechunk the dataset self.ds = self.ds.chunk(chunks) # Persist to trigger computation and avoid repeated rechunking if self.use_dask: print(" Persisting rechunked data to distributed memory...") self.ds = self.ds.persist() print("✓ Rechunking complete") return self
[docs] def subset(self, freq_range=None, time_range=None, persist=True): """Create a subset of the data""" self._check_loaded() ds_subset = self.ds if freq_range is not None: ds_subset = ds_subset.sel(frequency=slice(freq_range[0], freq_range[1])) if time_range is not None: ds_subset = ds_subset.sel(time=slice(time_range[0], time_range[1])) if persist: ds_subset = ds_subset.persist() return ds_subset
[docs] def analyze_spatial_correlation(self, timeseries, spatial_grid, method='pearson', min_periods=None, fill_grid_nan=True, absolute=False, compute=True, use_lag=False, max_lag=None): """ Compute spatial correlation between a time series and gridded spatial data. This method correlates a single time series (e.g., acoustic levels at a point) with time series at each spatial grid cell (e.g., vessel counts across a region). Uses Dask for parallel processing across grid cells. Parameters ---------- timeseries : xarray.DataArray 1D time series with 'time' dimension (e.g., SPL at recorder location) spatial_grid : xarray.DataArray 3D gridded data with dimensions (time, latitude, longitude) (e.g., vessel counts on a spatial grid) method : str, optional Correlation method: 'pearson' (default), 'spearman', or 'kendall' min_periods : int, optional Minimum number of overlapping time points required. Default is None (uses all available overlap) fill_grid_nan : bool, optional If True (default), fill NaN values in spatial_grid with zeros. This is useful for vessel count grids where NaN typically means zero vessels. A warning will be displayed if NaNs are found and filled. absolute : bool, optional If True, return absolute values of correlation coefficients [0, 1]. If False (default), return full correlation values [-1, 1]. Use absolute=True when you care about correlation strength regardless of direction. compute : bool, optional If True (default), compute result immediately. If False, return lazy Dask array. use_lag : bool, optional If True, use cross-correlation with lag to find maximum correlation. If False (default), use standard correlation at lag=0. This is useful when the spatial signal (e.g., vessel activity) may precede or lag the acoustic signal. max_lag : int, optional Maximum time lag (in time steps) to consider when use_lag=True. The correlation will be computed for all lags from -max_lag to +max_lag. If None and use_lag=True, defaults to min(50, n_times // 4). Positive lags mean the grid leads the timeseries. Returns ------- xarray.DataArray or xarray.Dataset If ``use_lag=False``: 2D correlation map with dimensions ``(latitude, longitude)`` containing correlation coefficients in [-1, 1] (or [0, 1] if ``absolute=True``). If ``use_lag=True``: Dataset with two DataArrays: ``'correlation'`` (maximum coefficient at each grid cell) and ``'lag'`` (time lag in steps where maximum occurs; positive means the grid leads the timeseries). Examples -------- >>> # Correlate acoustic levels with vessel grid >>> with AISQueryHelper(db_file) as ais: ... vessel_grid = ais.create_gridded_vessel_counts(...) >>> >>> with HMD(n_workers=8) as hmd: ... hmd.load_nc_files(deployment_dir, time_range=time_range) ... band_levels = hmd.extract_band_levels([[50, 300]], ['ship']) ... stats = hmd.compute_timeseries_stats(band_levels, resolution='1H') ... ... # Correlate mean ship noise with vessel counts (fills NaNs with 0) ... corr_map = hmd.analyze_spatial_correlation( ... stats['ship_mean'], ... vessel_grid, ... method='pearson', ... fill_grid_nan=True # Default, fills NaNs in grid with 0 ... ) ... ... # Get absolute correlation (strength regardless of direction) ... corr_map_abs = hmd.analyze_spatial_correlation( ... stats['ship_mean'], ... vessel_grid, ... absolute=True # Returns values in [0, 1] ... ) ... ... # Or keep NaNs in the grid (may result in NaN correlations) ... corr_map = hmd.analyze_spatial_correlation( ... stats['ship_mean'], ... vessel_grid, ... fill_grid_nan=False ... ) ... ... # Plot correlation map ... corr_map.plot() ... ... # Use cross-correlation with lag to account for temporal offset ... result = hmd.analyze_spatial_correlation( ... stats['ship_mean'], ... vessel_grid, ... use_lag=True, ... max_lag=24 # Search up to 24 hours of lag ... ) ... ... # Access the correlation and lag maps ... corr_map_lag = result['correlation'] # Maximum correlation at each point ... lag_map = result['lag'] # Lag in time steps where max occurs ... ... # Plot results ... corr_map_lag.plot() ... lag_map.plot() # Positive = grid leads, negative = timeseries leads Notes ----- - Time coordinates must overlap between timeseries and spatial_grid - NaN values are handled automatically (excluded from correlation) - Computation is parallelized across grid cells using Dask - Grid cells with insufficient valid data return NaN - When use_lag=True, only Pearson correlation is supported (method must be 'pearson') """ import xarray as xr import numpy as np import dask.array as da from scipy.stats import pearsonr, spearmanr, kendalltau # Validate inputs if not isinstance(timeseries, xr.DataArray): raise ValueError("timeseries must be an xarray.DataArray") if not isinstance(spatial_grid, xr.DataArray): raise ValueError("spatial_grid must be an xarray.DataArray") if 'time' not in timeseries.dims: raise ValueError("timeseries must have 'time' dimension") required_dims = {'time', 'latitude', 'longitude'} if not required_dims.issubset(spatial_grid.dims): raise ValueError(f"spatial_grid must have dimensions: {required_dims}") if timeseries.dims != ('time',): raise ValueError("timeseries must be 1D with only 'time' dimension") # Select correlation function corr_funcs = { 'pearson': pearsonr, 'spearman': spearmanr, 'kendall': kendalltau } if method not in corr_funcs: raise ValueError(f"method must be one of {list(corr_funcs.keys())}") # Validate use_lag parameter if use_lag and method != 'pearson': raise ValueError("use_lag=True requires method='pearson' (other correlation methods not supported for lag analysis)") print(f"Computing {method} correlation between time series and spatial grid...") print(f" Time series shape: {timeseries.shape}") print(f" Spatial grid shape: {spatial_grid.shape}") # Handle NaN values in spatial grid if fill_grid_nan: # Check if there are NaNs if hasattr(spatial_grid, 'compute'): # For Dask arrays, check a sample sample = spatial_grid.isel(time=0, latitude=0, longitude=0).compute() has_nans = True # Assume yes for Dask arrays to avoid full computation else: has_nans = np.any(np.isnan(spatial_grid.values)) if has_nans: n_nans = np.sum(np.isnan(spatial_grid.values)) if not hasattr(spatial_grid, 'compute') else "unknown" print(f" ⚠ WARNING: NaN values found in spatial_grid (count: {n_nans})") print(f" → Filling NaN values with 0 (typical for vessel count grids)") spatial_grid = spatial_grid.fillna(0) # Align time coordinates - find overlapping times timeseries_aligned, grid_aligned = xr.align( timeseries, spatial_grid, join='inner', # Only keep overlapping times copy=False ) n_times = len(timeseries_aligned.time) print(f" Overlapping time points: {n_times}") if n_times == 0: raise ValueError("No overlapping time points between timeseries and spatial_grid") # Set default min_periods if min_periods is None: min_periods = max(3, int(0.5 * n_times)) # At least 3 or 50% of points print(f" Minimum valid points required: {min_periods}") # Set default max_lag if using lag-based correlation if use_lag and max_lag is None: max_lag = min(50, n_times // 4) print(f" Using default max_lag: {max_lag}") # Convert to numpy arrays for the timeseries (compute if needed) if hasattr(timeseries_aligned, 'compute'): ts_values = timeseries_aligned.compute().values else: ts_values = timeseries_aligned.values if use_lag: # Define lag-based correlation function for a single grid cell def correlate_grid_cell_with_lag(grid_cell_values): """ Compute cross-correlation with lag and return max correlation and corresponding lag. Returns a 2-element array: [max_corr, best_lag] """ # Find valid (non-NaN) points in both series valid_mask = ~(np.isnan(ts_values) | np.isnan(grid_cell_values)) n_valid = np.sum(valid_mask) # Return NaN for both if insufficient valid points if n_valid < min_periods: return np.array([np.nan, np.nan]) # Extract valid points ts_valid = ts_values[valid_mask] grid_valid = grid_cell_values[valid_mask] # Check for constant values (correlation undefined) if np.std(ts_valid) == 0 or np.std(grid_valid) == 0: return np.array([np.nan, np.nan]) # Normalize the series for correlation ts_norm = (ts_valid - np.mean(ts_valid)) / np.std(ts_valid) grid_norm = (grid_valid - np.mean(grid_valid)) / np.std(grid_valid) max_corr = np.nan best_lag = 0 try: # Compute correlation at each lag for lag in range(-max_lag, max_lag + 1): if lag == 0: # No shift needed corr_value = np.mean(ts_norm * grid_norm) elif lag > 0: # Grid leads: shift grid forward (or ts backward) # Correlate ts[lag:] with grid[:-lag] if lag < len(ts_norm): corr_value = np.mean(ts_norm[lag:] * grid_norm[:-lag]) else: continue # Skip if lag is too large else: # lag < 0 # Timeseries leads: shift ts forward (or grid backward) # Correlate ts[:lag] with grid[-lag:] abs_lag = abs(lag) if abs_lag < len(ts_norm): corr_value = np.mean(ts_norm[:-abs_lag] * grid_norm[abs_lag:]) else: continue # Skip if lag is too large # Keep track of maximum absolute correlation if np.isnan(max_corr) or abs(corr_value) > abs(max_corr): max_corr = corr_value best_lag = lag return np.array([max_corr, best_lag]) except Exception: return np.array([np.nan, np.nan]) # Apply lag-based correlation function result_array = xr.apply_ufunc( correlate_grid_cell_with_lag, grid_aligned, input_core_dims=[['time']], output_core_dims=[['stat']], # Output has a 'stat' dimension for [corr, lag] vectorize=True, dask='parallelized', output_dtypes=[float], dask_gufunc_kwargs={'output_sizes': {'stat': 2}} ) # Split the result into correlation and lag DataArrays correlation_map = result_array.isel(stat=0) lag_map = result_array.isel(stat=1) else: # Define standard correlation function for a single grid cell (lag=0) def correlate_grid_cell(grid_cell_values): """ Correlate timeseries with a single grid cell time series. Handles NaN values and returns correlation coefficient. """ # Find valid (non-NaN) points in both series valid_mask = ~(np.isnan(ts_values) | np.isnan(grid_cell_values)) n_valid = np.sum(valid_mask) # Return NaN if insufficient valid points if n_valid < min_periods: return np.nan # Extract valid points ts_valid = ts_values[valid_mask] grid_valid = grid_cell_values[valid_mask] # Check for constant values (correlation undefined) if np.std(ts_valid) == 0 or np.std(grid_valid) == 0: return np.nan try: # Compute correlation corr, _ = corr_funcs[method](ts_valid, grid_valid) return corr except Exception: return np.nan # Apply correlation function along time dimension using xarray's apply_ufunc # This leverages Dask for parallel processing correlation_map = xr.apply_ufunc( correlate_grid_cell, grid_aligned, input_core_dims=[['time']], # Apply function along time dimension vectorize=True, # Vectorize over latitude/longitude dask='parallelized', # Use Dask for parallel processing output_dtypes=[float], ) # Apply absolute value if requested if absolute: print(" Taking absolute values of correlations...") correlation_map = np.abs(correlation_map) # Add metadata corr_type = 'absolute_' if absolute else '' valid_range = [0.0, 1.0] if absolute else [-1.0, 1.0] correlation_map.name = f'{corr_type}{method}_correlation' correlation_map.attrs = { 'long_name': f'{method.capitalize()} Correlation Coefficient' + (' (Absolute)' if absolute else ''), 'description': f'Spatial correlation between time series and gridded data', 'method': method, 'absolute': absolute, 'min_periods': min_periods, 'n_time_points': n_times, 'valid_range': valid_range } # Compute if requested if compute: print(" Computing correlations...") correlation_map = correlation_map.compute() if use_lag: lag_map = lag_map.compute() # Report statistics valid_cells = np.sum(~np.isnan(correlation_map.values)) total_cells = correlation_map.size print(f"✓ Correlation analysis complete") print(f" Valid grid cells: {valid_cells}/{total_cells} ({100*valid_cells/total_cells:.1f}%)") if valid_cells > 0: corr_vals = correlation_map.values[~np.isnan(correlation_map.values)] print(f" Correlation range: [{np.min(corr_vals):.3f}, {np.max(corr_vals):.3f}]") print(f" Mean correlation: {np.mean(corr_vals):.3f}") if use_lag: lag_vals = lag_map.values[~np.isnan(lag_map.values)] print(f" Lag range: [{int(np.min(lag_vals))} to {int(np.max(lag_vals))}] time steps") print(f" Mean lag: {np.mean(lag_vals):.1f} time steps") # Return Dataset with correlation and lag if use_lag=True, otherwise just correlation if use_lag: # Add metadata to lag map lag_map.name = 'lag' lag_map.attrs = { 'long_name': 'Time Lag at Maximum Correlation', 'description': 'Lag in time steps where maximum correlation occurs', 'units': 'time steps', 'positive_lag_means': 'grid leads timeseries', 'negative_lag_means': 'timeseries leads grid', 'max_lag_searched': max_lag } # Create Dataset with both correlation and lag result = xr.Dataset({ 'correlation': correlation_map, 'lag': lag_map }) return result else: return correlation_map
[docs] def save_to_csv(self, data, output_path, wide_format=True): """ Save xarray Dataset or DataArray to CSV file(s). Parameters ---------- data : xarray.Dataset or xarray.DataArray Data to save output_path : str Output file path. For Datasets with multiple variables: - If wide_format=True: saves to single file at this path - If wide_format=False: saves multiple files with variable names appended wide_format : bool, optional If True (default), save all variables in one wide CSV. If False, save each variable to a separate CSV file. Examples -------- >>> # Save statistics to CSV >>> stats = hmd.compute_timeseries_stats(band_levels, resolution='1h') >>> hmd.save_to_csv(stats, 'statistics.csv') >>> # Save each variable separately >>> hmd.save_to_csv(stats, 'statistics.csv', wide_format=False) """ from pathlib import Path # Compute if lazy if hasattr(data, 'compute'): print("Computing data before saving...") data = data.compute() if isinstance(data, xr.DataArray): # Single array - convert to DataFrame and save df = data.to_dataframe() df.to_csv(output_path) print(f"✓ Saved to {output_path}") elif isinstance(data, xr.Dataset): if wide_format: # Save all variables in one wide CSV df = data.to_dataframe() df.to_csv(output_path) print(f"✓ Saved {len(data.data_vars)} variables to {output_path}") else: # Save each variable to separate file output_path = Path(output_path) stem = output_path.stem suffix = output_path.suffix parent = output_path.parent for var_name in data.data_vars: var_path = parent / f"{stem}_{var_name}{suffix}" df = data[var_name].to_dataframe() df.to_csv(var_path) print(f"✓ Saved {var_name} to {var_path}") else: raise ValueError("data must be xarray.DataArray or Dataset")
@staticmethod def _integrate_band_db(db_values, freq_axis='frequency'): """Properly integrate dB values across a frequency band""" linear = 10 ** (db_values / 10) integrated_linear = linear.sum(dim=freq_axis) return 10 * np.log10(integrated_linear) def _check_loaded(self): """Check if dataset is loaded""" if self.ds is None: raise ValueError("No data loaded. Use .load_nc_files() first.") def _setup_dask(self): """Setup Dask cluster for parallel processing""" try: cluster = LocalCluster( n_workers=self.n_workers, threads_per_worker=2 if self.use_processes else 4, processes=self.use_processes, memory_limit=self.memory_per_worker, local_directory=self.temp_directory ) self.client = Client(cluster) mode = "processes" if self.use_processes else "threads" print(f"✓ Dask cluster started ({mode} mode)") print(f" Dashboard: {self.client.dashboard_link}") except RuntimeError as e: if "multiprocessing" in str(e) or "bootstrapping" in str(e): print("\n" + "=" * 70) print("ERROR: Windows multiprocessing issue detected!") print("=" * 70) print("\nSOLUTION: Use threaded mode") print(" hmd = HMD(use_processes=False)") print("=" * 70 + "\n") print("→ Automatically falling back to threaded mode...") self.use_processes = False self._setup_dask() else: raise @staticmethod def _filter_files_by_date(files, time_range): """Filter files based on date in filename (end date exclusive)""" import pandas as pd if time_range is None: return files, 0 start_date = pd.to_datetime(time_range[0]).to_pydatetime() end_date = pd.to_datetime(time_range[1]).to_pydatetime() filtered_files = [] skipped_count = 0 unparseable_count = 0 for file in files: file_date = HMD._parse_date_from_filename(file.name) if file_date is None: filtered_files.append(file) unparseable_count += 1 elif start_date <= file_date < end_date: # End date is exclusive filtered_files.append(file) else: skipped_count += 1 if unparseable_count > 0: print(f" Note: {unparseable_count} file(s) don't match date pattern") return filtered_files, skipped_count @staticmethod def _parse_date_from_filename(filename): """Extract date from filename ending with '_YYYYMMDD.nc'""" import re from datetime import datetime filename_str = str(filename) match = re.search(r'_(\d{8})\.nc$', filename_str) if match: date_str = match.group(1) try: return datetime.strptime(date_str, '%Y%m%d') except ValueError: return None return None
[docs] def close(self): """Close the Dask client and clean up resources""" if self.client is not None: print("Closing Dask client...") self.client.close() self.client = None print("✓ Dask client closed")
def __enter__(self): """Context manager entry""" return self def __exit__(self, exc_type, exc_val, exc_tb): """Context manager exit""" self.close() def __repr__(self): """String representation""" if self.ds is None: return "HMD(no data loaded)" else: time_span = len(self.ds.time) freq_bins = len(self.ds.frequency) size_gb = self.ds.nbytes / 1e9 return (f"HMD(time={time_span}, frequency={freq_bins}, " f"size={size_gb:.2f}GB)")
# def analyze_frequency_bands(self, bands_dict, time_resolution='1H'): # """ # Analyze multiple frequency bands efficiently. # # Parameters # ---------- # bands_dict : dict # Dictionary of band names and frequency ranges (Hz) # time_resolution : str # Temporal resolution ('1T', '1H', '1D', etc.) # # Returns # ------- # dict # # Examples # -------- # >>> bands = {'low': (100, 500), 'mid': (500, 2000)} # >>> results = hmd.analyze_frequency_bands(bands, time_resolution='1H') # """ # self._check_loaded() # # results = {} # # for band_name, (f_min, f_max) in bands_dict.items(): # print(f"Processing band '{band_name}': {f_min}-{f_max} Hz") # # band_data = self.ds.sel(frequency=slice(f_min, f_max)) # # linear_power = 10 ** (band_data.psd / 10) # integrated_power = linear_power.sum(dim='frequency') # integrated_level = 10 * np.log10(integrated_power) # # stats = { # 'integrated': integrated_level, # 'max': band_data.psd.max(dim='frequency'), # 'median': band_data.psd.median(dim='frequency'), # 'std': band_data.psd.std(dim='frequency') # } # # if time_resolution != '1T': # stats = {k: v.resample(time=time_resolution).mean() # for k, v in stats.items()} # # results[band_name] = stats # # print("Computing statistics...") # computed_results = {} # for band_name in results: # computed_results[band_name] = dask.compute(results[band_name])[0] # # print("✓ Band analysis complete") # return computed_results # def to_zarr(self, output_path, chunks=None): # """Convert dataset to Zarr format""" # self._check_loaded() # # if chunks is not None: # ds_rechunked = self.ds.chunk(chunks) # else: # ds_rechunked = self.ds # # print(f"Converting to Zarr format: {output_path}") # ds_rechunked.to_zarr(output_path, mode='w', consolidated=True) # print("✓ Conversion complete") # def analyze_temporal_patterns(self, freq_band=None): # """Analyze daily and hourly patterns""" # self._check_loaded() # # print("Analyzing temporal patterns...") # # linear_power = 10 ** (self.ds.psd / 10) # # if freq_band is not None: # linear_power = linear_power.sel( # frequency=slice(freq_band[0], freq_band[1]) # ) # # integrated_power = linear_power.sum(dim='frequency') # data_db = 10 * np.log10(integrated_power) # # hour_of_day = data_db.groupby('time.hour').mean() # day_of_week = data_db.groupby('time.dayofweek').mean() # daily_mean = data_db.resample(time='1D').mean() # hourly_mean = data_db.resample(time='1H').mean() # # result = dask.compute({ # 'hour_of_day': hour_of_day, # 'day_of_week': day_of_week, # 'daily': daily_mean, # 'hourly': hourly_mean # })[0] # # print("✓ Temporal pattern analysis complete") # return result # def resample_temporal(self, resolution='1H', freq_band=None): # """Resample data to coarser temporal resolution""" # self._check_loaded() # # if freq_band is not None: # band_data = self.ds.sel(frequency=slice(freq_band[0], freq_band[1])) # linear_power = 10 ** (band_data.psd / 10) # integrated_power = linear_power.sum(dim='frequency') # data_db = 10 * np.log10(integrated_power) # resampled = data_db.resample(time=resolution).mean() # else: # resampled = self.ds.psd.resample(time=resolution).mean() # # return resampled if __name__ == "__main__": # Example: Extract band levels and use plot_timeseries hmd = HMD(n_workers=4) hmd.load_nc_files('path/to/deployment_01') # Extract bands freq_bands = [[100], [50, 300], [1000], [2000, 10000]] names = ['100Hz', 'ship', '1kHz', 'mammal'] result = hmd.extract_band_levels(freq_bands, band_names=names) # Plot with new unified method - overlaid (default) hmd.plot_timeseries(result.compute()) # Plot as separate subplots hmd.plot_timeseries(result.compute(), overlay=False) # Plot specific bands only hmd.plot_timeseries(result.compute(), band_names=['ship', 'mammal']) # Plot single band hmd.plot_timeseries(result['ship'].compute()) # # In plot_hmd_test.py # # # Example 1: Single dataset # stats_timeseries = hmd.compute_timeseries_stats(band_levels, resolution='1h') # hmd.plot_interactive_advanced(stats_timeseries.compute()) # # # Example 2: Plot raw band levels # band_levels_computed = band_levels.compute() # hmd.plot_interactive_advanced(band_levels_computed, title='Raw Band Levels') # # # Example 3: Compare raw data with statistics (MULTIPLE DATASETS) # band_levels_computed = band_levels.compute() # stats_computed = stats_timeseries.compute() # # hmd.plot_interactive_advanced( # [band_levels_computed, stats_computed], # title='Raw Data vs Hourly Statistics' # ) # # # Example 4: Compare multiple time periods or deployments # deployment1 = hmd.subset(time_range=('2018-11-01', '2018-11-15')).compute() # deployment2 = hmd.subset(time_range=('2018-11-16', '2018-11-30')).compute() # # hmd.plot_interactive_advanced( # [deployment1, deployment2], # title='Deployment Comparison' # ) # # # Example 5: Save to HTML without opening browser # hmd.plot_interactive_advanced( # [band_levels_computed, stats_computed], # save_html='comparison_dashboard.html', # show=False # )