mirror of
https://github.com/carlospolop/hacktricks
synced 2024-12-19 01:24:50 +00:00
116 lines
7.4 KiB
Markdown
116 lines
7.4 KiB
Markdown
|
# 7.1. Fine-Tuning for Classification
|
||
|
|
||
|
## What is
|
||
|
|
||
|
Fine-tuning is the process of taking a **pre-trained model** that has learned **general language patterns** from vast amounts of data and **adapting** it to perform a **specific task** or to understand domain-specific language. This is achieved by continuing the training of the model on a smaller, task-specific dataset, allowing it to adjust its parameters to better suit the nuances of the new data while leveraging the broad knowledge it has already acquired. Fine-tuning enables the model to deliver more accurate and relevant results in specialized applications without the need to train a new model from scratch.
|
||
|
|
||
|
{% hint style="info" %}
|
||
|
As pre-training a LLM that "understands" the text is pretty expensive it's usually easier and cheaper to to fine-tune open source pre-trained models to perform a specific task we want it to perform.
|
||
|
{% endhint %}
|
||
|
|
||
|
{% hint style="success" %}
|
||
|
The goal of this section is to show how to fine-tune an already pre-trained model so instead of generating new text the LLM will select give the **probabilities of the given text being categorized in each of the given categories** (like if a text is spam or not).
|
||
|
{% endhint %}
|
||
|
|
||
|
## Preparing the data set
|
||
|
|
||
|
### Data set size
|
||
|
|
||
|
Of course, in order to fine-tune a model you need some structured data to use to specialise your LLM. In the example proposed in [https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01\_main-chapter-code/ch06.ipynb](https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01\_main-chapter-code/ch06.ipynb), GPT2 is fine tuned to detect if an email is spam or not using the data from [https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip](https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip)_._
|
||
|
|
||
|
This data set contains much more examples of "not spam" that of "spam", therefore the book suggest to **only use as many examples of "not spam" as of "spam"** (therefore, removing from the training data all the extra examples). In this case, this was 747 examples of each.
|
||
|
|
||
|
Then, **70%** of the data set is used for **training**, **10%** for **validation** and **20%** for **testing**.
|
||
|
|
||
|
* The **validation set** is used during the training phase to fine-tune the model's **hyperparameters** and make decisions about model architecture, effectively helping to prevent overfitting by providing feedback on how the model performs on unseen data. It allows for iterative improvements without biasing the final evaluation.
|
||
|
* This means that although the data included in this data set is not used for the training directly, it's used to tune the best **hyperparameters**, so this set cannot be used to evaluate the performance of the model like the testing one.
|
||
|
* In contrast, the **test set** is used **only after** the model has been fully trained and all adjustments are complete; it provides an unbiased assessment of the model's ability to generalize to new, unseen data. This final evaluation on the test set gives a realistic indication of how the model is expected to perform in real-world applications.
|
||
|
|
||
|
### Entries length
|
||
|
|
||
|
As the training example expects entries (emails text in this case) of the same length, it was decided to make every entry as large as the largest one by adding the ids of `<|endoftext|>` as padding.
|
||
|
|
||
|
### Initialize the model
|
||
|
|
||
|
Using the open-source pre-trained weights initialize the model to train. We have already done this before and follow the instructions of [https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01\_main-chapter-code/ch06.ipynb](https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01\_main-chapter-code/ch06.ipynb) you can easily do it.
|
||
|
|
||
|
## Classification head
|
||
|
|
||
|
In this specific example (predicting if a text is spam or not), we are not interested in fine tune according to the complete vocabulary of GPT2 but we only want the new model to say if the email is spam (1) or not (0). Therefore, we are going to **modify the final layer that** gives the probabilities per token of the vocabulary for one that only gives the probabilities of being spam or not (so like a vocabulary of 2 words).
|
||
|
|
||
|
```python
|
||
|
# This code modified the final layer with a Linear one with 2 outs
|
||
|
num_classes = 2
|
||
|
model.out_head = torch.nn.Linear(
|
||
|
in_features=BASE_CONFIG["emb_dim"],
|
||
|
out_features=num_classes
|
||
|
)
|
||
|
```
|
||
|
|
||
|
## Parameters to tune
|
||
|
|
||
|
In order to fine tune fast it's easier to not fine tune all the parameters but only some final ones. This is because it's known that the lower layers generally capture basic language structures and semantics applicable. So, just **fine tuning the last layers is usually enough and faster**.
|
||
|
|
||
|
```python
|
||
|
# This code makes all the parameters of the model unrtainable
|
||
|
for param in model.parameters():
|
||
|
param.requires_grad = False
|
||
|
|
||
|
# Allow to fine tune the last layer in the transformer block
|
||
|
for param in model.trf_blocks[-1].parameters():
|
||
|
param.requires_grad = True
|
||
|
|
||
|
# Allow to fine tune the final layer norm
|
||
|
for param in model.final_norm.parameters():
|
||
|
param.requires_grad = True
|
||
|
```
|
||
|
|
||
|
## Entries to use for training
|
||
|
|
||
|
In previos sections the LLM was trained reducing the loss of every predicted token, even though almost all the predicted tokens were in the input sentence (only 1 at the end was really predicted) in order for the model to understand better the language.
|
||
|
|
||
|
In this case we only care on the model being able to predict if the model is spam or not, so we only care about the last token predicted. Therefore, it's needed to modify out previous training loss functions to only take into account that token.
|
||
|
|
||
|
This is implemented in [https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01\_main-chapter-code/ch06.ipynb](https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01\_main-chapter-code/ch06.ipynb) as:
|
||
|
|
||
|
```python
|
||
|
def calc_accuracy_loader(data_loader, model, device, num_batches=None):
|
||
|
model.eval()
|
||
|
correct_predictions, num_examples = 0, 0
|
||
|
|
||
|
if num_batches is None:
|
||
|
num_batches = len(data_loader)
|
||
|
else:
|
||
|
num_batches = min(num_batches, len(data_loader))
|
||
|
for i, (input_batch, target_batch) in enumerate(data_loader):
|
||
|
if i < num_batches:
|
||
|
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
|
||
|
|
||
|
with torch.no_grad():
|
||
|
logits = model(input_batch)[:, -1, :] # Logits of last output token
|
||
|
predicted_labels = torch.argmax(logits, dim=-1)
|
||
|
|
||
|
num_examples += predicted_labels.shape[0]
|
||
|
correct_predictions += (predicted_labels == target_batch).sum().item()
|
||
|
else:
|
||
|
break
|
||
|
return correct_predictions / num_examples
|
||
|
|
||
|
|
||
|
def calc_loss_batch(input_batch, target_batch, model, device):
|
||
|
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
|
||
|
logits = model(input_batch)[:, -1, :] # Logits of last output token
|
||
|
loss = torch.nn.functional.cross_entropy(logits, target_batch)
|
||
|
return loss
|
||
|
```
|
||
|
|
||
|
Note how for each batch we are only interested in the **logits of the last token predicted**.
|
||
|
|
||
|
## Complete GPT2 fine-tune classification code
|
||
|
|
||
|
You can find all the code to fine-tune GPT2 to be a spam classifier in [https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01\_main-chapter-code/load-finetuned-model.ipynb](https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01\_main-chapter-code/load-finetuned-model.ipynb)
|
||
|
|
||
|
## References
|
||
|
|
||
|
* [https://www.manning.com/books/build-a-large-language-model-from-scratch](https://www.manning.com/books/build-a-large-language-model-from-scratch)
|