https://github.com/google/trax

# Preamble

In [None]:
import os
import numpy as np

!pip install -q -U trax
import trax

In [None]:
import time

from IPython.display import Audio

f = 400
srate = 10000
duration = 0.1
beep = np.sin(2*np.pi*f *np.arange(duration*srate)/srate)
Audio(beep, rate=srate, autoplay=True)

# Test run

This code needs
* ```ende_32k.subword``` in folder ```./data```
* and ```ende_wmt32k.pkl.gz``` also in folder ```./data```


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!ls drive/MyDrive/Y2024/ANN/Transformer/trax

data	  mycore.py  mytraining.py  __pycache__		       TraxTransformer.ipynb
myatt.py  mymod.py   output_dir     TraceTrainingV5_WMT.ipynb


In [None]:
!pwd

/content


In [None]:
import sys
sys.path.append("/content/drive/MyDrive/Y2024/ANN/Transformer/trax") # my path
## Have it point to your path! This path has folder "data"

In [None]:
%cd /content/drive/MyDrive/Y2024/ANN/Transformer/trax

/content/drive/MyDrive/Y2024/ANN/Transformer/trax


In [None]:
!ls ./data

ende_32k.subword    wmt14_translate_de-en_train.csv
ende_wmt32k.pkl.gz  wmt14_translate_de-en_validation.csv


If you see ```ende_32k.subword``` and ```ende_wmt32k.pkl.gz``` above, you are good to go.

Otherwise, work this out.

In [None]:
"""
Based on an example code on https://github.com/google/trax.

Original location 'gs://trax-ml/models/translation/ende_wmt32k.pkl.gz'
is replaced with downloaded './data/ende_wmt32k.pkl.gz'

So is the google server location 'gs://trax-ml/vocabs/' with './data/'
"""


# Create a Transformer model.
# Pre-trained model config in gs://trax-ml/models/translation/ende_wmt32k.gin
model = trax.models.Transformer(
    input_vocab_size=33300,
    d_model=512, d_ff=2048,
    n_heads=8, n_encoder_layers=6, n_decoder_layers=6,
    max_len=2048, mode='predict')

# Initialize using pre-trained weights.
# model.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz', weights_only=True)
model.init_from_file('./data/ende_wmt32k.pkl.gz', weights_only=True)

# Tokenize a sentence.
sentence = 'You can cut all the flowers but you cannot keep spring from coming.'
src_tokenized = list(trax.data.tokenize(iter([sentence]),  # Operates on streams.
                                    vocab_dir='./data/',
                                    vocab_file='ende_32k.subword'))[0]

# Decode from the Transformer.
tokenized = src_tokenized[None, :]  # Add batch dimension.
tokenized_translation = trax.supervised.decoding.autoregressive_sample(
    model, tokenized, temperature=0.0)  # Higher temperature: more diverse results.

# De-tokenize,
tokenized_translation = tokenized_translation[0][:-1]  # Remove batch and EOS.
translation = trax.data.detokenize(tokenized_translation,
                                    vocab_dir='./data/',
                                   vocab_file='ende_32k.subword')

print('Source:', sentence)
print('>>> Tokenized=', src_tokenized)
print('>>> Output tokens:', tokenized_translation)
print('Translation:', translation)
Audio(beep, rate=srate, autoplay=True)

Source: You can cut all the flowers but you cannot keep spring from coming.
>>> Tokenized= [  415    85  1977    83     4 15286   466   101    72   442  1356  7191
    55  2097     3]
>>> Output tokens: [   67   119   233 18690    23 27551     5     2   163    67   119    44
  1169 12869    21   710     3]
Translation: Sie können alle Blumen schneiden, aber Sie können nicht halten Frühling von kommen.


See the german translation.

"Sie können alle Blumen schneiden, aber Sie können nicht halten Frühling von kommen."

Then, it's working now.

# Tokenization

We need to break down a sentence to tokens.

In [None]:
tokenizer = trax.data.Tokenize(vocab_file='ende_32k.subword', vocab_dir='./data/')

sentence = "Be the change you wish to see in the world."
tokenized = list(tokenizer([sentence]))[0]
print(tokenized)

[11715     4   436    72  1057     9   372     6     4   172     3]


In [None]:
for w in sentence.split(' '):
    tokenized = list(tokenizer([w]))[0]
    print(w, '>>>', tokenized)

Be >>> [11715]
the >>> [4]
change >>> [436]
you >>> [72]
wish >>> [1057]
to >>> [9]
see >>> [372]
in >>> [6]
the >>> [4]
world. >>> [172   3]


In [None]:
sentence = "MooDeng is a pygmy hippo."

tokens = list(tokenizer([sentence]))[0]
print(sentence)
print(tokens)
print()

for w in sentence.split(' '):
    tokenized = list(tokenizer([w]))[0]
    print(w, '>>>', tokenized)

MooDeng is a pygmy hippo.
[ 8493  8032   299    16    13 22958 17364   105  3481 14875     5     3]

