๐งน Sweep์ด๋?
0. Overview
๋จ 1%๋ผ๋ ์ฑ๋ฅ์ ์ฌ๋ฆฌ๊ธฐ์ํด ๋ง์ ์ฌ๋๋ค์ด ๋ถ๋จํ ๋ ธ๋ ฅ์ ํ๋ค. (ex. paperswithcode)
๊ฒฐ๊ตญ, ๋ชจ๋ธ์ ์ฑ๋ฅ์ ์ต๋๋ก ๋์ด์ฌ๋ฆฌ๊ธฐ ์ํด์๋ Hyper-parameter๋ฅผ ๋ณ๊ฒฝํ๋ฉฐ ์ต์ ์ ๊ฐ์ ์ฐพ๊ธฐ์ํด ์ด์ ๋ํด์๋ ๋ถ๋จํ ๋ ธ๋ ฅ์ ๊ธฐ์ธ์ฌ์ผ ํ์ง๋ง, ์ด๋ ๋งค์ฐ ํผ๊ณคํ๊ณ Cost๋ฅผ ๋ง์ด ๋ค๊ฒ ๋ง๋ ๋ค.
์ด๋ฅผ ์ํด ๋ฑ์ฅํ ๊ฒ์ด ๋ฐ๋ก W&B์ Sweep์ด๋ค!
1. Sweep์ด๋?
๊ธฐ๋ณธ์ ์ผ๋ก Hyper-parameter๋ฅผ ์๋์ผ๋ก ์ต์ ํ์ฃผ๋ Tool
Hyper-parameter Seach๋ฐฉ์์ผ๋ก ๋ค์ 3๊ฐ์ง๊ฐ ์กด์ฌํ๋ค.
- Grid ๋ฐฉ์
- Random ๋ฐฉ์
- Bayes ๋ฐฉ์
์ ํํ search ๋ฐฉ์์ผ๋ก ํ์ดํผ ํ๋ผ๋ฏธํฐ ํ๋์ด ์๋ฃ ๋๋ฉด WandB์ ์น์์ ์ ๊ณต๋๋ dashboard๋ก ์๊ฐํ๋ ๋ชจ์ต์ ๋ณผ ์ ์๋ค.
์ด๋ ๊ฒ ์๊ฐํ๋ ๋ชจ์ต์ ์์ ๊ทธ๋ฆผ๊ณผ ๊ฐ๋ค. Sweep์ ์๋์ผ๋ก tuningํด์ฃผ๋ ๊ธฐ๋ฅ ๋ฟ๋ง ์๋๋ผ,
๊ฐ๊ฐ์ hyper parameter๋ค์ด metric(accuracy, loss ๋ฑ)์ ์ผ๋ง๋ ์ค์ํ ์ง ์๋ ค์ฃผ๊ณ ์๊ด๊ด๊ณ๋ฅผ ๋ณด์ฌ์ฃผ๊ธฐ์ ํ์์ ์ด๋ผ ํ ์ ์๋ค.
๐งน Sweep ์ฌ์ฉ๋ฒโ๏ธ
Sweep์ ํํ 2๊ฐ์ ๋จ๊ณ(Initialize the Sweep, Run the Sweep Agent)๊ฐ ํ์ํ๋ค.
1. Initialize the Sweep
โ Sweep Configuration๋ฅผ ์ ์
Sweep Initialize๋ฅผ ์ํด ๋จผ์ ๊ตฌ์ฑ์์(configuration)๋ฅผ ์ ์ํด์ผํ๋ค.
์ด๋ฅผ ์ํด required์ option์ผ๋ก ๋๋๋ค.
program(์ด๋์์) method(๋ฌด์์) parameters(์ด๋ป๊ฒ)
์ต์ ํ๋ฅผ ํ ๊ฒ์ธ์ง ์ ์ํด์ผํ๋ค.
์ด๋, ์ต์ ํ ๋ฐฉ๋ฒ์ผ๋ก 3๊ฐ์ง๊ฐ ์กด์ฌํ๋ค.
- Grid ๋ฐฉ์ : ๊ฐ๋ฅํ ๋ชจ๋ ์กฐํฉ ํ์ (= Cost↑)
- Random ๋ฐฉ์ : randomํ๊ฒ ์ ํ (= Cost↓, opt์ฐพ์ํ๋ฅ ↓)
- Bayes ๋ฐฉ์ : ์ด์ ์ ์๋ํ hyper-parameter์กฐํฉ์ ๊ฒฐ๊ณผ๋ฅผ ์ฌ์ฉ, ๋ค์์๋์กฐํฉ ์ถ๋ก ์ ์ฌ์ฉ
→ ๋ชจ๋ธ์ฑ๋ฅ์ ์ต๋๋ก ํฅ์์ํฌ ์ ์๋ hyper-parameter์กฐํฉ์ ์ฐพ๋๋ค. (= ์ด๊ธฐํ์์ด ๋๋ฆผ)
์ด๋, ํนํ๋ parameters ํํธ๊ฐ ์ค์ํ๊ธฐ์ ์ข ๋ ์ดํด๋ณด์.
values value
Hyper-parameter์ ๋ํด ํน์ ๊ฐ์ ์ค์ ํด์ ์ฐ๋ฆฌ๊ฐ ์ํ๋ ๊ฐ๋ง ์ ํํ๊ฒ ํด์ค.
(value๋ 1๊ฐ์ง ๊ฐ์ ์ค์ ํด์ค ๋ ์ฌ์ฉ)
distribution
values์ ๋์กฐ๋๋ ๋ฐฉ์.
ํน์ ๊ฐ์ ์ค์ ํ๋ ๋์ ์ํ๋ ๋ถํฌ ์์์ ๊ฐ์ ์ ํ.
Sweep์์๋ uniform, normal, q_log_uniform๊ณผ ๊ฐ์ด ๋ค์ํ ๋ถํฌ๋ฅผ ์ ๊ณต.
๋ํ ์ ํ๋ ๋ถํฌ๋ฅผ min, max์ mu, sigma, q๋ฅผ ํตํด ์์ ๋กญ๊ฒ ๋ณํ๊ฐ๋ฅ.
min, max
๋ถํฌ์ ์ต์โ์ต๋๊ฐ์ ์ค์ .
mu sigma
ํ๊ท ๊ณผ ํ์คํธ์ฐจ๋ฅผ ๋ํ๋ด๋ ๊ฐ, ์ ๊ท๋ถํฌ(normal)์ ๋ชจ์์ ๊ฒฐ์ .
q
Quantization์ ์ฝ์๋ก distribution์์ ๋์จ ๊ฐ X๋ฅผ ์์ํ.
ex) q๋ฅผ 2๋ก ์ค์ ํ๋ค๋ฉด X๋ 2์ ๋ฐฐ์๋ก ๋ฐ๋.
(ex. ์ round(X / q) *q๋ฅผ ์ ์ฉํ๋ฉด, -2.96์ -2๋ก 13.27์ 14๋ก 8.43์ 8๋ก ๋ฐ๋.)=
โ project์ ์ฌ์ฉํ๊ธฐ์ํด Sweep API๋ก ์ด๊ธฐํ
Sweep์ config๊ฐ ์ ๋๋ก ์ ์๊ฐ ๋๋ค๋ฉด ์ด์ ํ๋ก์ ํธ์ ์ ์ฉ์ ํด์ค์ผํ๋ค.
sweep ์ด๊ธฐํ ์ฝ๋:
sweep_id = wandb.sweep(config.sweep_config)
์์์ ์ ์๋ config ๋ณ์๋ฅผ ์
๋ ฅ์ผ๋ก ๋ฐ๊ณ sweep id๋ฅผ ์ถ๋ ฅํด์ค๋ค.
์ด id๋ ๋ค์ step์์ sweep์ ์คํ์ํฌ ๋ ๊ณ ์ ํ identifier๋ก ์ฌ์ฉ๋๋ค.
2. Run the Sweep Agent
- ํจ์๋ ํ๋ก๊ทธ๋จ์ W&B์๋ฒ์์ ์คํ.
์ด์ ๋ณธ๊ฒฉ์ ์ธ ์คํ๋ง์ด ๋จ์๋ค.
์์์ ์ ์ํด์ค configuration์ ์ฌ์ฉํด sweep์ ์งํํ์.
sweep ์งํ์ฝ๋:
wandb.agent(sweep_id, function=train, count=count)
์ด๋, ์์์ ์ถ๋ ฅ๋ sweep_id๋ฅผ ์ ๋ ฅ์ผ๋ก ๋ฃ์ด์ค๋ค.
๋ํ, function์ ์ฐ๋ฆฌ๊ฐ ์ ์ํ trainํจ์๋ฅผ ๋ฃ์ด์ฃผ๊ณ
sweep์ ๋ช๋ฒ ์งํํ ์ง ์ซ์๋ฅผ count์ ์ ๋ ฅํด์ค๋ค.
cf). yaml ํ์ผ๋ก ์คํํ๋ ๋ฐฉ๋ฒ.
project์ entity๋ฅผ ๊ธฐ์ ๊ฐ๋ฅํ ๊ณณ
- config ์ค์ ํ๋ ํ์ผ (config.py ํน์ config.yaml)
- wandb.sweep()
- wandb.init()
- wandb.agent()
config๋ฅผ .pyํ์ผ๋ก ์ ์ํ๋ ๋ฐฉ์๊ณผ .yamlํ์ผ๋ก ์ ์ํ๋ ๋ฐฉ์์ด ์กด์ฌ.yamlํ์ผ๋ก ์คํํ๋ ๋ฐฉ๋ฒ์ ๋ํด ์์๋ณด์.
1. config.yaml ํ์ผ ์์ฑ
2. . yaml ํ์ผ, Sweep์ ์ ๋ ฅ
wandb sweep config.yaml
3. Sweep id๋ฅผ Agent์ ์ ๋ ฅ
wandb agent SWEEP_ID
cf) wandb terminal์์ ๋ช ๋ น์ด๋ก ์ง์ ํ๊ธฐ.
โ sweep ํ์ ์ ํ
wandb agent --count [LIMIT_NUM] [SWEEPID]
โ Multi-GPU sweep ์ฌ์ฉ
CUDA_VISIBLE_DEVICES=0 wandb agent sweep_id
CUDA_VISIBLE_DEVICES=1 wandb agent sweep_id
๐งน W&B Sweep ์คํ์ ์ํ ์์์ฝ๋
from dataset import SweepDataset
from model import ConvNet
from optimize import build_optimizer
from utils import train_epoch
import wandb
import config
parser = argparse.ArgumentParser()
parser.add_argument('--batch-size', type=int, default=8, metavar='N')
parser.add_arguemnt('--epochs', type=int, default=10)
args = parser.parse_args()
wandb.config.update(args)
def train():
wandb.init(config=config.hyperparameter_defaults)
w_config = wandb.config
loader = SweepDataset(w_config.batch_size, config.train_transform)
model = ConvNet(w_config.fc_layer_size, w_config.dropout).to(config.DEVICE)
optimizer = build_optimizer(model, w_config.optimizer, w_config.learning_rate)
wandb.watch(model, log='all')
for epoch in range(w_config.epochs):
avg_loss = train_epoch(model, loader, optimizer, wandb)
print(f"TRAIN: EPOCH {epoch + 1:04d} / {w_config.epochs:04d} | Epoch LOSS {avg_loss:.4f}")
wandb.log({'Epoch': epoch, "loss": avg_loss, "epoch": epoch})
sweep_id = wandb.sweep(config.sweep_config)
wandb.agent(sweep_id, train, count=2)
sweep์ ์ํ config ํ์ผ์ config.py์ ๊ตฌํ๋์ด ์์ต๋๋ค. ์ฝ๋๋ฅผ ์์๋๋ก ์ค๋ช ํ๋ฉด ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
1. hyper parameter์ ์ด๊ธฐ๊ฐ์ wandb.init์ ์ ๋ ฅ์ผ๋ก ๋ฃ์ด์ค๋๋ค.
2. w_config๋ sweep์ ํ ๋์ hyper parameter์ ๋๋ค.
3. loader, model, optimizer ํจ์์ w_config๋ฅผ ๋งค๊ฐ๋ณ์๋ก ์ ๋ฌํด์ค๋๋ค.
4. model์ ์ ์ํ๋ฉด wandb.watch ํจ์๋ก gradient๋ฅผ ์ถ์ ํฉ๋๋ค.
5. epoch ๋ณ๋ก ๋์ค๋ log๋ฅผ wandb.log์ ์ ์ฅํฉ๋๋ค.
6. config ํ์ผ์ ์ ์ํด๋ ๊ตฌ์ฑ ์์๋ฅผ wandb.sweep์ ์ ๋ ฅํฉ๋๋ค.
7. wandb.sweep์์ ๋์จ id์ ์์ ๊ตฌํ๋ train ํจ์, ๊ทธ๋ฆฌ๊ณ ํ์๋ฅผ wandb.agent์ ์ ๋ ฅํ๊ณ sweep์ ์คํ์ํต๋๋ค.
์ฐธ๊ณ ) https://pebpung.github.io/wandb/2021/10/10/WandB-2.html
๐งน Sweep ์๊ฐํ
WandB์คํ ์ดํ, ์๊ฐํ๋ ๊ฒฐ๊ณผ๋ฅผ ๋ถ์ํด๋ณด์. (๋์๋ณด๋)
์ด๋ฅผ ์ํด์๋ Sweep workspace์ ๊ตฌ์ฑ๋ฐฉ์์ ๋ํด ์์๋ด์ผํ๋ค.
์ข์ธก ๊ทธ๋ํ: y์ถ์ metric, X์ถ์ ์์ฑ๋ ๋ ์ง๋ฅผ ์๋ฏธ.
์ฐ์ธก ํ: hyper parameter๊ฐ metric(accuracy, loss ๋ฑ)์ ์ผ๋ง๋ ์ค์ํ ์ง์ ์๊ด๊ด๊ณ๊ฐ ์ด๋์ ๋ ์ธ์ง๋ ์๋ ค์ค.
์ ๊ทธ๋ฆผ์ hyper-parameter์ ํ๊ณผ์ ์ ์๊ฐ์ ์ผ๋ก ๋ณด์ฌ์ค ๊ทธ๋ฆผ์ด๋ค.
โ X์ถ: config์์ ์ค์ ํ hyper-parameter์ ์ข ๋ฅ
โ y์ถ: config์์ ์ค์ ํ hyper-parameter์ ๋ฒ์
์ถ๊ฐ์ ์ผ๋ก ๋ง์ฐ์ค๋ฅผ ๊ฐ์ ธ๋ค ๋์ผ๋ฉด ํด๋น ๊ทธ๋ํ์์์ ๊ฐ์ ์ ์ ์๋ค.
'Deep Learning : Vision System > Pytorch & MLOps' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[๐ฅPyTorch 2.2]: transformv2 , torch.compile (0) | 2024.01.31 |
---|---|
[WandB] Step 3. WandB ์๊ฐํ ๋ฐฉ๋ฒ. (0) | 2024.01.09 |
[WandB] Step 1. WandB Experiments. with MNIST (2) | 2024.01.09 |