diff --git a/scripts/helper.py b/scripts/helper.py index 842758bb..d7e9a813 100644 --- a/scripts/helper.py +++ b/scripts/helper.py @@ -3,6 +3,7 @@ import os import pandas as pd +import pathlib def split_data(data_path, time_column_name, split_date): @@ -16,28 +17,29 @@ def split_data(data_path, time_column_name, split_date): if path not in (train_data_path, inference_data_path)] for file in files_list: - file_name = os.path.basename(file) - file_extension = os.path.splitext(file_name)[1].lower() - df = read_file(file, file_extension) + df = read_file(file) before_split_date = df[time_column_name] < split_date train_df, inference_df = df[before_split_date], df[~before_split_date] - write_file(train_df, os.path.join(train_data_path, file_name), file_extension) - write_file(inference_df, os.path.join(inference_data_path, file_name), file_extension) + file_name = os.path.basename(file) + write_file(train_df, os.path.join(train_data_path, file_name)) + write_file(inference_df, os.path.join(inference_data_path, file_name)) return train_data_path, inference_data_path -def read_file(path, extension): - if extension == ".parquet": +def read_file(path): + extension = pathlib.Path(path.lower()).suffix + if extension == "parquet": return pd.read_parquet(path) else: return pd.read_csv(path) -def write_file(data, path, extension): - if extension == ".parquet": +def write_file(data, path): + extension = pathlib.Path(path.lower()).suffix + if extension == "parquet": data.to_parquet(path) else: data.to_csv(path, index=None, header=True)