Reduce chance of RCU deadlock in the LinuxKPI by implementing the section

feature of the concurrency kit, CK.

Differential Revision:	https://reviews.freebsd.org/D29467
Reviewed by:	kib@ and markj@
MFC after:	1 week
Sponsored by:	Mellanox Technologies // NVIDIA Networking
This commit is contained in:
Hans Petter Selasky 2021-03-28 09:36:48 +02:00
parent 19318a62d7
commit 1777720880
2 changed files with 33 additions and 11 deletions

View File

@ -82,6 +82,7 @@ struct task_struct {
int bsd_interrupt_value; int bsd_interrupt_value;
struct work_struct *work; /* current work struct, if set */ struct work_struct *work; /* current work struct, if set */
struct task_struct *group_leader; struct task_struct *group_leader;
unsigned rcu_section[TS_RCU_TYPE_MAX];
}; };
#define current ({ \ #define current ({ \

View File

@ -1,6 +1,6 @@
/*- /*-
* Copyright (c) 2016 Matthew Macy (mmacy@mattmacy.io) * Copyright (c) 2016 Matthew Macy (mmacy@mattmacy.io)
* Copyright (c) 2017-2020 Hans Petter Selasky (hselasky@freebsd.org) * Copyright (c) 2017-2021 Hans Petter Selasky (hselasky@freebsd.org)
* All rights reserved. * All rights reserved.
* *
* Redistribution and use in source and binary forms, with or without * Redistribution and use in source and binary forms, with or without
@ -85,6 +85,15 @@ struct linux_epoch_record {
*/ */
CTASSERT(sizeof(struct rcu_head) == sizeof(struct callback_head)); CTASSERT(sizeof(struct rcu_head) == sizeof(struct callback_head));
/*
* Verify that "rcu_section[0]" has the same size as
* "ck_epoch_section_t". This has been done to avoid having to add
* special compile flags for including ck_epoch.h to all clients of
* the LinuxKPI.
*/
CTASSERT(sizeof(((struct task_struct *)0)->rcu_section[0] ==
sizeof(ck_epoch_section_t)));
/* /*
* Verify that "epoch_record" is at beginning of "struct * Verify that "epoch_record" is at beginning of "struct
* linux_epoch_record": * linux_epoch_record":
@ -189,6 +198,14 @@ linux_rcu_read_lock(unsigned type)
if (RCU_SKIP()) if (RCU_SKIP())
return; return;
ts = current;
/* assert valid refcount */
MPASS(ts->rcu_recurse[type] != INT_MAX);
if (++(ts->rcu_recurse[type]) != 1)
return;
/* /*
* Pin thread to current CPU so that the unlock code gets the * Pin thread to current CPU so that the unlock code gets the
* same per-CPU epoch record: * same per-CPU epoch record:
@ -196,17 +213,15 @@ linux_rcu_read_lock(unsigned type)
sched_pin(); sched_pin();
record = &DPCPU_GET(linux_epoch_record[type]); record = &DPCPU_GET(linux_epoch_record[type]);
ts = current;
/* /*
* Use a critical section to prevent recursion inside * Use a critical section to prevent recursion inside
* ck_epoch_begin(). Else this function supports recursion. * ck_epoch_begin(). Else this function supports recursion.
*/ */
critical_enter(); critical_enter();
ck_epoch_begin(&record->epoch_record, NULL); ck_epoch_begin(&record->epoch_record,
ts->rcu_recurse[type]++; (ck_epoch_section_t *)&ts->rcu_section[type]);
if (ts->rcu_recurse[type] == 1) TAILQ_INSERT_TAIL(&record->ts_head, ts, rcu_entry[type]);
TAILQ_INSERT_TAIL(&record->ts_head, ts, rcu_entry[type]);
critical_exit(); critical_exit();
} }
@ -221,18 +236,24 @@ linux_rcu_read_unlock(unsigned type)
if (RCU_SKIP()) if (RCU_SKIP())
return; return;
record = &DPCPU_GET(linux_epoch_record[type]);
ts = current; ts = current;
/* assert valid refcount */
MPASS(ts->rcu_recurse[type] > 0);
if (--(ts->rcu_recurse[type]) != 0)
return;
record = &DPCPU_GET(linux_epoch_record[type]);
/* /*
* Use a critical section to prevent recursion inside * Use a critical section to prevent recursion inside
* ck_epoch_end(). Else this function supports recursion. * ck_epoch_end(). Else this function supports recursion.
*/ */
critical_enter(); critical_enter();
ck_epoch_end(&record->epoch_record, NULL); ck_epoch_end(&record->epoch_record,
ts->rcu_recurse[type]--; (ck_epoch_section_t *)&ts->rcu_section[type]);
if (ts->rcu_recurse[type] == 0) TAILQ_REMOVE(&record->ts_head, ts, rcu_entry[type]);
TAILQ_REMOVE(&record->ts_head, ts, rcu_entry[type]);
critical_exit(); critical_exit();
sched_unpin(); sched_unpin();