Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code callbacks and optional console output suppression while training #160

Open
wants to merge 6 commits into
base: master
Choose a base branch
from

Conversation

Meorge
Copy link

@Meorge Meorge commented Oct 28, 2021

I'm hoping to work on a GUI frontend app that will make using aitextgen more accessible, but I realized that to do that it would need to be able to interface more closely with aitextgen itself. To be able to accomplish this, I've started working on adding support for custom callbacks/hooks in aitextgen. This pull request adds some basic support for callbacks/hooks while training with train(), as well as options for suppressing some of the default console output.

Custom callbacks

The user can pass a dictionary into the train() function as the argument callbacks with lambdas or functions corresponding to the following keys:

  • on_train_start(): called when training begins
  • on_train_end(): called when training ends
  • on_batch_end(current_steps: int, total_steps: int, current_loss: int, avg_loss: int): called when a batch is completed, and passes the current and total steps along with the current and average loss to the function
  • on_sample_text_generated(texts: List[str]): called when sample texts are generated, and passes the List of sample texts to the function
  • on_model_saved(current_steps: int, total_steps: int, output_dir: str): called when the model is saved during training, and passes the current and total steps along with the output directory to the function

Suppressing console output

If set to False, the print_generated and print_saved arguments for train() will prevent generated text and "saving model" messages from being displayed in the console, respectively.

Example script

I modified the Hello World Tutorial from the aitextgen website to test this functionality:

from aitextgen.aitextgen.TokenDataset import TokenDataset
from aitextgen.aitextgen.tokenizers import train_tokenizer
from aitextgen.aitextgen.utils import GPT2ConfigCPU
from aitextgen.aitextgen import aitextgen

def on_train_start():
    print('Training has started!')

def on_train_end():
    print('Training has ended!')

def on_sample_text_generated(texts):
    print(f'Here are the generated sample texts! {texts}')

def on_batch_end(current_steps, total_steps, current_loss, avg_loss):
    print(f'step {current_steps}/{total_steps}: current loss={current_loss}, avg loss={avg_loss}')

def on_model_saved(steps, max, output_dir):
    print(f'step {steps}/{max}: saving to {output_dir}!')

def main():
    # The name of the downloaded Shakespeare text for training
    file_name = "training.txt"

    # Train a custom BPE Tokenizer on the downloaded text
    # This will save one file: `aitextgen.tokenizer.json`, which contains the
    # information needed to rebuild the tokenizer.
    train_tokenizer(file_name)
    tokenizer_file = "aitextgen.tokenizer.json"

    # GPT2ConfigCPU is a mini variant of GPT-2 optimized for CPU-training
    # e.g. the # of input tokens here is 64 vs. 1024 for base GPT-2.
    config = GPT2ConfigCPU()

    # Instantiate aitextgen using the created tokenizer and config
    ai = aitextgen(tokenizer_file=tokenizer_file, config=config)

    # You can build datasets for training by creating TokenDatasets,
    # which automatically processes the dataset with the appropriate size.
    data = TokenDataset(file_name, tokenizer_file=tokenizer_file, block_size=64)

    # Define our callbacks
    callbacks = {
        'on_train_start': on_train_start,
        'on_train_end': on_train_end,
        'on_batch_end': on_batch_end,
        'on_sample_text_generated': on_sample_text_generated,
        'on_model_saved': on_model_saved
    }

    # Train the model!
    ai.train(data,
        batch_size=1,
        num_steps=20,
        generate_every=10,
        save_every=5,
        print_generated=False,
        print_saved=False,
        callbacks=callbacks
        )

if __name__ == "__main__": main()

Once it got past setup, the output was:

step 1/20: current loss=6.89, avg loss=6.89
step 2/20: current loss=6.81, avg loss=6.8892
step 3/20: current loss=6.8, avg loss=6.888307999999999
step 4/20: current loss=6.81, avg loss=6.88752492
step 5/20: current loss=6.79, avg loss=6.886549670799999
5/20 steps so output to trained_model
step 6/20: current loss=6.75, avg loss=6.885184174091999
step 7/20: current loss=6.72, avg loss=6.883532332351079
step 8/20: current loss=6.71, avg loss=6.881797009027568
step 9/20: current loss=6.68, avg loss=6.879779038937292
step 10/20: current loss=6.64, avg loss=6.877381248547919
10/20 steps so output to trained_model
Here are the generated sample texts! [' in\n re\n IHHAs,\ns W\ny\n\n\n,:\n\nses the\n And I:\nsees INence thou W isG\n from\n\n\n to onry\n m,se;,owardo Fence on\nri to\n']
step 11/20: current loss=6.62, avg loss=6.87480743606244
step 12/20: current loss=6.58, avg loss=6.871859361701816
step 13/20: current loss=6.55, avg loss=6.868640768084798
step 14/20: current loss=6.53, avg loss=6.8652543604039495
step 15/20: current loss=6.52, avg loss=6.86180181679991
15/20 steps so output to trained_model
step 16/20: current loss=6.5, avg loss=6.858183798631911
step 17/20: current loss=6.46, avg loss=6.854201960645592
step 18/20: current loss=6.45, avg loss=6.850159941039136
step 19/20: current loss=6.43, avg loss=6.845958341628745
step 20/20: current loss=6.41, avg loss=6.8415987582124576
Loss: 6.410 — Avg: 6.842: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:07<00:00,  2.56it/s]20/20 steps so output to trained_model
Here are the generated sample texts! [' ones\n\n\n\n\n itinging a\nisse\n\n to\nuis m\n\n not\n,\n\n\n\ning\n\n\n:\n\n\n to\n\n\n:\ntent:\n\n\ners\nyt toe;he\n\n,t be']
Loss: 6.410 — Avg: 6.842: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:07<00:00,  2.52it/s]

Currently supports callbacks for when training begins and ends, when a batch ends, and when sample text is generated.
The print_generated argument in aitextgen.train() allows the user to prevent the text from being printed to the console
The print_saved argument in aitextgen.train() allows the user to prevent "X steps reached: saving model" messages from being printed to the console
Callback is called whenever the model is saved
@Meorge Meorge changed the title Code callbacks and console output suppression while training Code callbacks and optional console output suppression while training Oct 28, 2021
@Meorge
Copy link
Author

Meorge commented Oct 28, 2021

Using these features, this evening I started implementing a GUI frontend for aitextgen using PyQt5. Still obviously very barebones but I'm happy how quickly I was able to get it to work :)

Screen.Recording.2021-10-27.at.9.20.32.PM.mov

@minimaxir
Copy link
Owner

I saw the tweets about this UI.

Because I'm working on migrating to the official Hugging Face trainer in the trainer branch, I need to see how much of this is applicable, but this is impressive work!

@Meorge
Copy link
Author

Meorge commented Nov 18, 2021

Thanks, I'm happy to hear it! 😄

That's a good point regarding the trainer branch. For now, the fork I have seems to be producing output I'm happy enough with, but when I have time I will try to check out that branch and investigate incorporating the callbacks into it/the Hugging Face libraries.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants