571 lines
19 KiB
Plaintext
571 lines
19 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"%conda install -c conda-forge sox"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"%pip install scaper\n",
|
|
"%pip install nussl\n",
|
|
"%pip install git+https://github.com/source-separation/tutorial"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Ersetze WINDOW_HAMMING\n",
|
|
"!sed -i 's/scipy.signal.hamming/scipy.signal.windows.hamming/' /opt/conda/lib/python3.11/site-packages/nussl/core/constants.py\n",
|
|
"\n",
|
|
"# Ersetze WINDOW_HANN\n",
|
|
"!sed -i 's/scipy.signal.hann/scipy.signal.windows.hann/' /opt/conda/lib/python3.11/site-packages/nussl/core/constants.py\n",
|
|
"\n",
|
|
"# Ersetze WINDOW_BLACKMAN\n",
|
|
"!sed -i 's/scipy.signal.blackman/scipy.signal.windows.blackman/' /opt/conda/lib/python3.11/site-packages/nussl/core/constants.py\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Import notwendiger Bibliotheken\n",
|
|
"import nussl\n",
|
|
"from common import data, viz\n",
|
|
"from IPython.display import Audio\n",
|
|
"import IPython\n",
|
|
"import tensorflow as tf\n",
|
|
"from tensorflow.keras import layers, models\n",
|
|
"import numpy as np\n",
|
|
"import librosa\n",
|
|
"import os\n",
|
|
"from scipy import signal\n",
|
|
"import matplotlib.pyplot as plt"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Prepare MUSDB\n",
|
|
"data.prepare_musdb('musdb/')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 34,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# STFT Parameter setzen\n",
|
|
"stft_params = nussl.STFTParams(window_length=512, hop_length=128, window_type='sqrt_hann')\n",
|
|
"#stft_params = nussl.STFTParams(window_length=1024, hop_length=256, window_type='sqrt_hann')\n",
|
|
"\n",
|
|
"# Pfad zu den Trainingsdaten festlegen\n",
|
|
"fg_path = \"musdb/train\"\n",
|
|
"\n",
|
|
"# Trainingsdaten mit `nussl` laden\n",
|
|
"train_data = data.on_the_fly(stft_params, transform=None, fg_path=fg_path, num_mixtures=2000, coherent_prob=0.5)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Funktionen zum Visualizieren der Waveform und Spectrogram\n",
|
|
"\n",
|
|
"def show_wav(sources):\n",
|
|
" if isinstance(sources, list):\n",
|
|
" sources = {f'Source {i}': s for i, s in enumerate(sources)}\n",
|
|
" plt.figure(figsize=(10, 5))\n",
|
|
" plt.plot()\n",
|
|
" nussl.core.utils.visualize_sources_as_waveform(sources)\n",
|
|
" plt.show()\n",
|
|
"\n",
|
|
"def show_1wav(data):\n",
|
|
" plt.figure(figsize=(10, 5))\n",
|
|
" plt.plot()\n",
|
|
" nussl.core.utils.visualize_waveform(data)\n",
|
|
" plt.show()\n",
|
|
"\n",
|
|
"def show_fre(sources):\n",
|
|
" if isinstance(sources, list):\n",
|
|
" sources = {f'Source {i}': s for i, s in enumerate(sources)}\n",
|
|
" plt.figure(figsize=(10, 5))\n",
|
|
" plt.plot()\n",
|
|
" nussl.core.utils.visualize_sources_as_masks(sources, db_cutoff=-80)\n",
|
|
" plt.tight_layout()\n",
|
|
" plt.show()\n",
|
|
"\n",
|
|
"def show_1fre(data):\n",
|
|
" plt.figure(figsize=(10, 5))\n",
|
|
" plt.plot()\n",
|
|
" nussl.core.utils.visualize_spectrogram(data)\n",
|
|
" plt.show()\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 35,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Beispielhafte Audioausgabe zur Überprüfung\n",
|
|
"item = train_data[0]\n",
|
|
"print(item.keys())\n",
|
|
"Audio(data=item['mix'].audio_data, rate=item['mix'].sample_rate)\n",
|
|
"show_wav(item['sources'])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 36,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"# Vorverarbeitung der STFT-Spezrogramme\n",
|
|
"def preprocess_spectrogram(magnitude, target_size=(512, 128)):\n",
|
|
" magnitude = tf.image.resize(magnitude, target_size)\n",
|
|
" return magnitude\n",
|
|
"\n",
|
|
"# Extrahieren der STFT-Daten aus den Mix- und Gesangsdaten\n",
|
|
"def prepare_data(data_item, stft_params):\n",
|
|
" # Lade die Mix- und Gesangs-Daten\n",
|
|
" mix = data_item['mix']\n",
|
|
" vocals = data_item['sources']['vocals']\n",
|
|
"\n",
|
|
" # Berechne das STFT (Spektrum) der Mix- und Gesangs-Daten\n",
|
|
" mix_signal = mix\n",
|
|
" vocals_signal = vocals\n",
|
|
"\n",
|
|
" # STFT von Mix und Gesang\n",
|
|
" mix_mag = np.abs(mix_signal.stft(window_length=stft_params.window_length, hop_length=stft_params.hop_length))\n",
|
|
" vocals_mag = np.abs(vocals_signal.stft(window_length=stft_params.window_length, hop_length=stft_params.hop_length))\n",
|
|
"\n",
|
|
" # Resize für das U-Net\n",
|
|
" mix_mag_resized = preprocess_spectrogram(mix_mag)\n",
|
|
" vocals_mag_resized = preprocess_spectrogram(vocals_mag)\n",
|
|
"\n",
|
|
" return mix_mag_resized, vocals_mag_resized\n",
|
|
"\n",
|
|
"# Beispiel-Daten (train_data ist hier das, was du aus nussl bekommst)\n",
|
|
"mix_data, vocals_data = prepare_data(train_data[0], stft_params)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 37,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# SDR Metric: Verhältnis zwische gewolltem Signal und Rauschen in dB\n",
|
|
"def sdr_metric(y_true, y_pred):\n",
|
|
" signal_power = tf.reduce_sum(y_true ** 2)\n",
|
|
" noise_power = tf.reduce_sum((y_true - y_pred) ** 2)\n",
|
|
" return 10 * tf.math.log(signal_power / noise_power) / tf.math.log(10.0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 38,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Vorhersagen und SDR für Testdaten berechnen\n",
|
|
"def evaluate_model_with_sdr(model, test_data, stft_params):\n",
|
|
" sdr_scores = []\n",
|
|
" for data_item in test_data:\n",
|
|
" # Preprocess test sample\n",
|
|
" X_test, y_test = prepare_data(data_item, stft_params)\n",
|
|
"\n",
|
|
" # Vorhersage des Modells\n",
|
|
" y_pred = model.predict(np.expand_dims(X_test, axis=0))[0]\n",
|
|
"\n",
|
|
" # SDR berechnen\n",
|
|
" #sdr_score = compute_sdr(y_test.numpy(), y_pred)\n",
|
|
" #sdr_score = compute_sdr(y_test, y_pred)\n",
|
|
" sdr_score = sdr_metric(y_test, y_pred)\n",
|
|
" sdr_scores.append(sdr_score)\n",
|
|
" \n",
|
|
" # Durchschnittlichen SDR berechnen\n",
|
|
" avg_sdr = np.mean(sdr_scores)\n",
|
|
" print(f\"Durchschnittlicher SDR: {avg_sdr:.2f} dB\")\n",
|
|
" return avg_sdr"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 39,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"def build_unet(input_shape=(512, 128, 1)):\n",
|
|
" inputs = tf.keras.Input(shape=input_shape)\n",
|
|
"\n",
|
|
" # Encoder\n",
|
|
" conv1 = layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same')(inputs)\n",
|
|
" conv1 = layers.BatchNormalization()(conv1)\n",
|
|
" conv1 = layers.ReLU()(conv1)\n",
|
|
" conv1 = layers.Dropout(0.5)(conv1)\n",
|
|
"\n",
|
|
" conv2 = layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same')(conv1)\n",
|
|
" conv2 = layers.BatchNormalization()(conv2)\n",
|
|
" conv2 = layers.ReLU()(conv2)\n",
|
|
" conv2 = layers.Dropout(0.5)(conv2)\n",
|
|
"\n",
|
|
" conv3 = layers.Conv2D(256, (5, 5), strides=(2, 2), padding='same')(conv2)\n",
|
|
" conv3 = layers.BatchNormalization()(conv3)\n",
|
|
" conv3 = layers.ReLU()(conv3)\n",
|
|
" conv3 = layers.Dropout(0.5)(conv3)\n",
|
|
"\n",
|
|
" conv4 = layers.Conv2D(512, (5, 5), strides=(2, 2), padding='same')(conv3)\n",
|
|
" conv4 = layers.BatchNormalization()(conv4)\n",
|
|
" conv4 = layers.ReLU()(conv4)\n",
|
|
"\n",
|
|
" conv5 = layers.Conv2D(1024, (5, 5), strides=(2, 2), padding='same')(conv4)\n",
|
|
" conv5 = layers.BatchNormalization()(conv5)\n",
|
|
" conv5 = layers.ReLU()(conv5)\n",
|
|
"\n",
|
|
" # Decoder\n",
|
|
" up6 = layers.Conv2DTranspose(512, (5, 5), strides=(2, 2), padding='same')(conv5)\n",
|
|
" up6 = layers.BatchNormalization()(up6)\n",
|
|
" up6 = layers.ReLU()(up6)\n",
|
|
" up6 = layers.Concatenate()([up6, conv4])\n",
|
|
"\n",
|
|
" up7 = layers.Conv2DTranspose(256, (5, 5), strides=(2, 2), padding='same')(up6)\n",
|
|
" up7 = layers.BatchNormalization()(up7)\n",
|
|
" up7 = layers.ReLU()(up7)\n",
|
|
" up7 = layers.Concatenate()([up7, conv3])\n",
|
|
"\n",
|
|
" up8 = layers.Conv2DTranspose(128, (5, 5), strides=(2, 2), padding='same')(up7)\n",
|
|
" up8 = layers.BatchNormalization()(up8)\n",
|
|
" up8 = layers.ReLU()(up8)\n",
|
|
" up8 = layers.Concatenate()([up8, conv2])\n",
|
|
"\n",
|
|
" up9 = layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same')(up8)\n",
|
|
" up9 = layers.BatchNormalization()(up9)\n",
|
|
" up9 = layers.ReLU()(up9)\n",
|
|
" up9 = layers.Concatenate()([up9, conv1])\n",
|
|
"\n",
|
|
" outputs = layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', activation='sigmoid')(up9)\n",
|
|
"\n",
|
|
" model = tf.keras.Model(inputs, outputs)\n",
|
|
" return model\n",
|
|
"\n",
|
|
"\n",
|
|
"\"\"\" # U-Net Modell\n",
|
|
"def build_unet(input_shape=(512, 128, 1)):\n",
|
|
" inputs = layers.Input(shape=input_shape)\n",
|
|
" #inputs = layers.Input(shape=X_train.shape[1:])\n",
|
|
"\n",
|
|
" # Encoder\n",
|
|
" enc1 = conv_block(inputs, 64, dropout=True)\n",
|
|
" enc2 = conv_block(enc1, 128, dropout=True)\n",
|
|
" enc3 = conv_block(enc2, 256, dropout=True)\n",
|
|
" enc4 = conv_block(enc3, 512)\n",
|
|
" enc5 = conv_block(enc4, 1024)\n",
|
|
" \n",
|
|
" # Decoder with skip connections\n",
|
|
" dec4 = conv_transpose_block(enc5, 512)\n",
|
|
" dec4 = layers.concatenate([dec4, enc4])\n",
|
|
" \n",
|
|
" dec3 = conv_transpose_block(dec4, 256)\n",
|
|
" dec3 = layers.concatenate([dec3, enc3])\n",
|
|
" \n",
|
|
" dec2 = conv_transpose_block(dec3, 128)\n",
|
|
" dec2 = layers.concatenate([dec2, enc2])\n",
|
|
" \n",
|
|
" dec1 = conv_transpose_block(dec2, 64)\n",
|
|
" dec1 = layers.concatenate([dec1, enc1])\n",
|
|
" \n",
|
|
" # Output layer with sigmoid activation for mask\n",
|
|
" outputs = layers.Conv2D(1, kernel_size=1, activation=\"sigmoid\")(dec1)\n",
|
|
" \n",
|
|
" return models.Model(inputs, outputs)\n",
|
|
"\n",
|
|
"def conv_block(inputs, filters, dropout=False):\n",
|
|
" x = layers.Conv2D(filters, kernel_size=5, strides=2, padding=\"same\")(inputs)\n",
|
|
" x = layers.BatchNormalization()(x)\n",
|
|
" x = layers.ReLU()(x)\n",
|
|
" if dropout:\n",
|
|
" x = layers.Dropout(0.5)(x)\n",
|
|
" return x\n",
|
|
"\n",
|
|
"def conv_transpose_block(inputs, filters):\n",
|
|
" x = layers.Conv2DTranspose(filters, kernel_size=5, strides=2, padding=\"same\")(inputs)\n",
|
|
" x = layers.BatchNormalization()(x)\n",
|
|
" x = layers.ReLU()(x)\n",
|
|
" return x \"\"\"\n",
|
|
"\n",
|
|
"# Modell initialisieren\n",
|
|
"model = build_unet(input_shape=(512, 128, 1))\n",
|
|
"#model.compile(optimizer='adam', loss='mse')\n",
|
|
"\n",
|
|
"\"\"\" def sdr_metric(y_true, y_pred):\n",
|
|
" signal_power = tf.reduce_sum(y_true ** 2)\n",
|
|
" noise_power = tf.reduce_sum((y_true - y_pred) ** 2)\n",
|
|
" return 10 * tf.math.log(signal_power / noise_power) / tf.math.log(10.0) \"\"\"\n",
|
|
"\n",
|
|
"\n",
|
|
"# Lernrate anpassen\n",
|
|
"from tensorflow.keras.optimizers import Adam\n",
|
|
"\n",
|
|
"# Lernrate anpassen\n",
|
|
"optimizer = Adam(learning_rate=0.001)\n",
|
|
"\n",
|
|
"model.compile(optimizer=optimizer, loss='mse', metrics=[sdr_metric])\n",
|
|
"\n",
|
|
"# Zusammenfassung des Modells anzeigen\n",
|
|
"model.summary()\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 42,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Trainingsdaten vorbereiten\n",
|
|
"def create_training_data(data, stft_params):\n",
|
|
" inputs = []\n",
|
|
" targets = []\n",
|
|
" \n",
|
|
" i = 0\n",
|
|
" for data_item in data:\n",
|
|
" i = i + 1\n",
|
|
" print(i, \"/\", data.num_mixtures, end='\\r')\n",
|
|
"\n",
|
|
" mix, vocals = prepare_data(data_item, stft_params)\n",
|
|
" inputs.append(mix)\n",
|
|
" targets.append(vocals)\n",
|
|
" \n",
|
|
" # Umwandlung in Tensoren\n",
|
|
" inputs = np.array(inputs)\n",
|
|
" targets = np.array(targets)\n",
|
|
" \n",
|
|
" return inputs, targets\n",
|
|
"\n",
|
|
"# Beispiel für Trainingsdaten (train_data ist deine Datensammlung)\n",
|
|
"X_train, y_train = create_training_data(train_data, stft_params)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"class LossPlotCallback(tf.keras.callbacks.Callback):\n",
|
|
" def __init__(self):\n",
|
|
" super().__init__()\n",
|
|
" self.batch_losses = [] # Liste für Batch-Losses\n",
|
|
"\n",
|
|
" def on_train_batch_end(self, batch, logs=None):\n",
|
|
" # Füge den aktuellen Batch-Loss hinzu\n",
|
|
" self.batch_losses.append(logs['loss'])\n",
|
|
"\n",
|
|
" def on_train_end(self, logs=None):\n",
|
|
" # Plotten des Loss-Verlaufs mit logarithmischer Y-Achse\n",
|
|
" plt.figure(figsize=(10, 6))\n",
|
|
" plt.plot(self.batch_losses, 'b-', label='Batch Loss')\n",
|
|
" plt.yscale('log') # Logarithmische Skala für die Y-Achse\n",
|
|
" plt.title('Loss-Verlauf während des Trainings (logarithmisch)')\n",
|
|
" plt.xlabel('Batch')\n",
|
|
" plt.ylabel('Loss (log)')\n",
|
|
" plt.legend()\n",
|
|
" plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 44,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Lernrate anpassen\n",
|
|
"optimizer = Adam(learning_rate=0.001)\n",
|
|
"\n",
|
|
"# Mischung als Tensor bereitstellen\n",
|
|
"mixture = X_train\n",
|
|
"#model.compile(optimizer='adam', loss=mask_loss(mixture), metrics=[sdr_metric2( mixture)])\n",
|
|
"\n",
|
|
"model.compile(optimizer=optimizer, loss='mse', metrics=[sdr_metric])\n",
|
|
"\n",
|
|
"loss_plot_callback = LossPlotCallback()\n",
|
|
"\n",
|
|
"# Modelltraining\n",
|
|
"history = model.fit(X_train, y_train, batch_size=4, epochs=15, validation_split=0.1, callbacks=[loss_plot_callback])\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"# Extrahiere den Trainings- und Validierungsverlust\n",
|
|
"loss = history.history['loss']\n",
|
|
"val_loss = history.history['val_loss']\n",
|
|
"epochs = range(1, len(loss) + 1)\n",
|
|
"\n",
|
|
"# Plotten\n",
|
|
"plt.figure(figsize=(10, 6))\n",
|
|
"plt.plot(epochs, loss, 'bo-', label='Training Loss')\n",
|
|
"plt.plot(epochs, val_loss, 'ro-', label='Validation Loss')\n",
|
|
"plt.title('Loss-Verlauf während des Trainings')\n",
|
|
"plt.xlabel('Epoche')\n",
|
|
"plt.ylabel('Loss')\n",
|
|
"plt.legend()\n",
|
|
"plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 28,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Testdaten laden\n",
|
|
"test_path = \"musdb/test\"\n",
|
|
"test_data = data.on_the_fly(stft_params, transform=None, fg_path=test_path, num_mixtures=100)\n",
|
|
"\n",
|
|
"# Beispielhafte Testdatei auswählen\n",
|
|
"song_id = 0\n",
|
|
"test_item = test_data[song_id]\n",
|
|
"train_item = train_data[song_id]\n",
|
|
"\n",
|
|
"# Mix und Ziel (z. B. Vocals) extrahieren und vorverarbeiten\n",
|
|
"mix_mag, vocals_mag = prepare_data(test_item, stft_params)\n",
|
|
"#mix_mag, vocals_mag = prepare_data(train_item, stft_params)\n",
|
|
"\n",
|
|
"# Modellvorhersage\n",
|
|
"predicted_mask = model.predict(tf.expand_dims(mix_mag, axis=0))\n",
|
|
"predicted_mask = np.squeeze(predicted_mask, axis=0)\n",
|
|
"\n",
|
|
"# Spektrogramm des gemischten Signals\n",
|
|
"mix_signal = test_item['mix']\n",
|
|
"#mix_signal = train_item['mix']\n",
|
|
"mix_stft = mix_signal.stft(window_length=512, hop_length=128)\n",
|
|
"\n",
|
|
"# Maske auf die Größe des gemischten Spektrogramms bringen\n",
|
|
"predicted_mask_resized = tf.image.resize(predicted_mask, mix_stft.shape[:2])\n",
|
|
"\n",
|
|
"# Dimensionen anpassen (falls notwendig)\n",
|
|
"#predicted_mask_resized = tf.squeeze(predicted_mask_resized).numpy()\n",
|
|
"\n",
|
|
"# Maske anwenden\n",
|
|
"predicted_stft = predicted_mask_resized * mix_stft\n",
|
|
"\n",
|
|
"audio_vocal = nussl.AudioSignal()\n",
|
|
"audio_vocal.stft_data = predicted_stft.numpy()\n",
|
|
"audio_vocal.istft(window_length=512, hop_length=128)\n",
|
|
"\n",
|
|
"item = test_data[song_id]\n",
|
|
"#item = train_data[song_id]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"mask = nussl.AudioSignal()\n",
|
|
"mask.stft_data = predicted_mask_resized.numpy()\n",
|
|
"mask.istft(window_length=512, hop_length=128)\n",
|
|
"show_1fre(mask)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 29,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Orginaldaten\n",
|
|
"print('Mix:')\n",
|
|
"display(Audio(data=item['mix'].audio_data, rate=item['mix'].sample_rate))\n",
|
|
"print('Vocals:')\n",
|
|
"display(Audio(data=item['sources']['vocals'].audio_data, rate=item['mix'].sample_rate))\n",
|
|
"#print(item.keys())\n",
|
|
"show_wav(item['sources'])\n",
|
|
"show_fre(item['sources'])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 31,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Audio nach Model\n",
|
|
"print('nach model:')\n",
|
|
"display(Audio(data=audio_vocal.audio_data, rate=item['mix'].sample_rate))\n",
|
|
"\n",
|
|
"print('Ziel Amplitudenverlauf:')\n",
|
|
"show_1wav(data=item['sources']['vocals'])\n",
|
|
"print('Amplitudenverlauf nach Model:')\n",
|
|
"show_1wav(audio_vocal)\n",
|
|
"\n",
|
|
"print('Ziel Spektogram:')\n",
|
|
"show_1fre(data=item['sources']['vocals'])\n",
|
|
"print('Spektogram nach Model:')\n",
|
|
"show_1fre(audio_vocal)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 46,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Modell speichern\n",
|
|
"model.save(\"24_11_16_unet_4.1846.h5\")\n",
|
|
"print(\"Modell erfolgreich gespeichert!\")\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Modell laden\n",
|
|
"loaded_model = tf.keras.models.load_model(\"unet_music_source_separation.h5\")\n",
|
|
"print(\"Modell erfolgreich geladen!\")\n"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|