Skip to content

Commit

Permalink
Make Region word-aligned
Browse files Browse the repository at this point in the history
  • Loading branch information
rexim committed Jun 26, 2023
1 parent aaf8e02 commit e240599
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 20 deletions.
2 changes: 1 addition & 1 deletion demos/adder.c
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ int main(void)
gym_layout_end();

char buffer[256];
snprintf(buffer, sizeof(buffer), "Epoch: %zu/%zu, Rate: %f, Cost: %f, Temporary Memory: %zu\n", epoch, max_epoch, rate, nn_cost(nn, ti, to), temp.size);
snprintf(buffer, sizeof(buffer), "Epoch: %zu/%zu, Rate: %f, Cost: %f, Temporary Memory: %zu\n", epoch, max_epoch, rate, nn_cost(nn, ti, to), region_occupied_bytes(&temp));
DrawTextEx(font, buffer, CLITERAL(Vector2){}, h*0.04, 0, WHITE);
}
EndDrawing();
Expand Down
2 changes: 1 addition & 1 deletion demos/img2nn.c
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ int main(int argc, char **argv)
gym_layout_end();

char buffer[256];
snprintf(buffer, sizeof(buffer), "Epoch: %zu/%zu, Rate: %f, Cost: %f, Temporary Memory: %zu\n", epoch, max_epoch, rate, plot.count > 0 ? plot.items[plot.count - 1] : 0, temp.size);
snprintf(buffer, sizeof(buffer), "Epoch: %zu/%zu, Rate: %f, Cost: %f, Temporary Memory: %zu\n", epoch, max_epoch, rate, plot.count > 0 ? plot.items[plot.count - 1] : 0, region_occupied_bytes(&temp));
DrawTextEx(font, buffer, CLITERAL(Vector2) {}, h*0.04, 0, WHITE);
gym_slider(&rate, &rate_dragging, 0, h*0.08, w, h*0.02);
}
Expand Down
2 changes: 1 addition & 1 deletion demos/xor.c
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ int main(void)
gym_layout_end();

char buffer[256];
snprintf(buffer, sizeof(buffer), "Epoch: %zu/%zu, Rate: %f, Cost: %f, Temporary Memory: %zu bytes", epoch, max_epoch, rate, nn_cost(nn, ti, to), temp.size);
snprintf(buffer, sizeof(buffer), "Epoch: %zu/%zu, Rate: %f, Cost: %f, Temporary Memory: %zu bytes", epoch, max_epoch, rate, nn_cost(nn, ti, to), region_occupied_bytes(&temp));
DrawTextEx(font, buffer, CLITERAL(Vector2){}, h*0.04, 0, WHITE);
}
EndDrawing();
Expand Down
43 changes: 26 additions & 17 deletions nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,15 @@ float dactf(float y, Act act);
typedef struct {
size_t capacity;
size_t size;
char *data;
uintptr_t *words;
} Region;

Region region_alloc_alloc(size_t capacity);
void *region_alloc(Region *r, size_t size);
#define region_reset(r) (r)->size = 0
// capacity is in bytes, but it can allocate more just to keep things
// word aligned
Region region_alloc_alloc(size_t capacity_bytes);
void *region_alloc(Region *r, size_t size_bytes);
#define region_reset(r) (NN_ASSERT((r) != NULL), (r)->size = 0)
#define region_occupied_bytes(r) (NN_ASSERT((r) != NULL), (r)->size*sizeof(*(r)->words))

typedef struct {
size_t rows;
Expand Down Expand Up @@ -557,24 +560,30 @@ void batch_process(Region *r, Batch *b, size_t batch_size, NN nn, Mat t, float r
}
}

Region region_alloc_alloc(size_t capacity)
Region region_alloc_alloc(size_t capacity_bytes)
{
void *data = NN_MALLOC(capacity);
NN_ASSERT(data != NULL);
Region r = {
.capacity = capacity,
.data = data
};
Region r = {0};

size_t word_size = sizeof(*r.words);
size_t capacity_words = (capacity_bytes + word_size - 1)/word_size;

void *words = NN_MALLOC(capacity_words*word_size);
NN_ASSERT(words != NULL);
r.capacity = capacity_words;
r.words = words;
return r;
}

void *region_alloc(Region *r, size_t size)
void *region_alloc(Region *r, size_t size_bytes)
{
if (r == NULL) return NN_MALLOC(size);
NN_ASSERT(r->size + size <= r->capacity);
if (r->size + size > r->capacity) return NULL;
void *result = &r->data[r->size];
r->size += size;
if (r == NULL) return NN_MALLOC(size_bytes);
size_t word_size = sizeof(*r->words);
size_t size_words = (size_bytes + word_size - 1)/word_size;

NN_ASSERT(r->size + size_words <= r->capacity);
if (r->size + size_words > r->capacity) return NULL;
void *result = &r->words[r->size];
r->size += size_words;
return result;
}

Expand Down

0 comments on commit e240599

Please sign in to comment.