Skip to article frontmatterSkip to article content

The Interactive Transformer

by Joshua Carpeggiani

Introduction

Welcome. This is an implementation of a transformer in pure go with no third party libraries. This way, everything from tensor operations to tokenization are all done inside this notebook.

Because the goal of this project is illustrative, there are no optimisations. Everything runs on a single goroutine and doesn’t have any parallelism. This makes it easy to follow and good as a reference guide.

This page is also heavily referenced. The goal is to have everything have a reference to its original paper or reference implementation.

GPT introduction

References

Usually this goes at the bottom, but seeing this entire thing is a reference, it’ll go at the top instead.

  • IEEE-754 - This is binary floating point. The specification gives good details about what types of errors that can accumulate which would impact training and interence, as well as possible optimisations that can be used, like fused-multiply-add (FMA), which can reduce intermediate errors. CPUs usually use FP32, GPUs FP16.
  • BLAS - Speed up matrix multiplication on CPUs.
  • Layer Normalization - Layernorm was introduced in this paper.
  • Deep Residual Learning for Image Recognition - The introduction of residuals allowed for deeper networks. Before this paper the depth of a neural network was limited because it would diverge enough and back propagation was really, really difficult to do because of vanishing gradients. Residuals essentially have a “short circuit” past a block which limits how much the neural networks can influence.
  • Gaussian Error Linear Units (GELUs) - Activation function that leaves positive values unchanged but maps negative numbers to near zero. Other architectures use different activation functions. For example, OpenElm uses SwiGLU.
  • Learning representations by back-propagating errors - Back-propagation was introduced here, couldn’t find the original paper. This was done by Hinton and co and was what lead to the AI era of the 80s.
  • Adam: A Method for Stochastic Optimization - Introduced the Adam optimiser.
  • DECOUPLED WEIGHT DECAY REGULARIZATION - Introduces AdamW optimiser used in the first transformer. Adam with weight where weight increases as time goes on.
  • Adam: A Method for Stochastic Optimization - Introduced the Adam optimiser.
  • DECOUPLED WEIGHT DECAY REGULARIZATION - Introduces AdamW optimiser used in the first transformer. Adam with weight where weight increases as time goes on.
  • Fast Transformer Decoding: One Write-Head is All You Need - People always point to the original Attention is all you need paper or the GPT paper that introduced the decoder only model, but this one was the first one that actually used it in practice.
  • Language Models are Unsupervised Multitask Learners - This is the GPT-2 paper
  • Improving Language Understanding by Generative Pre-Training -** This paper introduced the “GPT” which was a breakthrough at the time. It introduced the idea of using next token prediction as a way to do self-supervised learning, which meant that we can put all of the internet into it and with a simple loss function over the vocabulary adjust the weights via backpropagation.
  • Attention Is All You Need - The OG introduced the idea of self-attention and the encoder/decoder architecture for language translation tasks (the encoder later got dropped because it was only used for translation). Another breakthrough from this paper was the training; “The Transformer allows for significantly more parallelisation and can reach a new state of the art in translation quality after being trained for as little as twelve hours on eight P100 GPUs.” - This fact here was what let it: overtake RNNs (which weren’t parallelisable), and lead NVIDIA to be worth more than 2.7 Trillion token credits.

Table of contents

Background

Before we can dive into the transformer, we need to cover the basics:

  • Datatypes
  • Matricies + Tensors
  • Matrix multiplication

Datatypes and Math

Because we’re GPU poor, and because it makes the implementation easier, we use float32 for all parameters and calculations.

CPUs can either do calculations in 32 or 64 bits, but the go standard library is opinionated and only supports 64 bit math operations. This wraps all the math functions we need. Whilst all modern architectures have instructions for both float32 and float64 operations, float32 is still faster because it uses 1/2 the bits, so the throughput can be 2x the float64 (citation needed). This is an obvious optimisation for this implementation.

Because graphics applications aren’t needed to be precise, GPUs often use IEEE 754 half precision which is 16 bits, the training loss from switching from 32 -> 16 bits is negligible. (citation needed)

Papers

  • IEEE-754 - I’ve linked the Wikipedia because you need to pay for the standard.
import "math"

func Abs(x float32) float32 {
	if x > 0 {
		return x
	}
	return -x
}

func Cosh(x float32) float32 {
	return float32(math.Cosh(float64(x)))
}

func Exp(x float32) float32 {
	return float32(math.Exp(float64(x)))
}

func Inf(sign int) float32 {
	return float32(math.Inf(sign))
}

func Log(x float32) float32 {
	return float32(math.Log(float64(x)))
}

func IsNaN(f float32) bool {
	return math.IsNaN(float64(f))
}

func Pow(x, y float32) float32 {
	return float32(math.Pow(float64(x), float64(y)))
}

func Sqrt(x float32) float32 {
	return float32(math.Sqrt(float64(x)))
}

func Tanh(x float32) float32 {
	return float32(math.Tanh(float64(x)))
}

Tensors

What is a tensor? A tensor is a multi-dimensional array. A regular slice is one-dimensional, holding elements in a sequence. A tensor can have multiple dimensions, like a 2D array (grid) or even a 3D array (cube).

Tensor libraries like pytorch or tensorflow exist in python. The most widely used tensor library for local inference is https://ggml.ai/ which powers llama.cpp.

What is a tensor


type tensor struct {
	data []float32
	dims []int
    stride []int
}

func (t tensor) Data() []float32 {
	return t.data
}

func newTensor(data []float32, dims ...int) (tensor, int) {
	s := 1
	for _, d := range dims {
		s *= d
	}
	if s > len(data) {
		panic("dimensions larger than supplied data")
	}
	ss := min(s, len(data))
	return tensor{
		data: data[:ss],
		dims: dims,
	}, ss
}

func (t tensor) size() int {
	size := 1
	for _, dim := range t.dims {
		size *= dim
	}
	return size
}

func (t tensor) index(idx ...int) tensor {
	// 1. Error Handling (Partially Adjusted)
	if len(idx) > len(t.dims) {
		panic("Too many indices for tensor dimensions")
	}
	for i, dim := range idx {
		if dim < 0 || dim >= t.dims[i] {
			panic("Index out of bounds")
		}
	}
	// 2. Calculate Linear Index (Partially Adjusted)
	linearIndex := idx[0]
	stride := t.size()
	for i := 1; i < len(idx); i++ {
		stride /= t.dims[i]
		linearIndex += idx[i] * stride
	}
	// 3. Adjust Dimensions and Return Sub-Tensor
	newDims := t.dims[len(idx):]                  // Keep remaining dimensions
	end := linearIndex + t.subTensorSize(newDims) // Size based on remaining dimensions

	return tensor{
		data: t.data[linearIndex:end],
		dims: newDims,
	}
}

// Helper function to calculate the size of a sub-tensor
func (t tensor) subTensorSize(idx []int) int {
	subTensorSize := 1
	for _, dim := range t.dims[len(idx):] {
		subTensorSize *= dim
	}
	return subTensorSize
}

Matrix Multiplication

matmulForward performs matrix multiplication and adds bias. Parameters:

  • out: output matrix
  • inp: input matrix
  • weight: weight matrix
  • bias: bias vector
  • B: batch size
  • T: sequence length (number of time steps)
  • C: input dimension (number of features)
  • OC: number of output channels

Most of the time spent in inference is in this function. Because we’re only doing this on a CPU this implemenation is very, very slow, and this is where different implementations would use a GPU/CUDA/Metal implementation to do parallel computation.

On CPU, many architectures have an optimisation called Basic Linear Algebra Subprograms BLAS. This allows for tiling (breaking matricies into smaller pieces and processing) or Single Instruction Multiple Data (SIMD).

Matrix multiplication

func matmulForward(out, inp, weight, bias []float32, B, T, C, OC int) {
	// Iterate over each batch
	var wg sync.WaitGroup
	for b := 0; b < B; b++ {
		// Iterate over each time step in the sequence
		for t := 0; t < T; t++ {
			wg.Add(1)
			go func(b, t int) {
				defer wg.Done()
				// Calculate the index in the output slice
				inp_bt := inp[b*T*C+t*C:]
				out_bt := out[b*T*OC+t*OC:]
				for o := 0; o < OC; o++ {
					var val float32
					if bias != nil {
						val = bias[o]
					}
					// Calculate the index in the weight slice
					wrow := weight[o*C:]
					// Perform the dot product between the input and weight row
					for i := 0; i < C; i++ {
						val += inp_bt[i] * wrow[i]
					}
					// Store the output value in the output slice
					out_bt[o] = val
				}
			}(b, t)
		}
	}
	wg.Wait()
}

GPT

decoder

Table of contents

Parameters vs Activations

  • Parameters - The bulk of what makes up “the model”. Most of the bytes you download comes from this part.

  • Activations - Output of mathematical operations between the input and the parameters.

Forward pass

A forward pass is the “inference” stage - this section is what’s occuring when you talk with ChatGPT.

Preparing

This section transforms text into a vector representation that can be processed by a neural network.

  • Tokenizer - Converts text to numeric ids that can be processed.
  • Data Loading - This section describes how data is loaded, including batching, tokenization, and offsetting.
  • Embedding - Converts these ids into n dimensional vector space

N-Layers

This section is repeated for every layer. GPT-2 has 12 layers.

  • Masked Multi-Head Attention - Allows all tokens in the context window to impact other tokens in the context window
  • Add and Norm - Adds residual stream and normalises outputs
  • Feed Forward - Feed forward is a standard MLP. Allows for more complex connections to be formed than just the attention mechanism alone.

Final transformations

This section takes the higher dimensionality representations of our activations and processes it to give us our final output

  • Linear - Transformation that reduces dimensionality into “logits” which are correlated to how likely each token is (-inf==never, +inf=100% certainty)
  • Softmax - This takes the logits and creates a probability distribution that adds up to 100%
  • Sampling - This samples the probability distribution and returns the single token that’s needed to make the next prediction

A complete forward pass

This section puts all of this together.

Backwards pass

This is “training”. Companies spend billions of dollars optimizing to make this as fast as possible.

Data loading

import (
	"bytes"
	"encoding/binary"
	"errors"
	"io"
)

const Int32ByteLen = 4

