diff --git a/app.py b/app.py
index c74e412b..8b869a2f 100644
--- a/app.py
+++ b/app.py
@@ -3,7 +3,9 @@
from tabs.full_inference import full_inference_tab
from tabs.download_model import download_model_tab
from tabs.settings import theme_tab, lang_tab, restart_tab
-from programs.applio_code.rvc.lib.tools.prerequisites_download import prequisites_download_pipeline
+from programs.applio_code.rvc.lib.tools.prerequisites_download import (
+ prequisites_download_pipeline,
+)
from tabs.presence import load_config_presence, presence_tab
now_dir = os.getcwd()
@@ -15,7 +17,7 @@
prequisites_download_pipeline(
False,
False,
- True,
+ True,
False,
)
@@ -32,15 +34,14 @@
RPCManager.start_presence()
-
rvc_theme = loadThemes.load_theme() or "NoCrypt/miku"
-with gr.Blocks(
- theme=rvc_theme, title="Advanced RVC Inference"
-) as rvc:
+with gr.Blocks(theme=rvc_theme, title="Advanced RVC Inference") as rvc:
gr.Markdown('
Advanced RVC Inference
')
- gr.Markdown('this project Maintained by
NeoDev ')
-
+ gr.Markdown(
+ 'this project Maintained by
NeoDev '
+ )
+
with gr.Tab(i18n("Full Inference")):
full_inference_tab()
with gr.Tab(i18n("Download Model")):
diff --git a/assets/discord_presence.py b/assets/discord_presence.py
index 45bc6873..ab52ec87 100644
--- a/assets/discord_presence.py
+++ b/assets/discord_presence.py
@@ -30,8 +30,14 @@ def update_presence(self):
state="Advanced-RVC",
details="Advaced voice cloning with UVR5 feature",
buttons=[
- {"label": "Home", "url": "https://github.com/ArkanDash/Advanced-RVC-Inference"},
- {"label": "Download", "url": "https://github.com/ArkanDash/Advanced-RVC-Inference/archive/refs/heads/master.zip"},
+ {
+ "label": "Home",
+ "url": "https://github.com/ArkanDash/Advanced-RVC-Inference",
+ },
+ {
+ "label": "Download",
+ "url": "https://github.com/ArkanDash/Advanced-RVC-Inference/archive/refs/heads/master.zip",
+ },
],
large_image="logo",
large_text="Experimenting with Advanced-RVC",
diff --git a/assets/i18n/languages/id_ID.py b/assets/i18n/languages/id_ID.py
index 29c08f35..ff760676 100644
--- a/assets/i18n/languages/id_ID.py
+++ b/assets/i18n/languages/id_ID.py
@@ -85,5 +85,5 @@
"Export Audio": "Ekspor Audio",
"Music URL": "URL Musik",
"Download": "Unduh",
- "Model URL": "URL Model"
+ "Model URL": "URL Model",
}
diff --git a/assets/themes/Grheme.py b/assets/themes/Grheme.py
index b1c68db1..66e36ec5 100644
--- a/assets/themes/Grheme.py
+++ b/assets/themes/Grheme.py
@@ -1,5 +1,4 @@
from __future__ import annotations
-import time
from typing import Iterable
import gradio as gr
diff --git a/core.py b/core.py
index e86f6799..42cb21e0 100644
--- a/core.py
+++ b/core.py
@@ -15,8 +15,7 @@
from programs.music_separation_code.inference import proc_file
logging.basicConfig(
- level=logging.INFO,
- format="%(asctime)s - %(levelname)s - %(message)s"
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
now_dir = os.getcwd()
sys.path.append(now_dir)
@@ -1015,7 +1014,6 @@ def download_model(link):
return "Model downloaded with success"
-
def download_music(link):
if not link or not isinstance(link, str):
logging.error("Invalid link provided.")
@@ -1035,9 +1033,11 @@ def download_music(link):
command = [
"yt-dlp",
"-x",
- "--audio-format", "wav",
- "--output", output_template,
- link
+ "--audio-format",
+ "wav",
+ "--output",
+ output_template,
+ link,
]
try:
diff --git a/programs/music_separation_code/ensemble.py b/programs/music_separation_code/ensemble.py
index 76fec7b7..1ad1ff9f 100644
--- a/programs/music_separation_code/ensemble.py
+++ b/programs/music_separation_code/ensemble.py
@@ -1,5 +1,5 @@
# coding: utf-8
-__author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/'
+__author__ = "Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/"
import os
import librosa
@@ -81,42 +81,42 @@ def average_waveforms(pred_track, weights, algorithm):
mod_track = []
for i in range(pred_track.shape[0]):
- if algorithm == 'avg_wave':
+ if algorithm == "avg_wave":
mod_track.append(pred_track[i] * weights[i])
- elif algorithm in ['median_wave', 'min_wave', 'max_wave']:
+ elif algorithm in ["median_wave", "min_wave", "max_wave"]:
mod_track.append(pred_track[i])
- elif algorithm in ['avg_fft', 'min_fft', 'max_fft', 'median_fft']:
+ elif algorithm in ["avg_fft", "min_fft", "max_fft", "median_fft"]:
spec = stft(pred_track[i], nfft=2048, hl=1024)
- if algorithm in ['avg_fft']:
+ if algorithm in ["avg_fft"]:
mod_track.append(spec * weights[i])
else:
mod_track.append(spec)
pred_track = np.array(mod_track)
- if algorithm in ['avg_wave']:
+ if algorithm in ["avg_wave"]:
pred_track = pred_track.sum(axis=0)
pred_track /= np.array(weights).sum().T
- elif algorithm in ['median_wave']:
+ elif algorithm in ["median_wave"]:
pred_track = np.median(pred_track, axis=0)
- elif algorithm in ['min_wave']:
+ elif algorithm in ["min_wave"]:
pred_track = np.array(pred_track)
pred_track = lambda_min(pred_track, axis=0, key=np.abs)
- elif algorithm in ['max_wave']:
+ elif algorithm in ["max_wave"]:
pred_track = np.array(pred_track)
pred_track = lambda_max(pred_track, axis=0, key=np.abs)
- elif algorithm in ['avg_fft']:
+ elif algorithm in ["avg_fft"]:
pred_track = pred_track.sum(axis=0)
pred_track /= np.array(weights).sum()
pred_track = istft(pred_track, 1024, final_length)
- elif algorithm in ['min_fft']:
+ elif algorithm in ["min_fft"]:
pred_track = np.array(pred_track)
pred_track = lambda_min(pred_track, axis=0, key=np.abs)
pred_track = istft(pred_track, 1024, final_length)
- elif algorithm in ['max_fft']:
+ elif algorithm in ["max_fft"]:
pred_track = np.array(pred_track)
pred_track = absmax(pred_track, axis=0)
pred_track = istft(pred_track, 1024, final_length)
- elif algorithm in ['median_fft']:
+ elif algorithm in ["median_fft"]:
pred_track = np.array(pred_track)
pred_track = np.median(pred_track, axis=0)
pred_track = istft(pred_track, 1024, final_length)
@@ -125,37 +125,58 @@ def average_waveforms(pred_track, weights, algorithm):
def ensemble_files(args):
parser = argparse.ArgumentParser()
- parser.add_argument("--files", type=str, required=True, nargs='+', help="Path to all audio-files to ensemble")
- parser.add_argument("--type", type=str, default='avg_wave', help="One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft")
- parser.add_argument("--weights", type=float, nargs='+', help="Weights to create ensemble. Number of weights must be equal to number of files")
- parser.add_argument("--output", default="res.wav", type=str, help="Path to wav file where ensemble result will be stored")
+ parser.add_argument(
+ "--files",
+ type=str,
+ required=True,
+ nargs="+",
+ help="Path to all audio-files to ensemble",
+ )
+ parser.add_argument(
+ "--type",
+ type=str,
+ default="avg_wave",
+ help="One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft",
+ )
+ parser.add_argument(
+ "--weights",
+ type=float,
+ nargs="+",
+ help="Weights to create ensemble. Number of weights must be equal to number of files",
+ )
+ parser.add_argument(
+ "--output",
+ default="res.wav",
+ type=str,
+ help="Path to wav file where ensemble result will be stored",
+ )
if args is None:
args = parser.parse_args()
else:
args = parser.parse_args(args)
- print('Ensemble type: {}'.format(args.type))
- print('Number of input files: {}'.format(len(args.files)))
+ print("Ensemble type: {}".format(args.type))
+ print("Number of input files: {}".format(len(args.files)))
if args.weights is not None:
weights = args.weights
else:
weights = np.ones(len(args.files))
- print('Weights: {}'.format(weights))
- print('Output file: {}'.format(args.output))
+ print("Weights: {}".format(weights))
+ print("Output file: {}".format(args.output))
data = []
for f in args.files:
if not os.path.isfile(f):
- print('Error. Can\'t find file: {}. Check paths.'.format(f))
+ print("Error. Can't find file: {}. Check paths.".format(f))
exit()
- print('Reading file: {}'.format(f))
+ print("Reading file: {}".format(f))
wav, sr = librosa.load(f, sr=None, mono=False)
# wav, sr = sf.read(f)
print("Waveform shape: {} sample rate: {}".format(wav.shape, sr))
data.append(wav)
data = np.array(data)
res = average_waveforms(data, weights, args.type)
- print('Result shape: {}'.format(res.shape))
- sf.write(args.output, res.T, sr, 'FLOAT')
+ print("Result shape: {}".format(res.shape))
+ sf.write(args.output, res.T, sr, "FLOAT")
if __name__ == "__main__":
diff --git a/programs/music_separation_code/inference.py b/programs/music_separation_code/inference.py
index 8c991d45..8ddfcb77 100644
--- a/programs/music_separation_code/inference.py
+++ b/programs/music_separation_code/inference.py
@@ -7,7 +7,6 @@
from tqdm import tqdm
import sys
import os
-import glob
import torch
import numpy as np
import soundfile as sf
diff --git a/programs/music_separation_code/models/bandit/core/__init__.py b/programs/music_separation_code/models/bandit/core/__init__.py
index a4d6d795..86e1557c 100644
--- a/programs/music_separation_code/models/bandit/core/__init__.py
+++ b/programs/music_separation_code/models/bandit/core/__init__.py
@@ -1,20 +1,14 @@
import os.path
from collections import defaultdict
from itertools import chain, combinations
-from typing import (
- Any,
- Dict,
- Iterator,
- Mapping, Optional,
- Tuple, Type,
- TypedDict
-)
+from typing import Any, Dict, Iterator, Mapping, Optional, Tuple, Type, TypedDict
import pytorch_lightning as pl
import torch
import torchaudio as ta
import torchmetrics as tm
from asteroid import losses as asteroid_losses
+
# from deepspeed.ops.adam import DeepSpeedCPUAdam
# from geoopt import optim as gooptim
from pytorch_lightning.utilities.types import STEP_OUTPUT
@@ -30,7 +24,7 @@
# from pandas.io.json._normalize import nested_to_record
-ConfigDict = TypedDict('ConfigDict', {'name': str, 'kwargs': Dict[str, Any]})
+ConfigDict = TypedDict("ConfigDict", {"name": str, "kwargs": Dict[str, Any]})
class SchedulerConfigDict(ConfigDict):
@@ -38,9 +32,9 @@ class SchedulerConfigDict(ConfigDict):
OptimizerSchedulerConfigDict = TypedDict(
- 'OptimizerSchedulerConfigDict',
- {"optimizer": ConfigDict, "scheduler": SchedulerConfigDict},
- total=False
+ "OptimizerSchedulerConfigDict",
+ {"optimizer": ConfigDict, "scheduler": SchedulerConfigDict},
+ total=False,
)
@@ -71,14 +65,13 @@ def get_optimizer_class(name: str) -> Type[optim.Optimizer]:
def parse_optimizer_config(
- config: OptimizerSchedulerConfigDict,
- parameters: Iterator[nn.Parameter]
+ config: OptimizerSchedulerConfigDict, parameters: Iterator[nn.Parameter]
) -> ConfigureOptimizerReturnDict:
optim_class = get_optimizer_class(config["optimizer"]["name"])
optimizer = optim_class(parameters, **config["optimizer"]["kwargs"])
optim_dict: ConfigureOptimizerReturnDict = {
- "optimizer": optimizer,
+ "optimizer": optimizer,
}
if "scheduler" in config:
@@ -86,10 +79,7 @@ def parse_optimizer_config(
lr_scheduler_class_ = config["scheduler"]["name"]
lr_scheduler_class = lr_scheduler.__dict__[lr_scheduler_class_]
lr_scheduler_dict: LRSchedulerReturnDict = {
- "scheduler": lr_scheduler_class(
- optimizer,
- **config["scheduler"]["kwargs"]
- )
+ "scheduler": lr_scheduler_class(optimizer, **config["scheduler"]["kwargs"])
}
if lr_scheduler_class_ == "ReduceLROnPlateau":
@@ -169,29 +159,26 @@ class LightningSystem(pl.LightningModule):
_BG_STEMS = ["background", "effects", "mne"]
def __init__(
- self,
- config: Dict,
- loss_adjustment: float = 1.0,
- attach_fader: bool = False
- ) -> None:
+ self, config: Dict, loss_adjustment: float = 1.0, attach_fader: bool = False
+ ) -> None:
super().__init__()
self.optimizer_config = config["optimizer"]
self.model = parse_model_config(config["model"])
self.loss = parse_loss_config(config["loss"])
self.metrics = nn.ModuleDict(
- {
- stem: parse_metric_config(config["metrics"]["dev"])
- for stem in self.model.stems
- }
+ {
+ stem: parse_metric_config(config["metrics"]["dev"])
+ for stem in self.model.stems
+ }
)
self.metrics.disallow_fsdp = True
self.test_metrics = nn.ModuleDict(
- {
- stem: parse_metric_config(config["metrics"]["test"])
- for stem in self.model.stems
- }
+ {
+ stem: parse_metric_config(config["metrics"]["test"])
+ for stem in self.model.stems
+ }
)
self.test_metrics.disallow_fsdp = True
@@ -216,22 +203,18 @@ def __init__(
self.val_prefix = None
self.test_prefix = None
-
def configure_optimizers(self) -> Any:
return parse_optimizer_config(
- self.optimizer_config,
- self.trainer.model.parameters()
- )
+ self.optimizer_config, self.trainer.model.parameters()
+ )
- def compute_loss(self, batch: BatchedDataDict, output: OutputType) -> Dict[
- str, torch.Tensor]:
+ def compute_loss(
+ self, batch: BatchedDataDict, output: OutputType
+ ) -> Dict[str, torch.Tensor]:
return {"loss": self.loss(output, batch)}
def update_metrics(
- self,
- batch: BatchedDataDict,
- output: OutputType,
- mode: str
+ self, batch: BatchedDataDict, output: OutputType, mode: str
) -> None:
if mode == "test":
@@ -247,9 +230,9 @@ def update_metrics(
# print(f"matching for {stem}")
if mode == "train":
metric.update(
- output["audio"][stem],#.cpu(),
- batch["audio"][stem],#.cpu()
- )
+ output["audio"][stem], # .cpu(),
+ batch["audio"][stem], # .cpu()
+ )
else:
if stem not in batch["audio"]:
matched = False
@@ -273,16 +256,18 @@ def update_metrics(
if matched:
# print(f"matched {stem}!")
if stem == "mne" and "mne" not in output["audio"]:
- output["audio"]["mne"] = output["audio"]["music"] + output["audio"]["effects"]
-
+ output["audio"]["mne"] = (
+ output["audio"]["music"] + output["audio"]["effects"]
+ )
+
metric.update(
- output["audio"][stem],#.cpu(),
- batch["audio"][stem],#.cpu(),
+ output["audio"][stem], # .cpu(),
+ batch["audio"][stem], # .cpu(),
)
# print(metric.compute())
- def compute_metrics(self, mode: str="dev") -> Dict[
- str, torch.Tensor]:
+
+ def compute_metrics(self, mode: str = "dev") -> Dict[str, torch.Tensor]:
if mode == "test":
metrics = self.test_metrics
@@ -293,10 +278,8 @@ def compute_metrics(self, mode: str="dev") -> Dict[
for stem, metric in metrics.items():
md = metric.compute()
- metric_dict.update(
- {f"{stem}/{k}": v for k, v in md.items()}
- )
-
+ metric_dict.update({f"{stem}/{k}": v for k, v in md.items()})
+
self.log_dict(metric_dict, prog_bar=True, logger=False)
return metric_dict
@@ -311,10 +294,8 @@ def reset_metrics(self, test_mode: bool = False) -> None:
for _, metric in metrics.items():
metric.reset()
-
def forward(self, batch: BatchedDataDict) -> Any:
batch, output = self.model(batch)
-
return batch, output
@@ -332,7 +313,6 @@ def common_step(self, batch: BatchedDataDict, mode: str) -> Any:
return output, loss_dict
-
def training_step(self, batch: BatchedDataDict) -> Dict[str, Any]:
if self.augmentation is not None:
@@ -343,9 +323,7 @@ def training_step(self, batch: BatchedDataDict) -> Dict[str, Any]:
with torch.inference_mode():
self.log_dict_with_prefix(
- loss_dict,
- "train",
- batch_size=batch["audio"]["mixture"].shape[0]
+ loss_dict, "train", batch_size=batch["audio"]["mixture"].shape[0]
)
loss_dict["loss"] *= self.loss_adjustment
@@ -353,7 +331,7 @@ def training_step(self, batch: BatchedDataDict) -> Dict[str, Any]:
return loss_dict
def on_train_batch_end(
- self, outputs: STEP_OUTPUT, batch: BatchedDataDict, batch_idx: int
+ self, outputs: STEP_OUTPUT, batch: BatchedDataDict, batch_idx: int
) -> None:
metric_dict = self.compute_metrics()
@@ -361,10 +339,7 @@ def on_train_batch_end(
self.reset_metrics()
def validation_step(
- self,
- batch: BatchedDataDict,
- batch_idx: int,
- dataloader_idx: int = 0
+ self, batch: BatchedDataDict, batch_idx: int, dataloader_idx: int = 0
) -> Dict[str, Any]:
with torch.inference_mode():
@@ -378,11 +353,11 @@ def validation_step(
_, loss_dict = self.common_step(batch, mode="val")
self.log_dict_with_prefix(
- loss_dict,
- self.val_prefix,
- batch_size=batch["audio"]["mixture"].shape[0],
- prog_bar=True,
- add_dataloader_idx=False
+ loss_dict,
+ self.val_prefix,
+ batch_size=batch["audio"]["mixture"].shape[0],
+ prog_bar=True,
+ add_dataloader_idx=False,
)
return loss_dict
@@ -392,29 +367,23 @@ def on_validation_epoch_end(self) -> None:
def _on_validation_epoch_end(self) -> None:
metric_dict = self.compute_metrics()
- self.log_dict_with_prefix(metric_dict, self.val_prefix, prog_bar=True,
- add_dataloader_idx=False)
+ self.log_dict_with_prefix(
+ metric_dict, self.val_prefix, prog_bar=True, add_dataloader_idx=False
+ )
# self.logger.save()
# print(self.val_prefix, "Validation metrics:", metric_dict)
self.reset_metrics()
-
def old_predtest_step(
- self,
- batch: BatchedDataDict,
- batch_idx: int,
- dataloader_idx: int = 0
+ self, batch: BatchedDataDict, batch_idx: int, dataloader_idx: int = 0
) -> Tuple[BatchedDataDict, OutputType]:
audio_batch = batch["audio"]["mixture"]
track_batch = batch.get("track", ["" for _ in range(len(audio_batch))])
output_list_of_dicts = [
- self.fader(
- audio[None, ...],
- lambda a: self.test_forward(a, track)
- )
- for audio, track in zip(audio_batch, track_batch)
+ self.fader(audio[None, ...], lambda a: self.test_forward(a, track))
+ for audio, track in zip(audio_batch, track_batch)
]
output_dict_of_lists = defaultdict(list)
@@ -424,19 +393,16 @@ def old_predtest_step(
output_dict_of_lists[stem].append(audio)
output = {
- "audio": {
- stem: torch.concat(output_list, dim=0)
- for stem, output_list in output_dict_of_lists.items()
- }
+ "audio": {
+ stem: torch.concat(output_list, dim=0)
+ for stem, output_list in output_dict_of_lists.items()
+ }
}
return batch, output
def predtest_step(
- self,
- batch: BatchedDataDict,
- batch_idx: int = -1,
- dataloader_idx: int = 0
+ self, batch: BatchedDataDict, batch_idx: int = -1, dataloader_idx: int = 0
) -> Tuple[BatchedDataDict, OutputType]:
if getattr(self.model, "bypass_fader", False):
@@ -444,17 +410,13 @@ def predtest_step(
else:
audio_batch = batch["audio"]["mixture"]
output = self.fader(
- audio_batch,
- lambda a: self.test_forward(a, "", batch=batch)
+ audio_batch, lambda a: self.test_forward(a, "", batch=batch)
)
return batch, output
def test_forward(
- self,
- audio: torch.Tensor,
- track: str = "",
- batch: BatchedDataDict = None
+ self, audio: torch.Tensor, track: str = "", batch: BatchedDataDict = None
) -> torch.Tensor:
if self.fader is None:
@@ -466,10 +428,11 @@ def test_forward(
cond = cond.repeat(audio.shape[0], 1)
_, output = self.forward(
- {"audio": {"mixture": audio},
- "track": track,
- "condition": cond,
- }
+ {
+ "audio": {"mixture": audio},
+ "track": track,
+ "condition": cond,
+ }
) # TODO: support track properly
return output["audio"]
@@ -478,10 +441,7 @@ def on_test_epoch_start(self) -> None:
self.attach_fader(force_reattach=True)
def test_step(
- self,
- batch: BatchedDataDict,
- batch_idx: int,
- dataloader_idx: int = 0
+ self, batch: BatchedDataDict, batch_idx: int, dataloader_idx: int = 0
) -> Any:
curr_test_prefix = f"test{dataloader_idx}"
@@ -505,22 +465,23 @@ def on_test_epoch_end(self) -> None:
def _on_test_epoch_end(self) -> None:
metric_dict = self.compute_metrics(mode="test")
- self.log_dict_with_prefix(metric_dict, self.test_prefix, prog_bar=True,
- add_dataloader_idx=False)
+ self.log_dict_with_prefix(
+ metric_dict, self.test_prefix, prog_bar=True, add_dataloader_idx=False
+ )
# self.logger.save()
# print(self.test_prefix, "Test metrics:", metric_dict)
self.reset_metrics()
def predict_step(
- self,
- batch: BatchedDataDict,
- batch_idx: int = 0,
- dataloader_idx: int = 0,
- include_track_name: Optional[bool] = None,
- get_no_vox_combinations: bool = True,
- get_residual: bool = False,
- treat_batch_as_channels: bool = False,
- fs: Optional[int] = None,
+ self,
+ batch: BatchedDataDict,
+ batch_idx: int = 0,
+ dataloader_idx: int = 0,
+ include_track_name: Optional[bool] = None,
+ get_no_vox_combinations: bool = True,
+ get_residual: bool = False,
+ treat_batch_as_channels: bool = False,
+ fs: Optional[int] = None,
) -> Any:
assert self.predict_output_path is not None
@@ -531,7 +492,7 @@ def predict_step(
with torch.inference_mode():
batch, output = self.predtest_step(batch, batch_idx, dataloader_idx)
- print('Pred test finished...')
+ print("Pred test finished...")
torch.cuda.empty_cache()
metric_dict = {}
@@ -545,24 +506,22 @@ def predict_step(
if get_no_vox_combinations:
no_vox_stems = [
- stem for stem in output["audio"] if
- stem not in self._VOX_STEMS
+ stem for stem in output["audio"] if stem not in self._VOX_STEMS
]
no_vox_combinations = chain.from_iterable(
- combinations(no_vox_stems, r) for r in
- range(2, len(no_vox_stems) + 1)
+ combinations(no_vox_stems, r) for r in range(2, len(no_vox_stems) + 1)
)
for combination in no_vox_combinations:
combination_ = list(combination)
output["audio"]["+".join(combination_)] = sum(
- [output["audio"][stem] for stem in combination_]
+ [output["audio"][stem] for stem in combination_]
)
if treat_batch_as_channels:
for stem in output["audio"]:
output["audio"][stem] = output["audio"][stem].reshape(
- 1, -1, output["audio"][stem].shape[-1]
+ 1, -1, output["audio"][stem].shape[-1]
)
batch_size = 1
@@ -575,28 +534,24 @@ def predict_step(
if batch.get("audio", {}).get(stem, None) is not None:
self.test_metrics[stem].reset()
metrics = self.test_metrics[stem](
- batch["audio"][stem][[b], ...],
- output["audio"][stem][[b], ...]
+ batch["audio"][stem][[b], ...], output["audio"][stem][[b], ...]
)
snr = metrics["snr"]
sisnr = metrics["sisnr"]
sdr = metrics["sdr"]
metric_dict[stem] = metrics
print(
- track_name,
- f"snr={snr:2.2f} dB",
- f"sisnr={sisnr:2.2f}",
- f"sdr={sdr:2.2f} dB",
+ track_name,
+ f"snr={snr:2.2f} dB",
+ f"sisnr={sisnr:2.2f}",
+ f"sdr={sdr:2.2f} dB",
)
filename = f"{stem} - snr={snr:2.2f}dB - sdr={sdr:2.2f}dB.wav"
else:
filename = f"{stem}.wav"
if include_track_name:
- output_dir = os.path.join(
- self.predict_output_path,
- track_name
- )
+ output_dir = os.path.join(self.predict_output_path, track_name)
else:
output_dir = self.predict_output_path
@@ -606,23 +561,23 @@ def predict_step(
fs = self.fs
ta.save(
- os.path.join(output_dir, filename),
- output["audio"][stem][b, ...].cpu(),
- fs,
+ os.path.join(output_dir, filename),
+ output["audio"][stem][b, ...].cpu(),
+ fs,
)
return metric_dict
def get_stems(
- self,
- batch: BatchedDataDict,
- batch_idx: int = 0,
- dataloader_idx: int = 0,
- include_track_name: Optional[bool] = None,
- get_no_vox_combinations: bool = True,
- get_residual: bool = False,
- treat_batch_as_channels: bool = False,
- fs: Optional[int] = None,
+ self,
+ batch: BatchedDataDict,
+ batch_idx: int = 0,
+ dataloader_idx: int = 0,
+ include_track_name: Optional[bool] = None,
+ get_no_vox_combinations: bool = True,
+ get_residual: bool = False,
+ treat_batch_as_channels: bool = False,
+ fs: Optional[int] = None,
) -> Any:
assert self.predict_output_path is not None
@@ -646,24 +601,22 @@ def get_stems(
if get_no_vox_combinations:
no_vox_stems = [
- stem for stem in output["audio"] if
- stem not in self._VOX_STEMS
+ stem for stem in output["audio"] if stem not in self._VOX_STEMS
]
no_vox_combinations = chain.from_iterable(
- combinations(no_vox_stems, r) for r in
- range(2, len(no_vox_stems) + 1)
+ combinations(no_vox_stems, r) for r in range(2, len(no_vox_stems) + 1)
)
for combination in no_vox_combinations:
combination_ = list(combination)
output["audio"]["+".join(combination_)] = sum(
- [output["audio"][stem] for stem in combination_]
+ [output["audio"][stem] for stem in combination_]
)
if treat_batch_as_channels:
for stem in output["audio"]:
output["audio"][stem] = output["audio"][stem].reshape(
- 1, -1, output["audio"][stem].shape[-1]
+ 1, -1, output["audio"][stem].shape[-1]
)
batch_size = 1
@@ -675,28 +628,24 @@ def get_stems(
if batch.get("audio", {}).get(stem, None) is not None:
self.test_metrics[stem].reset()
metrics = self.test_metrics[stem](
- batch["audio"][stem][[b], ...],
- output["audio"][stem][[b], ...]
+ batch["audio"][stem][[b], ...], output["audio"][stem][[b], ...]
)
snr = metrics["snr"]
sisnr = metrics["sisnr"]
sdr = metrics["sdr"]
metric_dict[stem] = metrics
print(
- track_name,
- f"snr={snr:2.2f} dB",
- f"sisnr={sisnr:2.2f}",
- f"sdr={sdr:2.2f} dB",
+ track_name,
+ f"snr={snr:2.2f} dB",
+ f"sisnr={sisnr:2.2f}",
+ f"sdr={sdr:2.2f} dB",
)
filename = f"{stem} - snr={snr:2.2f}dB - sdr={sdr:2.2f}dB.wav"
else:
filename = f"{stem}.wav"
if include_track_name:
- output_dir = os.path.join(
- self.predict_output_path,
- track_name
- )
+ output_dir = os.path.join(self.predict_output_path, track_name)
else:
output_dir = self.predict_output_path
@@ -710,12 +659,11 @@ def get_stems(
return result
def load_state_dict(
- self, state_dict: Mapping[str, Any], strict: bool = False
+ self, state_dict: Mapping[str, Any], strict: bool = False
) -> Any:
return super().load_state_dict(state_dict, strict=False)
-
def set_predict_output_path(self, path: str) -> None:
self.predict_output_path = path
os.makedirs(self.predict_output_path, exist_ok=True)
@@ -727,18 +675,17 @@ def attach_fader(self, force_reattach=False) -> None:
self.fader = parse_fader_config(self.fader_config)
self.fader.to(self.device)
-
def log_dict_with_prefix(
- self,
- dict_: Dict[str, torch.Tensor],
- prefix: str,
- batch_size: Optional[int] = None,
- **kwargs: Any
+ self,
+ dict_: Dict[str, torch.Tensor],
+ prefix: str,
+ batch_size: Optional[int] = None,
+ **kwargs: Any,
) -> None:
self.log_dict(
- {f"{prefix}/{k}": v for k, v in dict_.items()},
- batch_size=batch_size,
- logger=True,
- sync_dist=True,
- **kwargs,
- )
\ No newline at end of file
+ {f"{prefix}/{k}": v for k, v in dict_.items()},
+ batch_size=batch_size,
+ logger=True,
+ sync_dist=True,
+ **kwargs,
+ )
diff --git a/programs/music_separation_code/models/bandit/core/data/__init__.py b/programs/music_separation_code/models/bandit/core/data/__init__.py
index 1087fe2c..a9d4d672 100644
--- a/programs/music_separation_code/models/bandit/core/data/__init__.py
+++ b/programs/music_separation_code/models/bandit/core/data/__init__.py
@@ -1,2 +1,2 @@
from .dnr.datamodule import DivideAndRemasterDataModule
-from .musdb.datamodule import MUSDB18DataModule
\ No newline at end of file
+from .musdb.datamodule import MUSDB18DataModule
diff --git a/programs/music_separation_code/models/bandit/core/data/_types.py b/programs/music_separation_code/models/bandit/core/data/_types.py
index 9499f9a8..65e4607a 100644
--- a/programs/music_separation_code/models/bandit/core/data/_types.py
+++ b/programs/music_separation_code/models/bandit/core/data/_types.py
@@ -4,11 +4,10 @@
AudioDict = Dict[str, torch.Tensor]
-DataDict = TypedDict('DataDict', {'audio': AudioDict, 'track': str})
+DataDict = TypedDict("DataDict", {"audio": AudioDict, "track": str})
BatchedDataDict = TypedDict(
- 'BatchedDataDict',
- {'audio': AudioDict, 'track': Sequence[str]}
+ "BatchedDataDict", {"audio": AudioDict, "track": Sequence[str]}
)
diff --git a/programs/music_separation_code/models/bandit/core/data/augmentation.py b/programs/music_separation_code/models/bandit/core/data/augmentation.py
index 238214bf..1aa2a9cf 100644
--- a/programs/music_separation_code/models/bandit/core/data/augmentation.py
+++ b/programs/music_separation_code/models/bandit/core/data/augmentation.py
@@ -9,18 +9,19 @@
class BaseAugmentor(nn.Module, ABC):
- def forward(self, item: Union[DataDict, BatchedDataDict]) -> Union[
- DataDict, BatchedDataDict]:
+ def forward(
+ self, item: Union[DataDict, BatchedDataDict]
+ ) -> Union[DataDict, BatchedDataDict]:
raise NotImplementedError
class StemAugmentor(BaseAugmentor):
def __init__(
- self,
- audiomentations: Dict[str, Dict[str, Any]],
- fix_clipping: bool = True,
- scaler_margin: float = 0.5,
- apply_both_default_and_common: bool = False,
+ self,
+ audiomentations: Dict[str, Dict[str, Any]],
+ fix_clipping: bool = True,
+ scaler_margin: float = 0.5,
+ apply_both_default_and_common: bool = False,
) -> None:
super().__init__()
@@ -32,23 +33,16 @@ def __init__(
for stem in audiomentations:
if audiomentations[stem]["name"] == "Compose":
- augmentations[stem] = getattr(
- tam,
- audiomentations[stem]["name"]
- )(
- [
- getattr(tam, aug["name"])(**aug["kwargs"])
- for aug in
- audiomentations[stem]["kwargs"]["transforms"]
- ],
- **audiomentations[stem]["kwargs"]["kwargs"],
+ augmentations[stem] = getattr(tam, audiomentations[stem]["name"])(
+ [
+ getattr(tam, aug["name"])(**aug["kwargs"])
+ for aug in audiomentations[stem]["kwargs"]["transforms"]
+ ],
+ **audiomentations[stem]["kwargs"]["kwargs"],
)
else:
- augmentations[stem] = getattr(
- tam,
- audiomentations[stem]["name"]
- )(
- **audiomentations[stem]["kwargs"]
+ augmentations[stem] = getattr(tam, audiomentations[stem]["name"])(
+ **audiomentations[stem]["kwargs"]
)
self.augmentations = nn.ModuleDict(augmentations)
@@ -56,7 +50,7 @@ def __init__(
self.scaler_margin = scaler_margin
def check_and_fix_clipping(
- self, item: Union[DataDict, BatchedDataDict]
+ self, item: Union[DataDict, BatchedDataDict]
) -> Union[DataDict, BatchedDataDict]:
max_abs = []
@@ -64,18 +58,20 @@ def check_and_fix_clipping(
max_abs.append(item["audio"][stem].abs().max().item())
if max(max_abs) > 1.0:
- scaler = 1.0 / (max(max_abs) + torch.rand(
- (1,),
- device=item["audio"]["mixture"].device
- ) * self.scaler_margin)
+ scaler = 1.0 / (
+ max(max_abs)
+ + torch.rand((1,), device=item["audio"]["mixture"].device)
+ * self.scaler_margin
+ )
for stem in item["audio"]:
item["audio"][stem] *= scaler
return item
- def forward(self, item: Union[DataDict, BatchedDataDict]) -> Union[
- DataDict, BatchedDataDict]:
+ def forward(
+ self, item: Union[DataDict, BatchedDataDict]
+ ) -> Union[DataDict, BatchedDataDict]:
for stem in item["audio"]:
if stem == "mixture":
@@ -83,22 +79,21 @@ def forward(self, item: Union[DataDict, BatchedDataDict]) -> Union[
if self.has_common:
item["audio"][stem] = self.augmentations["[common]"](
- item["audio"][stem]
+ item["audio"][stem]
).samples
if stem in self.augmentations:
item["audio"][stem] = self.augmentations[stem](
- item["audio"][stem]
+ item["audio"][stem]
).samples
elif self.has_default:
if not self.has_common or self.apply_both_default_and_common:
item["audio"][stem] = self.augmentations["[default]"](
- item["audio"][stem]
+ item["audio"][stem]
).samples
item["audio"]["mixture"] = sum(
- [item["audio"][stem] for stem in item["audio"]
- if stem != "mixture"]
+ [item["audio"][stem] for stem in item["audio"] if stem != "mixture"]
) # type: ignore[call-overload, assignment]
if self.fix_clipping:
diff --git a/programs/music_separation_code/models/bandit/core/data/augmented.py b/programs/music_separation_code/models/bandit/core/data/augmented.py
index 84d19599..3c052440 100644
--- a/programs/music_separation_code/models/bandit/core/data/augmented.py
+++ b/programs/music_separation_code/models/bandit/core/data/augmented.py
@@ -8,15 +8,15 @@
class AugmentedDataset(data.Dataset):
def __init__(
- self,
- dataset: data.Dataset,
- augmentation: nn.Module = nn.Identity(),
- target_length: Optional[int] = None,
+ self,
+ dataset: data.Dataset,
+ augmentation: nn.Module = nn.Identity(),
+ target_length: Optional[int] = None,
) -> None:
warnings.warn(
- "This class is no longer used. Attach augmentation to "
- "the LightningSystem instead.",
- DeprecationWarning,
+ "This class is no longer used. Attach augmentation to "
+ "the LightningSystem instead.",
+ DeprecationWarning,
)
self.dataset = dataset
@@ -25,8 +25,7 @@ def __init__(
self.ds_length: int = len(dataset) # type: ignore[arg-type]
self.length = target_length if target_length is not None else self.ds_length
- def __getitem__(self, index: int) -> Dict[str, Union[str, Dict[str,
- torch.Tensor]]]:
+ def __getitem__(self, index: int) -> Dict[str, Union[str, Dict[str, torch.Tensor]]]:
item = self.dataset[index % self.ds_length]
item = self.augmentation(item)
return item
diff --git a/programs/music_separation_code/models/bandit/core/data/base.py b/programs/music_separation_code/models/bandit/core/data/base.py
index a7b6c33a..18e37393 100644
--- a/programs/music_separation_code/models/bandit/core/data/base.py
+++ b/programs/music_separation_code/models/bandit/core/data/base.py
@@ -1,6 +1,5 @@
-import os
from abc import ABC, abstractmethod
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List
import numpy as np
import pedalboard as pb
@@ -13,14 +12,15 @@
class BaseSourceSeparationDataset(data.Dataset, ABC):
def __init__(
- self, split: str,
- stems: List[str],
- files: List[str],
- data_path: str,
- fs: int,
- npy_memmap: bool,
- recompute_mixture: bool
- ):
+ self,
+ split: str,
+ stems: List[str],
+ files: List[str],
+ data_path: str,
+ fs: int,
+ npy_memmap: bool,
+ recompute_mixture: bool,
+ ):
self.split = split
self.stems = stems
self.stems_no_mixture = [s for s in stems if s != "mixture"]
@@ -31,12 +31,7 @@ def __init__(
self.recompute_mixture = recompute_mixture
@abstractmethod
- def get_stem(
- self,
- *,
- stem: str,
- identifier: Dict[str, Any]
- ) -> torch.Tensor:
+ def get_stem(self, *, stem: str, identifier: Dict[str, Any]) -> torch.Tensor:
raise NotImplementedError
def _get_audio(self, stems, identifier: Dict[str, Any]):
@@ -49,10 +44,7 @@ def _get_audio(self, stems, identifier: Dict[str, Any]):
def get_audio(self, identifier: Dict[str, Any]) -> AudioDict:
if self.recompute_mixture:
- audio = self._get_audio(
- self.stems_no_mixture,
- identifier=identifier
- )
+ audio = self._get_audio(self.stems_no_mixture, identifier=identifier)
audio["mixture"] = self.compute_mixture(audio)
return audio
else:
@@ -64,6 +56,4 @@ def get_identifier(self, index: int) -> Dict[str, Any]:
def compute_mixture(self, audio: AudioDict) -> torch.Tensor:
- return sum(
- audio[stem] for stem in audio if stem != "mixture"
- )
+ return sum(audio[stem] for stem in audio if stem != "mixture")
diff --git a/programs/music_separation_code/models/bandit/core/data/dnr/datamodule.py b/programs/music_separation_code/models/bandit/core/data/dnr/datamodule.py
index dc555060..2971d419 100644
--- a/programs/music_separation_code/models/bandit/core/data/dnr/datamodule.py
+++ b/programs/music_separation_code/models/bandit/core/data/dnr/datamodule.py
@@ -7,20 +7,20 @@
DivideAndRemasterDataset,
DivideAndRemasterDeterministicChunkDataset,
DivideAndRemasterRandomChunkDataset,
- DivideAndRemasterRandomChunkDatasetWithSpeechReverb
+ DivideAndRemasterRandomChunkDatasetWithSpeechReverb,
)
def DivideAndRemasterDataModule(
- data_root: str = "$DATA_ROOT/DnR/v2",
- batch_size: int = 2,
- num_workers: int = 8,
- train_kwargs: Optional[Mapping] = None,
- val_kwargs: Optional[Mapping] = None,
- test_kwargs: Optional[Mapping] = None,
- datamodule_kwargs: Optional[Mapping] = None,
- use_speech_reverb: bool = False
- # augmentor=None
+ data_root: str = "$DATA_ROOT/DnR/v2",
+ batch_size: int = 2,
+ num_workers: int = 8,
+ train_kwargs: Optional[Mapping] = None,
+ val_kwargs: Optional[Mapping] = None,
+ test_kwargs: Optional[Mapping] = None,
+ datamodule_kwargs: Optional[Mapping] = None,
+ use_speech_reverb: bool = False,
+ # augmentor=None
) -> pl.LightningDataModule:
if train_kwargs is None:
train_kwargs = {}
@@ -47,26 +47,20 @@ def DivideAndRemasterDataModule(
else:
train_cls = DivideAndRemasterRandomChunkDataset
- train_dataset = train_cls(
- data_root, "train", **train_kwargs
- )
+ train_dataset = train_cls(data_root, "train", **train_kwargs)
# if augmentor is not None:
# train_dataset = AugmentedDataset(train_dataset, augmentor)
datamodule = pl.LightningDataModule.from_datasets(
- train_dataset=train_dataset,
- val_dataset=DivideAndRemasterDeterministicChunkDataset(
- data_root, "val", **val_kwargs
- ),
- test_dataset=DivideAndRemasterDataset(
- data_root,
- "test",
- **test_kwargs
- ),
- batch_size=batch_size,
- num_workers=num_workers,
- **datamodule_kwargs
+ train_dataset=train_dataset,
+ val_dataset=DivideAndRemasterDeterministicChunkDataset(
+ data_root, "val", **val_kwargs
+ ),
+ test_dataset=DivideAndRemasterDataset(data_root, "test", **test_kwargs),
+ batch_size=batch_size,
+ num_workers=num_workers,
+ **datamodule_kwargs
)
datamodule.predict_dataloader = datamodule.test_dataloader # type: ignore[method-assign]
diff --git a/programs/music_separation_code/models/bandit/core/data/dnr/dataset.py b/programs/music_separation_code/models/bandit/core/data/dnr/dataset.py
index 639290d8..00142c7b 100644
--- a/programs/music_separation_code/models/bandit/core/data/dnr/dataset.py
+++ b/programs/music_separation_code/models/bandit/core/data/dnr/dataset.py
@@ -15,10 +15,10 @@
class DivideAndRemasterBaseDataset(BaseSourceSeparationDataset, ABC):
ALLOWED_STEMS = ["mixture", "speech", "music", "effects", "mne"]
STEM_NAME_MAP = {
- "mixture": "mix",
- "speech": "speech",
- "music": "music",
- "effects": "sfx",
+ "mixture": "mix",
+ "speech": "speech",
+ "music": "music",
+ "effects": "sfx",
}
SPLIT_NAME_MAP = {"train": "tr", "val": "cv", "test": "tt"}
@@ -26,52 +26,42 @@ class DivideAndRemasterBaseDataset(BaseSourceSeparationDataset, ABC):
FULL_TRACK_LENGTH_SAMPLES = FULL_TRACK_LENGTH_SECOND * 44100
def __init__(
- self,
- split: str,
- stems: List[str],
- files: List[str],
- data_path: str,
- fs: int = 44100,
- npy_memmap: bool = True,
- recompute_mixture: bool = False,
+ self,
+ split: str,
+ stems: List[str],
+ files: List[str],
+ data_path: str,
+ fs: int = 44100,
+ npy_memmap: bool = True,
+ recompute_mixture: bool = False,
) -> None:
super().__init__(
- split=split,
- stems=stems,
- files=files,
- data_path=data_path,
- fs=fs,
- npy_memmap=npy_memmap,
- recompute_mixture=recompute_mixture
+ split=split,
+ stems=stems,
+ files=files,
+ data_path=data_path,
+ fs=fs,
+ npy_memmap=npy_memmap,
+ recompute_mixture=recompute_mixture,
)
- def get_stem(
- self,
- *,
- stem: str,
- identifier: Dict[str, Any]
- ) -> torch.Tensor:
-
+ def get_stem(self, *, stem: str, identifier: Dict[str, Any]) -> torch.Tensor:
+
if stem == "mne":
- return self.get_stem(
- stem="music",
- identifier=identifier) + self.get_stem(
- stem="effects",
- identifier=identifier)
+ return self.get_stem(stem="music", identifier=identifier) + self.get_stem(
+ stem="effects", identifier=identifier
+ )
track = identifier["track"]
path = os.path.join(self.data_path, track)
if self.npy_memmap:
audio = np.load(
- os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.npy"),
- mmap_mode="r"
+ os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.npy"), mmap_mode="r"
)
else:
# noinspection PyUnresolvedReferences
- audio, _ = ta.load(
- os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.wav")
- )
+ audio, _ = ta.load(os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.wav"))
return audio
@@ -87,12 +77,12 @@ def __getitem__(self, index: int) -> DataDict:
class DivideAndRemasterDataset(DivideAndRemasterBaseDataset):
def __init__(
- self,
- data_root: str,
- split: str,
- stems: Optional[List[str]] = None,
- fs: int = 44100,
- npy_memmap: bool = True,
+ self,
+ data_root: str,
+ split: str,
+ stems: Optional[List[str]] = None,
+ fs: int = 44100,
+ npy_memmap: bool = True,
) -> None:
if stems is None:
@@ -103,11 +93,9 @@ def __init__(
files = sorted(os.listdir(data_path))
files = [
- f
- for f in files
- if (not f.startswith(".")) and os.path.isdir(
- os.path.join(data_path, f)
- )
+ f
+ for f in files
+ if (not f.startswith(".")) and os.path.isdir(os.path.join(data_path, f))
]
# pprint(list(enumerate(files)))
if split == "train":
@@ -120,12 +108,12 @@ def __init__(
self.n_tracks = len(files)
super().__init__(
- data_path=data_path,
- split=split,
- stems=stems,
- files=files,
- fs=fs,
- npy_memmap=npy_memmap,
+ data_path=data_path,
+ split=split,
+ stems=stems,
+ files=files,
+ fs=fs,
+ npy_memmap=npy_memmap,
)
def __len__(self) -> int:
@@ -134,14 +122,14 @@ def __len__(self) -> int:
class DivideAndRemasterRandomChunkDataset(DivideAndRemasterBaseDataset):
def __init__(
- self,
- data_root: str,
- split: str,
- target_length: int,
- chunk_size_second: float,
- stems: Optional[List[str]] = None,
- fs: int = 44100,
- npy_memmap: bool = True,
+ self,
+ data_root: str,
+ split: str,
+ target_length: int,
+ chunk_size_second: float,
+ stems: Optional[List[str]] = None,
+ fs: int = 44100,
+ npy_memmap: bool = True,
) -> None:
if stems is None:
@@ -152,11 +140,9 @@ def __init__(
files = sorted(os.listdir(data_path))
files = [
- f
- for f in files
- if (not f.startswith(".")) and os.path.isdir(
- os.path.join(data_path, f)
- )
+ f
+ for f in files
+ if (not f.startswith(".")) and os.path.isdir(os.path.join(data_path, f))
]
if split == "train":
@@ -172,12 +158,12 @@ def __init__(
self.chunk_size = int(chunk_size_second * fs)
super().__init__(
- data_path=data_path,
- split=split,
- stems=stems,
- files=files,
- fs=fs,
- npy_memmap=npy_memmap,
+ data_path=data_path,
+ split=split,
+ stems=stems,
+ files=files,
+ fs=fs,
+ npy_memmap=npy_memmap,
)
def __len__(self) -> int:
@@ -187,22 +173,18 @@ def get_identifier(self, index):
return super().get_identifier(index % self.n_tracks)
def get_stem(
- self,
- *,
- stem: str,
- identifier: Dict[str, Any],
- chunk_here: bool = False,
- ) -> torch.Tensor:
-
- stem = super().get_stem(
- stem=stem,
- identifier=identifier
- )
+ self,
+ *,
+ stem: str,
+ identifier: Dict[str, Any],
+ chunk_here: bool = False,
+ ) -> torch.Tensor:
+
+ stem = super().get_stem(stem=stem, identifier=identifier)
if chunk_here:
start = np.random.randint(
- 0,
- self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size
+ 0, self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size
)
end = start + self.chunk_size
@@ -216,29 +198,24 @@ def __getitem__(self, index: int) -> DataDict:
audio = self.get_audio(identifier)
# self.index_lock = None
- start = np.random.randint(
- 0,
- self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size
- )
+ start = np.random.randint(0, self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size)
end = start + self.chunk_size
- audio = {
- k: v[:, start:end] for k, v in audio.items()
- }
+ audio = {k: v[:, start:end] for k, v in audio.items()}
return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
class DivideAndRemasterDeterministicChunkDataset(DivideAndRemasterBaseDataset):
def __init__(
- self,
- data_root: str,
- split: str,
- chunk_size_second: float,
- hop_size_second: float,
- stems: Optional[List[str]] = None,
- fs: int = 44100,
- npy_memmap: bool = True,
+ self,
+ data_root: str,
+ split: str,
+ chunk_size_second: float,
+ hop_size_second: float,
+ stems: Optional[List[str]] = None,
+ fs: int = 44100,
+ npy_memmap: bool = True,
) -> None:
if stems is None:
@@ -249,11 +226,9 @@ def __init__(
files = sorted(os.listdir(data_path))
files = [
- f
- for f in files
- if (not f.startswith(".")) and os.path.isdir(
- os.path.join(data_path, f)
- )
+ f
+ for f in files
+ if (not f.startswith(".")) and os.path.isdir(os.path.join(data_path, f))
]
# pprint(list(enumerate(files)))
if split == "train":
@@ -268,19 +243,18 @@ def __init__(
self.chunk_size = int(chunk_size_second * fs)
self.hop_size = int(hop_size_second * fs)
self.n_chunks_per_track = int(
- (
- self.FULL_TRACK_LENGTH_SECOND - chunk_size_second) / hop_size_second
+ (self.FULL_TRACK_LENGTH_SECOND - chunk_size_second) / hop_size_second
)
self.length = self.n_tracks * self.n_chunks_per_track
super().__init__(
- data_path=data_path,
- split=split,
- stems=stems,
- files=files,
- fs=fs,
- npy_memmap=npy_memmap,
+ data_path=data_path,
+ split=split,
+ stems=stems,
+ files=files,
+ fs=fs,
+ npy_memmap=npy_memmap,
)
def get_identifier(self, index):
@@ -308,17 +282,17 @@ def __getitem__(self, item: int) -> DataDict:
class DivideAndRemasterRandomChunkDatasetWithSpeechReverb(
- DivideAndRemasterRandomChunkDataset
+ DivideAndRemasterRandomChunkDataset
):
def __init__(
- self,
- data_root: str,
- split: str,
- target_length: int,
- chunk_size_second: float,
- stems: Optional[List[str]] = None,
- fs: int = 44100,
- npy_memmap: bool = True,
+ self,
+ data_root: str,
+ split: str,
+ target_length: int,
+ chunk_size_second: float,
+ stems: Optional[List[str]] = None,
+ fs: int = 44100,
+ npy_memmap: bool = True,
) -> None:
if stems is None:
@@ -327,13 +301,13 @@ def __init__(
stems_no_mixture = [s for s in stems if s != "mixture"]
super().__init__(
- data_root=data_root,
- split=split,
- target_length=target_length,
- chunk_size_second=chunk_size_second,
- stems=stems_no_mixture,
- fs=fs,
- npy_memmap=npy_memmap,
+ data_root=data_root,
+ split=split,
+ target_length=target_length,
+ chunk_size_second=chunk_size_second,
+ stems=stems_no_mixture,
+ fs=fs,
+ npy_memmap=npy_memmap,
)
self.stems = stems
@@ -349,17 +323,17 @@ def __getitem__(self, index: int) -> DataDict:
wet_level = np.random.rand()
speech = pb.Reverb(
- room_size=np.random.rand(),
- damping=np.random.rand(),
- wet_level=wet_level,
- dry_level=(1 - wet_level),
- width=np.random.rand()
+ room_size=np.random.rand(),
+ damping=np.random.rand(),
+ wet_level=wet_level,
+ dry_level=(1 - wet_level),
+ width=np.random.rand(),
).process(dry, self.fs, buffer_size=8192 * 4)[..., :n_samples]
data_["audio"]["speech"] = speech
data_["audio"]["mixture"] = sum(
- [data_["audio"][s] for s in self.stems_no_mixture]
+ [data_["audio"][s] for s in self.stems_no_mixture]
)
return data_
@@ -375,10 +349,10 @@ def __len__(self) -> int:
for split_ in ["train", "val", "test"]:
ds = DivideAndRemasterRandomChunkDatasetWithSpeechReverb(
- data_root="$DATA_ROOT/DnR/v2np",
- split=split_,
- target_length=100,
- chunk_size_second=6.0
+ data_root="$DATA_ROOT/DnR/v2np",
+ split=split_,
+ target_length=100,
+ chunk_size_second=6.0,
)
print(split_, len(ds))
diff --git a/programs/music_separation_code/models/bandit/core/data/dnr/preprocess.py b/programs/music_separation_code/models/bandit/core/data/dnr/preprocess.py
index 9d0b5869..18d68b18 100644
--- a/programs/music_separation_code/models/bandit/core/data/dnr/preprocess.py
+++ b/programs/music_separation_code/models/bandit/core/data/dnr/preprocess.py
@@ -16,7 +16,9 @@ def process_one(inputs: Tuple[str, str, int]) -> None:
data, fs = ta.load(infile)
if fs != target_fs:
- data = ta.functional.resample(data, fs, target_fs, resampling_method="sinc_interp_kaiser")
+ data = ta.functional.resample(
+ data, fs, target_fs, resampling_method="sinc_interp_kaiser"
+ )
fs = target_fs
data = data.numpy()
@@ -30,16 +32,11 @@ def process_one(inputs: Tuple[str, str, int]) -> None:
np.save(outfile, data)
-def preprocess(
- data_path: str,
- output_path: str,
- fs: int
-) -> None:
+def preprocess(data_path: str, output_path: str, fs: int) -> None:
files = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True)
print(files)
outfiles = [
- f.replace(data_path, output_path).replace(".wav", ".npy") for f in
- files
+ f.replace(data_path, output_path).replace(".wav", ".npy") for f in files
]
os.makedirs(output_path, exist_ok=True)
diff --git a/programs/music_separation_code/models/bandit/core/data/musdb/datamodule.py b/programs/music_separation_code/models/bandit/core/data/musdb/datamodule.py
index a8984dae..7b3c25e5 100644
--- a/programs/music_separation_code/models/bandit/core/data/musdb/datamodule.py
+++ b/programs/music_separation_code/models/bandit/core/data/musdb/datamodule.py
@@ -7,21 +7,21 @@
MUSDB18BaseDataset,
MUSDB18FullTrackDataset,
MUSDB18SadDataset,
- MUSDB18SadOnTheFlyAugmentedDataset
+ MUSDB18SadOnTheFlyAugmentedDataset,
)
def MUSDB18DataModule(
- data_root: str = "$DATA_ROOT/MUSDB18/HQ",
- target_stem: str = "vocals",
- batch_size: int = 2,
- num_workers: int = 8,
- train_kwargs: Optional[Mapping] = None,
- val_kwargs: Optional[Mapping] = None,
- test_kwargs: Optional[Mapping] = None,
- datamodule_kwargs: Optional[Mapping] = None,
- use_on_the_fly: bool = True,
- npy_memmap: bool = True
+ data_root: str = "$DATA_ROOT/MUSDB18/HQ",
+ target_stem: str = "vocals",
+ batch_size: int = 2,
+ num_workers: int = 8,
+ train_kwargs: Optional[Mapping] = None,
+ val_kwargs: Optional[Mapping] = None,
+ test_kwargs: Optional[Mapping] = None,
+ datamodule_kwargs: Optional[Mapping] = None,
+ use_on_the_fly: bool = True,
+ npy_memmap: bool = True,
) -> pl.LightningDataModule:
if train_kwargs is None:
train_kwargs = {}
@@ -39,39 +39,37 @@ def MUSDB18DataModule(
if use_on_the_fly:
train_dataset = MUSDB18SadOnTheFlyAugmentedDataset(
- data_root=os.path.join(data_root, "saded-np"),
- split="train",
- target_stem=target_stem,
- **train_kwargs
+ data_root=os.path.join(data_root, "saded-np"),
+ split="train",
+ target_stem=target_stem,
+ **train_kwargs
)
else:
train_dataset = MUSDB18SadDataset(
- data_root=os.path.join(data_root, "saded-np"),
- split="train",
- target_stem=target_stem,
- **train_kwargs
+ data_root=os.path.join(data_root, "saded-np"),
+ split="train",
+ target_stem=target_stem,
+ **train_kwargs
)
datamodule = pl.LightningDataModule.from_datasets(
- train_dataset=train_dataset,
- val_dataset=MUSDB18SadDataset(
- data_root=os.path.join(data_root, "saded-np"),
- split="val",
- target_stem=target_stem,
- **val_kwargs
- ),
- test_dataset=MUSDB18FullTrackDataset(
- data_root=os.path.join(data_root, "canonical"),
- split="test",
- **test_kwargs
- ),
- batch_size=batch_size,
- num_workers=num_workers,
- **datamodule_kwargs
+ train_dataset=train_dataset,
+ val_dataset=MUSDB18SadDataset(
+ data_root=os.path.join(data_root, "saded-np"),
+ split="val",
+ target_stem=target_stem,
+ **val_kwargs
+ ),
+ test_dataset=MUSDB18FullTrackDataset(
+ data_root=os.path.join(data_root, "canonical"), split="test", **test_kwargs
+ ),
+ batch_size=batch_size,
+ num_workers=num_workers,
+ **datamodule_kwargs
)
datamodule.predict_dataloader = ( # type: ignore[method-assign]
- datamodule.test_dataloader
+ datamodule.test_dataloader
)
return datamodule
diff --git a/programs/music_separation_code/models/bandit/core/data/musdb/dataset.py b/programs/music_separation_code/models/bandit/core/data/musdb/dataset.py
index c59a07d0..f66319f0 100644
--- a/programs/music_separation_code/models/bandit/core/data/musdb/dataset.py
+++ b/programs/music_separation_code/models/bandit/core/data/musdb/dataset.py
@@ -16,22 +16,22 @@ class MUSDB18BaseDataset(BaseSourceSeparationDataset, ABC):
ALLOWED_STEMS = ["mixture", "vocals", "bass", "drums", "other"]
def __init__(
- self,
- split: str,
- stems: List[str],
- files: List[str],
- data_path: str,
- fs: int = 44100,
- npy_memmap=False,
+ self,
+ split: str,
+ stems: List[str],
+ files: List[str],
+ data_path: str,
+ fs: int = 44100,
+ npy_memmap=False,
) -> None:
super().__init__(
- split=split,
- stems=stems,
- files=files,
- data_path=data_path,
- fs=fs,
- npy_memmap=npy_memmap,
- recompute_mixture=False
+ split=split,
+ stems=stems,
+ files=files,
+ data_path=data_path,
+ fs=fs,
+ npy_memmap=npy_memmap,
+ recompute_mixture=False,
)
def get_stem(self, *, stem: str, identifier) -> torch.Tensor:
@@ -61,25 +61,24 @@ class MUSDB18FullTrackDataset(MUSDB18BaseDataset):
N_TRAIN_TRACKS = 100
N_TEST_TRACKS = 50
VALIDATION_FILES = [
- "Actions - One Minute Smile",
- "Clara Berry And Wooldog - Waltz For My Victims",
- "Johnny Lokke - Promises & Lies",
- "Patrick Talbot - A Reason To Leave",
- "Triviul - Angelsaint",
- "Alexander Ross - Goodbye Bolero",
- "Fergessen - Nos Palpitants",
- "Leaf - Summerghost",
- "Skelpolu - Human Mistakes",
- "Young Griffo - Pennies",
- "ANiMAL - Rockshow",
- "James May - On The Line",
- "Meaxic - Take A Step",
- "Traffic Experiment - Sirens",
+ "Actions - One Minute Smile",
+ "Clara Berry And Wooldog - Waltz For My Victims",
+ "Johnny Lokke - Promises & Lies",
+ "Patrick Talbot - A Reason To Leave",
+ "Triviul - Angelsaint",
+ "Alexander Ross - Goodbye Bolero",
+ "Fergessen - Nos Palpitants",
+ "Leaf - Summerghost",
+ "Skelpolu - Human Mistakes",
+ "Young Griffo - Pennies",
+ "ANiMAL - Rockshow",
+ "James May - On The Line",
+ "Meaxic - Take A Step",
+ "Traffic Experiment - Sirens",
]
def __init__(
- self, data_root: str, split: str, stems: Optional[List[
- str]] = None
+ self, data_root: str, split: str, stems: Optional[List[str]] = None
) -> None:
if stems is None:
@@ -112,25 +111,21 @@ def __init__(
self.n_tracks = len(files)
- super().__init__(
- data_path=data_path,
- split=split,
- stems=stems,
- files=files
- )
+ super().__init__(data_path=data_path, split=split, stems=stems, files=files)
def __len__(self) -> int:
return self.n_tracks
+
class MUSDB18SadDataset(MUSDB18BaseDataset):
def __init__(
- self,
- data_root: str,
- split: str,
- target_stem: str,
- stems: Optional[List[str]] = None,
- target_length: Optional[int] = None,
- npy_memmap=False,
+ self,
+ data_root: str,
+ split: str,
+ target_stem: str,
+ stems: Optional[List[str]] = None,
+ target_length: Optional[int] = None,
+ npy_memmap=False,
) -> None:
if stems is None:
@@ -142,16 +137,16 @@ def __init__(
files = [f for f in files if not f.startswith(".")]
super().__init__(
- data_path=data_path,
- split=split,
- stems=stems,
- files=files,
- npy_memmap=npy_memmap
+ data_path=data_path,
+ split=split,
+ stems=stems,
+ files=files,
+ npy_memmap=npy_memmap,
)
self.n_segments = len(files)
self.target_stem = target_stem
self.target_length = (
- target_length if target_length is not None else self.n_segments
+ target_length if target_length is not None else self.n_segments
)
def __len__(self) -> int:
@@ -169,23 +164,22 @@ def get_identifier(self, index):
class MUSDB18SadOnTheFlyAugmentedDataset(MUSDB18SadDataset):
def __init__(
- self,
- data_root: str,
- split: str,
- target_stem: str,
- stems: Optional[List[str]] = None,
- target_length: int = 20000,
- apply_probability: Optional[float] = None,
- chunk_size_second: float = 3.0,
- random_scale_range_db: Tuple[float, float] = (-10, 10),
- drop_probability: float = 0.1,
- rescale: bool = True,
+ self,
+ data_root: str,
+ split: str,
+ target_stem: str,
+ stems: Optional[List[str]] = None,
+ target_length: int = 20000,
+ apply_probability: Optional[float] = None,
+ chunk_size_second: float = 3.0,
+ random_scale_range_db: Tuple[float, float] = (-10, 10),
+ drop_probability: float = 0.1,
+ rescale: bool = True,
) -> None:
super().__init__(data_root, split, target_stem, stems)
if apply_probability is None:
- apply_probability = (
- target_length - self.n_segments) / target_length
+ apply_probability = (target_length - self.n_segments) / target_length
self.apply_probability = apply_probability
self.drop_probability = drop_probability
@@ -226,7 +220,7 @@ def __getitem__(self, index: int) -> DataDict:
if self.chunk_size_sample < audio[stem].shape[-1]:
chunk_start = np.random.randint(
- audio[stem].shape[-1] - self.chunk_size_sample
+ audio[stem].shape[-1] - self.chunk_size_sample
)
else:
chunk_start = 0
@@ -239,18 +233,16 @@ def __getitem__(self, index: int) -> DataDict:
linear_scale = np.power(10, db_scale / 20)
# db_scale = f"{db_scale:+2.1f}"
# print(linear_scale)
- audio[stem][...,
- chunk_start: chunk_start + self.chunk_size_sample] = (
- linear_scale
- * audio[stem][...,
- chunk_start: chunk_start + self.chunk_size_sample]
+ audio[stem][..., chunk_start : chunk_start + self.chunk_size_sample] = (
+ linear_scale
+ * audio[stem][..., chunk_start : chunk_start + self.chunk_size_sample]
)
audio["mixture"] = self.compute_mixture(audio)
if self.rescale:
max_abs_val = max(
- [torch.max(torch.abs(audio[stem])) for stem in self.stems]
+ [torch.max(torch.abs(audio[stem])) for stem in self.stems]
) # type: ignore[type-var]
if max_abs_val > 1:
audio = {k: v / max_abs_val for k, v in audio.items()}
@@ -259,6 +251,7 @@ def __getitem__(self, index: int) -> DataDict:
return {"audio": audio, "track": f"{self.split}/{track}"}
+
# if __name__ == "__main__":
#
# from pprint import pprint
diff --git a/programs/music_separation_code/models/bandit/core/data/musdb/preprocess.py b/programs/music_separation_code/models/bandit/core/data/musdb/preprocess.py
index 45b3fe40..bbc02b14 100644
--- a/programs/music_separation_code/models/bandit/core/data/musdb/preprocess.py
+++ b/programs/music_separation_code/models/bandit/core/data/musdb/preprocess.py
@@ -12,20 +12,21 @@
from core.data.musdb.dataset import MUSDB18FullTrackDataset
import pyloudnorm as pyln
+
class SourceActivityDetector(nn.Module):
def __init__(
- self,
- analysis_stem: str,
- output_path: str,
- fs: int = 44100,
- segment_length_second: float = 6.0,
- hop_length_second: float = 3.0,
- n_chunks: int = 10,
- chunk_epsilon: float = 1e-5,
- energy_threshold_quantile: float = 0.15,
- segment_epsilon: float = 1e-3,
- salient_proportion_threshold: float = 0.5,
- target_lufs: float = -24
+ self,
+ analysis_stem: str,
+ output_path: str,
+ fs: int = 44100,
+ segment_length_second: float = 6.0,
+ hop_length_second: float = 3.0,
+ n_chunks: int = 10,
+ chunk_epsilon: float = 1e-5,
+ energy_threshold_quantile: float = 0.15,
+ segment_epsilon: float = 1e-3,
+ salient_proportion_threshold: float = 0.5,
+ target_lufs: float = -24,
) -> None:
super().__init__()
@@ -48,8 +49,7 @@ def __init__(
def forward(self, data: DataDict) -> None:
- stem_ = self.analysis_stem if (
- self.analysis_stem != "none") else "mixture"
+ stem_ = self.analysis_stem if (self.analysis_stem != "none") else "mixture"
x = data["audio"][stem_]
@@ -69,9 +69,7 @@ def forward(self, data: DataDict) -> None:
n_chan, n_samples = x.shape
n_segments = (
- int(
- np.ceil((n_samples - self.segment_length) / self.hop_length)
- ) + 1
+ int(np.ceil((n_samples - self.segment_length) / self.hop_length)) + 1
)
segments = torch.zeros((n_segments, n_chan, self.segment_length))
@@ -84,16 +82,12 @@ def forward(self, data: DataDict) -> None:
if end - start < self.segment_length:
xseg = F.pad(
- xseg,
- pad=(0, self.segment_length - (end - start)),
- value=torch.nan
+ xseg, pad=(0, self.segment_length - (end - start)), value=torch.nan
)
segments[i, :, :] = xseg
- chunks = segments.reshape(
- (n_segments, n_chan, self.n_chunks, self.chunk_size)
- )
+ chunks = segments.reshape((n_segments, n_chan, self.n_chunks, self.chunk_size))
if self.analysis_stem != "none":
chunk_energies = torch.mean(torch.square(chunks), dim=(1, 3))
@@ -101,7 +95,7 @@ def forward(self, data: DataDict) -> None:
chunk_energies[chunk_energies == 0] = self.chunk_epsilon
energy_threshold = torch.nanquantile(
- chunk_energies, q=self.energy_threshold_quantile
+ chunk_energies, q=self.energy_threshold_quantile
)
if energy_threshold < self.segment_epsilon:
@@ -109,11 +103,11 @@ def forward(self, data: DataDict) -> None:
chunks_above_threshold = chunk_energies > energy_threshold
n_chunks_above_threshold = torch.mean(
- chunks_above_threshold.to(torch.float), dim=-1
+ chunks_above_threshold.to(torch.float), dim=-1
)
segment_above_threshold = (
- n_chunks_above_threshold > self.salient_proportion_threshold
+ n_chunks_above_threshold > self.salient_proportion_threshold
)
if torch.sum(segment_above_threshold) == 0:
@@ -127,9 +121,9 @@ def forward(self, data: DataDict) -> None:
continue
outpath = os.path.join(
- self.output_path,
- self.analysis_stem,
- f"{data['track']} - {self.analysis_stem}{i:03d}",
+ self.output_path,
+ self.analysis_stem,
+ f"{data['track']} - {self.analysis_stem}{i:03d}",
)
os.makedirs(outpath, exist_ok=True)
@@ -145,8 +139,7 @@ def forward(self, data: DataDict) -> None:
if end - start < self.segment_length:
segment = F.pad(
- segment,
- (0, self.segment_length - (end - start))
+ segment, (0, self.segment_length - (end - start))
)
assert segment.shape[-1] == self.segment_length, segment.shape
@@ -157,35 +150,35 @@ def forward(self, data: DataDict) -> None:
def preprocess(
- analysis_stem: str,
- output_path: str = "/data/MUSDB18/HQ/saded-np",
- fs: int = 44100,
- segment_length_second: float = 6.0,
- hop_length_second: float = 3.0,
- n_chunks: int = 10,
- chunk_epsilon: float = 1e-5,
- energy_threshold_quantile: float = 0.15,
- segment_epsilon: float = 1e-3,
- salient_proportion_threshold: float = 0.5,
+ analysis_stem: str,
+ output_path: str = "/data/MUSDB18/HQ/saded-np",
+ fs: int = 44100,
+ segment_length_second: float = 6.0,
+ hop_length_second: float = 3.0,
+ n_chunks: int = 10,
+ chunk_epsilon: float = 1e-5,
+ energy_threshold_quantile: float = 0.15,
+ segment_epsilon: float = 1e-3,
+ salient_proportion_threshold: float = 0.5,
) -> None:
sad = SourceActivityDetector(
- analysis_stem=analysis_stem,
- output_path=output_path,
- fs=fs,
- segment_length_second=segment_length_second,
- hop_length_second=hop_length_second,
- n_chunks=n_chunks,
- chunk_epsilon=chunk_epsilon,
- energy_threshold_quantile=energy_threshold_quantile,
- segment_epsilon=segment_epsilon,
- salient_proportion_threshold=salient_proportion_threshold,
+ analysis_stem=analysis_stem,
+ output_path=output_path,
+ fs=fs,
+ segment_length_second=segment_length_second,
+ hop_length_second=hop_length_second,
+ n_chunks=n_chunks,
+ chunk_epsilon=chunk_epsilon,
+ energy_threshold_quantile=energy_threshold_quantile,
+ segment_epsilon=segment_epsilon,
+ salient_proportion_threshold=salient_proportion_threshold,
)
for split in ["train", "val", "test"]:
ds = MUSDB18FullTrackDataset(
- data_root="/data/MUSDB18/HQ/canonical",
- split=split,
+ data_root="/data/MUSDB18/HQ/canonical",
+ split=split,
)
tracks = []
@@ -196,9 +189,8 @@ def preprocess(
tracks.append(track)
process_map(sad, tracks, max_workers=8)
-def loudness_norm_one(
- inputs
-):
+
+def loudness_norm_one(inputs):
infile, outfile, target_lufs = inputs
audio, fs = ta.load(infile)
@@ -211,25 +203,21 @@ def loudness_norm_one(
os.makedirs(os.path.dirname(outfile), exist_ok=True)
np.save(outfile, audio.T)
+
def loudness_norm(
- data_path: str,
- # output_path: str,
- target_lufs = -17.0,
+ data_path: str,
+ # output_path: str,
+ target_lufs=-17.0,
):
- files = glob.glob(
- os.path.join(data_path, "**", "*.wav"), recursive=True
- )
+ files = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True)
- outfiles = [
- f.replace(".wav", ".npy").replace("saded", "saded-np") for f in files
- ]
+ outfiles = [f.replace(".wav", ".npy").replace("saded", "saded-np") for f in files]
files = [(f, o, target_lufs) for f, o in zip(files, outfiles)]
process_map(loudness_norm_one, files, chunksize=2)
-
if __name__ == "__main__":
from tqdm import tqdm
diff --git a/programs/music_separation_code/models/bandit/core/loss/__init__.py b/programs/music_separation_code/models/bandit/core/loss/__init__.py
index 0ab803ae..993be521 100644
--- a/programs/music_separation_code/models/bandit/core/loss/__init__.py
+++ b/programs/music_separation_code/models/bandit/core/loss/__init__.py
@@ -1,2 +1,8 @@
from ._multistem import MultiStemWrapperFromConfig
-from ._timefreq import ReImL1Loss, ReImL2Loss, TimeFreqL1Loss, TimeFreqL2Loss, TimeFreqSignalNoisePNormRatioLoss
+from ._timefreq import (
+ ReImL1Loss,
+ ReImL2Loss,
+ TimeFreqL1Loss,
+ TimeFreqL2Loss,
+ TimeFreqSignalNoisePNormRatioLoss,
+)
diff --git a/programs/music_separation_code/models/bandit/core/loss/_complex.py b/programs/music_separation_code/models/bandit/core/loss/_complex.py
index 1d97e5d8..68c82f20 100644
--- a/programs/music_separation_code/models/bandit/core/loss/_complex.py
+++ b/programs/music_separation_code/models/bandit/core/loss/_complex.py
@@ -11,15 +11,8 @@ def __init__(self, module: _Loss) -> None:
super().__init__()
self.module = module
- def forward(
- self,
- preds: torch.Tensor,
- target: torch.Tensor
- ) -> torch.Tensor:
- return self.module(
- torch.view_as_real(preds),
- torch.view_as_real(target)
- )
+ def forward(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
+ return self.module(torch.view_as_real(preds), torch.view_as_real(target))
class ReImL1Loss(ReImLossWrapper):
diff --git a/programs/music_separation_code/models/bandit/core/loss/_multistem.py b/programs/music_separation_code/models/bandit/core/loss/_multistem.py
index 675e0ffb..e9c4a4f7 100644
--- a/programs/music_separation_code/models/bandit/core/loss/_multistem.py
+++ b/programs/music_separation_code/models/bandit/core/loss/_multistem.py
@@ -24,16 +24,14 @@ def __init__(self, module: _Loss, modality: str = "audio") -> None:
self.modality = modality
def forward(
- self,
- preds: Dict[str, Dict[str, torch.Tensor]],
- target: Dict[str, Dict[str, torch.Tensor]],
+ self,
+ preds: Dict[str, Dict[str, torch.Tensor]],
+ target: Dict[str, Dict[str, torch.Tensor]],
) -> torch.Tensor:
loss = {
- stem: self.loss(
- preds[self.modality][stem],
- target[self.modality][stem]
- )
- for stem in preds[self.modality] if stem in target[self.modality]
+ stem: self.loss(preds[self.modality][stem], target[self.modality][stem])
+ for stem in preds[self.modality]
+ if stem in target[self.modality]
}
return sum(list(loss.values()))
diff --git a/programs/music_separation_code/models/bandit/core/loss/_timefreq.py b/programs/music_separation_code/models/bandit/core/loss/_timefreq.py
index 6ea9d599..96080e85 100644
--- a/programs/music_separation_code/models/bandit/core/loss/_timefreq.py
+++ b/programs/music_separation_code/models/bandit/core/loss/_timefreq.py
@@ -8,14 +8,15 @@
from models.bandit.core.loss._complex import ReImL1Loss, ReImL2Loss, ReImLossWrapper
from models.bandit.core.loss.snr import SignalNoisePNormRatio
+
class TimeFreqWrapper(_Loss):
def __init__(
- self,
- time_module: _Loss,
- freq_module: Optional[_Loss] = None,
- time_weight: float = 1.0,
- freq_weight: float = 1.0,
- multistem: bool = True,
+ self,
+ time_module: _Loss,
+ freq_module: Optional[_Loss] = None,
+ time_weight: float = 1.0,
+ freq_weight: float = 1.0,
+ multistem: bool = True,
) -> None:
super().__init__()
@@ -36,42 +37,36 @@ def __init__(
def forward(self, preds: Any, target: Any) -> torch.Tensor:
return self.time_weight * self.time_module(
- preds, target
+ preds, target
) + self.freq_weight * self.freq_module(preds, target)
class TimeFreqL1Loss(TimeFreqWrapper):
def __init__(
- self,
- time_weight: float = 1.0,
- freq_weight: float = 1.0,
- tkwargs: Optional[Dict[str, Any]] = None,
- fkwargs: Optional[Dict[str, Any]] = None,
- multistem: bool = True,
+ self,
+ time_weight: float = 1.0,
+ freq_weight: float = 1.0,
+ tkwargs: Optional[Dict[str, Any]] = None,
+ fkwargs: Optional[Dict[str, Any]] = None,
+ multistem: bool = True,
) -> None:
if tkwargs is None:
tkwargs = {}
if fkwargs is None:
fkwargs = {}
- time_module = (nn.L1Loss(**tkwargs))
+ time_module = nn.L1Loss(**tkwargs)
freq_module = ReImL1Loss(**fkwargs)
- super().__init__(
- time_module,
- freq_module,
- time_weight,
- freq_weight,
- multistem
- )
+ super().__init__(time_module, freq_module, time_weight, freq_weight, multistem)
class TimeFreqL2Loss(TimeFreqWrapper):
def __init__(
- self,
- time_weight: float = 1.0,
- freq_weight: float = 1.0,
- tkwargs: Optional[Dict[str, Any]] = None,
- fkwargs: Optional[Dict[str, Any]] = None,
- multistem: bool = True,
+ self,
+ time_weight: float = 1.0,
+ freq_weight: float = 1.0,
+ tkwargs: Optional[Dict[str, Any]] = None,
+ fkwargs: Optional[Dict[str, Any]] = None,
+ multistem: bool = True,
) -> None:
if tkwargs is None:
tkwargs = {}
@@ -79,24 +74,17 @@ def __init__(
fkwargs = {}
time_module = nn.MSELoss(**tkwargs)
freq_module = ReImL2Loss(**fkwargs)
- super().__init__(
- time_module,
- freq_module,
- time_weight,
- freq_weight,
- multistem
- )
-
+ super().__init__(time_module, freq_module, time_weight, freq_weight, multistem)
class TimeFreqSignalNoisePNormRatioLoss(TimeFreqWrapper):
def __init__(
- self,
- time_weight: float = 1.0,
- freq_weight: float = 1.0,
- tkwargs: Optional[Dict[str, Any]] = None,
- fkwargs: Optional[Dict[str, Any]] = None,
- multistem: bool = True,
+ self,
+ time_weight: float = 1.0,
+ freq_weight: float = 1.0,
+ tkwargs: Optional[Dict[str, Any]] = None,
+ fkwargs: Optional[Dict[str, Any]] = None,
+ multistem: bool = True,
) -> None:
if tkwargs is None:
tkwargs = {}
@@ -104,10 +92,4 @@ def __init__(
fkwargs = {}
time_module = SignalNoisePNormRatio(**tkwargs)
freq_module = SignalNoisePNormRatio(**fkwargs)
- super().__init__(
- time_module,
- freq_module,
- time_weight,
- freq_weight,
- multistem
- )
+ super().__init__(time_module, freq_module, time_weight, freq_weight, multistem)
diff --git a/programs/music_separation_code/models/bandit/core/loss/snr.py b/programs/music_separation_code/models/bandit/core/loss/snr.py
index 2996dd57..8d712a52 100644
--- a/programs/music_separation_code/models/bandit/core/loss/snr.py
+++ b/programs/music_separation_code/models/bandit/core/loss/snr.py
@@ -2,15 +2,16 @@
from torch.nn.modules.loss import _Loss
from torch.nn import functional as F
+
class SignalNoisePNormRatio(_Loss):
def __init__(
- self,
- p: float = 1.0,
- scale_invariant: bool = False,
- zero_mean: bool = False,
- take_log: bool = True,
- reduction: str = "mean",
- EPS: float = 1e-3,
+ self,
+ p: float = 1.0,
+ scale_invariant: bool = False,
+ zero_mean: bool = False,
+ take_log: bool = True,
+ reduction: str = "mean",
+ EPS: float = 1e-3,
) -> None:
assert reduction != "sum", NotImplementedError
super().__init__(reduction=reduction)
@@ -23,23 +24,21 @@ def __init__(
self.scale_invariant = scale_invariant
- def forward(
- self,
- est_target: torch.Tensor,
- target: torch.Tensor
- ) -> torch.Tensor:
+ def forward(self, est_target: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
target_ = target
if self.scale_invariant:
ndim = target.ndim
dot = torch.sum(est_target * torch.conj(target), dim=-1, keepdim=True)
- s_target_energy = (
- torch.sum(target * torch.conj(target), dim=-1, keepdim=True)
+ s_target_energy = torch.sum(
+ target * torch.conj(target), dim=-1, keepdim=True
)
if ndim > 2:
dot = torch.sum(dot, dim=list(range(1, ndim)), keepdim=True)
- s_target_energy = torch.sum(s_target_energy, dim=list(range(1, ndim)), keepdim=True)
+ s_target_energy = torch.sum(
+ s_target_energy, dim=list(range(1, ndim)), keepdim=True
+ )
target_scaler = (dot + 1e-8) / (s_target_energy + 1e-8)
target = target_ * target_scaler
@@ -48,25 +47,26 @@ def forward(
est_target = torch.view_as_real(est_target)
target = torch.view_as_real(target)
-
batch_size = est_target.shape[0]
est_target = est_target.reshape(batch_size, -1)
target = target.reshape(batch_size, -1)
# target_ = target_.reshape(batch_size, -1)
if self.p == 1:
- e_error = torch.abs(est_target-target).mean(dim=-1)
+ e_error = torch.abs(est_target - target).mean(dim=-1)
e_target = torch.abs(target).mean(dim=-1)
elif self.p == 2:
- e_error = torch.square(est_target-target).mean(dim=-1)
+ e_error = torch.square(est_target - target).mean(dim=-1)
e_target = torch.square(target).mean(dim=-1)
else:
raise NotImplementedError
-
+
if self.take_log:
- loss = 10*(torch.log10(e_error + self.EPS) - torch.log10(e_target + self.EPS))
+ loss = 10 * (
+ torch.log10(e_error + self.EPS) - torch.log10(e_target + self.EPS)
+ )
else:
- loss = (e_error + self.EPS)/(e_target + self.EPS)
+ loss = (e_error + self.EPS) / (e_target + self.EPS)
if self.reduction == "mean":
loss = loss.mean()
@@ -75,17 +75,16 @@ def forward(
return loss
-
class MultichannelSingleSrcNegSDR(_Loss):
def __init__(
- self,
- sdr_type: str,
- p: float = 2.0,
- zero_mean: bool = True,
- take_log: bool = True,
- reduction: str = "mean",
- EPS: float = 1e-8,
+ self,
+ sdr_type: str,
+ p: float = 2.0,
+ zero_mean: bool = True,
+ take_log: bool = True,
+ reduction: str = "mean",
+ EPS: float = 1e-8,
) -> None:
assert reduction != "sum", NotImplementedError
super().__init__(reduction=reduction)
@@ -98,14 +97,10 @@ def __init__(
self.p = p
- def forward(
- self,
- est_target: torch.Tensor,
- target: torch.Tensor
- ) -> torch.Tensor:
+ def forward(self, est_target: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
if target.size() != est_target.size() or target.ndim != 3:
raise TypeError(
- f"Inputs must be of shape [batch, time], got {target.size()} and {est_target.size()} instead"
+ f"Inputs must be of shape [batch, time], got {target.size()} and {est_target.size()} instead"
)
# Step 1. Zero-mean norm
if self.zero_mean:
@@ -118,9 +113,7 @@ def forward(
# [batch, 1]
dot = torch.sum(est_target * target, dim=[1, 2], keepdim=True)
# [batch, 1]
- s_target_energy = (
- torch.sum(target ** 2, dim=[1, 2], keepdim=True) + self.EPS
- )
+ s_target_energy = torch.sum(target**2, dim=[1, 2], keepdim=True) + self.EPS
# [batch, time]
scaled_target = dot * target / s_target_energy
else:
@@ -133,12 +126,12 @@ def forward(
# [batch]
if self.p == 2.0:
- losses = torch.sum(scaled_target ** 2, dim=[1, 2]) / (
- torch.sum(e_noise ** 2, dim=[1, 2]) + self.EPS
+ losses = torch.sum(scaled_target**2, dim=[1, 2]) / (
+ torch.sum(e_noise**2, dim=[1, 2]) + self.EPS
)
else:
losses = torch.norm(scaled_target, p=self.p, dim=[1, 2]) / (
- torch.linalg.vector_norm(e_noise, p=self.p, dim=[1, 2]) + self.EPS
+ torch.linalg.vector_norm(e_noise, p=self.p, dim=[1, 2]) + self.EPS
)
if self.take_log:
losses = 10 * torch.log10(losses + self.EPS)
diff --git a/programs/music_separation_code/models/bandit/core/metrics/_squim.py b/programs/music_separation_code/models/bandit/core/metrics/_squim.py
index ec76b5fb..71c993a2 100644
--- a/programs/music_separation_code/models/bandit/core/metrics/_squim.py
+++ b/programs/music_separation_code/models/bandit/core/metrics/_squim.py
@@ -40,7 +40,10 @@ def __init__(self, val_range: Tuple[float, float] = (0.0, 1.0)) -> None:
self.sigmoid: nn.modules.Module = nn.Sigmoid()
def forward(self, x: torch.Tensor) -> torch.Tensor:
- out = self.sigmoid(x) * (self.val_range[1] - self.val_range[0]) + self.val_range[0]
+ out = (
+ self.sigmoid(x) * (self.val_range[1] - self.val_range[0])
+ + self.val_range[0]
+ )
return out
@@ -72,7 +75,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
class SingleRNN(nn.Module):
- def __init__(self, rnn_type: str, input_size: int, hidden_size: int, dropout: float = 0.0) -> None:
+ def __init__(
+ self, rnn_type: str, input_size: int, hidden_size: int, dropout: float = 0.0
+ ) -> None:
super(SingleRNN, self).__init__()
self.rnn_type = rnn_type
@@ -144,7 +149,10 @@ def pad_chunk(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
# input shape: (B, N, T)
seq_len = x.shape[-1]
- rest = self.chunk_size - (self.chunk_stride + seq_len % self.chunk_size) % self.chunk_size
+ rest = (
+ self.chunk_size
+ - (self.chunk_stride + seq_len % self.chunk_size) % self.chunk_size
+ )
out = F.pad(x, [self.chunk_stride, rest + self.chunk_stride])
return out, rest
@@ -153,18 +161,42 @@ def chunking(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
out, rest = self.pad_chunk(x)
batch_size, feat_dim, seq_len = out.shape
- segments1 = out[:, :, : -self.chunk_stride].contiguous().view(batch_size, feat_dim, -1, self.chunk_size)
- segments2 = out[:, :, self.chunk_stride :].contiguous().view(batch_size, feat_dim, -1, self.chunk_size)
+ segments1 = (
+ out[:, :, : -self.chunk_stride]
+ .contiguous()
+ .view(batch_size, feat_dim, -1, self.chunk_size)
+ )
+ segments2 = (
+ out[:, :, self.chunk_stride :]
+ .contiguous()
+ .view(batch_size, feat_dim, -1, self.chunk_size)
+ )
out = torch.cat([segments1, segments2], dim=3)
- out = out.view(batch_size, feat_dim, -1, self.chunk_size).transpose(2, 3).contiguous()
+ out = (
+ out.view(batch_size, feat_dim, -1, self.chunk_size)
+ .transpose(2, 3)
+ .contiguous()
+ )
return out, rest
def merging(self, x: torch.Tensor, rest: int) -> torch.Tensor:
batch_size, dim, _, _ = x.shape
- out = x.transpose(2, 3).contiguous().view(batch_size, dim, -1, self.chunk_size * 2)
- out1 = out[:, :, :, : self.chunk_size].contiguous().view(batch_size, dim, -1)[:, :, self.chunk_stride :]
- out2 = out[:, :, :, self.chunk_size :].contiguous().view(batch_size, dim, -1)[:, :, : -self.chunk_stride]
+ out = (
+ x.transpose(2, 3)
+ .contiguous()
+ .view(batch_size, dim, -1, self.chunk_size * 2)
+ )
+ out1 = (
+ out[:, :, :, : self.chunk_size]
+ .contiguous()
+ .view(batch_size, dim, -1)[:, :, self.chunk_stride :]
+ )
+ out2 = (
+ out[:, :, :, self.chunk_size :]
+ .contiguous()
+ .view(batch_size, dim, -1)[:, :, : -self.chunk_stride]
+ )
out = out1 + out2
if rest > 0:
out = out[:, :, :-rest]
@@ -175,16 +207,36 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x, rest = self.chunking(x)
batch_size, _, dim1, dim2 = x.shape
out = x
- for row_rnn, row_norm, col_rnn, col_norm in zip(self.row_rnn, self.row_norm, self.col_rnn, self.col_norm):
- row_in = out.permute(0, 3, 2, 1).contiguous().view(batch_size * dim2, dim1, -1).contiguous()
+ for row_rnn, row_norm, col_rnn, col_norm in zip(
+ self.row_rnn, self.row_norm, self.col_rnn, self.col_norm
+ ):
+ row_in = (
+ out.permute(0, 3, 2, 1)
+ .contiguous()
+ .view(batch_size * dim2, dim1, -1)
+ .contiguous()
+ )
row_out = row_rnn(row_in)
- row_out = row_out.view(batch_size, dim2, dim1, -1).permute(0, 3, 2, 1).contiguous()
+ row_out = (
+ row_out.view(batch_size, dim2, dim1, -1)
+ .permute(0, 3, 2, 1)
+ .contiguous()
+ )
row_out = row_norm(row_out)
out = out + row_out
- col_in = out.permute(0, 2, 3, 1).contiguous().view(batch_size * dim1, dim2, -1).contiguous()
+ col_in = (
+ out.permute(0, 2, 3, 1)
+ .contiguous()
+ .view(batch_size * dim1, dim2, -1)
+ .contiguous()
+ )
col_out = col_rnn(col_in)
- col_out = col_out.view(batch_size, dim1, dim2, -1).permute(0, 3, 1, 2).contiguous()
+ col_out = (
+ col_out.view(batch_size, dim1, dim2, -1)
+ .permute(0, 3, 1, 2)
+ .contiguous()
+ )
col_out = col_norm(col_out)
out = out + col_out
out = self.conv(out)
@@ -236,7 +288,9 @@ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
List(torch.Tensor): List of score Tenosrs. Each Tensor is with dimension `(batch,)`.
"""
if x.ndim != 2:
- raise ValueError(f"The input must be a 2D Tensor. Found dimension {x.ndim}.")
+ raise ValueError(
+ f"The input must be a 2D Tensor. Found dimension {x.ndim}."
+ )
x = x / (torch.mean(x**2, dim=1, keepdim=True) ** 0.5 * 20)
out = self.encoder(x)
out = self.dprnn(out)
@@ -257,7 +311,9 @@ def _create_branch(d_model: int, nhead: int, metric: str) -> nn.modules.Module:
Returns:
(nn.Module): Returned module to predict corresponding metric score.
"""
- layer1 = nn.TransformerEncoderLayer(d_model, nhead, d_model * 4, dropout=0.0, batch_first=True)
+ layer1 = nn.TransformerEncoderLayer(
+ d_model, nhead, d_model * 4, dropout=0.0, batch_first=True
+ )
layer2 = AutoPool()
if metric == "stoi":
layer3 = nn.Sequential(
@@ -274,7 +330,9 @@ def _create_branch(d_model: int, nhead: int, metric: str) -> nn.modules.Module:
RangeSigmoid(val_range=PESQRange),
)
else:
- layer3: nn.modules.Module = nn.Sequential(nn.Linear(d_model, d_model), nn.PReLU(), nn.Linear(d_model, 1))
+ layer3: nn.modules.Module = nn.Sequential(
+ nn.Linear(d_model, d_model), nn.PReLU(), nn.Linear(d_model, 1)
+ )
return nn.Sequential(layer1, layer2, layer3)
@@ -305,7 +363,9 @@ def squim_objective_model(
if chunk_stride is None:
chunk_stride = chunk_size // 2
encoder = Encoder(feat_dim, win_len)
- dprnn = DPRNN(feat_dim, hidden_dim, num_blocks, rnn_type, d_model, chunk_size, chunk_stride)
+ dprnn = DPRNN(
+ feat_dim, hidden_dim, num_blocks, rnn_type, d_model, chunk_size, chunk_stride
+ )
branches = nn.ModuleList(
[
_create_branch(d_model, nhead, "stoi"),
@@ -329,6 +389,7 @@ def squim_objective_base() -> SquimObjective:
chunk_size=71,
)
+
@dataclass
class SquimObjectiveBundle:
@@ -380,4 +441,3 @@ def sample_rate(self):
Please refer to :py:class:`SquimObjectiveBundle` for usage instructions.
"""
-
diff --git a/programs/music_separation_code/models/bandit/core/metrics/snr.py b/programs/music_separation_code/models/bandit/core/metrics/snr.py
index d2830b2c..6b7a1687 100644
--- a/programs/music_separation_code/models/bandit/core/metrics/snr.py
+++ b/programs/music_separation_code/models/bandit/core/metrics/snr.py
@@ -25,11 +25,11 @@ def compute(self) -> Any:
class BaseChunkMedianSignalRatio(tm.Metric):
def __init__(
- self,
- func: Callable,
- window_size: int,
- hop_size: int = None,
- zero_mean: bool = False,
+ self,
+ func: Callable,
+ window_size: int,
+ hop_size: int = None,
+ zero_mean: bool = False,
) -> None:
super().__init__()
@@ -40,20 +40,14 @@ def __init__(
hop_size = window_size
self.hop_size = hop_size
- self.add_state(
- "sum_snr",
- default=torch.tensor(0.0),
- dist_reduce_fx="sum"
- )
+ self.add_state("sum_snr", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
n_samples = target.shape[-1]
- n_chunks = int(
- np.ceil((n_samples - self.window_size) / self.hop_size) + 1
- )
+ n_chunks = int(np.ceil((n_samples - self.window_size) / self.hop_size) + 1)
snr_chunk = []
@@ -66,10 +60,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
end = start + self.window_size
try:
- chunk_snr = self.func(
- preds[..., start:end],
- target[..., start:end]
- )
+ chunk_snr = self.func(preds[..., start:end], target[..., start:end])
# print(preds.shape, chunk_snr.shape)
@@ -90,61 +81,47 @@ def compute(self) -> Any:
class ChunkMedianSignalNoiseRatio(BaseChunkMedianSignalRatio):
def __init__(
- self,
- window_size: int,
- hop_size: int = None,
- zero_mean: bool = False
+ self, window_size: int, hop_size: int = None, zero_mean: bool = False
) -> None:
super().__init__(
- func=tmF.signal_noise_ratio,
- window_size=window_size,
- hop_size=hop_size,
- zero_mean=zero_mean,
+ func=tmF.signal_noise_ratio,
+ window_size=window_size,
+ hop_size=hop_size,
+ zero_mean=zero_mean,
)
class ChunkMedianScaleInvariantSignalNoiseRatio(BaseChunkMedianSignalRatio):
def __init__(
- self,
- window_size: int,
- hop_size: int = None,
- zero_mean: bool = False
+ self, window_size: int, hop_size: int = None, zero_mean: bool = False
) -> None:
super().__init__(
- func=tmF.scale_invariant_signal_noise_ratio,
- window_size=window_size,
- hop_size=hop_size,
- zero_mean=zero_mean,
+ func=tmF.scale_invariant_signal_noise_ratio,
+ window_size=window_size,
+ hop_size=hop_size,
+ zero_mean=zero_mean,
)
class ChunkMedianSignalDistortionRatio(BaseChunkMedianSignalRatio):
def __init__(
- self,
- window_size: int,
- hop_size: int = None,
- zero_mean: bool = False
+ self, window_size: int, hop_size: int = None, zero_mean: bool = False
) -> None:
super().__init__(
- func=tmF.signal_distortion_ratio,
- window_size=window_size,
- hop_size=hop_size,
- zero_mean=zero_mean,
+ func=tmF.signal_distortion_ratio,
+ window_size=window_size,
+ hop_size=hop_size,
+ zero_mean=zero_mean,
)
-class ChunkMedianScaleInvariantSignalDistortionRatio(
- BaseChunkMedianSignalRatio
- ):
+class ChunkMedianScaleInvariantSignalDistortionRatio(BaseChunkMedianSignalRatio):
def __init__(
- self,
- window_size: int,
- hop_size: int = None,
- zero_mean: bool = False
+ self, window_size: int, hop_size: int = None, zero_mean: bool = False
) -> None:
super().__init__(
- func=tmF.scale_invariant_signal_distortion_ratio,
- window_size=window_size,
- hop_size=hop_size,
- zero_mean=zero_mean,
+ func=tmF.scale_invariant_signal_distortion_ratio,
+ window_size=window_size,
+ hop_size=hop_size,
+ zero_mean=zero_mean,
)
diff --git a/programs/music_separation_code/models/bandit/core/model/_spectral.py b/programs/music_separation_code/models/bandit/core/model/_spectral.py
index 564cd286..6af5cbd0 100644
--- a/programs/music_separation_code/models/bandit/core/model/_spectral.py
+++ b/programs/music_separation_code/models/bandit/core/model/_spectral.py
@@ -7,18 +7,18 @@
class _SpectralComponent(nn.Module):
def __init__(
- self,
- n_fft: int = 2048,
- win_length: Optional[int] = 2048,
- hop_length: int = 512,
- window_fn: str = "hann_window",
- wkwargs: Optional[Dict] = None,
- power: Optional[int] = None,
- center: bool = True,
- normalized: bool = True,
- pad_mode: str = "constant",
- onesided: bool = True,
- **kwargs,
+ self,
+ n_fft: int = 2048,
+ win_length: Optional[int] = 2048,
+ hop_length: int = 512,
+ window_fn: str = "hann_window",
+ wkwargs: Optional[Dict] = None,
+ power: Optional[int] = None,
+ center: bool = True,
+ normalized: bool = True,
+ pad_mode: str = "constant",
+ onesided: bool = True,
+ **kwargs,
) -> None:
super().__init__()
@@ -26,33 +26,29 @@ def __init__(
window_fn = torch.__dict__[window_fn]
- self.stft = (
- ta.transforms.Spectrogram(
- n_fft=n_fft,
- win_length=win_length,
- hop_length=hop_length,
- pad_mode=pad_mode,
- pad=0,
- window_fn=window_fn,
- wkwargs=wkwargs,
- power=power,
- normalized=normalized,
- center=center,
- onesided=onesided,
- )
+ self.stft = ta.transforms.Spectrogram(
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ pad_mode=pad_mode,
+ pad=0,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ power=power,
+ normalized=normalized,
+ center=center,
+ onesided=onesided,
)
- self.istft = (
- ta.transforms.InverseSpectrogram(
- n_fft=n_fft,
- win_length=win_length,
- hop_length=hop_length,
- pad_mode=pad_mode,
- pad=0,
- window_fn=window_fn,
- wkwargs=wkwargs,
- normalized=normalized,
- center=center,
- onesided=onesided,
- )
+ self.istft = ta.transforms.InverseSpectrogram(
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ pad_mode=pad_mode,
+ pad=0,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ normalized=normalized,
+ center=center,
+ onesided=onesided,
)
diff --git a/programs/music_separation_code/models/bandit/core/model/bsrnn/bandsplit.py b/programs/music_separation_code/models/bandit/core/model/bsrnn/bandsplit.py
index 63e62558..43217655 100644
--- a/programs/music_separation_code/models/bandit/core/model/bsrnn/bandsplit.py
+++ b/programs/music_separation_code/models/bandit/core/model/bsrnn/bandsplit.py
@@ -13,12 +13,12 @@
class NormFC(nn.Module):
def __init__(
- self,
- emb_dim: int,
- bandwidth: int,
- in_channel: int,
- normalize_channel_independently: bool = False,
- treat_channel_as_feature: bool = True,
+ self,
+ emb_dim: int,
+ bandwidth: int,
+ in_channel: int,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
) -> None:
super().__init__()
@@ -67,14 +67,14 @@ def forward(self, xb):
class BandSplitModule(nn.Module):
def __init__(
- self,
- band_specs: List[Tuple[float, float]],
- emb_dim: int,
- in_channel: int,
- require_no_overlap: bool = False,
- require_no_gap: bool = True,
- normalize_channel_independently: bool = False,
- treat_channel_as_feature: bool = True,
+ self,
+ band_specs: List[Tuple[float, float]],
+ emb_dim: int,
+ in_channel: int,
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
) -> None:
super().__init__()
@@ -94,18 +94,18 @@ def __init__(
self.emb_dim = emb_dim
self.norm_fc_modules = nn.ModuleList(
- [ # type: ignore
- (
- NormFC(
- emb_dim=emb_dim,
- bandwidth=bw,
- in_channel=in_channel,
- normalize_channel_independently=normalize_channel_independently,
- treat_channel_as_feature=treat_channel_as_feature,
- )
- )
- for bw in self.band_widths
- ]
+ [ # type: ignore
+ (
+ NormFC(
+ emb_dim=emb_dim,
+ bandwidth=bw,
+ in_channel=in_channel,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ )
+ )
+ for bw in self.band_widths
+ ]
)
def forward(self, x: torch.Tensor):
@@ -114,15 +114,11 @@ def forward(self, x: torch.Tensor):
batch, in_chan, _, n_time = x.shape
z = torch.zeros(
- size=(batch, self.n_bands, n_time, self.emb_dim),
- device=x.device
+ size=(batch, self.n_bands, n_time, self.emb_dim), device=x.device
)
xr = torch.view_as_real(x) # batch, in_chan, n_freq, n_time, 2
- xr = torch.permute(
- xr,
- (0, 3, 1, 4, 2)
- ) # batch, n_time, in_chan, 2, n_freq
+ xr = torch.permute(xr, (0, 3, 1, 4, 2)) # batch, n_time, in_chan, 2, n_freq
batch, n_time, in_chan, reim, band_width = xr.shape
for i, nfm in enumerate(self.norm_fc_modules):
# print(f"bandsplit/band{i:02d}")
diff --git a/programs/music_separation_code/models/bandit/core/model/bsrnn/core.py b/programs/music_separation_code/models/bandit/core/model/bsrnn/core.py
index 7fd36259..1dbfb32b 100644
--- a/programs/music_separation_code/models/bandit/core/model/bsrnn/core.py
+++ b/programs/music_separation_code/models/bandit/core/model/bsrnn/core.py
@@ -8,12 +8,12 @@
from models.bandit.core.model.bsrnn.bandsplit import BandSplitModule
from models.bandit.core.model.bsrnn.maskestim import (
MaskEstimationModule,
- OverlappingMaskEstimationModule
+ OverlappingMaskEstimationModule,
)
from models.bandit.core.model.bsrnn.tfmodel import (
ConvolutionalTimeFreqModule,
SeqBandModellingModule,
- TransformerTimeFreqModule
+ TransformerTimeFreqModule,
)
@@ -36,7 +36,6 @@ def forward(self, x, cond=None, compute_residual: bool = True):
q = self.tf_model(z) # (batch, emb_dim, n_band, n_time)
# print(q)
-
# if torch.any(torch.isnan(q)):
# raise ValueError("q nan")
@@ -54,25 +53,23 @@ def forward(self, x, cond=None, compute_residual: bool = True):
return {"spectrogram": out}
-
-
- def instantiate_mask_estim(self,
- in_channel: int,
- stems: List[str],
- band_specs: List[Tuple[float, float]],
- emb_dim: int,
- mlp_dim: int,
- cond_dim: int,
- hidden_activation: str,
-
- hidden_activation_kwargs: Optional[Dict] = None,
- complex_mask: bool = True,
- overlapping_band: bool = False,
- freq_weights: Optional[List[torch.Tensor]] = None,
- n_freq: Optional[int] = None,
- use_freq_weights: bool = True,
- mult_add_mask: bool = False
- ):
+ def instantiate_mask_estim(
+ self,
+ in_channel: int,
+ stems: List[str],
+ band_specs: List[Tuple[float, float]],
+ emb_dim: int,
+ mlp_dim: int,
+ cond_dim: int,
+ hidden_activation: str,
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ overlapping_band: bool = False,
+ freq_weights: Optional[List[torch.Tensor]] = None,
+ n_freq: Optional[int] = None,
+ use_freq_weights: bool = True,
+ mult_add_mask: bool = False,
+ ):
if hidden_activation_kwargs is None:
hidden_activation_kwargs = {}
@@ -86,75 +83,77 @@ def instantiate_mask_estim(self,
if mult_add_mask:
self.mask_estim = nn.ModuleDict(
- {
- stem: MultAddMaskEstimationModule(
- band_specs=band_specs,
- freq_weights=freq_weights,
- n_freq=n_freq,
- emb_dim=emb_dim,
- mlp_dim=mlp_dim,
- in_channel=in_channel,
- hidden_activation=hidden_activation,
- hidden_activation_kwargs=hidden_activation_kwargs,
- complex_mask=complex_mask,
- use_freq_weights=use_freq_weights,
- )
- for stem in stems
- }
+ {
+ stem: MultAddMaskEstimationModule(
+ band_specs=band_specs,
+ freq_weights=freq_weights,
+ n_freq=n_freq,
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ in_channel=in_channel,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ use_freq_weights=use_freq_weights,
+ )
+ for stem in stems
+ }
)
else:
self.mask_estim = nn.ModuleDict(
- {
- stem: OverlappingMaskEstimationModule(
- band_specs=band_specs,
- freq_weights=freq_weights,
- n_freq=n_freq,
- emb_dim=emb_dim,
- mlp_dim=mlp_dim,
- in_channel=in_channel,
- hidden_activation=hidden_activation,
- hidden_activation_kwargs=hidden_activation_kwargs,
- complex_mask=complex_mask,
- use_freq_weights=use_freq_weights,
- )
- for stem in stems
- }
+ {
+ stem: OverlappingMaskEstimationModule(
+ band_specs=band_specs,
+ freq_weights=freq_weights,
+ n_freq=n_freq,
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ in_channel=in_channel,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ use_freq_weights=use_freq_weights,
+ )
+ for stem in stems
+ }
)
else:
self.mask_estim = nn.ModuleDict(
- {
- stem: MaskEstimationModule(
- band_specs=band_specs,
- emb_dim=emb_dim,
- mlp_dim=mlp_dim,
- cond_dim=cond_dim,
- in_channel=in_channel,
- hidden_activation=hidden_activation,
- hidden_activation_kwargs=hidden_activation_kwargs,
- complex_mask=complex_mask,
- )
- for stem in stems
- }
+ {
+ stem: MaskEstimationModule(
+ band_specs=band_specs,
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ cond_dim=cond_dim,
+ in_channel=in_channel,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ )
+ for stem in stems
+ }
)
- def instantiate_bandsplit(self,
- in_channel: int,
- band_specs: List[Tuple[float, float]],
- require_no_overlap: bool = False,
- require_no_gap: bool = True,
- normalize_channel_independently: bool = False,
- treat_channel_as_feature: bool = True,
- emb_dim: int = 128
- ):
+ def instantiate_bandsplit(
+ self,
+ in_channel: int,
+ band_specs: List[Tuple[float, float]],
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ emb_dim: int = 128,
+ ):
self.band_split = BandSplitModule(
- in_channel=in_channel,
- band_specs=band_specs,
- require_no_overlap=require_no_overlap,
- require_no_gap=require_no_gap,
- normalize_channel_independently=normalize_channel_independently,
- treat_channel_as_feature=treat_channel_as_feature,
- emb_dim=emb_dim,
- )
+ in_channel=in_channel,
+ band_specs=band_specs,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ emb_dim=emb_dim,
+ )
+
class SingleMaskBandsplitCoreBase(BandsplitCoreBase):
def __init__(self, **kwargs) -> None:
@@ -172,169 +171,166 @@ def forward(self, x):
class SingleMaskBandsplitCoreRNN(
- SingleMaskBandsplitCoreBase,
+ SingleMaskBandsplitCoreBase,
):
def __init__(
- self,
- in_channel: int,
- band_specs: List[Tuple[float, float]],
- require_no_overlap: bool = False,
- require_no_gap: bool = True,
- normalize_channel_independently: bool = False,
- treat_channel_as_feature: bool = True,
- n_sqm_modules: int = 12,
- emb_dim: int = 128,
- rnn_dim: int = 256,
- bidirectional: bool = True,
- rnn_type: str = "LSTM",
- mlp_dim: int = 512,
- hidden_activation: str = "Tanh",
- hidden_activation_kwargs: Optional[Dict] = None,
- complex_mask: bool = True,
+ self,
+ in_channel: int,
+ band_specs: List[Tuple[float, float]],
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ mlp_dim: int = 512,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
) -> None:
super().__init__()
- self.band_split = (BandSplitModule(
- in_channel=in_channel,
- band_specs=band_specs,
- require_no_overlap=require_no_overlap,
- require_no_gap=require_no_gap,
- normalize_channel_independently=normalize_channel_independently,
- treat_channel_as_feature=treat_channel_as_feature,
- emb_dim=emb_dim,
- ))
- self.tf_model = (SeqBandModellingModule(
- n_modules=n_sqm_modules,
- emb_dim=emb_dim,
- rnn_dim=rnn_dim,
- bidirectional=bidirectional,
- rnn_type=rnn_type,
- ))
- self.mask_estim = (MaskEstimationModule(
- in_channel=in_channel,
- band_specs=band_specs,
- emb_dim=emb_dim,
- mlp_dim=mlp_dim,
- hidden_activation=hidden_activation,
- hidden_activation_kwargs=hidden_activation_kwargs,
- complex_mask=complex_mask,
- ))
+ self.band_split = BandSplitModule(
+ in_channel=in_channel,
+ band_specs=band_specs,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ emb_dim=emb_dim,
+ )
+ self.tf_model = SeqBandModellingModule(
+ n_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ )
+ self.mask_estim = MaskEstimationModule(
+ in_channel=in_channel,
+ band_specs=band_specs,
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ )
class SingleMaskBandsplitCoreTransformer(
- SingleMaskBandsplitCoreBase,
+ SingleMaskBandsplitCoreBase,
):
def __init__(
- self,
- in_channel: int,
- band_specs: List[Tuple[float, float]],
- require_no_overlap: bool = False,
- require_no_gap: bool = True,
- normalize_channel_independently: bool = False,
- treat_channel_as_feature: bool = True,
- n_sqm_modules: int = 12,
- emb_dim: int = 128,
- rnn_dim: int = 256,
- bidirectional: bool = True,
- tf_dropout: float = 0.0,
- mlp_dim: int = 512,
- hidden_activation: str = "Tanh",
- hidden_activation_kwargs: Optional[Dict] = None,
- complex_mask: bool = True,
+ self,
+ in_channel: int,
+ band_specs: List[Tuple[float, float]],
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ tf_dropout: float = 0.0,
+ mlp_dim: int = 512,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
) -> None:
super().__init__()
self.band_split = BandSplitModule(
- in_channel=in_channel,
- band_specs=band_specs,
- require_no_overlap=require_no_overlap,
- require_no_gap=require_no_gap,
- normalize_channel_independently=normalize_channel_independently,
- treat_channel_as_feature=treat_channel_as_feature,
- emb_dim=emb_dim,
+ in_channel=in_channel,
+ band_specs=band_specs,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ emb_dim=emb_dim,
)
self.tf_model = TransformerTimeFreqModule(
- n_modules=n_sqm_modules,
- emb_dim=emb_dim,
- rnn_dim=rnn_dim,
- bidirectional=bidirectional,
- dropout=tf_dropout,
+ n_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ dropout=tf_dropout,
)
self.mask_estim = MaskEstimationModule(
- in_channel=in_channel,
- band_specs=band_specs,
- emb_dim=emb_dim,
- mlp_dim=mlp_dim,
- hidden_activation=hidden_activation,
- hidden_activation_kwargs=hidden_activation_kwargs,
- complex_mask=complex_mask,
+ in_channel=in_channel,
+ band_specs=band_specs,
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
)
class MultiSourceMultiMaskBandSplitCoreRNN(MultiMaskBandSplitCoreBase):
def __init__(
- self,
- in_channel: int,
- stems: List[str],
- band_specs: List[Tuple[float, float]],
- require_no_overlap: bool = False,
- require_no_gap: bool = True,
- normalize_channel_independently: bool = False,
- treat_channel_as_feature: bool = True,
- n_sqm_modules: int = 12,
- emb_dim: int = 128,
- rnn_dim: int = 256,
- bidirectional: bool = True,
- rnn_type: str = "LSTM",
- mlp_dim: int = 512,
- cond_dim: int = 0,
- hidden_activation: str = "Tanh",
- hidden_activation_kwargs: Optional[Dict] = None,
- complex_mask: bool = True,
- overlapping_band: bool = False,
- freq_weights: Optional[List[torch.Tensor]] = None,
- n_freq: Optional[int] = None,
- use_freq_weights: bool = True,
- mult_add_mask: bool = False
+ self,
+ in_channel: int,
+ stems: List[str],
+ band_specs: List[Tuple[float, float]],
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ mlp_dim: int = 512,
+ cond_dim: int = 0,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ overlapping_band: bool = False,
+ freq_weights: Optional[List[torch.Tensor]] = None,
+ n_freq: Optional[int] = None,
+ use_freq_weights: bool = True,
+ mult_add_mask: bool = False,
) -> None:
super().__init__()
self.instantiate_bandsplit(
- in_channel=in_channel,
- band_specs=band_specs,
- require_no_overlap=require_no_overlap,
- require_no_gap=require_no_gap,
- normalize_channel_independently=normalize_channel_independently,
- treat_channel_as_feature=treat_channel_as_feature,
- emb_dim=emb_dim
+ in_channel=in_channel,
+ band_specs=band_specs,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ emb_dim=emb_dim,
)
-
- self.tf_model = (
- SeqBandModellingModule(
- n_modules=n_sqm_modules,
- emb_dim=emb_dim,
- rnn_dim=rnn_dim,
- bidirectional=bidirectional,
- rnn_type=rnn_type,
- )
+ self.tf_model = SeqBandModellingModule(
+ n_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
)
self.mult_add_mask = mult_add_mask
self.instantiate_mask_estim(
- in_channel=in_channel,
- stems=stems,
- band_specs=band_specs,
- emb_dim=emb_dim,
- mlp_dim=mlp_dim,
- cond_dim=cond_dim,
- hidden_activation=hidden_activation,
- hidden_activation_kwargs=hidden_activation_kwargs,
- complex_mask=complex_mask,
- overlapping_band=overlapping_band,
- freq_weights=freq_weights,
- n_freq=n_freq,
- use_freq_weights=use_freq_weights,
- mult_add_mask=mult_add_mask
+ in_channel=in_channel,
+ stems=stems,
+ band_specs=band_specs,
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ cond_dim=cond_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ overlapping_band=overlapping_band,
+ freq_weights=freq_weights,
+ n_freq=n_freq,
+ use_freq_weights=use_freq_weights,
+ mult_add_mask=mult_add_mask,
)
@staticmethod
@@ -358,133 +354,132 @@ def mask(self, x, m):
class MultiSourceMultiMaskBandSplitCoreTransformer(
- MultiMaskBandSplitCoreBase,
+ MultiMaskBandSplitCoreBase,
):
def __init__(
- self,
- in_channel: int,
- stems: List[str],
- band_specs: List[Tuple[float, float]],
- require_no_overlap: bool = False,
- require_no_gap: bool = True,
- normalize_channel_independently: bool = False,
- treat_channel_as_feature: bool = True,
- n_sqm_modules: int = 12,
- emb_dim: int = 128,
- rnn_dim: int = 256,
- bidirectional: bool = True,
- tf_dropout: float = 0.0,
- mlp_dim: int = 512,
- hidden_activation: str = "Tanh",
- hidden_activation_kwargs: Optional[Dict] = None,
- complex_mask: bool = True,
- overlapping_band: bool = False,
- freq_weights: Optional[List[torch.Tensor]] = None,
- n_freq: Optional[int] = None,
- use_freq_weights:bool=True,
- rnn_type: str = "LSTM",
- cond_dim: int = 0,
- mult_add_mask: bool = False
+ self,
+ in_channel: int,
+ stems: List[str],
+ band_specs: List[Tuple[float, float]],
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ tf_dropout: float = 0.0,
+ mlp_dim: int = 512,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ overlapping_band: bool = False,
+ freq_weights: Optional[List[torch.Tensor]] = None,
+ n_freq: Optional[int] = None,
+ use_freq_weights: bool = True,
+ rnn_type: str = "LSTM",
+ cond_dim: int = 0,
+ mult_add_mask: bool = False,
) -> None:
super().__init__()
self.instantiate_bandsplit(
- in_channel=in_channel,
- band_specs=band_specs,
- require_no_overlap=require_no_overlap,
- require_no_gap=require_no_gap,
- normalize_channel_independently=normalize_channel_independently,
- treat_channel_as_feature=treat_channel_as_feature,
- emb_dim=emb_dim
+ in_channel=in_channel,
+ band_specs=band_specs,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ emb_dim=emb_dim,
)
self.tf_model = TransformerTimeFreqModule(
- n_modules=n_sqm_modules,
- emb_dim=emb_dim,
- rnn_dim=rnn_dim,
- bidirectional=bidirectional,
- dropout=tf_dropout,
+ n_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ dropout=tf_dropout,
)
-
+
self.instantiate_mask_estim(
- in_channel=in_channel,
- stems=stems,
- band_specs=band_specs,
- emb_dim=emb_dim,
- mlp_dim=mlp_dim,
- cond_dim=cond_dim,
- hidden_activation=hidden_activation,
- hidden_activation_kwargs=hidden_activation_kwargs,
- complex_mask=complex_mask,
- overlapping_band=overlapping_band,
- freq_weights=freq_weights,
- n_freq=n_freq,
- use_freq_weights=use_freq_weights,
- mult_add_mask=mult_add_mask
+ in_channel=in_channel,
+ stems=stems,
+ band_specs=band_specs,
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ cond_dim=cond_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ overlapping_band=overlapping_band,
+ freq_weights=freq_weights,
+ n_freq=n_freq,
+ use_freq_weights=use_freq_weights,
+ mult_add_mask=mult_add_mask,
)
-
class MultiSourceMultiMaskBandSplitCoreConv(
- MultiMaskBandSplitCoreBase,
+ MultiMaskBandSplitCoreBase,
):
def __init__(
- self,
- in_channel: int,
- stems: List[str],
- band_specs: List[Tuple[float, float]],
- require_no_overlap: bool = False,
- require_no_gap: bool = True,
- normalize_channel_independently: bool = False,
- treat_channel_as_feature: bool = True,
- n_sqm_modules: int = 12,
- emb_dim: int = 128,
- rnn_dim: int = 256,
- bidirectional: bool = True,
- tf_dropout: float = 0.0,
- mlp_dim: int = 512,
- hidden_activation: str = "Tanh",
- hidden_activation_kwargs: Optional[Dict] = None,
- complex_mask: bool = True,
- overlapping_band: bool = False,
- freq_weights: Optional[List[torch.Tensor]] = None,
- n_freq: Optional[int] = None,
- use_freq_weights:bool=True,
- rnn_type: str = "LSTM",
- cond_dim: int = 0,
- mult_add_mask: bool = False
+ self,
+ in_channel: int,
+ stems: List[str],
+ band_specs: List[Tuple[float, float]],
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ tf_dropout: float = 0.0,
+ mlp_dim: int = 512,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ overlapping_band: bool = False,
+ freq_weights: Optional[List[torch.Tensor]] = None,
+ n_freq: Optional[int] = None,
+ use_freq_weights: bool = True,
+ rnn_type: str = "LSTM",
+ cond_dim: int = 0,
+ mult_add_mask: bool = False,
) -> None:
super().__init__()
self.instantiate_bandsplit(
- in_channel=in_channel,
- band_specs=band_specs,
- require_no_overlap=require_no_overlap,
- require_no_gap=require_no_gap,
- normalize_channel_independently=normalize_channel_independently,
- treat_channel_as_feature=treat_channel_as_feature,
- emb_dim=emb_dim
+ in_channel=in_channel,
+ band_specs=band_specs,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ emb_dim=emb_dim,
)
self.tf_model = ConvolutionalTimeFreqModule(
- n_modules=n_sqm_modules,
- emb_dim=emb_dim,
- rnn_dim=rnn_dim,
- bidirectional=bidirectional,
- dropout=tf_dropout,
+ n_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ dropout=tf_dropout,
)
-
+
self.instantiate_mask_estim(
- in_channel=in_channel,
- stems=stems,
- band_specs=band_specs,
- emb_dim=emb_dim,
- mlp_dim=mlp_dim,
- cond_dim=cond_dim,
- hidden_activation=hidden_activation,
- hidden_activation_kwargs=hidden_activation_kwargs,
- complex_mask=complex_mask,
- overlapping_band=overlapping_band,
- freq_weights=freq_weights,
- n_freq=n_freq,
- use_freq_weights=use_freq_weights,
- mult_add_mask=mult_add_mask
+ in_channel=in_channel,
+ stems=stems,
+ band_specs=band_specs,
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ cond_dim=cond_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ overlapping_band=overlapping_band,
+ freq_weights=freq_weights,
+ n_freq=n_freq,
+ use_freq_weights=use_freq_weights,
+ mult_add_mask=mult_add_mask,
)
@@ -500,40 +495,40 @@ def mask(self, x, m):
padding = ((kernel_freq - 1) // 2, (kernel_time - 1) // 2)
xf = F.unfold(
- x,
- kernel_size=(kernel_freq, kernel_time),
- padding=padding,
- stride=(1, 1),
+ x,
+ kernel_size=(kernel_freq, kernel_time),
+ padding=padding,
+ stride=(1, 1),
)
xf = xf.view(
- -1,
- n_channel,
- kernel_freq,
- kernel_time,
- n_freq,
- n_time,
+ -1,
+ n_channel,
+ kernel_freq,
+ kernel_time,
+ n_freq,
+ n_time,
)
sf = xf * m
sf = sf.view(
- -1,
- n_channel * kernel_freq * kernel_time,
- n_freq * n_time,
+ -1,
+ n_channel * kernel_freq * kernel_time,
+ n_freq * n_time,
)
s = F.fold(
- sf,
- output_size=(n_freq, n_time),
- kernel_size=(kernel_freq, kernel_time),
- padding=padding,
- stride=(1, 1),
+ sf,
+ output_size=(n_freq, n_time),
+ kernel_size=(kernel_freq, kernel_time),
+ padding=padding,
+ stride=(1, 1),
).view(
- -1,
- n_channel,
- n_freq,
- n_time,
+ -1,
+ n_channel,
+ n_freq,
+ n_time,
)
return s
@@ -570,64 +565,59 @@ def old_mask(self, x, m):
fslice = slice(max(0, df), min(n_freq, n_freq + df))
tslice = slice(max(0, dt), min(n_time, n_time + dt))
- s[:, :, fslice, tslice] += x[:, :, fslice, tslice] * m[ifreq,
- itime, :,
- :, fslice,
- tslice]
+ s[:, :, fslice, tslice] += (
+ x[:, :, fslice, tslice] * m[ifreq, itime, :, :, fslice, tslice]
+ )
return s
-class MultiSourceMultiPatchingMaskBandSplitCoreRNN(
- PatchingMaskBandsplitCoreBase
-):
+class MultiSourceMultiPatchingMaskBandSplitCoreRNN(PatchingMaskBandsplitCoreBase):
def __init__(
- self,
- in_channel: int,
- stems: List[str],
- band_specs: List[Tuple[float, float]],
- mask_kernel_freq: int,
- mask_kernel_time: int,
- conv_kernel_freq: int,
- conv_kernel_time: int,
- kernel_norm_mlp_version: int,
- require_no_overlap: bool = False,
- require_no_gap: bool = True,
- normalize_channel_independently: bool = False,
- treat_channel_as_feature: bool = True,
- n_sqm_modules: int = 12,
- emb_dim: int = 128,
- rnn_dim: int = 256,
- bidirectional: bool = True,
- rnn_type: str = "LSTM",
- mlp_dim: int = 512,
- hidden_activation: str = "Tanh",
- hidden_activation_kwargs: Optional[Dict] = None,
- complex_mask: bool = True,
- overlapping_band: bool = False,
- freq_weights: Optional[List[torch.Tensor]] = None,
- n_freq: Optional[int] = None,
+ self,
+ in_channel: int,
+ stems: List[str],
+ band_specs: List[Tuple[float, float]],
+ mask_kernel_freq: int,
+ mask_kernel_time: int,
+ conv_kernel_freq: int,
+ conv_kernel_time: int,
+ kernel_norm_mlp_version: int,
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ mlp_dim: int = 512,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ overlapping_band: bool = False,
+ freq_weights: Optional[List[torch.Tensor]] = None,
+ n_freq: Optional[int] = None,
) -> None:
super().__init__()
self.band_split = BandSplitModule(
- in_channel=in_channel,
- band_specs=band_specs,
- require_no_overlap=require_no_overlap,
- require_no_gap=require_no_gap,
- normalize_channel_independently=normalize_channel_independently,
- treat_channel_as_feature=treat_channel_as_feature,
- emb_dim=emb_dim,
+ in_channel=in_channel,
+ band_specs=band_specs,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ emb_dim=emb_dim,
)
- self.tf_model = (
- SeqBandModellingModule(
- n_modules=n_sqm_modules,
- emb_dim=emb_dim,
- rnn_dim=rnn_dim,
- bidirectional=bidirectional,
- rnn_type=rnn_type,
- )
+ self.tf_model = SeqBandModellingModule(
+ n_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
)
if hidden_activation_kwargs is None:
@@ -637,25 +627,25 @@ def __init__(
assert freq_weights is not None
assert n_freq is not None
self.mask_estim = nn.ModuleDict(
- {
- stem: PatchingMaskEstimationModule(
- band_specs=band_specs,
- freq_weights=freq_weights,
- n_freq=n_freq,
- emb_dim=emb_dim,
- mlp_dim=mlp_dim,
- in_channel=in_channel,
- hidden_activation=hidden_activation,
- hidden_activation_kwargs=hidden_activation_kwargs,
- complex_mask=complex_mask,
- mask_kernel_freq=mask_kernel_freq,
- mask_kernel_time=mask_kernel_time,
- conv_kernel_freq=conv_kernel_freq,
- conv_kernel_time=conv_kernel_time,
- kernel_norm_mlp_version=kernel_norm_mlp_version
- )
- for stem in stems
- }
+ {
+ stem: PatchingMaskEstimationModule(
+ band_specs=band_specs,
+ freq_weights=freq_weights,
+ n_freq=n_freq,
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ in_channel=in_channel,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ mask_kernel_freq=mask_kernel_freq,
+ mask_kernel_time=mask_kernel_time,
+ conv_kernel_freq=conv_kernel_freq,
+ conv_kernel_time=conv_kernel_time,
+ kernel_norm_mlp_version=kernel_norm_mlp_version,
+ )
+ for stem in stems
+ }
)
else:
raise NotImplementedError
diff --git a/programs/music_separation_code/models/bandit/core/model/bsrnn/maskestim.py b/programs/music_separation_code/models/bandit/core/model/bsrnn/maskestim.py
index 0b9289df..6049596c 100644
--- a/programs/music_separation_code/models/bandit/core/model/bsrnn/maskestim.py
+++ b/programs/music_separation_code/models/bandit/core/model/bsrnn/maskestim.py
@@ -1,4 +1,3 @@
-import warnings
from typing import Dict, List, Optional, Tuple, Type
import torch
@@ -15,26 +14,27 @@
class BaseNormMLP(nn.Module):
def __init__(
- self,
- emb_dim: int,
- mlp_dim: int,
- bandwidth: int,
- in_channel: Optional[int],
- hidden_activation: str = "Tanh",
- hidden_activation_kwargs=None,
- complex_mask: bool = True, ):
+ self,
+ emb_dim: int,
+ mlp_dim: int,
+ bandwidth: int,
+ in_channel: Optional[int],
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs=None,
+ complex_mask: bool = True,
+ ):
super().__init__()
if hidden_activation_kwargs is None:
hidden_activation_kwargs = {}
self.hidden_activation_kwargs = hidden_activation_kwargs
self.norm = nn.LayerNorm(emb_dim)
- self.hidden = torch.jit.script(nn.Sequential(
+ self.hidden = torch.jit.script(
+ nn.Sequential(
nn.Linear(in_features=emb_dim, out_features=mlp_dim),
- activation.__dict__[hidden_activation](
- **self.hidden_activation_kwargs
- ),
- ))
+ activation.__dict__[hidden_activation](**self.hidden_activation_kwargs),
+ )
+ )
self.bandwidth = bandwidth
self.in_channel = in_channel
@@ -46,33 +46,33 @@ def __init__(
class NormMLP(BaseNormMLP):
def __init__(
- self,
- emb_dim: int,
- mlp_dim: int,
- bandwidth: int,
- in_channel: Optional[int],
- hidden_activation: str = "Tanh",
- hidden_activation_kwargs=None,
- complex_mask: bool = True,
+ self,
+ emb_dim: int,
+ mlp_dim: int,
+ bandwidth: int,
+ in_channel: Optional[int],
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs=None,
+ complex_mask: bool = True,
) -> None:
super().__init__(
- emb_dim=emb_dim,
- mlp_dim=mlp_dim,
- bandwidth=bandwidth,
- in_channel=in_channel,
- hidden_activation=hidden_activation,
- hidden_activation_kwargs=hidden_activation_kwargs,
- complex_mask=complex_mask,
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ bandwidth=bandwidth,
+ in_channel=in_channel,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
)
self.output = torch.jit.script(
- nn.Sequential(
- nn.Linear(
- in_features=mlp_dim,
- out_features=bandwidth * in_channel * self.reim * 2,
- ),
- nn.GLU(dim=-1),
- )
+ nn.Sequential(
+ nn.Linear(
+ in_features=mlp_dim,
+ out_features=bandwidth * in_channel * self.reim * 2,
+ ),
+ nn.GLU(dim=-1),
+ )
)
def reshape_output(self, mb):
@@ -80,23 +80,14 @@ def reshape_output(self, mb):
batch, n_time, _ = mb.shape
if self.complex_mask:
mb = mb.reshape(
- batch,
- n_time,
- self.in_channel,
- self.bandwidth,
- self.reim
+ batch, n_time, self.in_channel, self.bandwidth, self.reim
).contiguous()
# print(mb.shape)
- mb = torch.view_as_complex(
- mb
- ) # (batch, n_time, in_channel, bandwidth)
+ mb = torch.view_as_complex(mb) # (batch, n_time, in_channel, bandwidth)
else:
mb = mb.reshape(batch, n_time, self.in_channel, self.bandwidth)
- mb = torch.permute(
- mb,
- (0, 2, 3, 1)
- ) # (batch, in_channel, bandwidth, n_time)
+ mb = torch.permute(mb, (0, 2, 3, 1)) # (batch, in_channel, bandwidth, n_time)
return mb
@@ -106,7 +97,6 @@ def forward(self, qb):
# if torch.any(torch.isnan(qb)):
# raise ValueError("qb0")
-
qb = self.norm(qb) # (batch, n_time, emb_dim)
# if torch.any(torch.isnan(qb)):
@@ -124,17 +114,34 @@ def forward(self, qb):
class MultAddNormMLP(NormMLP):
- def __init__(self, emb_dim: int, mlp_dim: int, bandwidth: int, in_channel: "int | None", hidden_activation: str = "Tanh", hidden_activation_kwargs=None, complex_mask: bool = True) -> None:
- super().__init__(emb_dim, mlp_dim, bandwidth, in_channel, hidden_activation, hidden_activation_kwargs, complex_mask)
+ def __init__(
+ self,
+ emb_dim: int,
+ mlp_dim: int,
+ bandwidth: int,
+ in_channel: "int | None",
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs=None,
+ complex_mask: bool = True,
+ ) -> None:
+ super().__init__(
+ emb_dim,
+ mlp_dim,
+ bandwidth,
+ in_channel,
+ hidden_activation,
+ hidden_activation_kwargs,
+ complex_mask,
+ )
self.output2 = torch.jit.script(
- nn.Sequential(
- nn.Linear(
- in_features=mlp_dim,
- out_features=bandwidth * in_channel * self.reim * 2,
- ),
- nn.GLU(dim=-1),
- )
+ nn.Sequential(
+ nn.Linear(
+ in_features=mlp_dim,
+ out_features=bandwidth * in_channel * self.reim * 2,
+ ),
+ nn.GLU(dim=-1),
+ )
)
def forward(self, qb):
@@ -155,16 +162,16 @@ class MaskEstimationModuleSuperBase(nn.Module):
class MaskEstimationModuleBase(MaskEstimationModuleSuperBase):
def __init__(
- self,
- band_specs: List[Tuple[float, float]],
- emb_dim: int,
- mlp_dim: int,
- in_channel: Optional[int],
- hidden_activation: str = "Tanh",
- hidden_activation_kwargs: Dict = None,
- complex_mask: bool = True,
- norm_mlp_cls: Type[nn.Module] = NormMLP,
- norm_mlp_kwargs: Dict = None,
+ self,
+ band_specs: List[Tuple[float, float]],
+ emb_dim: int,
+ mlp_dim: int,
+ in_channel: Optional[int],
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Dict = None,
+ complex_mask: bool = True,
+ norm_mlp_cls: Type[nn.Module] = NormMLP,
+ norm_mlp_kwargs: Dict = None,
) -> None:
super().__init__()
@@ -178,21 +185,21 @@ def __init__(
norm_mlp_kwargs = {}
self.norm_mlp = nn.ModuleList(
- [
- (
- norm_mlp_cls(
- bandwidth=self.band_widths[b],
- emb_dim=emb_dim,
- mlp_dim=mlp_dim,
- in_channel=in_channel,
- hidden_activation=hidden_activation,
- hidden_activation_kwargs=hidden_activation_kwargs,
- complex_mask=complex_mask,
- **norm_mlp_kwargs,
- )
- )
- for b in range(self.n_bands)
- ]
+ [
+ (
+ norm_mlp_cls(
+ bandwidth=self.band_widths[b],
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ in_channel=in_channel,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ **norm_mlp_kwargs,
+ )
+ )
+ for b in range(self.n_bands)
+ ]
)
def compute_masks(self, q):
@@ -209,23 +216,22 @@ def compute_masks(self, q):
return masks
-
class OverlappingMaskEstimationModule(MaskEstimationModuleBase):
def __init__(
- self,
- in_channel: int,
- band_specs: List[Tuple[float, float]],
- freq_weights: List[torch.Tensor],
- n_freq: int,
- emb_dim: int,
- mlp_dim: int,
- cond_dim: int = 0,
- hidden_activation: str = "Tanh",
- hidden_activation_kwargs: Dict = None,
- complex_mask: bool = True,
- norm_mlp_cls: Type[nn.Module] = NormMLP,
- norm_mlp_kwargs: Dict = None,
- use_freq_weights: bool = True,
+ self,
+ in_channel: int,
+ band_specs: List[Tuple[float, float]],
+ freq_weights: List[torch.Tensor],
+ n_freq: int,
+ emb_dim: int,
+ mlp_dim: int,
+ cond_dim: int = 0,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Dict = None,
+ complex_mask: bool = True,
+ norm_mlp_cls: Type[nn.Module] = NormMLP,
+ norm_mlp_kwargs: Dict = None,
+ use_freq_weights: bool = True,
) -> None:
check_nonzero_bandwidth(band_specs)
check_no_gap(band_specs)
@@ -234,15 +240,15 @@ def __init__(
# raise NotImplementedError
super().__init__(
- band_specs=band_specs,
- emb_dim=emb_dim + cond_dim,
- mlp_dim=mlp_dim,
- in_channel=in_channel,
- hidden_activation=hidden_activation,
- hidden_activation_kwargs=hidden_activation_kwargs,
- complex_mask=complex_mask,
- norm_mlp_cls=norm_mlp_cls,
- norm_mlp_kwargs=norm_mlp_kwargs,
+ band_specs=band_specs,
+ emb_dim=emb_dim + cond_dim,
+ mlp_dim=mlp_dim,
+ in_channel=in_channel,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ norm_mlp_cls=norm_mlp_cls,
+ norm_mlp_kwargs=norm_mlp_kwargs,
)
self.n_freq = n_freq
@@ -276,22 +282,22 @@ def forward(self, q, cond=None):
q = torch.cat([q, cond], dim=-1)
elif self.cond_dim > 0:
cond = torch.ones(
- (batch, n_bands, n_time, self.cond_dim),
- device=q.device,
- dtype=q.dtype,
+ (batch, n_bands, n_time, self.cond_dim),
+ device=q.device,
+ dtype=q.dtype,
)
q = torch.cat([q, cond], dim=-1)
else:
pass
mask_list = self.compute_masks(
- q
+ q
) # [n_bands * (batch, in_channel, bandwidth, n_time)]
masks = torch.zeros(
- (batch, self.in_channel, self.n_freq, n_time),
- device=q.device,
- dtype=mask_list[0].dtype,
+ (batch, self.in_channel, self.n_freq, n_time),
+ device=q.device,
+ dtype=mask_list[0].dtype,
)
for im, mask in enumerate(mask_list):
@@ -306,42 +312,39 @@ def forward(self, q, cond=None):
class MaskEstimationModule(OverlappingMaskEstimationModule):
def __init__(
- self,
- band_specs: List[Tuple[float, float]],
- emb_dim: int,
- mlp_dim: int,
- in_channel: Optional[int],
- hidden_activation: str = "Tanh",
- hidden_activation_kwargs: Dict = None,
- complex_mask: bool = True,
- **kwargs,
+ self,
+ band_specs: List[Tuple[float, float]],
+ emb_dim: int,
+ mlp_dim: int,
+ in_channel: Optional[int],
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Dict = None,
+ complex_mask: bool = True,
+ **kwargs,
) -> None:
check_nonzero_bandwidth(band_specs)
check_no_gap(band_specs)
check_no_overlap(band_specs)
super().__init__(
- in_channel=in_channel,
- band_specs=band_specs,
- freq_weights=None,
- n_freq=None,
- emb_dim=emb_dim,
- mlp_dim=mlp_dim,
- hidden_activation=hidden_activation,
- hidden_activation_kwargs=hidden_activation_kwargs,
- complex_mask=complex_mask,
+ in_channel=in_channel,
+ band_specs=band_specs,
+ freq_weights=None,
+ n_freq=None,
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
)
def forward(self, q, cond=None):
# q = (batch, n_bands, n_time, emb_dim)
masks = self.compute_masks(
- q
+ q
) # [n_bands * (batch, in_channel, bandwidth, n_time)]
# TODO: currently this requires band specs to have no gap and no overlap
- masks = torch.concat(
- masks,
- dim=2
- ) # (batch, in_channel, n_freq, n_time)
+ masks = torch.concat(masks, dim=2) # (batch, in_channel, n_freq, n_time)
return masks
diff --git a/programs/music_separation_code/models/bandit/core/model/bsrnn/tfmodel.py b/programs/music_separation_code/models/bandit/core/model/bsrnn/tfmodel.py
index ba710798..f482a118 100644
--- a/programs/music_separation_code/models/bandit/core/model/bsrnn/tfmodel.py
+++ b/programs/music_separation_code/models/bandit/core/model/bsrnn/tfmodel.py
@@ -15,13 +15,13 @@ def __init__(self) -> None:
class ResidualRNN(nn.Module):
def __init__(
- self,
- emb_dim: int,
- rnn_dim: int,
- bidirectional: bool = True,
- rnn_type: str = "LSTM",
- use_batch_trick: bool = True,
- use_layer_norm: bool = True,
+ self,
+ emb_dim: int,
+ rnn_dim: int,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ use_batch_trick: bool = True,
+ use_layer_norm: bool = True,
) -> None:
# n_group is the size of the 2nd dim
super().__init__()
@@ -33,16 +33,15 @@ def __init__(
self.norm = nn.GroupNorm(num_groups=emb_dim, num_channels=emb_dim)
self.rnn = rnn.__dict__[rnn_type](
- input_size=emb_dim,
- hidden_size=rnn_dim,
- num_layers=1,
- batch_first=True,
- bidirectional=bidirectional,
+ input_size=emb_dim,
+ hidden_size=rnn_dim,
+ num_layers=1,
+ batch_first=True,
+ bidirectional=bidirectional,
)
self.fc = nn.Linear(
- in_features=rnn_dim * (2 if bidirectional else 1),
- out_features=emb_dim
+ in_features=rnn_dim * (2 if bidirectional else 1), out_features=emb_dim
)
self.use_batch_trick = use_batch_trick
@@ -60,13 +59,13 @@ def forward(self, z):
z = self.norm(z) # (batch, n_uncrossed, n_across, emb_dim)
else:
z = torch.permute(
- z, (0, 3, 1, 2)
+ z, (0, 3, 1, 2)
) # (batch, emb_dim, n_uncrossed, n_across)
z = self.norm(z) # (batch, emb_dim, n_uncrossed, n_across)
z = torch.permute(
- z, (0, 2, 3, 1)
+ z, (0, 2, 3, 1)
) # (batch, n_uncrossed, n_across, emb_dim)
batch, n_uncrossed, n_across, emb_dim = z.shape
@@ -74,7 +73,9 @@ def forward(self, z):
if self.use_batch_trick:
z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
- z = self.rnn(z.contiguous())[0] # (batch * n_uncrossed, n_across, dir_rnn_dim)
+ z = self.rnn(z.contiguous())[
+ 0
+ ] # (batch * n_uncrossed, n_across, dir_rnn_dim)
z = torch.reshape(z, (batch, n_uncrossed, n_across, -1))
# (batch, n_uncrossed, n_across, dir_rnn_dim)
@@ -85,10 +86,7 @@ def forward(self, z):
zi = self.rnn(z[:, i, :, :])[0] # (batch, n_across, emb_dim)
zlist.append(zi)
- z = torch.stack(
- zlist,
- dim=1
- ) # (batch, n_uncrossed, n_across, dir_rnn_dim)
+ z = torch.stack(zlist, dim=1) # (batch, n_uncrossed, n_across, dir_rnn_dim)
z = self.fc(z) # (batch, n_uncrossed, n_across, emb_dim)
@@ -99,13 +97,13 @@ def forward(self, z):
class SeqBandModellingModule(TimeFrequencyModellingModule):
def __init__(
- self,
- n_modules: int = 12,
- emb_dim: int = 128,
- rnn_dim: int = 256,
- bidirectional: bool = True,
- rnn_type: str = "LSTM",
- parallel_mode=False,
+ self,
+ n_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ parallel_mode=False,
) -> None:
super().__init__()
self.seqband = nn.ModuleList([])
@@ -113,31 +111,33 @@ def __init__(
if parallel_mode:
for _ in range(n_modules):
self.seqband.append(
- nn.ModuleList(
- [ResidualRNN(
- emb_dim=emb_dim,
- rnn_dim=rnn_dim,
- bidirectional=bidirectional,
- rnn_type=rnn_type,
- ),
- ResidualRNN(
- emb_dim=emb_dim,
- rnn_dim=rnn_dim,
- bidirectional=bidirectional,
- rnn_type=rnn_type,
- )]
- )
+ nn.ModuleList(
+ [
+ ResidualRNN(
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ ),
+ ResidualRNN(
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ ),
+ ]
+ )
)
else:
for _ in range(2 * n_modules):
self.seqband.append(
- ResidualRNN(
- emb_dim=emb_dim,
- rnn_dim=rnn_dim,
- bidirectional=bidirectional,
- rnn_type=rnn_type,
- )
+ ResidualRNN(
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ )
)
self.parallel_mode = parallel_mode
@@ -149,8 +149,8 @@ def forward(self, z):
for sbm_pair in self.seqband:
# z: (batch, n_bands, n_time, emb_dim)
sbm_t, sbm_f = sbm_pair[0], sbm_pair[1]
- zt = sbm_t(z) # (batch, n_bands, n_time, emb_dim)
- zf = sbm_f(z.transpose(1, 2)) # (batch, n_time, n_bands, emb_dim)
+ zt = sbm_t(z) # (batch, n_bands, n_time, emb_dim)
+ zf = sbm_f(z.transpose(1, 2)) # (batch, n_time, n_bands, emb_dim)
z = zt + zf.transpose(1, 2)
else:
for sbm in self.seqband:
@@ -169,20 +169,17 @@ def forward(self, z):
class ResidualTransformer(nn.Module):
def __init__(
- self,
- emb_dim: int = 128,
- rnn_dim: int = 256,
- bidirectional: bool = True,
- dropout: float = 0.0,
+ self,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ dropout: float = 0.0,
) -> None:
# n_group is the size of the 2nd dim
super().__init__()
self.tf = nn.TransformerEncoderLayer(
- d_model=emb_dim,
- nhead=4,
- dim_feedforward=rnn_dim,
- batch_first=True
+ d_model=emb_dim, nhead=4, dim_feedforward=rnn_dim, batch_first=True
)
self.is_causal = not bidirectional
@@ -191,7 +188,9 @@ def __init__(
def forward(self, z):
batch, n_uncrossed, n_across, emb_dim = z.shape
z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
- z = self.tf(z, is_causal=self.is_causal) # (batch, n_uncrossed, n_across, emb_dim)
+ z = self.tf(
+ z, is_causal=self.is_causal
+ ) # (batch, n_uncrossed, n_across, emb_dim)
z = torch.reshape(z, (batch, n_uncrossed, n_across, emb_dim))
return z
@@ -199,12 +198,12 @@ def forward(self, z):
class TransformerTimeFreqModule(TimeFrequencyModellingModule):
def __init__(
- self,
- n_modules: int = 12,
- emb_dim: int = 128,
- rnn_dim: int = 256,
- bidirectional: bool = True,
- dropout: float = 0.0,
+ self,
+ n_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ dropout: float = 0.0,
) -> None:
super().__init__()
self.norm = nn.LayerNorm(emb_dim)
@@ -212,12 +211,12 @@ def __init__(
for _ in range(2 * n_modules):
self.seqband.append(
- ResidualTransformer(
- emb_dim=emb_dim,
- rnn_dim=rnn_dim,
- bidirectional=bidirectional,
- dropout=dropout,
- )
+ ResidualTransformer(
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ dropout=dropout,
+ )
)
def forward(self, z):
@@ -238,14 +237,13 @@ def forward(self, z):
return q # (batch, n_bands, n_time, emb_dim)
-
class ResidualConvolution(nn.Module):
def __init__(
- self,
- emb_dim: int = 128,
- rnn_dim: int = 256,
- bidirectional: bool = True,
- dropout: float = 0.0,
+ self,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ dropout: float = 0.0,
) -> None:
# n_group is the size of the 2nd dim
super().__init__()
@@ -258,22 +256,21 @@ def __init__(
kernel_size=(3, 3),
padding="same",
stride=(1, 1),
- ),
- nn.Tanhshrink()
+ ),
+ nn.Tanhshrink(),
)
self.is_causal = not bidirectional
self.dropout = dropout
self.fc = nn.Conv2d(
- in_channels=rnn_dim,
- out_channels=emb_dim,
- kernel_size=(1, 1),
- padding="same",
- stride=(1, 1),
+ in_channels=rnn_dim,
+ out_channels=emb_dim,
+ kernel_size=(1, 1),
+ padding="same",
+ stride=(1, 1),
)
-
def forward(self, z):
# z = (batch, n_uncrossed, n_across, emb_dim)
@@ -289,29 +286,35 @@ def forward(self, z):
class ConvolutionalTimeFreqModule(TimeFrequencyModellingModule):
def __init__(
- self,
- n_modules: int = 12,
- emb_dim: int = 128,
- rnn_dim: int = 256,
- bidirectional: bool = True,
- dropout: float = 0.0,
+ self,
+ n_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ dropout: float = 0.0,
) -> None:
super().__init__()
- self.seqband = torch.jit.script(nn.Sequential(
- *[ResidualConvolution(
- emb_dim=emb_dim,
- rnn_dim=rnn_dim,
- bidirectional=bidirectional,
- dropout=dropout,
- ) for _ in range(2 * n_modules) ]))
+ self.seqband = torch.jit.script(
+ nn.Sequential(
+ *[
+ ResidualConvolution(
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ dropout=dropout,
+ )
+ for _ in range(2 * n_modules)
+ ]
+ )
+ )
def forward(self, z):
# z = (batch, n_bands, n_time, emb_dim)
- z = torch.permute(z, (0, 3, 1, 2)) # (batch, emb_dim, n_bands, n_time)
+ z = torch.permute(z, (0, 3, 1, 2)) # (batch, emb_dim, n_bands, n_time)
- z = self.seqband(z) # (batch, emb_dim, n_bands, n_time)
+ z = self.seqband(z) # (batch, emb_dim, n_bands, n_time)
- z = torch.permute(z, (0, 2, 3, 1)) # (batch, n_bands, n_time, emb_dim)
+ z = torch.permute(z, (0, 2, 3, 1)) # (batch, n_bands, n_time, emb_dim)
return z
diff --git a/programs/music_separation_code/models/bandit/core/model/bsrnn/utils.py b/programs/music_separation_code/models/bandit/core/model/bsrnn/utils.py
index bf8636e6..d5f32bad 100644
--- a/programs/music_separation_code/models/bandit/core/model/bsrnn/utils.py
+++ b/programs/music_separation_code/models/bandit/core/model/bsrnn/utils.py
@@ -1,6 +1,6 @@
import os
from abc import abstractmethod
-from typing import Any, Callable
+from typing import Callable
import numpy as np
import torch
@@ -70,12 +70,7 @@ def hertz_to_index(self, hz: float, round: bool = True):
return index
- def get_band_specs_with_bandwidth(
- self,
- start_index,
- end_index,
- bandwidth_hz
- ):
+ def get_band_specs_with_bandwidth(self, start_index, end_index, bandwidth_hz):
band_specs = []
lower = start_index
@@ -105,110 +100,84 @@ def get_band_specs(self):
@property
def version1(self):
return self.get_band_specs_with_bandwidth(
- start_index=0, end_index=self.max_index, bandwidth_hz=1000
+ start_index=0, end_index=self.max_index, bandwidth_hz=1000
)
def version2(self):
below16k = self.get_band_specs_with_bandwidth(
- start_index=0, end_index=self.split16k, bandwidth_hz=1000
+ start_index=0, end_index=self.split16k, bandwidth_hz=1000
)
below20k = self.get_band_specs_with_bandwidth(
- start_index=self.split16k,
- end_index=self.split20k,
- bandwidth_hz=2000
+ start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
)
return below16k + below20k + self.above20k
def version3(self):
below8k = self.get_band_specs_with_bandwidth(
- start_index=0, end_index=self.split8k, bandwidth_hz=1000
+ start_index=0, end_index=self.split8k, bandwidth_hz=1000
)
below16k = self.get_band_specs_with_bandwidth(
- start_index=self.split8k,
- end_index=self.split16k,
- bandwidth_hz=2000
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
)
return below8k + below16k + self.above16k
def version4(self):
below1k = self.get_band_specs_with_bandwidth(
- start_index=0, end_index=self.split1k, bandwidth_hz=100
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
)
below8k = self.get_band_specs_with_bandwidth(
- start_index=self.split1k,
- end_index=self.split8k,
- bandwidth_hz=1000
+ start_index=self.split1k, end_index=self.split8k, bandwidth_hz=1000
)
below16k = self.get_band_specs_with_bandwidth(
- start_index=self.split8k,
- end_index=self.split16k,
- bandwidth_hz=2000
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
)
return below1k + below8k + below16k + self.above16k
def version5(self):
below1k = self.get_band_specs_with_bandwidth(
- start_index=0, end_index=self.split1k, bandwidth_hz=100
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
)
below16k = self.get_band_specs_with_bandwidth(
- start_index=self.split1k,
- end_index=self.split16k,
- bandwidth_hz=1000
+ start_index=self.split1k, end_index=self.split16k, bandwidth_hz=1000
)
below20k = self.get_band_specs_with_bandwidth(
- start_index=self.split16k,
- end_index=self.split20k,
- bandwidth_hz=2000
+ start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
)
return below1k + below16k + below20k + self.above20k
def version6(self):
below1k = self.get_band_specs_with_bandwidth(
- start_index=0, end_index=self.split1k, bandwidth_hz=100
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
)
below4k = self.get_band_specs_with_bandwidth(
- start_index=self.split1k,
- end_index=self.split4k,
- bandwidth_hz=500
+ start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
)
below8k = self.get_band_specs_with_bandwidth(
- start_index=self.split4k,
- end_index=self.split8k,
- bandwidth_hz=1000
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
)
below16k = self.get_band_specs_with_bandwidth(
- start_index=self.split8k,
- end_index=self.split16k,
- bandwidth_hz=2000
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
)
return below1k + below4k + below8k + below16k + self.above16k
def version7(self):
below1k = self.get_band_specs_with_bandwidth(
- start_index=0, end_index=self.split1k, bandwidth_hz=100
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
)
below4k = self.get_band_specs_with_bandwidth(
- start_index=self.split1k,
- end_index=self.split4k,
- bandwidth_hz=250
+ start_index=self.split1k, end_index=self.split4k, bandwidth_hz=250
)
below8k = self.get_band_specs_with_bandwidth(
- start_index=self.split4k,
- end_index=self.split8k,
- bandwidth_hz=500
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
)
below16k = self.get_band_specs_with_bandwidth(
- start_index=self.split8k,
- end_index=self.split16k,
- bandwidth_hz=1000
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
)
below20k = self.get_band_specs_with_bandwidth(
- start_index=self.split16k,
- end_index=self.split20k,
- bandwidth_hz=2000
+ start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
)
return below1k + below4k + below8k + below16k + below20k + self.above20k
@@ -224,27 +193,19 @@ def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
def get_band_specs(self):
below500 = self.get_band_specs_with_bandwidth(
- start_index=0, end_index=self.split500, bandwidth_hz=50
+ start_index=0, end_index=self.split500, bandwidth_hz=50
)
below1k = self.get_band_specs_with_bandwidth(
- start_index=self.split500,
- end_index=self.split1k,
- bandwidth_hz=100
+ start_index=self.split500, end_index=self.split1k, bandwidth_hz=100
)
below4k = self.get_band_specs_with_bandwidth(
- start_index=self.split1k,
- end_index=self.split4k,
- bandwidth_hz=500
+ start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
)
below8k = self.get_band_specs_with_bandwidth(
- start_index=self.split4k,
- end_index=self.split8k,
- bandwidth_hz=1000
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
)
below16k = self.get_band_specs_with_bandwidth(
- start_index=self.split8k,
- end_index=self.split16k,
- bandwidth_hz=2000
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
)
above16k = [(self.split16k, self.max_index)]
@@ -257,59 +218,43 @@ def __init__(self, nfft: int, fs: int) -> None:
def get_band_specs(self):
below1k = self.get_band_specs_with_bandwidth(
- start_index=0, end_index=self.split1k, bandwidth_hz=50
+ start_index=0, end_index=self.split1k, bandwidth_hz=50
)
below2k = self.get_band_specs_with_bandwidth(
- start_index=self.split1k,
- end_index=self.split2k,
- bandwidth_hz=100
+ start_index=self.split1k, end_index=self.split2k, bandwidth_hz=100
)
below4k = self.get_band_specs_with_bandwidth(
- start_index=self.split2k,
- end_index=self.split4k,
- bandwidth_hz=250
+ start_index=self.split2k, end_index=self.split4k, bandwidth_hz=250
)
below8k = self.get_band_specs_with_bandwidth(
- start_index=self.split4k,
- end_index=self.split8k,
- bandwidth_hz=500
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
)
below16k = self.get_band_specs_with_bandwidth(
- start_index=self.split8k,
- end_index=self.split16k,
- bandwidth_hz=1000
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
)
above16k = [(self.split16k, self.max_index)]
return below1k + below2k + below4k + below8k + below16k + above16k
-
-
class PerceptualBandsplitSpecification(BandsplitSpecification):
def __init__(
- self,
- nfft: int,
- fs: int,
- fbank_fn: Callable[[int, int, float, float, int], torch.Tensor],
- n_bands: int,
- f_min: float = 0.0,
- f_max: float = None
+ self,
+ nfft: int,
+ fs: int,
+ fbank_fn: Callable[[int, int, float, float, int], torch.Tensor],
+ n_bands: int,
+ f_min: float = 0.0,
+ f_max: float = None,
) -> None:
super().__init__(nfft=nfft, fs=fs)
self.n_bands = n_bands
if f_max is None:
f_max = fs / 2
- self.filterbank = fbank_fn(
- n_bands, fs, f_min, f_max, self.max_index
- )
+ self.filterbank = fbank_fn(n_bands, fs, f_min, f_max, self.max_index)
- weight_per_bin = torch.sum(
- self.filterbank,
- dim=0,
- keepdim=True
- ) # (1, n_freqs)
+ weight_per_bin = torch.sum(self.filterbank, dim=0, keepdim=True) # (1, n_freqs)
normalized_mel_fb = self.filterbank / weight_per_bin # (n_mels, n_freqs)
freq_weights = []
@@ -342,22 +287,23 @@ def save_to_file(self, dir_path: str) -> None:
with open(os.path.join(dir_path, "mel_bandsplit_spec.pkl"), "wb") as f:
pickle.dump(
- {
- "band_specs": self.band_specs,
- "freq_weights": self.freq_weights,
- "filterbank": self.filterbank,
- },
- f,
+ {
+ "band_specs": self.band_specs,
+ "freq_weights": self.freq_weights,
+ "filterbank": self.filterbank,
+ },
+ f,
)
+
def mel_filterbank(n_bands, fs, f_min, f_max, n_freqs):
fb = taF.melscale_fbanks(
- n_mels=n_bands,
- sample_rate=fs,
- f_min=f_min,
- f_max=f_max,
- n_freqs=n_freqs,
- ).T
+ n_mels=n_bands,
+ sample_rate=fs,
+ f_min=f_min,
+ f_max=f_max,
+ n_freqs=n_freqs,
+ ).T
fb[0, 0] = 1.0
@@ -366,17 +312,19 @@ def mel_filterbank(n_bands, fs, f_min, f_max, n_freqs):
class MelBandsplitSpecification(PerceptualBandsplitSpecification):
def __init__(
- self,
- nfft: int,
- fs: int,
- n_bands: int,
- f_min: float = 0.0,
- f_max: float = None
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
) -> None:
- super().__init__(fbank_fn=mel_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
+ super().__init__(
+ fbank_fn=mel_filterbank,
+ nfft=nfft,
+ fs=fs,
+ n_bands=n_bands,
+ f_min=f_min,
+ f_max=f_max,
+ )
+
-def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs,
- scale="constant"):
+def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs, scale="constant"):
nfft = 2 * (n_freqs - 1)
df = fs / nfft
@@ -403,55 +351,57 @@ def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs,
fb = np.zeros((n_bands, n_freqs))
for i in range(n_bands):
- fb[i, low_bins[i]:high_bins[i]+1] = 1.0
+ fb[i, low_bins[i] : high_bins[i] + 1] = 1.0
- fb[0, :low_bins[0]] = 1.0
- fb[-1, high_bins[-1]+1:] = 1.0
+ fb[0, : low_bins[0]] = 1.0
+ fb[-1, high_bins[-1] + 1 :] = 1.0
return torch.as_tensor(fb)
+
class MusicalBandsplitSpecification(PerceptualBandsplitSpecification):
def __init__(
- self,
- nfft: int,
- fs: int,
- n_bands: int,
- f_min: float = 0.0,
- f_max: float = None
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
) -> None:
- super().__init__(fbank_fn=musical_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
+ super().__init__(
+ fbank_fn=musical_filterbank,
+ nfft=nfft,
+ fs=fs,
+ n_bands=n_bands,
+ f_min=f_min,
+ f_max=f_max,
+ )
-def bark_filterbank(
- n_bands, fs, f_min, f_max, n_freqs
-):
- nfft = 2 * (n_freqs -1)
+def bark_filterbank(n_bands, fs, f_min, f_max, n_freqs):
+ nfft = 2 * (n_freqs - 1)
fb, _ = bark_fbanks.bark_filter_banks(
- nfilts=n_bands,
- nfft=nfft,
- fs=fs,
- low_freq=f_min,
- high_freq=f_max,
- scale="constant"
+ nfilts=n_bands,
+ nfft=nfft,
+ fs=fs,
+ low_freq=f_min,
+ high_freq=f_max,
+ scale="constant",
)
return torch.as_tensor(fb)
+
class BarkBandsplitSpecification(PerceptualBandsplitSpecification):
def __init__(
- self,
- nfft: int,
- fs: int,
- n_bands: int,
- f_min: float = 0.0,
- f_max: float = None
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
) -> None:
- super().__init__(fbank_fn=bark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
+ super().__init__(
+ fbank_fn=bark_filterbank,
+ nfft=nfft,
+ fs=fs,
+ n_bands=n_bands,
+ f_min=f_min,
+ f_max=f_max,
+ )
-def triangular_bark_filterbank(
- n_bands, fs, f_min, f_max, n_freqs
-):
+def triangular_bark_filterbank(n_bands, fs, f_min, f_max, n_freqs):
all_freqs = torch.linspace(0, fs // 2, n_freqs)
@@ -474,47 +424,41 @@ def triangular_bark_filterbank(
return fb
+
class TriangularBarkBandsplitSpecification(PerceptualBandsplitSpecification):
def __init__(
- self,
- nfft: int,
- fs: int,
- n_bands: int,
- f_min: float = 0.0,
- f_max: float = None
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
) -> None:
- super().__init__(fbank_fn=triangular_bark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
-
+ super().__init__(
+ fbank_fn=triangular_bark_filterbank,
+ nfft=nfft,
+ fs=fs,
+ n_bands=n_bands,
+ f_min=f_min,
+ f_max=f_max,
+ )
-def minibark_filterbank(
- n_bands, fs, f_min, f_max, n_freqs
-):
- fb = bark_filterbank(
- n_bands,
- fs,
- f_min,
- f_max,
- n_freqs
- )
+def minibark_filterbank(n_bands, fs, f_min, f_max, n_freqs):
+ fb = bark_filterbank(n_bands, fs, f_min, f_max, n_freqs)
fb[fb < np.sqrt(0.5)] = 0.0
return fb
+
class MiniBarkBandsplitSpecification(PerceptualBandsplitSpecification):
def __init__(
- self,
- nfft: int,
- fs: int,
- n_bands: int,
- f_min: float = 0.0,
- f_max: float = None
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
) -> None:
- super().__init__(fbank_fn=minibark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
-
-
-
+ super().__init__(
+ fbank_fn=minibark_filterbank,
+ nfft=nfft,
+ fs=fs,
+ n_bands=n_bands,
+ f_min=f_min,
+ f_max=f_max,
+ )
def erb_filterbank(
@@ -533,14 +477,13 @@ def erb_filterbank(
m_max = hz2erb(f_max)
m_pts = torch.linspace(m_min, m_max, n_bands + 2)
- f_pts = (torch.pow(10, (m_pts / A)) - 1)/ 0.00437
+ f_pts = (torch.pow(10, (m_pts / A)) - 1) / 0.00437
# create filterbank
fb = _create_triangular_filterbank(all_freqs, f_pts)
fb = fb.T
-
first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
@@ -549,35 +492,34 @@ def erb_filterbank(
return fb
-
class EquivalentRectangularBandsplitSpecification(PerceptualBandsplitSpecification):
def __init__(
- self,
- nfft: int,
- fs: int,
- n_bands: int,
- f_min: float = 0.0,
- f_max: float = None
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
) -> None:
- super().__init__(fbank_fn=erb_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
+ super().__init__(
+ fbank_fn=erb_filterbank,
+ nfft=nfft,
+ fs=fs,
+ n_bands=n_bands,
+ f_min=f_min,
+ f_max=f_max,
+ )
+
if __name__ == "__main__":
import pandas as pd
band_defs = []
- for bands in [VocalBandsplitSpecification]:
+ for bands in [VocalBandsplitSpecification]:
band_name = bands.__name__.replace("BandsplitSpecification", "")
mbs = bands(nfft=2048, fs=44100).get_band_specs()
for i, (f_min, f_max) in enumerate(mbs):
- band_defs.append({
- "band": band_name,
- "band_index": i,
- "f_min": f_min,
- "f_max": f_max
- })
+ band_defs.append(
+ {"band": band_name, "band_index": i, "f_min": f_min, "f_max": f_max}
+ )
df = pd.DataFrame(band_defs)
- df.to_csv("vox7bands.csv", index=False)
\ No newline at end of file
+ df.to_csv("vox7bands.csv", index=False)
diff --git a/programs/music_separation_code/models/bandit/core/model/bsrnn/wrapper.py b/programs/music_separation_code/models/bandit/core/model/bsrnn/wrapper.py
index a31c087d..6f26e9d9 100644
--- a/programs/music_separation_code/models/bandit/core/model/bsrnn/wrapper.py
+++ b/programs/music_separation_code/models/bandit/core/model/bsrnn/wrapper.py
@@ -1,4 +1,3 @@
-from pprint import pprint
from typing import Dict, List, Optional, Tuple, Union
import torch
@@ -6,76 +5,62 @@
from models.bandit.core.model._spectral import _SpectralComponent
from models.bandit.core.model.bsrnn.utils import (
- BarkBandsplitSpecification, BassBandsplitSpecification,
+ BarkBandsplitSpecification,
+ BassBandsplitSpecification,
DrumBandsplitSpecification,
- EquivalentRectangularBandsplitSpecification, MelBandsplitSpecification,
- MusicalBandsplitSpecification, OtherBandsplitSpecification,
- TriangularBarkBandsplitSpecification, VocalBandsplitSpecification,
+ EquivalentRectangularBandsplitSpecification,
+ MelBandsplitSpecification,
+ MusicalBandsplitSpecification,
+ OtherBandsplitSpecification,
+ TriangularBarkBandsplitSpecification,
+ VocalBandsplitSpecification,
)
from .core import (
MultiSourceMultiMaskBandSplitCoreConv,
MultiSourceMultiMaskBandSplitCoreRNN,
MultiSourceMultiMaskBandSplitCoreTransformer,
- MultiSourceMultiPatchingMaskBandSplitCoreRNN, SingleMaskBandsplitCoreRNN,
+ MultiSourceMultiPatchingMaskBandSplitCoreRNN,
+ SingleMaskBandsplitCoreRNN,
SingleMaskBandsplitCoreTransformer,
)
import pytorch_lightning as pl
+
def get_band_specs(band_specs, n_fft, fs, n_bands=None):
if band_specs in ["dnr:speech", "dnr:vox7", "musdb:vocals", "musdb:vox7"]:
- bsm = VocalBandsplitSpecification(
- nfft=n_fft, fs=fs
- ).get_band_specs()
+ bsm = VocalBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs()
freq_weights = None
overlapping_band = False
elif "tribark" in band_specs:
assert n_bands is not None
- specs = TriangularBarkBandsplitSpecification(
- nfft=n_fft,
- fs=fs,
- n_bands=n_bands
- )
+ specs = TriangularBarkBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
bsm = specs.get_band_specs()
freq_weights = specs.get_freq_weights()
overlapping_band = True
elif "bark" in band_specs:
assert n_bands is not None
- specs = BarkBandsplitSpecification(
- nfft=n_fft,
- fs=fs,
- n_bands=n_bands
- )
+ specs = BarkBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
bsm = specs.get_band_specs()
freq_weights = specs.get_freq_weights()
overlapping_band = True
elif "erb" in band_specs:
assert n_bands is not None
specs = EquivalentRectangularBandsplitSpecification(
- nfft=n_fft,
- fs=fs,
- n_bands=n_bands
+ nfft=n_fft, fs=fs, n_bands=n_bands
)
bsm = specs.get_band_specs()
freq_weights = specs.get_freq_weights()
overlapping_band = True
elif "musical" in band_specs:
assert n_bands is not None
- specs = MusicalBandsplitSpecification(
- nfft=n_fft,
- fs=fs,
- n_bands=n_bands
- )
+ specs = MusicalBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
bsm = specs.get_band_specs()
freq_weights = specs.get_freq_weights()
overlapping_band = True
elif band_specs == "dnr:mel" or "mel" in band_specs:
assert n_bands is not None
- specs = MelBandsplitSpecification(
- nfft=n_fft,
- fs=fs,
- n_bands=n_bands
- )
+ specs = MelBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
bsm = specs.get_band_specs()
freq_weights = specs.get_freq_weights()
overlapping_band = True
@@ -88,38 +73,24 @@ def get_band_specs(band_specs, n_fft, fs, n_bands=None):
def get_band_specs_map(band_specs_map, n_fft, fs, n_bands=None):
if band_specs_map == "musdb:all":
bsm = {
- "vocals": VocalBandsplitSpecification(
- nfft=n_fft, fs=fs
- ).get_band_specs(),
- "drums": DrumBandsplitSpecification(
- nfft=n_fft, fs=fs
- ).get_band_specs(),
- "bass": BassBandsplitSpecification(
- nfft=n_fft, fs=fs
- ).get_band_specs(),
- "other": OtherBandsplitSpecification(
- nfft=n_fft, fs=fs
- ).get_band_specs(),
+ "vocals": VocalBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
+ "drums": DrumBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
+ "bass": BassBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
+ "other": OtherBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
}
freq_weights = None
overlapping_band = False
elif band_specs_map == "dnr:vox7":
bsm_, freq_weights, overlapping_band = get_band_specs(
- "dnr:speech", n_fft, fs, n_bands
+ "dnr:speech", n_fft, fs, n_bands
)
- bsm = {
- "speech": bsm_,
- "music": bsm_,
- "effects": bsm_
- }
+ bsm = {"speech": bsm_, "music": bsm_, "effects": bsm_}
elif "dnr:vox7:" in band_specs_map:
stem = band_specs_map.split(":")[-1]
bsm_, freq_weights, overlapping_band = get_band_specs(
- "dnr:speech", n_fft, fs, n_bands
+ "dnr:speech", n_fft, fs, n_bands
)
- bsm = {
- stem: bsm_
- }
+ bsm = {stem: bsm_}
else:
raise NameError
@@ -128,51 +99,45 @@ def get_band_specs_map(band_specs_map, n_fft, fs, n_bands=None):
class BandSplitWrapperBase(pl.LightningModule):
bsrnn: nn.Module
-
+
def __init__(self, **kwargs):
super().__init__()
-class SingleMaskMultiSourceBandSplitBase(
- BandSplitWrapperBase,
- _SpectralComponent
-):
+class SingleMaskMultiSourceBandSplitBase(BandSplitWrapperBase, _SpectralComponent):
def __init__(
- self,
- band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
- fs: int = 44100,
- n_fft: int = 2048,
- win_length: Optional[int] = 2048,
- hop_length: int = 512,
- window_fn: str = "hann_window",
- wkwargs: Optional[Dict] = None,
- power: Optional[int] = None,
- center: bool = True,
- normalized: bool = True,
- pad_mode: str = "constant",
- onesided: bool = True,
- n_bands: int = None,
+ self,
+ band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
+ fs: int = 44100,
+ n_fft: int = 2048,
+ win_length: Optional[int] = 2048,
+ hop_length: int = 512,
+ window_fn: str = "hann_window",
+ wkwargs: Optional[Dict] = None,
+ power: Optional[int] = None,
+ center: bool = True,
+ normalized: bool = True,
+ pad_mode: str = "constant",
+ onesided: bool = True,
+ n_bands: int = None,
) -> None:
super().__init__(
- n_fft=n_fft,
- win_length=win_length,
- hop_length=hop_length,
- window_fn=window_fn,
- wkwargs=wkwargs,
- power=power,
- center=center,
- normalized=normalized,
- pad_mode=pad_mode,
- onesided=onesided,
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ power=power,
+ center=center,
+ normalized=normalized,
+ pad_mode=pad_mode,
+ onesided=onesided,
)
if isinstance(band_specs_map, str):
- self.band_specs_map, self.freq_weights, self.overlapping_band = get_band_specs_map(
- band_specs_map,
- n_fft,
- fs,
- n_bands=n_bands
- )
+ self.band_specs_map, self.freq_weights, self.overlapping_band = (
+ get_band_specs_map(band_specs_map, n_fft, fs, n_bands=n_bands)
+ )
self.stems = list(self.band_specs_map.keys())
@@ -180,8 +145,7 @@ def forward(self, batch):
audio = batch["audio"]
with torch.no_grad():
- batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in
- audio}
+ batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in audio}
X = batch["spectrogram"]["mixture"]
length = batch["audio"]["mixture"].shape[-1]
@@ -197,47 +161,41 @@ def forward(self, batch):
return batch, output
-class MultiMaskMultiSourceBandSplitBase(
- BandSplitWrapperBase,
- _SpectralComponent
-):
+class MultiMaskMultiSourceBandSplitBase(BandSplitWrapperBase, _SpectralComponent):
def __init__(
- self,
- stems: List[str],
- band_specs: Union[str, List[Tuple[float, float]]],
- fs: int = 44100,
- n_fft: int = 2048,
- win_length: Optional[int] = 2048,
- hop_length: int = 512,
- window_fn: str = "hann_window",
- wkwargs: Optional[Dict] = None,
- power: Optional[int] = None,
- center: bool = True,
- normalized: bool = True,
- pad_mode: str = "constant",
- onesided: bool = True,
- n_bands: int = None,
+ self,
+ stems: List[str],
+ band_specs: Union[str, List[Tuple[float, float]]],
+ fs: int = 44100,
+ n_fft: int = 2048,
+ win_length: Optional[int] = 2048,
+ hop_length: int = 512,
+ window_fn: str = "hann_window",
+ wkwargs: Optional[Dict] = None,
+ power: Optional[int] = None,
+ center: bool = True,
+ normalized: bool = True,
+ pad_mode: str = "constant",
+ onesided: bool = True,
+ n_bands: int = None,
) -> None:
super().__init__(
- n_fft=n_fft,
- win_length=win_length,
- hop_length=hop_length,
- window_fn=window_fn,
- wkwargs=wkwargs,
- power=power,
- center=center,
- normalized=normalized,
- pad_mode=pad_mode,
- onesided=onesided,
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ power=power,
+ center=center,
+ normalized=normalized,
+ pad_mode=pad_mode,
+ onesided=onesided,
)
if isinstance(band_specs, str):
self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs(
- band_specs,
- n_fft,
- fs,
- n_bands
- )
+ band_specs, n_fft, fs, n_bands
+ )
self.stems = stems
@@ -246,8 +204,7 @@ def forward(self, batch):
audio = batch["audio"]
cond = batch.get("condition", None)
with torch.no_grad():
- batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in
- audio}
+ batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in audio}
X = batch["spectrogram"]["mixture"]
length = batch["audio"]["mixture"].shape[-1]
@@ -262,47 +219,41 @@ def forward(self, batch):
return batch, output
-class MultiMaskMultiSourceBandSplitBaseSimple(
- BandSplitWrapperBase,
- _SpectralComponent
-):
+class MultiMaskMultiSourceBandSplitBaseSimple(BandSplitWrapperBase, _SpectralComponent):
def __init__(
- self,
- stems: List[str],
- band_specs: Union[str, List[Tuple[float, float]]],
- fs: int = 44100,
- n_fft: int = 2048,
- win_length: Optional[int] = 2048,
- hop_length: int = 512,
- window_fn: str = "hann_window",
- wkwargs: Optional[Dict] = None,
- power: Optional[int] = None,
- center: bool = True,
- normalized: bool = True,
- pad_mode: str = "constant",
- onesided: bool = True,
- n_bands: int = None,
+ self,
+ stems: List[str],
+ band_specs: Union[str, List[Tuple[float, float]]],
+ fs: int = 44100,
+ n_fft: int = 2048,
+ win_length: Optional[int] = 2048,
+ hop_length: int = 512,
+ window_fn: str = "hann_window",
+ wkwargs: Optional[Dict] = None,
+ power: Optional[int] = None,
+ center: bool = True,
+ normalized: bool = True,
+ pad_mode: str = "constant",
+ onesided: bool = True,
+ n_bands: int = None,
) -> None:
super().__init__(
- n_fft=n_fft,
- win_length=win_length,
- hop_length=hop_length,
- window_fn=window_fn,
- wkwargs=wkwargs,
- power=power,
- center=center,
- normalized=normalized,
- pad_mode=pad_mode,
- onesided=onesided,
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ power=power,
+ center=center,
+ normalized=normalized,
+ pad_mode=pad_mode,
+ onesided=onesided,
)
if isinstance(band_specs, str):
self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs(
- band_specs,
- n_fft,
- fs,
- n_bands
- )
+ band_specs, n_fft, fs, n_bands
+ )
self.stems = stems
@@ -321,221 +272,219 @@ def forward(self, batch):
class SingleMaskMultiSourceBandSplitRNN(SingleMaskMultiSourceBandSplitBase):
def __init__(
- self,
- in_channel: int,
- band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
- fs: int = 44100,
- require_no_overlap: bool = False,
- require_no_gap: bool = True,
- normalize_channel_independently: bool = False,
- treat_channel_as_feature: bool = True,
- n_sqm_modules: int = 12,
- emb_dim: int = 128,
- rnn_dim: int = 256,
- bidirectional: bool = True,
- rnn_type: str = "LSTM",
- mlp_dim: int = 512,
- hidden_activation: str = "Tanh",
- hidden_activation_kwargs: Optional[Dict] = None,
- complex_mask: bool = True,
- n_fft: int = 2048,
- win_length: Optional[int] = 2048,
- hop_length: int = 512,
- window_fn: str = "hann_window",
- wkwargs: Optional[Dict] = None,
- power: Optional[int] = None,
- center: bool = True,
- normalized: bool = True,
- pad_mode: str = "constant",
- onesided: bool = True,
+ self,
+ in_channel: int,
+ band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
+ fs: int = 44100,
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ mlp_dim: int = 512,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ n_fft: int = 2048,
+ win_length: Optional[int] = 2048,
+ hop_length: int = 512,
+ window_fn: str = "hann_window",
+ wkwargs: Optional[Dict] = None,
+ power: Optional[int] = None,
+ center: bool = True,
+ normalized: bool = True,
+ pad_mode: str = "constant",
+ onesided: bool = True,
) -> None:
super().__init__(
- band_specs_map=band_specs_map,
- fs=fs,
- n_fft=n_fft,
- win_length=win_length,
- hop_length=hop_length,
- window_fn=window_fn,
- wkwargs=wkwargs,
- power=power,
- center=center,
- normalized=normalized,
- pad_mode=pad_mode,
- onesided=onesided,
+ band_specs_map=band_specs_map,
+ fs=fs,
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ power=power,
+ center=center,
+ normalized=normalized,
+ pad_mode=pad_mode,
+ onesided=onesided,
)
self.bsrnn = nn.ModuleDict(
- {
- src: SingleMaskBandsplitCoreRNN(
- band_specs=specs,
- in_channel=in_channel,
- require_no_overlap=require_no_overlap,
- require_no_gap=require_no_gap,
- normalize_channel_independently=normalize_channel_independently,
- treat_channel_as_feature=treat_channel_as_feature,
- n_sqm_modules=n_sqm_modules,
- emb_dim=emb_dim,
- rnn_dim=rnn_dim,
- bidirectional=bidirectional,
- rnn_type=rnn_type,
- mlp_dim=mlp_dim,
- hidden_activation=hidden_activation,
- hidden_activation_kwargs=hidden_activation_kwargs,
- complex_mask=complex_mask,
- )
- for src, specs in self.band_specs_map.items()
- }
+ {
+ src: SingleMaskBandsplitCoreRNN(
+ band_specs=specs,
+ in_channel=in_channel,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ n_sqm_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ mlp_dim=mlp_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ )
+ for src, specs in self.band_specs_map.items()
+ }
)
-class SingleMaskMultiSourceBandSplitTransformer(
- SingleMaskMultiSourceBandSplitBase
-):
+class SingleMaskMultiSourceBandSplitTransformer(SingleMaskMultiSourceBandSplitBase):
def __init__(
- self,
- in_channel: int,
- band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
- fs: int = 44100,
- require_no_overlap: bool = False,
- require_no_gap: bool = True,
- normalize_channel_independently: bool = False,
- treat_channel_as_feature: bool = True,
- n_sqm_modules: int = 12,
- emb_dim: int = 128,
- rnn_dim: int = 256,
- bidirectional: bool = True,
- tf_dropout: float = 0.0,
- mlp_dim: int = 512,
- hidden_activation: str = "Tanh",
- hidden_activation_kwargs: Optional[Dict] = None,
- complex_mask: bool = True,
- n_fft: int = 2048,
- win_length: Optional[int] = 2048,
- hop_length: int = 512,
- window_fn: str = "hann_window",
- wkwargs: Optional[Dict] = None,
- power: Optional[int] = None,
- center: bool = True,
- normalized: bool = True,
- pad_mode: str = "constant",
- onesided: bool = True,
+ self,
+ in_channel: int,
+ band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
+ fs: int = 44100,
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ tf_dropout: float = 0.0,
+ mlp_dim: int = 512,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ n_fft: int = 2048,
+ win_length: Optional[int] = 2048,
+ hop_length: int = 512,
+ window_fn: str = "hann_window",
+ wkwargs: Optional[Dict] = None,
+ power: Optional[int] = None,
+ center: bool = True,
+ normalized: bool = True,
+ pad_mode: str = "constant",
+ onesided: bool = True,
) -> None:
super().__init__(
- band_specs_map=band_specs_map,
- fs=fs,
- n_fft=n_fft,
- win_length=win_length,
- hop_length=hop_length,
- window_fn=window_fn,
- wkwargs=wkwargs,
- power=power,
- center=center,
- normalized=normalized,
- pad_mode=pad_mode,
- onesided=onesided,
+ band_specs_map=band_specs_map,
+ fs=fs,
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ power=power,
+ center=center,
+ normalized=normalized,
+ pad_mode=pad_mode,
+ onesided=onesided,
)
self.bsrnn = nn.ModuleDict(
- {
- src: SingleMaskBandsplitCoreTransformer(
- band_specs=specs,
- in_channel=in_channel,
- require_no_overlap=require_no_overlap,
- require_no_gap=require_no_gap,
- normalize_channel_independently=normalize_channel_independently,
- treat_channel_as_feature=treat_channel_as_feature,
- n_sqm_modules=n_sqm_modules,
- emb_dim=emb_dim,
- rnn_dim=rnn_dim,
- bidirectional=bidirectional,
- tf_dropout=tf_dropout,
- mlp_dim=mlp_dim,
- hidden_activation=hidden_activation,
- hidden_activation_kwargs=hidden_activation_kwargs,
- complex_mask=complex_mask,
- )
- for src, specs in self.band_specs_map.items()
- }
+ {
+ src: SingleMaskBandsplitCoreTransformer(
+ band_specs=specs,
+ in_channel=in_channel,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ n_sqm_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ tf_dropout=tf_dropout,
+ mlp_dim=mlp_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ )
+ for src, specs in self.band_specs_map.items()
+ }
)
class MultiMaskMultiSourceBandSplitRNN(MultiMaskMultiSourceBandSplitBase):
def __init__(
- self,
- in_channel: int,
- stems: List[str],
- band_specs: Union[str, List[Tuple[float, float]]],
- fs: int = 44100,
- require_no_overlap: bool = False,
- require_no_gap: bool = True,
- normalize_channel_independently: bool = False,
- treat_channel_as_feature: bool = True,
- n_sqm_modules: int = 12,
- emb_dim: int = 128,
- rnn_dim: int = 256,
- cond_dim: int = 0,
- bidirectional: bool = True,
- rnn_type: str = "LSTM",
- mlp_dim: int = 512,
- hidden_activation: str = "Tanh",
- hidden_activation_kwargs: Optional[Dict] = None,
- complex_mask: bool = True,
- n_fft: int = 2048,
- win_length: Optional[int] = 2048,
- hop_length: int = 512,
- window_fn: str = "hann_window",
- wkwargs: Optional[Dict] = None,
- power: Optional[int] = None,
- center: bool = True,
- normalized: bool = True,
- pad_mode: str = "constant",
- onesided: bool = True,
- n_bands: int = None,
- use_freq_weights: bool = True,
- normalize_input: bool = False,
- mult_add_mask: bool = False,
- freeze_encoder: bool = False,
+ self,
+ in_channel: int,
+ stems: List[str],
+ band_specs: Union[str, List[Tuple[float, float]]],
+ fs: int = 44100,
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ cond_dim: int = 0,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ mlp_dim: int = 512,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ n_fft: int = 2048,
+ win_length: Optional[int] = 2048,
+ hop_length: int = 512,
+ window_fn: str = "hann_window",
+ wkwargs: Optional[Dict] = None,
+ power: Optional[int] = None,
+ center: bool = True,
+ normalized: bool = True,
+ pad_mode: str = "constant",
+ onesided: bool = True,
+ n_bands: int = None,
+ use_freq_weights: bool = True,
+ normalize_input: bool = False,
+ mult_add_mask: bool = False,
+ freeze_encoder: bool = False,
) -> None:
super().__init__(
- stems=stems,
- band_specs=band_specs,
- fs=fs,
- n_fft=n_fft,
- win_length=win_length,
- hop_length=hop_length,
- window_fn=window_fn,
- wkwargs=wkwargs,
- power=power,
- center=center,
- normalized=normalized,
- pad_mode=pad_mode,
- onesided=onesided,
- n_bands=n_bands,
+ stems=stems,
+ band_specs=band_specs,
+ fs=fs,
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ power=power,
+ center=center,
+ normalized=normalized,
+ pad_mode=pad_mode,
+ onesided=onesided,
+ n_bands=n_bands,
)
self.bsrnn = MultiSourceMultiMaskBandSplitCoreRNN(
- stems=stems,
- band_specs=self.band_specs,
- in_channel=in_channel,
- require_no_overlap=require_no_overlap,
- require_no_gap=require_no_gap,
- normalize_channel_independently=normalize_channel_independently,
- treat_channel_as_feature=treat_channel_as_feature,
- n_sqm_modules=n_sqm_modules,
- emb_dim=emb_dim,
- rnn_dim=rnn_dim,
- bidirectional=bidirectional,
- rnn_type=rnn_type,
- mlp_dim=mlp_dim,
- cond_dim=cond_dim,
- hidden_activation=hidden_activation,
- hidden_activation_kwargs=hidden_activation_kwargs,
- complex_mask=complex_mask,
- overlapping_band=self.overlapping_band,
- freq_weights=self.freq_weights,
- n_freq=n_fft // 2 + 1,
- use_freq_weights=use_freq_weights,
- mult_add_mask=mult_add_mask
+ stems=stems,
+ band_specs=self.band_specs,
+ in_channel=in_channel,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ n_sqm_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ mlp_dim=mlp_dim,
+ cond_dim=cond_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ overlapping_band=self.overlapping_band,
+ freq_weights=self.freq_weights,
+ n_freq=n_fft // 2 + 1,
+ use_freq_weights=use_freq_weights,
+ mult_add_mask=mult_add_mask,
)
self.normalize_input = normalize_input
@@ -551,81 +500,81 @@ def __init__(
class MultiMaskMultiSourceBandSplitRNNSimple(MultiMaskMultiSourceBandSplitBaseSimple):
def __init__(
- self,
- in_channel: int,
- stems: List[str],
- band_specs: Union[str, List[Tuple[float, float]]],
- fs: int = 44100,
- require_no_overlap: bool = False,
- require_no_gap: bool = True,
- normalize_channel_independently: bool = False,
- treat_channel_as_feature: bool = True,
- n_sqm_modules: int = 12,
- emb_dim: int = 128,
- rnn_dim: int = 256,
- cond_dim: int = 0,
- bidirectional: bool = True,
- rnn_type: str = "LSTM",
- mlp_dim: int = 512,
- hidden_activation: str = "Tanh",
- hidden_activation_kwargs: Optional[Dict] = None,
- complex_mask: bool = True,
- n_fft: int = 2048,
- win_length: Optional[int] = 2048,
- hop_length: int = 512,
- window_fn: str = "hann_window",
- wkwargs: Optional[Dict] = None,
- power: Optional[int] = None,
- center: bool = True,
- normalized: bool = True,
- pad_mode: str = "constant",
- onesided: bool = True,
- n_bands: int = None,
- use_freq_weights: bool = True,
- normalize_input: bool = False,
- mult_add_mask: bool = False,
- freeze_encoder: bool = False,
+ self,
+ in_channel: int,
+ stems: List[str],
+ band_specs: Union[str, List[Tuple[float, float]]],
+ fs: int = 44100,
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ cond_dim: int = 0,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ mlp_dim: int = 512,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ n_fft: int = 2048,
+ win_length: Optional[int] = 2048,
+ hop_length: int = 512,
+ window_fn: str = "hann_window",
+ wkwargs: Optional[Dict] = None,
+ power: Optional[int] = None,
+ center: bool = True,
+ normalized: bool = True,
+ pad_mode: str = "constant",
+ onesided: bool = True,
+ n_bands: int = None,
+ use_freq_weights: bool = True,
+ normalize_input: bool = False,
+ mult_add_mask: bool = False,
+ freeze_encoder: bool = False,
) -> None:
super().__init__(
- stems=stems,
- band_specs=band_specs,
- fs=fs,
- n_fft=n_fft,
- win_length=win_length,
- hop_length=hop_length,
- window_fn=window_fn,
- wkwargs=wkwargs,
- power=power,
- center=center,
- normalized=normalized,
- pad_mode=pad_mode,
- onesided=onesided,
- n_bands=n_bands,
+ stems=stems,
+ band_specs=band_specs,
+ fs=fs,
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ power=power,
+ center=center,
+ normalized=normalized,
+ pad_mode=pad_mode,
+ onesided=onesided,
+ n_bands=n_bands,
)
self.bsrnn = MultiSourceMultiMaskBandSplitCoreRNN(
- stems=stems,
- band_specs=self.band_specs,
- in_channel=in_channel,
- require_no_overlap=require_no_overlap,
- require_no_gap=require_no_gap,
- normalize_channel_independently=normalize_channel_independently,
- treat_channel_as_feature=treat_channel_as_feature,
- n_sqm_modules=n_sqm_modules,
- emb_dim=emb_dim,
- rnn_dim=rnn_dim,
- bidirectional=bidirectional,
- rnn_type=rnn_type,
- mlp_dim=mlp_dim,
- cond_dim=cond_dim,
- hidden_activation=hidden_activation,
- hidden_activation_kwargs=hidden_activation_kwargs,
- complex_mask=complex_mask,
- overlapping_band=self.overlapping_band,
- freq_weights=self.freq_weights,
- n_freq=n_fft // 2 + 1,
- use_freq_weights=use_freq_weights,
- mult_add_mask=mult_add_mask
+ stems=stems,
+ band_specs=self.band_specs,
+ in_channel=in_channel,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ n_sqm_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ mlp_dim=mlp_dim,
+ cond_dim=cond_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ overlapping_band=self.overlapping_band,
+ freq_weights=self.freq_weights,
+ n_freq=n_fft // 2 + 1,
+ use_freq_weights=use_freq_weights,
+ mult_add_mask=mult_add_mask,
)
self.normalize_input = normalize_input
@@ -639,244 +588,241 @@ def __init__(
param.requires_grad = False
-class MultiMaskMultiSourceBandSplitTransformer(
- MultiMaskMultiSourceBandSplitBase
-):
+class MultiMaskMultiSourceBandSplitTransformer(MultiMaskMultiSourceBandSplitBase):
def __init__(
- self,
- in_channel: int,
- stems: List[str],
- band_specs: Union[str, List[Tuple[float, float]]],
- fs: int = 44100,
- require_no_overlap: bool = False,
- require_no_gap: bool = True,
- normalize_channel_independently: bool = False,
- treat_channel_as_feature: bool = True,
- n_sqm_modules: int = 12,
- emb_dim: int = 128,
- rnn_dim: int = 256,
- cond_dim: int = 0,
- bidirectional: bool = True,
- rnn_type: str = "LSTM",
- mlp_dim: int = 512,
- hidden_activation: str = "Tanh",
- hidden_activation_kwargs: Optional[Dict] = None,
- complex_mask: bool = True,
- n_fft: int = 2048,
- win_length: Optional[int] = 2048,
- hop_length: int = 512,
- window_fn: str = "hann_window",
- wkwargs: Optional[Dict] = None,
- power: Optional[int] = None,
- center: bool = True,
- normalized: bool = True,
- pad_mode: str = "constant",
- onesided: bool = True,
- n_bands: int = None,
- use_freq_weights: bool = True,
- normalize_input: bool = False,
- mult_add_mask: bool = False
+ self,
+ in_channel: int,
+ stems: List[str],
+ band_specs: Union[str, List[Tuple[float, float]]],
+ fs: int = 44100,
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ cond_dim: int = 0,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ mlp_dim: int = 512,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ n_fft: int = 2048,
+ win_length: Optional[int] = 2048,
+ hop_length: int = 512,
+ window_fn: str = "hann_window",
+ wkwargs: Optional[Dict] = None,
+ power: Optional[int] = None,
+ center: bool = True,
+ normalized: bool = True,
+ pad_mode: str = "constant",
+ onesided: bool = True,
+ n_bands: int = None,
+ use_freq_weights: bool = True,
+ normalize_input: bool = False,
+ mult_add_mask: bool = False,
) -> None:
super().__init__(
- stems=stems,
- band_specs=band_specs,
- fs=fs,
- n_fft=n_fft,
- win_length=win_length,
- hop_length=hop_length,
- window_fn=window_fn,
- wkwargs=wkwargs,
- power=power,
- center=center,
- normalized=normalized,
- pad_mode=pad_mode,
- onesided=onesided,
- n_bands=n_bands,
+ stems=stems,
+ band_specs=band_specs,
+ fs=fs,
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ power=power,
+ center=center,
+ normalized=normalized,
+ pad_mode=pad_mode,
+ onesided=onesided,
+ n_bands=n_bands,
)
self.bsrnn = MultiSourceMultiMaskBandSplitCoreTransformer(
- stems=stems,
- band_specs=self.band_specs,
- in_channel=in_channel,
- require_no_overlap=require_no_overlap,
- require_no_gap=require_no_gap,
- normalize_channel_independently=normalize_channel_independently,
- treat_channel_as_feature=treat_channel_as_feature,
- n_sqm_modules=n_sqm_modules,
- emb_dim=emb_dim,
- rnn_dim=rnn_dim,
- bidirectional=bidirectional,
- rnn_type=rnn_type,
- mlp_dim=mlp_dim,
- cond_dim=cond_dim,
- hidden_activation=hidden_activation,
- hidden_activation_kwargs=hidden_activation_kwargs,
- complex_mask=complex_mask,
- overlapping_band=self.overlapping_band,
- freq_weights=self.freq_weights,
- n_freq=n_fft // 2 + 1,
- use_freq_weights=use_freq_weights,
- mult_add_mask=mult_add_mask
+ stems=stems,
+ band_specs=self.band_specs,
+ in_channel=in_channel,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ n_sqm_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ mlp_dim=mlp_dim,
+ cond_dim=cond_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ overlapping_band=self.overlapping_band,
+ freq_weights=self.freq_weights,
+ n_freq=n_fft // 2 + 1,
+ use_freq_weights=use_freq_weights,
+ mult_add_mask=mult_add_mask,
)
-
-class MultiMaskMultiSourceBandSplitConv(
- MultiMaskMultiSourceBandSplitBase
-):
+class MultiMaskMultiSourceBandSplitConv(MultiMaskMultiSourceBandSplitBase):
def __init__(
- self,
- in_channel: int,
- stems: List[str],
- band_specs: Union[str, List[Tuple[float, float]]],
- fs: int = 44100,
- require_no_overlap: bool = False,
- require_no_gap: bool = True,
- normalize_channel_independently: bool = False,
- treat_channel_as_feature: bool = True,
- n_sqm_modules: int = 12,
- emb_dim: int = 128,
- rnn_dim: int = 256,
- cond_dim: int = 0,
- bidirectional: bool = True,
- rnn_type: str = "LSTM",
- mlp_dim: int = 512,
- hidden_activation: str = "Tanh",
- hidden_activation_kwargs: Optional[Dict] = None,
- complex_mask: bool = True,
- n_fft: int = 2048,
- win_length: Optional[int] = 2048,
- hop_length: int = 512,
- window_fn: str = "hann_window",
- wkwargs: Optional[Dict] = None,
- power: Optional[int] = None,
- center: bool = True,
- normalized: bool = True,
- pad_mode: str = "constant",
- onesided: bool = True,
- n_bands: int = None,
- use_freq_weights: bool = True,
- normalize_input: bool = False,
- mult_add_mask: bool = False
+ self,
+ in_channel: int,
+ stems: List[str],
+ band_specs: Union[str, List[Tuple[float, float]]],
+ fs: int = 44100,
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ cond_dim: int = 0,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ mlp_dim: int = 512,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ n_fft: int = 2048,
+ win_length: Optional[int] = 2048,
+ hop_length: int = 512,
+ window_fn: str = "hann_window",
+ wkwargs: Optional[Dict] = None,
+ power: Optional[int] = None,
+ center: bool = True,
+ normalized: bool = True,
+ pad_mode: str = "constant",
+ onesided: bool = True,
+ n_bands: int = None,
+ use_freq_weights: bool = True,
+ normalize_input: bool = False,
+ mult_add_mask: bool = False,
) -> None:
super().__init__(
- stems=stems,
- band_specs=band_specs,
- fs=fs,
- n_fft=n_fft,
- win_length=win_length,
- hop_length=hop_length,
- window_fn=window_fn,
- wkwargs=wkwargs,
- power=power,
- center=center,
- normalized=normalized,
- pad_mode=pad_mode,
- onesided=onesided,
- n_bands=n_bands,
+ stems=stems,
+ band_specs=band_specs,
+ fs=fs,
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ power=power,
+ center=center,
+ normalized=normalized,
+ pad_mode=pad_mode,
+ onesided=onesided,
+ n_bands=n_bands,
)
self.bsrnn = MultiSourceMultiMaskBandSplitCoreConv(
- stems=stems,
- band_specs=self.band_specs,
- in_channel=in_channel,
- require_no_overlap=require_no_overlap,
- require_no_gap=require_no_gap,
- normalize_channel_independently=normalize_channel_independently,
- treat_channel_as_feature=treat_channel_as_feature,
- n_sqm_modules=n_sqm_modules,
- emb_dim=emb_dim,
- rnn_dim=rnn_dim,
- bidirectional=bidirectional,
- rnn_type=rnn_type,
- mlp_dim=mlp_dim,
- cond_dim=cond_dim,
- hidden_activation=hidden_activation,
- hidden_activation_kwargs=hidden_activation_kwargs,
- complex_mask=complex_mask,
- overlapping_band=self.overlapping_band,
- freq_weights=self.freq_weights,
- n_freq=n_fft // 2 + 1,
- use_freq_weights=use_freq_weights,
- mult_add_mask=mult_add_mask
+ stems=stems,
+ band_specs=self.band_specs,
+ in_channel=in_channel,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ n_sqm_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ mlp_dim=mlp_dim,
+ cond_dim=cond_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ overlapping_band=self.overlapping_band,
+ freq_weights=self.freq_weights,
+ n_freq=n_fft // 2 + 1,
+ use_freq_weights=use_freq_weights,
+ mult_add_mask=mult_add_mask,
)
+
+
class PatchingMaskMultiSourceBandSplitRNN(MultiMaskMultiSourceBandSplitBase):
def __init__(
- self,
- in_channel: int,
- stems: List[str],
- band_specs: Union[str, List[Tuple[float, float]]],
- kernel_norm_mlp_version: int = 1,
- mask_kernel_freq: int = 3,
- mask_kernel_time: int = 3,
- conv_kernel_freq: int = 1,
- conv_kernel_time: int = 1,
- fs: int = 44100,
- require_no_overlap: bool = False,
- require_no_gap: bool = True,
- normalize_channel_independently: bool = False,
- treat_channel_as_feature: bool = True,
- n_sqm_modules: int = 12,
- emb_dim: int = 128,
- rnn_dim: int = 256,
- bidirectional: bool = True,
- rnn_type: str = "LSTM",
- mlp_dim: int = 512,
- hidden_activation: str = "Tanh",
- hidden_activation_kwargs: Optional[Dict] = None,
- complex_mask: bool = True,
- n_fft: int = 2048,
- win_length: Optional[int] = 2048,
- hop_length: int = 512,
- window_fn: str = "hann_window",
- wkwargs: Optional[Dict] = None,
- power: Optional[int] = None,
- center: bool = True,
- normalized: bool = True,
- pad_mode: str = "constant",
- onesided: bool = True,
- n_bands: int = None,
+ self,
+ in_channel: int,
+ stems: List[str],
+ band_specs: Union[str, List[Tuple[float, float]]],
+ kernel_norm_mlp_version: int = 1,
+ mask_kernel_freq: int = 3,
+ mask_kernel_time: int = 3,
+ conv_kernel_freq: int = 1,
+ conv_kernel_time: int = 1,
+ fs: int = 44100,
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ mlp_dim: int = 512,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ n_fft: int = 2048,
+ win_length: Optional[int] = 2048,
+ hop_length: int = 512,
+ window_fn: str = "hann_window",
+ wkwargs: Optional[Dict] = None,
+ power: Optional[int] = None,
+ center: bool = True,
+ normalized: bool = True,
+ pad_mode: str = "constant",
+ onesided: bool = True,
+ n_bands: int = None,
) -> None:
super().__init__(
- stems=stems,
- band_specs=band_specs,
- fs=fs,
- n_fft=n_fft,
- win_length=win_length,
- hop_length=hop_length,
- window_fn=window_fn,
- wkwargs=wkwargs,
- power=power,
- center=center,
- normalized=normalized,
- pad_mode=pad_mode,
- onesided=onesided,
- n_bands=n_bands,
+ stems=stems,
+ band_specs=band_specs,
+ fs=fs,
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ power=power,
+ center=center,
+ normalized=normalized,
+ pad_mode=pad_mode,
+ onesided=onesided,
+ n_bands=n_bands,
)
self.bsrnn = MultiSourceMultiPatchingMaskBandSplitCoreRNN(
- stems=stems,
- band_specs=self.band_specs,
- in_channel=in_channel,
- require_no_overlap=require_no_overlap,
- require_no_gap=require_no_gap,
- normalize_channel_independently=normalize_channel_independently,
- treat_channel_as_feature=treat_channel_as_feature,
- n_sqm_modules=n_sqm_modules,
- emb_dim=emb_dim,
- rnn_dim=rnn_dim,
- bidirectional=bidirectional,
- rnn_type=rnn_type,
- mlp_dim=mlp_dim,
- hidden_activation=hidden_activation,
- hidden_activation_kwargs=hidden_activation_kwargs,
- complex_mask=complex_mask,
- overlapping_band=self.overlapping_band,
- freq_weights=self.freq_weights,
- n_freq=n_fft // 2 + 1,
- mask_kernel_freq=mask_kernel_freq,
- mask_kernel_time=mask_kernel_time,
- conv_kernel_freq=conv_kernel_freq,
- conv_kernel_time=conv_kernel_time,
- kernel_norm_mlp_version=kernel_norm_mlp_version,
+ stems=stems,
+ band_specs=self.band_specs,
+ in_channel=in_channel,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ n_sqm_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ mlp_dim=mlp_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ overlapping_band=self.overlapping_band,
+ freq_weights=self.freq_weights,
+ n_freq=n_fft // 2 + 1,
+ mask_kernel_freq=mask_kernel_freq,
+ mask_kernel_time=mask_kernel_time,
+ conv_kernel_freq=conv_kernel_freq,
+ conv_kernel_time=conv_kernel_time,
+ kernel_norm_mlp_version=kernel_norm_mlp_version,
)
diff --git a/programs/music_separation_code/models/bandit/core/utils/audio.py b/programs/music_separation_code/models/bandit/core/utils/audio.py
index e4066d7d..6bdea55b 100644
--- a/programs/music_separation_code/models/bandit/core/utils/audio.py
+++ b/programs/music_separation_code/models/bandit/core/utils/audio.py
@@ -1,7 +1,7 @@
from collections import defaultdict
from tqdm import tqdm
-from typing import Callable, Dict, List, Optional, Tuple
+from typing import Callable, Dict, Tuple
import numpy as np
import torch
@@ -11,19 +11,17 @@
@torch.jit.script
def merge(
- combined: torch.Tensor,
- original_batch_size: int,
- n_channel: int,
- n_chunks: int,
- chunk_size: int, ):
+ combined: torch.Tensor,
+ original_batch_size: int,
+ n_channel: int,
+ n_chunks: int,
+ chunk_size: int,
+):
combined = torch.reshape(
- combined,
- (original_batch_size, n_chunks, n_channel, chunk_size)
+ combined, (original_batch_size, n_chunks, n_channel, chunk_size)
)
combined = torch.permute(combined, (0, 2, 3, 1)).reshape(
- original_batch_size * n_channel,
- chunk_size,
- n_chunks
+ original_batch_size * n_channel, chunk_size, n_chunks
)
return combined
@@ -31,33 +29,23 @@ def merge(
@torch.jit.script
def unfold(
- padded_audio: torch.Tensor,
- original_batch_size: int,
- n_channel: int,
- chunk_size: int,
- hop_size: int
- ) -> torch.Tensor:
+ padded_audio: torch.Tensor,
+ original_batch_size: int,
+ n_channel: int,
+ chunk_size: int,
+ hop_size: int,
+) -> torch.Tensor:
unfolded_input = F.unfold(
- padded_audio[:, :, None, :],
- kernel_size=(1, chunk_size),
- stride=(1, hop_size)
+ padded_audio[:, :, None, :], kernel_size=(1, chunk_size), stride=(1, hop_size)
)
_, _, n_chunks = unfolded_input.shape
unfolded_input = unfolded_input.view(
- original_batch_size,
- n_channel,
- chunk_size,
- n_chunks
+ original_batch_size, n_channel, chunk_size, n_chunks
)
- unfolded_input = torch.permute(
- unfolded_input,
- (0, 3, 1, 2)
- ).reshape(
- original_batch_size * n_chunks,
- n_channel,
- chunk_size
+ unfolded_input = torch.permute(unfolded_input, (0, 3, 1, 2)).reshape(
+ original_batch_size * n_chunks, n_channel, chunk_size
)
return unfolded_input
@@ -66,40 +54,31 @@ def unfold(
@torch.jit.script
# @torch.compile
def merge_chunks_all(
- combined: torch.Tensor,
- original_batch_size: int,
- n_channel: int,
- n_samples: int,
- n_padded_samples: int,
- n_chunks: int,
- chunk_size: int,
- hop_size: int,
- edge_frame_pad_sizes: Tuple[int, int],
- standard_window: torch.Tensor,
- first_window: torch.Tensor,
- last_window: torch.Tensor
+ combined: torch.Tensor,
+ original_batch_size: int,
+ n_channel: int,
+ n_samples: int,
+ n_padded_samples: int,
+ n_chunks: int,
+ chunk_size: int,
+ hop_size: int,
+ edge_frame_pad_sizes: Tuple[int, int],
+ standard_window: torch.Tensor,
+ first_window: torch.Tensor,
+ last_window: torch.Tensor,
):
- combined = merge(
- combined,
- original_batch_size,
- n_channel,
- n_chunks,
- chunk_size
- )
+ combined = merge(combined, original_batch_size, n_channel, n_chunks, chunk_size)
combined = combined * standard_window[:, None].to(combined.device)
combined = F.fold(
- combined.to(torch.float32), output_size=(1, n_padded_samples),
- kernel_size=(1, chunk_size),
- stride=(1, hop_size)
+ combined.to(torch.float32),
+ output_size=(1, n_padded_samples),
+ kernel_size=(1, chunk_size),
+ stride=(1, hop_size),
)
- combined = combined.view(
- original_batch_size,
- n_channel,
- n_padded_samples
- )
+ combined = combined.view(original_batch_size, n_channel, n_padded_samples)
pad_front, pad_back = edge_frame_pad_sizes
combined = combined[..., pad_front:-pad_back]
@@ -112,43 +91,33 @@ def merge_chunks_all(
def merge_chunks_edge(
- combined: torch.Tensor,
- original_batch_size: int,
- n_channel: int,
- n_samples: int,
- n_padded_samples: int,
- n_chunks: int,
- chunk_size: int,
- hop_size: int,
- edge_frame_pad_sizes: Tuple[int, int],
- standard_window: torch.Tensor,
- first_window: torch.Tensor,
- last_window: torch.Tensor
+ combined: torch.Tensor,
+ original_batch_size: int,
+ n_channel: int,
+ n_samples: int,
+ n_padded_samples: int,
+ n_chunks: int,
+ chunk_size: int,
+ hop_size: int,
+ edge_frame_pad_sizes: Tuple[int, int],
+ standard_window: torch.Tensor,
+ first_window: torch.Tensor,
+ last_window: torch.Tensor,
):
- combined = merge(
- combined,
- original_batch_size,
- n_channel,
- n_chunks,
- chunk_size
- )
+ combined = merge(combined, original_batch_size, n_channel, n_chunks, chunk_size)
combined[..., 0] = combined[..., 0] * first_window
combined[..., -1] = combined[..., -1] * last_window
- combined[..., 1:-1] = combined[...,
- 1:-1] * standard_window[:, None]
+ combined[..., 1:-1] = combined[..., 1:-1] * standard_window[:, None]
combined = F.fold(
- combined, output_size=(1, n_padded_samples),
- kernel_size=(1, chunk_size),
- stride=(1, hop_size)
+ combined,
+ output_size=(1, n_padded_samples),
+ kernel_size=(1, chunk_size),
+ stride=(1, hop_size),
)
- combined = combined.view(
- original_batch_size,
- n_channel,
- n_padded_samples
- )
+ combined = combined.view(original_batch_size, n_channel, n_padded_samples)
combined = combined[..., :n_samples]
@@ -157,12 +126,12 @@ def merge_chunks_edge(
class BaseFader(nn.Module):
def __init__(
- self,
- chunk_size_second: float,
- hop_size_second: float,
- fs: int,
- fade_edge_frames: bool,
- batch_size: int,
+ self,
+ chunk_size_second: float,
+ hop_size_second: float,
+ fs: int,
+ fade_edge_frames: bool,
+ batch_size: int,
) -> None:
super().__init__()
@@ -179,9 +148,7 @@ def prepare(self, audio):
audio = F.pad(audio, self.edge_frame_pad_sizes, mode="reflect")
n_samples = audio.shape[-1]
- n_chunks = int(
- np.ceil((n_samples - self.chunk_size) / self.hop_size) + 1
- )
+ n_chunks = int(np.ceil((n_samples - self.chunk_size) / self.hop_size) + 1)
padded_size = (n_chunks - 1) * self.hop_size + self.chunk_size
pad_size = padded_size - n_samples
@@ -191,9 +158,9 @@ def prepare(self, audio):
return padded_audio, n_chunks
def forward(
- self,
- audio: torch.Tensor,
- model_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]],
+ self,
+ audio: torch.Tensor,
+ model_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]],
):
original_dtype = audio.dtype
@@ -208,14 +175,11 @@ def forward(
if n_channel > 1:
padded_audio = padded_audio.view(
- original_batch_size * n_channel, 1, n_padded_samples
+ original_batch_size * n_channel, 1, n_padded_samples
)
unfolded_input = unfold(
- padded_audio,
- original_batch_size,
- n_channel,
- self.chunk_size, self.hop_size
+ padded_audio, original_batch_size, n_channel, self.chunk_size, self.hop_size
)
n_total_chunks, n_channel, chunk_size = unfolded_input.shape
@@ -223,15 +187,12 @@ def forward(
n_batch = np.ceil(n_total_chunks / self.batch_size).astype(int)
chunks_in = [
- unfolded_input[
- b * self.batch_size:(b + 1) * self.batch_size, ...].clone()
- for b in range(n_batch)
+ unfolded_input[b * self.batch_size : (b + 1) * self.batch_size, ...].clone()
+ for b in range(n_batch)
]
all_chunks_out = defaultdict(
- lambda: torch.zeros_like(
- unfolded_input, device="cpu"
- )
+ lambda: torch.zeros_like(unfolded_input, device="cpu")
)
# for b, cin in enumerate(tqdm(chunks_in)):
@@ -243,8 +204,9 @@ def forward(
chunks_out = model_fn(cin.to(original_device))
del cin
for s, c in chunks_out.items():
- all_chunks_out[s][b * self.batch_size:(b + 1) * self.batch_size,
- ...] = c.cpu()
+ all_chunks_out[s][
+ b * self.batch_size : (b + 1) * self.batch_size, ...
+ ] = c.cpu()
del chunks_out
del unfolded_input
@@ -260,28 +222,24 @@ def forward(
for s, c in all_chunks_out.items():
combined: torch.Tensor = fn(
- c,
- original_batch_size,
- n_channel,
- n_samples,
- n_padded_samples,
- n_chunks,
- self.chunk_size,
- self.hop_size,
- self.edge_frame_pad_sizes,
- self.standard_window,
- self.__dict__.get("first_window", self.standard_window),
- self.__dict__.get("last_window", self.standard_window)
+ c,
+ original_batch_size,
+ n_channel,
+ n_samples,
+ n_padded_samples,
+ n_chunks,
+ self.chunk_size,
+ self.hop_size,
+ self.edge_frame_pad_sizes,
+ self.standard_window,
+ self.__dict__.get("first_window", self.standard_window),
+ self.__dict__.get("last_window", self.standard_window),
)
- outputs[s] = combined.to(
- dtype=original_dtype,
- device=original_device
- )
+ outputs[s] = combined.to(dtype=original_dtype, device=original_device)
+
+ return {"audio": outputs}
- return {
- "audio": outputs
- }
#
# def old_forward(
# self,
@@ -366,22 +324,22 @@ def forward(
class LinearFader(BaseFader):
def __init__(
- self,
- chunk_size_second: float,
- hop_size_second: float,
- fs: int,
- fade_edge_frames: bool = False,
- batch_size: int = 1,
+ self,
+ chunk_size_second: float,
+ hop_size_second: float,
+ fs: int,
+ fade_edge_frames: bool = False,
+ batch_size: int = 1,
) -> None:
assert hop_size_second >= chunk_size_second / 2
super().__init__(
- chunk_size_second=chunk_size_second,
- hop_size_second=hop_size_second,
- fs=fs,
- fade_edge_frames=fade_edge_frames,
- batch_size=batch_size,
+ chunk_size_second=chunk_size_second,
+ hop_size_second=hop_size_second,
+ fs=fs,
+ fade_edge_frames=fade_edge_frames,
+ batch_size=batch_size,
)
in_fade = torch.linspace(0.0, 1.0, self.overlap_size + 1)[:-1]
@@ -391,8 +349,7 @@ def __init__(
# using nn.Parameters allows lightning to take care of devices for us
self.register_buffer(
- "standard_window",
- torch.concat([in_fade, center_ones, out_fade])
+ "standard_window", torch.concat([in_fade, center_ones, out_fade])
)
self.fade_edge_frames = fade_edge_frames
@@ -400,23 +357,21 @@ def __init__(
if not self.fade_edge_frames:
self.first_window = nn.Parameter(
- torch.concat([inout_ones, center_ones, out_fade]),
- requires_grad=False
+ torch.concat([inout_ones, center_ones, out_fade]), requires_grad=False
)
self.last_window = nn.Parameter(
- torch.concat([in_fade, center_ones, inout_ones]),
- requires_grad=False
+ torch.concat([in_fade, center_ones, inout_ones]), requires_grad=False
)
class OverlapAddFader(BaseFader):
def __init__(
- self,
- window_type: str,
- chunk_size_second: float,
- hop_size_second: float,
- fs: int,
- batch_size: int = 1,
+ self,
+ window_type: str,
+ chunk_size_second: float,
+ hop_size_second: float,
+ fs: int,
+ batch_size: int = 1,
) -> None:
assert (chunk_size_second / hop_size_second) % 2 == 0
assert int(chunk_size_second * fs) % 2 == 0
@@ -432,31 +387,25 @@ def __init__(
self.hop_multiplier = self.chunk_size / (2 * self.hop_size)
# print(f"hop multiplier: {self.hop_multiplier}")
- self.edge_frame_pad_sizes = (
- 2 * self.overlap_size,
- 2 * self.overlap_size
- )
+ self.edge_frame_pad_sizes = (2 * self.overlap_size, 2 * self.overlap_size)
self.register_buffer(
- "standard_window", torch.windows.__dict__[window_type](
- self.chunk_size, sym=False, # dtype=torch.float64
- ) / self.hop_multiplier
+ "standard_window",
+ torch.windows.__dict__[window_type](
+ self.chunk_size,
+ sym=False, # dtype=torch.float64
+ )
+ / self.hop_multiplier,
)
if __name__ == "__main__":
import torchaudio as ta
+
fs = 44100
- ola = OverlapAddFader(
- "hann",
- 6.0,
- 1.0,
- fs,
- batch_size=16
- )
+ ola = OverlapAddFader("hann", 6.0, 1.0, fs, batch_size=16)
audio_, _ = ta.load(
- "$DATA_ROOT/MUSDB18/HQ/canonical/test/BKS - Too "
- "Much/vocals.wav"
+ "$DATA_ROOT/MUSDB18/HQ/canonical/test/BKS - Too " "Much/vocals.wav"
)
audio_ = audio_[None, ...]
out = ola(audio_, lambda x: {"stem": x})["audio"]["stem"]
diff --git a/programs/music_separation_code/models/bandit/model_from_config.py b/programs/music_separation_code/models/bandit/model_from_config.py
index 00ea586d..9735bda0 100644
--- a/programs/music_separation_code/models/bandit/model_from_config.py
+++ b/programs/music_separation_code/models/bandit/model_from_config.py
@@ -2,7 +2,7 @@
import os.path
import torch
-code_path = os.path.dirname(os.path.abspath(__file__)) + '/'
+code_path = os.path.dirname(os.path.abspath(__file__)) + "/"
sys.path.append(code_path)
import yaml
@@ -22,10 +22,8 @@ def get_model(
config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
f.close()
- model = MultiMaskMultiSourceBandSplitRNNSimple(
- **config.model
- )
- d = torch.load(code_path + 'model_bandit_plus_dnr_sdr_11.47.chpt')
+ model = MultiMaskMultiSourceBandSplitRNNSimple(**config.model)
+ d = torch.load(code_path + "model_bandit_plus_dnr_sdr_11.47.chpt")
model.load_state_dict(d)
model.to(device)
return model, config
diff --git a/programs/music_separation_code/models/bandit_v2/bandit.py b/programs/music_separation_code/models/bandit_v2/bandit.py
index ac4e13f4..fba32962 100644
--- a/programs/music_separation_code/models/bandit_v2/bandit.py
+++ b/programs/music_separation_code/models/bandit_v2/bandit.py
@@ -11,7 +11,6 @@
from .utils import MusicalBandsplitSpecification
-
class BaseEndToEndModule(pl.LightningModule):
def __init__(
self,
@@ -178,12 +177,12 @@ def instantiate_tf_modelling(
)
except Exception as e:
self.tf_model = SeqBandModellingModule(
- n_modules=n_sqm_modules,
- emb_dim=emb_dim,
- rnn_dim=rnn_dim,
- bidirectional=bidirectional,
- rnn_type=rnn_type,
- )
+ n_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ )
def mask(self, x, m):
return x * m
@@ -193,11 +192,7 @@ def forward(self, batch, mode="train"):
init_shape = batch.shape
if not isinstance(batch, dict):
mono = batch.view(-1, 1, batch.shape[-1])
- batch = {
- "mixture": {
- "audio": mono
- }
- }
+ batch = {"mixture": {"audio": mono}}
with torch.no_grad():
mixture = batch["mixture"]["audio"]
@@ -217,7 +212,9 @@ def forward(self, batch, mode="train"):
b = []
for s in self.stems:
# We need to obtain stereo again
- r = batch['estimates'][s]['audio'].view(-1, init_shape[1], init_shape[2])
+ r = batch["estimates"][s]["audio"].view(
+ -1, init_shape[1], init_shape[2]
+ )
b.append(r)
# And we need to return back tensor and not independent stems
batch = torch.stack(b, dim=1)
@@ -364,4 +361,3 @@ def separate(self, batch):
}
return batch
-
diff --git a/programs/music_separation_code/models/bandit_v2/film.py b/programs/music_separation_code/models/bandit_v2/film.py
index e3079533..253594ad 100644
--- a/programs/music_separation_code/models/bandit_v2/film.py
+++ b/programs/music_separation_code/models/bandit_v2/film.py
@@ -1,10 +1,11 @@
from torch import nn
import torch
+
class FiLM(nn.Module):
def __init__(self):
super().__init__()
-
+
def forward(self, x, gamma, beta):
return gamma * x + beta
@@ -13,13 +14,10 @@ class BTFBroadcastedFiLM(nn.Module):
def __init__(self):
super().__init__()
self.film = FiLM()
-
+
def forward(self, x, gamma, beta):
-
+
gamma = gamma[None, None, None, :]
beta = beta[None, None, None, :]
-
+
return self.film(x, gamma, beta)
-
-
-
\ No newline at end of file
diff --git a/programs/music_separation_code/models/bs_roformer/attend.py b/programs/music_separation_code/models/bs_roformer/attend.py
index d6dc4b30..9ebb7c93 100644
--- a/programs/music_separation_code/models/bs_roformer/attend.py
+++ b/programs/music_separation_code/models/bs_roformer/attend.py
@@ -11,18 +11,24 @@
# constants
-FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
+FlashAttentionConfig = namedtuple(
+ "FlashAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"]
+)
# helpers
+
def exists(val):
return val is not None
+
def default(v, d):
return v if exists(v) else d
+
def once(fn):
called = False
+
@wraps(fn)
def inner(x):
nonlocal called
@@ -30,26 +36,26 @@ def inner(x):
return
called = True
return fn(x)
+
return inner
+
print_once = once(print)
# main class
+
class Attend(nn.Module):
- def __init__(
- self,
- dropout = 0.,
- flash = False,
- scale = None
- ):
+ def __init__(self, dropout=0.0, flash=False, scale=None):
super().__init__()
self.scale = scale
self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)
self.flash = flash
- assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
+ assert not (
+ flash and version.parse(torch.__version__) < version.parse("2.0.0")
+ ), "in order to use flash attention, you must be using pytorch 2.0 or above"
# determine efficient attention configs for cuda and cpu
@@ -59,22 +65,35 @@ def __init__(
if not torch.cuda.is_available() or not flash:
return
- device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
- device_version = version.parse(f'{device_properties.major}.{device_properties.minor}')
+ device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
+ device_version = version.parse(
+ f"{device_properties.major}.{device_properties.minor}"
+ )
- if device_version >= version.parse('8.0'):
- if os.name == 'nt':
- print_once('Windows OS detected, using math or mem efficient attention if input tensor is on cuda')
+ if device_version >= version.parse("8.0"):
+ if os.name == "nt":
+ print_once(
+ "Windows OS detected, using math or mem efficient attention if input tensor is on cuda"
+ )
self.cuda_config = FlashAttentionConfig(False, True, True)
else:
- print_once('GPU Compute Capability equal or above 8.0, using flash attention if input tensor is on cuda')
+ print_once(
+ "GPU Compute Capability equal or above 8.0, using flash attention if input tensor is on cuda"
+ )
self.cuda_config = FlashAttentionConfig(True, False, False)
else:
- print_once('GPU Compute Capability below 8.0, using math or mem efficient attention if input tensor is on cuda')
+ print_once(
+ "GPU Compute Capability below 8.0, using math or mem efficient attention if input tensor is on cuda"
+ )
self.cuda_config = FlashAttentionConfig(False, True, True)
def flash_attn(self, q, k, v):
- _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
+ _, heads, q_len, _, k_len, is_cuda, device = (
+ *q.shape,
+ k.shape[-2],
+ q.is_cuda,
+ q.device,
+ )
if exists(self.scale):
default_scale = q.shape[-1] ** -0.5
@@ -88,8 +107,7 @@ def flash_attn(self, q, k, v):
with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
- q, k, v,
- dropout_p = self.dropout if self.training else 0.
+ q, k, v, dropout_p=self.dropout if self.training else 0.0
)
return out
diff --git a/programs/music_separation_code/models/bs_roformer/bs_roformer.py b/programs/music_separation_code/models/bs_roformer/bs_roformer.py
index 2fda0cc9..3ed15445 100644
--- a/programs/music_separation_code/models/bs_roformer/bs_roformer.py
+++ b/programs/music_separation_code/models/bs_roformer/bs_roformer.py
@@ -17,6 +17,7 @@
# helper functions
+
def exists(val):
return val is not None
@@ -35,14 +36,15 @@ def unpack_one(t, ps, pattern):
# norm
+
def l2norm(t):
- return F.normalize(t, dim = -1, p = 2)
+ return F.normalize(t, dim=-1, p=2)
class RMSNorm(Module):
def __init__(self, dim):
super().__init__()
- self.scale = dim ** 0.5
+ self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(dim))
def forward(self, x):
@@ -51,13 +53,9 @@ def forward(self, x):
# attention
+
class FeedForward(Module):
- def __init__(
- self,
- dim,
- mult=4,
- dropout=0.
- ):
+ def __init__(self, dim, mult=4, dropout=0.0):
super().__init__()
dim_inner = int(dim * mult)
self.net = nn.Sequential(
@@ -66,7 +64,7 @@ def __init__(
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim_inner, dim),
- nn.Dropout(dropout)
+ nn.Dropout(dropout),
)
def forward(self, x):
@@ -75,17 +73,11 @@ def forward(self, x):
class Attention(Module):
def __init__(
- self,
- dim,
- heads=8,
- dim_head=64,
- dropout=0.,
- rotary_embed=None,
- flash=True
+ self, dim, heads=8, dim_head=64, dropout=0.0, rotary_embed=None, flash=True
):
super().__init__()
self.heads = heads
- self.scale = dim_head ** -0.5
+ self.scale = dim_head**-0.5
dim_inner = heads * dim_head
self.rotary_embed = rotary_embed
@@ -98,14 +90,15 @@ def __init__(
self.to_gates = nn.Linear(dim, heads)
self.to_out = nn.Sequential(
- nn.Linear(dim_inner, dim, bias=False),
- nn.Dropout(dropout)
+ nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout)
)
def forward(self, x):
x = self.norm(x)
- q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
+ q, k, v = rearrange(
+ self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads
+ )
if exists(self.rotary_embed):
q = self.rotary_embed.rotate_queries_or_keys(q)
@@ -114,9 +107,9 @@ def forward(self, x):
out = self.attend(q, k, v)
gates = self.to_gates(x)
- out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
+ out = out * rearrange(gates, "b n h -> b h n 1").sigmoid()
- out = rearrange(out, 'b h n d -> b n (h d)')
+ out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
@@ -126,42 +119,25 @@ class LinearAttention(Module):
"""
@beartype
- def __init__(
- self,
- *,
- dim,
- dim_head=32,
- heads=8,
- scale=8,
- flash=False,
- dropout=0.
- ):
+ def __init__(self, *, dim, dim_head=32, heads=8, scale=8, flash=False, dropout=0.0):
super().__init__()
dim_inner = dim_head * heads
self.norm = RMSNorm(dim)
self.to_qkv = nn.Sequential(
nn.Linear(dim, dim_inner * 3, bias=False),
- Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads)
+ Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads),
)
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
- self.attend = Attend(
- scale=scale,
- dropout=dropout,
- flash=flash
- )
+ self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
self.to_out = nn.Sequential(
- Rearrange('b h d n -> b n (h d)'),
- nn.Linear(dim_inner, dim, bias=False)
+ Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False)
)
- def forward(
- self,
- x
- ):
+ def forward(self, x):
x = self.norm(x)
q, k, v = self.to_qkv(x)
@@ -176,34 +152,47 @@ def forward(
class Transformer(Module):
def __init__(
- self,
- *,
- dim,
- depth,
- dim_head=64,
- heads=8,
- attn_dropout=0.,
- ff_dropout=0.,
- ff_mult=4,
- norm_output=True,
- rotary_embed=None,
- flash_attn=True,
- linear_attn=False
+ self,
+ *,
+ dim,
+ depth,
+ dim_head=64,
+ heads=8,
+ attn_dropout=0.0,
+ ff_dropout=0.0,
+ ff_mult=4,
+ norm_output=True,
+ rotary_embed=None,
+ flash_attn=True,
+ linear_attn=False,
):
super().__init__()
self.layers = ModuleList([])
for _ in range(depth):
if linear_attn:
- attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn)
+ attn = LinearAttention(
+ dim=dim,
+ dim_head=dim_head,
+ heads=heads,
+ dropout=attn_dropout,
+ flash=flash_attn,
+ )
else:
- attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout,
- rotary_embed=rotary_embed, flash=flash_attn)
-
- self.layers.append(ModuleList([
- attn,
- FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
- ]))
+ attn = Attention(
+ dim=dim,
+ dim_head=dim_head,
+ heads=heads,
+ dropout=attn_dropout,
+ rotary_embed=rotary_embed,
+ flash=flash_attn,
+ )
+
+ self.layers.append(
+ ModuleList(
+ [attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]
+ )
+ )
self.norm = RMSNorm(dim) if norm_output else nn.Identity()
@@ -218,22 +207,16 @@ def forward(self, x):
# bandsplit module
+
class BandSplit(Module):
@beartype
- def __init__(
- self,
- dim,
- dim_inputs: Tuple[int, ...]
- ):
+ def __init__(self, dim, dim_inputs: Tuple[int, ...]):
super().__init__()
self.dim_inputs = dim_inputs
self.to_features = ModuleList([])
for dim_in in dim_inputs:
- net = nn.Sequential(
- RMSNorm(dim_in),
- nn.Linear(dim_in, dim)
- )
+ net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
self.to_features.append(net)
@@ -248,13 +231,7 @@ def forward(self, x):
return torch.stack(outs, dim=-2)
-def MLP(
- dim_in,
- dim_out,
- dim_hidden=None,
- depth=1,
- activation=nn.Tanh
-):
+def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh):
dim_hidden = default(dim_hidden, dim_in)
net = []
@@ -275,13 +252,7 @@ def MLP(
class MaskEstimator(Module):
@beartype
- def __init__(
- self,
- dim,
- dim_inputs: Tuple[int, ...],
- depth,
- mlp_expansion_factor=4
- ):
+ def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4):
super().__init__()
self.dim_inputs = dim_inputs
self.to_freqs = ModuleList([])
@@ -291,8 +262,7 @@ def __init__(
net = []
mlp = nn.Sequential(
- MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
- nn.GLU(dim=-1)
+ MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1)
)
self.to_freqs.append(mlp)
@@ -312,14 +282,68 @@ def forward(self, x):
# main class
DEFAULT_FREQS_PER_BANDS = (
- 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
- 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
- 2, 2, 2, 2,
- 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
- 12, 12, 12, 12, 12, 12, 12, 12,
- 24, 24, 24, 24, 24, 24, 24, 24,
- 48, 48, 48, 48, 48, 48, 48, 48,
- 128, 129,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 4,
+ 4,
+ 4,
+ 4,
+ 4,
+ 4,
+ 4,
+ 4,
+ 4,
+ 4,
+ 4,
+ 4,
+ 12,
+ 12,
+ 12,
+ 12,
+ 12,
+ 12,
+ 12,
+ 12,
+ 24,
+ 24,
+ 24,
+ 24,
+ 24,
+ 24,
+ 24,
+ 24,
+ 48,
+ 48,
+ 48,
+ 48,
+ 48,
+ 48,
+ 48,
+ 48,
+ 128,
+ 129,
)
@@ -327,35 +351,41 @@ class BSRoformer(Module):
@beartype
def __init__(
- self,
- dim,
- *,
- depth,
- stereo=False,
- num_stems=1,
- time_transformer_depth=2,
- freq_transformer_depth=2,
- linear_transformer_depth=0,
- freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
- # in the paper, they divide into ~60 bands, test with 1 for starters
- dim_head=64,
- heads=8,
- attn_dropout=0.,
- ff_dropout=0.,
- flash_attn=True,
- dim_freqs_in=1025,
- stft_n_fft=2048,
- stft_hop_length=512,
- # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
- stft_win_length=2048,
- stft_normalized=False,
- stft_window_fn: Optional[Callable] = None,
- mask_estimator_depth=2,
- multi_stft_resolution_loss_weight=1.,
- multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
- multi_stft_hop_size=147,
- multi_stft_normalized=False,
- multi_stft_window_fn: Callable = torch.hann_window
+ self,
+ dim,
+ *,
+ depth,
+ stereo=False,
+ num_stems=1,
+ time_transformer_depth=2,
+ freq_transformer_depth=2,
+ linear_transformer_depth=0,
+ freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
+ # in the paper, they divide into ~60 bands, test with 1 for starters
+ dim_head=64,
+ heads=8,
+ attn_dropout=0.0,
+ ff_dropout=0.0,
+ flash_attn=True,
+ dim_freqs_in=1025,
+ stft_n_fft=2048,
+ stft_hop_length=512,
+ # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
+ stft_win_length=2048,
+ stft_normalized=False,
+ stft_window_fn: Optional[Callable] = None,
+ mask_estimator_depth=2,
+ multi_stft_resolution_loss_weight=1.0,
+ multi_stft_resolutions_window_sizes: Tuple[int, ...] = (
+ 4096,
+ 2048,
+ 1024,
+ 512,
+ 256,
+ ),
+ multi_stft_hop_size=147,
+ multi_stft_normalized=False,
+ multi_stft_window_fn: Callable = torch.hann_window,
):
super().__init__()
@@ -372,7 +402,7 @@ def __init__(
attn_dropout=attn_dropout,
ff_dropout=ff_dropout,
flash_attn=flash_attn,
- norm_output=False
+ norm_output=False,
)
time_rotary_embed = RotaryEmbedding(dim=dim_head)
@@ -381,12 +411,26 @@ def __init__(
for _ in range(depth):
tran_modules = []
if linear_transformer_depth > 0:
- tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs))
+ tran_modules.append(
+ Transformer(
+ depth=linear_transformer_depth,
+ linear_attn=True,
+ **transformer_kwargs,
+ )
+ )
tran_modules.append(
- Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs)
+ Transformer(
+ depth=time_transformer_depth,
+ rotary_embed=time_rotary_embed,
+ **transformer_kwargs,
+ )
)
tran_modules.append(
- Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs)
+ Transformer(
+ depth=freq_transformer_depth,
+ rotary_embed=freq_rotary_embed,
+ **transformer_kwargs,
+ )
)
self.layers.append(nn.ModuleList(tran_modules))
@@ -396,31 +440,38 @@ def __init__(
n_fft=stft_n_fft,
hop_length=stft_hop_length,
win_length=stft_win_length,
- normalized=stft_normalized
+ normalized=stft_normalized,
)
- self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
+ self.stft_window_fn = partial(
+ default(stft_window_fn, torch.hann_window), stft_win_length
+ )
- freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_n_fft), return_complex=True).shape[1]
+ freqs = torch.stft(
+ torch.randn(1, 4096),
+ **self.stft_kwargs,
+ window=torch.ones(stft_n_fft),
+ return_complex=True,
+ ).shape[1]
assert len(freqs_per_bands) > 1
- assert sum(
- freqs_per_bands) == freqs, f'the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}'
+ assert (
+ sum(freqs_per_bands) == freqs
+ ), f"the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}"
- freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in freqs_per_bands)
-
- self.band_split = BandSplit(
- dim=dim,
- dim_inputs=freqs_per_bands_with_complex
+ freqs_per_bands_with_complex = tuple(
+ 2 * f * self.audio_channels for f in freqs_per_bands
)
+ self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
+
self.mask_estimators = nn.ModuleList([])
for _ in range(num_stems):
mask_estimator = MaskEstimator(
dim=dim,
dim_inputs=freqs_per_bands_with_complex,
- depth=mask_estimator_depth
+ depth=mask_estimator_depth,
)
self.mask_estimators.append(mask_estimator)
@@ -433,16 +484,10 @@ def __init__(
self.multi_stft_window_fn = multi_stft_window_fn
self.multi_stft_kwargs = dict(
- hop_length=multi_stft_hop_size,
- normalized=multi_stft_normalized
+ hop_length=multi_stft_hop_size, normalized=multi_stft_normalized
)
- def forward(
- self,
- raw_audio,
- target=None,
- return_loss_breakdown=False
- ):
+ def forward(self, raw_audio, target=None, return_loss_breakdown=False):
"""
einops
@@ -461,32 +506,41 @@ def forward(
x_is_mps = True if device.type == "mps" else False
if raw_audio.ndim == 2:
- raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
+ raw_audio = rearrange(raw_audio, "b t -> b 1 t")
channels = raw_audio.shape[1]
assert (not self.stereo and channels == 1) or (
- self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
+ self.stereo and channels == 2
+ ), "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
# to stft
- raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
+ raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
stft_window = self.stft_window_fn(device=device)
# RuntimeError: FFT operations are only supported on MacOS 14+
# Since it's tedious to define whether we're on correct MacOS version - simple try-catch is used
try:
- stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
+ stft_repr = torch.stft(
+ raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True
+ )
except:
- stft_repr = torch.stft(raw_audio.cpu() if x_is_mps else raw_audio, **self.stft_kwargs, window=stft_window.cpu() if x_is_mps else stft_window, return_complex=True).to(device)
+ stft_repr = torch.stft(
+ raw_audio.cpu() if x_is_mps else raw_audio,
+ **self.stft_kwargs,
+ window=stft_window.cpu() if x_is_mps else stft_window,
+ return_complex=True,
+ ).to(device)
stft_repr = torch.view_as_real(stft_repr)
- stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
- stft_repr = rearrange(stft_repr,
- 'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
+ stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
+ stft_repr = rearrange(
+ stft_repr, "b s f t c -> b (f s) t c"
+ ) # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
- x = rearrange(stft_repr, 'b f t c -> b t (f c)')
+ x = rearrange(stft_repr, "b f t c -> b t (f c)")
x = self.band_split(x)
@@ -495,37 +549,39 @@ def forward(
for transformer_block in self.layers:
if len(transformer_block) == 3:
- linear_transformer, time_transformer, freq_transformer = transformer_block
+ linear_transformer, time_transformer, freq_transformer = (
+ transformer_block
+ )
- x, ft_ps = pack([x], 'b * d')
+ x, ft_ps = pack([x], "b * d")
x = linear_transformer(x)
- x, = unpack(x, ft_ps, 'b * d')
+ (x,) = unpack(x, ft_ps, "b * d")
else:
time_transformer, freq_transformer = transformer_block
- x = rearrange(x, 'b t f d -> b f t d')
- x, ps = pack([x], '* t d')
+ x = rearrange(x, "b t f d -> b f t d")
+ x, ps = pack([x], "* t d")
x = time_transformer(x)
- x, = unpack(x, ps, '* t d')
- x = rearrange(x, 'b f t d -> b t f d')
- x, ps = pack([x], '* f d')
+ (x,) = unpack(x, ps, "* t d")
+ x = rearrange(x, "b f t d -> b t f d")
+ x, ps = pack([x], "* f d")
x = freq_transformer(x)
- x, = unpack(x, ps, '* f d')
+ (x,) = unpack(x, ps, "* f d")
x = self.final_norm(x)
num_stems = len(self.mask_estimators)
mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
- mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2)
+ mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2)
# modulate frequency representation
- stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
+ stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
# complex number multiplication
@@ -536,18 +592,29 @@ def forward(
# istft
- stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
+ stft_repr = rearrange(
+ stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels
+ )
# same as torch.stft() fix for MacOS MPS above
try:
- recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False)
+ recon_audio = torch.istft(
+ stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False
+ )
except:
- recon_audio = torch.istft(stft_repr.cpu() if x_is_mps else stft_repr, **self.stft_kwargs, window=stft_window.cpu() if x_is_mps else stft_window, return_complex=False).to(device)
-
- recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s=self.audio_channels, n=num_stems)
+ recon_audio = torch.istft(
+ stft_repr.cpu() if x_is_mps else stft_repr,
+ **self.stft_kwargs,
+ window=stft_window.cpu() if x_is_mps else stft_window,
+ return_complex=False,
+ ).to(device)
+
+ recon_audio = rearrange(
+ recon_audio, "(b n s) t -> b n s t", s=self.audio_channels, n=num_stems
+ )
if num_stems == 1:
- recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
+ recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
# if a target is passed in, calculate loss for learning
@@ -558,33 +625,45 @@ def forward(
assert target.ndim == 4 and target.shape[1] == self.num_stems
if target.ndim == 2:
- target = rearrange(target, '... t -> ... 1 t')
+ target = rearrange(target, "... t -> ... 1 t")
- target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft
+ target = target[
+ ..., : recon_audio.shape[-1]
+ ] # protect against lost length on istft
loss = F.l1_loss(recon_audio, target)
- multi_stft_resolution_loss = 0.
+ multi_stft_resolution_loss = 0.0
for window_size in self.multi_stft_resolutions_window_sizes:
res_stft_kwargs = dict(
- n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft
+ n_fft=max(
+ window_size, self.multi_stft_n_fft
+ ), # not sure what n_fft is across multi resolution stft
win_length=window_size,
return_complex=True,
window=self.multi_stft_window_fn(window_size, device=device),
**self.multi_stft_kwargs,
)
- recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs)
- target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs)
+ recon_Y = torch.stft(
+ rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs
+ )
+ target_Y = torch.stft(
+ rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs
+ )
- multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
+ multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(
+ recon_Y, target_Y
+ )
- weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
+ weighted_multi_resolution_loss = (
+ multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
+ )
total_loss = loss + weighted_multi_resolution_loss
if not return_loss_breakdown:
return total_loss
- return total_loss, (loss, multi_stft_resolution_loss)
\ No newline at end of file
+ return total_loss, (loss, multi_stft_resolution_loss)
diff --git a/programs/music_separation_code/models/bs_roformer/mel_band_roformer.py b/programs/music_separation_code/models/bs_roformer/mel_band_roformer.py
index 3ce7fe14..105ced15 100644
--- a/programs/music_separation_code/models/bs_roformer/mel_band_roformer.py
+++ b/programs/music_separation_code/models/bs_roformer/mel_band_roformer.py
@@ -20,6 +20,7 @@
# helper functions
+
def exists(val):
return val is not None
@@ -36,9 +37,9 @@ def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
-def pad_at_dim(t, pad, dim=-1, value=0.):
- dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
- zeros = ((0, 0) * dims_from_right)
+def pad_at_dim(t, pad, dim=-1, value=0.0):
+ dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
+ zeros = (0, 0) * dims_from_right
return F.pad(t, (*zeros, *pad), value=value)
@@ -48,10 +49,11 @@ def l2norm(t):
# norm
+
class RMSNorm(Module):
def __init__(self, dim):
super().__init__()
- self.scale = dim ** 0.5
+ self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(dim))
def forward(self, x):
@@ -60,13 +62,9 @@ def forward(self, x):
# attention
+
class FeedForward(Module):
- def __init__(
- self,
- dim,
- mult=4,
- dropout=0.
- ):
+ def __init__(self, dim, mult=4, dropout=0.0):
super().__init__()
dim_inner = int(dim * mult)
self.net = nn.Sequential(
@@ -75,7 +73,7 @@ def __init__(
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim_inner, dim),
- nn.Dropout(dropout)
+ nn.Dropout(dropout),
)
def forward(self, x):
@@ -84,17 +82,11 @@ def forward(self, x):
class Attention(Module):
def __init__(
- self,
- dim,
- heads=8,
- dim_head=64,
- dropout=0.,
- rotary_embed=None,
- flash=True
+ self, dim, heads=8, dim_head=64, dropout=0.0, rotary_embed=None, flash=True
):
super().__init__()
self.heads = heads
- self.scale = dim_head ** -0.5
+ self.scale = dim_head**-0.5
dim_inner = heads * dim_head
self.rotary_embed = rotary_embed
@@ -107,14 +99,15 @@ def __init__(
self.to_gates = nn.Linear(dim, heads)
self.to_out = nn.Sequential(
- nn.Linear(dim_inner, dim, bias=False),
- nn.Dropout(dropout)
+ nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout)
)
def forward(self, x):
x = self.norm(x)
- q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
+ q, k, v = rearrange(
+ self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads
+ )
if exists(self.rotary_embed):
q = self.rotary_embed.rotate_queries_or_keys(q)
@@ -123,9 +116,9 @@ def forward(self, x):
out = self.attend(q, k, v)
gates = self.to_gates(x)
- out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
+ out = out * rearrange(gates, "b n h -> b h n 1").sigmoid()
- out = rearrange(out, 'b h n d -> b n (h d)')
+ out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
@@ -135,42 +128,25 @@ class LinearAttention(Module):
"""
@beartype
- def __init__(
- self,
- *,
- dim,
- dim_head=32,
- heads=8,
- scale=8,
- flash=False,
- dropout=0.
- ):
+ def __init__(self, *, dim, dim_head=32, heads=8, scale=8, flash=False, dropout=0.0):
super().__init__()
dim_inner = dim_head * heads
self.norm = RMSNorm(dim)
self.to_qkv = nn.Sequential(
nn.Linear(dim, dim_inner * 3, bias=False),
- Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads)
+ Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads),
)
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
- self.attend = Attend(
- scale=scale,
- dropout=dropout,
- flash=flash
- )
+ self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
self.to_out = nn.Sequential(
- Rearrange('b h d n -> b n (h d)'),
- nn.Linear(dim_inner, dim, bias=False)
+ Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False)
)
- def forward(
- self,
- x
- ):
+ def forward(self, x):
x = self.norm(x)
q, k, v = self.to_qkv(x)
@@ -185,34 +161,47 @@ def forward(
class Transformer(Module):
def __init__(
- self,
- *,
- dim,
- depth,
- dim_head=64,
- heads=8,
- attn_dropout=0.,
- ff_dropout=0.,
- ff_mult=4,
- norm_output=True,
- rotary_embed=None,
- flash_attn=True,
- linear_attn=False
+ self,
+ *,
+ dim,
+ depth,
+ dim_head=64,
+ heads=8,
+ attn_dropout=0.0,
+ ff_dropout=0.0,
+ ff_mult=4,
+ norm_output=True,
+ rotary_embed=None,
+ flash_attn=True,
+ linear_attn=False,
):
super().__init__()
self.layers = ModuleList([])
for _ in range(depth):
if linear_attn:
- attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn)
+ attn = LinearAttention(
+ dim=dim,
+ dim_head=dim_head,
+ heads=heads,
+ dropout=attn_dropout,
+ flash=flash_attn,
+ )
else:
- attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout,
- rotary_embed=rotary_embed, flash=flash_attn)
-
- self.layers.append(ModuleList([
- attn,
- FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
- ]))
+ attn = Attention(
+ dim=dim,
+ dim_head=dim_head,
+ heads=heads,
+ dropout=attn_dropout,
+ rotary_embed=rotary_embed,
+ flash=flash_attn,
+ )
+
+ self.layers.append(
+ ModuleList(
+ [attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]
+ )
+ )
self.norm = RMSNorm(dim) if norm_output else nn.Identity()
@@ -227,22 +216,16 @@ def forward(self, x):
# bandsplit module
+
class BandSplit(Module):
@beartype
- def __init__(
- self,
- dim,
- dim_inputs: Tuple[int, ...]
- ):
+ def __init__(self, dim, dim_inputs: Tuple[int, ...]):
super().__init__()
self.dim_inputs = dim_inputs
self.to_features = ModuleList([])
for dim_in in dim_inputs:
- net = nn.Sequential(
- RMSNorm(dim_in),
- nn.Linear(dim_in, dim)
- )
+ net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
self.to_features.append(net)
@@ -257,13 +240,7 @@ def forward(self, x):
return torch.stack(outs, dim=-2)
-def MLP(
- dim_in,
- dim_out,
- dim_hidden=None,
- depth=1,
- activation=nn.Tanh
-):
+def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh):
dim_hidden = default(dim_hidden, dim_in)
net = []
@@ -284,13 +261,7 @@ def MLP(
class MaskEstimator(Module):
@beartype
- def __init__(
- self,
- dim,
- dim_inputs: Tuple[int, ...],
- depth,
- mlp_expansion_factor=4
- ):
+ def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4):
super().__init__()
self.dim_inputs = dim_inputs
self.to_freqs = ModuleList([])
@@ -300,8 +271,7 @@ def __init__(
net = []
mlp = nn.Sequential(
- MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
- nn.GLU(dim=-1)
+ MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1)
)
self.to_freqs.append(mlp)
@@ -320,40 +290,47 @@ def forward(self, x):
# main class
+
class MelBandRoformer(Module):
@beartype
def __init__(
- self,
- dim,
- *,
- depth,
- stereo=False,
- num_stems=1,
- time_transformer_depth=2,
- freq_transformer_depth=2,
- linear_transformer_depth=0,
- num_bands=60,
- dim_head=64,
- heads=8,
- attn_dropout=0.1,
- ff_dropout=0.1,
- flash_attn=True,
- dim_freqs_in=1025,
- sample_rate=44100, # needed for mel filter bank from librosa
- stft_n_fft=2048,
- stft_hop_length=512,
- # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
- stft_win_length=2048,
- stft_normalized=False,
- stft_window_fn: Optional[Callable] = None,
- mask_estimator_depth=1,
- multi_stft_resolution_loss_weight=1.,
- multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
- multi_stft_hop_size=147,
- multi_stft_normalized=False,
- multi_stft_window_fn: Callable = torch.hann_window,
- match_input_audio_length=False, # if True, pad output tensor to match length of input tensor
+ self,
+ dim,
+ *,
+ depth,
+ stereo=False,
+ num_stems=1,
+ time_transformer_depth=2,
+ freq_transformer_depth=2,
+ linear_transformer_depth=0,
+ num_bands=60,
+ dim_head=64,
+ heads=8,
+ attn_dropout=0.1,
+ ff_dropout=0.1,
+ flash_attn=True,
+ dim_freqs_in=1025,
+ sample_rate=44100, # needed for mel filter bank from librosa
+ stft_n_fft=2048,
+ stft_hop_length=512,
+ # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
+ stft_win_length=2048,
+ stft_normalized=False,
+ stft_window_fn: Optional[Callable] = None,
+ mask_estimator_depth=1,
+ multi_stft_resolution_loss_weight=1.0,
+ multi_stft_resolutions_window_sizes: Tuple[int, ...] = (
+ 4096,
+ 2048,
+ 1024,
+ 512,
+ 256,
+ ),
+ multi_stft_hop_size=147,
+ multi_stft_normalized=False,
+ multi_stft_window_fn: Callable = torch.hann_window,
+ match_input_audio_length=False, # if True, pad output tensor to match length of input tensor
):
super().__init__()
@@ -369,7 +346,7 @@ def __init__(
dim_head=dim_head,
attn_dropout=attn_dropout,
ff_dropout=ff_dropout,
- flash_attn=flash_attn
+ flash_attn=flash_attn,
)
time_rotary_embed = RotaryEmbedding(dim=dim_head)
@@ -378,80 +355,104 @@ def __init__(
for _ in range(depth):
tran_modules = []
if linear_transformer_depth > 0:
- tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs))
+ tran_modules.append(
+ Transformer(
+ depth=linear_transformer_depth,
+ linear_attn=True,
+ **transformer_kwargs,
+ )
+ )
tran_modules.append(
- Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs)
+ Transformer(
+ depth=time_transformer_depth,
+ rotary_embed=time_rotary_embed,
+ **transformer_kwargs,
+ )
)
tran_modules.append(
- Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs)
+ Transformer(
+ depth=freq_transformer_depth,
+ rotary_embed=freq_rotary_embed,
+ **transformer_kwargs,
+ )
)
self.layers.append(nn.ModuleList(tran_modules))
- self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
+ self.stft_window_fn = partial(
+ default(stft_window_fn, torch.hann_window), stft_win_length
+ )
self.stft_kwargs = dict(
n_fft=stft_n_fft,
hop_length=stft_hop_length,
win_length=stft_win_length,
- normalized=stft_normalized
+ normalized=stft_normalized,
)
- freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_n_fft), return_complex=True).shape[1]
+ freqs = torch.stft(
+ torch.randn(1, 4096),
+ **self.stft_kwargs,
+ window=torch.ones(stft_n_fft),
+ return_complex=True,
+ ).shape[1]
# create mel filter bank
# with librosa.filters.mel as in section 2 of paper
- mel_filter_bank_numpy = filters.mel(sr=sample_rate, n_fft=stft_n_fft, n_mels=num_bands)
+ mel_filter_bank_numpy = filters.mel(
+ sr=sample_rate, n_fft=stft_n_fft, n_mels=num_bands
+ )
mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy)
# for some reason, it doesn't include the first freq? just force a value for now
- mel_filter_bank[0][0] = 1.
+ mel_filter_bank[0][0] = 1.0
# In some systems/envs we get 0.0 instead of ~1.9e-18 in the last position,
# so let's force a positive value
- mel_filter_bank[-1, -1] = 1.
+ mel_filter_bank[-1, -1] = 1.0
# binary as in paper (then estimated masks are averaged for overlapping regions)
freqs_per_band = mel_filter_bank > 0
- assert freqs_per_band.any(dim=0).all(), 'all frequencies need to be covered by all bands for now'
+ assert freqs_per_band.any(
+ dim=0
+ ).all(), "all frequencies need to be covered by all bands for now"
- repeated_freq_indices = repeat(torch.arange(freqs), 'f -> b f', b=num_bands)
+ repeated_freq_indices = repeat(torch.arange(freqs), "f -> b f", b=num_bands)
freq_indices = repeated_freq_indices[freqs_per_band]
if stereo:
- freq_indices = repeat(freq_indices, 'f -> f s', s=2)
+ freq_indices = repeat(freq_indices, "f -> f s", s=2)
freq_indices = freq_indices * 2 + torch.arange(2)
- freq_indices = rearrange(freq_indices, 'f s -> (f s)')
+ freq_indices = rearrange(freq_indices, "f s -> (f s)")
- self.register_buffer('freq_indices', freq_indices, persistent=False)
- self.register_buffer('freqs_per_band', freqs_per_band, persistent=False)
+ self.register_buffer("freq_indices", freq_indices, persistent=False)
+ self.register_buffer("freqs_per_band", freqs_per_band, persistent=False)
- num_freqs_per_band = reduce(freqs_per_band, 'b f -> b', 'sum')
- num_bands_per_freq = reduce(freqs_per_band, 'b f -> f', 'sum')
+ num_freqs_per_band = reduce(freqs_per_band, "b f -> b", "sum")
+ num_bands_per_freq = reduce(freqs_per_band, "b f -> f", "sum")
- self.register_buffer('num_freqs_per_band', num_freqs_per_band, persistent=False)
- self.register_buffer('num_bands_per_freq', num_bands_per_freq, persistent=False)
+ self.register_buffer("num_freqs_per_band", num_freqs_per_band, persistent=False)
+ self.register_buffer("num_bands_per_freq", num_bands_per_freq, persistent=False)
# band split and mask estimator
- freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in num_freqs_per_band.tolist())
-
- self.band_split = BandSplit(
- dim=dim,
- dim_inputs=freqs_per_bands_with_complex
+ freqs_per_bands_with_complex = tuple(
+ 2 * f * self.audio_channels for f in num_freqs_per_band.tolist()
)
+ self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
+
self.mask_estimators = nn.ModuleList([])
for _ in range(num_stems):
mask_estimator = MaskEstimator(
dim=dim,
dim_inputs=freqs_per_bands_with_complex,
- depth=mask_estimator_depth
+ depth=mask_estimator_depth,
)
self.mask_estimators.append(mask_estimator)
@@ -464,18 +465,12 @@ def __init__(
self.multi_stft_window_fn = multi_stft_window_fn
self.multi_stft_kwargs = dict(
- hop_length=multi_stft_hop_size,
- normalized=multi_stft_normalized
+ hop_length=multi_stft_hop_size, normalized=multi_stft_normalized
)
self.match_input_audio_length = match_input_audio_length
- def forward(
- self,
- raw_audio,
- target=None,
- return_loss_breakdown=False
- ):
+ def forward(self, raw_audio, target=None, return_loss_breakdown=False):
"""
einops
@@ -491,27 +486,31 @@ def forward(
device = raw_audio.device
if raw_audio.ndim == 2:
- raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
+ raw_audio = rearrange(raw_audio, "b t -> b 1 t")
batch, channels, raw_audio_length = raw_audio.shape
istft_length = raw_audio_length if self.match_input_audio_length else None
assert (not self.stereo and channels == 1) or (
- self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
+ self.stereo and channels == 2
+ ), "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
# to stft
- raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
+ raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
stft_window = self.stft_window_fn(device=device)
- stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
+ stft_repr = torch.stft(
+ raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True
+ )
stft_repr = torch.view_as_real(stft_repr)
- stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
- stft_repr = rearrange(stft_repr,
- 'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
+ stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
+ stft_repr = rearrange(
+ stft_repr, "b s f t c -> b (f s) t c"
+ ) # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
# index out all frequencies for all frequency ranges across bands ascending in one go
@@ -523,7 +522,7 @@ def forward(
# fold the complex (real and imag) into the frequencies dimension
- x = rearrange(x, 'b f t c -> b t (f c)')
+ x = rearrange(x, "b f t c -> b t (f c)")
x = self.band_split(x)
@@ -532,35 +531,37 @@ def forward(
for transformer_block in self.layers:
if len(transformer_block) == 3:
- linear_transformer, time_transformer, freq_transformer = transformer_block
+ linear_transformer, time_transformer, freq_transformer = (
+ transformer_block
+ )
- x, ft_ps = pack([x], 'b * d')
+ x, ft_ps = pack([x], "b * d")
x = linear_transformer(x)
- x, = unpack(x, ft_ps, 'b * d')
+ (x,) = unpack(x, ft_ps, "b * d")
else:
time_transformer, freq_transformer = transformer_block
- x = rearrange(x, 'b t f d -> b f t d')
- x, ps = pack([x], '* t d')
+ x = rearrange(x, "b t f d -> b f t d")
+ x, ps = pack([x], "* t d")
x = time_transformer(x)
- x, = unpack(x, ps, '* t d')
- x = rearrange(x, 'b f t d -> b t f d')
- x, ps = pack([x], '* f d')
+ (x,) = unpack(x, ps, "* t d")
+ x = rearrange(x, "b f t d -> b t f d")
+ x, ps = pack([x], "* f d")
x = freq_transformer(x)
- x, = unpack(x, ps, '* f d')
+ (x,) = unpack(x, ps, "* f d")
num_stems = len(self.mask_estimators)
masks = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
- masks = rearrange(masks, 'b n t (f c) -> b n f t c', c=2)
+ masks = rearrange(masks, "b n t (f c) -> b n f t c", c=2)
# modulate frequency representation
- stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
+ stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
# complex number multiplication
@@ -571,12 +572,20 @@ def forward(
# need to average the estimated mask for the overlapped frequencies
- scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1])
+ scatter_indices = repeat(
+ self.freq_indices,
+ "f -> b n f t",
+ b=batch,
+ n=num_stems,
+ t=stft_repr.shape[-1],
+ )
- stft_repr_expanded_stems = repeat(stft_repr, 'b 1 ... -> b n ...', n=num_stems)
- masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(2, scatter_indices, masks)
+ stft_repr_expanded_stems = repeat(stft_repr, "b 1 ... -> b n ...", n=num_stems)
+ masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(
+ 2, scatter_indices, masks
+ )
- denom = repeat(self.num_bands_per_freq, 'f -> (f r) 1', r=channels)
+ denom = repeat(self.num_bands_per_freq, "f -> (f r) 1", r=channels)
masks_averaged = masks_summed / denom.clamp(min=1e-8)
@@ -586,15 +595,28 @@ def forward(
# istft
- stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
+ stft_repr = rearrange(
+ stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels
+ )
- recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False,
- length=istft_length)
+ recon_audio = torch.istft(
+ stft_repr,
+ **self.stft_kwargs,
+ window=stft_window,
+ return_complex=False,
+ length=istft_length,
+ )
- recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', b=batch, s=self.audio_channels, n=num_stems)
+ recon_audio = rearrange(
+ recon_audio,
+ "(b n s) t -> b n s t",
+ b=batch,
+ s=self.audio_channels,
+ n=num_stems,
+ )
if num_stems == 1:
- recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
+ recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
# if a target is passed in, calculate loss for learning
@@ -605,29 +627,41 @@ def forward(
assert target.ndim == 4 and target.shape[1] == self.num_stems
if target.ndim == 2:
- target = rearrange(target, '... t -> ... 1 t')
+ target = rearrange(target, "... t -> ... 1 t")
- target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft
+ target = target[
+ ..., : recon_audio.shape[-1]
+ ] # protect against lost length on istft
loss = F.l1_loss(recon_audio, target)
- multi_stft_resolution_loss = 0.
+ multi_stft_resolution_loss = 0.0
for window_size in self.multi_stft_resolutions_window_sizes:
res_stft_kwargs = dict(
- n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft
+ n_fft=max(
+ window_size, self.multi_stft_n_fft
+ ), # not sure what n_fft is across multi resolution stft
win_length=window_size,
return_complex=True,
window=self.multi_stft_window_fn(window_size, device=device),
**self.multi_stft_kwargs,
)
- recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs)
- target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs)
+ recon_Y = torch.stft(
+ rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs
+ )
+ target_Y = torch.stft(
+ rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs
+ )
- multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
+ multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(
+ recon_Y, target_Y
+ )
- weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
+ weighted_multi_resolution_loss = (
+ multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
+ )
total_loss = loss + weighted_multi_resolution_loss
diff --git a/programs/music_separation_code/models/demucs4ht.py b/programs/music_separation_code/models/demucs4ht.py
index 06c279c3..bf87cb16 100644
--- a/programs/music_separation_code/models/demucs4ht.py
+++ b/programs/music_separation_code/models/demucs4ht.py
@@ -1,11 +1,9 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-from functools import partial
import numpy as np
import torch
-import json
from omegaconf import OmegaConf
from demucs.demucs import Demucs
from demucs.hdemucs import HDemucs
@@ -317,7 +315,7 @@ def __init__(
dconv=dconv_mode & 1,
context=context_enc,
empty=last_freq,
- **kwt
+ **kwt,
)
self.tencoder.append(tenc)
@@ -337,7 +335,7 @@ def __init__(
dconv=dconv_mode & 2,
last=index == 0,
context=context,
- **kw_dec
+ **kw_dec,
)
if multi:
dec = MultiWrap(dec, multi_freqs)
@@ -349,7 +347,7 @@ def __init__(
empty=last_freq,
last=index == 0,
context=context,
- **kwt
+ **kwt,
)
self.tdecoder.insert(0, tdec)
self.decoder.insert(0, dec)
@@ -443,7 +441,7 @@ def _spec(self, x):
z = spectro(x, nfft, hl)[..., :-1, :]
assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
- z = z[..., 2: 2 + le]
+ z = z[..., 2 : 2 + le]
return z
def _ispec(self, z, length=None, scale=0):
@@ -453,7 +451,7 @@ def _ispec(self, z, length=None, scale=0):
pad = hl // 2 * 3
le = hl * int(math.ceil(length / hl)) + 2 * pad
x = ispectro(z, hl, length=le)
- x = x[..., pad: pad + length]
+ x = x[..., pad : pad + length]
return x
def _magnitude(self, z):
@@ -527,8 +525,9 @@ def valid_length(self, length: int):
training_length = int(self.segment * self.samplerate)
if training_length < length:
raise ValueError(
- f"Given length {length} is longer than "
- f"training length {training_length}")
+ f"Given length {length} is longer than "
+ f"training length {training_length}"
+ )
return training_length
def cac2cws(self, x):
@@ -695,19 +694,17 @@ def forward(self, mix):
def get_model(args):
extra = {
- 'sources': list(args.training.instruments),
- 'audio_channels': args.training.channels,
- 'samplerate': args.training.samplerate,
+ "sources": list(args.training.instruments),
+ "audio_channels": args.training.channels,
+ "samplerate": args.training.samplerate,
# 'segment': args.model_segment or 4 * args.dset.segment,
- 'segment': args.training.segment,
+ "segment": args.training.segment,
}
klass = {
- 'demucs': Demucs,
- 'hdemucs': HDemucs,
- 'htdemucs': HTDemucs,
+ "demucs": Demucs,
+ "hdemucs": HDemucs,
+ "htdemucs": HTDemucs,
}[args.model]
kw = OmegaConf.to_container(getattr(args, args.model), resolve=True)
model = klass(**extra, **kw)
return model
-
-
diff --git a/programs/music_separation_code/models/mdx23c_tfc_tdf_v3.py b/programs/music_separation_code/models/mdx23c_tfc_tdf_v3.py
index caa818cf..ad89c85b 100644
--- a/programs/music_separation_code/models/mdx23c_tfc_tdf_v3.py
+++ b/programs/music_separation_code/models/mdx23c_tfc_tdf_v3.py
@@ -22,12 +22,14 @@ def __call__(self, x):
hop_length=self.hop_length,
window=window,
center=True,
- return_complex=True
+ return_complex=True,
)
x = torch.view_as_real(x)
x = x.permute([0, 3, 1, 2])
- x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]])
- return x[..., :self.dim_f, :]
+ x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape(
+ [*batch_dims, c * 2, -1, x.shape[-1]]
+ )
+ return x[..., : self.dim_f, :]
def inverse(self, x):
window = self.window.to(x.device)
@@ -38,20 +40,22 @@ def inverse(self, x):
x = torch.cat([x, f_pad], -2)
x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t])
x = x.permute([0, 2, 3, 1])
- x = x[..., 0] + x[..., 1] * 1.j
- x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True)
+ x = x[..., 0] + x[..., 1] * 1.0j
+ x = torch.istft(
+ x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True
+ )
x = x.reshape([*batch_dims, 2, -1])
return x
def get_norm(norm_type):
def norm(c, norm_type):
- if norm_type == 'BatchNorm':
+ if norm_type == "BatchNorm":
return nn.BatchNorm2d(c)
- elif norm_type == 'InstanceNorm':
+ elif norm_type == "InstanceNorm":
return nn.InstanceNorm2d(c, affine=True)
- elif 'GroupNorm' in norm_type:
- g = int(norm_type.replace('GroupNorm', ''))
+ elif "GroupNorm" in norm_type:
+ g = int(norm_type.replace("GroupNorm", ""))
return nn.GroupNorm(num_groups=g, num_channels=c)
else:
return nn.Identity()
@@ -60,12 +64,12 @@ def norm(c, norm_type):
def get_act(act_type):
- if act_type == 'gelu':
+ if act_type == "gelu":
return nn.GELU()
- elif act_type == 'relu':
+ elif act_type == "relu":
return nn.ReLU()
- elif act_type[:3] == 'elu':
- alpha = float(act_type.replace('elu', ''))
+ elif act_type[:3] == "elu":
+ alpha = float(act_type.replace("elu", ""))
return nn.ELU(alpha)
else:
raise Exception
@@ -77,7 +81,13 @@ def __init__(self, in_c, out_c, scale, norm, act):
self.conv = nn.Sequential(
norm(in_c),
act,
- nn.ConvTranspose2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False)
+ nn.ConvTranspose2d(
+ in_channels=in_c,
+ out_channels=out_c,
+ kernel_size=scale,
+ stride=scale,
+ bias=False,
+ ),
)
def forward(self, x):
@@ -90,7 +100,13 @@ def __init__(self, in_c, out_c, scale, norm, act):
self.conv = nn.Sequential(
norm(in_c),
act,
- nn.Conv2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False)
+ nn.Conv2d(
+ in_channels=in_c,
+ out_channels=out_c,
+ kernel_size=scale,
+ stride=scale,
+ bias=False,
+ ),
)
def forward(self, x):
@@ -146,7 +162,9 @@ def __init__(self, config):
norm = get_norm(norm_type=config.model.norm)
act = get_act(act_type=config.model.act)
- self.num_target_instruments = 1 if config.training.target_instrument else len(config.training.instruments)
+ self.num_target_instruments = (
+ 1 if config.training.target_instrument else len(config.training.instruments)
+ )
self.num_subbands = config.model.num_subbands
dim_c = self.num_subbands * config.audio.num_channels * 2
@@ -183,7 +201,7 @@ def __init__(self, config):
self.final_conv = nn.Sequential(
nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False),
act,
- nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False)
+ nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False),
)
self.stft = STFT(config.audio)
diff --git a/programs/music_separation_code/models/scnet/scnet.py b/programs/music_separation_code/models/scnet/scnet.py
index b27704dc..37bdaad0 100644
--- a/programs/music_separation_code/models/scnet/scnet.py
+++ b/programs/music_separation_code/models/scnet/scnet.py
@@ -3,7 +3,6 @@
import torch.nn.functional as F
from collections import deque
from .separation import SeparationNet
-import typing as tp
import math
@@ -21,7 +20,7 @@ class ConvolutionModule(nn.Module):
depth (int): number of layers in the residual branch. Each layer has its own
compress (float): amount of channel compression.
kernel (int): kernel size for the convolutions.
- """
+ """
def __init__(self, channels, depth=2, compress=4, kernel=3):
super().__init__()
@@ -31,12 +30,18 @@ def __init__(self, channels, depth=2, compress=4, kernel=3):
norm = lambda d: nn.GroupNorm(1, d)
self.layers = nn.ModuleList([])
for _ in range(self.depth):
- padding = (kernel // 2)
+ padding = kernel // 2
mods = [
norm(channels),
nn.Conv1d(channels, hidden_size * 2, kernel, padding=padding),
nn.GLU(1),
- nn.Conv1d(hidden_size, hidden_size, kernel, padding=padding, groups=hidden_size),
+ nn.Conv1d(
+ hidden_size,
+ hidden_size,
+ kernel,
+ padding=padding,
+ groups=hidden_size,
+ ),
norm(hidden_size),
Swish(),
nn.Conv1d(hidden_size, channels, 1),
@@ -63,7 +68,9 @@ class FusionLayer(nn.Module):
def __init__(self, channels, kernel_size=3, stride=1, padding=1):
super(FusionLayer, self).__init__()
- self.conv = nn.Conv2d(channels * 2, channels * 2, kernel_size, stride=stride, padding=padding)
+ self.conv = nn.Conv2d(
+ channels * 2, channels * 2, kernel_size, stride=stride, padding=padding
+ )
def forward(self, x, skip=None):
if skip is not None:
@@ -96,13 +103,20 @@ def __init__(self, channels_in, channels_out, band_configs):
self.kernels = []
for config in band_configs.values():
self.convs.append(
- nn.Conv2d(channels_in, channels_out, (config['kernel'], 1), (config['stride'], 1), (0, 0)))
- self.strides.append(config['stride'])
- self.kernels.append(config['kernel'])
+ nn.Conv2d(
+ channels_in,
+ channels_out,
+ (config["kernel"], 1),
+ (config["stride"], 1),
+ (0, 0),
+ )
+ )
+ self.strides.append(config["stride"])
+ self.kernels.append(config["kernel"])
# Saving rate proportions for determining splits
- self.SR_low = band_configs['low']['SR']
- self.SR_mid = band_configs['mid']['SR']
+ self.SR_low = band_configs["low"]["SR"]
+ self.SR_mid = band_configs["mid"]["SR"]
def forward(self, x):
B, C, Fr, T = x.shape
@@ -110,13 +124,15 @@ def forward(self, x):
splits = [
(0, math.ceil(Fr * self.SR_low)),
(math.ceil(Fr * self.SR_low), math.ceil(Fr * (self.SR_low + self.SR_mid))),
- (math.ceil(Fr * (self.SR_low + self.SR_mid)), Fr)
+ (math.ceil(Fr * (self.SR_low + self.SR_mid)), Fr),
]
# Processing each band with the corresponding convolution
outputs = []
original_lengths = []
- for conv, stride, kernel, (start, end) in zip(self.convs, self.strides, self.kernels, splits):
+ for conv, stride, kernel, (start, end) in zip(
+ self.convs, self.strides, self.kernels, splits
+ ):
extracted = x[:, :, start:end, :]
original_lengths.append(end - start)
current_length = extracted.shape[2]
@@ -151,10 +167,17 @@ def __init__(self, channels_in, channels_out, band_configs):
super(SUlayer, self).__init__()
# Initializing convolutional layers for each band
- self.convtrs = nn.ModuleList([
- nn.ConvTranspose2d(channels_in, channels_out, [config['kernel'], 1], [config['stride'], 1])
- for _, config in band_configs.items()
- ])
+ self.convtrs = nn.ModuleList(
+ [
+ nn.ConvTranspose2d(
+ channels_in,
+ channels_out,
+ [config["kernel"], 1],
+ [config["stride"], 1],
+ )
+ for _, config in band_configs.items()
+ ]
+ )
def forward(self, x, lengths, origin_lengths):
B, C, Fr, T = x.shape
@@ -162,7 +185,7 @@ def forward(self, x, lengths, origin_lengths):
splits = [
(0, lengths[0]),
(lengths[0], lengths[0] + lengths[1]),
- (lengths[0] + lengths[1], None)
+ (lengths[0] + lengths[1], None),
]
# Processing each band with the corresponding convolution
outputs = []
@@ -173,7 +196,7 @@ def forward(self, x, lengths, origin_lengths):
dist = abs(origin_lengths[idx] - current_Fr_length) // 2
# Trim the output to the original length symmetrically
- trimmed_out = out[:, :, dist:dist + origin_lengths[idx], :]
+ trimmed_out = out[:, :, dist : dist + origin_lengths[idx], :]
outputs.append(trimmed_out)
@@ -195,16 +218,26 @@ class SDblock(nn.Module):
- depths (list of int): List specifying the convolution depths for low, mid, and high frequency bands.
"""
- def __init__(self, channels_in, channels_out, band_configs={}, conv_config={}, depths=[3, 2, 1], kernel_size=3):
+ def __init__(
+ self,
+ channels_in,
+ channels_out,
+ band_configs={},
+ conv_config={},
+ depths=[3, 2, 1],
+ kernel_size=3,
+ ):
super(SDblock, self).__init__()
self.SDlayer = SDlayer(channels_in, channels_out, band_configs)
# Dynamically create convolution modules for each band based on depths
- self.conv_modules = nn.ModuleList([
- ConvolutionModule(channels_out, depth, **conv_config) for depth in depths
- ])
+ self.conv_modules = nn.ModuleList(
+ [ConvolutionModule(channels_out, depth, **conv_config) for depth in depths]
+ )
# Set the kernel_size to an odd number.
- self.globalconv = nn.Conv2d(channels_out, channels_out, kernel_size, 1, (kernel_size - 1) // 2)
+ self.globalconv = nn.Conv2d(
+ channels_out, channels_out, kernel_size, 1, (kernel_size - 1) // 2
+ )
def forward(self, x):
bands, original_lengths = self.SDlayer(x)
@@ -216,7 +249,6 @@ def forward(self, x):
.permute(0, 2, 1, 3)
)
for conv, band in zip(self.conv_modules, bands)
-
]
lengths = [band.size(-2) for band in bands]
full_band = torch.cat(bands, dim=2)
@@ -250,47 +282,54 @@ class SCNet(nn.Module):
"""
- def __init__(self,
- sources=['drums', 'bass', 'other', 'vocals'],
- audio_channels=2,
- # Main structure
- dims=[4, 32, 64, 128], # dims = [4, 64, 128, 256] in SCNet-large
- # STFT
- nfft=4096,
- hop_size=1024,
- win_size=4096,
- normalized=True,
- # SD/SU layer
- band_SR=[0.175, 0.392, 0.433],
- band_stride=[1, 4, 16],
- band_kernel=[3, 4, 16],
- # Convolution Module
- conv_depths=[3, 2, 1],
- compress=4,
- conv_kernel=3,
- # Dual-path RNN
- num_dplayer=6,
- expand=1,
- ):
+ def __init__(
+ self,
+ sources=["drums", "bass", "other", "vocals"],
+ audio_channels=2,
+ # Main structure
+ dims=[4, 32, 64, 128], # dims = [4, 64, 128, 256] in SCNet-large
+ # STFT
+ nfft=4096,
+ hop_size=1024,
+ win_size=4096,
+ normalized=True,
+ # SD/SU layer
+ band_SR=[0.175, 0.392, 0.433],
+ band_stride=[1, 4, 16],
+ band_kernel=[3, 4, 16],
+ # Convolution Module
+ conv_depths=[3, 2, 1],
+ compress=4,
+ conv_kernel=3,
+ # Dual-path RNN
+ num_dplayer=6,
+ expand=1,
+ ):
super().__init__()
self.sources = sources
self.audio_channels = audio_channels
self.dims = dims
- band_keys = ['low', 'mid', 'high']
- self.band_configs = {band_keys[i]: {'SR': band_SR[i], 'stride': band_stride[i], 'kernel': band_kernel[i]} for i
- in range(len(band_keys))}
+ band_keys = ["low", "mid", "high"]
+ self.band_configs = {
+ band_keys[i]: {
+ "SR": band_SR[i],
+ "stride": band_stride[i],
+ "kernel": band_kernel[i],
+ }
+ for i in range(len(band_keys))
+ }
self.hop_length = hop_size
self.conv_config = {
- 'compress': compress,
- 'kernel': conv_kernel,
+ "compress": compress,
+ "kernel": conv_kernel,
}
self.stft_config = {
- 'n_fft': nfft,
- 'hop_length': hop_size,
- 'win_length': win_size,
- 'center': True,
- 'normalized': normalized
+ "n_fft": nfft,
+ "hop_length": hop_size,
+ "win_length": win_size,
+ "center": True,
+ "normalized": normalized,
}
self.encoder = nn.ModuleList()
@@ -302,7 +341,7 @@ def __init__(self,
channels_out=dims[index + 1],
band_configs=self.band_configs,
conv_config=self.conv_config,
- depths=conv_depths
+ depths=conv_depths,
)
self.encoder.append(enc)
@@ -310,9 +349,11 @@ def __init__(self,
FusionLayer(channels=dims[index + 1]),
SUlayer(
channels_in=dims[index + 1],
- channels_out=dims[index] if index != 0 else dims[index] * len(sources),
+ channels_out=(
+ dims[index] if index != 0 else dims[index] * len(sources)
+ ),
band_configs=self.band_configs,
- )
+ ),
)
self.decoder.insert(0, dec)
@@ -337,8 +378,12 @@ def forward(self, x):
x = x.reshape(-1, L)
x = torch.stft(x, **self.stft_config, return_complex=True)
x = torch.view_as_real(x)
- x = x.permute(0, 3, 1, 2).reshape(x.shape[0] // self.audio_channels, x.shape[3] * self.audio_channels,
- x.shape[1], x.shape[2])
+ x = x.permute(0, 3, 1, 2).reshape(
+ x.shape[0] // self.audio_channels,
+ x.shape[3] * self.audio_channels,
+ x.shape[1],
+ x.shape[2],
+ )
B, C, Fr, T = x.shape
diff --git a/programs/music_separation_code/models/scnet/separation.py b/programs/music_separation_code/models/scnet/separation.py
index d902dac4..8965e2c8 100644
--- a/programs/music_separation_code/models/scnet/separation.py
+++ b/programs/music_separation_code/models/scnet/separation.py
@@ -21,8 +21,8 @@ def forward(self, x):
# B, C, F, T = x.shape
if self.inverse:
x = x.float()
- x_r = x[:, :self.channels // 2, :, :]
- x_i = x[:, self.channels // 2:, :, :]
+ x_r = x[:, : self.channels // 2, :, :]
+ x_i = x[:, self.channels // 2 :, :, :]
x = torch.complex(x_r, x_i)
x = torch.fft.irfft(x, dim=3, norm="ortho")
else:
@@ -51,12 +51,22 @@ def __init__(self, d_model, expand, bidirectional=True):
self.hidden_size = d_model * expand
self.bidirectional = bidirectional
# Initialize LSTM layers and normalization layers
- self.lstm_layers = nn.ModuleList([self._init_lstm_layer(self.d_model, self.hidden_size) for _ in range(2)])
- self.linear_layers = nn.ModuleList([nn.Linear(self.hidden_size * 2, self.d_model) for _ in range(2)])
+ self.lstm_layers = nn.ModuleList(
+ [self._init_lstm_layer(self.d_model, self.hidden_size) for _ in range(2)]
+ )
+ self.linear_layers = nn.ModuleList(
+ [nn.Linear(self.hidden_size * 2, self.d_model) for _ in range(2)]
+ )
self.norm_layers = nn.ModuleList([nn.GroupNorm(1, d_model) for _ in range(2)])
def _init_lstm_layer(self, d_model, hidden_size):
- return LSTM(d_model, hidden_size, num_layers=1, bidirectional=self.bidirectional, batch_first=True)
+ return LSTM(
+ d_model,
+ hidden_size,
+ num_layers=1,
+ bidirectional=self.bidirectional,
+ batch_first=True,
+ )
def forward(self, x):
B, C, F, T = x.shape
@@ -98,13 +108,19 @@ def __init__(self, channels, expand=1, num_layers=6):
self.num_layers = num_layers
- self.dp_modules = nn.ModuleList([
- DualPathRNN(channels * (2 if i % 2 == 1 else 1), expand) for i in range(num_layers)
- ])
-
- self.feature_conversion = nn.ModuleList([
- FeatureConversion(channels * 2, inverse=False if i % 2 == 0 else True) for i in range(num_layers)
- ])
+ self.dp_modules = nn.ModuleList(
+ [
+ DualPathRNN(channels * (2 if i % 2 == 1 else 1), expand)
+ for i in range(num_layers)
+ ]
+ )
+
+ self.feature_conversion = nn.ModuleList(
+ [
+ FeatureConversion(channels * 2, inverse=False if i % 2 == 0 else True)
+ for i in range(num_layers)
+ ]
+ )
def forward(self, x):
for i in range(self.num_layers):
diff --git a/programs/music_separation_code/models/scnet_unofficial/__init__.py b/programs/music_separation_code/models/scnet_unofficial/__init__.py
index 6d034d38..298d9939 100644
--- a/programs/music_separation_code/models/scnet_unofficial/__init__.py
+++ b/programs/music_separation_code/models/scnet_unofficial/__init__.py
@@ -1 +1 @@
-from models.scnet_unofficial.scnet import SCNet
\ No newline at end of file
+from models.scnet_unofficial.scnet import SCNet
diff --git a/programs/music_separation_code/models/scnet_unofficial/modules/dualpath_rnn.py b/programs/music_separation_code/models/scnet_unofficial/modules/dualpath_rnn.py
index 2dfcdbcf..644d05a1 100644
--- a/programs/music_separation_code/models/scnet_unofficial/modules/dualpath_rnn.py
+++ b/programs/music_separation_code/models/scnet_unofficial/modules/dualpath_rnn.py
@@ -2,10 +2,11 @@
import torch.nn as nn
import torch.nn.functional as Func
+
class RMSNorm(nn.Module):
def __init__(self, dim):
super().__init__()
- self.scale = dim ** 0.5
+ self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(dim))
def forward(self, x):
@@ -17,11 +18,8 @@ def __init__(self, d_model, d_state, d_conv, d_expand):
super().__init__()
self.norm = RMSNorm(dim=d_model)
self.mamba = Mamba(
- d_model=d_model,
- d_state=d_state,
- d_conv=d_conv,
- d_expand=d_expand
- )
+ d_model=d_model, d_state=d_state, d_conv=d_conv, d_expand=d_expand
+ )
def forward(self, x):
x = x + self.mamba(self.norm(x))
@@ -128,7 +126,7 @@ def forward(self, x: torch.Tensor, time_dim: int) -> torch.Tensor:
x = x.reshape(B, F, T, D // 2, 2)
x = torch.view_as_complex(x)
x = torch.fft.irfft(x, n=time_dim, dim=2)
-
+
x = x.to(dtype)
return x
@@ -166,11 +164,10 @@ def __init__(
n_layers: int,
input_dim: int,
hidden_dim: int,
-
use_mamba: bool = False,
d_state: int = 16,
d_conv: int = 4,
- d_expand: int = 2
+ d_expand: int = 2,
):
"""
Initializes DualPathRNN with the specified number of layers, input dimension, and hidden dimension.
@@ -179,9 +176,20 @@ def __init__(
if use_mamba:
from mamba_ssm.modules.mamba_simple import Mamba
+
net = MambaModule
- dkwargs = {"d_model": input_dim, "d_state": d_state, "d_conv": d_conv, "d_expand": d_expand}
- ukwargs = {"d_model": input_dim * 2, "d_state": d_state, "d_conv": d_conv, "d_expand": d_expand * 2}
+ dkwargs = {
+ "d_model": input_dim,
+ "d_state": d_state,
+ "d_conv": d_conv,
+ "d_expand": d_expand,
+ }
+ ukwargs = {
+ "d_model": input_dim * 2,
+ "d_state": d_state,
+ "d_conv": d_conv,
+ "d_expand": d_expand * 2,
+ }
else:
net = RNNModule
dkwargs = {"input_dim": input_dim, "hidden_dim": hidden_dim}
@@ -190,13 +198,15 @@ def __init__(
self.layers = nn.ModuleList()
for i in range(1, n_layers + 1):
kwargs = dkwargs if i % 2 == 1 else ukwargs
- layer = nn.ModuleList([
- net(**kwargs),
- net(**kwargs),
- RFFTModule(inverse=(i % 2 == 0)),
- ])
+ layer = nn.ModuleList(
+ [
+ net(**kwargs),
+ net(**kwargs),
+ RFFTModule(inverse=(i % 2 == 0)),
+ ]
+ )
self.layers.append(layer)
-
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Performs forward pass through the DualPathRNN.
@@ -224,5 +234,5 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.permute(0, 2, 1, 3)
x = rfft_layer(x, time_dim)
-
+
return x
diff --git a/programs/music_separation_code/models/scnet_unofficial/scnet.py b/programs/music_separation_code/models/scnet_unofficial/scnet.py
index d076f85f..d6dcf728 100644
--- a/programs/music_separation_code/models/scnet_unofficial/scnet.py
+++ b/programs/music_separation_code/models/scnet_unofficial/scnet.py
@@ -1,8 +1,8 @@
-'''
+"""
SCNet - great paper, great implementation
https://arxiv.org/pdf/2401.13276.pdf
https://github.com/amanteur/SCNet-PyTorch
-'''
+"""
from typing import List
@@ -20,6 +20,7 @@
from beartype.typing import Tuple, Optional, List, Callable
from beartype import beartype
+
def exists(val):
return val is not None
@@ -39,7 +40,7 @@ def unpack_one(t, ps, pattern):
class RMSNorm(nn.Module):
def __init__(self, dim):
super().__init__()
- self.scale = dim ** 0.5
+ self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(dim))
def forward(self, x):
@@ -48,20 +49,13 @@ def forward(self, x):
class BandSplit(nn.Module):
@beartype
- def __init__(
- self,
- dim,
- dim_inputs: Tuple[int, ...]
- ):
+ def __init__(self, dim, dim_inputs: Tuple[int, ...]):
super().__init__()
self.dim_inputs = dim_inputs
self.to_features = ModuleList([])
for dim_in in dim_inputs:
- net = nn.Sequential(
- RMSNorm(dim_in),
- nn.Linear(dim_in, dim)
- )
+ net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
self.to_features.append(net)
@@ -107,6 +101,7 @@ class SCNet(nn.Module):
C is channel dim (mono / stereo),
T is sequence length,
"""
+
@beartype
def __init__(
self,
@@ -122,7 +117,7 @@ def __init__(
win_length: int = 4096,
stft_window_fn: Optional[Callable] = None,
stft_normalized: bool = False,
- **kwargs
+ **kwargs,
):
"""
Initializes SCNet with input parameters.
@@ -156,7 +151,7 @@ def __init__(
n_layers=n_rnn_layers,
input_dim=dims[-1],
hidden_dim=rnn_hidden_dim,
- **kwargs
+ **kwargs,
)
self.su_blocks = nn.ModuleList(
SUBlock(
@@ -174,10 +169,12 @@ def __init__(
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
- normalized=stft_normalized
+ normalized=stft_normalized,
)
- self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), win_length)
+ self.stft_window_fn = partial(
+ default(stft_window_fn, torch.hann_window), win_length
+ )
self.n_sources = n_sources
self.hop_length = hop_length
@@ -208,19 +205,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
stft_window = self.stft_window_fn(device=device)
if x.ndim == 2:
- x = rearrange(x, 'b t -> b 1 t')
+ x = rearrange(x, "b t -> b 1 t")
c = x.shape[1]
-
+
stft_pad = self.hop_length - x.shape[-1] % self.hop_length
x = F.pad(x, (0, stft_pad))
# stft
- x, ps = pack_one(x, '* t')
+ x, ps = pack_one(x, "* t")
x = torch.stft(x, **self.stft_kwargs, window=stft_window, return_complex=True)
x = torch.view_as_real(x)
- x = unpack_one(x, ps, '* c f t')
- x = rearrange(x, 'b c f t r -> b f t (c r)')
+ x = unpack_one(x, ps, "* c f t")
+ x = rearrange(x, "b c f t r -> b f t (c r)")
# encoder part
x_skips = []
@@ -236,14 +233,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x = su_block(x, x_skip)
# istft
- x = rearrange(x, 'b f t (c r n) -> b n c f t r', c=c, n=self.n_sources, r=2)
+ x = rearrange(x, "b f t (c r n) -> b n c f t r", c=c, n=self.n_sources, r=2)
x = x.contiguous()
- x = torch.view_as_complex(x)
- x = rearrange(x, 'b n c f t -> (b n c) f t')
+ x = torch.view_as_complex(x)
+ x = rearrange(x, "b n c f t -> (b n c) f t")
x = torch.istft(x, **self.stft_kwargs, window=stft_window, return_complex=False)
- x = rearrange(x, '(b n c) t -> b n c t', c=c, n=self.n_sources)
+ x = rearrange(x, "(b n c) t -> b n c t", c=c, n=self.n_sources)
- x = x[..., :-stft_pad]
+ x = x[..., :-stft_pad]
return x
diff --git a/programs/music_separation_code/models/scnet_unofficial/utils.py b/programs/music_separation_code/models/scnet_unofficial/utils.py
index aae1afcd..d236d499 100644
--- a/programs/music_separation_code/models/scnet_unofficial/utils.py
+++ b/programs/music_separation_code/models/scnet_unofficial/utils.py
@@ -1,8 +1,8 @@
-'''
+"""
SCNet - great paper, great implementation
https://arxiv.org/pdf/2401.13276.pdf
https://github.com/amanteur/SCNet-PyTorch
-'''
+"""
from typing import List, Tuple, Union
@@ -10,7 +10,7 @@
def create_intervals(
- splits: List[Union[float, int]]
+ splits: List[Union[float, int]],
) -> List[Union[Tuple[float, float], Tuple[int, int]]]:
"""
Create intervals based on splits provided.
@@ -132,4 +132,4 @@ def compute_gcr(subband_shapes: List[List[int]]) -> float:
gcr = torch.stack(
[(1 - t[i + 1] / t[i]).mean() for i in range(0, len(t) - 1)]
).mean()
- return float(gcr)
\ No newline at end of file
+ return float(gcr)
diff --git a/programs/music_separation_code/models/segm_models.py b/programs/music_separation_code/models/segm_models.py
index cf858ec2..537d94af 100644
--- a/programs/music_separation_code/models/segm_models.py
+++ b/programs/music_separation_code/models/segm_models.py
@@ -21,12 +21,14 @@ def __call__(self, x):
hop_length=self.hop_length,
window=window,
center=True,
- return_complex=True
+ return_complex=True,
)
x = torch.view_as_real(x)
x = x.permute([0, 3, 1, 2])
- x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]])
- return x[..., :self.dim_f, :]
+ x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape(
+ [*batch_dims, c * 2, -1, x.shape[-1]]
+ )
+ return x[..., : self.dim_f, :]
def inverse(self, x):
window = self.window.to(x.device)
@@ -37,25 +39,21 @@ def inverse(self, x):
x = torch.cat([x, f_pad], -2)
x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t])
x = x.permute([0, 2, 3, 1])
- x = x[..., 0] + x[..., 1] * 1.j
+ x = x[..., 0] + x[..., 1] * 1.0j
x = torch.istft(
- x,
- n_fft=self.n_fft,
- hop_length=self.hop_length,
- window=window,
- center=True
+ x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True
)
x = x.reshape([*batch_dims, 2, -1])
return x
def get_act(act_type):
- if act_type == 'gelu':
+ if act_type == "gelu":
return nn.GELU()
- elif act_type == 'relu':
+ elif act_type == "relu":
return nn.ReLU()
- elif act_type[:3] == 'elu':
- alpha = float(act_type.replace('elu', ''))
+ elif act_type[:3] == "elu":
+ alpha = float(act_type.replace("elu", ""))
return nn.ELU(alpha)
else:
raise Exception
@@ -64,7 +62,7 @@ def get_act(act_type):
def get_decoder(config, c):
decoder = None
decoder_options = dict()
- if config.model.decoder_type == 'unet':
+ if config.model.decoder_type == "unet":
try:
decoder_options = dict(config.decoder_unet)
except:
@@ -76,7 +74,7 @@ def get_decoder(config, c):
classes=c,
**decoder_options,
)
- elif config.model.decoder_type == 'fpn':
+ elif config.model.decoder_type == "fpn":
try:
decoder_options = dict(config.decoder_fpn)
except:
@@ -88,7 +86,7 @@ def get_decoder(config, c):
classes=c,
**decoder_options,
)
- elif config.model.decoder_type == 'unet++':
+ elif config.model.decoder_type == "unet++":
try:
decoder_options = dict(config.decoder_unet_plus_plus)
except:
@@ -100,7 +98,7 @@ def get_decoder(config, c):
classes=c,
**decoder_options,
)
- elif config.model.decoder_type == 'manet':
+ elif config.model.decoder_type == "manet":
try:
decoder_options = dict(config.decoder_manet)
except:
@@ -112,7 +110,7 @@ def get_decoder(config, c):
classes=c,
**decoder_options,
)
- elif config.model.decoder_type == 'linknet':
+ elif config.model.decoder_type == "linknet":
try:
decoder_options = dict(config.decoder_linknet)
except:
@@ -124,7 +122,7 @@ def get_decoder(config, c):
classes=c,
**decoder_options,
)
- elif config.model.decoder_type == 'pspnet':
+ elif config.model.decoder_type == "pspnet":
try:
decoder_options = dict(config.decoder_pspnet)
except:
@@ -136,7 +134,7 @@ def get_decoder(config, c):
classes=c,
**decoder_options,
)
- elif config.model.decoder_type == 'pspnet':
+ elif config.model.decoder_type == "pspnet":
try:
decoder_options = dict(config.decoder_pspnet)
except:
@@ -148,7 +146,7 @@ def get_decoder(config, c):
classes=c,
**decoder_options,
)
- elif config.model.decoder_type == 'pan':
+ elif config.model.decoder_type == "pan":
try:
decoder_options = dict(config.decoder_pan)
except:
@@ -160,7 +158,7 @@ def get_decoder(config, c):
classes=c,
**decoder_options,
)
- elif config.model.decoder_type == 'deeplabv3':
+ elif config.model.decoder_type == "deeplabv3":
try:
decoder_options = dict(config.decoder_deeplabv3)
except:
@@ -172,7 +170,7 @@ def get_decoder(config, c):
classes=c,
**decoder_options,
)
- elif config.model.decoder_type == 'deeplabv3plus':
+ elif config.model.decoder_type == "deeplabv3plus":
try:
decoder_options = dict(config.decoder_deeplabv3plus)
except:
@@ -194,7 +192,9 @@ def __init__(self, config):
act = get_act(act_type=config.model.act)
- self.num_target_instruments = 1 if config.training.target_instrument else len(config.training.instruments)
+ self.num_target_instruments = (
+ 1 if config.training.target_instrument else len(config.training.instruments)
+ )
self.num_subbands = config.model.num_subbands
dim_c = self.num_subbands * config.audio.num_channels * 2
@@ -208,7 +208,7 @@ def __init__(self, config):
self.final_conv = nn.Sequential(
nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False),
act,
- nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False)
+ nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False),
)
self.stft = STFT(config.audio)
diff --git a/programs/music_separation_code/models/torchseg_models.py b/programs/music_separation_code/models/torchseg_models.py
index fb4bd9fb..92fec692 100644
--- a/programs/music_separation_code/models/torchseg_models.py
+++ b/programs/music_separation_code/models/torchseg_models.py
@@ -21,12 +21,14 @@ def __call__(self, x):
hop_length=self.hop_length,
window=window,
center=True,
- return_complex=True
+ return_complex=True,
)
x = torch.view_as_real(x)
x = x.permute([0, 3, 1, 2])
- x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]])
- return x[..., :self.dim_f, :]
+ x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape(
+ [*batch_dims, c * 2, -1, x.shape[-1]]
+ )
+ return x[..., : self.dim_f, :]
def inverse(self, x):
window = self.window.to(x.device)
@@ -37,25 +39,21 @@ def inverse(self, x):
x = torch.cat([x, f_pad], -2)
x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t])
x = x.permute([0, 2, 3, 1])
- x = x[..., 0] + x[..., 1] * 1.j
+ x = x[..., 0] + x[..., 1] * 1.0j
x = torch.istft(
- x,
- n_fft=self.n_fft,
- hop_length=self.hop_length,
- window=window,
- center=True
+ x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True
)
x = x.reshape([*batch_dims, 2, -1])
return x
def get_act(act_type):
- if act_type == 'gelu':
+ if act_type == "gelu":
return nn.GELU()
- elif act_type == 'relu':
+ elif act_type == "relu":
return nn.ReLU()
- elif act_type[:3] == 'elu':
- alpha = float(act_type.replace('elu', ''))
+ elif act_type[:3] == "elu":
+ alpha = float(act_type.replace("elu", ""))
return nn.ELU(alpha)
else:
raise Exception
@@ -64,7 +62,7 @@ def get_act(act_type):
def get_decoder(config, c):
decoder = None
decoder_options = dict()
- if config.model.decoder_type == 'unet':
+ if config.model.decoder_type == "unet":
try:
decoder_options = dict(config.decoder_unet)
except:
@@ -76,7 +74,7 @@ def get_decoder(config, c):
classes=c,
**decoder_options,
)
- elif config.model.decoder_type == 'fpn':
+ elif config.model.decoder_type == "fpn":
try:
decoder_options = dict(config.decoder_fpn)
except:
@@ -88,7 +86,7 @@ def get_decoder(config, c):
classes=c,
**decoder_options,
)
- elif config.model.decoder_type == 'unet++':
+ elif config.model.decoder_type == "unet++":
try:
decoder_options = dict(config.decoder_unet_plus_plus)
except:
@@ -100,7 +98,7 @@ def get_decoder(config, c):
classes=c,
**decoder_options,
)
- elif config.model.decoder_type == 'manet':
+ elif config.model.decoder_type == "manet":
try:
decoder_options = dict(config.decoder_manet)
except:
@@ -112,7 +110,7 @@ def get_decoder(config, c):
classes=c,
**decoder_options,
)
- elif config.model.decoder_type == 'linknet':
+ elif config.model.decoder_type == "linknet":
try:
decoder_options = dict(config.decoder_linknet)
except:
@@ -124,7 +122,7 @@ def get_decoder(config, c):
classes=c,
**decoder_options,
)
- elif config.model.decoder_type == 'pspnet':
+ elif config.model.decoder_type == "pspnet":
try:
decoder_options = dict(config.decoder_pspnet)
except:
@@ -136,7 +134,7 @@ def get_decoder(config, c):
classes=c,
**decoder_options,
)
- elif config.model.decoder_type == 'pspnet':
+ elif config.model.decoder_type == "pspnet":
try:
decoder_options = dict(config.decoder_pspnet)
except:
@@ -148,7 +146,7 @@ def get_decoder(config, c):
classes=c,
**decoder_options,
)
- elif config.model.decoder_type == 'pan':
+ elif config.model.decoder_type == "pan":
try:
decoder_options = dict(config.decoder_pan)
except:
@@ -160,7 +158,7 @@ def get_decoder(config, c):
classes=c,
**decoder_options,
)
- elif config.model.decoder_type == 'deeplabv3':
+ elif config.model.decoder_type == "deeplabv3":
try:
decoder_options = dict(config.decoder_deeplabv3)
except:
@@ -172,7 +170,7 @@ def get_decoder(config, c):
classes=c,
**decoder_options,
)
- elif config.model.decoder_type == 'deeplabv3plus':
+ elif config.model.decoder_type == "deeplabv3plus":
try:
decoder_options = dict(config.decoder_deeplabv3plus)
except:
@@ -194,7 +192,9 @@ def __init__(self, config):
act = get_act(act_type=config.model.act)
- self.num_target_instruments = 1 if config.training.target_instrument else len(config.training.instruments)
+ self.num_target_instruments = (
+ 1 if config.training.target_instrument else len(config.training.instruments)
+ )
self.num_subbands = config.model.num_subbands
dim_c = self.num_subbands * config.audio.num_channels * 2
@@ -208,7 +208,7 @@ def __init__(self, config):
self.final_conv = nn.Sequential(
nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False),
act,
- nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False)
+ nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False),
)
self.stft = STFT(config.audio)
diff --git a/programs/music_separation_code/models/upernet_swin_transformers.py b/programs/music_separation_code/models/upernet_swin_transformers.py
index d20e289b..27f32f41 100644
--- a/programs/music_separation_code/models/upernet_swin_transformers.py
+++ b/programs/music_separation_code/models/upernet_swin_transformers.py
@@ -22,12 +22,14 @@ def __call__(self, x):
hop_length=self.hop_length,
window=window,
center=True,
- return_complex=True
+ return_complex=True,
)
x = torch.view_as_real(x)
x = x.permute([0, 3, 1, 2])
- x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]])
- return x[..., :self.dim_f, :]
+ x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape(
+ [*batch_dims, c * 2, -1, x.shape[-1]]
+ )
+ return x[..., : self.dim_f, :]
def inverse(self, x):
window = self.window.to(x.device)
@@ -38,13 +40,9 @@ def inverse(self, x):
x = torch.cat([x, f_pad], -2)
x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t])
x = x.permute([0, 2, 3, 1])
- x = x[..., 0] + x[..., 1] * 1.j
+ x = x[..., 0] + x[..., 1] * 1.0j
x = torch.istft(
- x,
- n_fft=self.n_fft,
- hop_length=self.hop_length,
- window=window,
- center=True
+ x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True
)
x = x.reshape([*batch_dims, 2, -1])
return x
@@ -52,12 +50,12 @@ def inverse(self, x):
def get_norm(norm_type):
def norm(c, norm_type):
- if norm_type == 'BatchNorm':
+ if norm_type == "BatchNorm":
return nn.BatchNorm2d(c)
- elif norm_type == 'InstanceNorm':
+ elif norm_type == "InstanceNorm":
return nn.InstanceNorm2d(c, affine=True)
- elif 'GroupNorm' in norm_type:
- g = int(norm_type.replace('GroupNorm', ''))
+ elif "GroupNorm" in norm_type:
+ g = int(norm_type.replace("GroupNorm", ""))
return nn.GroupNorm(num_groups=g, num_channels=c)
else:
return nn.Identity()
@@ -66,12 +64,12 @@ def norm(c, norm_type):
def get_act(act_type):
- if act_type == 'gelu':
+ if act_type == "gelu":
return nn.GELU()
- elif act_type == 'relu':
+ elif act_type == "relu":
return nn.ReLU()
- elif act_type[:3] == 'elu':
- alpha = float(act_type.replace('elu', ''))
+ elif act_type[:3] == "elu":
+ alpha = float(act_type.replace("elu", ""))
return nn.ELU(alpha)
else:
raise Exception
@@ -83,7 +81,13 @@ def __init__(self, in_c, out_c, scale, norm, act):
self.conv = nn.Sequential(
norm(in_c),
act,
- nn.ConvTranspose2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False)
+ nn.ConvTranspose2d(
+ in_channels=in_c,
+ out_channels=out_c,
+ kernel_size=scale,
+ stride=scale,
+ bias=False,
+ ),
)
def forward(self, x):
@@ -96,7 +100,13 @@ def __init__(self, in_c, out_c, scale, norm, act):
self.conv = nn.Sequential(
norm(in_c),
act,
- nn.Conv2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False)
+ nn.Conv2d(
+ in_channels=in_c,
+ out_channels=out_c,
+ kernel_size=scale,
+ stride=scale,
+ bias=False,
+ ),
)
def forward(self, x):
@@ -151,7 +161,9 @@ def __init__(self, config):
act = get_act(act_type=config.model.act)
- self.num_target_instruments = 1 if config.training.target_instrument else len(config.training.instruments)
+ self.num_target_instruments = (
+ 1 if config.training.target_instrument else len(config.training.instruments)
+ )
self.num_subbands = config.model.num_subbands
dim_c = self.num_subbands * config.audio.num_channels * 2
@@ -160,16 +172,24 @@ def __init__(self, config):
self.first_conv = nn.Conv2d(dim_c, c, 1, 1, 0, bias=False)
- self.swin_upernet_model = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-swin-large")
+ self.swin_upernet_model = UperNetForSemanticSegmentation.from_pretrained(
+ "openmmlab/upernet-swin-large"
+ )
- self.swin_upernet_model.auxiliary_head.classifier = nn.Conv2d(256, c, kernel_size=(1, 1), stride=(1, 1))
- self.swin_upernet_model.decode_head.classifier = nn.Conv2d(512, c, kernel_size=(1, 1), stride=(1, 1))
- self.swin_upernet_model.backbone.embeddings.patch_embeddings.projection = nn.Conv2d(c, 192, kernel_size=(4, 4), stride=(4, 4))
+ self.swin_upernet_model.auxiliary_head.classifier = nn.Conv2d(
+ 256, c, kernel_size=(1, 1), stride=(1, 1)
+ )
+ self.swin_upernet_model.decode_head.classifier = nn.Conv2d(
+ 512, c, kernel_size=(1, 1), stride=(1, 1)
+ )
+ self.swin_upernet_model.backbone.embeddings.patch_embeddings.projection = (
+ nn.Conv2d(c, 192, kernel_size=(4, 4), stride=(4, 4))
+ )
self.final_conv = nn.Sequential(
nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False),
act,
- nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False)
+ nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False),
)
self.stft = STFT(config.audio)
@@ -217,7 +237,9 @@ def forward(self, x):
if __name__ == "__main__":
- model = UperNetForSemanticSegmentation.from_pretrained("./results/", ignore_mismatched_sizes=True)
+ model = UperNetForSemanticSegmentation.from_pretrained(
+ "./results/", ignore_mismatched_sizes=True
+ )
print(model)
print(model.auxiliary_head.classifier)
print(model.decode_head.classifier)
@@ -225,4 +247,4 @@ def forward(self, x):
x = torch.zeros((2, 16, 512, 512), dtype=torch.float32)
res = model(x)
print(res.logits.shape)
- model.save_pretrained('./results/')
\ No newline at end of file
+ model.save_pretrained("./results/")
diff --git a/programs/music_separation_code/utils.py b/programs/music_separation_code/utils.py
index 711af16b..1daee57f 100644
--- a/programs/music_separation_code/utils.py
+++ b/programs/music_separation_code/utils.py
@@ -1,7 +1,6 @@
# coding: utf-8
-__author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/'
+__author__ = "Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/"
-import time
import numpy as np
import torch
import torch.nn as nn
@@ -12,64 +11,65 @@
from numpy.typing import NDArray
from typing import Dict
+
def get_model_from_config(model_type, config_path):
with open(config_path) as f:
- if model_type == 'htdemucs':
+ if model_type == "htdemucs":
config = OmegaConf.load(config_path)
else:
config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
- if model_type == 'mdx23c':
+ if model_type == "mdx23c":
from models.mdx23c_tfc_tdf_v3 import TFC_TDF_net
+
model = TFC_TDF_net(config)
- elif model_type == 'htdemucs':
+ elif model_type == "htdemucs":
from models.demucs4ht import get_model
+
model = get_model(config)
- elif model_type == 'segm_models':
+ elif model_type == "segm_models":
from models.segm_models import Segm_Models_Net
+
model = Segm_Models_Net(config)
- elif model_type == 'torchseg':
+ elif model_type == "torchseg":
from models.torchseg_models import Torchseg_Net
+
model = Torchseg_Net(config)
- elif model_type == 'mel_band_roformer':
+ elif model_type == "mel_band_roformer":
from models.bs_roformer import MelBandRoformer
- model = MelBandRoformer(
- **dict(config.model)
- )
- elif model_type == 'bs_roformer':
+
+ model = MelBandRoformer(**dict(config.model))
+ elif model_type == "bs_roformer":
from models.bs_roformer import BSRoformer
- model = BSRoformer(
- **dict(config.model)
- )
- elif model_type == 'swin_upernet':
+
+ model = BSRoformer(**dict(config.model))
+ elif model_type == "swin_upernet":
from models.upernet_swin_transformers import Swin_UperNet_Model
+
model = Swin_UperNet_Model(config)
- elif model_type == 'bandit':
+ elif model_type == "bandit":
from models.bandit.core.model import MultiMaskMultiSourceBandSplitRNNSimple
- model = MultiMaskMultiSourceBandSplitRNNSimple(
- **config.model
- )
- elif model_type == 'bandit_v2':
+
+ model = MultiMaskMultiSourceBandSplitRNNSimple(**config.model)
+ elif model_type == "bandit_v2":
from models.bandit_v2.bandit import Bandit
- model = Bandit(
- **config.kwargs
- )
- elif model_type == 'scnet_unofficial':
+
+ model = Bandit(**config.kwargs)
+ elif model_type == "scnet_unofficial":
from models.scnet_unofficial import SCNet
- model = SCNet(
- **config.model
- )
- elif model_type == 'scnet':
+
+ model = SCNet(**config.model)
+ elif model_type == "scnet":
from models.scnet import SCNet
- model = SCNet(
- **config.model
- )
+
+ model = SCNet(**config.model)
else:
- print('Unknown model: {}'.format(model_type))
+ print("Unknown model: {}".format(model_type))
model = None
return model, config
+
def _getWindowingArray(window_size, fade_size):
fadein = torch.linspace(0, 1, fade_size)
fadeout = torch.linspace(1, 0, fade_size)
@@ -91,16 +91,16 @@ def demix_track(config, model, mix, device, pbar=False):
# Do pad from the beginning and end to account floating window results better
if length_init > 2 * border and (border > 0):
- mix = nn.functional.pad(mix, (border, border), mode='reflect')
+ mix = nn.functional.pad(mix, (border, border), mode="reflect")
# windowingArray crossfades at segment boundaries to mitigate clicking artifacts
windowingArray = _getWindowingArray(C, fade_size)
with torch.cuda.amp.autocast(enabled=config.training.use_amp):
- use_amp = getattr(config.training, 'use_amp', False)
+ use_amp = getattr(config.training, "use_amp", False)
with torch.inference_mode():
if config.training.target_instrument is not None:
- req_shape = (1, ) + tuple(mix.shape)
+ req_shape = (1,) + tuple(mix.shape)
else:
req_shape = (len(config.training.instruments),) + tuple(mix.shape)
@@ -109,17 +109,28 @@ def demix_track(config, model, mix, device, pbar=False):
i = 0
batch_data = []
batch_locations = []
- progress_bar = tqdm(total=mix.shape[1], desc="Processing audio chunks", leave=False) if pbar else None
+ progress_bar = (
+ tqdm(total=mix.shape[1], desc="Processing audio chunks", leave=False)
+ if pbar
+ else None
+ )
while i < mix.shape[1]:
# print(i, i + C, mix.shape[1])
- part = mix[:, i:i + C].to(device)
+ part = mix[:, i : i + C].to(device)
length = part.shape[-1]
if length < C:
if length > C // 2 + 1:
- part = nn.functional.pad(input=part, pad=(0, C - length), mode='reflect')
+ part = nn.functional.pad(
+ input=part, pad=(0, C - length), mode="reflect"
+ )
else:
- part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
+ part = nn.functional.pad(
+ input=part,
+ pad=(0, C - length, 0, 0),
+ mode="constant",
+ value=0,
+ )
batch_data.append(part)
batch_locations.append((i, length))
i += step
@@ -136,8 +147,10 @@ def demix_track(config, model, mix, device, pbar=False):
for j in range(len(batch_locations)):
start, l = batch_locations[j]
- result[..., start:start+l] += x[j][..., :l].cpu() * window[..., :l]
- counter[..., start:start+l] += window[..., :l]
+ result[..., start : start + l] += (
+ x[j][..., :l].cpu() * window[..., :l]
+ )
+ counter[..., start : start + l] += window[..., :l]
batch_data = []
batch_locations = []
@@ -159,7 +172,9 @@ def demix_track(config, model, mix, device, pbar=False):
if config.training.target_instrument is None:
return {k: v for k, v in zip(config.training.instruments, estimated_sources)}
else:
- return {k: v for k, v in zip([config.training.target_instrument], estimated_sources)}
+ return {
+ k: v for k, v in zip([config.training.target_instrument], estimated_sources)
+ }
def demix_track_demucs(config, model, mix, device, pbar=False):
@@ -172,32 +187,37 @@ def demix_track_demucs(config, model, mix, device, pbar=False):
with torch.cuda.amp.autocast(enabled=config.training.use_amp):
with torch.inference_mode():
- req_shape = (S, ) + tuple(mix.shape)
+ req_shape = (S,) + tuple(mix.shape)
result = torch.zeros(req_shape, dtype=torch.float32)
counter = torch.zeros(req_shape, dtype=torch.float32)
i = 0
batch_data = []
batch_locations = []
- progress_bar = tqdm(total=mix.shape[1], desc="Processing audio chunks", leave=False) if pbar else None
+ progress_bar = (
+ tqdm(total=mix.shape[1], desc="Processing audio chunks", leave=False)
+ if pbar
+ else None
+ )
while i < mix.shape[1]:
# print(i, i + C, mix.shape[1])
- part = mix[:, i:i + C].to(device)
+ part = mix[:, i : i + C].to(device)
length = part.shape[-1]
if length < C:
- part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
+ part = nn.functional.pad(
+ input=part, pad=(0, C - length, 0, 0), mode="constant", value=0
+ )
batch_data.append(part)
batch_locations.append((i, length))
i += step
-
if len(batch_data) >= batch_size or (i >= mix.shape[1]):
arr = torch.stack(batch_data, dim=0)
x = model(arr)
for j in range(len(batch_locations)):
start, l = batch_locations[j]
- result[..., start:start+l] += x[j][..., :l].cpu()
- counter[..., start:start+l] += 1.
+ result[..., start : start + l] += x[j][..., :l].cpu()
+ counter[..., start : start + l] += 1.0
batch_data = []
batch_locations = []
@@ -226,9 +246,12 @@ def sdr(references, estimates):
den += delta
return 10 * np.log10(num / den)
-def demix(config, model, mix: NDArray, device, pbar=False, model_type: str = None) -> Dict[str, NDArray]:
+
+def demix(
+ config, model, mix: NDArray, device, pbar=False, model_type: str = None
+) -> Dict[str, NDArray]:
mix = torch.tensor(mix, dtype=torch.float32)
- if model_type == 'htdemucs':
+ if model_type == "htdemucs":
return demix_track_demucs(config, model, mix, device, pbar=pbar)
else:
return demix_track(config, model, mix, device, pbar=pbar)
diff --git a/tabs/full_inference.py b/tabs/full_inference.py
index 08eaafef..bb53dc5a 100644
--- a/tabs/full_inference.py
+++ b/tabs/full_inference.py
@@ -19,191 +19,113 @@
from assets.i18n.i18n import I18nAuto
-
-
-
i18n = I18nAuto()
-
now_dir = os.getcwd()
sys.path.append(now_dir)
-
model_root = os.path.join(now_dir, "logs")
audio_root = os.path.join(now_dir, "audio_files", "original_files")
-
model_root_relative = os.path.relpath(model_root, now_dir)
audio_root_relative = os.path.relpath(audio_root, now_dir)
-
sup_audioext = {
-
"wav",
-
"mp3",
-
"flac",
-
"ogg",
-
"opus",
-
"m4a",
-
"mp4",
-
"aac",
-
"alac",
-
"wma",
-
"aiff",
-
"webm",
-
"ac3",
-
}
-
names = [
-
os.path.join(root, file)
-
for root, _, files in os.walk(model_root_relative, topdown=False)
-
for file in files
-
if (
-
file.endswith((".pth", ".onnx"))
-
and not (file.startswith("G_") or file.startswith("D_"))
-
)
-
]
-
indexes_list = [
-
os.path.join(root, name)
-
for root, _, files in os.walk(model_root_relative, topdown=False)
-
for name in files
-
if name.endswith(".index") and "trained" not in name
-
]
-
audio_paths = [
-
os.path.join(root, name)
-
for root, _, files in os.walk(audio_root_relative, topdown=False)
-
for name in files
-
if name.endswith(tuple(sup_audioext))
-
and root == audio_root_relative
-
and "_output" not in name
-
]
-
vocals_model_names = [
-
"Mel-Roformer by KimberleyJSN",
-
"BS-Roformer by ViperX",
-
"MDX23C",
-
]
-
karaoke_models_names = [
-
"Mel-Roformer Karaoke by aufr33 and viperx",
-
"UVR-BVE",
-
]
-
denoise_models_names = [
-
"Mel-Roformer Denoise Normal by aufr33",
-
"Mel-Roformer Denoise Aggressive by aufr33",
-
"UVR Denoise",
-
]
-
dereverb_models_names = [
-
"MDX23C DeReverb by aufr33 and jarredou",
-
"UVR-Deecho-Dereverb",
-
"MDX Reverb HQ by FoxJoy",
-
"BS-Roformer Dereverb by anvuew",
-
]
-
deeecho_models_names = ["UVR-Deecho-Normal", "UVR-Deecho-Aggressive"]
-
-
-
def get_indexes():
indexes_list = [
-
os.path.join(dirpath, filename)
-
for dirpath, _, filenames in os.walk(model_root_relative)
-
for filename in filenames
-
if filename.endswith(".index") and "trained" not in filename
-
]
-
-
return indexes_list if indexes_list else ""
-
-
-
def match_index(model_file_value):
if model_file_value:
@@ -235,15 +157,10 @@ def match_index(model_file_value):
return ""
-
-
-
def output_path_fn(input_audio_path):
original_name_without_extension = os.path.basename(input_audio_path).rsplit(".", 1)[
-
0
-
]
new_name = original_name_without_extension + "_output.wav"
@@ -253,9 +170,6 @@ def output_path_fn(input_audio_path):
return output_path
-
-
-
def get_number_of_gpus():
if torch.cuda.is_available():
@@ -269,9 +183,6 @@ def get_number_of_gpus():
return "-"
-
-
-
def max_vram_gpu(gpu):
if torch.cuda.is_available():
@@ -287,15 +198,10 @@ def max_vram_gpu(gpu):
return "0"
-
-
-
def format_title(title):
formatted_title = (
-
unicodedata.normalize("NFKD", title).encode("ascii", "ignore").decode("utf-8")
-
)
formatted_title = re.sub(r"[\u2500-\u257F]+", "", formatted_title)
@@ -307,9 +213,6 @@ def format_title(title):
return formatted_title
-
-
-
def save_to_wav(upload_audio):
file_path = upload_audio
@@ -318,14 +221,10 @@ def save_to_wav(upload_audio):
target_path = os.path.join(audio_root_relative, formated_name)
-
-
if os.path.exists(target_path):
os.remove(target_path)
-
-
os.makedirs(os.path.dirname(target_path), exist_ok=True)
shutil.copy(file_path, target_path)
@@ -333,9 +232,6 @@ def save_to_wav(upload_audio):
return target_path, output_path_fn(target_path)
-
-
-
def delete_outputs():
gr.Info(f"Outputs cleared!")
@@ -349,115 +245,64 @@ def delete_outputs():
os.remove(os.path.join(root, name))
-
-
-
def change_choices():
names = [
-
os.path.join(root, file)
-
for root, _, files in os.walk(model_root_relative, topdown=False)
-
for file in files
-
if (
-
file.endswith((".pth", ".onnx"))
-
and not (file.startswith("G_") or file.startswith("D_"))
-
)
-
]
-
-
indexes_list = [
-
os.path.join(root, name)
-
for root, _, files in os.walk(model_root_relative, topdown=False)
-
for name in files
-
if name.endswith(".index") and "trained" not in name
-
]
-
-
audio_paths = [
-
os.path.join(root, name)
-
for root, _, files in os.walk(audio_root_relative, topdown=False)
-
for name in files
-
if name.endswith(tuple(sup_audioext))
-
and root == audio_root_relative
-
and "_output" not in name
-
]
-
-
return (
-
{"choices": sorted(names), "__type__": "update"},
-
{"choices": sorted(indexes_list), "__type__": "update"},
-
{"choices": sorted(audio_paths), "__type__": "update"},
-
)
-
-
-
def download_music_tab():
with gr.Row():
link = gr.Textbox(
-
label=i18n("Music URL"),
-
lines=1,
-
)
output = gr.Textbox(
-
label=i18n("Output Information"),
-
info=i18n("The output information will be displayed here."),
-
)
download = gr.Button(i18n("Download"))
-
- download.click(
+ download.click(
download_music,
-
inputs=[link],
-
outputs=[output],
-
)
-
-
-
-
-
def full_inference_tab():
default_weight = names[0] if names else None
@@ -467,67 +312,41 @@ def full_inference_tab():
with gr.Row():
model_file = gr.Dropdown(
-
label=i18n("Voice Model"),
-
info=i18n("Select the voice model to use for the conversion."),
-
choices=sorted(names, key=lambda path: os.path.getsize(path)),
-
interactive=True,
-
value=default_weight,
-
allow_custom_value=True,
-
)
index_file = gr.Dropdown(
-
label=i18n("Index File"),
-
info=i18n("Select the index file to use for the conversion."),
-
choices=get_indexes(),
-
value=match_index(default_weight) if default_weight else "",
-
interactive=True,
-
allow_custom_value=True,
-
)
with gr.Column():
with gr.Row():
unload_button = gr.Button(i18n("Unload Voice"))
- refresh_button = gr.Button(i18n("Refresh"))
-
+ refresh_button = gr.Button(i18n("Refresh"))
unload_button.click(
-
fn=lambda: (
-
{"value": "", "__type__": "update"},
-
{"value": "", "__type__": "update"},
-
),
-
inputs=[],
-
outputs=[model_file, index_file],
-
)
model_file.select(
-
fn=lambda model_file_value: match_index(model_file_value),
-
inputs=[model_file],
-
outputs=[index_file],
-
)
with gr.Tab(i18n("Single")):
@@ -535,33 +354,21 @@ def full_inference_tab():
with gr.Column():
upload_audio = gr.Audio(
-
label=i18n("Upload Audio"),
-
type="filepath",
-
editable=False,
-
sources="upload",
-
)
with gr.Row():
audio = gr.Dropdown(
-
label=i18n("Select Audio"),
-
info=i18n("Select the audio to convert."),
-
choices=sorted(audio_paths),
-
value=audio_paths[0] if audio_paths else "",
-
interactive=True,
-
allow_custom_value=True,
-
)
with gr.Accordion(i18n("Advanced Settings"), open=False):
@@ -569,955 +376,538 @@ def full_inference_tab():
with gr.Accordion(i18n("RVC Settings"), open=False):
output_path = gr.Textbox(
-
label=i18n("Output Path"),
-
placeholder=i18n("Enter output path"),
-
info=i18n(
-
"The path where the output audio will be saved, by default in audio_files/rvc/output.wav"
-
),
-
value=os.path.join(now_dir, "audio_files", "rvc"),
-
interactive=False,
-
visible=False,
-
)
infer_backing_vocals = gr.Checkbox(
-
label=i18n("Infer Backing Vocals"),
-
info=i18n("Infer the bakcing vocals too."),
-
visible=True,
-
value=False,
-
interactive=True,
-
)
with gr.Row():
infer_backing_vocals_model = gr.Dropdown(
-
label=i18n("Backing Vocals Model"),
-
info=i18n(
-
"Select the backing vocals model to use for the conversion."
-
),
-
choices=sorted(names, key=lambda path: os.path.getsize(path)),
-
interactive=True,
-
value=default_weight,
-
visible=False,
-
allow_custom_value=False,
-
)
infer_backing_vocals_index = gr.Dropdown(
-
label=i18n("Backing Vocals Index File"),
-
info=i18n(
-
"Select the backing vocals index file to use for the conversion."
-
),
-
choices=get_indexes(),
-
value=match_index(default_weight) if default_weight else "",
-
interactive=True,
-
visible=False,
-
allow_custom_value=True,
-
)
with gr.Column():
refresh_button_infer_backing_vocals = gr.Button(
-
i18n("Refresh"),
-
visible=False,
-
)
unload_button_infer_backing_vocals = gr.Button(
-
i18n("Unload Voice"),
-
visible=False,
-
)
-
-
unload_button_infer_backing_vocals.click(
-
fn=lambda: (
-
{"value": "", "__type__": "update"},
-
{"value": "", "__type__": "update"},
-
),
-
inputs=[],
-
outputs=[
-
infer_backing_vocals_model,
-
infer_backing_vocals_index,
-
],
-
)
infer_backing_vocals_model.select(
-
fn=lambda model_file_value: match_index(model_file_value),
-
inputs=[infer_backing_vocals_model],
-
outputs=[infer_backing_vocals_index],
-
)
with gr.Accordion(
-
i18n("RVC Settings for Backing vocals"), open=False, visible=False
-
) as back_rvc_settings:
export_format_rvc_back = gr.Radio(
-
label=i18n("Export Format"),
-
info=i18n("Select the format to export the audio."),
-
choices=["WAV", "MP3", "FLAC", "OGG", "M4A"],
-
value="MP3",
-
interactive=True,
-
visible=False,
-
)
split_audio_back = gr.Checkbox(
-
label=i18n("Split Audio"),
-
info=i18n(
-
"Split the audio into chunks for inference to obtain better results in some cases."
-
),
-
visible=True,
-
value=False,
-
interactive=True,
-
)
pitch_extract_back = gr.Radio(
-
label=i18n("Pitch Extractor"),
-
info=i18n("Pitch extract Algorith."),
-
choices=["rmvpe", "crepe", "crepe-tiny", "fcpe"],
-
value="rmvpe",
-
interactive=True,
-
)
hop_length_back = gr.Slider(
-
label=i18n("Hop Length"),
-
info=i18n("Hop length for pitch extraction."),
-
minimum=1,
-
maximum=512,
-
step=1,
-
value=64,
-
visible=False,
-
)
embedder_model_back = gr.Radio(
-
label=i18n("Embedder Model"),
-
info=i18n("Model used for learning speaker embedding."),
-
choices=[
-
"contentvec",
-
"chinese-hubert-base",
-
"japanese-hubert-base",
-
"korean-hubert-base",
-
],
-
value="contentvec",
-
interactive=True,
-
)
autotune_back = gr.Checkbox(
-
label=i18n("Autotune"),
-
info=i18n(
-
"Apply a soft autotune to your inferences, recommended for singing conversions."
-
),
-
visible=True,
-
value=False,
-
interactive=True,
-
)
pitch_back = gr.Slider(
-
label=i18n("Pitch"),
-
info=i18n("Adjust the pitch of the audio."),
-
minimum=-12,
-
maximum=12,
-
step=1,
-
value=0,
-
interactive=True,
-
)
filter_radius_back = gr.Slider(
-
minimum=0,
-
maximum=7,
-
label=i18n("Filter Radius"),
-
info=i18n(
-
"If the number is greater than or equal to three, employing median filtering on the collected tone results has the potential to decrease respiration."
-
),
-
value=3,
-
step=1,
-
interactive=True,
-
)
index_rate_back = gr.Slider(
-
minimum=0,
-
maximum=1,
-
label=i18n("Search Feature Ratio"),
-
info=i18n(
-
"Influence exerted by the index file; a higher value corresponds to greater influence. However, opting for lower values can help mitigate artifacts present in the audio."
-
),
-
value=0.75,
-
interactive=True,
-
)
rms_mix_rate_back = gr.Slider(
-
minimum=0,
-
maximum=1,
-
label=i18n("Volume Envelope"),
-
info=i18n(
-
"Substitute or blend with the volume envelope of the output. The closer the ratio is to 1, the more the output envelope is employed."
-
),
-
value=0.25,
-
interactive=True,
-
)
protect_back = gr.Slider(
-
minimum=0,
-
maximum=0.5,
-
label=i18n("Protect Voiceless Consonants"),
-
info=i18n(
-
"Safeguard distinct consonants and breathing sounds to prevent electro-acoustic tearing and other artifacts. Pulling the parameter to its maximum value of 0.5 offers comprehensive protection. However, reducing this value might decrease the extent of protection while potentially mitigating the indexing effect."
-
),
-
value=0.33,
-
interactive=True,
-
)
clear_outputs_infer = gr.Button(
-
i18n("Clear Outputs (Deletes all audios in assets/audios)")
-
)
export_format_rvc = gr.Radio(
-
label=i18n("Export Format"),
-
info=i18n("Select the format to export the audio."),
-
choices=["WAV", "MP3", "FLAC", "OGG", "M4A"],
-
value="FLAC",
-
interactive=True,
-
visible=False,
-
)
split_audio = gr.Checkbox(
-
label=i18n("Split Audio"),
-
info=i18n(
-
"Split the audio into chunks for inference to obtain better results in some cases."
-
),
-
visible=True,
-
value=False,
-
interactive=True,
-
)
pitch_extract = gr.Radio(
-
label=i18n("Pitch Extractor"),
-
info=i18n("Pitch extract Algorith."),
-
choices=["rmvpe", "crepe", "crepe-tiny", "fcpe"],
-
value="rmvpe",
-
interactive=True,
-
)
hop_length = gr.Slider(
-
label=i18n("Hop Length"),
-
info=i18n("Hop length for pitch extraction."),
-
minimum=1,
-
maximum=512,
-
step=1,
-
value=64,
-
visible=False,
-
)
embedder_model = gr.Radio(
-
label=i18n("Embedder Model"),
-
info=i18n("Model used for learning speaker embedding."),
-
choices=[
-
"contentvec",
-
"chinese-hubert-base",
-
"japanese-hubert-base",
-
"korean-hubert-base",
-
],
-
value="contentvec",
-
interactive=True,
-
)
autotune = gr.Checkbox(
-
label=i18n("Autotune"),
-
info=i18n(
-
"Apply a soft autotune to your inferences, recommended for singing conversions."
-
),
-
visible=True,
-
value=False,
-
interactive=True,
-
)
pitch = gr.Slider(
-
label=i18n("Pitch"),
-
info=i18n("Adjust the pitch of the audio."),
-
minimum=-12,
-
maximum=12,
-
step=1,
-
value=0,
-
interactive=True,
-
)
filter_radius = gr.Slider(
-
minimum=0,
-
maximum=7,
-
label=i18n("Filter Radius"),
-
info=i18n(
-
"If the number is greater than or equal to three, employing median filtering on the collected tone results has the potential to decrease respiration."
-
),
-
value=3,
-
step=1,
-
interactive=True,
-
)
index_rate = gr.Slider(
-
minimum=0,
-
maximum=1,
-
label=i18n("Search Feature Ratio"),
-
info=i18n(
-
"Influence exerted by the index file; a higher value corresponds to greater influence. However, opting for lower values can help mitigate artifacts present in the audio."
-
),
-
value=0.75,
-
interactive=True,
-
)
rms_mix_rate = gr.Slider(
-
minimum=0,
-
maximum=1,
-
label=i18n("Volume Envelope"),
-
info=i18n(
-
"Substitute or blend with the volume envelope of the output. The closer the ratio is to 1, the more the output envelope is employed."
-
),
-
value=0.25,
-
interactive=True,
-
)
protect = gr.Slider(
-
minimum=0,
-
maximum=0.5,
-
label=i18n("Protect Voiceless Consonants"),
-
info=i18n(
-
"Safeguard distinct consonants and breathing sounds to prevent electro-acoustic tearing and other artifacts. Pulling the parameter to its maximum value of 0.5 offers comprehensive protection. However, reducing this value might decrease the extent of protection while potentially mitigating the indexing effect."
-
),
-
value=0.33,
-
interactive=True,
-
)
with gr.Accordion(i18n("Audio Separation Settings"), open=False):
use_tta = gr.Checkbox(
-
label=i18n("Use TTA"),
-
info=i18n("Use Test Time Augmentation."),
-
visible=True,
-
value=False,
-
interactive=True,
-
)
batch_size = gr.Slider(
-
minimum=1,
-
maximum=24,
-
step=1,
-
label=i18n("Batch Size"),
-
info=i18n("Set the batch size for the separation."),
-
value=1,
-
interactive=True,
-
)
vocal_model = gr.Dropdown(
-
label=i18n("Vocals Model"),
-
info=i18n("Select the vocals model to use for the separation."),
-
choices=sorted(vocals_model_names),
-
interactive=True,
-
value="Mel-Roformer by KimberleyJSN",
-
allow_custom_value=False,
-
)
karaoke_model = gr.Dropdown(
-
label=i18n("Karaoke Model"),
-
info=i18n("Select the karaoke model to use for the separation."),
-
choices=sorted(karaoke_models_names),
-
interactive=True,
-
value="Mel-Roformer Karaoke by aufr33 and viperx",
-
allow_custom_value=False,
-
)
dereverb_model = gr.Dropdown(
-
label=i18n("Dereverb Model"),
-
info=i18n("Select the dereverb model to use for the separation."),
-
choices=sorted(dereverb_models_names),
-
interactive=True,
-
value="UVR-Deecho-Dereverb",
-
allow_custom_value=False,
-
)
deecho = gr.Checkbox(
-
label=i18n("Deeecho"),
-
info=i18n("Apply deeecho to the audio."),
-
visible=True,
-
value=True,
-
interactive=True,
-
)
deeecho_model = gr.Dropdown(
-
label=i18n("Deeecho Model"),
-
info=i18n("Select the deeecho model to use for the separation."),
-
choices=sorted(deeecho_models_names),
-
interactive=True,
-
value="UVR-Deecho-Normal",
-
allow_custom_value=False,
-
)
denoise = gr.Checkbox(
-
label=i18n("Denoise"),
-
info=i18n("Apply denoise to the audio."),
-
visible=True,
-
value=False,
-
interactive=True,
-
)
denoise_model = gr.Dropdown(
-
label=i18n("Denoise Model"),
-
info=i18n("Select the denoise model to use for the separation."),
-
choices=sorted(denoise_models_names),
-
interactive=True,
-
value="Mel-Roformer Denoise Normal by aufr33",
-
allow_custom_value=False,
-
visible=False,
-
)
with gr.Accordion(i18n("Audio post-process Settings"), open=False):
change_inst_pitch = gr.Slider(
-
label=i18n("Change Instrumental Pitch"),
-
info=i18n("Change the pitch of the instrumental."),
-
minimum=-12,
-
maximum=12,
-
step=1,
-
value=0,
-
interactive=True,
-
)
delete_audios = gr.Checkbox(
-
label=i18n("Delete Audios"),
-
info=i18n("Delete the audios after the conversion."),
-
visible=True,
-
value=False,
-
interactive=True,
-
)
reverb = gr.Checkbox(
-
label=i18n("Reverb"),
-
info=i18n("Apply reverb to the audio."),
-
visible=True,
-
value=False,
-
interactive=True,
-
)
reverb_room_size = gr.Slider(
-
minimum=0,
-
maximum=1,
-
label=i18n("Reverb Room Size"),
-
info=i18n("Set the room size of the reverb."),
-
value=0.5,
-
interactive=True,
-
visible=False,
-
)
-
-
reverb_damping = gr.Slider(
-
minimum=0,
-
maximum=1,
-
label=i18n("Reverb Damping"),
-
info=i18n("Set the damping of the reverb."),
-
value=0.5,
-
interactive=True,
-
visible=False,
-
)
-
-
reverb_wet_gain = gr.Slider(
-
minimum=0,
-
maximum=1,
-
label=i18n("Reverb Wet Gain"),
-
info=i18n("Set the wet gain of the reverb."),
-
value=0.33,
-
interactive=True,
-
visible=False,
-
)
-
-
reverb_dry_gain = gr.Slider(
-
minimum=0,
-
maximum=1,
-
label=i18n("Reverb Dry Gain"),
-
info=i18n("Set the dry gain of the reverb."),
-
value=0.4,
-
interactive=True,
-
visible=False,
-
)
-
-
reverb_width = gr.Slider(
-
minimum=0,
-
maximum=1,
-
label=i18n("Reverb Width"),
-
info=i18n("Set the width of the reverb."),
-
value=1.0,
-
interactive=True,
-
visible=False,
-
)
vocals_volume = gr.Slider(
-
label=i18n("Vocals Volume"),
-
info=i18n("Adjust the volume of the vocals."),
-
minimum=-10,
-
maximum=0,
-
step=1,
-
value=-3,
-
interactive=True,
-
)
instrumentals_volume = gr.Slider(
-
label=i18n("Instrumentals Volume"),
-
info=i18n("Adjust the volume of the Instrumentals."),
-
minimum=-10,
-
maximum=0,
-
step=1,
-
value=-3,
-
interactive=True,
-
)
backing_vocals_volume = gr.Slider(
-
label=i18n("Backing Vocals Volume"),
-
info=i18n("Adjust the volume of the backing vocals."),
-
minimum=-10,
-
maximum=0,
-
step=1,
-
value=-3,
-
interactive=True,
-
)
export_format_final = gr.Radio(
-
label=i18n("Export Format"),
-
info=i18n("Select the format to export the audio."),
-
choices=["WAV", "MP3", "FLAC", "OGG", "M4A"],
-
value="FLAC",
-
interactive=True,
-
)
with gr.Accordion(i18n("Device Settings"), open=False):
devices = gr.Textbox(
-
label=i18n("Device"),
-
info=i18n(
-
"Select the device to use for the conversion. 0 to ∞ separated by - and for CPU leave only an -"
-
),
-
value=get_number_of_gpus(),
-
interactive=True,
-
)
-
-
+
with gr.Row():
convert_button = gr.Button(i18n("Convert"))
-
+
with gr.Row():
vc_output1 = gr.Textbox(
-
- label=i18n("Output Information"),
-
- info=i18n("The output information will be displayed here."),
+ label=i18n("Output Information"),
+ info=i18n("The output information will be displayed here."),
)
vc_output2 = gr.Audio(label=i18n("Export Audio"))
- with gr.Tab(i18n("Download Music")):
+ with gr.Tab(i18n("Download Music")):
download_music_tab()
@@ -1525,285 +915,152 @@ def update_dropdown_visibility(checkbox):
return gr.update(visible=checkbox)
-
-
def update_reverb_sliders_visibility(reverb_checked):
return {
-
reverb_room_size: gr.update(visible=reverb_checked),
-
reverb_damping: gr.update(visible=reverb_checked),
-
reverb_wet_gain: gr.update(visible=reverb_checked),
-
reverb_dry_gain: gr.update(visible=reverb_checked),
-
reverb_width: gr.update(visible=reverb_checked),
-
}
-
-
def update_visibility_infer_backing(infer_backing_vocals):
visible = infer_backing_vocals
return (
-
{"visible": visible, "__type__": "update"},
-
{"visible": visible, "__type__": "update"},
-
{"visible": visible, "__type__": "update"},
-
{"visible": visible, "__type__": "update"},
-
{"visible": visible, "__type__": "update"},
-
)
-
-
def update_hop_length_visibility(pitch_extract_value):
return gr.update(visible=pitch_extract_value in ["crepe", "crepe-tiny"])
-
-
-
refresh_button.click(
-
fn=change_choices,
-
inputs=[],
-
outputs=[model_file, index_file, audio],
-
)
refresh_button_infer_backing_vocals.click(
-
fn=change_choices,
-
inputs=[],
-
outputs=[infer_backing_vocals_model, infer_backing_vocals_index],
-
)
upload_audio.upload(
-
fn=save_to_wav,
-
inputs=[upload_audio],
-
outputs=[audio, output_path],
-
)
clear_outputs_infer.click(
-
fn=delete_outputs,
-
inputs=[],
-
outputs=[],
-
)
convert_button.click(
-
full_inference_program,
-
inputs=[
-
model_file,
-
index_file,
-
audio,
-
output_path,
-
export_format_rvc,
-
split_audio,
-
autotune,
-
vocal_model,
-
karaoke_model,
-
dereverb_model,
-
deecho,
-
deeecho_model,
-
denoise,
-
denoise_model,
-
reverb,
-
vocals_volume,
-
instrumentals_volume,
-
backing_vocals_volume,
-
export_format_final,
-
devices,
-
pitch,
-
filter_radius,
-
index_rate,
-
rms_mix_rate,
-
protect,
-
pitch_extract,
-
hop_length,
-
reverb_room_size,
-
reverb_damping,
-
reverb_wet_gain,
-
reverb_dry_gain,
-
reverb_width,
-
embedder_model,
-
delete_audios,
-
use_tta,
-
batch_size,
-
infer_backing_vocals,
-
infer_backing_vocals_model,
-
infer_backing_vocals_index,
-
change_inst_pitch,
-
pitch_back,
-
filter_radius_back,
-
index_rate_back,
-
rms_mix_rate_back,
-
protect_back,
-
pitch_extract_back,
-
hop_length_back,
-
export_format_rvc_back,
-
split_audio_back,
-
autotune_back,
-
embedder_model_back,
-
],
-
outputs=[vc_output1, vc_output2],
-
)
-
-
deecho.change(
-
fn=update_dropdown_visibility,
-
inputs=deecho,
-
outputs=deeecho_model,
-
)
-
-
denoise.change(
-
fn=update_dropdown_visibility,
-
inputs=denoise,
-
outputs=denoise_model,
-
)
-
-
reverb.change(
-
fn=update_reverb_sliders_visibility,
-
inputs=reverb,
-
outputs=[
-
reverb_room_size,
-
reverb_damping,
-
reverb_wet_gain,
-
reverb_dry_gain,
-
reverb_width,
-
],
-
)
pitch_extract.change(
-
fn=update_hop_length_visibility,
-
inputs=pitch_extract,
-
outputs=hop_length,
-
)
-
-
infer_backing_vocals.change(
-
fn=update_visibility_infer_backing,
-
inputs=[infer_backing_vocals],
-
outputs=[
-
infer_backing_vocals_model,
-
infer_backing_vocals_index,
-
refresh_button_infer_backing_vocals,
-
unload_button_infer_backing_vocals,
-
back_rvc_settings,
-
],
-
)
diff --git a/tabs/settings.py b/tabs/settings.py
index 6ab6bec3..2734130c 100644
--- a/tabs/settings.py
+++ b/tabs/settings.py
@@ -39,7 +39,6 @@ def save_lang_settings(selected_language):
json.dump(config, file, indent=2)
-
def restart_applio():
if os.name != "nt":
os.system("clear")
@@ -49,11 +48,6 @@ def restart_applio():
os.execl(python, python, *sys.argv)
-
-
-
-
-
def lang_tab():
with gr.Column():
selected_language = gr.Dropdown(