diff options
author | Anthony Wang | 2022-02-22 17:55:04 -0600 |
---|---|---|
committer | Anthony Wang | 2022-02-22 17:55:04 -0600 |
commit | f08e5bfc5f2dbc3ab82a1b92aeec9c4c4368524f (patch) | |
tree | 9882f328ad225ce79923a216d14bf4dc50a9adf8 | |
parent | 8dab77d61bd0e8cf53f8e786191e026c724a6500 (diff) |
Reformat code with autopep8
-rw-r--r-- | bot.py | 6 | ||||
-rw-r--r-- | train.py | 6 |
2 files changed, 7 insertions, 5 deletions
@@ -6,8 +6,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM parser = ArgumentParser() parser.add_argument('-t', '--token', help='Mastodon application access token') -parser.add_argument('-i', '--input', default='i am', - help='initial input text for prediction') +parser.add_argument('-i', '--input', help='initial input text for prediction') parser.add_argument('-m', '--model', default='model', help='path to load saved model') args = parser.parse_args() @@ -19,7 +18,8 @@ model = AutoModelForCausalLM.from_pretrained(args.model) # Run the input through the model inputs = tokenizer.encode(args.input, return_tensors="pt") -output = tokenizer.decode(model.generate(inputs, do_sample=True, max_length=25, top_p=0.9, temperature=0.8)[0]) +output = tokenizer.decode(model.generate( + inputs, do_sample=True, max_length=25, top_p=0.9, temperature=0.8)[0]) print(output) @@ -16,7 +16,8 @@ args = parser.parse_args() # Load and tokenize dataset raw_dataset = load_dataset('text', data_files={'train': args.input}, keep_linebreaks=True) tokenizer = AutoTokenizer.from_pretrained('distilgpt2', use_fast=True) -tokenized_dataset = raw_dataset.map(lambda examples : tokenizer(examples['text']), batched=True, remove_columns='text') +tokenized_dataset = raw_dataset.map(lambda examples: tokenizer(examples['text']), + batched=True, remove_columns='text') # Generate chunks of block_size @@ -44,6 +45,7 @@ lm_dataset = tokenized_dataset.map(group_texts, batched=True) # Create and train the model model = AutoModelForCausalLM.from_pretrained('distilgpt2') -trainer = Trainer(model, TrainingArguments(output_dir=args.output), default_data_collator, lm_dataset['train']) +trainer = Trainer(model, TrainingArguments(output_dir=args.output), + default_data_collator, lm_dataset['train']) trainer.train() trainer.save_model() |