type DataLoader struct {
	filename        string
	batchSize       int
	seqLength       int
	currentPosition int64
	fileSize        int64
	NumBatches      int
	data            []int32
	dataAll         []int32
}

func NewDataLoader(filename string, batchSize, seqLength int) (*DataLoader, error) {
	file, err := os.Open(filename)
	if err != nil {
		return nil, err
	}
	return newDataLoader(file, batchSize, seqLength)
}

func newDataLoader(file io.Reader, batchSize, seqLength int) (*DataLoader, error) {
	data, err := io.ReadAll(file)
	if err != nil {
		return nil, err
	}
	size := len(data)
	if size < (batchSize*seqLength+1)*Int32ByteLen {
		return nil, errors.New("error: file size is too small for the batch size and sequence length")
	}
	loader := &DataLoader{
		batchSize:  batchSize,
		seqLength:  seqLength,
		NumBatches: size / (batchSize * seqLength * Int32ByteLen),
		data:       make([]int32, size/Int32ByteLen),
		fileSize:   int64(size / Int32ByteLen),
	}
	if err := binary.Read(bytes.NewReader(data), binary.LittleEndian, loader.data); err != nil {
		return nil, err
	}
	return loader, nil
}

func newDataLoaderFromInts(data []int32, batchSize, seqLength int) (*DataLoader, error) {
	size := len(data)
	if size < (batchSize*seqLength + 1) {
		return nil, errors.New("error: file size is too small for the batch size and sequence length")
	}
	loader := &DataLoader{
		batchSize:  batchSize,
		seqLength:  seqLength,
		NumBatches: size / (batchSize * seqLength),
		data:       data,
		fileSize:   int64(size),
	}
	return loader, nil
}

func (loader *DataLoader) Reset() {
	loader.currentPosition = 0
}

func (loader *DataLoader) NextBatch() ([]int32, []int32, error) {
	nextPos := loader.currentPosition + int64(loader.batchSize*loader.seqLength)
	if nextPos+1 > loader.fileSize {
		loader.Reset()
		nextPos = loader.currentPosition + int64(loader.batchSize*loader.seqLength)
	}
	// don't  x4 because we're indexing int32 not byte
	inputs := loader.data[loader.currentPosition:nextPos]
	targets := loader.data[loader.currentPosition+1 : nextPos+1]
	loader.currentPosition = nextPos
	return inputs, targets, nil
}

Parameters

A Parameter is a numerical value that determines the strength of the connection between neurons. These connections are similar to synapses in the human brain, and the parameters are like the knobs that adjust the strength of those connections.

There are two main types of parameters in neural networks:

  • Weights: These are associated with each connection between neurons. They multiply the signal coming from one neuron before it’s passed on to the next neuron. A higher weight means a stronger connection and a greater influence on the receiving neuron.

  • Biases: These are added to the sum of the weighted inputs at each neuron. They act like a baseline shift, allowing the neuron to activate even if the weighted inputs are weak.

back to top

// ParameterTensors are the parameters of the model
type ParameterTensors struct {
	Memory        []float32
	WordTokEmbed  tensor // (V, C) - Word/Token Embedding weights (Vocabulary size, Embedding dimension)
	WordPosEmbed  tensor // (maxT, C) - Positional Embedding weights (Maximum Sequence length, Embedding dimension)
	LayerNorm1W   tensor // (L, C) - Weights for Layer Normalization 1 (Number of layers, Embedding dimension)
	LayerNorm1B   tensor // (L, C) - Biases for Layer Normalization 1
	QueryKeyValW  tensor // (L, 3*C, C) - Attention QKV weights (Layers, 3 * Embedding dimension, Embedding dimension)
	QueryKeyValB  tensor // (L, 3*C) - Attention QKV biases
	AttProjW      tensor // (L, C, C) - Attention projection weights (Layers, Embedding dimension, Embedding dimension)
	AttProjB      tensor // (L, C) - Attention projection biases
	Layer2NormW   tensor // (L, C) - Weights for Layer Normalization 2
	Layer2NormB   tensor // (L, C) - Biases for Layer Normalization 2
	FeedFwdW      tensor // (L, 4*C, C) - Feed-forward layer weights (Layers, 4 * Embedding Dimension, Embedding Dimension)
	FeedFwdB      tensor // (L, 4*C) - Feed-forward layer biases
	FeedFwdProjW  tensor // (L, C, 4*C) - Feed-forward projection weights
	FeedFwdProjB  tensor // (L, C)- Feed-forward projection biases
	LayerFinNormW tensor // (C) - Final layer normalization weights
	LayerFinNormB tensor // (C) - Final layer normalization biases
}

func newParameterTensors(V, C, maxSeqLen, L int) ParameterTensors {
	var tensor ParameterTensors
	tensor.Init(V, C, maxSeqLen, L)
	return tensor
}

func (tensor *ParameterTensors) Len() int {
	return len(tensor.Memory)
}

// Init initialises the ParameterTensors with specific sizes for each tensor based on the model architecture.
func (tensor *ParameterTensors) Init(V, C, maxSeqLen, L int) {
	tensor.Memory = make([]float32,
		V*C+ // WordTokEmbed
			maxSeqLen*C+ // WordPosEmbed
			L*C+ // LayerNorm1W
			L*C+ // LayerNorm1B
			L*3*C*C+ // QueryKeyValW
			L*3*C+ // QueryKeyValB
			L*C*C+ // AttProjW
			L*C+ // AttProjB
			L*C+ // Layer2NormW
			L*C+ // Layer2NormB
			L*4*C*C+ // FeedFwdW
			L*4*C+ // FeedFwdB
			L*C*4*C+ // FeedFwdProjW
			L*C+ // FeedFwdProjB
			C+ // LayerFinNormW
			C, // LayerFinNormB
	)
	var ptr int
	memPtr := tensor.Memory
	tensor.WordTokEmbed, ptr = newTensor(memPtr, V, C)
	memPtr = memPtr[ptr:]
	tensor.WordPosEmbed, ptr = newTensor(memPtr, maxSeqLen, C)
	memPtr = memPtr[ptr:]
	tensor.LayerNorm1W, ptr = newTensor(memPtr, L, C)
	memPtr = memPtr[ptr:]
	tensor.LayerNorm1B, ptr = newTensor(memPtr, L, C)
	memPtr = memPtr[ptr:]
	tensor.QueryKeyValW, ptr = newTensor(memPtr, L, 3*C, C)
	memPtr = memPtr[ptr:]
	tensor.QueryKeyValB, ptr = newTensor(memPtr, L, 3*C)
	memPtr = memPtr[ptr:]
	tensor.AttProjW, ptr = newTensor(memPtr, L, C, C)
	memPtr = memPtr[ptr:]
	tensor.AttProjB, ptr = newTensor(memPtr, L, C)
	memPtr = memPtr[ptr:]
	tensor.Layer2NormW, ptr = newTensor(memPtr, L, C)
	memPtr = memPtr[ptr:]
	tensor.Layer2NormB, ptr = newTensor(memPtr, L, C)
	memPtr = memPtr[ptr:]
	tensor.FeedFwdW, ptr = newTensor(memPtr, L, 4*C, C)
	memPtr = memPtr[ptr:]
	tensor.FeedFwdB, ptr = newTensor(memPtr, L, 4*C)
	memPtr = memPtr[ptr:]
	tensor.FeedFwdProjW, ptr = newTensor(memPtr, L, C, 4*C)
	memPtr = memPtr[ptr:]
	tensor.FeedFwdProjB, ptr = newTensor(memPtr, L, C)
	memPtr = memPtr[ptr:]
	tensor.LayerFinNormW, ptr = newTensor(memPtr, C)
	memPtr = memPtr[ptr:]
	tensor.LayerFinNormB, ptr = newTensor(memPtr, C)
	memPtr = memPtr[ptr:]
	if len(memPtr) != 0 {
		panic("something went real bad here")
	}
}

Activations

An activation is the output of the input, and a mathematical operation. If the weight determines the strength of the function, the activation is the output.

back to top


// ActivationTensors
type ActivationTensors struct {
	Memory             []float32
	Encoded            tensor // (B, T, C) - Initial encoded input representations (Batch size, Sequence length, Embedding dimension)
	Layer1Act          tensor // (L, B, T, C) - Activations after Layer Normalization 1
	LayerNorm1Mean     tensor // (L, B, T) - Mean values for Layer Normalization 1
	LayerNorm1Rstd     tensor // (L, B, T) - Reciprocal of standard deviation for Layer Normalization 1
	QueryKeyVal        tensor // (L, B, T, 3*C) - Combined Query, Key, Value representations for attention
	AttentionInter     tensor // (L, B, T, C) - Intermediate attention-like result
	PreAttention       tensor // (L, B, NH, T, T) - Pre-attention scores
	Attention          tensor // (L, B, NH, T, T) - Normalized attention weights (Number of layers, Batch size, Number of Attention Heads, Sequence length, Sequence length)
	AttentionProj      tensor // (L, B, T, C) - Projected attention outputs
	Residual2          tensor // (L, B, T, C) - Residual connection after attention
	LayerNorm2Act      tensor // (L, B, T, C) - Activations after Layer Normalization 2
	LayerNorm2Mean     tensor // (L, B, T) - Mean values for Layer Normalization 2
	LayerNorm2Rstd     tensor // (L, B, T) - Reciprocal of standard deviation for Layer Normalization 2
	FeedForward        tensor // (L, B, T, 4*C) - Intermediate Feed-Forward Network activations
	FeedForwardGelu    tensor // (L, B, T, 4*C) - FeedForward activations after applying GELU (non-linearity)
	FeedForwardProj    tensor // (L, B, T, C) - Projected output of the Feed-Forward Network
	Residual3          tensor // (L, B, T, C) - Residual connection after Feed-Forward Network
	LayerNormFinal     tensor // (B, T, C) - Final activations after Layer Normalization
	LayerNormFinalMean tensor // (B, T) - Mean values for final Layer Normalization
	LayerNormFinalStd  tensor // (B, T) - Reciprocal of standard deviation for final Layer Normalization
	Logits             tensor // (B, T, V) - Raw output scores (before softmax)
	Probabilities      tensor // (B, T, V) - Softmax probabilities over the vocabulary
	Losses             tensor // (B, T) - Loss values per token in the batch
}

