595 lines
20 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Import notwendiger Bibliotheken\n",
"import nussl\n",
"from common import data, viz\n",
"from IPython.display import Audio\n",
"import IPython\n",
"import tensorflow as tf\n",
"from tensorflow.keras import layers, models\n",
"import numpy as np\n",
"import librosa\n",
"import os\n",
"from scipy import signal\n",
"import matplotlib.pyplot as plt\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# STFT Parameter setzen\n",
"stft_params = nussl.STFTParams(window_length=512, hop_length=128, window_type='sqrt_hann')\n",
"#stft_params = nussl.STFTParams(window_length=1024, hop_length=256, window_type='sqrt_hann')\n",
"\n",
"# Pfad zu den Trainingsdaten festlegen\n",
"fg_path = \"C:\\\\Users\\\\Lukas\\\\nussl_tutorial\\\\train\"\n",
"\n",
"# Trainingsdaten mit `nussl` laden\n",
"train_data = data.on_the_fly(stft_params, transform=None, fg_path=fg_path, num_mixtures=100, coherent_prob=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Funktionen zum Visualizieren der Waveform und Spectrogram\n",
"\n",
"def show_wav(sources):\n",
" if isinstance(sources, list):\n",
" sources = {f'Source {i}': s for i, s in enumerate(sources)}\n",
" plt.figure(figsize=(10, 5))\n",
" plt.plot()\n",
" nussl.core.utils.visualize_sources_as_waveform(sources)\n",
" plt.show()\n",
"\n",
"def show_1wav(data):\n",
" plt.figure(figsize=(10, 5))\n",
" plt.plot()\n",
" nussl.core.utils.visualize_waveform(data)\n",
" plt.show()\n",
"\n",
"def show_fre(sources):\n",
" if isinstance(sources, list):\n",
" sources = {f'Source {i}': s for i, s in enumerate(sources)}\n",
" plt.figure(figsize=(10, 5))\n",
" plt.plot()\n",
" nussl.core.utils.visualize_sources_as_masks(sources, db_cutoff=-80)\n",
" plt.tight_layout()\n",
" plt.show()\n",
"\n",
"def show_1fre(data):\n",
" plt.figure(figsize=(10, 5))\n",
" plt.plot()\n",
" nussl.core.utils.visualize_spectrogram(data)\n",
" plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Beispielhafte Audioausgabe zur Überprüfung\n",
"item = train_data[0]\n",
"print(item.keys())\n",
"display(Audio(data=item['mix'].audio_data, rate=item['mix'].sample_rate))\n",
"show_wav(item['sources'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"# Vorverarbeitung der STFT-Spezrogramme\n",
"def preprocess_spectrogram(magnitude, target_size=(512, 128)):\n",
" magnitude = tf.image.resize(magnitude, target_size)\n",
" return magnitude\n",
"\n",
"# Extrahieren der STFT-Daten aus den Mix- und Gesangsdaten\n",
"def prepare_data(data_item, stft_params):\n",
" # Lade die Mix- und Gesangs-Daten\n",
" mix = data_item['mix']\n",
" vocals = data_item['sources']['vocals']\n",
"\n",
" # Berechne das STFT (Spektrum) der Mix- und Gesangs-Daten\n",
" mix_signal = mix\n",
" vocals_signal = vocals\n",
"\n",
" # STFT von Mix und Gesang\n",
" mix_mag = np.abs(mix_signal.stft(window_length=stft_params.window_length, hop_length=stft_params.hop_length))\n",
" vocals_mag = np.abs(vocals_signal.stft(window_length=stft_params.window_length, hop_length=stft_params.hop_length))\n",
"\n",
" # Resize für das U-Net\n",
" mix_mag_resized = preprocess_spectrogram(mix_mag)\n",
" vocals_mag_resized = preprocess_spectrogram(vocals_mag)\n",
"\n",
" return mix_mag_resized, vocals_mag_resized\n",
"\n",
"# Beispiel-Daten (train_data ist hier das, was du aus nussl bekommst)\n",
"mix_data, vocals_data = prepare_data(train_data[0], stft_params)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# SDR Metric: Verhältnis zwische gewolltem Signal und Rauschen in dB\n",
"def sdr_metric(y_true, y_pred):\n",
" signal_power = tf.reduce_sum(y_true ** 2)\n",
" noise_power = tf.reduce_sum((y_true - y_pred) ** 2)\n",
" return 10 * tf.math.log(signal_power / noise_power) / tf.math.log(10.0)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# Vorhersagen und SDR für Testdaten berechnen\n",
"def evaluate_model_with_sdr(model, test_data, stft_params):\n",
" sdr_scores = []\n",
" for data_item in test_data:\n",
" # Preprocess test sample\n",
" X_test, y_test = prepare_data(data_item, stft_params)\n",
"\n",
" # Vorhersage des Modells\n",
" y_pred = model.predict(np.expand_dims(X_test, axis=0))[0]\n",
"\n",
" # SDR berechnen\n",
" #sdr_score = compute_sdr(y_test.numpy(), y_pred)\n",
" #sdr_score = compute_sdr(y_test, y_pred)\n",
" sdr_score = sdr_metric(y_test, y_pred)\n",
" sdr_scores.append(sdr_score)\n",
" \n",
" # Durchschnittlichen SDR berechnen\n",
" avg_sdr = np.mean(sdr_scores)\n",
" print(f\"Durchschnittlicher SDR: {avg_sdr:.2f} dB\")\n",
" return avg_sdr"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"def build_unet(input_shape=(512, 128, 1)):\n",
" inputs = tf.keras.Input(shape=input_shape)\n",
"\n",
" # Encoder\n",
" conv1 = layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same')(inputs)\n",
" conv1 = layers.BatchNormalization()(conv1)\n",
" conv1 = layers.ReLU()(conv1)\n",
" conv1 = layers.Dropout(0.5)(conv1)\n",
"\n",
" conv2 = layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same')(conv1)\n",
" conv2 = layers.BatchNormalization()(conv2)\n",
" conv2 = layers.ReLU()(conv2)\n",
" conv2 = layers.Dropout(0.5)(conv2)\n",
"\n",
" conv3 = layers.Conv2D(256, (5, 5), strides=(2, 2), padding='same')(conv2)\n",
" conv3 = layers.BatchNormalization()(conv3)\n",
" conv3 = layers.ReLU()(conv3)\n",
" conv3 = layers.Dropout(0.5)(conv3)\n",
"\n",
" conv4 = layers.Conv2D(512, (5, 5), strides=(2, 2), padding='same')(conv3)\n",
" conv4 = layers.BatchNormalization()(conv4)\n",
" conv4 = layers.ReLU()(conv4)\n",
"\n",
" conv5 = layers.Conv2D(1024, (5, 5), strides=(2, 2), padding='same')(conv4)\n",
" conv5 = layers.BatchNormalization()(conv5)\n",
" conv5 = layers.ReLU()(conv5)\n",
"\n",
" # Decoder\n",
" up6 = layers.Conv2DTranspose(512, (5, 5), strides=(2, 2), padding='same')(conv5)\n",
" up6 = layers.BatchNormalization()(up6)\n",
" up6 = layers.ReLU()(up6)\n",
" up6 = layers.Concatenate()([up6, conv4])\n",
"\n",
" up7 = layers.Conv2DTranspose(256, (5, 5), strides=(2, 2), padding='same')(up6)\n",
" up7 = layers.BatchNormalization()(up7)\n",
" up7 = layers.ReLU()(up7)\n",
" up7 = layers.Concatenate()([up7, conv3])\n",
"\n",
" up8 = layers.Conv2DTranspose(128, (5, 5), strides=(2, 2), padding='same')(up7)\n",
" up8 = layers.BatchNormalization()(up8)\n",
" up8 = layers.ReLU()(up8)\n",
" up8 = layers.Concatenate()([up8, conv2])\n",
"\n",
" up9 = layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same')(up8)\n",
" up9 = layers.BatchNormalization()(up9)\n",
" up9 = layers.ReLU()(up9)\n",
" up9 = layers.Concatenate()([up9, conv1])\n",
"\n",
" outputs = layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', activation='sigmoid')(up9)\n",
"\n",
" model = tf.keras.Model(inputs, outputs)\n",
" return model\n",
"\n",
"\n",
"\"\"\" # U-Net Modell\n",
"def build_unet(input_shape=(512, 128, 1)):\n",
" inputs = layers.Input(shape=input_shape)\n",
" #inputs = layers.Input(shape=X_train.shape[1:])\n",
"\n",
" # Encoder\n",
" enc1 = conv_block(inputs, 64, dropout=True)\n",
" enc2 = conv_block(enc1, 128, dropout=True)\n",
" enc3 = conv_block(enc2, 256, dropout=True)\n",
" enc4 = conv_block(enc3, 512)\n",
" enc5 = conv_block(enc4, 1024)\n",
" \n",
" # Decoder with skip connections\n",
" dec4 = conv_transpose_block(enc5, 512)\n",
" dec4 = layers.concatenate([dec4, enc4])\n",
" \n",
" dec3 = conv_transpose_block(dec4, 256)\n",
" dec3 = layers.concatenate([dec3, enc3])\n",
" \n",
" dec2 = conv_transpose_block(dec3, 128)\n",
" dec2 = layers.concatenate([dec2, enc2])\n",
" \n",
" dec1 = conv_transpose_block(dec2, 64)\n",
" dec1 = layers.concatenate([dec1, enc1])\n",
" \n",
" # Output layer with sigmoid activation for mask\n",
" outputs = layers.Conv2D(1, kernel_size=1, activation=\"sigmoid\")(dec1)\n",
" \n",
" return models.Model(inputs, outputs)\n",
"\n",
"def conv_block(inputs, filters, dropout=False):\n",
" x = layers.Conv2D(filters, kernel_size=5, strides=2, padding=\"same\")(inputs)\n",
" x = layers.BatchNormalization()(x)\n",
" x = layers.ReLU()(x)\n",
" if dropout:\n",
" x = layers.Dropout(0.5)(x)\n",
" return x\n",
"\n",
"def conv_transpose_block(inputs, filters):\n",
" x = layers.Conv2DTranspose(filters, kernel_size=5, strides=2, padding=\"same\")(inputs)\n",
" x = layers.BatchNormalization()(x)\n",
" x = layers.ReLU()(x)\n",
" return x \"\"\"\n",
"\n",
"# Modell initialisieren\n",
"model = build_unet(input_shape=(512, 128, 1))\n",
"#model.compile(optimizer='adam', loss='mse')\n",
"\n",
"\"\"\" def sdr_metric(y_true, y_pred):\n",
" signal_power = tf.reduce_sum(y_true ** 2)\n",
" noise_power = tf.reduce_sum((y_true - y_pred) ** 2)\n",
" return 10 * tf.math.log(signal_power / noise_power) / tf.math.log(10.0) \"\"\"\n",
"\n",
"\n",
"# Lernrate anpassen\n",
"from tensorflow.keras.optimizers import Adam\n",
"\n",
"# Lernrate anpassen\n",
"optimizer = Adam(learning_rate=0.001)\n",
"\n",
"model.compile(optimizer=optimizer, loss='mse', metrics=[sdr_metric])\n",
"\n",
"# Zusammenfassung des Modells anzeigen\n",
"model.summary()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Trainingsdaten vorbereiten\n",
"def create_training_data(data, stft_params):\n",
" inputs = []\n",
" targets = []\n",
" \n",
" i = 0\n",
" for data_item in data:\n",
" i = i + 1\n",
" print(i, \"/\", data.num_mixtures, end='\\r')\n",
"\n",
" mix, vocals = prepare_data(data_item, stft_params)\n",
" inputs.append(mix)\n",
" targets.append(vocals)\n",
" \n",
" # Umwandlung in Tensoren\n",
" inputs = np.array(inputs)\n",
" targets = np.array(targets)\n",
" \n",
" return inputs, targets\n",
"\n",
"# Beispiel für Trainingsdaten (train_data ist deine Datensammlung)\n",
"X_train, y_train = create_training_data(train_data, stft_params)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def mask_loss(mixture):\n",
" def loss(y_true, y_pred):\n",
" batch_size = tf.shape(y_pred)[0]\n",
" current_mixture = tf.slice(mixture, [0, 0, 0, 0], [batch_size, -1, -1, -1])\n",
" after_mask = y_pred * current_mixture\n",
" return tf.reduce_mean(tf.square(y_true - after_mask))\n",
" return loss\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def sdr_metric2(mixture):\n",
" def metric(y_true, y_pred):\n",
" batch_size = tf.shape(y_pred)[0]\n",
" current_mixture = mixture[:batch_size]\n",
" after_mask = y_pred * current_mixture\n",
" signal_power = tf.reduce_sum(y_true ** 2)\n",
" noise_power = tf.reduce_sum((y_true - after_mask ) ** 2)\n",
" return 10 * tf.math.log(signal_power / noise_power) / tf.math.log(10.0)\n",
" return metric"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"def perceptual_loss(y_true, y_pred, extractor, mixture):\n",
" after_mask = y_pred * mixture\n",
" features_true = extractor(y_true)\n",
" features_pred = extractor(after_mask)\n",
" return tf.reduce_mean(tf.square(features_true - features_pred))\n",
"\n",
"# Beispiel für den STFT als Extraktor:\n",
"stft_extractor = lambda x: tf.signal.stft(x, frame_length=512, frame_step=128)\n",
"#model.compile(optimizer='adam', loss=lambda y_true, y_pred: perceptual_loss(y_true, y_pred, stft_extractor))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"class LossPlotCallback(tf.keras.callbacks.Callback):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.batch_losses = [] # Liste für Batch-Losses\n",
"\n",
" def on_train_batch_end(self, batch, logs=None):\n",
" # Füge den aktuellen Batch-Loss hinzu\n",
" self.batch_losses.append(logs['loss'])\n",
"\n",
" def on_train_end(self, logs=None):\n",
" # Plotten des Loss-Verlaufs mit logarithmischer Y-Achse\n",
" plt.figure(figsize=(10, 6))\n",
" plt.plot(self.batch_losses, 'b-', label='Batch Loss')\n",
" plt.yscale('log') # Logarithmische Skala für die Y-Achse\n",
" plt.title('Loss-Verlauf während des Trainings (logarithmisch)')\n",
" plt.xlabel('Batch')\n",
" plt.ylabel('Loss (log)')\n",
" plt.legend()\n",
" plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Lernrate anpassen\n",
"optimizer = Adam(learning_rate=0.001)\n",
"\n",
"# Mischung als Tensor bereitstellen\n",
"mixture = X_train\n",
"#model.compile(optimizer='adam', loss=mask_loss(mixture), metrics=[sdr_metric2( mixture)])\n",
"\n",
"model.compile(optimizer=optimizer, loss='mse', metrics=[sdr_metric])\n",
"\n",
"loss_plot_callback = LossPlotCallback()\n",
"\n",
"# Modelltraining\n",
"history = model.fit(X_train, y_train, batch_size=4, epochs=15, validation_split=0.1, callbacks=[loss_plot_callback])\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"# Extrahiere den Trainings- und Validierungsverlust\n",
"loss = history.history['loss']\n",
"val_loss = history.history['val_loss']\n",
"epochs = range(1, len(loss) + 1)\n",
"\n",
"# Plotten\n",
"plt.figure(figsize=(10, 6))\n",
"plt.plot(epochs, loss, 'bo-', label='Training Loss')\n",
"plt.plot(epochs, val_loss, 'ro-', label='Validation Loss')\n",
"plt.title('Loss-Verlauf während des Trainings')\n",
"plt.xlabel('Epoche')\n",
"plt.ylabel('Loss')\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Testdaten laden\n",
"test_path = \"C:\\\\Users\\\\Lukas\\\\nussl_tutorial\\\\test\"\n",
"test_data = data.on_the_fly(stft_params, transform=None, fg_path=test_path, num_mixtures=100)\n",
"\n",
"# Beispielhafte Testdatei auswählen\n",
"song_id = 20\n",
"test_item = test_data[song_id]\n",
"train_item = train_data[song_id]\n",
"\n",
"# Mix und Ziel (z. B. Vocals) extrahieren und vorverarbeiten\n",
"mix_mag, vocals_mag = prepare_data(test_item, stft_params)\n",
"#mix_mag, vocals_mag = prepare_data(train_item, stft_params)\n",
"\n",
"# Modellvorhersage\n",
"predicted_mask = model.predict(tf.expand_dims(mix_mag, axis=0))\n",
"predicted_mask = np.squeeze(predicted_mask, axis=0)\n",
"\n",
"# Spektrogramm des gemischten Signals\n",
"mix_signal = test_item['mix']\n",
"#mix_signal = train_item['mix']\n",
"mix_stft = mix_signal.stft(window_length=512, hop_length=128)\n",
"\n",
"# Maske auf die Größe des gemischten Spektrogramms bringen\n",
"predicted_mask_resized = tf.image.resize(predicted_mask, mix_stft.shape[:2])\n",
"\n",
"# Dimensionen anpassen (falls notwendig)\n",
"#predicted_mask_resized = tf.squeeze(predicted_mask_resized).numpy()\n",
"\n",
"# Maske anwenden\n",
"predicted_stft = predicted_mask_resized * mix_stft\n",
"\n",
"audio_vocal = nussl.AudioSignal()\n",
"audio_vocal.stft_data = predicted_stft.numpy()\n",
"audio_vocal.istft(window_length=512, hop_length=128)\n",
"\n",
"item = test_data[song_id]\n",
"#item = train_data[song_id]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"mask = nussl.AudioSignal()\n",
"mask.stft_data = predicted_mask_resized.numpy()\n",
"mask.istft(window_length=512, hop_length=128)\n",
"show_1fre(mask)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"evaluate_model_with_sdr(model, test_data, stft_params)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Orginaldaten\n",
"print('Mix:')\n",
"display(Audio(data=item['mix'].audio_data, rate=item['mix'].sample_rate))\n",
"print('Vocals:')\n",
"display(Audio(data=item['sources']['vocals'].audio_data, rate=item['mix'].sample_rate))\n",
"#print(item.keys())\n",
"show_wav(item['sources'])\n",
"show_fre(item['sources'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Audio nach Model\n",
"print('nach model:')\n",
"display(Audio(data=audio_vocal.audio_data, rate=item['mix'].sample_rate))\n",
"\n",
"print('Ziel Amplitudenverlauf:')\n",
"show_1wav(data=item['sources']['vocals'])\n",
"print('Amplitudenverlauf nach Model:')\n",
"show_1wav(audio_vocal)\n",
"\n",
"print('Ziel Spektogram:')\n",
"show_1fre(data=item['sources']['vocals'])\n",
"print('Spektogram nach Model:')\n",
"show_1fre(audio_vocal)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Modell speichern\n",
"model.save(\"t.h5\")\n",
"print(\"Modell erfolgreich gespeichert!\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Modell laden\n",
"model = tf.keras.models.load_model(\"24_11_18_unet_4.34.h5\", custom_objects={'sdr_metric': sdr_metric})\n",
"print(\"Modell erfolgreich geladen!\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}