Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

lilianweng/transformer-tensorflow

Repository files navigation

Transformer

Implementation of the Transformer model in the paper:

Ashish Vaswani, et al. "Attention is all you need." NIPS 2017.

Transformer model

Check my blog post on attention and transformer:

Implementations that helped me:

Setup

$ git clone https://github.com/lilianweng/transformer-tensorflow.git
$ cd transformer-tensorflow
$ pip install -r requirements.txt

Train a Model

# Check the help message:

$ python train.py --help

Usage: train.py [OPTIONS]

Options:
  --seq-len INTEGER               Input sequence length.  [default: 20]
  --d-model INTEGER               d_model  [default: 512]
  --d-ff INTEGER                  d_ff  [default: 2048]
  --n-head INTEGER                n_head  [default: 8]
  --batch-size INTEGER            Batch size  [default: 128]
  --max-steps INTEGER             Max train steps.  [default: 300000]
  --dataset [iwslt15|wmt14|wmt15]
                                  Which translation dataset to use.  [default:
                                  iwslt15]
  --help                          Show this message and exit.

# Train a model on dataset WMT14:

$ python train.py --dataset wmt14

Evaluate a Trained Model

Let's say, the model is saved in folder transformer-wmt14-seq20-d512-head8-1541573730 in checkpoints folder.

$ python eval.py transformer-wmt14-seq20-d512-head8-1541573730

With the default config, this implementation gets BLEU ~ 20 on wmt14 test set.

Implementation Notes

[WIP] A couple of tricking points in the implementation.

  • How to construct the mask correctly?
  • How to correctly shift decoder input (as training input) and decoder target (as ground truth in the loss function)?
  • How to make the prediction in an autoregressive way?
  • Keeping the embedding of <pad> as a constant zero vector is sorta important.

Releases

No releases published

Packages

No packages published

Languages

Morty Proxy This is a proxified and sanitized view of the page, visit original site.