func (tensor *ActivationTensors) Init(B, C, T, L, NH, V int) {
	tensor.Memory = make([]float32,
		B*T*C+
			L*B*T*C+
			L*B*T+
			L*B*T+
			L*B*T*C*3+
			L*B*T*C+
			L*B*NH*T*T+
			L*B*NH*T*T+
			L*B*T*C+
			L*B*T*C+
			L*B*T*C+
			L*B*T+
			L*B*T+
			L*B*T*C*4+
			L*B*T*C*4+
			L*B*T*C+
			L*B*T*C+
			B*T*C+
			B*T+
			B*T+
			B*T*V+
			B*T*V+
			B*T)
	var ptr int
	memPtr := tensor.Memory
	tensor.Encoded, ptr = newTensor(memPtr, B, T, C)
	memPtr = memPtr[ptr:]
	tensor.Layer1Act, ptr = newTensor(memPtr, L, B, T, C)
	memPtr = memPtr[ptr:]
	tensor.LayerNorm1Mean, ptr = newTensor(memPtr, L, B, T)
	memPtr = memPtr[ptr:]
	tensor.LayerNorm1Rstd, ptr = newTensor(memPtr, L, B, T)
	memPtr = memPtr[ptr:]
	tensor.QueryKeyVal, ptr = newTensor(memPtr, L, B, T, C*3)
	memPtr = memPtr[ptr:]
	tensor.AttentionInter, ptr = newTensor(memPtr, L, B, T, C)
	memPtr = memPtr[ptr:]
	tensor.PreAttention, ptr = newTensor(memPtr, L, B, NH, T, T)
	memPtr = memPtr[ptr:]
	tensor.Attention, ptr = newTensor(memPtr, L, B, NH, T, T)
	memPtr = memPtr[ptr:]
	tensor.AttentionProj, ptr = newTensor(memPtr, L, B, T, C)
	memPtr = memPtr[ptr:]
	tensor.Residual2, ptr = newTensor(memPtr, L, B, T, C)
	memPtr = memPtr[ptr:]
	tensor.LayerNorm2Act, ptr = newTensor(memPtr, L, B, T, C)
	memPtr = memPtr[ptr:]
	tensor.LayerNorm2Mean, ptr = newTensor(memPtr, L, B, T)
	memPtr = memPtr[ptr:]
	tensor.LayerNorm2Rstd, ptr = newTensor(memPtr, L, B, T)
	memPtr = memPtr[ptr:]
	tensor.FeedForward, ptr = newTensor(memPtr, L, B, T, C*4)
	memPtr = memPtr[ptr:]
	tensor.FeedForwardGelu, ptr = newTensor(memPtr, L, B, T, C*4)
	memPtr = memPtr[ptr:]
	tensor.FeedForwardProj, ptr = newTensor(memPtr, L, B, T, C)
	memPtr = memPtr[ptr:]
	tensor.Residual3, ptr = newTensor(memPtr, L, B, T, C)
	memPtr = memPtr[ptr:]
	tensor.LayerNormFinal, ptr = newTensor(memPtr, B, T, C)
	memPtr = memPtr[ptr:]
	tensor.LayerNormFinalMean, ptr = newTensor(memPtr, B, T)
	memPtr = memPtr[ptr:]
	tensor.LayerNormFinalStd, ptr = newTensor(memPtr, B, T)
	memPtr = memPtr[ptr:]
	tensor.Logits, ptr = newTensor(memPtr, B, T, V)
	memPtr = memPtr[ptr:]
	tensor.Probabilities, ptr = newTensor(memPtr, B, T, V)
	memPtr = memPtr[ptr:]
	tensor.Losses, ptr = newTensor(memPtr, B, T)
	memPtr = memPtr[ptr:]
	if len(memPtr) != 0 {
		panic("something went real bad here")
	}
}

Tokenizer

tokenization

Tokenization is the fundamental process of transforming text data into a format the model can understand. It involves breaking down sentences into smaller units called tokens.

back to top

import (
	"encoding/binary"
	"encoding/json"
	"errors"
	"os"
	"sort"
)

const GPT2_EOT int32 = 50256

type Tokenizer struct {
	vocabSize  uint32
	tokenTable []string         // tokenTable maps token id to string
	tokenMap   map[string]int32 // tokenMap maps token to id
	init       bool
}

func newTokenizer(vocab []string) Tokenizer {
	tokenizer := Tokenizer{
		vocabSize:  uint32(len(vocab)),
		tokenTable: vocab,
		tokenMap:   make(map[string]int32),
		init:       true,
	}
	for i, token := range vocab {
		tokenizer.tokenMap[token] = int32(i)
	}
	return tokenizer
}

func NewTokenizer(filename string) (Tokenizer, error) {
	f, err := os.Open(filename)
	if err != nil {
		return Tokenizer{}, err
	}
	defer f.Close()
	header := make([]uint32, 256)
	if err := binary.Read(f, binary.LittleEndian, header); err != nil {
		return Tokenizer{}, err
	}
	if header[0] != 20240328 || header[1] != 1 {
		return Tokenizer{}, errors.New("incorrect header for tokenizer")
	}
	tok := Tokenizer{
		vocabSize:  header[2],
		tokenTable: make([]string, header[2]),
		tokenMap:   make(map[string]int32),
		init:       true,
	}
	var length byte
	for i := range tok.tokenTable {
		if err := binary.Read(f, binary.LittleEndian, &length); err != nil {
			return tok, err
		}
		if length <= 0 {
			return tok, errors.New("tokenizer failure")
		}
		tokenBytes := make([]byte, length)
		if err := binary.Read(f, binary.LittleEndian, tokenBytes); err != nil {
			return tok, err
		}
		tok.tokenTable[i] = string(tokenBytes)
		tok.tokenMap[tok.tokenTable[i]] = int32(i)
	}
	return tok, nil
}

type TokenizerJSON struct {
	Version string `json:"version"`
	Model   struct {
		Type          string            `json:"type"`
		Vocab         map[string]int    `json:"vocab"`
		MergesData    []string          `json:"merges,omitempty"`
		SpecialTokens map[string]string `json:"special_tokens"`
	} `json:"model"`
}

func NewTokenizerJson(filename string) (Tokenizer, error) {
	// Read the JSON file
	fileContent, err := os.ReadFile(filename)
	if err != nil {
		return Tokenizer{}, err
	}

	// Unmarshal JSON into our struct
	var tokenizerData TokenizerJSON
	if err := json.Unmarshal(fileContent, &tokenizerData); err != nil {
		return Tokenizer{}, err
	}

	// Create a new Tokenizer instance
	tok := Tokenizer{
		vocabSize:  uint32(len(tokenizerData.Model.Vocab)),
		tokenTable: make([]string, len(tokenizerData.Model.Vocab)),
		tokenMap:   make(map[string]int32),
		init:       true,
	}

	// Create a slice of token-id pairs for sorting
	var tokenIDPairs []struct {
		Token string
		ID    int
	}
	for token, id := range tokenizerData.Model.Vocab {
		// Convert the first two bytes to the 'Ġ' character if they match 0xC4 0xA0
		if len(token) >= 2 && token[0] == 0xC4 && token[1] == 0xA0 {
			token = " " + token[2:]
		}
		tokenIDPairs = append(tokenIDPairs, struct {
			Token string
			ID    int
		}{token, id})
	}

	// Sort the token-id pairs by ID
	sort.Slice(tokenIDPairs, func(i, j int) bool {
		return tokenIDPairs[i].ID < tokenIDPairs[j].ID
	})

	// Populate tokenTable and tokenMap
	for i, pair := range tokenIDPairs {
		tok.tokenTable[i] = pair.Token
		tok.tokenMap[pair.Token] = int32(i)
	}

	return tok, nil
}

func (t Tokenizer) Decode(tokens []int32) (string, error) {
	s := ""
	for _, token := range tokens {
		if token >= int32(len(t.tokenTable)) {
			return "", errors.New("not valid token")
		}
		if token != GPT2_EOT {
			s += t.tokenTable[token]
		}
	}
	return s, nil
}

func (t Tokenizer) Encode(text string) ([]int32, error) {
	tokens := []int32{}
	for len(text) > 0 {
		longestMatch := ""
		longestMatchToken := int32(GPT2_EOT)
		for i := len(text); i > 0; i-- {
			subStr := text[:i]
			if token, exists := t.tokenMap[subStr]; exists {
				longestMatch = subStr
				longestMatchToken = token
				break
			}
		}
		if longestMatch == "" {
			// If no match found, treat the first character as an unknown token
			tokens = append(tokens, GPT2_EOT)
			text = text[1:]
		} else {
			tokens = append(tokens, longestMatchToken)
			text = text[len(longestMatch):]
		}
	}
	return tokens, nil
}

Tokenize some text

%%
tokenizer, err := NewTokenizerJson("/Users/joshcarp/Documents/the-interactive-transformer/tokenizer.json"); if err != nil {
    panic(err)
}
gonbui.RequestInput("Tokenize some text: ", false)
reader := bufio.NewReader(os.Stdin)
inputText, err := reader.ReadString('\n')
if err != nil {
    panic(err)
}
if err != nil { panic(err) }
encoded, err := tokenizer.Encode(inputText)
fmt.Println("encoded: ", encoded)
decoded, err := tokenizer.Decode(encoded)
fmt.Println("decoded: ", decoded)
encoded:  [31373 612 220 50256]
decoded:  hello there 

Embedding

embeddings

encoderForward iterates through the batch/sequence and combines the word token embeddings with the word position embeddings. This allows out vector to encode tokens and positions in one vector.

Word embeddings

back to top

