picograd - an autograd engine in c
July 2024
over the past two weeks, i’ve spent a lot of my time building out picograd: an implementation of an automatic differentiation (autograd) engine in less than 1,000 lines of c. it provides a foundation for building and training simple neural networks from scratch, supporting various layer types, activation functions, and optimizations. it was heavily inspired from karpathy’s micrograd in python which was my first introduction to machine learning (thank you big bro andrej).
much like micrograd, everything in picograd is built around the Value
struct. each instance represents a single node in the computational graph which forms the basis for backpropagation during training. each node has an associated backward function which then allows it to know how to calculate gradients with respect to its inputs.
these values represent things like weights, biases, activations, and more, but they can also represent intermediate computations throughout the network for things like forward passes, loss calculations, etc. this allows for values with different purposes to be managed separately and allows for impermanent values to be freed from memory once they’re no longer needed.
// TLDR: structure of Value node in graph
// data -> scalar value (1.0, 2.0, etc)
// grad -> gradient calculated during backprop
// backward -> pointer that stores derivative function associated with operation
// prev -> ValueArray of "child" Values
// forward: these are nodes before the current node in the forward pass
// backward: these are nodes after the current node in the backward pass
struct Value{
double data;
double grad;
void (*backward)(Value*);
ValueArray* prev;
};
Value* create_value(double scalar, bool is_temp){
Value* node = (Value*)malloc(sizeof(Value));
node->data = scalar;
node->grad = 0.0;
node->backward = NULL;
node->prev = (ValueArray*)malloc(sizeof(ValueArray));
initialize_array(node->prev, 1);
if (is_temp){
insert_array(&temp_values, node); // neuron forward pass, losses, eps values, etc
} else{
insert_array(&network_params, node); // neuron weights / biases, things that need gradients
}
return node;
despite it working, i’m almost certain this is a very suboptimal way to handle the memory, but considering this is was my first venture into c, i considered anything that didn’t leak memory a victory. at first i simply had network_params
that tracked every node which i planned on freeing after training was complete; a fool-proof plan, so what could go wrong? it worked at first, but performance diminished exponentially as the network grew in size. i wasn’t freeing the old useless memory, such as the forward pass of a neuron/layer, so things almost froze to a complete halt trying to keep track of things i didn’t even care about anymore.
after modifying create_value
to include a is_temp
flag, i could now separate “permanent” values like network weights and biases from “temporary” values like the output of a forward pass, the mean squared error of the 117th step, and so on. this is the purpose of temp_values
and it served that purpose well; things sped up and and no memory leaks made their way through.
still, i don’t really like it. i have to initialize it like this whenever i want to use it to train something:
#include "../picograd.h"
ValueArray network_params; // this is global, ew
ValueArray temp_values; // this is global, ew
...
int main(){
initialize_array(&network_params, 1); // for memory management, this needs to be initialized
initialize_array(&temp_values, 1); // for memory management, this needs to be initialized
...
i could have taken the approach where i have a requires_grad
variable that i could set to false that would not assign a gradient, a backward function, an array of children, etc. this wouldn’t mess with the graph in anyway, so i would have less worries about freeing something that might still be tied up somewhere and getting invalid read segfaults. i could’ve also implemented something like a ref_count
which counts how many other nodes are pointing to that node and simply routinely checking nodes and freeing ones with zero connections. i might try and implement these changes later down the road.
features
picograd has a lot of things you’d find in other small autograd engines such as:
- automatic differentiation:
- implements backpropagation for 10 operations
- basic neural network components:
- neurons
- layers
- multi-layer perceptron (mlp)
- activation functions:
- relu
- sigmoid
- tanh
- normalization:
- batch normalization
- loss functions:
- mean squared error (mse)
- binary cross-entropy
- optimization:
- stochastic gradient descent (sgd)
- learning rate decay
- l2 regularization
- dynamic memory management:
- a custom (and as we’ve seen, poorly built) implementation for creating and freeing various structures
implementing these features from scratch was super helpful in truly understanding each step. it’s very easy to get caught up in tutorial hell where you’re just watching someone implement something while convincing yourself you’re “learning” only to end up sitting infront of an empty text editor having no idea what to do.
a couple times throughout the process i found myself going back to karpathy’s zero to hero series to listen to him explain certain things about broadcasting, batch normalization, etc. it mentally clicks a different way when you cannot directly copy lines of code one at a time.
project structure
picograd is composed of 7 header files. you can simply #include picograd.h
and be fine.
types.h
: defines the core data structures and enums used throughout the project.ops.h
: implements basic operations and their corresponding gradients.nn.h
: contains neural network component implementations (neurons, layers, mlp, normalization).matrix.h
: provides matrix operations for up to three dimensions.memory.h
: manages memory allocation and deallocation for the project.train.h
: implements the training loop and associated functions.picograd.h
: main header file that includes all other headers.
example usage
here’s a basic approach of how i used picograd to train a simple mlp to solve XOR:
#include "../picograd.h"
// tracks permanent Value objects that need gradients (think weights and biases)
ValueArray network_params;
// tracks temporary Value objects that can be routinely flushed (think intermediate results during a forward pass)
ValueArray temp_values;
// EXAMPLE:
// XOR logic gate
// input: 4 XOR states
// output: their expected outputs
#define BATCH_SIZE 4 // how many samples are you providing?
#define NUM_INPUTS 2 // how many inputs are in each sample?
#define NUM_HIDDEN 16 // how many Neurons are in each hidden layer?
#define NUM_OUTPUTS 1 // how many outputs should the MLP calculate?
#define NUM_EPOCHS 1000 // how long should we train the MLP?
#define LEARNING_RATE 1e-1 // what's the learning rate?
#define LEARNING_RATE_DECAY 0 // do you want a learning rate decay? (i.e. 1e-3)
#define LOSS_TYPE CROSS_ENTROPY // loss
int main(){
initialize_array(&network_params, 1); // for memory management, this needs to be initialized
initialize_array(&temp_values, 1); // for memory management, this needs to be initialized
// create input and output data
double r_xor_inputs[][2] = {{0, 0}, {0, 1}, {1, 0}, {1, 1}};
double xor_outputs[4] = {0, 1, 1, 0};
double** xor_inputs = reshape_input(r_xor_inputs, BATCH_SIZE, NUM_INPUTS, 1, 'd');
// create network configuration
size_t num_layers = 3;
size_t layer_sizes[] = {NUM_INPUTS, NUM_HIDDEN, NUM_OUTPUTS};
ActivationType layer_activations[] = {RELU, SIGMOID}; // relu: input->hidden, sigmoid: hidden->output
NormalizationType layer_normalizations[] = {-1, -1}; // no batch normalization needed
NetworkConfig* config = create_network_config(num_layers, layer_sizes, layer_activations, layer_normalizations);
MLP* n = initialize_mlp(config);
print_network_info("Created Network Configuration",
config->num_layers,
config->layer_sizes,
config->layer_activations,
config->layer_normalizations,
NULL,
n);
free_network_config(config);
// training parameters
TrainingParams params = {
.num_samples = BATCH_SIZE,
.num_inputs = NUM_INPUTS,
.num_outputs = NUM_OUTPUTS,
.num_epochs = NUM_EPOCHS,
.learning_rate = LEARNING_RATE,
.lr_decay = LEARNING_RATE_DECAY,
.loss_type = LOSS_TYPE
};
// train the model
train_mlp(n, xor_inputs, xor_outputs, ¶ms);
// test the model
printf("\nTesting the model:\n");
for (int i = 0; i < 4; i++){
ValueArray* input = array_to_value_array(xor_inputs[i], 2);
ValueArray* output = mlp_forward(n, input, false);
printf("Input: (%f, %f), Output: %f\n", xor_inputs[i][0], xor_inputs[i][1], output->values[0]->data);
free_array(input);
free_array(output);
}
// Epoch 1000, Loss: 0.001823
// Testing the model:
// Input: (0.000000, 0.000000), Output: 0.003383
// Input: (0.000000, 1.000000), Output: 0.998622
// Input: (1.000000, 0.000000), Output: 0.998600
// Input: (1.000000, 1.000000), Output: 0.001030
free_mlp(n);
free_network_params();
return 0;
}
and here’s an example of using picograd to solve a standard “predict the next number in a sequence” problem like this:
#include "../picograd.h"
ValueArray network_params;
ValueArray temp_values;
// EXAMPLE:
// simple sequence prediction problem
// input: sequences of 5 numbers
// output: predict the next number in the sequence
#define BATCH_SIZE 20 // how many samples are you providing?
#define SEQUENCE_LENGTH 5 // what's the context window of each sample?
#define NUM_INPUTS 1 // how many features for each sequence step?
#define NUM_HIDDEN 8 // how many Neurons are in each hidden layer?
#define NUM_OUTPUTS 1 // how many outputs should the MLP calculate?
#define NUM_EPOCHS 2000 // how long should we train the MLP?
#define LEARNING_RATE 1e-4 // what's the learning rate?
#define LEARNING_RATE_DECAY 1e-4 // do you want a learning rate decay? (i.e. 1e-3)
#define LOSS_TYPE MSE // loss
// generate a simple sequence: each number is the sum of the two preceding numbers
void generate_sequence(double* sequence, int length){
sequence[0] = rand() % 10;
sequence[1] = rand() % 10;
for (int i = 2; i < length; i++){
sequence[i] = sequence[i-1] + sequence[i-2];
}
}
int main(){
initialize_array(&network_params, 1);
initialize_array(&temp_values, 1);
// create input and output data
double input_data[BATCH_SIZE][SEQUENCE_LENGTH][NUM_INPUTS]; // (50, 5, 1)
double target_data[BATCH_SIZE][NUM_OUTPUTS]; // (50, 1)
for (int i = 0; i < BATCH_SIZE; i++){ // for 50 batches
double sequence[SEQUENCE_LENGTH + 1];
generate_sequence(sequence, SEQUENCE_LENGTH + 1); // first sequence length are the actual sequence, +1 is the target
for (int j = 0; j < SEQUENCE_LENGTH; j++){
input_data[i][j][0] = sequence[j]; // inserting input elements to batch i at slot j
}
target_data[i][0] = sequence[SEQUENCE_LENGTH]; // each batch only has 1 output, the target number
}
// reshape 3d input data -> 2d input data
double** reshaped_input = reshape_input(input_data, BATCH_SIZE, SEQUENCE_LENGTH, NUM_INPUTS, 'd');
// create network configuration
size_t num_layers = 4;
size_t layer_sizes[] = {SEQUENCE_LENGTH * NUM_INPUTS, NUM_HIDDEN, NUM_HIDDEN, NUM_OUTPUTS};
ActivationType layer_activations[] = {TANH, TANH, -1};
NormalizationType layer_normalizations[] = {BATCH, BATCH, -1};
NetworkConfig* config = create_network_config(num_layers, layer_sizes, layer_activations, layer_normalizations);
MLP* n = initialize_mlp(config);
print_network_info("Created Network Configuration",
config->num_layers,
config->layer_sizes,
config->layer_activations,
config->layer_normalizations,
NULL,
n);
// set up training parameters
TrainingParams params = {
.num_samples = BATCH_SIZE,
.num_inputs = SEQUENCE_LENGTH * NUM_INPUTS,
.num_outputs = NUM_OUTPUTS,
.num_epochs = NUM_EPOCHS,
.learning_rate = LEARNING_RATE,
.lr_decay = LEARNING_RATE_DECAY,
.loss_type = LOSS_TYPE
};
// train the model
train_mlp(n, reshaped_input, target_data[0], ¶ms);
// test the model
printf("Testing the model:\n");
for (int i = 0; i < 5; i++){
double test_sequence[SEQUENCE_LENGTH + 1];
generate_sequence(test_sequence, SEQUENCE_LENGTH + 1);
double test_input[SEQUENCE_LENGTH][NUM_INPUTS]; // 2d because we're just testing single examples, don't need batches
for (int j = 0; j < SEQUENCE_LENGTH; j++){
test_input[j][0] = test_sequence[j];
}
double** reshaped_test = reshape_input(test_input, 1, SEQUENCE_LENGTH, NUM_INPUTS, 'd');
ValueArray* input_array = array_to_value_array(reshaped_test[0], SEQUENCE_LENGTH * NUM_INPUTS);
ValueArray* output = mlp_forward(n, input_array, false); // false because we're at inference
printf("Input sequence: ");
for (int j = 0; j < SEQUENCE_LENGTH; j++){
printf("%.1f ", test_sequence[j]);
}
printf("\n");
printf("Predicted next number: %.1f\n", output->values[0]->data);
printf("Actual next number: %.1f\n\n", test_sequence[SEQUENCE_LENGTH]);
free_array(input_array);
free_array(output);
free_reshape(reshaped_test, 1);
}
// Epoch 2000, Loss: 1.813578
//
// Testing the model:
// Input sequence: 6.0 3.0 9.0 12.0 21.0
// Predicted next number: 33.5
// Actual next number: 33.0
//
// Input sequence: 8.0 5.0 13.0 18.0 31.0
// Predicted next number: 47.9
// Actual next number: 49.0
//
// Input sequence: 6.0 1.0 7.0 8.0 15.0
// Predicted next number: 22.0
// Actual next number: 23.0
//
// Input sequence: 1.0 5.0 6.0 11.0 17.0
// Predicted next number: 31.3
// Actual next number: 28.0
//
// Input sequence: 9.0 8.0 17.0 25.0 42.0
// Predicted next number: 61.5
// Actual next number: 67.0
// clean up
free_mlp(n);
free_network_config(config);
free_reshape(reshaped_input, BATCH_SIZE);
free_network_params();
return 0;
}
you can try out the examples inside the examples folder by running something like:
gcc -O2 -o xor xor.c -lm
gcc -O2 -o binary binary.c -lm
gcc -O2 -o sequence sequence.c -lm
some things to consider
this is my first project in c and was meant to serve as a tutorial to the language. prior to starting picograd, i saw c as a very intimidating language that i could barely write fizz-buzz in; however, after building all this out, i actually feel somewhat competent when using it (with a bit of help from claude at least). as frustrating as it was resolving segfaults and memory leaks, i now actually enjoy writing c and i believe this project made me a better programmer and “ml engineer.”
on top of that, i am much more confident in my knowledge of the foundation of neural networks after having to implement nearly everything from scratch in a language i am not fluent in whatsoever. it was hard, but i enjoy doing difficult things.
with that being said, this is clearly not a battle-tested engine, nor was it ever meant to be! i would not recommend using it as there are plenty of better options out there lol. there’s a million different things i could continue to implement, improve, or refactor, and i may do so as time goes on. the main goal of this was to learn more about c and strengthen my knowledge on the fundamentals of neural networks, and i believe i accomplished both of those things in 1289 lines.