cifar
#include "darknet.h"train_cifar
void train_cifar(char *cfgfile, char *weightfile)
{
srand(time(0));
float avg_loss = -1;
char *base = basecfg(cfgfile);
printf("%s\n", base);
network *net = load_network(cfgfile, weightfile, 0);
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
char *backup_directory = "/home/pjreddie/backup/";
int classes = 10;
int N = 50000;
char **labels = get_labels("data/cifar/labels.txt");
int epoch = (*net->seen)/N;
data train = load_all_cifar10();
while(get_current_batch(net) < net->max_batches || net->max_batches == 0){
clock_t time=clock();
float loss = train_network_sgd(net, train, 1);
if(avg_loss == -1) avg_loss = loss;
avg_loss = avg_loss*.95 + loss*.05;
printf("%ld, %.3f: %f, %f avg, %f rate, %lf seconds, %ld images\n", get_current_batch(net), (float)(*net->seen)/N, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net->seen);
if(*net->seen/N > epoch){
epoch = *net->seen/N;
char buff[256];
sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch);
save_weights(net, buff);
}
if(get_current_batch(net)%100 == 0){
char buff[256];
sprintf(buff, "%s/%s.backup",backup_directory,base);
save_weights(net, buff);
}
}
char buff[256];
sprintf(buff, "%s/%s.weights", backup_directory, base);
save_weights(net, buff);
free_network(net);
free_ptrs((void**)labels, classes);
free(base);
free_data(train);
}train_cifar_distill
test_cifar_multi
test_cifar
extract_cifar
test_cifar_csv
test_cifar_csvtrain
eval_cifar_csv
run_cifar
Last updated