func encoderForward(out []float32, inp []int32, wte []float32, wpe []float32, B, T, C int) {
	// Iterate over each batch
	for b := 0; b < B; b++ {
		// Iterate over each time step in the sequence
		for t := 0; t < T; t++ {
			// Calculate the index in the output slice. Each vector is C elements long.
			startOutIndex := b*T*C + t*C
			// Calculate the token ID index in the input
			// inp is the tokenized input, each number in inp char is an index within wte (word token embeddings)
			ix := inp[b*T+t]
			// Calculate the index in the token embeddings slice
			// inp -> id -> wte[id]
			startWteIndex := ix * int32(C)
			// Calculate the index in the position embeddings slice
			// Wpe starts at 0 (when t is zero) which is basically mapping directly to index
			startWpeIndex := t * C
			// Add the vectors from `wte` and `wpe` and store the result in `out`
			// here we combine the vectors in the C dimensions.
			for i := 0; i < C; i++ {
				out[startOutIndex+i] = wte[startWteIndex+int32(i)] + wpe[startWpeIndex+i]
			}
		}
	}
}
%test
func TestEncoderForwardExplicit(t *testing.T) {
    inp := []int32{1, 0} // [1 -> wte (2, 3), wpe(4, 5)] [0 -> wte (0, 1), wpe(6, 7)]
    wte := []float32{0, 1, 2, 3}
    wpe := []float32{4, 5, 6, 7}
    B := 1 // Batch size
    T := 1 // Sequence Len
    C := 2 // Dimensions
    out := make([]float32, len(inp))
    encoderForward(out, inp, wte, wpe, B, T, C)
    expectedOut := []float32{6, 8}
    assert.Equal(t, expectedOut, out)
}
=== RUN   TestEncoderForwardExplicit
--- PASS: TestEncoderForwardExplicit (0.00s)
PASS

Layernorm forward

layernormForward normalizes the activations in each layer. It improves convergence in training and reduces sensitivity to initial parameters. For each vector, the mean and variance are calculated.

Parameters:

  • out: output activations (B,T,C)
  • mean: mean values (B,T) for each position (b,t)
  • rstd: reciprocal standard deviations (B,T) for each position (b,t)
  • inp: input activations (B,T,C)
  • weight: learnable weight (C) for scaling
  • bias: learnable bias (C) for shifting
  • B: batch size
  • T: sequence length (number of time steps)
  • C: embedding dimension (number of features)

Papers

back to top

func layernormForward(out, mean, rstd, inp, weight, bias []float32, B, T, C int) {
	var eps float32 = 1e-5
	for b := 0; b < B; b++ {
		for t := 0; t < T; t++ {
			x := inp[b*T*C+t*C:]
			// Calculate mean
			var m float32 = 0.0
			for i := 0; i < C; i++ {
				m += x[i]
			}
			m /= float32(C)
			// Calculate variance
			var v float32 = 0.0
			for i := 0; i < C; i++ {
				xshift := x[i] - m
				v += xshift * xshift
			}
			v /= float32(C)
			// Calculate rstd (reciprocal standard deviation)
			s := 1.0 / Sqrt((v)+eps)
			// Normalize, scale, shift, and store output
			outBT := out[b*T*C+t*C:]
			for i := 0; i < C; i++ {
				// subtract mean to center data
				// divide by std to scale variance
				// (val - mean) / std
				n := s * (x[i] - m)
				// Multiply the weight
				o := n*weight[i] + bias[i]
				outBT[i] = o
			}
			// Store mean and rstd for backward pass
			mean[b*T+t] = m
			rstd[b*T+t] = s
		}
	}
}

AttentionForward

attention

attentionForward performs the attention forward pass.

attention is the only layer that mixes information across time every other operation is applied at every (b,t) position independently (and of course, no layer mixes information across batch)

Parameters:

  • out: output matrix (B,T,C)
  • preatt: pre-attention scores (B,NH,T,T)
  • att: post-attention scores (B,NH,T,T)
  • inp: input matrix (B,T,3C) holding Query, Key, Value vectors
  • B: batch size
  • T: sequence length (number of time steps)
  • C: input dimension (number of features)
  • NH: number of attention heads

Papers

back to top

func attentionForward(out, preatt, att, inp []float32, B, T, C, NH int) {
	C3 := C * 3  // This is the dimensions for the key, query and values
	hs := C / NH // head size
	scale := 1.0 / Sqrt(float32(hs))
	// Iterate over batch, sequence length, and number of heads
	var wg sync.WaitGroup
	for b := 0; b < B; b++ {
		// Sequence length
		for t := 0; t < T; t++ {
			for h := 0; h < NH; h++ {
				wg.Add(1)
				go func(b, t, h int) {
					defer wg.Done()
					// Calculate indices for query, pre-attention, and attention arrays
					// query is any particular input asking for information from other inputs
					queryT := inp[b*T*C3+t*C3+h*hs:] // inp[B][T][C3]
					preattBth := preatt[b*NH*T*T+h*T*T+t*T:]
					attBth := att[b*NH*T*T+h*T*T+t*T:]
					// Pass 1: Calculate query dot key and max value
					// The dot product is described in the paper as being better because
					// it can be optimized with matrix multiplication
					var maxval float32 = -10000.0
					// range from 0 to the current inp
					for t2 := 0; t2 <= t; t2++ {
						// Calculate key index for t2
						key_t2 := inp[b*T*C3+t2*C3+h*hs+C:] // +C because it's key
						// Compute dot product and update max value
						var val float32
						for i := 0; i < hs; i++ {
							val += queryT[i] * key_t2[i]
						}
						val *= scale
						if val > maxval {
							maxval = val
						}
						// preatt[b][h][t1][t2] == dot product (similarity) between query vector at position t1 and
						// key vector at t2.
						preattBth[t2] = val
					}
					// Pass 2: Calculate the exp and keep track of sum
					// Calculate exponential sum and update preatt and att arrays
					// maps the max value to zero,
					// and everything else negative.
					// when the exp function is called then the range of numbers will be
					// between 0 and e.
					var expsum float32
					for t2 := 0; t2 <= t; t2++ {
						expv := Exp((preattBth[t2]) - maxval)
						// expsum is a sum of all the exp'd pre_att values
						expsum += expv
						// att_bth[t2] is the exp'd preatt_bth[t2]
						attBth[t2] = expv
					}
					var expsum_inv float32
					if expsum != 0.0 {
						expsum_inv = 1.0 / expsum
					}
					// Pass 3: Normalize to get softmax
					// from 0 -> t2: att_bth[t2] = exp(preatt[t2]) / sum(exp(preatt[:]))
					// for everything else it's zero
					for t2 := 0; t2 < T; t2++ {
						if t2 <= t {
							attBth[t2] *= expsum_inv
						} else {
							// Causal attention mask (optional; used for debugging and comparison)
							attBth[t2] = 0.0
						}
					}

					// Pass 4: Accumulate weighted values into the output of attention
					// out = attention * values
					// The values in this instance are the initial token/position embeddings that have gone through many linear
					// transformations at this point.
					// This is simply applying the learned attention "weights" to the lkqv values.
					// These weights must change a whole bunch after back propagation.
					out_bth := out[b*T*C+t*C+h*hs:]
					for i := 0; i < hs; i++ {
						out_bth[i] = 0.0
					}
					for t2 := 0; t2 <= t; t2++ {
						value_t2 := inp[b*T*C3+t2*C3+h*hs+C*2:] // +C*2 because it's value
						att_btht2 := attBth[t2]
						for i := 0; i < hs; i++ {
							out_bth[i] += att_btht2 * value_t2[i]
						}
					}
				}(b, t, h)
			}
		}
	}
	wg.Wait()
}
%test
func TestAttentionForward(t *testing.T) {
	type args struct {
		inp []float32
		B   int
		T   int
		C   int
		NH  int
	}
	tests := []struct {
		name       string
		args       args
		wantOut    []float32
		wantPreatt []float32
		wantAtt    []float32
	}{
		{
			name: "Small Input Test",
			args: args{
				inp: []float32{1, 2, 3, 4, 5, 6},
				B:   1,
				T:   1,
				C:   2,
				NH:  1,
			},
			wantOut:    []float32{5, 6},
			wantPreatt: []float32{7.7781744},
			wantAtt:    []float32{1},
		},
		{
			name: "Larger Input Test",
			args: args{
				inp: []float32{ // (B, T, C3)
					/* B = 1 */
					/* T =  0 */
					/*qry*/ 1, 2, 3, // query compared against (4, 5, 6) but not (13, 14, 15) because it's in the future (t=1)
					/*key*/ 4, 5, 6,
					/*val*/ 7, 8, 9,
					/* T =  1 */
					/*qry*/ 10, 11, 12, // will be compared against (4, 5, 6) (t-1) and (13, 14, 15)
					/*key*/ 13, 14, 15,
					/*val*/ 16, 17, 18, // vals are updated to
				},
				B:  1,
				T:  2,
				C:  3,
				NH: 1,
			},
			wantOut: []float32{ // (B, T, C)
				/*      B = 0       */
				/*      T = 0       */
				/* C =  0    1    2 */
				/*  */ 7, 8, 9,
				/* T = 1 */
				/* C =  0    1    2 */
				/*  */ 16, 17, 18,
			},
			wantPreatt: []float32{ // (B, NH, T, T)
				/* B =  0    */
				/* NH = 0    */
				/*T =   1  2 */
				/*T=1*/ 18.475208, 0, // preatt: 18 -> 1, 0 -> 0
				/*T=2*/ 96.417496, 267.89053, // 96 -> 9, 267 -> 1
			},
			wantAtt: []float32{ // (B, NH, T, T)
				/* B = 0     */
				/* NH = 0    */
				/*T =   1  2 */
				/*T=1*/ 1, 0,
				/*T=2*/ 0, 1,
			},
		},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			out, preatt, att := make([]float32, len(tt.wantOut)), make([]float32, len(tt.wantPreatt)), make([]float32, len(tt.wantAtt))
			attentionForward(out, preatt, att, tt.args.inp, tt.args.B, tt.args.T, tt.args.C, tt.args.NH)
			assert.InDeltaSlice(t, tt.wantOut, out, 1e-4, fmt.Sprintf("want: %v got: %v", tt.wantOut, out))
			assert.InDeltaSlice(t, tt.wantPreatt, preatt, 1e-4, fmt.Sprintf("want: %v got: %v", tt.wantPreatt, preatt))
			assert.InDeltaSlice(t, tt.wantAtt, att, 1e-4, fmt.Sprintf("want: %v got: %v", tt.wantAtt, att))
		})
	}
}
=== RUN   TestAttentionForward
=== RUN   TestAttentionForward/Small_Input_Test
=== RUN   TestAttentionForward/Larger_Input_Test
--- PASS: TestAttentionForward (0.00s)
    --- PASS: TestAttentionForward/Small_Input_Test (0.00s)
    --- PASS: TestAttentionForward/Larger_Input_Test (0.00s)
