This repo contains an Pytorch implementation for the ACL 2017 paper Get To The Point: Summarization with Pointer-Generator Networks. The code framework is based on TextBox.
python >= 3.8.11torch >= 1.6.0
Run install.sh to install other requirements.
The processed dataset can be downloaded from Google Drive. Once finished, unzip the datafiles (train.src, train.tgt, ...) to ./data.
An overview of dataset: train: 287113 cases, dev: 13368 cases, test: 11490 cases
# overall settings
data_path: 'data/'
checkpoint_dir: 'saved/'
generated_text_dir: 'generated/'
# dataset settings
max_vocab_size: 50000
src_len: 400
tgt_len: 100
# model settngs
decoding_strategy: 'beam_search'
beam_size: 4
is_attention: True
is_pgen: True
is_coverage: True
cov_loss_lambda: 1.0Log file is located in ./log, more details can be found in yamls.
Note: Distributed Data Parallel (DDP) is not supported yet.
if __name__ == '__main__':
config = Config(config_dict={'test_only': False,
'load_experiment': None})
train(config)If you want to resume from a checkpoint, just set the 'load_experiment': './saved/$model_name$.pth'. Similarly, when 'test_only' is set to True, 'load_experiment' is required.
The best model is trained on a TITAN Xp GPU (10GB usage).
| Model | Rouge-1 | Rouge-2 | Rouge-L |
|---|---|---|---|
| Seq2Seq | 22.17 | 7.20 | 20.97 |
| Seq2Seq+attn | 29.35 | 12.58 | 27.38 |
| Seq2Seq+attn+pgen | 39.05 | 18.41 | 36.02 |
| Seq2Seq+attn+pgen+coverage | 41.52 | 19.82 | 38.53 |
Note: The architecture of the Seq2Seq model is based on lstm, I hope I can replace it with transformer in the future.