MooDeng >>> [8493 8032  299]
is >>> [16]
a >>> [13]
pygmy >>> [22958 17364   105]
hippo. >>> [ 3481 14875     5     3]


In [None]:
for t in [8493, 8032,  299]:
    w = trax.data.detokenize([t], vocab_dir='./data/', vocab_file='ende_32k.subword')
    print(t, '>>>', w)

8493 >>> Moo
8032 >>> Den
299 >>> g


Note on tokenization

1. It's not a world standard. Different tokenizer works differently.

2. A word can be a token. Some words can be many tokens. Many tokens can be composed to be a word.

# Single-shot run

In [None]:
from trax import layers as tl

inputs = np.array([[415, 85, 1977, 83, 4, 15286, 466,
                    101, 72, 442, 1356, 7191, 55, 2097, 3]])
outputs = np.array([[0, 67, 119, 233, 18690,    23, 27551,     5,     2,   163,
                     67, 119,  44,  1169, 12869,    21,   710,     3]])

model.init_from_file('./data/ende_wmt32k.pkl.gz', weights_only=True)

last_index = -1
for i, y in enumerate(outputs[0]):

    current_symbols = np.array([[y]])

    logits = model((inputs, current_symbols))[0]
    print('y in:', y, '; output logits.shape=', logits.shape, end='>>>')

    logits = tl.log_softmax(logits[:, last_index, :])
    eprob = np.exp(logits)

    TopN = 1
    ids = np.argsort(eprob[0,:])[-TopN:][::-1]
    print(' top ids:', ids)
    print()

Audio(beep, rate=srate, autoplay=True)

y in: 0 ; output logits.shape= (1, 1, 33300)>>> top ids: [67]

y in: 67 ; output logits.shape= (1, 1, 33300)>>> top ids: [119]

y in: 119 ; output logits.shape= (1, 1, 33300)>>> top ids: [233]

y in: 233 ; output logits.shape= (1, 1, 33300)>>> top ids: [18690]

y in: 18690 ; output logits.shape= (1, 1, 33300)>>> top ids: [23]

y in: 23 ; output logits.shape= (1, 1, 33300)>>> top ids: [27551]

y in: 27551 ; output logits.shape= (1, 1, 33300)>>> top ids: [5]

y in: 5 ; output logits.shape= (1, 1, 33300)>>> top ids: [2]

y in: 2 ; output logits.shape= (1, 1, 33300)>>> top ids: [163]

y in: 163 ; output logits.shape= (1, 1, 33300)>>> top ids: [67]

y in: 67 ; output logits.shape= (1, 1, 33300)>>> top ids: [119]

y in: 119 ; output logits.shape= (1, 1, 33300)>>> top ids: [44]

y in: 44 ; output logits.shape= (1, 1, 33300)>>> top ids: [1169]

y in: 1169 ; output logits.shape= (1, 1, 33300)>>> top ids: [12869]

y in: 12869 ; output logits.shape= (1, 1, 33300)>>> top ids: [21]

y in: 21 ; outp

# Train transformer

machine translation example