PASS

Residual forward

residual

residualForward implements a simple residual connection, a common technique used in deep neural networks to improve training and performance.

Papers

  • Deep Residual Learning for Image Recognition - The introduction of residuals allowed for deeper networks. Before this paper the depth of a neural network was limited because it would diverge enough and back propagation was really, really difficult to do because of vanishing gradients. Residuals essentially have a “short circuit” past a block which limits how much the neural networks can influence.

back to top

func residualForward(out, inp1, inp2 []float32, N int) {
	for i := 0; i < N; i++ {
		out[i] = inp1[i] + inp2[i]
	}
}

geluForward

The geluForward function applies the GELU activation to the input values stored in the inp slice and writes the activated values to the out slice.

geluForward is the Gaussian Error Linear Units activation function. It leaves positive values mostly unchanged but maps negative value close to zero. This introduces “non-linearity” to the neural network and allows for the model to fit to functions that aren’t just linear regressions.

Papers

  • Gaussian Error Linear Units (GELUs) - Activation function that leaves positive values unchanged but maps negative numbers to near zero. Other architectures use different activation functions. For example, OpenElm uses SwiGLU.

back to top

var GELUSCALEFACTOR = Sqrt(2.0 / math.Pi)
func geluForward(out, inp []float32, n int) {
	for i := 0; i < n; i++ {
		x := inp[i]
		cube := 0.044715 * x * x * x
		out[i] = 0.5 * x * (1.0 + Tanh(GELUSCALEFACTOR*(x+cube)))
	}
}

Softmax

softmaxForward calculates the softmax probabilities for a batch of input logits, converting them into a probability distribution over multiple classes. It’s a common operation in neural networks, especially for classification tasks.

back to top

func softmaxForward(probs, logits []float32, B, T, V int) {
	var wg sync.WaitGroup
	for b := 0; b < B; b++ {
		for t := 0; t < T; t++ {
			wg.Add(1)
			go func(b, t int) {
				defer wg.Done()
				baseIndex := b*T*V + t*V
				logitsBT := logits[baseIndex : baseIndex+V]
				probsBT := probs[baseIndex : baseIndex+V]
				// Numerical Stability
				var maxval float32 = -10000.0
				for i := 0; i < V; i++ {
					if logitsBT[i] > maxval {
						maxval = logitsBT[i]
					}
				}
				// Calculate exponentials and sum
				var sum float32
				for i := 0; i < V; i++ {
					probsBT[i] = Exp((logitsBT[i] - maxval))
					sum += probsBT[i] // Using float32 for potential precision gain
				}
				// Normalize
				for i := 0; i < V; i++ {
					probsBT[i] /= sum
				}
			}(b, t)
		}
	}
	wg.Wait()
}

CrossEntropyForward

The function crossEntropyForward calculates the cross-entropy loss for a batch of predicted probability distributions and their corresponding target labels.

back to top

// crossEntropyForward
func crossEntropyForward(losses []float32, probs []float32, targets []int32, B, T, V int) {
	// Iterate over each batch
	for b := 0; b < B; b++ {
		// Iterate over each time step in the sequence
		for t := 0; t < T; t++ {
			// Calculate the index in the probability slice
			startIndex := int32(b*T*V + t*V)
			// Get the correct index in the logits for the current batch and time step
			ix := targets[b*T+t]
			// Calculate the cross-entropy loss
			prob := probs[startIndex+ix]
			// Calculate the negative log of the probability for the correct target index
			losses[b*T+t] = -Log((prob))
		}
	}
}

Putting it all together

type GPT2Config struct {
	MaxSeqLen int `json:"max_seq_len"`
	V         int `json:"vocab_size"`
	L         int `json:"num_layers"`
	NH        int `json:"num_heads"`
	C         int `json:"channels"`
	EOT       int32
}


type GPT2 struct {
	Tokenizer Tokenizer
	Config    GPT2Config // Hyper-parameters of the model
	// Params has the actual weights of the model. Params.Memory is for convenience to be able to set/reset parameters simply
	Params ParameterTensors // Weights of the model
	// Grads contains the delta/gradient that will eventually be applied to the params in the model
	Grads ParameterTensors // Gradients of the weights
	// Fields for AdamW optimizer
	MMemory []float32         // First moment estimates (for AdamW)
	VMemory []float32         // Second moment estimates (for AdamW)
	Acts    ActivationTensors // Activations of the model
	// gradients of the activations
	GradsActs ActivationTensors
	B         int     // Current batch size (B)
	T         int     // Current sequence length (T)
	Inputs    []int32 // Input tokens
	Targets   []int32 // Target tokens
	MeanLoss  float32 // Mean loss after a forward pass
	Rand      *rand.Rand
}


func loadFromReader(f io.Reader) (*GPT2, error) {
	header := make([]int32, 256)
	err := binary.Read(f, binary.LittleEndian, header)
	if err != nil {
		return nil, fmt.Errorf("error reading model header: %v", err)
	}
	if header[0] != 20240326 || header[1] != 1 {
		return nil, fmt.Errorf("bad model file format")
	}
	model := &GPT2{
		Config: GPT2Config{
			MaxSeqLen: int(header[2]),
			V:         int(header[3]),
			L:         int(header[4]),
			NH:        int(header[5]),
			C:         int(header[6]),
			EOT:       GPT2_EOT,
		},
		Rand: rand.New(rand.NewSource(21)),
	}
	model.Params.Init(model.Config.V, model.Config.C, model.Config.MaxSeqLen, model.Config.L)
	if err := binary.Read(f, binary.LittleEndian, model.Params.Memory); err != nil {
		return nil, fmt.Errorf("error reading model: %v", err)
	}
	return model, nil
}
// LoadGPT2Model loads the GPT-2 model from a checkpoint file.
func LoadGPT2Model(checkpointPath, tokenizerFile string) (*GPT2, error) {
	// File Reading
	f, err := os.Open(checkpointPath)
	if err != nil {
		return nil, fmt.Errorf("Error opening model file: %v", err)
	}
	defer f.Close()
	// Read Model Header
	model, err := loadFromReader(f)
	if err != nil {
		return nil, err
	}
	if tokenizerFile == "" {
		return model, err
	}
	tok, err := NewTokenizer(tokenizerFile)
	if err != nil {
		return nil, err
	}
	model.Tokenizer = tok
	return model, nil
}

Forward

The function Forward implements the forward pass of a GPT-2 language model. It takes a sequence of input tokens and a sequence of target tokens (if available) as input, and it calculates the model’s output probabilities for the next token in the sequence.

back to top

