Skip to content

Model Training

Model training utilities for AI-based audio processing.

This module provides functions to build, train, and save machine learning models.

Author: Esgr0bar

build_model(input_shape)

Build a Convolutional Neural Network (CNN) model.

Parameters:

Name Type Description Default
input_shape tuple

Shape of the input data (height, width, channels).

required

Returns:

Type Description

keras.Model: Compiled CNN model.

Source code in src/model_training.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def build_model(input_shape):
    """
    Build a Convolutional Neural Network (CNN) model.

    Args:
        input_shape (tuple): Shape of the input data (height, width, channels).

    Returns:
        keras.Model: Compiled CNN model.
    """
    model = models.Sequential([
        layers.Conv2D(32, (3, 3), activation='relu', input_shape=input_shape),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.Flatten(),
        layers.Dense(64, activation='relu'),
        layers.Dense(1)  # Assuming regression output; modify for classification
    ])

    model.compile(optimizer='adam', loss='mse', metrics=['mae'])
    return model

save_model(model, file_path)

Save the trained model to a file.

Parameters:

Name Type Description Default
model Model

Trained model.

required
file_path str

Path to save the model.

required

Returns:

Type Description

None

Source code in src/model_training.py
52
53
54
55
56
57
58
59
60
61
62
63
def save_model(model, file_path):
    """
    Save the trained model to a file.

    Args:
        model (keras.Model): Trained model.
        file_path (str): Path to save the model.

    Returns:
        None
    """
    model.save(file_path)

train_model(model, X_train, y_train, epochs=10, validation_split=0.2)

Train the CNN model on the provided data.

Parameters:

Name Type Description Default
model Model

Compiled model.

required
X_train ndarray

Training data.

required
y_train ndarray

Target values.

required
epochs int

Number of epochs to train.

10
validation_split float

Fraction of data to use for validation.

0.2

Returns:

Type Description

keras.callbacks.History: Training history.

Source code in src/model_training.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def train_model(model, X_train, y_train, epochs=10, validation_split=0.2):
    """
    Train the CNN model on the provided data.

    Args:
        model (keras.Model): Compiled model.
        X_train (numpy.ndarray): Training data.
        y_train (numpy.ndarray): Target values.
        epochs (int): Number of epochs to train.
        validation_split (float): Fraction of data to use for validation.

    Returns:
        keras.callbacks.History: Training history.
    """
    history = model.fit(X_train, y_train, epochs=epochs, validation_split=validation_split)
    return history