* [WMT 2024 English - German](https://www.kaggle.com/datasets/mohamedlotfy50/wmt-2014-english-german)

as used in Vaswani et al. 2017's work.


Machine translation ```model``` needs both ```source``` and ```current_symbol```


[Trax](https://github.com/google/trax)
```python=
    current_symbol = np.array([[yi]])    

    logits = model((inputs, current_symbol))[0]
```

This seems to be properly handled by ```training.Loop``` and the ```Transformer``` model.

Recall

* ```Transformer``` model has [```tl.ShiftRight(mode=mode)```](https://github.com/google/trax/blob/master/trax/models/transformer.py#L379)
  * Target ground-truth, $\mathbf{Y} = [y_1, y_2, \ldots, y_L]$ is used as
    * Ground-truth for the next tokens $[y_1, y_2, \ldots, y_L]$ to compare against the prediction $\hat{\mathbf{Y}} = [\hat{y}_1, \hat{y}_2, \ldots, \hat{y}_L]$
    * [Shift-right](https://github.com/google/trax/blob/master/trax/layers/attention.py#L576) (equivalent to going back one time step), $[0, y_1, y_2, \ldots, y_{L-1}]$. This is used as decoder input during the training. The shift-right is turned off during prediction.
      
      ```
        def ShiftRight(n_positions=1, mode='train'):
          # TODO(jonni): Include pad arg, like PaddingMask, to allow non-default pads?
          def f(x):
            if mode == 'predict':
              return x
            padded = _zero_pad(x, (n_positions, 0), 1)
            return padded[:, :-n_positions]
          return Fn(f'ShiftRight({n_positions})', f)      
      ```


* Causal attention is where the difference in predict and train modes are the most pronunced. This contrast can apparently seen in [DotProductCausalAttention](https://github.com/google/trax/blob/master/trax/layers/attention.py#L488)

```python=
def forward(self, inputs):

    q, k, v = inputs
    
    if self._portal_mask is not None:
      mask_for_predict = self._portal_mask.get_value()
    else:
      mask_for_predict = None
    
    if self._mode == 'predict':
      self.state, mask = _fast_inference_update_state(
          inputs, self.state,
          mask_for_predict=mask_for_predict)
      if self._portal_mask is not None:
        (_, k, v, _) = self.state
      else:
        (k, v, _) = self.state
    else:
      sequence_length = q.shape[-2]
      mask = _causal_mask(sequence_length)
    
    activations, attn_strengths = _per_head_attention(
        q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=self.rng)
    if self._mode == 'viz':
      self.state = attn_strengths
    return activations
```                       

1. As this ```inputs``` is of the decoder, this ```inputs```, $\psi$ is:

   \begin{align}
   \psi = \begin{cases}
   \tilde{y} & \mbox{ in predict mode,} \\
   \mathrm{shiftright}(\tilde{y}) & \mbox{ in train mode.} \\
   \end{cases}
   \end{align}

   where
   \begin{align}
   \tilde{y} = \begin{cases}
   \hat{y}_t & \mbox{ in predict mode,} \\
   [y_1, y_2, \ldots, y_L] & \mbox{ in train mode.} \\
   \end{cases}
   \end{align}

2. In predict mode, ```k``` and ```v``` get values from ```self.state```, which keeps tracks of current and previous target tokens: ```k``` and ```v``` are associated to $[y_1, y_2, \ldots, y_t]$ (previous and current target tokens).

3. In train mode, ```k``` and ```v``` get values directly from ```inputs```, which associate to shifted target tokens $[0, y_1, y_2, \ldots, y_{L-1}]$.

4. In predict mode, ```mask``` is assigned to mask out any association to future position (position $> t$).

5. In train mode, ```mask``` is set to be a triangular matrix, masking out any association to future position (position $> t$).

    >```
    def _causal_mask(length):
      if fastmath.is_backend(fastmath.Backend.JAX):
        return jnp.tril(jnp.ones((1, length, length), dtype=np.bool_), k=0)
      else:
        return np.tril(np.ones((1, length, length), dtype=np.bool_), k=0)
    >```

    This is an ingenious way to do the training. In one shot of source-target pair $(\mathbf{X}, \mathbf{Y})$, we can train the model to learn relation $f: \mathbf{X}, \mathrm{shiftright}(\mathbf{Y}) \mapsto \mathbf{Y}$, equivalent to $f: \mathbf{X}, y_t \mapsto y_{t+1}$ for all $t$'s.
   

Regarding masking and "sequence batch" learning, let's revise the attention:

$$\mbox{Attended output } = \mathrm{softmax} ( \mathrm{mask}\left(\frac{\hat{\mathbf{Q}} \hat{\mathbf{K}}^T }{\sqrt{M}}\right) ) \hat{\mathbf{V}}$$
    
Let's see, e.g., $\hat{\mathbf{Q}}$ of (1, 2, 3) or (2,3) for simplicity,
    
\begin{align}
\hat{\mathbf{Q}} = \begin{bmatrix}
q_1(t_1) & q_2(t_1) & q_3(t_1) \\
q_1(t_2) & q_2(t_2) & q_3(t_2)
\end{bmatrix}
\end{align}

Similarly, we have

\begin{align}
\hat{\mathbf{K}} = \begin{bmatrix}
k_1(t_1) & k_2(t_1) & k_3(t_1) \\
k_1(t_2) & k_2(t_2) & k_3(t_2)
\end{bmatrix}
\end{align}

The dot product
\begin{align}
\mathbf{P} = \frac{\hat{\mathbf{Q}} \hat{\mathbf{K}}^T }{\sqrt{M}}
= \begin{bmatrix}
p(t_1,t_1) & p(t_1,t_2) \\
p(t_2,t_1) & p(t_2,t_2)
\end{bmatrix}
\end{align}

After masking
\begin{align}
\mathrm{mask}(\mathbf{P}) =
= \begin{bmatrix}
p(t_1,t_1) & -\infty \\
p(t_2,t_1) & p(t_2,t_2)
\end{bmatrix}
\end{align}


Note the masking is muting the elements to $-\infty$'s, so that after softmax it becomes $0$'s. The implement uses extremely large negative number.

For any sequence length $L$,

\begin{align}
\mathrm{mask}(\mathbf{P}) =
= \begin{bmatrix}
p(t_1,t_1) & -\infty & -\infty & \cdots & -\infty\\
p(t_2,t_1) & p(t_2,t_2) & -\infty & \cdots & -\infty\\
p(t_3,t_1) & p(t_3,t_2) & p(t_3,t_3) & \cdots & -\infty\\
\vdots & \vdots & \vdots & \ddots & \vdots \\
p(t_L,t_1) & p(t_L,t_2) & p(t_L,t_3) & \cdots & p(t_L,t_L)\\
\end{bmatrix}
\end{align}


## Data

In [None]:
!ls ./data/wmt*

./data/wmt14_translate_de-en_train.csv	./data/wmt14_translate_de-en_validation.csv


In [None]:
import csv
with open('data/wmt14_translate_de-en_train.csv', 'r', encoding='utf-8') as csvfile:

    # for i in range(2):
    #     line = csvfile.readline()
    #     print(line)

    wmt = csv.reader(csvfile, delimiter=',')

    print(type(wmt))

    for i, row in enumerate(wmt):
        # print(type(row), len(row))
        print('0:', row[0])
        print('1:', row[1])
        print()

        if i > 2:
            break

<class '_csv.reader'>
0: de
1: en

0: An der B 211 befindet sich in Loyermoor der so genannte „Geest-Abbruch“, der eine Höhendifferenz von gut 30 Meter überbrückt.
1: Here the largest town of the district is located: Nordenham , lying opposite to Bremerhaven at the Weser mouth.

0: Ich begrüße die Erklärung des Herrn Kommissar und die Arbeit des Parlaments, die wirklich deutlich macht, welch konkretes Interesse wir an diesen Fragen haben, gerade unter den derzeitigen Umständen.
1: I should like, in passing, to pay tribute to the Commissioner' s statement and to all the work that Parliament has carried out, which shows that we attach real importance to these issues, particularly in the current circumstances.

0: Das ist das Gegenteil von dem, was getan werden müsste, und trotzdem machen Sie so weiter.
1: That is the opposite of what should be done and yet it is what you continue to do.



### Data stream

In [None]:
import csv

def dataset_stream(csv_path='data/wmt14_translate_de-en_train.csv'):

    with open(csv_path, 'r', encoding='utf-8') as csvfile:

        wmt = csv.reader(csvfile, delimiter=',')
        next(wmt)  ## Flush en, de out

        for row in wmt:
            yield (row[1], row[0])  ## (en, de)

In [None]:
# Test

ds = dataset_stream('data/wmt14_translate_de-en_train.csv')

x, y = next(ds)
print('x=', x)
print('y=', y)
print()

x, y = next(ds)
print('x=', x)
print('y=', y)

count = 0
for x, y in ds:
    print('x=', x)
    print('y=', y)
    print()
    count += 1

    if count > 10:
        break


x= Here the largest town of the district is located: Nordenham , lying opposite to Bremerhaven at the Weser mouth.
y= An der B 211 befindet sich in Loyermoor der so genannte „Geest-Abbruch“, der eine Höhendifferenz von gut 30 Meter überbrückt.

x= I should like, in passing, to pay tribute to the Commissioner' s statement and to all the work that Parliament has carried out, which shows that we attach real importance to these issues, particularly in the current circumstances.
y= Ich begrüße die Erklärung des Herrn Kommissar und die Arbeit des Parlaments, die wirklich deutlich macht, welch konkretes Interesse wir an diesen Fragen haben, gerade unter den derzeitigen Umständen.
x= That is the opposite of what should be done and yet it is what you continue to do.
y= Das ist das Gegenteil von dem, was getan werden müsste, und trotzdem machen Sie so weiter.

x= .
y= .

x= It was designed by the Viennese architect Ruppelmeyer. In that year the Bulgarian Prince Alexander Battenberg accepted as a

### Data pipeline

In [None]:
for src_line, tgt_line in dataset_stream('data/wmt14_translate_de-en_train.csv'):
    print('src:', src_line)
    print('tgt:', tgt_line)
    break

src: Here the largest town of the district is located: Nordenham , lying opposite to Bremerhaven at the Weser mouth.
tgt: An der B 211 befindet sich in Loyermoor der so genannte „Geest-Abbruch“, der eine Höhendifferenz von gut 30 Meter überbrückt.


In [None]:
# Define the data stream generator
def tokenized_stream(data_stream):

    tokenizer = trax.data.Tokenize(vocab_file='ende_32k.subword', vocab_dir='./data/')

    for src_line, tgt_line in data_stream:

        source_tokens = list(tokenizer([src_line.strip()]))[0]
        target_tokens = list(tokenizer([tgt_line.strip()]))[0]
        yield (source_tokens, target_tokens)


In [None]:
# Test
for src, tgt in tokenized_stream(dataset_stream('data/wmt14_translate_de-en_train.csv')):
    print(src)
    print(tgt)
    break

[ 2106     4  1824  1275     7     4  3213    16   737    64  6642 11251
   164   134 12188  5505     9 32631 12402 17817    23    68     4 12628
    70 14989     3]
[ 1208    11   524  5421   135  1032    51     6  3889 26667 17037    76
    11    79  6895   213  7527  2223    15 11746 13293     5  1654    11
    41 31941 17625   387    21   541   448  9803  4374 20622  6863     3]


In [None]:
# Apply Trax Serial pipeline
def transformer_pipeline(batch_size=32):
    return trax.data.Serial(
        trax.data.Batch(batch_size=batch_size),  # Batch the data
        trax.data.AddLossWeights(),              # Add loss weights for padding tokens ~ masking out the padding
        # data.Shuffle()                         # Shuffle the dataset
    )

# Use the pipeline with a data stream
batch_size = 32

train_stream = transformer_pipeline(batch_size)(
    tokenized_stream(dataset_stream('data/wmt14_translate_de-en_train.csv'))
)

val_stream = transformer_pipeline(batch_size)(
    tokenized_stream(dataset_stream('data/wmt14_translate_de-en_validation.csv'))
)

In [None]:
# Test

for inputs, targets, loss_weights in train_stream:
    print("Inputs (source):", inputs.shape, ':', inputs)
    print("Targets (target):", targets.shape, ':', targets)
    print("Loss Weights:", loss_weights.shape, ':', loss_weights)  # Target masking/masking out target zeros
    break

Audio(beep, rate=srate, autoplay=True)

Inputs (source): (32, 65) : [[2106    4 1824 ...    0    0    0]
 [  46  117  151 ...    0    0    0]
 [ 403   16    4 ...    0    0    0]
 ...
 [ 624  933 6081 ...    0    0    0]
 [  26   29  527 ...    0    0    0]
 [ 327   49 5584 ...    0    0    0]]
Targets (target): (32, 69) : [[ 1208    11   524 ...     0     0     0]
 [  161  3728    10 ...     0     0     0]
 [  111    24    34 ...     0     0     0]
 ...
 [ 1750   422   315 ...     0     0     0]
 [   26    57   423 ...     0     0     0]
 [11483    11 15624 ...     0     0     0]]
Loss Weights: (32, 69) : [[1. 1. 1. ... 1. 1. 1.]
 [1. 1. 1. ... 1. 1. 1.]
 [1. 1. 1. ... 1. 1. 1.]
 ...
 [1. 1. 1. ... 1. 1. 1.]
 [1. 1. 1. ... 1. 1. 1.]
 [1. 1. 1. ... 1. 1. 1.]]


```loss_weights``` is needed to mask out the target padding.

The example above is not a good example.
There are supposed to be 0's where paddings are (0's in ```targets```).

## model

In [None]:
from trax.supervised import training


Original
from trax.models.transformer import Transformer

# Define the model / Trax standard Transformer model
model = Transformer(
    input_vocab_size=33300,
    d_model=512,  d_ff=2048,
    n_heads=8, n_encoder_layers=6, n_decoder_layers=6,
    max_len=2048
)

Audio(beep, rate=srate, autoplay=True)

my_Transformer: mode='train', dropout=0.1
Elapsed time: 0.027143478393554688
Elapse time 16.89951515197754


## Task

* Train using [```WeightedCategoryCrossEntropy```](https://github.com/google/trax/blob/master/trax/layers/metrics.py#L245)

  * ```WeightedCategoryCrossEntropy``` is pretty much Category Cross-Entropy with masking (weighting, whose values are assigned either 0's or 1's to mask out padding)

In [None]:
# Define the training task
train_task = training.TrainTask(
    labeled_data=train_stream,
    loss_layer=tl.WeightedCategoryCrossEntropy(),   # Mask out target padding
    optimizer=trax.optimizers.Adam(0.01),
)

# Evaluation task
eval_task = training.EvalTask(
    labeled_data=val_stream,
    metrics=[tl.WeightedCategoryCrossEntropy(), tl.WeightedCategoryAccuracy()],
    n_eval_batches=1 # for visibility
)

Audio(beep, rate=srate, autoplay=True)

## Clear space for the coming trained params

In [None]:
directory_name = 'output_dir'

dpath = os.path.join(os.getcwd(), directory_name)

if not os.path.exists(dpath):

    try:
        os.mkdir(dpath)
        print(f"Directory '{directory_name}' created successfully.")
    except Exception as e:
        print(f"An error occurred: {e}")

fs = os.listdir(dpath)
for f in fs:
    fpath = os.path.join(dpath, f)
    if os.path.isfile(fpath):
        os.remove(fpath)
        print(f, 'is removed.')
    else:
        print(f'[{f}]')

[train]
[eval]


## Set training loop

In [None]:
# Training loop
t1 = time.time()
training_loop = training.Loop(
    model=model,
    tasks=train_task,
    eval_tasks=[eval_task],
    output_dir=directory_name
)
t2 = time.time()

print('Elapse time', t2 - t1)

Audio(beep, rate=srate, autoplay=True)



Elapse time 2.9101624488830566


## Train

In [None]:
# Start training
t1 = time.time()
training_loop.run(2)  # Train for 3 steps
t2 = time.time()

print('Elapse time', t2 - t1)

Audio(beep, rate=srate, autoplay=True)

my_Probe: tag= shift-right in ; n_in= 4 ; n_out= 4: (32, 105)
(32, 1, 1, 110)
(32, 110, 512)
(32, 105)
my_Probe: tag= shift-right out ; n_in= 4 ; n_out= 4: (32, 105)
(32, 1, 1, 110)
(32, 110, 512)
(32, 105)
* mask= Traced<ShapedArray(bool[1,105,105])>with<DynamicJaxprTrace(level=1/0)>
* mask= Traced<ShapedArray(bool[1,105,105])>with<DynamicJaxprTrace(level=1/0)>


  with gzip.GzipFile(fileobj=f, compresslevel=compresslevel) as gzipf:



Step      1: Total number of trainable weights: 80370196
Step      1: Ran 1 train steps in 167.94 secs
Step      1: train WeightedCategoryCrossEntropy |  21.56775856


  with gzip_lib.GzipFile(fileobj=f, compresslevel=2) as gzipf:


my_Probe: tag= shift-right in ; n_in= 4 ; n_out= 4: (32, 99)
(32, 1, 1, 100)
(32, 100, 512)
(32, 99)
my_Probe: tag= shift-right out ; n_in= 4 ; n_out= 4: (32, 99)
(32, 1, 1, 100)
(32, 100, 512)
(32, 99)
* mask= Traced<ShapedArray(bool[1,99,99])>with<DynamicJaxprTrace(level=1/0)>
* mask= Traced<ShapedArray(bool[1,99,99])>with<DynamicJaxprTrace(level=1/0)>
Step      1: eval  WeightedCategoryCrossEntropy |  11.10817528
Step      1: eval      WeightedCategoryAccuracy |  0.00000000
my_Probe: tag= shift-right in ; n_in= 4 ; n_out= 4: (32, 74)
(32, 1, 1, 78)
(32, 78, 512)
(32, 74)
my_Probe: tag= shift-right out ; n_in= 4 ; n_out= 4: (32, 74)
(32, 1, 1, 78)
(32, 78, 512)
(32, 74)
* mask= Traced<ShapedArray(bool[1,74,74])>with<DynamicJaxprTrace(level=1/0)>
* mask= Traced<ShapedArray(bool[1,74,74])>with<DynamicJaxprTrace(level=1/0)>
Elapse time 288.98834013938904


In [None]:
300/60

5.0

# Probe internal information

Trax code is quite complex and highly intertwined.
It is better done in a modified source file.

E.g.,
to probe information passing in ```Transformer``` model,
we can modify [```transformer.py```](https://github.com/google/trax/blob/master/trax/models/transformer.py).

Here, instead of directly modifying the original, we make another version of it and re-route our code to our new version.

* ```transformer.py``` $\rightarrow$ ```mod.py```

In [None]:
!ls *.py

myatt.py  mycore.py  mymod.py  mytraining.py


In [None]:
import mymod

# Reload imported libraries
import importlib
importlib.reload(mymod)

t1 = time.time()
model = mymod.my_Transformer(
    input_vocab_size=33300,
    d_model=512, d_ff=2048,
    n_heads=8, n_encoder_layers=6, n_decoder_layers=6,
    max_len=2048, mode='predict')

t2 = time.time()

print('Elapsed time:', t2 -t1)

t1 = time.time()
model.init_from_file('./data/ende_wmt32k.pkl.gz', weights_only=True)
t2 = time.time()

print('Elapsed time:', t2 -t1)
Audio(beep, rate=srate, autoplay=True)


my_Transformer: mode='predict', dropout=0.1
Elapsed time: 0.028961181640625
Elapsed time: 17.29772973060608


In [None]:
import inspect
lines = inspect.getsource(mymod.my_Transformer)
print(lines)

def my_Transformer(input_vocab_size,
                output_vocab_size=None,
                d_model=D_MODEL,
                d_ff=D_FF,
                n_encoder_layers=N_LAYERS,
                n_decoder_layers=N_LAYERS,
                n_heads=N_HEADS,
                max_len=MAX_SEQUENCE_LENGTH,
                dropout=DROPOUT_RATE,
                dropout_shared_axes=DROPOUT_SHARED_AXES,
                mode=MODE,
                ff_activation=FF_ACTIVATION_TYPE):
  """Returns a full Transformer model.

  This model is an encoder-decoder that performs tokenized string-to-string
  ("source"-to-"target") transduction:

    - inputs (2):

        - source: Array representing a batch of text strings via token
          IDs plus padding markers; shape is (batch_size, sequence_length),
          where sequence_length <= ``max_len``. Array elements are integers in
          ``range(input_vocab_size)``, and 0 values mark padding positions.

        - target: Array representing a batch of 

Here, we add ```my_Probe('shift-right in', n_in=4, n_out=4)``` where we want to probe in.

In [None]:
lines = inspect.getsource(mymod.my_Probe)
print(lines)

class my_Probe(tl.base.Layer):

  def __init__(self, tag, *args, **kwargs):

      super(my_Probe, self).__init__(*args, **kwargs)
      self._my_trace = False
      self._tag = tag
      self._note = ""
      self._log = {'count': 0}
    
  def forward(self, x):

      if self._my_trace:
          print("my_Probe: tag=", self._tag, '; n_in=', self._n_in, '; n_out=', self._n_out, end=': ')
          if hasattr(x, 'shape'):
            print(x.shape) 
          elif self._n_in > 1:
            for i in range(self._n_in):
              print(x[i].shape)
          
          
          self._log['count'] = self._log['count'] + 1

          self._log[self._log['count']] = {'tag': self._tag, 'x': x}
      return x



To change where to probe, we just have to edit ```mod.py``` and reload the library.

In [None]:
o = model.sublayers[12].sublayers[0].sublayers[1].sublayers[1].sublayers[2].sublayers[2]
o

my_DotProductCausalAttention_in3

In [None]:
o._log

{}

In [None]:
# DotProductCaussalAttention (1 of 6): 12, 15, 18, 21, 24, 27
model.sublayers[12].sublayers[0].sublayers[1].sublayers[1].sublayers[2].sublayers[2]._my_trace = True
model.sublayers[27].sublayers[0].sublayers[1].sublayers[1].sublayers[2].sublayers[2]._my_trace = True

# Joint attention: 13, 16, 19, 22, ,25, 28

# ShiftRight in
model.sublayers[5]._my_trace = True

# ShiftRight out
model.sublayers[7]._my_trace = True

In [None]:
from trax import layers as tl

inputs = np.array([[415, 85, 1977, 83, 4, 15286, 466,
                    101, 72, 442, 1356, 7191, 55, 2097, 3]])
outputs = np.array([[0, 67, 119, 233, 18690,    23, 27551,     5,     2,   163,
                     67, 119,  44,  1169, 12869,    21,   710,     3]])

last_index = -1
for i, y in enumerate(outputs[0]):

    current_symbols = np.array([[y]])

    logits = model((inputs, current_symbols))[0]
    print('y in:', y, '; output logits.shape=', logits.shape, end='>>>')

    logits = tl.log_softmax(logits[:, last_index, :])
    eprob = np.exp(logits)

    TopN = 1
    ids = np.argsort(eprob[0,:])[-TopN:][::-1]
    print(' top ids:', ids)
    print()

Audio(beep, rate=srate, autoplay=True)

my_Probe: tag= shift-right in ; n_in= 4 ; n_out= 4: (1, 1)
(1, 1, 1, 15)
(1, 15, 512)
(1, 1)
my_Probe: tag= shift-right out ; n_in= 4 ; n_out= 4: (1, 1)
(1, 1, 1, 15)
(1, 15, 512)
(1, 1)
* mask= [[[ True  True False ... False False False]]]
* mask= [[[ True  True False ... False False False]]]
y in: 0 ; output logits.shape= (1, 1, 33300)>>> top ids: [67]

my_Probe: tag= shift-right in ; n_in= 4 ; n_out= 4: (1, 1)
(1, 1, 1, 15)
(1, 15, 512)
(1, 1)
my_Probe: tag= shift-right out ; n_in= 4 ; n_out= 4: (1, 1)
(1, 1, 1, 15)
(1, 15, 512)
(1, 1)
* mask= [[[ True  True  True ... False False False]]]
* mask= [[[ True  True  True ... False False False]]]
y in: 67 ; output logits.shape= (1, 1, 33300)>>> top ids: [119]

my_Probe: tag= shift-right in ; n_in= 4 ; n_out= 4: (1, 1)
(1, 1, 1, 15)
(1, 15, 512)
(1, 1)
my_Probe: tag= shift-right out ; n_in= 4 ; n_out= 4: (1, 1)
(1, 1, 1, 15)
(1, 15, 512)
(1, 1)
* mask= [[[ True  True  True ... False False False]]]
* mask= [[[ True  True  True ... False Fa

In [None]:
o1 = model.sublayers[5]
o2 = model.sublayers[7]
print(o1._tag)
print(o2._tag)

shift-right in
shift-right out


In [None]:
print(o1._log.keys())
print(o2._log.keys())

dict_keys(['count', 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19])
dict_keys(['count', 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19])


In [None]:
print(o1._log[19].keys())
print(o2._log[19].keys())

dict_keys(['tag', 'x'])
dict_keys(['tag', 'x'])


In [None]:
for i in range(1, 20):
  x1 = o1._log[i]['x']
  print(i,':', len(x1), x1[0].shape)

1 : 4 (1, 1)
2 : 4 (1, 1)
3 : 4 (1, 1)
4 : 4 (1, 1)
5 : 4 (1, 1)
6 : 4 (1, 1)
7 : 4 (1, 1)
8 : 4 (1, 1)
9 : 4 (1, 1)
10 : 4 (1, 1)
11 : 4 (1, 1)
12 : 4 (1, 1)
13 : 4 (1, 1)
14 : 4 (1, 1)
15 : 4 (1, 1)
16 : 4 (1, 1)
17 : 4 (1, 1)
18 : 4 (1, 1)
19 : 4 (1, 1)


In [None]:
i = 3
o1._log[i]['x'][0], o2._log[i]['x'][0]


(array([[67]]), array([[67]]))

In [None]:
o = model.sublayers[12].sublayers[0].sublayers[1].sublayers[1].sublayers[2].sublayers[2]

o

my_DotProductCausalAttention_in3

In [None]:
o._log.keys()

dict_keys(['count', 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19])

In [None]:
i = 1
o._log[i]['q'].shape, o._log[i]['k'].shape, o._log[i]['v'].shape

((8, 1, 64), (8, 2048, 64), (8, 2048, 64))

# Emulate training

In [None]:
import mymod

# Reload imported libraries
import importlib
importlib.reload(mymod)

t1 = time.time()
model = mymod.my_Transformer(
    input_vocab_size=33300,
    d_model=512, d_ff=2048,
    n_heads=8, n_encoder_layers=6, n_decoder_layers=6,
    max_len=2048, mode='train')

t2 = time.time()

print('Elapsed time:', t2 -t1)

t1 = time.time()
model.init_from_file('./data/ende_wmt32k.pkl.gz', weights_only=True)
t2 = time.time()

print('Elapsed time:', t2 -t1)
Audio(beep, rate=srate, autoplay=True)


my_Transformer: mode='train', dropout=0.1
Elapsed time: 0.027943849563598633
Elapsed time: 15.563826084136963


In [None]:
# DotProductCaussalAttention (1 of 6): 12, 15, 18, 21, 24, 27
model.sublayers[12].sublayers[0].sublayers[1].sublayers[1].sublayers[2].sublayers[2]._my_trace = True
model.sublayers[27].sublayers[0].sublayers[1].sublayers[1].sublayers[2].sublayers[2]._my_trace = True

# Joint attention: 13, 16, 19, 22, ,25, 28

# ShiftRight in
model.sublayers[5]._my_trace = True

# ShiftRight out
model.sublayers[7]._my_trace = True

In [None]:
src, tgt = next(tokenized_stream(dataset_stream('data/wmt14_translate_de-en_train.csv')))
print('src:', src)
print('tgt:', tgt)


src: [ 2106     4  1824  1275     7     4  3213    16   737    64  6642 11251
   164   134 12188  5505     9 32631 12402 17817    23    68     4 12628
    70 14989     3]
tgt: [ 1208    11   524  5421   135  1032    51     6  3889 26667 17037    76
    11    79  6895   213  7527  2223    15 11746 13293     5  1654    11
    41 31941 17625   387    21   541   448  9803  4374 20622  6863     3]


In [None]:
inputs, targets, loss_weights = next(train_stream)

In [None]:
print(inputs.shape, targets.shape)

(32, 65) (32, 69)


In [None]:
logits = model((inputs, targets))[0]

print('logits=', logits)


my_Probe: tag= shift-right in ; n_in= 4 ; n_out= 4: (32, 69)
(32, 1, 1, 65)
(32, 65, 512)
(32, 69)
my_Probe: tag= shift-right out ; n_in= 4 ; n_out= 4: (32, 69)
(32, 1, 1, 65)
(32, 65, 512)
(32, 69)
* mask= [[[ True False False ... False False False]
  [ True  True False ... False False False]
  [ True  True  True ... False False False]
  ...
  [ True  True  True ...  True False False]
  [ True  True  True ...  True  True False]
  [ True  True  True ...  True  True  True]]]
* mask= [[[ True False False ... False False False]
  [ True  True False ... False False False]
  [ True  True  True ... False False False]
  ...
  [ True  True  True ...  True False False]
  [ True  True  True ...  True  True False]
  [ True  True  True ...  True  True  True]]]
logits= [[[-14.039238    2.7436395   4.5101347 ... -13.924271  -14.075339
   -14.08013  ]
  [-14.582066    4.4777493   6.9550357 ... -15.301696  -15.0304785
   -15.654066 ]
  [-15.589659    2.711014    6.2200456 ... -15.376288  -15.4009495
 

In [None]:
o = model.sublayers[12].sublayers[0].sublayers[1].sublayers[1].sublayers[2].sublayers[2]

o._log.keys()

dict_keys(['count', 1])

In [None]:
o._log[1].keys()

dict_keys(['inputs', 'outputs', 'tag', 'mask', 'state_in', 'state_out', 'q', 'k', 'v'])

In [None]:
o._log[1]['q'].shape, o._log[1]['k'].shape, o._log[1]['k'].shape

((256, 69, 64), (256, 69, 64), (256, 69, 64))

In [None]:
o._log[1]['mask']

array([[[ True, False, False, ..., False, False, False],
        [ True,  True, False, ..., False, False, False],
        [ True,  True,  True, ..., False, False, False],
        ...,
        [ True,  True,  True, ...,  True, False, False],
        [ True,  True,  True, ...,  True,  True, False],
        [ True,  True,  True, ...,  True,  True,  True]]])