func (model *GPT2) Forward(input, target []int32, B, T int) {
	V, L, NH, C := model.Config.V, model.Config.L, model.Config.NH, model.Config.C
	if model.Acts.Memory == nil {
		model.B, model.T = B, T
		model.Acts.Init(B, C, T, L, NH, V)
		model.Inputs = make([]int32, B*T)
		model.Targets = make([]int32, B*T)
	}
	copy(model.Inputs, input)
	copy(model.Targets, target)
	params, acts := model.Params, model.Acts
	// This encodes the word token embeddings with the positional embeddings
	// so that those vectors have spacial information and aren't just purely made up of the
	// token embeddings. The result of this is stored in acts.Encoded.
	// Input is a slice of ids/tokens that correspond to the vectors in WTE and their index is the "position"
	encoderForward(acts.Encoded.data, input, params.WordTokEmbed.data, params.WordPosEmbed.data, B, T, C)
	var residual []float32
	for l := 0; l < L; l++ {
		// residual is a connection between the last layers output, or the initial token/pos embedding (as applied above)
		if l == 0 {
			residual = acts.Encoded.data
		} else {
			residual = acts.Residual3.data[(l-1)*B*T*C:]
		}
		// Parameters
		l_ln1w := params.LayerNorm1W.data[l*C:]
		l_ln1b := params.LayerNorm1B.data[l*C:]
		l_qkvw := params.QueryKeyValW.data[l*3*C*C:]
		l_qkvb := params.QueryKeyValB.data[l*3*C:]
		l_attprojw := params.AttProjW.data[l*C*C:]
		l_attprojb := params.AttProjB.data[l*C:]
		l_ln2w := params.Layer2NormW.data[l*C:]
		l_ln2b := params.Layer2NormB.data[l*C:]
		l_fcw := params.FeedFwdW.data[l*4*C*C:]
		l_fcb := params.FeedFwdB.data[l*4*C:]
		l_fcprojw := params.FeedFwdProjW.data[l*C*4*C:]
		l_fcprojb := params.FeedFwdProjB.data[l*C:]
		// Activations
		l_ln1 := acts.Layer1Act.data[l*B*T*C:]
		l_ln1_mean := acts.LayerNorm1Mean.data[l*B*T:]
		l_ln1_rstd := acts.LayerNorm1Rstd.data[l*B*T:]
		l_qkv := acts.QueryKeyVal.data[l*B*T*3*C:]
		l_atty := acts.AttentionInter.data[l*B*T*C:]
		l_preatt := acts.PreAttention.data[l*B*NH*T*T:]
		l_att := acts.Attention.data[l*B*NH*T*T:]
		l_attproj := acts.AttentionProj.data[l*B*T*C:]
		l_residual2 := acts.Residual2.data[l*B*T*C:]
		l_ln2 := acts.LayerNorm2Act.data[l*B*T*C:]
		l_ln2_mean := acts.LayerNorm2Mean.data[l*B*T:]
		l_ln2_rstd := acts.LayerNorm2Rstd.data[l*B*T:]
		l_fch := acts.FeedForward.data[l*B*T*4*C:]
		l_fch_gelu := acts.FeedForwardGelu.data[l*B*T*4*C:]
		l_fcproj := acts.FeedForwardProj.data[l*B*T*C:]
		l_residual3 := acts.Residual3.data[l*B*T*C:]
		// Here we normalise the layer so that the mean is 0 and the standard deviation is ~1.
		// residual contains the un-edited activations
		layernormForward(l_ln1, l_ln1_mean, l_ln1_rstd, residual /*inp*/, l_ln1w /*weight*/, l_ln1b /*bias*/, B, T, C)
		/*
					l_qkvw = weight = Query Key Val Weights (C * 3C)
					l_ln1 = inp = layer activations
					l_qkvb = bias = Query Key Val Bias
					l_qkv = out = key/query/value matrix
				Here we're matrix multiplying  l_ln1(inp)*l_qkvw(weight) + l_qkvb(bias)
				This matrix multiplication essentially gets a layer activation for the model inputs (activations) which are multiplied
				by the model weights.
			This does the input "projection" via linear transformations via the model query/key/value weights into higher dimensionality.
		*/
		matmulForward(l_qkv, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C)
		/*
			The attention forward pass takes these query/key/value vectors, along with the model attention weights
			The model pre-attention scores, after the forward pass, have the un-normalised attention scores
			att has the attention acores and l_atty has the attention scores + the query/key/value scores
			l_qkv has the projection of the activations into a higher dimension.
			l_preatt: has the projection qkv vectors dot product(similarity), between an input's query and another input's key.
				This basically goes like this:
				word a: has a query vector "what am i looking for"
				word b: has a query vector "what do i need"
				if they're similar, these vectors will be similar, therefore the scores will be high and be stored in l_preatt
			the v in the qkv is the original token/position embeddings which have been through a number of linear transformations at this point.
		*/
		attentionForward(l_atty, l_preatt, l_att, l_qkv, B, T, C, NH)

		/*
			Here we do another matrix multiplication of attention weights and biases
			This projects the l_atty into another dimension. These will probably also get back propagated.
		*/
		matmulForward(l_attproj, l_atty, l_attprojw, l_attprojb, B, T, C, C)
		/*
			The residual forward simply adds the attention projection and the residual layer, which is the
			weights(or activations?) before any of the previous transformations. This allows a stronger signal and
			prevents weight dropout and i think makes back propagation more efficient.
		*/
		residualForward(l_residual2, residual, l_attproj, B*T*C)
		/*
			The weights in this level are the layer 2 activations, which are multiplied with the residual through the above sections
			This is normalised and everything into layernorm2
		*/
		layernormForward(l_ln2, l_ln2_mean, l_ln2_rstd, l_residual2, l_ln2w, l_ln2b, B, T, C)
		/*
			Feedforward is just another layer of a multi layer perceptron to make the "higher level" connections.
		*/
		matmulForward(l_fch, l_ln2, l_fcw, l_fcb, B, T, C, 4*C)
		/*
			This is an acitvation function which maps large values to close to one and smaller values to zero.
		*/
		geluForward(l_fch_gelu, l_fch, B*T*4*C)
		/*
			This now squishes the last layer into a smaller dimension so it can be added to the next layer.
		*/
		matmulForward(l_fcproj, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, 4*C, C)
		/*
			Now we set the next residual layer as the output of this layer. This is the l_fcproj + the current layer residual
		*/
		residualForward(l_residual3, l_residual2, l_fcproj, B*T*C)
	}
	residual = acts.Residual3.data[(L-1)*B*T*C:]

	/*
		Now this is the last thing. We're layer norming the final layer activations so that the logits can be calculated

	*/
	layernormForward(acts.LayerNormFinal.data, acts.LayerNormFinalMean.data, acts.LayerNormFinalStd.data, residual, params.LayerFinNormW.data, params.LayerFinNormB.data, B, T, C)
	/*
			Matrix multiplying the Word Token embedding gives us the logits.
		This is calculating a weighted sum. More likely tokens will be blown up and less likely will be zero or negative.
	*/
	matmulForward(acts.Logits.data, acts.LayerNormFinal.data, params.WordTokEmbed.data, nil, B, T, C, V)
	/*
		After all of this we can softmax the logits to get probabilities over the entire vocabulary
	*/
	softmaxForward(acts.Probabilities.data, acts.Logits.data, B, T, V)
	// also forward the cross-entropy loss function if we have the targets
	if len(target) > 0 {
		/*
			This compares the probabilities for each token and compares it to the target to calculate a loss.
		*/
		crossEntropyForward(model.Acts.Losses.data, model.Acts.Probabilities.data, target, B, T, V)
		// for convenience also evaluate the mean loss
		var meanLoss float32
		for i := range model.Acts.Losses.data {
			meanLoss += model.Acts.Losses.data[i]
		}
		meanLoss /= float32(B * T)
		model.MeanLoss = meanLoss

	} else {
		model.MeanLoss = -1.0
	}
}

Sampling

The probabilities are a float array of:

index/tokenid:probability

coin is a random value between 0 and 1.

We start with a cumulative sum, and when it gets above our target coin, we return.

This makes it that the most likely token returned is the one that has the most probability, but we still have the possibiility of choosing other ones, proportional to how likley they are.

back to top

func sampleMult(probabilities []float32, coin float32) int {
	var cdf float32
	for i, prob := range probabilities {
		cdf += prob
		if coin < cdf {
			return i
		}
	}
	return len(probabilities) - 1
}
func (model *GPT2) Inference(input string) (string, error) {
	B, T, nTokens := 1, 64, 20
	start := time.Now()
	defer func() {
		fmt.Printf("inference time took: %v\n", time.Now().Sub(start))
	}()
	tokens, err := model.Tokenizer.Encode(input)
	//prompt_len := len(tokens)
	if err != nil {
		return "", err
	}
	if len(tokens) < T {
		for i := len(tokens); i <= T; i++ {
			tokens = append(tokens, model.Config.EOT)
		}
	}
	fmt.Printf("input is %d tokens long\n", len(tokens))
	model.Forward(tokens, tokens[1:], B, T)
	for t := 1; t < nTokens; t++ {
		// for each t, we re-compute all activations between 0 and t
		// leaving this alone because you want separate code for inference anyway
		// the inference here is just for sanity checking purposes
		model.Forward(tokens, nil, B, t)
		probabilities := model.Acts.Probabilities.data[(t-1)*model.Config.V:]
		coin := model.Rand.Float32()
		nextToken2 := sampleMult(probabilities, coin)
		tokens[t] = rune(nextToken2)
		out, err := model.Tokenizer.Decode([]int32{tokens[t]})
		if err != nil {
			panic(err)
		}
		fmt.Print(out)

	}
	return model.Tokenizer.Decode(tokens)
}
func newGPT2(MaxSeqLen, V, L, NH, C int, vocab []string) GPT2 {
	model := GPT2{
		Config: GPT2Config{
			MaxSeqLen: MaxSeqLen,
			V:         V,
			L:         L,
			NH:        NH,
			C:         C,
		},
		Params:    newParameterTensors(V, C, MaxSeqLen, L),
		Tokenizer: newTokenizer(vocab),
		Rand:      rand.New(rand.NewSource(21)),
	}
	return model
}

Do some inference

%%

path := "/Users/joshcarp/Documents/the-interactive-transformer/"
model, err := LoadGPT2Model(path+"/gpt2_124M.bin", path+"/gpt2_tokenizer.bin")
if err != nil {
    panic(err)
}
gonbui.RequestInput("gpt2 text complete: ", false)
reader := bufio.NewReader(os.Stdin)
inputText, err := reader.ReadString('\n')
if err != nil {
    panic(err)
}
_, err = model.Inference(inputText)
if err != nil {
    panic(err)
}
input is 65 tokens long
ahlvinyl.org> <link rel="stylesheet"><span class="affiliate iconinference time took: 6.370601958s

Backward Pass

The backwards pass is where the “learning” happens. It is used to update the weights of the model. If we’re using the model for inference, deploying it as a chatbot, etc, we don’t do a backwards pass.

The backward pass calculates the difference between the predicted tokens (before the sampling), and calculates a gradient based on the learning algorithm.

Backpropagation

Papers

back to top

crossentropySoftmaxBackward

The function computes the gradients of the logits (dlogits) with respect to the loss, given the probabilities (probs) and target labels (targets). This gradient information is used during backpropagation to update the weights and biases of the network to minimize the cross-entropy loss.

back to top

// crossentropySoftmaxBackward calculates the cross entropy
func crossentropySoftmaxBackward(dlogits, dlosses, probs []float32, targets []int32, B, T, V int) {
	for b := 0; b < B; b++ {
		for t := 0; t < T; t++ {
			baseIndex := b*T*V + t*V
			dlogitsBT := dlogits[baseIndex : baseIndex+V]
			probsBT := probs[baseIndex : baseIndex+V]
			dloss := dlosses[b*T+t]
			ix := targets[b*T+t]
			for i := 0; i < V; i++ {
				p := probsBT[i]
				var indicator float32
				if int32(i) == ix {
					indicator = 1.0
				} else {
					indicator = 0.0
				}
				dlogitsBT[i] += (p - indicator) * dloss
			}
		}
	}
}

matmulBackward

The function computes the gradients of the inputs (dinp), weights (dweight), and biases (dbias) for a matrix multiplication operation. These gradients are necessary for adjusting the model parameters during training to minimize the error.

back to top

func matmulBackward(dinp, dweight, dbias, dout, inp, weight []float32, B, T, C, OC int) {
	var wg sync.WaitGroup
	for b := 0; b < B; b++ {
		for t := 0; t < T; t++ {
			wg.Add(1)
			go func(b, t int) {
				defer wg.Done()
				doutBt := dout[b*T*OC+t*OC:]
				dinpBt := dinp[b*T*C+t*C:]
				for o := 0; o < OC; o++ {
					wrow := weight[o*C:]
					d := doutBt[o]
					for i := 0; i < C; i++ {
						dinpBt[i] += wrow[i] * d
					}
				}
			}(b, t)
		}
	}
	wg.Wait()
	for o := 0; o < OC; o++ {
		wg.Add(1)
		go func(o int) {
			defer wg.Done()
			for b := 0; b < B; b++ {
				for t := 0; t < T; t++ {
					doutBt := dout[b*T*OC+t*OC:]
					inpBt := inp[b*T*C+t*C:]
					dwrow := dweight[o*C:]
					d := doutBt[o]
					if dbias != nil {
						dbias[o] += d
					}
					for i := 0; i < C; i++ {
						dwrow[i] += inpBt[i] * d
					}
				}
			}
		}(o)
	}
	wg.Wait()
}

