aboutsummaryrefslogtreecommitdiff
path: root/src/thread_pool.cpp
blob: 3565ef25a3b5d2ea3faad35aa1c5314e25240057 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
// thread_pool.cpp

struct WorkerTask;
struct ThreadPool;

gb_thread_local Thread *current_thread;

gb_internal void thread_pool_init(ThreadPool *pool, gbAllocator const &a, isize thread_count, char const *worker_name);
gb_internal void thread_pool_destroy(ThreadPool *pool);
gb_internal bool thread_pool_add_task(ThreadPool *pool, WorkerTaskProc *proc, void *data);
gb_internal void thread_pool_wait(ThreadPool *pool);

struct ThreadPool {
	gbAllocator   allocator;

	Slice<Thread> threads;
	std::atomic<bool> running;

	BlockingMutex task_lock;
	Condition     tasks_available;

	Futex tasks_left;
};

gb_internal void thread_pool_init(ThreadPool *pool, gbAllocator const &a, isize thread_count, char const *worker_name) {
	mutex_init(&pool->task_lock);
	condition_init(&pool->tasks_available);

	pool->allocator = a;
	slice_init(&pool->threads, a, thread_count + 1);

	// setup the main thread
	thread_init(pool, &pool->threads[0], 0);
	current_thread = &pool->threads[0];

	for_array_off(i, 1, pool->threads) {
		Thread *t = &pool->threads[i];
		thread_init_and_start(pool, t, i);
	}

	pool->running = true;
}

gb_internal void thread_pool_destroy(ThreadPool *pool) {
	pool->running = false;

	for_array_off(i, 1, pool->threads) {
		Thread *t = &pool->threads[i];
		condition_broadcast(&pool->tasks_available);
		thread_join_and_destroy(t);
	}
	for_array(i, pool->threads) {
		free(pool->threads[i].queue);
	}

	gb_free(pool->allocator, pool->threads.data);
	mutex_destroy(&pool->task_lock);
	condition_destroy(&pool->tasks_available);
}

void thread_pool_queue_push(Thread *thread, WorkerTask task) {
	uint64_t capture;
	uint64_t new_capture;
	do {
		capture = thread->head_and_tail.load();

		uint64_t mask = thread->capacity - 1;
		uint64_t head = (capture >> 32) & mask;
		uint64_t tail = ((uint32_t)capture) & mask;

		uint64_t new_head = (head + 1) & mask;
		if (new_head == tail) {
			GB_PANIC("Thread Queue Full!\n");
		}

		// This *must* be done in here, to avoid a potential race condition where we no longer own the slot by the time we're assigning
		thread->queue[head] = task;
		new_capture = (new_head << 32) | tail;
	} while (!thread->head_and_tail.compare_exchange_weak(capture, new_capture));

	thread->pool->tasks_left.fetch_add(1);
	condition_broadcast(&thread->pool->tasks_available);
}

bool thread_pool_queue_pop(Thread *thread, WorkerTask *task) {
	uint64_t capture;
	uint64_t new_capture;
	do {
		capture = thread->head_and_tail.load();

		uint64_t mask = thread->capacity - 1;
		uint64_t head = (capture >> 32) & mask;
		uint64_t tail = ((uint32_t)capture) & mask;

		uint64_t new_tail = (tail + 1) & mask;
		if (tail == head) {
			return false;
		}

		// Making a copy of the task before we increment the tail, avoiding the same potential race condition as above
		*task = thread->queue[tail];

		new_capture = (head << 32) | new_tail;
	} while (!thread->head_and_tail.compare_exchange_weak(capture, new_capture));

	return true;
}

gb_internal bool thread_pool_add_task(ThreadPool *pool, WorkerTaskProc *proc, void *data) {
	WorkerTask task = {};
	task.do_work = proc;
	task.data = data;
		
	thread_pool_queue_push(current_thread, task);
	return true;
}	

gb_internal void thread_pool_wait(ThreadPool *pool) {
	WorkerTask task;

	while (pool->tasks_left) {

		// if we've got tasks on our queue, run them
		while (thread_pool_queue_pop(current_thread, &task)) {
			task.do_work(task.data);
			pool->tasks_left.fetch_sub(1);
		}


		// is this mem-barriered enough?
		// This *must* be executed in this order, so the futex wakes immediately
		// if rem_tasks has changed since we checked last, otherwise the program
		// will permanently sleep
		Footex rem_tasks = pool->tasks_left.load();
		if (!rem_tasks) {
			break;
		}

		tpool_wait_on_addr(&pool->tasks_left, rem_tasks);
	}
}

gb_internal THREAD_PROC(thread_pool_thread_proc) {
	WorkerTask task;
	current_thread = thread;
	ThreadPool *pool = current_thread->pool;

	for (;;) {
work_start:
		if (!pool->running) {
			break;
		}

		// If we've got tasks to process, work through them
		size_t finished_tasks = 0;
		while (thread_pool_queue_pop(current_thread, &task)) {
			task.do_work(task.data);
			pool->tasks_left.fetch_sub(1);

			finished_tasks += 1;
		}
		if (finished_tasks > 0 && !pool->tasks_left) {
			tpool_wake_addr(&pool->tasks_left);
		}

		// If there's still work somewhere and we don't have it, steal it
		if (pool->tasks_left) {
			isize idx = current_thread->idx;
			for_array(i, pool->threads) {
				if (!pool->tasks_left) {
					break;
				}

				idx = (idx + 1) % pool->threads.count;
				Thread *thread = &pool->threads[idx];

				WorkerTask task;
				if (!thread_pool_queue_pop(thread, &task)) {
					continue;
				}

				task.do_work(task.data);
				pool->tasks_left.fetch_sub(1);

				if (!pool->tasks_left) {
					tpool_wake_addr(&pool->tasks_left);
				}

				goto work_start;
			}
		}

		// if we've done all our work, and there's nothing to steal, go to sleep
		mutex_lock(&pool->task_lock);
		condition_wait(&pool->tasks_available, &pool->task_lock);
		mutex_unlock(&pool->task_lock);
	}

	return 0;
}