summaryrefslogtreecommitdiffstats
path: root/mm/mempolicy.c
diff options
context:
space:
mode:
Diffstat (limited to 'mm/mempolicy.c')
-rw-r--r--mm/mempolicy.c35
1 files changed, 22 insertions, 13 deletions
diff --git a/mm/mempolicy.c b/mm/mempolicy.c
index da858f794eb6..cfd26d7e61a1 100644
--- a/mm/mempolicy.c
+++ b/mm/mempolicy.c
@@ -797,16 +797,19 @@ static void get_policy_nodemask(struct mempolicy *p, nodemask_t *nodes)
}
}
-static int lookup_node(unsigned long addr)
+static int lookup_node(struct mm_struct *mm, unsigned long addr)
{
struct page *p;
int err;
- err = get_user_pages(addr & PAGE_MASK, 1, 0, &p, NULL);
+ int locked = 1;
+ err = get_user_pages_locked(addr & PAGE_MASK, 1, 0, &p, &locked);
if (err >= 0) {
err = page_to_nid(p);
put_page(p);
}
+ if (locked)
+ up_read(&mm->mmap_sem);
return err;
}
@@ -817,7 +820,7 @@ static long do_get_mempolicy(int *policy, nodemask_t *nmask,
int err;
struct mm_struct *mm = current->mm;
struct vm_area_struct *vma = NULL;
- struct mempolicy *pol = current->mempolicy;
+ struct mempolicy *pol = current->mempolicy, *pol_refcount = NULL;
if (flags &
~(unsigned long)(MPOL_F_NODE|MPOL_F_ADDR|MPOL_F_MEMS_ALLOWED))
@@ -857,7 +860,16 @@ static long do_get_mempolicy(int *policy, nodemask_t *nmask,
if (flags & MPOL_F_NODE) {
if (flags & MPOL_F_ADDR) {
- err = lookup_node(addr);
+ /*
+ * Take a refcount on the mpol, lookup_node()
+ * wil drop the mmap_sem, so after calling
+ * lookup_node() only "pol" remains valid, "vma"
+ * is stale.
+ */
+ pol_refcount = pol;
+ vma = NULL;
+ mpol_get(pol);
+ err = lookup_node(mm, addr);
if (err < 0)
goto out;
*policy = err;
@@ -892,7 +904,9 @@ static long do_get_mempolicy(int *policy, nodemask_t *nmask,
out:
mpol_cond_put(pol);
if (vma)
- up_read(&current->mm->mmap_sem);
+ up_read(&mm->mmap_sem);
+ if (pol_refcount)
+ mpol_put(pol_refcount);
return err;
}
@@ -2697,12 +2711,11 @@ static const char * const policy_modes[] =
int mpol_parse_str(char *str, struct mempolicy **mpol)
{
struct mempolicy *new = NULL;
- unsigned short mode;
unsigned short mode_flags;
nodemask_t nodes;
char *nodelist = strchr(str, ':');
char *flags = strchr(str, '=');
- int err = 1;
+ int err = 1, mode;
if (nodelist) {
/* NUL-terminate mode or flags string */
@@ -2717,12 +2730,8 @@ int mpol_parse_str(char *str, struct mempolicy **mpol)
if (flags)
*flags++ = '\0'; /* terminate mode string */
- for (mode = 0; mode < MPOL_MAX; mode++) {
- if (!strcmp(str, policy_modes[mode])) {
- break;
- }
- }
- if (mode >= MPOL_MAX)
+ mode = match_string(policy_modes, MPOL_MAX, str);
+ if (mode < 0)
goto out;
switch (mode) {