layernormBackward

The function layernormBackward calculates the gradients for the backward pass of a Layer Normalization (LayerNorm) operation in a neural network. Here’s a breakdown of what it does:

Layer Normalization is a technique used to normalize the activations of a layer across its features, improving the training stability and performance of deep neural networks. It involves normalizing the input to have zero mean and unit variance. This function calculates the gradients needed to update the weights and biases of the LayerNorm operation during backpropagation.

back to top

func layernormBackward(dinp, dweight, dbias, dout, inp, weight, mean, rstd []float32, B, T, C int) {
	for b := 0; b < B; b++ {
		for t := 0; t < T; t++ {
			baseIndex := b*T*C + t*C
			doutBT := dout[baseIndex : baseIndex+C]
			inpBT := inp[baseIndex : baseIndex+C]
			dinpBT := dinp[baseIndex : baseIndex+C]
			meanBT := mean[b*T+t]
			rstdBT := rstd[b*T+t]

			// Reduce operations
			var dnormMean float32 = 0.0
			var dnormNormMean float32 = 0.0
			for i := 0; i < C; i++ {
				normBTI := (inpBT[i] - meanBT) * rstdBT
				dnormI := weight[i] * doutBT[i]
				dnormMean += dnormI
				dnormNormMean += dnormI * normBTI
			}
			dnormMean /= float32(C)
			dnormNormMean /= float32(C)

			// Accumulation loop
			for i := 0; i < C; i++ {
				normBTI := (inpBT[i] - meanBT) * rstdBT
				dnormI := weight[i] * doutBT[i]
				dbias[i] += doutBT[i]
				dweight[i] += normBTI * doutBT[i]

				var dval float32
				dval += dnormI                  // Term 1
				dval -= dnormMean               // Term 2
				dval -= normBTI * dnormNormMean // Term 3
				dval *= rstdBT                  // Final scale
				dinpBT[i] += dval
			}
		}
	}
}

residualBackward

The function residualBackward calculates the gradients for the backward pass of a residual connection in a neural network. Here’s a breakdown of what it does:

back to top

func residualBackward(dinp1, dinp2, dout []float32, N int) {
	for i := 0; i < N; i++ {
		dinp1[i] += dout[i]
		dinp2[i] += dout[i]
	}
}

geluBackward

Computes the gradient of the Gaussian Error Linear Unit (GELU) activation function for backpropagation in a neural network.

back to top

// geluBackward computes the backward pass of the GeLU non-linearity
func geluBackward(dinp, inp, dout []float32, n int) {
	for i := 0; i < n; i++ {
		x := inp[i]
		cube := 0.044715 * x * x * x
		tanhArg := GELUSCALEFACTOR * (x + cube)
		tanhOut := Tanh(tanhArg)
		coshfOut := Cosh(tanhArg)
		sechOut := 1.0 / (coshfOut * coshfOut)
		localGrad := 0.5*(1.0+tanhOut) + x*0.5*sechOut*GELUSCALEFACTOR*(1.0+3.0*0.044715*x*x)
		dinp[i] += localGrad * dout[i]
	}
}

attentionBackward

The attentionBackward function implements the backward pass for a self-attention mechanism in a neural network. This is a crucial part of training attention-based models, like transformers. It calculates the gradients of the attention weights, queries, keys, and values with respect to the outputs of the attention layer, allowing the model to adjust its parameters to improve performance.

back to top

// attentionBackward performs the backward pass for an attention mechanism
func attentionBackward(dinp, dpreatt, datt, dout, inp, att []float32, B, T, C, NH int) {
	// C3 is 3 times C, representing the size of Q, K, and V combined
	C3 := C * 3
	// hs is the size of each head
	hs := C / NH
	// scale is the factor used in the forward pass to scale the dot product
	scale := 1.0 / Sqrt(float32(hs))
	// Iterate through batch, time, and heads
	for b := 0; b < B; b++ {
		for t := 0; t < T; t++ {
			for h := 0; h < NH; h++ {
				// Calculate the indices for the arrays in this specific iteration
				attBTH := att[b*NH*T*T+h*T*T+t*T:]
				dattBTH := datt[b*NH*T*T+h*T*T+t*T:]
				dpreattBTH := dpreatt[b*NH*T*T+h*T*T+t*T:]
				dqueryT := dinp[b*T*C3+t*C3+h*hs:]
				queryT := inp[b*T*C3+t*C3+h*hs:]
				// Backward pass 4: value accumulation
				doutBTH := dout[b*T*C+t*C+h*hs:]
				for t2 := 0; t2 <= t; t2++ {
					valueT2 := inp[b*T*C3+t2*C3+h*hs+C*2:]
					dvalueT2 := dinp[b*T*C3+t2*C3+h*hs+C*2:]
					for i := 0; i < hs; i++ {
						// Compute gradients for attention and value accumulation
						dattBTH[t2] += valueT2[i] * doutBTH[i]
						dvalueT2[i] += attBTH[t2] * doutBTH[i]
					}
				}
				// Backward pass 2 & 3: softmax backward
				// Softmax does not require input (preatt) to backward
				for t2 := 0; t2 <= t; t2++ {
					for t3 := 0; t3 <= t; t3++ {
						var indicator float32
						if t2 == t3 {
							indicator = 1.0
						}
						localDerivative := attBTH[t2] * (indicator - attBTH[t3])
						dpreattBTH[t3] += localDerivative * dattBTH[t2]
					}
				}
				// Backward pass 1: query @ key matmul
				for t2 := 0; t2 <= t; t2++ {
					keyT2 := inp[b*T*C3+t2*C3+h*hs+C:]
					dkeyT2 := dinp[b*T*C3+t2*C3+h*hs+C:]
					for i := 0; i < hs; i++ {
						// Compute gradients for query and key
						dqueryT[i] += keyT2[i] * dpreattBTH[t2] * scale
						dkeyT2[i] += queryT[i] * dpreattBTH[t2] * scale
					}
				}
			}
		}
	}
}

matmulBackward

The function computes the gradients of the inputs (dinp), weights (dweight), and biases (dbias) for a matrix multiplication operation. These gradients are necessary for adjusting the model parameters during training to minimize the error.

dinp: A slice of floats representing the gradients of the outputs with respect to the inputs of the matrix multiplication. This is often calculated by the subsequent layer in the network. dweight: A slice of floats representing the gradients of the outputs with respect to the weights. Initially, this slice is filled with zeros. dbias: A slice of floats representing the gradients of the outputs with respect to the biases. Initially, this slice is filled with zeros. dout: A slice of floats representing the outputs of the matrix multiplication. inp: A slice of floats representing the inputs to the matrix multiplication. weight: A slice of floats representing the weights of the matrix multiplication. B: The batch size (number of samples). T: The time steps or sequence length. C: The number of input features. OC: The number of output features.

back to top

func matmulBackward(dinp, dweight, dbias, dout, inp, weight []float32, B, T, C, OC int) {
	var wg sync.WaitGroup
	for b := 0; b < B; b++ {
		for t := 0; t < T; t++ {
			wg.Add(1)
			go func(b, t int) {
				defer wg.Done()
				doutBt := dout[b*T*OC+t*OC:]
				dinpBt := dinp[b*T*C+t*C:]
				for o := 0; o < OC; o++ {
					wrow := weight[o*C:]
					d := doutBt[o]
					for i := 0; i < C; i++ {
						dinpBt[i] += wrow[i] * d
					}
				}
			}(b, t)
		}
	}
	wg.Wait()
	for o := 0; o < OC; o++ {
		wg.Add(1)
		go func(o int) {
			defer wg.Done()
			for b := 0; b < B; b++ {
				for t := 0; t < T; t++ {
					doutBt := dout[b*T*OC+t*OC:]
					inpBt := inp[b*T*C+t*C:]
					dwrow := dweight[o*C:]
					d := doutBt[o]
					if dbias != nil {
						dbias[o] += d
					}
					for i := 0; i < C; i++ {
						dwrow[i] += inpBt[i] * d
					}
				}
			}
		}(o)
	}
	wg.Wait()
}

encoderBackward

encoderBackward calculates gradients during backpropagation Parameters:

  • dwte: gradients with respect to word embeddings (wte)
  • dwpe: gradients with respect to positional embeddings (wpe)
  • dout: the gradient to apply to dwte and dwpe
  • inp: input tokens (ids that refer to indexes within wte)
  • B: batch size
  • T: sequence length (number of time steps)
  • C: embedding dimension (number of features)

back to top


func encoderBackward(dwte, dwpe []float32, dout []float32, inp []int32, B, T, C int) {
	// Iterate over the batch and time steps
	for b := 0; b < B; b++ {
		for t := 0; t < T; t++ {
			// Calculate offsets for indexing
			doutBTOffset := b*T*C + t*C
			ix := inp[b*T+t]              // Get the input token id
			dwteIxOffset := ix * int32(C) // Calculate the offset for dwte
			dwpeTOffset := t * C          // Calculate the offset for dwpe

			// Iterate over the embedding dimension and apply computations
			for i := 0; i < C; i++ {
				// Get the gradient value from dout
				d := dout[doutBTOffset+i]
				// Update the gradients for word embeddings (dwte) and positional embeddings (dwpe)
				dwte[dwteIxOffset+int32(i)] += d
				dwpe[dwpeTOffset+i] += d
			}
		}
	}
}

func (model *GPT2) ZeroGradient() {
	for i := range model.GradsActs.Memory {
		model.GradsActs.Memory[i] = 0.0
	}
	for i := range model.Grads.Memory {
		model.Grads.Memory[i] = 0.0
	}
}

Optimiser

The optimiser implementation keeps track of the weights that are being changed, and how fast they’re being changed.

Most neural network back propagation algorithms use AdamW, which is a weight-decay ontop of the Adam optimiser.

Papers

back to top

Optimiser

The optimiser implementation keeps track of the weights that are being changed, and how fast they’re being changed.

Most neural network back propagation algorithms use AdamW, which is a weight-decay ontop of the Adam optimiser.

