Source code for mne_videobrowser.media.audio

"""Contains AudioFile interface and its implementations for reading audio files."""

# License: BSD-3-Clause
# Copyright (c) 2014 BioMag Laboratory, Helsinki University Central Hospital
# Copyright (c) 2025 Aalto University

import logging
import struct
from abc import ABC, abstractmethod
from fractions import Fraction

import numpy as np
import numpy.typing as npt
from scipy import signal

from ._helsinki_videomeg_file_utils import UnknownVersionError, read_block_attributes

logger = logging.getLogger(__name__)


class AudioFile(ABC):
    """Handles reading audio files."""

    def __init__(self, fname: str) -> None:
        """Initialize the audio file reader with the given file name."""
        self._fname = fname

    def __del__(self) -> None:
        """Ensure the audio file is released when the object is deleted."""
        self.close()

    @abstractmethod
    def __enter__(self) -> "AudioFile":
        """Enter the runtime context with opened audio file."""
        pass

    def __exit__(self, exc_type, exc_value, traceback) -> None:
        """Exit the runtime context and release the audio file."""
        self.close()

    @abstractmethod
    def close(self) -> None:
        """Release any resources held by the audio file.

        Should be safe to call multiple times.
        """
        pass

    @abstractmethod
    def get_audio_all_channels(
        self, sample_range: tuple[int, int] | None = None
    ) -> npt.NDArray[np.float32]:
        """Get audio data for all channels in the specified sample range.

        Parameters
        ----------
        sample_range : tuple[int, int] | None
            A tuple specifying the start and end (exclusive) sample indices to include
            in the output. If None (default), all the samples are included.

        Returns
        -------
        npt.NDArray[np.float32]
            A 2D array of shape (n_channels, n_samples) containing the audio data.
        """
        pass

    @abstractmethod
    def get_audio_mean(
        self, sample_range: tuple[int, int] | None = None
    ) -> npt.NDArray[np.float32]:
        """Get mean audio data across channels in the specified sample range.

        Parameters
        ----------
        sample_range : tuple[int, int] | None
            A tuple specifying the start and end (exclusive) sample indices to include
            in the output. If None (default), all the samples are included.

        Returns
        -------
        npt.NDArray[np.float32]
            A 1D array containing the mean audio data for the specified sample range.
        """
        pass

    def get_global_max_amplitude(self, chunk_duration_seconds: float = 5) -> float:
        """Get the maximum absolute amplitude across all channels in the audio file.

        Parameters
        ----------
        chunk_duration_seconds : float
            Duration of each chunk (in seconds) to read and process at a time.
            Default is 5 seconds.

        Returns
        -------
        float
            The maximum absolute amplitude found in the audio file.
        """
        if chunk_duration_seconds <= 0:
            raise ValueError("Chunk duration must be a positive number.")

        n_samples_per_chunk = int(chunk_duration_seconds * self.sampling_rate)

        max_amplitude = 0.0
        for start_sample in range(0, self.n_samples, n_samples_per_chunk):
            # Ensure we don't go beyond the total number of samples.
            end_sample = min(start_sample + n_samples_per_chunk, self.n_samples)

            # Get (n_channels, n_samples) array for the chunk.
            audio_chunk = self.get_audio_all_channels((start_sample, end_sample))

            # Find and update the maximum amplitude.
            chunk_max = np.max(np.abs(audio_chunk))
            if chunk_max > max_amplitude:
                max_amplitude = chunk_max

        return float(max_amplitude)

    def get_min_max_envelope(
        self,
        window_size: int,
        channel_idx: int | None,
        sample_range: tuple[int, int] | None = None,
    ) -> tuple[
        npt.NDArray[np.float64], npt.NDArray[np.float32], npt.NDArray[np.float32]
    ]:
        """Calculate min-max envelope of the audio data using non-overlapping windows.

        Divides the audio signal into consecutive non-overlapping windows of fixed size
        and computes the minimum and maximum values in each window, capturing amplitude
        variations over time.

        Parameters
        ----------
        window_size : int
            The number of audio samples in each window.
        channel_idx : int | None
            The zero-based index of the channel to calculate the envelope for. If None,
            the envelope is calculated for the mean signal across all channels.
        sample_range : tuple[int, int] | None, optional
            A tuple specifying the start and end (exclusive) sample indices to include
            in the calculation. If None (default), all the samples are included.

        Returns
        -------
        times : npt.NDArray[np.float64]
            A 1D array of time points corresponding to the start time of each window.
        min_envelope : npt.NDArray[np.float32]
            A 1D array containing the minimum values of the audio signal in each window.
        max_envelope : npt.NDArray[np.float32]
            A 1D array containing the maximum values of the audio signal in each window.
        """
        if window_size <= 0:
            raise ValueError("Window size must be a positive integer.")
        if channel_idx is not None and (
            channel_idx < 0 or channel_idx >= self.n_channels
        ):
            raise ValueError(
                f"Invalid channel index: {channel_idx}. "
                f"Must be in range [0, {self.n_channels - 1}]."
            )

        if channel_idx is None:
            audio_data = self.get_audio_mean(sample_range)
        else:
            audio_data = self.get_audio_all_channels(sample_range)[channel_idx, :]

        n_samples = len(audio_data)
        if n_samples < window_size:
            raise ValueError(
                f"Audio data length {len(audio_data)} is less than the window "
                f"size {window_size}."
            )

        # Pad the audio data with the last sample if necessary.
        remainder = n_samples % window_size
        if remainder != 0:
            pad_size = window_size - remainder
            audio_data = np.pad(audio_data, (0, pad_size), mode="edge")
        n_samples = len(audio_data)  # Update n_samples after padding
        assert n_samples % window_size == 0, "Remainder should be zero after padding."

        # Calculate the min-max envelope
        n_windows = n_samples // window_size
        audio_windows = audio_data.reshape(n_windows, window_size)
        min_envelope = np.min(audio_windows, axis=1)
        max_envelope = np.max(audio_windows, axis=1)

        # Calculate the time points for the start of each window
        start_sample = 0 if sample_range is None else sample_range[0]
        window_start_samples = np.arange(n_windows) * window_size + start_sample
        times = window_start_samples / self.sampling_rate  # Convert to seconds

        return times, min_envelope, max_envelope

    def resample_poly(
        self, target_rate: int, channel_idx: int | None
    ) -> npt.NDArray[np.float32]:
        """Resample the audio to the target sampling rate using polyphase filtering.

        Parameters
        ----------
        target_rate : int
            The desired sampling rate to resample the audio data to.
        channel_idx : int | None
            The zero-based index of the channel to resample. If None, the mean signal
            across all channels is resampled.

        Returns
        -------
        npt.NDArray[np.float32]
            A 1D array containing the resampled audio data.
        """
        if target_rate <= 0:
            raise ValueError("Target sampling rate must be a positive integer.")
        # Get the audio data to resample.
        if channel_idx is None:
            audio_data = self.get_audio_mean()
        else:
            audio_data = self.get_audio_all_channels()[channel_idx, :]

        if target_rate == self.sampling_rate:
            logger.info(
                "Target sampling rate is the same as the original. "
                "Returning original audio data without resampling."
            )
            return audio_data

        up, down = self._find_resample_factors(target_rate)
        if max(up, down) > 1000:
            logger.warning(
                f"Resampling factors are large {up}:{down}. This may lead to "
                "significant computational overhead. Consider using different "
                "resampling method or adjusting the target rate."
            )
        logger.info(
            f"Resampling audio from {self.sampling_rate} Hz to {target_rate} Hz "
            f"using polyphase filtering with factors {up}:{down}."
        )
        return signal.resample_poly(audio_data, up, down)

    @property
    def fname(self) -> str:
        """Return full path to the audio file that is being read."""
        return self._fname

    @property
    @abstractmethod
    def sampling_rate(self) -> int:
        """Return the nominal sampling rate of the audio."""
        pass

    @property
    @abstractmethod
    def n_channels(self) -> int:
        """Return the number of channels in the audio."""
        pass

    @property
    @abstractmethod
    def bit_depth(self) -> int:
        """Return the bit depth of the audio."""
        pass

    @property
    @abstractmethod
    def duration(self) -> float:
        """Return the duration of the audio in seconds."""
        pass

    @property
    @abstractmethod
    def n_samples(self) -> int:
        """Return the number of samples (per channel) in the audio."""
        pass

    def print_stats(self) -> None:
        """Print basic statistics about the audio file."""
        print(f"Stats for audio: {self.fname}")
        print(f"  - Number of channels: {self.n_channels}")
        print(f"  - Sampling rate: {self.sampling_rate} Hz")
        print(f"  - Bit depth: {self.bit_depth} bits")
        print(f"  - Duration: {self.duration:.2f} seconds")
        print(f"  - Number of samples per channel: {self.n_samples}")

    def _find_resample_factors(self, target_rate: int) -> tuple[int, int]:
        """Find the factors for up-and downsampling to match the target rate."""
        frac = Fraction(target_rate, self.sampling_rate)
        up, down = frac.numerator, frac.denominator
        return up, down


