Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion machine/corpora/corpora_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,11 @@ def get_files(file_patterns: Iterable[str]) -> Iterable[Tuple[str, str]]:
if len(file_patterns) == 1 and os.path.isfile(file_patterns[0]):
yield ("*all*", file_patterns[0])
else:
for file_pattern in file_patterns:
for i, file_pattern in enumerate(file_patterns):
if os.path.isfile(file_pattern):
yield (str(i), file_pattern)
continue

if "*" not in file_pattern and "?" not in file_pattern and not os.path.exists(file_pattern):
raise FileNotFoundError(f"The specified path does not exist: {file_pattern}.")

Expand Down
42 changes: 29 additions & 13 deletions machine/jobs/translation_file_service.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Generator, Iterator, List, TypedDict
from typing import Any, Generator, Iterator, List, Optional, TypedDict, Union

import json_stream

Expand All @@ -21,40 +21,56 @@ class PretranslationInfo(TypedDict):
alignment: str


SOURCE_FILENAME = "train.src.txt"
TARGET_FILENAME = "train.trg.txt"
SOURCE_PRETRANSLATION_FILENAME = "pretranslate.src.json"
TARGET_PRETRANSLATION_FILENAME = "pretranslate.trg.json"


class TranslationFileService:
def __init__(
self,
type: SharedFileServiceType,
config: Any,
source_filenames: Optional[Union[str, List[str]]] = None,
target_filenames: Optional[Union[str, List[str]]] = None,
source_pretranslation_filename: str = "pretranslate.src.json",
target_pretranslation_filename: str = "pretranslate.trg.json",
) -> None:

if source_filenames is None:
source_filenames = ["train.src.txt", "train.key-terms.src.txt"]
if target_filenames is None:
target_filenames = ["train.trg.txt", "train.key-terms.trg.txt"]

self._source_filenames = [source_filenames] if isinstance(source_filenames, str) else list(source_filenames)
self._target_filenames = [target_filenames] if isinstance(target_filenames, str) else list(target_filenames)
self._source_pretranslation_filename = source_pretranslation_filename
self._target_pretranslation_filename = target_pretranslation_filename

self.shared_file_service: SharedFileServiceBase = get_shared_file_service(type, config)

def create_source_corpus(self) -> TextCorpus:
return TextFileTextCorpus(
self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{SOURCE_FILENAME}")
self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{source_filename}")
for source_filename in self._source_filenames
)

def create_target_corpus(self) -> TextCorpus:
return TextFileTextCorpus(
self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{TARGET_FILENAME}")
self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{target_filename}")
for target_filename in self._target_filenames
)

def exists_source_corpus(self) -> bool:
return self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{SOURCE_FILENAME}")
return all(
self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{source_filename}")
for source_filename in self._source_filenames
)

def exists_target_corpus(self) -> bool:
return self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{TARGET_FILENAME}")
return all(
self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{target_filename}")
for target_filename in self._target_filenames
)

def get_source_pretranslations(self) -> ContextManagedGenerator[PretranslationInfo, None, None]:
src_pretranslate_path = self.shared_file_service.download_file(
f"{self.shared_file_service.build_path}/{SOURCE_PRETRANSLATION_FILENAME}"
f"{self.shared_file_service.build_path}/{self._source_pretranslation_filename}"
)

def generator() -> Generator[PretranslationInfo, None, None]:
Expand All @@ -77,4 +93,4 @@ def save_model(self, model_path: Path, destination: str) -> None:

@contextmanager
def open_target_pretranslation_writer(self) -> Iterator[DictToJsonWriter]:
return self.shared_file_service.open_target_writer(TARGET_PRETRANSLATION_FILENAME)
return self.shared_file_service.open_target_writer(self._target_pretranslation_filename)
31 changes: 22 additions & 9 deletions machine/jobs/word_alignment_file_service.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Iterator, List, TypedDict
from typing import Any, Iterator, List, Optional, TypedDict, Union

import json_stream

Expand All @@ -23,27 +23,34 @@ def __init__(
self,
type: SharedFileServiceType,
config: Any,
source_filename: str = "train.src.txt",
target_filename: str = "train.trg.txt",
source_filenames: Optional[Union[str, List[str]]] = None,
target_filenames: Optional[Union[str, List[str]]] = None,
word_alignment_input_filename: str = "word_alignments.inputs.json",
word_alignment_output_filename: str = "word_alignments.outputs.json",
) -> None:

self._source_filename = source_filename
self._target_filename = target_filename
if source_filenames is None:
source_filenames = ["train.src.txt", "train.key-terms.src.txt"]
if target_filenames is None:
target_filenames = ["train.trg.txt", "train.key-terms.trg.txt"]

self._source_filenames = [source_filenames] if isinstance(source_filenames, str) else list(source_filenames)
self._target_filenames = [target_filenames] if isinstance(target_filenames, str) else list(target_filenames)
self._word_alignment_input_filename = word_alignment_input_filename
self._word_alignment_output_filename = word_alignment_output_filename

self.shared_file_service: SharedFileServiceBase = get_shared_file_service(type, config)

def create_source_corpus(self) -> TextCorpus:
return TextFileTextCorpus(
self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{self._source_filename}")
self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{source_filename}")
for source_filename in self._source_filenames
)

def create_target_corpus(self) -> TextCorpus:
return TextFileTextCorpus(
self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{self._target_filename}")
self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{target_filename}")
for target_filename in self._target_filenames
)

def get_word_alignment_inputs(self) -> List[WordAlignmentInput]:
Expand All @@ -64,10 +71,16 @@ def get_word_alignment_inputs(self) -> List[WordAlignmentInput]:
return wa_inputs

def exists_source_corpus(self) -> bool:
return self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{self._source_filename}")
return all(
self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{source_filename}")
for source_filename in self._source_filenames
)

def exists_target_corpus(self) -> bool:
return self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{self._target_filename}")
return all(
self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{target_filename}")
for target_filename in self._target_filenames
)

def exists_word_alignment_inputs(self) -> bool:
return self.shared_file_service._exists_file(
Expand Down
11 changes: 11 additions & 0 deletions tests/corpora/test_text_file_text_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,17 @@ def test_folder() -> None:
assert [t.id for t in corpus.texts] == ["Test1", "Test2", "Test3"]


def test_multiple_files() -> None:
corpus = TextFileTextCorpus(
[
TEXT_TEST_PROJECT_PATH / "Test1.txt",
TEXT_TEST_PROJECT_PATH / "Test2.txt",
TEXT_TEST_PROJECT_PATH / "Test3.txt",
]
)
assert [t.id for t in corpus.texts] == ["0", "1", "2"]


def test_single_file() -> None:
corpus = TextFileTextCorpus(TEXT_TEST_PROJECT_PATH / "Test1.txt")
assert [t.id for t in corpus.texts] == ["*all*"]
Expand Down