服务粉丝

我们一直在努力
当前位置:首页 > 财经 >

30分钟吃掉wandb模型训练可视化

日期: 来源:算法美食屋收集编辑:梁云1991

wandb是"我爱你,大baby"首字母的缩写。

顾名思义,她是炼丹师的大宝贝,是炼丹师最爱的炼丹伴侣。

公众号算法美食屋后台回复关键词:wandb,获取本教程 notebook源码 B站视频演示



just kidding, 开个玩笑!

wandb全称weights&bias,是一款类似TensorBoard

的机器学习可视化分析工具。

相比TensorBoard,wandb具有如下主要优势:

  • 日志上传云端永久存储,便于分享不怕丢失。

  • 可以存管代码,数据集和模型的版本,随时复现。(wandb.Artifact)

  • 可以使用交互式表格进行case分析(wandb.Table)

  • 可以自动化模型调参。(wandb.sweep)

官方文档:https://docs.wandb.ai/


总体来说,wandb目前的核心功能有以下4个:

1,实验跟踪:experiment tracking (wandb.log)

2,版本管理:version management (wandb.log_artifact, wandb.save)

3,case分析:case visualization (wandb.Table, wandb.Image)

4,超参调优:model optimization (wandb.sweep)

本文我们主要介绍 前3个能力,超参调优的介绍在下一篇文章。

〇,注册wandb

使用wandb可视化模型训练过程需要在  https://wandb.ai/ 注册账户,

并在个人settings页面获取 API keys。

#import os
#os.environ["WANDB_API_KEY"] = "xxxx"

import wandb
wandb.login()


一,实验跟踪

wandb 提供了类似 TensorBoard的实验跟踪能力,主要包括:

  • 模型配置超参数的记录

  • 模型训练过程中loss,metric等各种指标的记录和可视化

  • 图像的可视化(wandb.Image)

  • 其他各种Media(wandb.Vedio, wandb.Audio, wandb.Html, 3D点云等)

import os,PIL 
import numpy as np
from torch.utils.data import DataLoader, Dataset
import torch 
from torch import nn 
import torchvision 
from torchvision import transforms
import datetime
import wandb 
from argparse import Namespace

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
config = Namespace(
    project_name = 'wandb_demo',
    
    batch_size = 512,
    
    hidden_layer_width = 64,
    dropout_p = 0.1,
    
    lr = 1e-4,
    optim_type = 'Adam',
    
    epochs = 15,
    ckpt_path = 'checkpoint.pt'
)

def create_dataloaders(config):
    transform = transforms.Compose([transforms.ToTensor()])
    ds_train = torchvision.datasets.MNIST(root="./mnist/",train=True,download=True,transform=transform)
    ds_val = torchvision.datasets.MNIST(root="./mnist/",train=False,download=True,transform=transform)

    ds_train_sub = torch.utils.data.Subset(ds_train, indices=range(0, len(ds_train), 5))
    dl_train =  torch.utils.data.DataLoader(ds_train_sub, batch_size=config.batch_size, shuffle=True,
                                            num_workers=2,drop_last=True)
    dl_val =  torch.utils.data.DataLoader(ds_val, batch_size=config.batch_size, shuffle=False, 
                                          num_workers=2,drop_last=True)
    return dl_train,dl_val

def create_net(config):
    net = nn.Sequential()
    net.add_module("conv1",nn.Conv2d(in_channels=1,out_channels=config.hidden_layer_width,kernel_size = 3))
    net.add_module("pool1",nn.MaxPool2d(kernel_size = 2,stride = 2)) 
    net.add_module("conv2",nn.Conv2d(in_channels=config.hidden_layer_width,
                                     out_channels=config.hidden_layer_width,kernel_size = 5))
    net.add_module("pool2",nn.MaxPool2d(kernel_size = 2,stride = 2))
    net.add_module("dropout",nn.Dropout2d(p = config.dropout_p))
    net.add_module("adaptive_pool",nn.AdaptiveMaxPool2d((1,1)))
    net.add_module("flatten",nn.Flatten())
    net.add_module("linear1",nn.Linear(config.hidden_layer_width,config.hidden_layer_width))
    net.add_module("relu",nn.ReLU())
    net.add_module("linear2",nn.Linear(config.hidden_layer_width,10))
    net.to(device)
    return net 

def train_epoch(model,dl_train,optimizer):
    model.train()
    for step, batch in enumerate(dl_train):
        features,labels = batch
        features,labels = features.to(device),labels.to(device)

        preds = model(features)
        loss = nn.CrossEntropyLoss()(preds,labels)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()
    return model