Papers

back to top

func (model *GPT2) Update(learningRate, beta1, beta2, eps, weightDecay float32, t int) {
	// Lazy memory allocation
	if model.MMemory == nil {
		model.MMemory = make([]float32, model.Params.Len())
		model.VMemory = make([]float32, model.Params.Len())
	}
	// Parameter updates
	for i := 0; i < model.Params.Len(); i++ {
		parameter := model.Params.Memory[i]
		gradient := model.Grads.Memory[i]
		// Momentum update
		m := beta1*model.MMemory[i] + (1.0-beta1)*gradient
		// RMSprop update
		v := beta2*model.VMemory[i] + (1.0-beta2)*gradient*gradient
		// Bias correction
		mHat := m / (1.0 - Pow(beta1, float32(t)))
		vHat := v / (1.0 - Pow(beta2, float32(t)))
		// Parameter update
		model.MMemory[i] = m
		model.VMemory[i] = v
		model.Params.Memory[i] -= learningRate * (mHat/(Sqrt(vHat)+eps) + weightDecay*parameter)
	}
}

func (model *GPT2) Backward() error {
	//// double check we forwarded previously, with targets
	if model.MeanLoss == -1.0 {
		return errors.New("error: must forward with targets before backward")
	}
	// lazily allocate the memory for gradients of the weights and activations, if needed
	// convenience shortcuts
	B, T, V, L, NH, C := model.B, model.T, model.Config.V, model.Config.L, model.Config.NH, model.Config.C
	if len(model.Grads.Memory) == 0 {
		model.Grads.Init(V, C, model.Config.MaxSeqLen, L)
		model.GradsActs.Init(B, C, T, L, NH, V)
		model.ZeroGradient()
	}
	// backward pass
	params, grads, acts, gradsActs := model.Params, model.Grads, model.Acts, model.GradsActs
	// we kick off the chain by filling in dlosses with 1.0f/(B*T), to get the mean loss
	dlossMean := 1.0 / float32(B*T)
	for i := range gradsActs.Losses.data {
		gradsActs.Losses.data[i] = dlossMean
	}
	crossentropySoftmaxBackward(gradsActs.Logits.data, gradsActs.Losses.data, acts.Probabilities.data, model.Targets, B, T, V)
	matmulBackward(gradsActs.LayerNormFinal.data, grads.WordTokEmbed.data, nil, gradsActs.Logits.data, acts.LayerNormFinal.data, params.WordTokEmbed.data, B, T, C, V)
	residual := acts.Residual3.data[(L-1)*B*T*C:]       // last layer's residual
	dresidual := gradsActs.Residual3.data[(L-1)*B*T*C:] // write to last layer's residual
	layernormBackward(dresidual, grads.LayerFinNormW.data, grads.LayerFinNormB.data, gradsActs.LayerNormFinal.data, residual, params.LayerFinNormW.data, acts.LayerNormFinalMean.data, acts.LayerNormFinalStd.data, B, T, C)
	for l := L - 1; l >= 0; l-- {
		if l == 0 {
			residual = acts.Encoded.data
			dresidual = gradsActs.Encoded.data
		} else {
			residual = acts.Residual3.data[(l-1)*B*T*C:]
			dresidual = gradsActs.Residual3.data[(l-1)*B*T*C:]
		}

		// Assuming you have a 'params' variable of your ParameterTensors type
		l_ln1w := params.LayerNorm1W.data[l*C:]
		l_qkvw := params.QueryKeyValW.data[l*3*C*C:]
		l_attprojw := params.AttProjW.data[l*C*C:]
		l_ln2w := params.Layer2NormW.data[l*C:]
		l_fcw := params.FeedFwdW.data[l*4*C*C:]
		l_fcprojw := params.FeedFwdProjW.data[l*C*4*C:]
		// Gradients of weights
		dl_ln1w := grads.LayerNorm1W.data[l*C:]
		dl_ln1b := grads.LayerNorm1B.data[l*C:]
		dl_qkvw := grads.QueryKeyValW.data[l*3*C*C:]
		dl_qkvb := grads.QueryKeyValB.data[l*3*C:]
		dl_attprojw := grads.AttProjW.data[l*C*C:]
		dl_attprojb := grads.AttProjB.data[l*C:]
		dl_ln2w := grads.Layer2NormW.data[l*C:]
		dl_ln2b := grads.Layer2NormB.data[l*C:]
		dl_fcw := grads.FeedFwdW.data[l*4*C*C:]
		dl_fcb := grads.FeedFwdB.data[l*4*C:]
		dl_fcprojw := grads.FeedFwdProjW.data[l*C*4*C:]
		dl_fcprojb := grads.FeedFwdProjB.data[l*C:]
		// Activations
		l_ln1 := acts.Layer1Act.data[l*B*T*C:]
		l_ln1_mean := acts.LayerNorm1Mean.data[l*B*T:]
		l_ln1_rstd := acts.LayerNorm1Rstd.data[l*B*T:]
		l_qkv := acts.QueryKeyVal.data[l*B*T*3*C:]
		l_atty := acts.AttentionInter.data[l*B*T*C:]
		l_att := acts.Attention.data[l*B*NH*T*T:]
		l_residual2 := acts.Residual2.data[l*B*T*C:]
		l_ln2 := acts.LayerNorm2Act.data[l*B*T*C:]
		l_ln2_mean := acts.LayerNorm2Mean.data[l*B*T:]
		l_ln2_rstd := acts.LayerNorm2Rstd.data[l*B*T:]
		l_fch := acts.FeedForward.data[l*B*T*4*C:]
		l_fch_gelu := acts.FeedForwardGelu.data[l*B*T*4*C:]

		dl_ln1 := gradsActs.Layer1Act.data[l*B*T*C:]
		dl_qkv := gradsActs.QueryKeyVal.data[l*B*T*3*C:]
		dl_atty := gradsActs.AttentionInter.data[l*B*T*C:]
		dl_preatt := gradsActs.PreAttention.data[l*B*NH*T*T:]
		dl_att := gradsActs.Attention.data[l*B*NH*T*T:]
		dl_attproj := gradsActs.AttentionProj.data[l*B*T*C:]
		dl_residual2 := gradsActs.Residual2.data[l*B*T*C:]
		dl_ln2 := gradsActs.LayerNorm2Act.data[l*B*T*C:]
		dl_fch := gradsActs.FeedForward.data[l*B*T*4*C:]
		dl_fch_gelu := gradsActs.FeedForwardGelu.data[l*B*T*4*C:]
		dl_fcproj := gradsActs.FeedForwardProj.data[l*B*T*C:]
		dl_residual3 := gradsActs.Residual3.data[l*B*T*C:]
		residualBackward(dl_residual2, dl_fcproj, dl_residual3, B*T*C)
		matmulBackward(dl_fch_gelu, dl_fcprojw, dl_fcprojb, dl_fcproj, l_fch_gelu, l_fcprojw, B, T, 4*C, C)
		geluBackward(dl_fch, l_fch, dl_fch_gelu, B*T*4*C)
		matmulBackward(dl_ln2, dl_fcw, dl_fcb, dl_fch, l_ln2, l_fcw, B, T, C, 4*C)
		layernormBackward(dl_residual2, dl_ln2w, dl_ln2b, dl_ln2, l_residual2, l_ln2w, l_ln2_mean, l_ln2_rstd, B, T, C)
		residualBackward(dresidual, dl_attproj, dl_residual2, B*T*C)
		matmulBackward(dl_atty, dl_attprojw, dl_attprojb, dl_attproj, l_atty, l_attprojw, B, T, C, C)
		attentionBackward(dl_qkv, dl_preatt, dl_att, dl_atty, l_qkv, l_att, B, T, C, NH)
		matmulBackward(dl_ln1, dl_qkvw, dl_qkvb, dl_qkv, l_ln1, l_qkvw, B, T, C, 3*C)
		layernormBackward(dresidual, dl_ln1w, dl_ln1b, dl_ln1, residual, l_ln1w, l_ln1_mean, l_ln1_rstd, B, T, C)
	}
	// Here we want to apply our gradients to our encoded data.
	encoderBackward(grads.WordTokEmbed.data, grads.WordPosEmbed.data, gradsActs.Encoded.data, model.Inputs, B, T, C)
	return nil
}

func (model *GPT2) Train(valDataloader, trainDataloader *DataLoader, B, T int) error {
	fmt.Printf("train dataset num_batches: %d\n", valDataloader.NumBatches)
	const genMaxLength, valNumBatches = 20, 3
	for step := 0; step <= 3; step++ {
		if step%1 == 0 {
			var valLoss float32
			valDataloader.Reset()
			for i := 0; i < valNumBatches; i++ {
				input, target, err := valDataloader.NextBatch()
				if err != nil {
					return err
				}
				model.Forward(input, target, B, T)
				valLoss += model.MeanLoss
			}
			valLoss /= float32(valNumBatches)
			fmt.Printf("val loss %f\n", valLoss)
		}
		// do a training step
		start := time.Now()
		input, targets, err := trainDataloader.NextBatch()
		if err != nil {
			return err
		}
		model.Forward(input, targets, B, T)
		model.ZeroGradient()
		model.Backward()
		model.Update(1e-4, 0.9, 0.999, 1e-8, 0.0, step+1)
		fmt.Printf("step %d: train loss %f (took %v ms)\n", step, model.MeanLoss, time.Since(start))
	}
	return nil
}
%main
model, err := LoadGPT2Model("./gpt2_124M.bin", "./gpt2_tokenizer.bin")
if err != nil {
    log.Fatal(err)
}
B, T := 4, 64
trainDataloader, err := NewDataLoader("./TinyStories_train.bin", B, T)
if err != nil {
    log.Fatal(err)
}
fmt.Printf("train dataset num_batches: %d\n", trainDataloader.NumBatches)
valDataloader, err := NewDataLoader("./TinyStories_val.bin", B, T)
if err != nil {
    log.Fatal(err)
}
if err := model.Train(valDataloader, trainDataloader, B, T); err != nil {
    log.Fatal(err)
}
2024/08/26 17:53:26 Error opening model file: open ./gpt2_124M.bin: no such file or directory
exit status 1