diff --git a/machine/corpora/corpora_utils.py b/machine/corpora/corpora_utils.py index 04d7852..2c1e0bc 100644 --- a/machine/corpora/corpora_utils.py +++ b/machine/corpora/corpora_utils.py @@ -54,7 +54,11 @@ def get_files(file_patterns: Iterable[str]) -> Iterable[Tuple[str, str]]: if len(file_patterns) == 1 and os.path.isfile(file_patterns[0]): yield ("*all*", file_patterns[0]) else: - for file_pattern in file_patterns: + for i, file_pattern in enumerate(file_patterns): + if os.path.isfile(file_pattern): + yield (str(i), file_pattern) + continue + if "*" not in file_pattern and "?" not in file_pattern and not os.path.exists(file_pattern): raise FileNotFoundError(f"The specified path does not exist: {file_pattern}.") diff --git a/machine/jobs/translation_file_service.py b/machine/jobs/translation_file_service.py index e1b2794..a545ccf 100644 --- a/machine/jobs/translation_file_service.py +++ b/machine/jobs/translation_file_service.py @@ -1,6 +1,6 @@ from contextlib import contextmanager from pathlib import Path -from typing import Any, Generator, Iterator, List, TypedDict +from typing import Any, Generator, Iterator, List, Optional, TypedDict, Union import json_stream @@ -21,40 +21,56 @@ class PretranslationInfo(TypedDict): alignment: str -SOURCE_FILENAME = "train.src.txt" -TARGET_FILENAME = "train.trg.txt" -SOURCE_PRETRANSLATION_FILENAME = "pretranslate.src.json" -TARGET_PRETRANSLATION_FILENAME = "pretranslate.trg.json" - - class TranslationFileService: def __init__( self, type: SharedFileServiceType, config: Any, + source_filenames: Optional[Union[str, List[str]]] = None, + target_filenames: Optional[Union[str, List[str]]] = None, + source_pretranslation_filename: str = "pretranslate.src.json", + target_pretranslation_filename: str = "pretranslate.trg.json", ) -> None: + if source_filenames is None: + source_filenames = ["train.src.txt", "train.key-terms.src.txt"] + if target_filenames is None: + target_filenames = ["train.trg.txt", "train.key-terms.trg.txt"] + + self._source_filenames = [source_filenames] if isinstance(source_filenames, str) else list(source_filenames) + self._target_filenames = [target_filenames] if isinstance(target_filenames, str) else list(target_filenames) + self._source_pretranslation_filename = source_pretranslation_filename + self._target_pretranslation_filename = target_pretranslation_filename + self.shared_file_service: SharedFileServiceBase = get_shared_file_service(type, config) def create_source_corpus(self) -> TextCorpus: return TextFileTextCorpus( - self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{SOURCE_FILENAME}") + self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{source_filename}") + for source_filename in self._source_filenames ) def create_target_corpus(self) -> TextCorpus: return TextFileTextCorpus( - self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{TARGET_FILENAME}") + self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{target_filename}") + for target_filename in self._target_filenames ) def exists_source_corpus(self) -> bool: - return self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{SOURCE_FILENAME}") + return all( + self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{source_filename}") + for source_filename in self._source_filenames + ) def exists_target_corpus(self) -> bool: - return self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{TARGET_FILENAME}") + return all( + self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{target_filename}") + for target_filename in self._target_filenames + ) def get_source_pretranslations(self) -> ContextManagedGenerator[PretranslationInfo, None, None]: src_pretranslate_path = self.shared_file_service.download_file( - f"{self.shared_file_service.build_path}/{SOURCE_PRETRANSLATION_FILENAME}" + f"{self.shared_file_service.build_path}/{self._source_pretranslation_filename}" ) def generator() -> Generator[PretranslationInfo, None, None]: @@ -77,4 +93,4 @@ def save_model(self, model_path: Path, destination: str) -> None: @contextmanager def open_target_pretranslation_writer(self) -> Iterator[DictToJsonWriter]: - return self.shared_file_service.open_target_writer(TARGET_PRETRANSLATION_FILENAME) + return self.shared_file_service.open_target_writer(self._target_pretranslation_filename) diff --git a/machine/jobs/word_alignment_file_service.py b/machine/jobs/word_alignment_file_service.py index 851ab1c..831d07b 100644 --- a/machine/jobs/word_alignment_file_service.py +++ b/machine/jobs/word_alignment_file_service.py @@ -1,6 +1,6 @@ from contextlib import contextmanager from pathlib import Path -from typing import Any, Iterator, List, TypedDict +from typing import Any, Iterator, List, Optional, TypedDict, Union import json_stream @@ -23,14 +23,19 @@ def __init__( self, type: SharedFileServiceType, config: Any, - source_filename: str = "train.src.txt", - target_filename: str = "train.trg.txt", + source_filenames: Optional[Union[str, List[str]]] = None, + target_filenames: Optional[Union[str, List[str]]] = None, word_alignment_input_filename: str = "word_alignments.inputs.json", word_alignment_output_filename: str = "word_alignments.outputs.json", ) -> None: - self._source_filename = source_filename - self._target_filename = target_filename + if source_filenames is None: + source_filenames = ["train.src.txt", "train.key-terms.src.txt"] + if target_filenames is None: + target_filenames = ["train.trg.txt", "train.key-terms.trg.txt"] + + self._source_filenames = [source_filenames] if isinstance(source_filenames, str) else list(source_filenames) + self._target_filenames = [target_filenames] if isinstance(target_filenames, str) else list(target_filenames) self._word_alignment_input_filename = word_alignment_input_filename self._word_alignment_output_filename = word_alignment_output_filename @@ -38,12 +43,14 @@ def __init__( def create_source_corpus(self) -> TextCorpus: return TextFileTextCorpus( - self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{self._source_filename}") + self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{source_filename}") + for source_filename in self._source_filenames ) def create_target_corpus(self) -> TextCorpus: return TextFileTextCorpus( - self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{self._target_filename}") + self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{target_filename}") + for target_filename in self._target_filenames ) def get_word_alignment_inputs(self) -> List[WordAlignmentInput]: @@ -64,10 +71,16 @@ def get_word_alignment_inputs(self) -> List[WordAlignmentInput]: return wa_inputs def exists_source_corpus(self) -> bool: - return self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{self._source_filename}") + return all( + self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{source_filename}") + for source_filename in self._source_filenames + ) def exists_target_corpus(self) -> bool: - return self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{self._target_filename}") + return all( + self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{target_filename}") + for target_filename in self._target_filenames + ) def exists_word_alignment_inputs(self) -> bool: return self.shared_file_service._exists_file( diff --git a/tests/corpora/test_text_file_text_corpus.py b/tests/corpora/test_text_file_text_corpus.py index ef6fc0e..dc7cbd9 100644 --- a/tests/corpora/test_text_file_text_corpus.py +++ b/tests/corpora/test_text_file_text_corpus.py @@ -14,6 +14,17 @@ def test_folder() -> None: assert [t.id for t in corpus.texts] == ["Test1", "Test2", "Test3"] +def test_multiple_files() -> None: + corpus = TextFileTextCorpus( + [ + TEXT_TEST_PROJECT_PATH / "Test1.txt", + TEXT_TEST_PROJECT_PATH / "Test2.txt", + TEXT_TEST_PROJECT_PATH / "Test3.txt", + ] + ) + assert [t.id for t in corpus.texts] == ["0", "1", "2"] + + def test_single_file() -> None: corpus = TextFileTextCorpus(TEXT_TEST_PROJECT_PATH / "Test1.txt") assert [t.id for t in corpus.texts] == ["*all*"]