7 tips to reduce your VRAM when training LLMs
3 techniques you must know to evaluate your LLMs. Introduction to deploying private LLMs with AWS SageMaker.
Decoding ML Notes
This weekโs topics:
3 techniques you must know to evaluate your LLMs
7 tips you must know to reduce your VRAM consumption of your LLMs during training
Introduction to deploying private LLMs with AWS SageMaker
On the 3rd of May, I ๐ต๐ผ๐๐๐ฒ๐ฑ a ๐ณ๐ฟ๐ฒ๐ฒ ๐๐ฒ๐๐๐ถ๐ผ๐ป on Maven for ๐ต๐ฐ ๐ฝ๐ฒ๐ผ๐ฝ๐น๐ฒ on how to ๐๐ฟ๐ฐ๐ต๐ถ๐๐ฒ๐ฐ๐ ๐ฌ๐ผ๐๐ฟ ๐๐๐ ๐ง๐๐ถ๐ป. If you missed it, here is ๐ต๐ผ๐ you can ๐ฎ๐ฐ๐ฐ๐ฒ๐๐ ๐ถ๐ for ๐ณ๐ฟ๐ฒ๐ฒ โ
.
๐๐ฆ๐บ ๐ต๐ข๐ฌ๐ฆ๐ข๐ธ๐ข๐บ๐ด ๐ธ๐ฆ๐ณ๐ฆ:
โ Why I started building my LLM Twin
โ The 3 pipeline design / The FTI pipeline architecture
โ System design of the LLM Twin Architecture
โ Break down the RAG system of the LLM Twin Architecture
โ Live Demo
.
If you want the recording, you can watch it for free here: https://bit.ly/3PZGV0S
๐๐ญ๐ด๐ฐ, ๐ฉ๐ฆ๐ณ๐ฆ ๐ข๐ณ๐ฆ ๐ฐ๐ต๐ฉ๐ฆ๐ณ ๐ถ๐ด๐ฆ๐ง๐ถ๐ญ ๐ญ๐ช๐ฏ๐ฌ๐ด:
- ๐ด๐ญ๐ช๐ฅ๐ฆ๐ด: ๐ https://lnkd.in/d_MdqGwS
- ๐๐๐ ๐๐ธ๐ช๐ฏ ๐ค๐ฐ๐ถ๐ณ๐ด๐ฆ ๐๐ช๐ต๐๐ถ๐ฃ: ๐ https://lnkd.in/dzat6PB6
- ๐๐๐ ๐๐ธ๐ช๐ฏ ๐๐๐๐ ๐ญ๐ฆ๐ด๐ด๐ฐ๐ฏ๐ด: ๐ https://lnkd.in/dX__4mhX
3 techniques you must know to evaluate your LLMs
Here are 3 techniques you must know to evaluate your LLMs quickly.
Manually testing the output of your LLMs is a tedious and painful process โ you need to automate it.
In generative AI, most of the time, you cannot leverage standard metrics.
Thus, the real question is, how do you evaluate the outputs of an LLM?
#๐ญ. ๐ฆ๐๐ฟ๐๐ฐ๐๐๐ฟ๐ฒ๐ฑ ๐ฎ๐ป๐๐๐ฒ๐ฟ๐ - ๐๐ผ๐ ๐ธ๐ป๐ผ๐ ๐ฒ๐
๐ฎ๐ฐ๐๐น๐ ๐๐ต๐ฎ๐ ๐๐ผ๐ ๐๐ฎ๐ป๐ ๐๐ผ ๐ด๐ฒ๐
Even if you use an LLM to generate text, you can ask it to generate a response in a structured format (e.g., JSON) that can be parsed.
You know exactly what you want (e.g., a list of products extracted from the user's question).
Thus, you can easily compare the generated and ideal answers using classic approaches.
For example, when extracting the list of products from the user's input, you can do the following:
- check if the LLM outputs a valid JSON structure
- use a classic method to compare the generated and real answers
#๐ฎ. ๐ก๐ผ "๐ฟ๐ถ๐ด๐ต๐" ๐ฎ๐ป๐๐๐ฒ๐ฟ (๐ฒ.๐ด., ๐ด๐ฒ๐ป๐ฒ๐ฟ๐ฎ๐๐ถ๐ป๐ด ๐ฑ๐ฒ๐๐ฐ๐ฟ๐ถ๐ฝ๐๐ถ๐ผ๐ป๐, ๐๐๐บ๐บ๐ฎ๐ฟ๐ถ๐ฒ๐, ๐ฒ๐๐ฐ.)
When generating sentences, the LLM can use different styles, words, etc. Thus, traditional metrics (e.g., BLUE score) are too rigid to be useful.
You can leverage another LLM to test the output of our initial LLM. The trick is in what questions to ask.
Here, we have another 2 sub scenarios:
โณ ๐ฎ.๐ญ ๐ช๐ต๐ฒ๐ป ๐๐ผ๐ ๐ฑ๐ผ๐ป'๐ ๐ต๐ฎ๐๐ฒ ๐ฎ๐ป ๐ถ๐ฑ๐ฒ๐ฎ๐น ๐ฎ๐ป๐๐๐ฒ๐ฟ ๐๐ผ ๐ฐ๐ผ๐บ๐ฝ๐ฎ๐ฟ๐ฒ ๐๐ต๐ฒ ๐ฎ๐ป๐๐๐ฒ๐ฟ ๐๐ผ (๐๐ผ๐ ๐ฑ๐ผ๐ป'๐ ๐ต๐ฎ๐๐ฒ ๐ด๐ฟ๐ผ๐๐ป๐ฑ ๐๐ฟ๐๐๐ต)
You don't have access to an expert to write an ideal answer for a given question to compare it to.
Based on the initial prompt and generated answer, you can compile a set of questions and pass them to an LLM. Usually, these are Y/N questions that you can easily quantify and check the validity of the generated answer.
This is known as "Rubric Evaluation"
For example:
"""
- Is there any disagreement between the response and the context? (Y or N)
- Count how many questions the user asked. (output a number)
...
"""
This strategy is intuitive, as you can ask the LLM any question you are interested in as long it can output a quantifiable answer (Y/N or a number).
โณ ๐ฎ.๐ฎ. ๐ช๐ต๐ฒ๐ป ๐๐ผ๐ ๐ฑ๐ผ ๐ต๐ฎ๐๐ฒ ๐ฎ๐ป ๐ถ๐ฑ๐ฒ๐ฎ๐น ๐ฎ๐ป๐๐๐ฒ๐ฟ ๐๐ผ ๐ฐ๐ผ๐บ๐ฝ๐ฎ๐ฟ๐ฒ ๐๐ต๐ฒ ๐ฟ๐ฒ๐๐ฝ๐ผ๐ป๐๐ฒ ๐๐ผ (๐๐ผ๐ ๐ต๐ฎ๐๐ฒ ๐ด๐ฟ๐ผ๐๐ป๐ฑ ๐๐ฟ๐๐๐ต)
When you have access to an answer manually created by a group of experts, things are easier.
You will use an LLM to compare the generated and ideal answers based on semantics, not structure.
For example:
"""
(A) The submitted answer is a subset of the expert answer and entirely consistent.
...
(E) The answers differ, but these differences don't matter.
"""
7 tips you must know to reduce your VRAM consumption of your LLMs during training
Here are ๐ณ ๐๐ถ๐ฝ๐ you must know to ๐ฟ๐ฒ๐ฑ๐๐ฐ๐ฒ your ๐ฉ๐ฅ๐๐ ๐ฐ๐ผ๐ป๐๐๐บ๐ฝ๐๐ถ๐ผ๐ป of your ๐๐๐ ๐ during ๐๐ฟ๐ฎ๐ถ๐ป๐ถ๐ป๐ด so you can ๐ณ๐ถ๐ it on ๐
๐ญ ๐๐ฃ๐จ.
๐ญ. ๐ ๐ถ๐
๐ฒ๐ฑ-๐ฝ๐ฟ๐ฒ๐ฐ๐ถ๐๐ถ๐ผ๐ป: During training you use both FP32 and FP16 in the following way: "FP32 weights" -> "FP16 weights" -> "FP16 gradients" -> "FP32 gradients" -> "Update weights" -> "FP32 weights" (and repeat). As you can see, the forward & backward passes are done in FP16, and only the optimization step is done in FP32, which reduces both the VRAM and runtime.
๐ฎ. ๐๐ผ๐๐ฒ๐ฟ-๐ฝ๐ฟ๐ฒ๐ฐ๐ถ๐๐ถ๐ผ๐ป: All your computations are done in FP16 instead of FP32. But the key is using bfloat16 ("Brain Floating Point"), a numerical representation Google developed for deep learning. It allows you to represent very large and small numbers, avoiding overflowing or underflowing scenarios.
๐ฏ. ๐ฅ๐ฒ๐ฑ๐๐ฐ๐ถ๐ป๐ด ๐๐ต๐ฒ ๐ฏ๐ฎ๐๐ฐ๐ต ๐๐ถ๐๐ฒ: This one is straightforward. Fewer samples per training iteration result in smaller VRAM requirements. The downside of this method is that you can't go too low with your batch size without impacting your model's performance.
๐ฐ. ๐๐ฟ๐ฎ๐ฑ๐ถ๐ฒ๐ป๐ ๐ฎ๐ฐ๐ฐ๐๐บ๐๐น๐ฎ๐๐ถ๐ผ๐ป: It is a simple & powerful trick to increase your batch size virtually. You compute the gradients for "micro" batches (forward + backward passes). Once the accumulated gradients reach the given "virtual" target, the model weights are updated with the accumulated gradients. For example, you have a batch size of 4 and a micro-batch size of 1. Then, the forward & backward passes will be done using only x1 sample, and the optimization step will be done using the aggregated gradient of the 4 samples.
๐ฑ. ๐จ๐๐ฒ ๐ฎ ๐๐๐ฎ๐๐ฒ๐น๐ฒ๐๐ ๐ผ๐ฝ๐๐ถ๐บ๐ถ๐๐ฒ๐ฟ: Adam is the most popular optimizer. It is one of the most stable optimizers, but the downside is that it has 2 additional parameters (a mean & variance) for every model parameter. If you use a stateless optimizer, such as SGD, you can reduce the number of parameters by 2/3, which is significant for LLMs.
๐ฒ. ๐๐ฟ๐ฎ๐ฑ๐ถ๐ฒ๐ป๐ (๐ผ๐ฟ ๐ฎ๐ฐ๐๐ถ๐๐ฎ๐๐ถ๐ผ๐ป) ๐ฐ๐ต๐ฒ๐ฐ๐ธ๐ฝ๐ผ๐ถ๐ป๐๐ถ๐ป๐ด: It drops specific activations during the forward pass and recomputes them during the backward pass. Thus, it eliminates the need to hold all activations simultaneously in VRAM. This technique reduces VRAM consumption but makes the training slower.
๐ณ. ๐๐ฃ๐จ ๐ฝ๐ฎ๐ฟ๐ฎ๐บ๐ฒ๐๐ฒ๐ฟ ๐ผ๐ณ๐ณ๐น๐ผ๐ฎ๐ฑ๐ถ๐ป๐ด: The parameters that do not fit on your GPU's VRAM are loaded on the CPU. Intuitively, you can see it as a model parallelism between your GPU & CPU.
Most of these methods are orthogonal, so you can combine them and drastically reduce your VRAM requirements during training.
Introduction to deploying private LLMs with AWS SageMaker
Ever wondered ๐ต๐ผ๐ to ๐ฑ๐ฒ๐ฝ๐น๐ผ๐ in <๐ฏ๐ฌ ๐บ๐ถ๐ป๐๐๐ฒ๐ ๐ผ๐ฝ๐ฒ๐ป-๐๐ผ๐๐ฟ๐ฐ๐ฒ ๐๐๐ ๐, such as ๐๐น๐ฎ๐บ๐ฎ๐ฎ, on ๐๐ช๐ฆ ๐ฆ๐ฎ๐ด๐ฒ๐ ๐ฎ๐ธ๐ฒ๐ฟ? Then wonder no more โ
Step 1: Deploy the LLM to AWS SageMaker
The sweet thing about SageMaker is that it accelerates the development process, enabling a more efficient and rapid transition to the production stage.
- designing a config class for the deployment of the LLM
- set up AWS and deploy the LLM to SageMaker
- implement an inference class to call the deployed LLM in real-time through a web endpoint
- define a prompt template function to ensure reproducibility & consistency
...and, ultimately, how to play yourself with your freshly deployed LLM.
Here is the full article explaining how to deploy the LLM to AWS SageMaker โ
Step 2: Call the SageMaker inference endpoint
You've just deployed your Mistral LLM to SageMaker.
๐๐ฐ๐ธ ๐ธ๐ฉ๐ข๐ต?
Unfortunately, you are not done.
That was just the beginning of the journey.
โ Now, you have to write a Python client that calls the LLM.
๐๐ฒ๐'๐ ๐๐๐ฒ ๐ฎ ๐ฑ๐ผ๐ฐ๐๐บ๐ฒ๐ป๐ ๐๐๐บ๐บ๐ฎ๐ฟ๐ ๐๐ฎ๐๐ธ ๐ฎ๐ ๐ฎ๐ป ๐ฒ๐
๐ฎ๐บ๐ฝ๐น๐ฒ.
โโโ
๐ฆ๐๐ฒ๐ฝ ๐ญ: Define a Settings object using ๐ฑ๐บ๐ฅ๐ข๐ฏ๐ต๐ช๐ค.
๐ฆ๐๐ฒ๐ฝ ๐ฎ: Create an inference interface that inherits from ๐๐๐
๐ฆ๐๐ฒ๐ฝ ๐ฏ: Implement an ๐๐๐ ๐๐ข๐จ๐ฆ๐๐ข๐ฌ๐ฆ๐ณ version of the inference interface by specifying how to construct the HTTP payload and call the SageMaker endpoint. We want to keep this class independent from the summarization prompt!
๐ฆ๐๐ฒ๐ฝ ๐ฐ: Create the summarization prompt.
๐ฆ๐๐ฒ๐ฝ ๐ฑ: Encapsulate the summarization prompt and Python SageMaker client into a ๐๐ถ๐ฎ๐ฎ๐ข๐ณ๐ช๐ป๐ฆ๐๐ฉ๐ฐ๐ณ๐ต๐๐ฐ๐ค๐ถ๐ฎ๐ฆ๐ฏ๐ต task.
๐ฆ๐๐ฒ๐ฝ ๐ฒ: Wrap the ๐๐ถ๐ฎ๐ฎ๐ข๐ณ๐ช๐ป๐ฆ๐๐ฉ๐ฐ๐ณ๐ต๐๐ฐ๐ค๐ถ๐ฎ๐ฆ๐ฏ๐ต task with a FastAPI endpoint.
...and bam!
You have an LLM for summarizing any document.
.
๐๐ฒ๐ฟ๐ฒ ๐ฎ๐ฟ๐ฒ ๐๐ผ๐บ๐ฒ ๐ฎ๐ฑ๐๐ฎ๐ป๐๐ฎ๐ด๐ฒ๐ ๐ผ๐ณ ๐๐ต๐ฒ ๐ฑ๐ฒ๐๐ถ๐ด๐ป ๐ฑ๐ฒ๐๐ฐ๐ฟ๐ถ๐ฏ๐ฒ๐ฑ ๐ฎ๐ฏ๐ผ๐๐ฒ:
- by using an inference interface, you can quickly swap the LLM implementation
- by decoupling the prompt construction logic from the inference class, you can reuse the inference client with any prompt
- by wrapping everything with a ๐๐ถ๐ฎ๐ฎ๐ข๐ณ๐ช๐ป๐ฆ๐๐ฉ๐ฐ๐ณ๐ต๐๐ฐ๐ค๐ถ๐ฎ๐ฆ๐ฏ๐ต task you can quickly define & configure multiple types of tasks and leverage polymorphism to run them
Here is the full article explaining how to design the inference module โ
Images
If not otherwise stated, all images are created by the author.