def eval_epoch(model,dl_val):
    model.eval()
    accurate = 0
    num_elems = 0
    for batch in dl_val:
        features,labels = batch
        features,labels = features.to(device),labels.to(device)
        with torch.no_grad():
            preds = model(features)
        predictions = preds.argmax(dim=-1)
        accurate_preds =  (predictions==labels)
        num_elems += accurate_preds.shape[0]
        accurate += accurate_preds.long().sum()

    val_acc = accurate.item() / num_elems
    return val_acc

def train(config = config):
    dl_train, dl_val = create_dataloaders(config)
    model = create_net(config); 
    optimizer = torch.optim.__dict__[config.optim_type](params=model.parameters(), lr=config.lr)
    #======================================================================
    nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    wandb.init(project=config.project_name, config = config.__dict__, name = nowtime, save_code=True)
    model.run_id = wandb.run.id
    #======================================================================
    model.best_metric = -1.0
    for epoch in range(1,config.epochs+1):
        model = train_epoch(model,dl_train,optimizer)
        val_acc = eval_epoch(model,dl_val)
        if val_acc>model.best_metric:
            model.best_metric = val_acc
            torch.save(model.state_dict(),config.ckpt_path)   
        nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        print(f"epoch【{epoch}】@{nowtime} --> val_acc= {100 * val_acc:.2f}%")
        #======================================================================
        wandb.log({'epoch':epoch, 'val_acc': val_acc, 'best_val_acc':model.best_metric})
        #======================================================================        
    #======================================================================
    wandb.finish()
    #======================================================================
    return model   
model = train(config) ##3,2,1 点火        

相关阅读

  • 货拉拉 Android 模块化路由框架:TheRouter

  • 点击上方蓝字关注我,知识会给你力量本文作者:张涛-货拉拉TheRouter 是一个 Kotlin 编写,用于 Android 模块化开发的一整套解决方案框架。Github 项目地址与使用文档详见 https:
  • 说回 TheRouter

  • 点击上方蓝字关注我,知识会给你力量❝补充:开源仓库地址:https://github.com/HuolalaTech/hll-wp-therouter-android❞没错,货拉拉开源的路由库 —— TheRouter 是我写的大约在1
  • TheRouter 的跨模块依赖注入实现原理

  • 点击上方蓝字关注我,知识会给你力量本文作者——张涛(货拉拉)TheRouter用于跨模块通信设计的ServiceProvider,核心设计思想是参考了SOA(面向服务架构)的设计方式。具体到 Androi
  • kotlin修炼指南8—集合中的高阶函数

  • 点击上方蓝字关注我,知识会给你力量Kotlin对集合操作类新增了很多快捷的高阶函数操作,各种操作符让很多开发者傻傻分不清,特别是看一些Kotlin的源码或者是协程的源码,各种眼花缭
  • kotlin修炼指南9-Sequence的秘密

  • 点击上方蓝字关注我,知识会给你力量人们经常忽略Iterable和Sequence之间的区别。这是可以理解的,因为即使它们的定义也几乎是相同的。interface Iterable<out T> { operato
  • Flutter混编工程之异常处理

  • 点击上方蓝字关注我,知识会给你力量Flutter App层和Framework层的异常,通常是不会引起Crash的,但是Engine层的异常会造成Crash。而Flutter Engine部分的异常,主要是libfutter.so
  • 闲言碎语-第八期

  • 点击上方蓝字关注我,知识会给你力量时间一晃就过去了,22年的总结还没来得及写,转眼已经2023年了。22年对于很多人来说,应该都是比较魔幻的一年,特别是在上海的朋友,一小半的时间都
  • 来自亲爹的爱,有但是不多~!

  • 有了宝宝后麻麻们总觉得爸爸跟不上育儿步伐都说爸爸带娃只有三分钟热度麻麻们纷纷表示“带的很好,下次别带了”其实,新一代爸比育儿有方法快来看看他们都是怎么应对—····
  • 所有美好都在冬日沉淀

  • 冬日·美好HELLO WINTER致可爱的你寒冷的冬日来袭,气温骤降外面开始感到阵阵寒意但是我们依然可以通过细微之处将冬日的生活过得热气腾腾幸福满满守护陪伴冬日HELLO SNOW时节

热门文章

  • “复活”半年后 京东拍拍二手杀入公益事业

  • 京东拍拍二手“复活”半年后,杀入公益事业,试图让企业捐的赠品、家庭闲置品变成实实在在的“爱心”。 把“闲置品”变爱心 6月12日,“益心一益·守护梦想每一步”2018年四

最新文章

  • 30分钟吃掉wandb模型训练可视化

  • wandb是"我爱你,大baby"首字母的缩写。顾名思义,她是炼丹师的大宝贝,是炼丹师最爱的炼丹伴侣。公众号算法美食屋后台回复关键词:wandb,获取本教程 notebook源码 和 B站视频演示。