Skip to content

This repository contains the code for the paper "Self-supervised Text Style Transfer using Cycle-Consistent Adversarial Networks".

Notifications You must be signed in to change notification settings

gallipoligiuseppe/TST-CycleGAN

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Self-supervised Text Style Transfer using Cycle-Consistent Adversarial Networks

This repository contains the code for the paper Self-supervised Text Style Transfer using Cycle-Consistent Adversarial Networks, published in ACM Transactions on Intelligent Systems and Technology.

It includes the Python package to train and test the CycleGAN architecture for Text Style Transfer described in the paper.

Installation

The following command will clone the project:

git clone https://github.com/gallipoligiuseppe/TST-CycleGAN.git

To install the required libraries and dependencies, you can refer to the env.yml file.

Before experimenting, you can create a virtual environment for the project using Conda.

conda create -f env.yml -n cyclegan_tst 
conda activate cyclegan_tst

The installation should also cover all the dependencies. If you find any missing dependency, please let us know by opening an issue.

Usage

The package provides the scripts to implement, train and test the CycleGAN architecture for Text Style Transfer described in the paper.

Specifically, we focus on formality (informal ↔ formal) and sentiment (negative ↔ positive) transfer tasks.

Data

Formality transfer

According to the dataset license, you can request access to the GYAFC dataset following the steps described in its official repository.

Once you have gained access, put it into the family_relationships and entertainment_music directories for the Family & Relationships and Entertainment & Music domains, respectively, under the data/GYAFC folder. Please name the files as [train|dev|test].[0|1].txt, where 0 is for informal style and 1 is for formal style.

We could provide access to mixed-style data we use in our work after gaining access to the GYAFC dataset and verifying the dataset license.

Sentiment transfer

We use the Yelp dataset following the same splits as in Li et al. available in the official repository. Put it into the data/yelp folder and please name the files as [train|dev|test].[0|1].txt, where 0 is for negative sentiment and 1 is for positive sentiment.

Training

You can train the proposed CycleGAN architecture for Text Style Transfer using the train.py script. It can be customized using several command line arguments such as:

  • style_a/style_b: style A/B (i.e., informal/formal or negative/positive)
  • generator_model_tag: tag or path of the generator model
  • discriminator_model_tag: tag or path of the discriminator model
  • pretrained_classifier_model: tag or path of the style classifier model
  • lambdas: loss weighting factors in the form "λ1|λ2|λ3|λ4|λ5" for cycle-consistency, generator, discriminator (fake), discriminator (real), and classifier-guided losses, respectively
  • path_mono_A/path_mono_B: path to the training dataset for style A/B
  • path_mono_A_eval/path_mono_B_eval: path to the validation dataset for style A/B (if references for validation are not available, as in the Yelp dataset)
  • path_paral_A_eval/path_paral_B_eval: path to the validation dataset for style A/B (if references for validation are available, as in the GYAFC dataset)
  • path_paral_eval_ref: path to the references for validation (if references available, as in the GYAFC dataset)
  • learning_rate, epochs, batch_size: learning rate, number of epochs and batch size for model training

As an example, to train the CycleGAN architecture for formality transfer using the GYAFC dataset (Family & Relationships domain), you can use the following command:

CUDA_VISIBLE_DEVICES=0 python train.py --style_a=informal --style_b=formal --lang=en \
                       --path_mono_A=./data/GYAFC/family_relationships/train.0.txt --path_mono_B=./data/GYAFC/family_relationships/train.1.txt \
                       --path_paral_A_eval=./data/GYAFC/family_relationships/dev.0.txt --path_paral_B_eval=./data/GYAFC/family_relationships/dev.1.txt --path_paral_eval_ref=./data/GYAFC/family_relationships/references/dev/ --n_references=4 --shuffle \
                       --generator_model_tag=google-t5/t5-large --discriminator_model_tag=distilbert-base-cased --pretrained_classifier_model=./classifiers/GYAFC/family_relationships/bert-base-cased_5/ \
                       --lambdas="10|1|1|1|1" --epochs=30 --learning_rate=5e-5 --max_sequence_length=64 --batch_size=8  \
                       --save_base_folder=./ckpts/ --save_steps=1 --eval_strategy=epochs --eval_steps=1  --pin_memory --use_cuda_if_available

Testing

Once trained, you can evaluate the performance on the test set of the trained models using the test.py script. It can be customized using several command line arguments such as:

  • style_a/style_b: style A/B (i.e., informal/formal or negative/positive)
  • generator_model_tag: tag or path of the generator model
  • discriminator_model_tag: tag or path of the discriminator model
  • from_pretrained: folder to use as base path to load the model checkpoint(s) to test
  • pretrained_classifier_eval: tag or path of the oracle classifier model
  • path_paral_A_test/path_paral_B_test: path to the test dataset for style A/B
  • path_paral_test_ref: path to the references for test

As an example, to test the trained models for formality transfer using the GYAFC dataset (Family & Relationships domain), you can use the following command:

CUDA_VISIBLE_DEVICES=0 python test.py --style_a=informal --style_b=formal --lang=en \
                       --path_paral_A_test=./data/GYAFC/family_relationships/test.0.txt --path_paral_B_test=./data/GYAFC/family_relationships/test.1.txt --path_paral_test_ref=./data/GYAFC/family_relationships/references/test/ --n_references=4 \
                       --generator_model_tag=google-t5/t5-large --discriminator_model_tag=distilbert-base-cased \
                       --pretrained_classifier_eval=./classifiers/GYAFC/family_relationships/bert-base-cased_5/ \
                       --from_pretrained=./ckpts/ --max_sequence_length=64 --batch_size=16 --pin_memory --use_cuda_if_available 

Model checkpoints

The checkpoints of the best-performing models for both Text Style Transfer tasks will be made available on Hugging Face 🤗 in the following weeks. We will include the corresponding links to the model checkpoints below.

License

Creative Commons License
This work is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.

Authors

Moreno La Quatra, Giuseppe Gallipoli, Luca Cagliero

Corresponding author

For any questions about the content of the paper or the implementation, you can contact me at: giuseppe[DOT]gallipoli[AT]polito[DOT]it.

Citation

If you find this work useful, please cite our paper:

@article{LaQuatra24TST,
author = {La Quatra, Moreno and Gallipoli, Giuseppe and Cagliero, Luca},
title = {Self-supervised Text Style Transfer using Cycle-Consistent Adversarial Networks},
year = {2024},
publisher = {Association for Computing Machinery},
address = {New York, NY, USA},
issn = {2157-6904},
url = {https://doi.org/10.1145/3678179},
doi = {10.1145/3678179},
journal = {ACM Trans. Intell. Syst. Technol.},
month = {jul},
keywords = {Text Style Transfer, Sentiment transfer, Formality transfer, Cycle-consistent Generative Adversarial Networks, Transformers}
}

About

This repository contains the code for the paper "Self-supervised Text Style Transfer using Cycle-Consistent Adversarial Networks".

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages