diff --git a/test/lib/ut_multithread.c b/test/lib/ut_multithread.c index ba40d9fbf3..a19b6bb11b 100644 --- a/test/lib/ut_multithread.c +++ b/test/lib/ut_multithread.c @@ -52,10 +52,17 @@ struct ut_thread { struct spdk_thread *thread; struct spdk_io_channel *ch; TAILQ_HEAD(, ut_msg) msgs; + TAILQ_HEAD(, ut_poller) pollers; }; struct ut_thread *g_ut_threads; +struct ut_poller { + spdk_poller_fn fn; + void *arg; + TAILQ_ENTRY(ut_poller) tailq; +}; + static void __send_msg(spdk_thread_fn fn, void *ctx, void *thread_ctx) { @@ -70,6 +77,32 @@ __send_msg(spdk_thread_fn fn, void *ctx, void *thread_ctx) TAILQ_INSERT_TAIL(&thread->msgs, msg, link); } +static struct spdk_poller * +__start_poller(void *thread_ctx, spdk_thread_fn fn, void *arg, uint64_t period_microseconds) +{ + struct ut_thread *thread = thread_ctx; + struct ut_poller *poller = calloc(1, sizeof(struct ut_poller)); + + SPDK_CU_ASSERT_FATAL(poller != NULL); + + poller->fn = fn; + poller->arg = arg; + + TAILQ_INSERT_TAIL(&thread->pollers, poller, tailq); + + return (struct spdk_poller *)poller; +} + +static void +__stop_poller(struct spdk_poller *poller, void *thread_ctx) +{ + struct ut_thread *thread = thread_ctx; + + TAILQ_REMOVE(&thread->pollers, (struct ut_poller *)poller, tailq); + + free(poller); +} + static uintptr_t g_thread_id = MOCK_PASS_THRU; static void @@ -92,11 +125,13 @@ allocate_threads(int num_threads) for (i = 0; i < g_ut_num_threads; i++) { set_thread(i); - spdk_allocate_thread(__send_msg, NULL, NULL, &g_ut_threads[i], NULL); + spdk_allocate_thread(__send_msg, __start_poller, __stop_poller, + &g_ut_threads[i], NULL); thread = spdk_get_thread(); SPDK_CU_ASSERT_FATAL(thread != NULL); g_ut_threads[i].thread = thread; TAILQ_INIT(&g_ut_threads[i].msgs); + TAILQ_INIT(&g_ut_threads[i].pollers); } set_thread(MOCK_PASS_THRU); @@ -124,7 +159,9 @@ poll_thread(uintptr_t thread_id) int count = 0; struct ut_thread *thread = &g_ut_threads[thread_id]; struct ut_msg *msg; + struct ut_poller *poller; uintptr_t original_thread_id; + TAILQ_HEAD(, ut_poller) tmp_pollers; CU_ASSERT(thread_id != (uintptr_t)MOCK_PASS_THRU); CU_ASSERT(thread_id < g_ut_num_threads); @@ -141,6 +178,21 @@ poll_thread(uintptr_t thread_id) free(msg); } + TAILQ_INIT(&tmp_pollers); + + while (!TAILQ_EMPTY(&thread->pollers)) { + poller = TAILQ_FIRST(&thread->pollers); + TAILQ_REMOVE(&thread->pollers, poller, tailq); + + if (poller->fn) { + poller->fn(poller->arg); + } + + TAILQ_INSERT_TAIL(&tmp_pollers, poller, tailq); + } + + TAILQ_SWAP(&tmp_pollers, &thread->pollers, ut_poller, tailq); + set_thread(original_thread_id); return count;