diff --git a/sys/kern/subr_epoch.c b/sys/kern/subr_epoch.c index 9c0a5e092abd..ebb4e3dcc9a3 100644 --- a/sys/kern/subr_epoch.c +++ b/sys/kern/subr_epoch.c @@ -284,21 +284,15 @@ epoch_free(epoch_t epoch) } while (0) void -epoch_enter(epoch_t epoch) +epoch_enter_internal(epoch_t epoch, struct thread *td) { struct epoch_pcpu_state *eps; - struct thread *td; INIT_CHECK(epoch); - - td = curthread; critical_enter(); eps = epoch->e_pcpu[curcpu]; - td->td_epochnest++; - MPASS(td->td_epochnest < UCHAR_MAX - 2); - if (td->td_epochnest == 1) - TAILQ_INSERT_TAIL(&eps->eps_record.er_tdlist, td, td_epochq); #ifdef INVARIANTS + MPASS(td->td_epochnest < UCHAR_MAX - 2); if (td->td_epochnest > 1) { struct thread *curtd; int found = 0; @@ -307,38 +301,31 @@ epoch_enter(epoch_t epoch) if (curtd == td) found = 1; KASSERT(found, ("recursing on a second epoch")); - } -#endif - if (td->td_epochnest > 1) { critical_exit(); return; } +#endif + TAILQ_INSERT_TAIL(&eps->eps_record.er_tdlist, td, td_epochq); sched_pin(); ck_epoch_begin(&eps->eps_record.er_record, (ck_epoch_section_t*)&td->td_epoch_section); critical_exit(); } void -epoch_exit(epoch_t epoch) +epoch_exit_internal(epoch_t epoch, struct thread *td) { struct epoch_pcpu_state *eps; - struct thread *td; td = curthread; + MPASS(td->td_epochnest == 0); INIT_CHECK(epoch); - MPASS(td->td_epochnest); critical_enter(); eps = epoch->e_pcpu[curcpu]; - td->td_epochnest--; - if (td->td_epochnest == 0) - TAILQ_REMOVE(&eps->eps_record.er_tdlist, td, td_epochq); - else { - critical_exit(); - return; - } - sched_unpin(); + ck_epoch_end(&eps->eps_record.er_record, (ck_epoch_section_t*)&td->td_epoch_section); + TAILQ_REMOVE(&eps->eps_record.er_tdlist, td, td_epochq); eps->eps_record.er_gen++; + sched_unpin(); critical_exit(); } diff --git a/sys/sys/epoch.h b/sys/sys/epoch.h index 4eff193b098c..cfe86ccbe157 100644 --- a/sys/sys/epoch.h +++ b/sys/sys/epoch.h @@ -29,6 +29,8 @@ #ifndef _SYS_EPOCH_H_ #define _SYS_EPOCH_H_ +#include +#include struct epoch; typedef struct epoch *epoch_t; @@ -43,10 +45,35 @@ typedef struct epoch_context *epoch_context_t; epoch_t epoch_alloc(void); void epoch_free(epoch_t epoch); -void epoch_enter(epoch_t epoch); -void epoch_exit(epoch_t epoch); +void epoch_enter_internal(epoch_t epoch, struct thread *td); +void epoch_exit_internal(epoch_t epoch, struct thread *td); void epoch_wait(epoch_t epoch); void epoch_call(epoch_t epoch, epoch_context_t ctx, void (*callback) (epoch_context_t)); int in_epoch(void); +static __inline void +epoch_enter(epoch_t epoch) +{ + struct thread *td; + int nesting; + + td = curthread; + nesting = td->td_epochnest++; +#ifndef INVARIANTS + if (nesting == 0) +#endif + epoch_enter_internal(epoch, td); +} + +static __inline void +epoch_exit(epoch_t epoch) +{ + struct thread *td; + + td = curthread; + MPASS(td->td_epochnest); + if (td->td_epochnest-- == 1) + epoch_exit_internal(epoch, td); +} + #endif