diff options
Diffstat (limited to 'mm')
-rw-r--r-- | mm/memcontrol.c | 37 |
1 files changed, 33 insertions, 4 deletions
diff --git a/mm/memcontrol.c b/mm/memcontrol.c index b836e7f00309..bf9cf738c836 100644 --- a/mm/memcontrol.c +++ b/mm/memcontrol.c @@ -678,9 +678,20 @@ struct mem_cgroup *mem_cgroup_from_task(struct task_struct *p) } EXPORT_SYMBOL(mem_cgroup_from_task); -static struct mem_cgroup *get_mem_cgroup_from_mm(struct mm_struct *mm) +/** + * get_mem_cgroup_from_mm: Obtain a reference on given mm_struct's memcg. + * @mm: mm from which memcg should be extracted. It can be NULL. + * + * Obtain a reference on mm->memcg and returns it if successful. Otherwise + * root_mem_cgroup is returned. However if mem_cgroup is disabled, NULL is + * returned. + */ +struct mem_cgroup *get_mem_cgroup_from_mm(struct mm_struct *mm) { - struct mem_cgroup *memcg = NULL; + struct mem_cgroup *memcg; + + if (mem_cgroup_disabled()) + return NULL; rcu_read_lock(); do { @@ -700,6 +711,24 @@ static struct mem_cgroup *get_mem_cgroup_from_mm(struct mm_struct *mm) rcu_read_unlock(); return memcg; } +EXPORT_SYMBOL(get_mem_cgroup_from_mm); + +/** + * If current->active_memcg is non-NULL, do not fallback to current->mm->memcg. + */ +static __always_inline struct mem_cgroup *get_mem_cgroup_from_current(void) +{ + if (unlikely(current->active_memcg)) { + struct mem_cgroup *memcg = root_mem_cgroup; + + rcu_read_lock(); + if (css_tryget_online(¤t->active_memcg->css)) + memcg = current->active_memcg; + rcu_read_unlock(); + return memcg; + } + return get_mem_cgroup_from_mm(current->mm); +} /** * mem_cgroup_iter - iterate over memory cgroup hierarchy @@ -2261,7 +2290,7 @@ struct kmem_cache *memcg_kmem_get_cache(struct kmem_cache *cachep) if (current->memcg_kmem_skip_account) return cachep; - memcg = get_mem_cgroup_from_mm(current->mm); + memcg = get_mem_cgroup_from_current(); kmemcg_id = READ_ONCE(memcg->kmemcg_id); if (kmemcg_id < 0) goto out; @@ -2345,7 +2374,7 @@ int memcg_kmem_charge(struct page *page, gfp_t gfp, int order) if (memcg_kmem_bypass()) return 0; - memcg = get_mem_cgroup_from_mm(current->mm); + memcg = get_mem_cgroup_from_current(); if (!mem_cgroup_is_root(memcg)) { ret = memcg_kmem_charge_memcg(page, gfp, order, memcg); if (!ret) |