[docs] class AudioFileHelsinkiVideoMEG(AudioFile): """Read an audio file in the Helsinki VideoMEG project format. In addition to the properties of AudioFile interface, the following attributes are available: - buffer_timestamps_ms: buffers' timestamps (unix time in milliseconds) - format_string: format string for the audio data - buffer_size: buffer size (bytes) Parameters ---------- fname : str Full path to the audio file. magic_str : str, optional Magic string that should be at the beginning of video file. Default is "HELSINKI_VIDEO_MEG_PROJECT_AUDIO_FILE". regression_segment_length : int, optional Length of segments (in seconds) used in piecewise linear regression to compute timestamps for all audio samples. Default is 20 seconds. """ def __init__( self, fname: str, magic_str: str = "HELSINKI_VIDEO_MEG_PROJECT_AUDIO_FILE", regression_segment_length: int = 20, ) -> None: super().__init__(fname) self._regression_segment_length = regression_segment_length self._data_file = open(self._fname, "rb") # Check the magic string if not self._data_file.read(len(magic_str)) == magic_str.encode("utf8"): raise ValueError( f"File {fname} does not start with the expected " f"magic string: {magic_str}." ) # Read properties from the file header. self.ver = struct.unpack("I", self._data_file.read(4))[0] if self.ver != 0: # Can only read version 0. raise UnknownVersionError() self._sampling_rate, self._n_channels = struct.unpack( "II", self._data_file.read(8) ) self.format_string = self._data_file.read(2).decode("ascii") # Now file position is at the beginning of audio data blocks. begin_data = self._data_file.tell() self._data_file.seek(0, 2) # seek to end of file end_data = self._data_file.tell() # Seek back to the beginning of audio data blocks. self._data_file.seek(begin_data, 0) # Get the size of the payload in one audio data block and the total size # of the block (header + payload). Advances file position! _, first_payload_size, first_block_size = read_block_attributes( self._data_file, self.ver ) self.buffer_size_bytes = first_payload_size # size of audio data in one block self._data_file.seek(begin_data, 0) # return to beginning if not (end_data - begin_data) % first_block_size == 0: raise ValueError( "Audio data size is not a multiple of block size. " "The audio file may be corrupted." ) # Read the positions and timestamps of all audio blocks. self._n_blocks = (end_data - begin_data) // first_block_size self.buffer_timestamps_ms = np.zeros(self._n_blocks, dtype=np.int64) self._audio_block_positions: list[int] = [] for i in range(self._n_blocks): timestamp, payload_size, block_size = read_block_attributes( self._data_file, self.ver ) if block_size != first_block_size: raise ValueError( "Inconsistent block size while reading audio data. First block size" f" was {first_block_size} bytes, but block {i} size is" f" {block_size} bytes." ) self._audio_block_positions.append(self._data_file.tell()) self._data_file.seek(payload_size, 1) # skip actual audio data (payload) self.buffer_timestamps_ms[i] = timestamp # Make sure that the timestamps are increasing if not np.all(np.diff(self.buffer_timestamps_ms) >= 0): raise ValueError( "Audio buffer timestamps must be non-decreasing but found " "decreasing values." ) # Calculate stats for a single sample. self._bit_depth = self._get_bit_depth(self.format_string) self._n_bytes_per_sample = struct.calcsize(self.format_string) # Calculate how many samples there is in one raw audio data buffer, # taking into account that the buffer contains interleaved samples # from all channels. one_sample_from_all_channels_size = self._n_channels * self._n_bytes_per_sample if not self.buffer_size_bytes % one_sample_from_all_channels_size == 0: raise ValueError( "Audio buffer size is not a multiple of one sample from all channels." ) self._n_samples_per_channel_per_buffer = ( self.buffer_size_bytes // one_sample_from_all_channels_size ) # Calculate total number of samples per channel in the whole audio. self._n_samples = self._n_samples_per_channel_per_buffer * self._n_blocks self._compute_audio_timestamps() # will set self._audio_timestamps_ms def __enter__(self) -> "AudioFileHelsinkiVideoMEG": """Enter the runtime context with opened audio file.""" return self
[docs] def close(self) -> None: """Close the audio file.""" if hasattr(self, "_data_file") and not self._data_file.closed: try: self._data_file.close() except Exception as e: logger.warning(f"Error closing audio file {self._fname}: {e}")
[docs] def get_audio_all_channels( self, sample_range: tuple[int, int] | None = None ) -> npt.NDArray[np.float32]: """Get audio data for all channels in the specified sample range. Parameters ---------- sample_range : tuple[int, int] | None A tuple specifying the start and end (exclusive) sample indices to include in the output. If None (default), all the samples are included. Returns ------- npt.NDArray[np.float32] A 2D array of shape (n_channels, n_samples) containing the audio data. """ return self._get_audio_samples( sample_range if sample_range is not None else (0, self.n_samples) )
[docs] def get_audio_mean( self, sample_range: tuple[int, int] | None = None ) -> npt.NDArray[np.float32]: """Get mean audio data across channels in the specified sample range. Parameters ---------- sample_range : tuple[int, int] | None A tuple specifying the start and end (exclusive) sample indices to include in the output. If None (default), all the samples are included. Returns ------- npt.NDArray[np.float32] A 1D array containing the mean audio data for the specified sample range. """ audio_all_channels = self._get_audio_samples( sample_range if sample_range is not None else (0, self.n_samples) ) return audio_all_channels.mean(axis=0)
[docs] def get_audio_timestamps_ms(self) -> npt.NDArray[np.float64]: """Get timestamps for all audio samples in milliseconds. Returns ------- npt.NDArray[np.float64] A 1D array containing timestamps for all audio samples in milliseconds. """ return self._audio_timestamps_ms
@property def sampling_rate(self) -> int: return self._sampling_rate @property def n_channels(self) -> int: return self._n_channels @property def bit_depth(self) -> int: return self._bit_depth @property def n_samples(self) -> int: return self._n_samples @property def duration(self) -> float: return self.n_samples / self.sampling_rate def _get_audio_samples( self, sample_range: tuple[int, int] ) -> npt.NDArray[np.float32]: """Get audio samples in the specified range (start inclusive, end exclusive). Determines the correct audio blocks to read from file and unpacks the samples. """ start_sample, end_sample = sample_range if start_sample < 0 or end_sample > self.n_samples: raise ValueError("Sample range is out of bounds.") if start_sample >= end_sample: raise ValueError("Invalid sample range: start must be less than end.") n_samples_to_read = end_sample - start_sample duration_to_read = n_samples_to_read / self.sampling_rate # Determine which blocks to read. first_block_idx = start_sample // self._n_samples_per_channel_per_buffer last_block_idx = (end_sample - 1) // self._n_samples_per_channel_per_buffer n_blocks_to_read = last_block_idx - first_block_idx + 1 logger.debug( f"Reading {duration_to_read:.2f} seconds ({n_samples_to_read} samples) of " f"audio data from blocks {first_block_idx} to {last_block_idx}." ) # Allocate space for raw audio data from the blocks. block_data = bytearray(n_blocks_to_read * self.buffer_size_bytes) # Read the necessary blocks and concatenate their payloads (ignore headers). for block_idx in range(first_block_idx, last_block_idx + 1): # Determine where to copy the block data in the allocated bytearray. relative_block_idx = block_idx - first_block_idx block_start = relative_block_idx * self.buffer_size_bytes block_end = block_start + self.buffer_size_bytes # Read the block and copy its data. block_data[block_start:block_end] = self._read_block(block_idx) # Unpack the audio data from the read blocks to (n_channels, n_samples) array. unpacked_audio = self._unpack_audio(block_data, n_blocks_to_read) # Because we might have read more samples than requested (we read whole blocks), # determine the correct slice to return. first_block_start_sample = ( first_block_idx * self._n_samples_per_channel_per_buffer ) copy_start = start_sample - first_block_start_sample copy_end = copy_start + n_samples_to_read return unpacked_audio[:, copy_start:copy_end] def _read_block(self, block_idx: int) -> bytes: """Read the raw audio data from the specified block in the file.""" # Seek to the beginning of the block. block_pos = self._audio_block_positions[block_idx] self._data_file.seek(block_pos, 0) # Read the block data. return self._data_file.read(self.buffer_size_bytes) def _unpack_audio( self, audio_bytes: bytearray, n_blocks: int ) -> npt.NDArray[np.float32]: """Unpack given raw audio bytes from adjacent blocks. Parameters ---------- audio_bytes : bytearray Raw audio bytes from adjacent blocks to unpack. n_blocks : int Number of blocks contained in audio_bytes. Returns ------- npt.NDArray[np.float32] A 2D array of shape (n_channels, n_samples) containing the unpacked audio data. """ n_samples_per_channel = self._n_samples_per_channel_per_buffer * n_blocks dtype = self._get_numpy_dtype() audio = np.frombuffer(audio_bytes, dtype=dtype).astype(np.float32) # Reshape (n_channels, n_samples) layout. # The data is interleaved, so reshape to (n_samples, n_channels) first # and then transpose. return audio.reshape(n_samples_per_channel, self.n_channels).T def _get_bit_depth(self, format_string: str) -> int: """Get the bit depth from the format string.""" # Dictionary mapping format characters to bit depths bit_depth_map = { "b": 8, # signed char "B": 8, # unsigned char "h": 16, # short "H": 16, # unsigned short "i": 32, # int "I": 32, # unsigned int "l": 32, # long "L": 32, # unsigned long "q": 64, # long long "Q": 64, # unsigned long long "f": 32, # float "d": 64, # double } # Extract the format character, ignoring endianness indicators bit_depth_char = format_string[-1] if bit_depth_char not in bit_depth_map: raise ValueError( f"Unsupported bit depth character: {bit_depth_char} in format " f"string {format_string}" ) return bit_depth_map[bit_depth_char] def _get_numpy_dtype(self) -> np.dtype: """Construct numpy dtype from the format string.""" # Determine the data type for numpy based on the format string. dtype_map = { "b": np.int8, "B": np.uint8, "h": np.int16, "H": np.uint16, "i": np.int32, "I": np.uint32, "l": np.int32, "L": np.uint32, "q": np.int64, "Q": np.uint64, "f": np.float32, "d": np.float64, } sample_type = self.format_string[1] if sample_type not in dtype_map: raise ValueError( f"Unsupported sample type character: {sample_type} in format " f"string {self.format_string}" ) numpy_dtype = np.dtype(dtype_map[sample_type]) # Handle endianness. endian_char = self.format_string[0] if endian_char == "<": numpy_dtype = numpy_dtype.newbyteorder("<") elif endian_char == ">": numpy_dtype = numpy_dtype.newbyteorder(">") elif endian_char in ("=", "@"): # Native endianness numpy_dtype = numpy_dtype.newbyteorder("=") else: raise ValueError( f"Unsupported endianness character: {endian_char} in format " f"string {self.format_string}" ) return numpy_dtype def _compute_audio_timestamps(self) -> None: """Transform sparse buffer timestamps into dense sample timestamps. Uses piecewise linear regression to estimate timestamps for all samples based on the buffer timestamps. """ # Create an array that contains the indices of the last sample in each buffer. # These indices correspond to the timestamps we have. buffer_end_indices = np.arange( self._n_samples_per_channel_per_buffer - 1, self.n_samples, self._n_samples_per_channel_per_buffer, ) # Prepare arrays to hold the regression errors and the computed timestamps. regression_errors = -np.ones(self._n_blocks, dtype=np.float64) # Double precision is important here! audio_timestamps_ms = -np.ones(self.n_samples, dtype=np.float64) # Split the data into segments for piecewise linear regression. split_indices = list( range( 0, self.n_samples, self._regression_segment_length * self._sampling_rate ) ) # the last segment might be up to twice as long as the others split_indices[-1] = self.n_samples # Loop over the segments and perform linear regression. for i in range(len(split_indices) - 1): segment_start_idx = split_indices[i] segment_end_idx = split_indices[i + 1] # Find the buffers that have timestamps within the current segment. segment_mask = (buffer_end_indices >= segment_start_idx) & ( buffer_end_indices < segment_end_idx ) # Take the samples indices and timestamps. timestamp_indices = buffer_end_indices[segment_mask] timestamps_ms = self.buffer_timestamps_ms[segment_mask] # Fit a linear regression. p = np.polyfit( timestamp_indices, timestamps_ms, 1, ) # Compute the regression error for the known timestamps. regression_errors[segment_mask] = np.abs( np.polyval(p, timestamp_indices) - self.buffer_timestamps_ms[segment_mask] ) # Compute timestamps for all samples in the segment. audio_timestamps_ms[segment_start_idx:segment_end_idx] = np.polyval( p, np.arange(segment_start_idx, segment_end_idx) ) assert audio_timestamps_ms.min() >= 0, "All timestamps should be set" assert regression_errors.min() >= 0, "All regression errors should be set" logger.info( "Audio regression fit errors (abs): mean %.3f ms, median %.3f ms, " "max %.3f ms", regression_errors.mean(), np.median(regression_errors), regression_errors.max(), ) # Make sure that the timestamps are non-decreasing. timestamps_diff = np.diff(audio_timestamps_ms) if not np.all(timestamps_diff >= 0): logger.warning( "Piecewise linear regression produced %d decreasing timestamps. " "Replacing the decreasing timestamps with the previous valid timestamp " "to ensure non-decreasing timestamps.", np.sum(timestamps_diff < 0), ) audio_timestamps_ms = np.maximum.accumulate(audio_timestamps_ms) self._audio_timestamps_ms = audio_timestamps_ms